LCOV - code coverage report
Current view: top level - nntrainer/graph - graph_core.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 92.9 % 84 78
Test Date: 2025-12-14 20:38:17 Functions: 92.9 % 14 13

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
       4              :  *
       5              :  * @file    network_graph.h
       6              :  * @date    12 May 2020
       7              :  * @see     https://github.com/nnstreamer/nntrainer
       8              :  * @author  Jijoong Moon <jijoong.moon@samsung.com>
       9              :  * @author  Parichay Kapoor <pk.kapoor@samsung.com>
      10              :  * @bug     No known bugs except for NYI items
      11              :  * @brief   This is Graph Core Class for Neural Network
      12              :  *
      13              :  */
      14              : 
      15              : #include <algorithm>
      16              : #include <sstream>
      17              : 
      18              : #include <graph_core.h>
      19              : #include <nntrainer_error.h>
      20              : #include <nntrainer_log.h>
      21              : 
      22              : namespace nntrainer {
      23              : 
      24         8615 : void GraphCore::addGraphNode(std::shared_ptr<GraphNode> node) {
      25         8615 :   node_list.push_back(node);
      26        17230 :   node_map[node->getName()] = node_list.size() - 1;
      27         8615 : }
      28              : 
      29         4469 : const std::shared_ptr<GraphNode> &GraphCore::getNode(unsigned int ith) const {
      30         4469 :   return node_list.at(ith);
      31              : }
      32              : 
      33              : const std::shared_ptr<GraphNode> &
      34        20332 : GraphCore::getSortedNode(unsigned int ith) const {
      35        20332 :   return Sorted.at(ith);
      36              : }
      37              : 
      38            0 : const unsigned int GraphCore::getSortedNodeIdx(const std::string &name) const {
      39            0 :   return sorted_node_map.at(name);
      40              : }
      41              : 
      42          633 : void GraphCore::makeAdjacencyList(
      43              :   std::vector<std::list<std::shared_ptr<GraphNode>>> &adj) {
      44              :   /** initialize the adj list */
      45         5102 :   for (auto &node : node_list) {
      46        13407 :     adj.push_back(std::list<std::shared_ptr<GraphNode>>({node}));
      47              :   }
      48              : 
      49              :   /** make the connections */
      50         5102 :   for (auto &node : node_list) {
      51         9075 :     for (auto const &in_conn : node->getInputConnections()) {
      52         4606 :       unsigned int to_node_id = getNodeIdx(in_conn);
      53         4606 :       adj[to_node_id].push_back(node);
      54         4469 :     }
      55              :   }
      56          633 : }
      57              : 
      58         4469 : void GraphCore::topologicalSortUtil(
      59              :   std::vector<std::list<std::shared_ptr<GraphNode>>> &adj, unsigned int ith,
      60              :   std::vector<bool> &visited,
      61              :   std::stack<std::shared_ptr<GraphNode>> &dfs_stack) {
      62              :   visited[ith] = true;
      63              : 
      64              :   std::list<std::shared_ptr<GraphNode>>::iterator i;
      65        13544 :   for (i = adj[ith].begin(); i != adj[ith].end(); ++i) {
      66        18150 :     auto index = getNodeIdx((*i)->getName());
      67         9075 :     if (!visited[index])
      68         3485 :       topologicalSortUtil(adj, index, visited, dfs_stack);
      69              :   }
      70              : 
      71         4469 :   dfs_stack.push(getNode(ith));
      72         4469 : }
      73              : 
      74          633 : void GraphCore::topologicalSort() {
      75              :   std::vector<std::list<std::shared_ptr<GraphNode>>> adj;
      76              :   std::stack<std::shared_ptr<GraphNode>> dfs_stack;
      77          633 :   std::vector<bool> visited(node_list.size(), false);
      78              : 
      79          633 :   makeAdjacencyList(adj);
      80              :   Sorted.clear();
      81              : 
      82              :   // Quite likely this is not needed - verify this
      83              :   // TODO : After make node list of graph, we have to find root. (That means it
      84              :   // should be the only one input for now.). Need to support multiple input and
      85              :   // support search.
      86              : 
      87         5102 :   for (unsigned int i = 0; i < adj.size(); ++i) {
      88         4469 :     if (visited[i] == false) {
      89          984 :       topologicalSortUtil(adj, i, visited, dfs_stack);
      90              :     }
      91              :   }
      92              : 
      93         5102 :   while (dfs_stack.empty() == false) {
      94         4477 :     Sorted.push_back(dfs_stack.top());
      95              :     dfs_stack.pop();
      96              :   }
      97              : 
      98          633 :   if (Sorted.size() != node_list.size())
      99            0 :     throw std::runtime_error("Internal error in topologicalSort");
     100              :   unsigned int idx = 0;
     101         5102 :   for (auto &n : Sorted) {
     102         8938 :     sorted_node_map[n->getName()] = idx;
     103         4469 :     idx++;
     104              :   }
     105          633 : }
     106              : 
     107              : const std::shared_ptr<GraphNode> &
     108        12923 : GraphCore::getNode(const std::string &name) const {
     109        12920 :   return node_list.at(node_map.at(name));
     110              : }
     111              : 
     112         8615 : void GraphCore::addNode(std::shared_ptr<GraphNode> node, bool ensure_name) {
     113              :   /** Ensure that the node has a name and is unique */
     114         8615 :   if (ensure_name)
     115        16992 :     ensureName(*node);
     116              : 
     117              :   /** Insert the node to the graph */
     118         8615 :   addGraphNode(node);
     119         8615 : }
     120              : 
     121         8796 : void GraphCore::ensureName(GraphNode &node, const std::string &prefix_,
     122              :                            const std::string &postfix_, bool force_rename) {
     123        26388 :   auto to_lower = [](const std::string &str) -> std::string {
     124              :     std::string ret = str;
     125              :     std::transform(ret.begin(), ret.end(), ret.begin(),
     126       109887 :                    [](unsigned char c) { return std::tolower(c); });
     127        26388 :     return ret;
     128              :   };
     129              : 
     130         8796 :   std::string orig_name = to_lower(node.getName());
     131         8796 :   std::string prefix = to_lower(prefix_);
     132         8796 :   std::string postfix = to_lower(postfix_);
     133              : 
     134              :   bool orig_name_empty = orig_name.empty();
     135              :   /** If node already has name which is unique and valid, and force is
     136              :    * disabled, then nothing to do.
     137              :    */
     138         8796 :   if (!orig_name_empty && !force_rename && !verifyNode(orig_name)) {
     139         8465 :     node.setName(orig_name);
     140              :     node_names.emplace(orig_name);
     141              :     return;
     142              :   }
     143              : 
     144              :   /** If just prefix with node name makes it unique - directly set the name */
     145          331 :   if (!orig_name_empty) {
     146           12 :     std::string direct_name = prefix + orig_name + postfix;
     147              :     if (!verifyNode(direct_name)) {
     148            0 :       node.setName(direct_name);
     149              :       node_names.emplace(direct_name);
     150              :       return;
     151              :     }
     152              :   }
     153              : 
     154              :   std::unordered_set<std::string>::iterator iter;
     155              :   std::string name;
     156          331 :   if (orig_name_empty) {
     157          650 :     orig_name = node.getType();
     158              :   }
     159              : 
     160          662 :   std::string direct_name = prefix + orig_name + postfix;
     161              : 
     162              :   do {
     163          662 :     name = direct_name + std::to_string(def_name_count++);
     164              :     iter = node_names.find(name);
     165          331 :   } while (iter != node_names.end());
     166              : 
     167          331 :   node.setName(name);
     168              :   node_names.emplace(name);
     169              : }
     170              : 
     171          181 : void GraphCore::replaceNode(std::shared_ptr<GraphNode> from,
     172              :                             std::shared_ptr<GraphNode> to) {
     173          362 :   if (node_map.find(from->getName()) == node_map.end())
     174            0 :     throw std::invalid_argument("Graph node to be replaced is missing");
     175          362 :   if (node_map.find(to->getName()) != node_map.end())
     176            0 :     throw std::invalid_argument("Nodes in the graph must be unique");
     177              : 
     178          181 :   unsigned int from_idx = getNodeIdx(from->getName());
     179          181 :   node_list[from_idx] = to;
     180          181 :   node_map.erase(from->getName());
     181          362 :   node_map[to->getName()] = from_idx;
     182          181 : }
     183              : 
     184          642 : void GraphCore::realizeInputOutputNode() {
     185         5027 :   for (auto iter = cbegin(); iter != cend(); ++iter) {
     186         8770 :     if (iter->getInputConnections().size() == 0) {
     187         1986 :       input_list.push_back(*iter);
     188              :     }
     189         8770 :     if (iter->getOutputConnections().size() == 0) {
     190         1322 :       output_list.push_back(*iter);
     191              :     }
     192              :   }
     193          642 : }
     194              : 
     195        13862 : unsigned int GraphCore::getNodeIdx(const std::string &name) {
     196        13862 :   return node_map.at(name);
     197              : }
     198              : 
     199              : } /* namespace nntrainer */
        

Generated by: LCOV version 2.0-1