Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file network_graph.h
6 : * @date 12 May 2020
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Jijoong Moon <jijoong.moon@samsung.com>
9 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : * @brief This is Graph Core Class for Neural Network
12 : *
13 : */
14 :
15 : #include <algorithm>
16 : #include <sstream>
17 :
18 : #include <graph_core.h>
19 : #include <nntrainer_error.h>
20 : #include <nntrainer_log.h>
21 :
22 : namespace nntrainer {
23 :
24 8615 : void GraphCore::addGraphNode(std::shared_ptr<GraphNode> node) {
25 8615 : node_list.push_back(node);
26 17230 : node_map[node->getName()] = node_list.size() - 1;
27 8615 : }
28 :
29 4469 : const std::shared_ptr<GraphNode> &GraphCore::getNode(unsigned int ith) const {
30 4469 : return node_list.at(ith);
31 : }
32 :
33 : const std::shared_ptr<GraphNode> &
34 20332 : GraphCore::getSortedNode(unsigned int ith) const {
35 20332 : return Sorted.at(ith);
36 : }
37 :
38 0 : const unsigned int GraphCore::getSortedNodeIdx(const std::string &name) const {
39 0 : return sorted_node_map.at(name);
40 : }
41 :
42 633 : void GraphCore::makeAdjacencyList(
43 : std::vector<std::list<std::shared_ptr<GraphNode>>> &adj) {
44 : /** initialize the adj list */
45 5102 : for (auto &node : node_list) {
46 13407 : adj.push_back(std::list<std::shared_ptr<GraphNode>>({node}));
47 : }
48 :
49 : /** make the connections */
50 5102 : for (auto &node : node_list) {
51 9075 : for (auto const &in_conn : node->getInputConnections()) {
52 4606 : unsigned int to_node_id = getNodeIdx(in_conn);
53 4606 : adj[to_node_id].push_back(node);
54 4469 : }
55 : }
56 633 : }
57 :
58 4469 : void GraphCore::topologicalSortUtil(
59 : std::vector<std::list<std::shared_ptr<GraphNode>>> &adj, unsigned int ith,
60 : std::vector<bool> &visited,
61 : std::stack<std::shared_ptr<GraphNode>> &dfs_stack) {
62 : visited[ith] = true;
63 :
64 : std::list<std::shared_ptr<GraphNode>>::iterator i;
65 13544 : for (i = adj[ith].begin(); i != adj[ith].end(); ++i) {
66 18150 : auto index = getNodeIdx((*i)->getName());
67 9075 : if (!visited[index])
68 3485 : topologicalSortUtil(adj, index, visited, dfs_stack);
69 : }
70 :
71 4469 : dfs_stack.push(getNode(ith));
72 4469 : }
73 :
74 633 : void GraphCore::topologicalSort() {
75 : std::vector<std::list<std::shared_ptr<GraphNode>>> adj;
76 : std::stack<std::shared_ptr<GraphNode>> dfs_stack;
77 633 : std::vector<bool> visited(node_list.size(), false);
78 :
79 633 : makeAdjacencyList(adj);
80 : Sorted.clear();
81 :
82 : // Quite likely this is not needed - verify this
83 : // TODO : After make node list of graph, we have to find root. (That means it
84 : // should be the only one input for now.). Need to support multiple input and
85 : // support search.
86 :
87 5102 : for (unsigned int i = 0; i < adj.size(); ++i) {
88 4469 : if (visited[i] == false) {
89 984 : topologicalSortUtil(adj, i, visited, dfs_stack);
90 : }
91 : }
92 :
93 5102 : while (dfs_stack.empty() == false) {
94 4477 : Sorted.push_back(dfs_stack.top());
95 : dfs_stack.pop();
96 : }
97 :
98 633 : if (Sorted.size() != node_list.size())
99 0 : throw std::runtime_error("Internal error in topologicalSort");
100 : unsigned int idx = 0;
101 5102 : for (auto &n : Sorted) {
102 8938 : sorted_node_map[n->getName()] = idx;
103 4469 : idx++;
104 : }
105 633 : }
106 :
107 : const std::shared_ptr<GraphNode> &
108 12923 : GraphCore::getNode(const std::string &name) const {
109 12920 : return node_list.at(node_map.at(name));
110 : }
111 :
112 8615 : void GraphCore::addNode(std::shared_ptr<GraphNode> node, bool ensure_name) {
113 : /** Ensure that the node has a name and is unique */
114 8615 : if (ensure_name)
115 16992 : ensureName(*node);
116 :
117 : /** Insert the node to the graph */
118 8615 : addGraphNode(node);
119 8615 : }
120 :
121 8796 : void GraphCore::ensureName(GraphNode &node, const std::string &prefix_,
122 : const std::string &postfix_, bool force_rename) {
123 26388 : auto to_lower = [](const std::string &str) -> std::string {
124 : std::string ret = str;
125 : std::transform(ret.begin(), ret.end(), ret.begin(),
126 109887 : [](unsigned char c) { return std::tolower(c); });
127 26388 : return ret;
128 : };
129 :
130 8796 : std::string orig_name = to_lower(node.getName());
131 8796 : std::string prefix = to_lower(prefix_);
132 8796 : std::string postfix = to_lower(postfix_);
133 :
134 : bool orig_name_empty = orig_name.empty();
135 : /** If node already has name which is unique and valid, and force is
136 : * disabled, then nothing to do.
137 : */
138 8796 : if (!orig_name_empty && !force_rename && !verifyNode(orig_name)) {
139 8465 : node.setName(orig_name);
140 : node_names.emplace(orig_name);
141 : return;
142 : }
143 :
144 : /** If just prefix with node name makes it unique - directly set the name */
145 331 : if (!orig_name_empty) {
146 12 : std::string direct_name = prefix + orig_name + postfix;
147 : if (!verifyNode(direct_name)) {
148 0 : node.setName(direct_name);
149 : node_names.emplace(direct_name);
150 : return;
151 : }
152 : }
153 :
154 : std::unordered_set<std::string>::iterator iter;
155 : std::string name;
156 331 : if (orig_name_empty) {
157 650 : orig_name = node.getType();
158 : }
159 :
160 662 : std::string direct_name = prefix + orig_name + postfix;
161 :
162 : do {
163 662 : name = direct_name + std::to_string(def_name_count++);
164 : iter = node_names.find(name);
165 331 : } while (iter != node_names.end());
166 :
167 331 : node.setName(name);
168 : node_names.emplace(name);
169 : }
170 :
171 181 : void GraphCore::replaceNode(std::shared_ptr<GraphNode> from,
172 : std::shared_ptr<GraphNode> to) {
173 362 : if (node_map.find(from->getName()) == node_map.end())
174 0 : throw std::invalid_argument("Graph node to be replaced is missing");
175 362 : if (node_map.find(to->getName()) != node_map.end())
176 0 : throw std::invalid_argument("Nodes in the graph must be unique");
177 :
178 181 : unsigned int from_idx = getNodeIdx(from->getName());
179 181 : node_list[from_idx] = to;
180 181 : node_map.erase(from->getName());
181 362 : node_map[to->getName()] = from_idx;
182 181 : }
183 :
184 642 : void GraphCore::realizeInputOutputNode() {
185 5027 : for (auto iter = cbegin(); iter != cend(); ++iter) {
186 8770 : if (iter->getInputConnections().size() == 0) {
187 1986 : input_list.push_back(*iter);
188 : }
189 8770 : if (iter->getOutputConnections().size() == 0) {
190 1322 : output_list.push_back(*iter);
191 : }
192 : }
193 642 : }
194 :
195 13862 : unsigned int GraphCore::getNodeIdx(const std::string &name) {
196 13862 : return node_map.at(name);
197 : }
198 :
199 : } /* namespace nntrainer */
|