Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file slice_realizer.cpp
6 : * @date 14 October 2021
7 : * @brief NNTrainer graph realizer which slice the graph representation
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : */
12 :
13 : #include <connection.h>
14 : #include <iterator>
15 : #include <layer_node.h>
16 : #include <slice_realizer.h>
17 :
18 : #include <unordered_map>
19 :
20 : namespace nntrainer {
21 :
22 57 : SliceRealizer::SliceRealizer(const std::vector<Connection> &start_layers,
23 57 : const std::vector<Connection> &end_layers) {
24 : /// discard index information as it is not needed as it is not really needed
25 57 : this->start_layers.reserve(start_layers.size());
26 :
27 : std::transform(
28 : start_layers.begin(), start_layers.end(),
29 : std::back_inserter(this->start_layers),
30 : [](const Connection &c) -> const auto & { return c.getName(); });
31 :
32 57 : std::transform(
33 : end_layers.begin(), end_layers.end(),
34 57 : std::inserter(this->end_layers, this->end_layers.begin()),
35 : [](const Connection &c) -> const auto & { return c.getName(); });
36 57 : }
37 :
38 109 : SliceRealizer::~SliceRealizer() {}
39 :
40 : GraphRepresentation
41 57 : SliceRealizer::realize(const GraphRepresentation &reference) {
42 564 : struct NodeInfo {
43 0 : NodeInfo() : NodeInfo(nullptr) {}
44 282 : NodeInfo(std::shared_ptr<LayerNode> node) :
45 : node(node),
46 282 : is_visited(false),
47 282 : to_be_added(false) {}
48 : std::shared_ptr<LayerNode> node; /**< set this if not visited */
49 : bool is_visited; /**< set this if visited */
50 : bool to_be_added; /**< set this if it is to be added */
51 : std::vector<std::string> children;
52 :
53 : LayerNode *operator->() { return node.get(); }
54 : };
55 :
56 : /** @note mp has to be ordered map to keep the ordering of the nodes in the
57 : * graph */
58 : std::unordered_map<std::string, NodeInfo> mp; /// map point
59 :
60 57 : std::transform(
61 : reference.begin(), reference.end(), std::inserter(mp, mp.end()),
62 282 : [](std::shared_ptr<LayerNode> node) {
63 564 : return std::pair<std::string, NodeInfo>(node->getName(), node);
64 : });
65 :
66 57 : auto cur_start_layers = start_layers;
67 : auto cur_end_layers = end_layers;
68 :
69 : /** setup children before filling in the end layers */
70 57 : std::for_each(reference.begin(), reference.end(),
71 282 : [&mp](std::shared_ptr<LayerNode> node) {
72 282 : auto node_name = node->getName();
73 :
74 510 : for (auto i = 0u, num_node = node->getNumInputConnections();
75 510 : i < num_node; ++i) {
76 228 : const auto &parent = node->getInputConnectionName(i);
77 228 : mp.at(parent).children.push_back(node_name);
78 : };
79 282 : });
80 :
81 57 : if (cur_start_layers.empty()) {
82 3 : for (auto &node : mp) {
83 2 : if (node.second.node->getNumInputConnections() == 0) {
84 3 : cur_start_layers.push_back(node.second.node->getName());
85 : }
86 : }
87 : }
88 :
89 57 : if (cur_end_layers.empty()) {
90 3 : for (auto &node : mp) {
91 2 : if (node.second.children.size() == 0) {
92 0 : cur_end_layers.insert(node.first);
93 : }
94 : }
95 : }
96 :
97 57 : if (cur_start_layers.empty()) {
98 1 : throw std::runtime_error("No start layer is found, graph has a loop.");
99 : }
100 :
101 56 : if (cur_end_layers.empty()) {
102 1 : throw std::runtime_error("No end layer is found, graph has a loop.");
103 : }
104 :
105 : std::vector<std::string> dfs_stack;
106 :
107 : /** if the give node is the end node in the graph */
108 : auto is_end_node = [&cur_end_layers](const std::string &name) {
109 91 : auto iter = cur_end_layers.find(name);
110 : return iter != cur_end_layers.end();
111 55 : };
112 :
113 : /** add node to be included to subgraph */
114 : auto update_processed = [&mp](const std::string &name) {
115 193 : auto &node_info = mp.at(name);
116 38 : node_info.to_be_added = true;
117 193 : };
118 :
119 : /** dfs function to perform depth-first search recursively with tracking */
120 : std::function<void(const std::string &name)> dfs =
121 55 : [&dfs, &mp, &dfs_stack, &is_end_node,
122 : &update_processed](const std::string &name) {
123 228 : auto &node_info = mp.at(name);
124 : /** if node already added or end node, add the current stack to be added
125 : * to the subgraph */
126 228 : if (node_info.to_be_added || is_end_node(name)) {
127 193 : std::for_each(dfs_stack.begin(), dfs_stack.end(), update_processed);
128 193 : update_processed(name);
129 : }
130 :
131 : /** if node is visited, return */
132 228 : if (node_info.is_visited) {
133 : return;
134 : }
135 :
136 91 : node_info.is_visited = true;
137 91 : dfs_stack.push_back(name);
138 : /** run dfs on all the children */
139 128 : for (auto const &child : node_info.children) {
140 37 : dfs(child);
141 : }
142 91 : dfs_stack.pop_back();
143 55 : };
144 :
145 : /** run dfs from all the starting layers */
146 246 : for (auto &name : cur_start_layers) {
147 : dfs(name);
148 : }
149 :
150 : /** created the subgraph */
151 : GraphRepresentation subgraph;
152 : /** @note: iterate over reference than over mp to ensure the correct ordering
153 : * of layers */
154 333 : for (auto &node : reference) {
155 556 : if (mp[node->getName()].to_be_added) {
156 89 : subgraph.push_back(node);
157 : }
158 : }
159 :
160 56 : NNTR_THROW_IF(subgraph.empty(), std::invalid_argument)
161 : << "After slice, there is no node left, please check if configuration is "
162 : "correct";
163 :
164 54 : return subgraph;
165 113 : }
166 :
167 : } // namespace nntrainer
|