LCOV - code coverage report
Current view: top level - nntrainer/graph - graph_core.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 90.9 % 11 10
Test Date: 2025-12-14 20:38:17 Functions: 100.0 % 3 3

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
       4              :  *
       5              :  * @file    network_graph.h
       6              :  * @date    12 May 2020
       7              :  * @see     https://github.com/nnstreamer/nntrainer
       8              :  * @author  Jijoong Moon <jijoong.moon@samsung.com>
       9              :  * @author  Parichay Kapoor <pk.kapoor@samsung.com>
      10              :  * @bug     No known bugs except for NYI items
      11              :  * @brief   This is Graph Core Class for Neural Network
      12              :  *
      13              :  */
      14              : 
      15              : #ifndef __GRAPH_CORE_H__
      16              : #define __GRAPH_CORE_H__
      17              : #ifdef __cplusplus
      18              : 
      19              : #include <list>
      20              : #include <map>
      21              : #include <memory>
      22              : #include <stack>
      23              : #include <unordered_map>
      24              : #include <unordered_set>
      25              : #include <vector>
      26              : 
      27              : #include <graph_node.h>
      28              : 
      29              : namespace nntrainer {
      30              : 
      31              : /**
      32              :  * @class   Graph Core Class
      33              :  * @brief   Graph Core Class which provides core graph functionalities
      34              :  */
      35              : class GraphCore {
      36              : 
      37              : public:
      38              :   /**
      39              :    * @brief     Constructor of Graph Core Class
      40              :    */
      41         1543 :   GraphCore() : node_names(), def_name_count(0) {}
      42              : 
      43              :   /**
      44              :    * @brief     Destructor of Graph Core Class
      45              :    *
      46              :    */
      47         7668 :   ~GraphCore() = default;
      48              : 
      49              :   /**
      50              :    * @brief Add the given node into Graph
      51              :    * @param[in] node shared_ptr of node
      52              :    */
      53              :   void addNode(std::shared_ptr<GraphNode> node, bool ensure_name = true);
      54              : 
      55              :   /**
      56              :    * @brief getter of number of nodes
      57              :    * @param[out] number of nodes
      58              :    */
      59        28276 :   unsigned int size() const { return node_list.size(); }
      60              : 
      61              :   /**
      62              :    * @brief get if the graph is empty
      63              :    * @param[out] true if empty, else false
      64              :    */
      65              :   bool empty() const { return node_list.empty(); }
      66              : 
      67              :   /**
      68              :    * @brief     Swap function for the class
      69              :    */
      70          540 :   friend void swap(GraphCore &lhs, GraphCore &rhs) {
      71              :     using std::swap;
      72              : 
      73              :     swap(lhs.node_list, rhs.node_list);
      74              :     swap(lhs.node_map, rhs.node_map);
      75              :     swap(lhs.Sorted, rhs.Sorted);
      76              :     swap(lhs.node_names, rhs.node_names);
      77              :     swap(lhs.def_name_count, rhs.def_name_count);
      78          540 :   }
      79              : 
      80              :   /**
      81              :    * @brief getter of GraphNode with index number
      82              :    * @param[in] index
      83              :    * @ret GraphNode
      84              :    */
      85              :   const std::shared_ptr<GraphNode> &getNode(unsigned int ith) const;
      86              : 
      87              :   /**
      88              :    * @brief getter of Sorted GraphNode with index number
      89              :    * @param[in] index
      90              :    * @ret GraphNode
      91              :    */
      92              :   const std::shared_ptr<GraphNode> &getSortedNode(unsigned int ith) const;
      93              : 
      94              :   /**
      95              :    * @brief getter of Sorted GraphNode index with name
      96              :    * @param[in] layer name
      97              :    * @ret index
      98              :    */
      99              :   const unsigned int getSortedNodeIdx(const std::string &name) const;
     100              : 
     101              :   /**
     102              :    * @brief getter of GraphNode with node name
     103              :    * @param[in] node name
     104              :    * @retval GraphNode
     105              :    */
     106              :   const std::shared_ptr<GraphNode> &getNode(const std::string &name) const;
     107              : 
     108              :   /**
     109              :    * @brief     get begin iterator for the forwarding
     110              :    * @retval    const iterator marking the begin of forwarding
     111              :    * @note      this function should not be used when node_list is empty.
     112              :    * if node_list is empty, end iterator is dereferenced,
     113              :    * This action is undefined behavior.
     114              :    */
     115              :   template <
     116              :     typename T = GraphNode,
     117              :     std::enable_if_t<std::is_base_of<GraphNode, T>::value, T> * = nullptr>
     118              :   inline graph_const_iterator<T> cbegin() const {
     119        41435 :     if (Sorted.empty())
     120              :       return graph_const_iterator<T>(&(*node_list.cbegin()));
     121              :     else
     122              :       return graph_const_iterator<T>(&(*Sorted.cbegin()));
     123              :   }
     124              : 
     125              :   /**
     126              :    * @brief     get end iterator for the forwarding
     127              :    * @retval    const iterator marking the end of forwarding
     128              :    * @note      this function should not be used when node_list is empty.
     129              :    * if node_list is empty, end iterator is dereferenced,
     130              :    * This action is undefined behavior.
     131              :    */
     132              :   template <
     133              :     typename T = GraphNode,
     134              :     std::enable_if_t<std::is_base_of<GraphNode, T>::value, T> * = nullptr>
     135              :   inline graph_const_iterator<T> cend() const {
     136       121318 :     if (Sorted.empty())
     137              :       return graph_const_iterator<T>(&(*node_list.cbegin())) + node_list.size();
     138              :     else
     139              :       return graph_const_iterator<T>(&(*Sorted.cbegin())) + Sorted.size();
     140              :   }
     141              : 
     142              :   /**
     143              :    * @brief     get begin iterator for the backwarding
     144              :    * @retval    const reverse iterator marking the begin of backwarding
     145              :    */
     146              :   template <
     147              :     typename T = GraphNode,
     148              :     std::enable_if_t<std::is_base_of<GraphNode, T>::value, T> * = nullptr>
     149              :   inline graph_const_reverse_iterator<T> crbegin() const {
     150              :     return graph_const_reverse_iterator<T>(cend<T>());
     151              :   }
     152              : 
     153              :   /**
     154              :    * @brief     get end iterator for the backwarding
     155              :    * @retval    const reverse iterator marking the end of backwarding
     156              :    */
     157              :   template <
     158              :     typename T = GraphNode,
     159              :     std::enable_if_t<std::is_base_of<GraphNode, T>::value, T> * = nullptr>
     160              :   inline graph_const_reverse_iterator<T> crend() const {
     161              :     return graph_const_reverse_iterator<T>(cbegin<T>());
     162              :   }
     163              : 
     164              :   /**
     165              :    * @brief Sorting and Define order to calculate : Depth First Search
     166              :    */
     167              :   void topologicalSort();
     168              : 
     169              :   /**
     170              :    * @brief     Copy the graph
     171              :    * @param[in] from Graph Object to copy
     172              :    * @retval    Graph Object copyed
     173              :    */
     174              :   GraphCore &copy(GraphCore &from) {
     175            0 :     node_list.resize(from.node_list.size());
     176              :     if (this != &from) {
     177              :       //      for (unsigned int i = 0; i < node_list.size(); ++i)
     178              :       //        node_list[i]->copy(from.node_list[i]);
     179              :     }
     180              :     return *this;
     181              :   }
     182              : 
     183              :   /**
     184              :    * @brief     Ensure that node has a name.
     185              :    * @param[in] node GraphNode whose name is to be ensured to be valid
     186              :    * @param[in] prefix Prefix to be attached to the node name
     187              :    * @param[in] postfix Postfix to be attached to the node name
     188              :    * @param[in] force_rename If the node must be forcefully rename
     189              :    * @details   Ensures that the node has a unique and a valid name. A valid
     190              :    * name pre-assigned to the node can be changed if force_rename is enabled.
     191              :    */
     192              :   void ensureName(GraphNode &node, const std::string &prefix = "",
     193              :                   const std::string &postfix = "", bool force_rename = false);
     194              : 
     195              :   /**
     196              :    * @brief   Replace graph node in node_list
     197              :    * @param   from Graph node to be replaced
     198              :    * @param   to Graph node to replace
     199              :    */
     200              :   void replaceNode(std::shared_ptr<GraphNode> from,
     201              :                    std::shared_ptr<GraphNode> to);
     202              : 
     203              :   /**
     204              :    * @brief   getter of graph input nodes with index number
     205              :    * @param   idx
     206              :    * @return  graph node of input node
     207              :    */
     208              :   const std::shared_ptr<GraphNode> &getInputNode(unsigned int idx) const {
     209              :     return input_list[idx];
     210              :   }
     211              : 
     212              :   /**
     213              :    * @brief   getter of number of input nodes
     214              :    * @return  number of input nodes
     215              :    */
     216              :   unsigned int getNumInputNodes() const { return input_list.size(); }
     217              : 
     218              :   /**
     219              :    * @brief   getter of graph output nodes with index number
     220              :    * @param   idx
     221              :    * @return  graph node of output node
     222              :    */
     223              :   const std::shared_ptr<GraphNode> &getOutputNode(unsigned int idx) const {
     224         7505 :     return output_list[idx];
     225              :   }
     226              : 
     227              :   /**
     228              :    * @brief   getter of number of output nodes
     229              :    * @return  number of output nodes
     230              :    */
     231        14958 :   unsigned int getNumOutputNodes() const { return output_list.size(); }
     232              : 
     233              :   /**
     234              :    * @brief       replace output node
     235              :    * @param idx   output node index to be replaced
     236              :    * @param node  graph node shared pointer to replace
     237              :    */
     238              :   void replaceOutputNode(unsigned int idx, std::shared_ptr<GraphNode> node) {
     239              :     output_list[idx] = node;
     240              :   }
     241              : 
     242              :   /**
     243              :    * @brief find which node is a input or output node in graph
     244              :    */
     245              :   void realizeInputOutputNode();
     246              : 
     247              :   /**
     248              :    * @brief     Verify if the node exists
     249              :    */
     250              :   inline bool verifyNode(const std::string &name) {
     251         8477 :     if (node_names.find(name) == node_names.end())
     252              :       return false;
     253              :     return true;
     254              :   }
     255              : 
     256              : private:
     257              :   std::vector<std::shared_ptr<GraphNode>> input_list;
     258              :   std::vector<std::shared_ptr<GraphNode>> output_list;
     259              :   std::vector<std::shared_ptr<GraphNode>> node_list; /**< Unordered Node List */
     260              :   std::unordered_map<std::string, int> node_map;     /**< Unordered Node map  */
     261              :   std::unordered_map<std::string, int>
     262              :     sorted_node_map;                              /**< Unordered Node map  */
     263              :   std::vector<std::shared_ptr<GraphNode>> Sorted; /**< Ordered Node List  */
     264              : 
     265              :   std::unordered_set<std::string>
     266              :     node_names;       /**< Set containing all the names of nodes in the model */
     267              :   int def_name_count; /**< Count assigned to node names declared by default */
     268              : 
     269              :   /**
     270              :    * @brief     topological sort
     271              :    * @param[in] ith index of GraphNode
     272              :    * @param[in] visited temp list
     273              :    * @param[in] stack for Node list to visit.
     274              :    */
     275              :   void
     276              :   topologicalSortUtil(std::vector<std::list<std::shared_ptr<GraphNode>>> &adj,
     277              :                       unsigned int ith, std::vector<bool> &visited,
     278              :                       std::stack<std::shared_ptr<GraphNode>> &Stack);
     279              : 
     280              :   /**
     281              :    * @brief Add given GraphNode to the Graph
     282              :    * @param[in] node shared_ptr of GraphNode
     283              :    */
     284              :   void addGraphNode(std::shared_ptr<GraphNode> node);
     285              : 
     286              :   /**
     287              :    * @brief     make adjancency list for the current graph
     288              :    */
     289              :   void
     290              :   makeAdjacencyList(std::vector<std::list<std::shared_ptr<GraphNode>>> &adj);
     291              : 
     292              :   /**
     293              :    * @brief     Get index of the node with given name
     294              :    *
     295              :    * @param     name Name of the node
     296              :    * @return    internal index of the node
     297              :    */
     298              :   unsigned int getNodeIdx(const std::string &name);
     299              : };
     300              : 
     301              : } // namespace nntrainer
     302              : 
     303              : #endif /* __cplusplus */
     304              : #endif /* __NETWORK_GRAPH_H__ */
        

Generated by: LCOV version 2.0-1