LCOV - code coverage report
Current view: top level - nntrainer/compiler - slice_realizer.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 96.7 % 61 59
Test Date: 2025-12-14 20:38:17 Functions: 100.0 % 7 7

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
       4              :  *
       5              :  * @file slice_realizer.cpp
       6              :  * @date 14 October 2021
       7              :  * @brief NNTrainer graph realizer which slice the graph representation
       8              :  * @see https://github.com/nnstreamer/nntrainer
       9              :  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
      10              :  * @bug No known bugs except for NYI items
      11              :  */
      12              : 
      13              : #include <connection.h>
      14              : #include <iterator>
      15              : #include <layer_node.h>
      16              : #include <slice_realizer.h>
      17              : 
      18              : #include <unordered_map>
      19              : 
      20              : namespace nntrainer {
      21              : 
      22           57 : SliceRealizer::SliceRealizer(const std::vector<Connection> &start_layers,
      23           57 :                              const std::vector<Connection> &end_layers) {
      24              :   /// discard index information as it is not needed as it is not really needed
      25           57 :   this->start_layers.reserve(start_layers.size());
      26              : 
      27              :   std::transform(
      28              :     start_layers.begin(), start_layers.end(),
      29              :     std::back_inserter(this->start_layers),
      30              :     [](const Connection &c) -> const auto & { return c.getName(); });
      31              : 
      32           57 :   std::transform(
      33              :     end_layers.begin(), end_layers.end(),
      34           57 :     std::inserter(this->end_layers, this->end_layers.begin()),
      35              :     [](const Connection &c) -> const auto & { return c.getName(); });
      36           57 : }
      37              : 
      38          109 : SliceRealizer::~SliceRealizer() {}
      39              : 
      40              : GraphRepresentation
      41           57 : SliceRealizer::realize(const GraphRepresentation &reference) {
      42          564 :   struct NodeInfo {
      43            0 :     NodeInfo() : NodeInfo(nullptr) {}
      44          282 :     NodeInfo(std::shared_ptr<LayerNode> node) :
      45              :       node(node),
      46          282 :       is_visited(false),
      47          282 :       to_be_added(false) {}
      48              :     std::shared_ptr<LayerNode> node; /**< set this if not visited */
      49              :     bool is_visited;                 /**< set this if visited */
      50              :     bool to_be_added;                /**< set this if it is to be added */
      51              :     std::vector<std::string> children;
      52              : 
      53              :     LayerNode *operator->() { return node.get(); }
      54              :   };
      55              : 
      56              :   /** @note mp has to be ordered map to keep the ordering of the nodes in the
      57              :    * graph */
      58              :   std::unordered_map<std::string, NodeInfo> mp; /// map point
      59              : 
      60           57 :   std::transform(
      61              :     reference.begin(), reference.end(), std::inserter(mp, mp.end()),
      62          282 :     [](std::shared_ptr<LayerNode> node) {
      63          564 :       return std::pair<std::string, NodeInfo>(node->getName(), node);
      64              :     });
      65              : 
      66           57 :   auto cur_start_layers = start_layers;
      67              :   auto cur_end_layers = end_layers;
      68              : 
      69              :   /** setup children before filling in the end layers */
      70           57 :   std::for_each(reference.begin(), reference.end(),
      71          282 :                 [&mp](std::shared_ptr<LayerNode> node) {
      72          282 :                   auto node_name = node->getName();
      73              : 
      74          510 :                   for (auto i = 0u, num_node = node->getNumInputConnections();
      75          510 :                        i < num_node; ++i) {
      76          228 :                     const auto &parent = node->getInputConnectionName(i);
      77          228 :                     mp.at(parent).children.push_back(node_name);
      78              :                   };
      79          282 :                 });
      80              : 
      81           57 :   if (cur_start_layers.empty()) {
      82            3 :     for (auto &node : mp) {
      83            2 :       if (node.second.node->getNumInputConnections() == 0) {
      84            3 :         cur_start_layers.push_back(node.second.node->getName());
      85              :       }
      86              :     }
      87              :   }
      88              : 
      89           57 :   if (cur_end_layers.empty()) {
      90            3 :     for (auto &node : mp) {
      91            2 :       if (node.second.children.size() == 0) {
      92            0 :         cur_end_layers.insert(node.first);
      93              :       }
      94              :     }
      95              :   }
      96              : 
      97           57 :   if (cur_start_layers.empty()) {
      98            1 :     throw std::runtime_error("No start layer is found, graph has a loop.");
      99              :   }
     100              : 
     101           56 :   if (cur_end_layers.empty()) {
     102            1 :     throw std::runtime_error("No end layer is found, graph has a loop.");
     103              :   }
     104              : 
     105              :   std::vector<std::string> dfs_stack;
     106              : 
     107              :   /** if the give node is the end node in the graph */
     108              :   auto is_end_node = [&cur_end_layers](const std::string &name) {
     109           91 :     auto iter = cur_end_layers.find(name);
     110              :     return iter != cur_end_layers.end();
     111           55 :   };
     112              : 
     113              :   /** add node to be included to subgraph */
     114              :   auto update_processed = [&mp](const std::string &name) {
     115          193 :     auto &node_info = mp.at(name);
     116           38 :     node_info.to_be_added = true;
     117          193 :   };
     118              : 
     119              :   /** dfs function to perform depth-first search recursively with tracking */
     120              :   std::function<void(const std::string &name)> dfs =
     121           55 :     [&dfs, &mp, &dfs_stack, &is_end_node,
     122              :      &update_processed](const std::string &name) {
     123          228 :       auto &node_info = mp.at(name);
     124              :       /** if node already added or end node, add the current stack to be added
     125              :        * to the subgraph */
     126          228 :       if (node_info.to_be_added || is_end_node(name)) {
     127          193 :         std::for_each(dfs_stack.begin(), dfs_stack.end(), update_processed);
     128          193 :         update_processed(name);
     129              :       }
     130              : 
     131              :       /** if node is visited, return */
     132          228 :       if (node_info.is_visited) {
     133              :         return;
     134              :       }
     135              : 
     136           91 :       node_info.is_visited = true;
     137           91 :       dfs_stack.push_back(name);
     138              :       /** run dfs on all the children */
     139          128 :       for (auto const &child : node_info.children) {
     140           37 :         dfs(child);
     141              :       }
     142           91 :       dfs_stack.pop_back();
     143           55 :     };
     144              : 
     145              :   /** run dfs from all the starting layers */
     146          246 :   for (auto &name : cur_start_layers) {
     147              :     dfs(name);
     148              :   }
     149              : 
     150              :   /** created the subgraph */
     151              :   GraphRepresentation subgraph;
     152              :   /** @note: iterate over reference than over mp to ensure the correct ordering
     153              :    * of layers */
     154          333 :   for (auto &node : reference) {
     155          556 :     if (mp[node->getName()].to_be_added) {
     156           89 :       subgraph.push_back(node);
     157              :     }
     158              :   }
     159              : 
     160           56 :   NNTR_THROW_IF(subgraph.empty(), std::invalid_argument)
     161              :     << "After slice, there is no node left, please check if configuration is "
     162              :        "correct";
     163              : 
     164           54 :   return subgraph;
     165          113 : }
     166              : 
     167              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1