LCOV - code coverage report
Current view: top level - nntrainer/graph - graph_node.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 88.9 % 9 8
Test Date: 2025-12-14 20:38:17 Functions: - 0 0

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
       4              :  *
       5              :  * @file   graph_node.h
       6              :  * @date   1 April 2021
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Parichay Kapoor <pk.kapoor@samsung.com>
       9              :  * @bug    No known bugs except for NYI items
      10              :  * @brief  This is the graph node interface for c++ API
      11              :  */
      12              : 
      13              : #ifndef __GRAPH_NODE_H__
      14              : #define __GRAPH_NODE_H__
      15              : 
      16              : #include <iterator>
      17              : #include <memory>
      18              : #include <string>
      19              : #include <vector>
      20              : 
      21              : namespace nntrainer {
      22              : 
      23              : /**
      24              :  * @class   Layer Base class for the graph node
      25              :  * @brief   Base class for all layers
      26              :  */
      27              : class GraphNode {
      28              : public:
      29              :   /**
      30              :    * @brief Provides the time/order at which the node will be executed.
      31              :    * @details This time will be finalized once the graph has been calculated.
      32              :    * Each element indicates the orders with which the below operations
      33              :    * for each node are executed:
      34              :    * 1. Forwarding
      35              :    * 2. calcGradient
      36              :    * 3. calcDerivative
      37              :    * 4. ApplyGradient
      38              :    * One constraint is that they must be sorted in ascending order.
      39              :    * This ensures that the operations are executed in the order of their
      40              :    * listing.
      41              :    */
      42              :   typedef std::tuple<unsigned int, unsigned int, unsigned int, unsigned int>
      43              :     ExecutionOrder;
      44              : 
      45              :   /**
      46              :    * @brief     Destructor of Layer Class
      47              :    */
      48              :   virtual ~GraphNode() = default;
      49              : 
      50              :   /**
      51              :    * @brief     Get the Name of the underlying object
      52              :    *
      53              :    * @return std::string Name of the underlying object
      54              :    * @note name of each node in the graph must be unique
      55              :    */
      56              :   virtual const std::string getName() const = 0;
      57              : 
      58              :   /**
      59              :    * @brief     Set the Name of the underlying object
      60              :    *
      61              :    * @param[in] std::string Name for the underlying object
      62              :    * @note name of each node in the graph must be unique, and caller must ensure
      63              :    * that
      64              :    */
      65              :   virtual void setName(const std::string &name) = 0;
      66              : 
      67              :   /**
      68              :    * @brief     Get the Type of the underlying object
      69              :    *
      70              :    * @return const std::string type representation
      71              :    */
      72              :   virtual const std::string getType() const = 0;
      73              : 
      74              :   /**
      75              :    * @brief     Get the trainable parameter
      76              :    *
      77              :    * @return bool true / false
      78              :    */
      79              :   virtual bool getTrainable() const = 0;
      80              : 
      81              :   /**
      82              :    * @brief     Get the input connections for this node
      83              :    *
      84              :    * @return list of name of the nodes which form input connections
      85              :    */
      86              :   virtual const std::vector<std::string> getInputConnections() const = 0;
      87              : 
      88              :   /**
      89              :    * @brief     Get the output connections for this node
      90              :    *
      91              :    * @return list of name of the nodes which form output connections
      92              :    */
      93              :   virtual const std::vector<std::string> getOutputConnections() const = 0;
      94              : 
      95              :   /**
      96              :    * @brief     get the execution order/location of this node
      97              :    *
      98              :    * @retval    the execution order/location of this node
      99              :    * @details   The two values represents the value for forward and backward
     100              :    * respectively
     101              :    */
     102              :   virtual ExecutionOrder getExecutionOrder() const = 0;
     103              : 
     104              :   /**
     105              :    * @brief     set the execution order/location of this node
     106              :    *
     107              :    * @param     exec_order the execution order/location of this node
     108              :    * @details   The two values represents the value for forward and backward
     109              :    * respectively
     110              :    */
     111              :   virtual void setExecutionOrder(ExecutionOrder exec_order_) = 0;
     112              : };
     113              : 
     114              : /**
     115              :  * @brief   Iterator for GraphNode which return const
     116              :  * std::shared_ptr<LayerNodeType> object upon realize
     117              :  *
     118              :  * @note    This does not include the complete list of required functions. Add
     119              :  * them as per need.
     120              :  *
     121              :  * @note    GraphNodeType is to enable for both GraphNode and const GraphNode
     122              :  */
     123              : template <typename LayerNodeType, typename GraphNodeType>
     124              : class GraphNodeIterator {
     125              :   GraphNodeType *p; /** underlying object of GraphNode */
     126              : 
     127              : public:
     128              :   /**
     129              :    * @brief   iterator_traits types definition
     130              :    *
     131              :    * @note    these are not required to be explicitly defined now, but maintains
     132              :    *          forward compatibility for c++17 and later
     133              :    *
     134              :    * @note    value_type, pointer and reference are different from standard
     135              :    * iterator
     136              :    */
     137              :   typedef const std::shared_ptr<LayerNodeType> value_type;
     138              :   typedef std::random_access_iterator_tag iterator_category;
     139              :   typedef std::ptrdiff_t difference_type;
     140              :   typedef const std::shared_ptr<LayerNodeType> *pointer;
     141              :   typedef const std::shared_ptr<LayerNodeType> &reference;
     142              : 
     143              :   /**
     144              :    * @brief Construct a new Graph Node Iterator object
     145              :    *
     146              :    * @param x underlying object of GraphNode
     147              :    */
     148              :   GraphNodeIterator(GraphNodeType *x) : p(x) {}
     149              : 
     150              :   /**
     151              :    * @brief reference operator
     152              :    *
     153              :    * @return value_type
     154              :    * @note this is different from standard iterator
     155              :    */
     156              :   value_type operator*() const {
     157              :     return std::static_pointer_cast<LayerNodeType>(*p);
     158              :   }
     159              : 
     160              :   /**
     161              :    * @brief pointer operator
     162              :    *
     163              :    * @return value_type
     164              :    * @note this is different from standard iterator
     165              :    */
     166              :   value_type operator->() const {
     167              :     return std::static_pointer_cast<LayerNodeType>(*p);
     168              :   }
     169              : 
     170              :   /**
     171              :    * @brief == comparison operator override
     172              :    *
     173              :    * @param lhs iterator lhs
     174              :    * @param rhs iterator rhs
     175              :    * @retval true if match
     176              :    * @retval false if mismatch
     177              :    */
     178              :   friend bool operator==(GraphNodeIterator const &lhs,
     179              :                          GraphNodeIterator const &rhs) {
     180              :     return lhs.p == rhs.p;
     181              :   }
     182              : 
     183              :   /**
     184              :    * @brief != comparison operator override
     185              :    *
     186              :    * @param lhs iterator lhs
     187              :    * @param rhs iterator rhs
     188              :    * @retval true if mismatch
     189              :    * @retval false if match
     190              :    */
     191              :   friend bool operator!=(GraphNodeIterator const &lhs,
     192              :                          GraphNodeIterator const &rhs) {
     193          103 :     return lhs.p != rhs.p;
     194              :   }
     195              : 
     196              :   /**
     197              :    * @brief <= comparison operator override
     198              :    *
     199              :    * @param lhs iterator lhs
     200              :    * @param rhs iterator rhs
     201              :    * @retval true if left is less than or equal to the right value
     202              :    * @retval false if left is greater than the right value
     203              :    */
     204              :   friend bool operator<=(GraphNodeIterator const &lhs,
     205              :                          GraphNodeIterator const &rhs) {
     206              :     return lhs.p <= rhs.p;
     207              :   }
     208              : 
     209              :   /**
     210              :    * @brief override for ++ operator
     211              :    *
     212              :    * @return GraphNodeIterator&
     213              :    */
     214              :   GraphNodeIterator &operator++() {
     215         7038 :     p += 1;
     216         6950 :     return *this;
     217              :   }
     218              : 
     219              :   /**
     220              :    * @brief override for operator++
     221              :    *
     222              :    * @return GraphNodeIterator
     223              :    */
     224              :   GraphNodeIterator operator++(int) {
     225              :     GraphNodeIterator temp(p);
     226        83532 :     p += 1;
     227        83532 :     return temp;
     228              :   }
     229              : 
     230              :   /**
     231              :    * @brief override for -- operator
     232              :    *
     233              :    * @return GraphNodeIterator&
     234              :    */
     235              :   GraphNodeIterator &operator--() {
     236        29262 :     p -= 1;
     237              :     return *this;
     238              :   }
     239              : 
     240              :   /**
     241              :    * @brief override for operator--
     242              :    *
     243              :    * @return GraphNodeIterator
     244              :    */
     245              :   GraphNodeIterator operator--(int) {
     246              :     GraphNodeIterator temp(p);
     247              :     p -= 1;
     248              :     return temp;
     249              :   }
     250              : 
     251              :   /**
     252              :    * @brief override for subtract operator
     253              :    *
     254              :    * @param offset offset to subtract
     255              :    * @return GraphNodeIterator
     256              :    */
     257              :   GraphNodeIterator operator-(const difference_type offset) const {
     258         1081 :     return GraphNodeIterator(p - offset);
     259              :   }
     260              : 
     261              :   /**
     262              :    * @brief override for subtract operator
     263              :    *
     264              :    * @param other iterator to subtract
     265              :    * @return difference_type
     266              :    */
     267              :   difference_type operator-(const GraphNodeIterator &other) const {
     268         4484 :     return p - other.p;
     269              :   }
     270              : 
     271              :   /**
     272              :    * @brief override for subtract and return result operator
     273              :    *
     274              :    * @param offset offset to subtract
     275              :    * @return GraphNodeIterator&
     276              :    */
     277              :   GraphNodeIterator &operator-=(const difference_type offset) {
     278              :     p -= offset;
     279              :     return *this;
     280              :   }
     281              : 
     282              :   /**
     283              :    * @brief override for add operator
     284              :    *
     285              :    * @param offset offset to add
     286              :    * @return GraphNodeIterator
     287              :    */
     288              :   GraphNodeIterator operator+(const difference_type offset) const {
     289            0 :     return GraphNodeIterator(p + offset);
     290              :   }
     291              : 
     292              :   /**
     293              :    * @brief override for add and return result operator
     294              :    *
     295              :    * @param offset offset to add
     296              :    * @return GraphNodeIterator&
     297              :    */
     298              :   GraphNodeIterator &operator+=(const difference_type offset) {
     299              :     p += offset;
     300              :     return *this;
     301              :   }
     302              : };
     303              : 
     304              : /**
     305              :  * @brief   Reverse Iterator for GraphNode which return LayerNode object upon
     306              :  * realize
     307              :  *
     308              :  * @note    This just extends GraphNodeIterator and is limited by its
     309              :  * functionality.
     310              :  */
     311              : template <typename T_iterator>
     312              : class GraphNodeReverseIterator : public std::reverse_iterator<T_iterator> {
     313              : public:
     314              :   /**
     315              :    * @brief Construct a new Graph Node Reverse Iterator object
     316              :    *
     317              :    * @param iter Iterator
     318              :    */
     319              :   explicit GraphNodeReverseIterator(T_iterator iter) :
     320              :     std::reverse_iterator<T_iterator>(iter) {}
     321              : 
     322              :   /**
     323              :    *  @brief reference operator
     324              :    *
     325              :    * @return T_iterator::value_type
     326              :    * @note this is different from standard iterator
     327              :    */
     328              :   typename T_iterator::value_type operator*() const {
     329              :     auto temp = std::reverse_iterator<T_iterator>::current - 1;
     330              :     return *temp;
     331              :   }
     332              : 
     333              :   /**
     334              :    *  @brief pointer operator
     335              :    *
     336              :    * @return T_iterator::value_type
     337              :    * @note this is different from standard iterator
     338              :    */
     339              :   typename T_iterator::value_type operator->() const {
     340              :     auto temp = std::reverse_iterator<T_iterator>::current - 1;
     341              :     return *temp;
     342              :   }
     343              : };
     344              : 
     345              : /**
     346              :  * @brief     Iterators to traverse the graph
     347              :  */
     348              : template <class LayerNodeType>
     349              : using graph_const_iterator =
     350              :   GraphNodeIterator<LayerNodeType, const std::shared_ptr<GraphNode>>;
     351              : 
     352              : /**
     353              :  * @brief     Iterators to traverse the graph
     354              :  */
     355              : template <class LayerNodeType>
     356              : using graph_const_reverse_iterator = GraphNodeReverseIterator<
     357              :   GraphNodeIterator<LayerNodeType, const std::shared_ptr<GraphNode>>>;
     358              : 
     359              : } // namespace nntrainer
     360              : #endif // __GRAPH_NODE_H__
        

Generated by: LCOV version 2.0-1