Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file network_graph.h
6 : * @date 12 May 2020
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Jijoong Moon <jijoong.moon@samsung.com>
9 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : * @brief This is Graph Core Class for Neural Network
12 : *
13 : */
14 :
15 : #ifndef __GRAPH_CORE_H__
16 : #define __GRAPH_CORE_H__
17 : #ifdef __cplusplus
18 :
19 : #include <list>
20 : #include <map>
21 : #include <memory>
22 : #include <stack>
23 : #include <unordered_map>
24 : #include <unordered_set>
25 : #include <vector>
26 :
27 : #include <graph_node.h>
28 :
29 : namespace nntrainer {
30 :
31 : /**
32 : * @class Graph Core Class
33 : * @brief Graph Core Class which provides core graph functionalities
34 : */
35 : class GraphCore {
36 :
37 : public:
38 : /**
39 : * @brief Constructor of Graph Core Class
40 : */
41 1543 : GraphCore() : node_names(), def_name_count(0) {}
42 :
43 : /**
44 : * @brief Destructor of Graph Core Class
45 : *
46 : */
47 7668 : ~GraphCore() = default;
48 :
49 : /**
50 : * @brief Add the given node into Graph
51 : * @param[in] node shared_ptr of node
52 : */
53 : void addNode(std::shared_ptr<GraphNode> node, bool ensure_name = true);
54 :
55 : /**
56 : * @brief getter of number of nodes
57 : * @param[out] number of nodes
58 : */
59 28276 : unsigned int size() const { return node_list.size(); }
60 :
61 : /**
62 : * @brief get if the graph is empty
63 : * @param[out] true if empty, else false
64 : */
65 : bool empty() const { return node_list.empty(); }
66 :
67 : /**
68 : * @brief Swap function for the class
69 : */
70 540 : friend void swap(GraphCore &lhs, GraphCore &rhs) {
71 : using std::swap;
72 :
73 : swap(lhs.node_list, rhs.node_list);
74 : swap(lhs.node_map, rhs.node_map);
75 : swap(lhs.Sorted, rhs.Sorted);
76 : swap(lhs.node_names, rhs.node_names);
77 : swap(lhs.def_name_count, rhs.def_name_count);
78 540 : }
79 :
80 : /**
81 : * @brief getter of GraphNode with index number
82 : * @param[in] index
83 : * @ret GraphNode
84 : */
85 : const std::shared_ptr<GraphNode> &getNode(unsigned int ith) const;
86 :
87 : /**
88 : * @brief getter of Sorted GraphNode with index number
89 : * @param[in] index
90 : * @ret GraphNode
91 : */
92 : const std::shared_ptr<GraphNode> &getSortedNode(unsigned int ith) const;
93 :
94 : /**
95 : * @brief getter of Sorted GraphNode index with name
96 : * @param[in] layer name
97 : * @ret index
98 : */
99 : const unsigned int getSortedNodeIdx(const std::string &name) const;
100 :
101 : /**
102 : * @brief getter of GraphNode with node name
103 : * @param[in] node name
104 : * @retval GraphNode
105 : */
106 : const std::shared_ptr<GraphNode> &getNode(const std::string &name) const;
107 :
108 : /**
109 : * @brief get begin iterator for the forwarding
110 : * @retval const iterator marking the begin of forwarding
111 : * @note this function should not be used when node_list is empty.
112 : * if node_list is empty, end iterator is dereferenced,
113 : * This action is undefined behavior.
114 : */
115 : template <
116 : typename T = GraphNode,
117 : std::enable_if_t<std::is_base_of<GraphNode, T>::value, T> * = nullptr>
118 : inline graph_const_iterator<T> cbegin() const {
119 41435 : if (Sorted.empty())
120 : return graph_const_iterator<T>(&(*node_list.cbegin()));
121 : else
122 : return graph_const_iterator<T>(&(*Sorted.cbegin()));
123 : }
124 :
125 : /**
126 : * @brief get end iterator for the forwarding
127 : * @retval const iterator marking the end of forwarding
128 : * @note this function should not be used when node_list is empty.
129 : * if node_list is empty, end iterator is dereferenced,
130 : * This action is undefined behavior.
131 : */
132 : template <
133 : typename T = GraphNode,
134 : std::enable_if_t<std::is_base_of<GraphNode, T>::value, T> * = nullptr>
135 : inline graph_const_iterator<T> cend() const {
136 121318 : if (Sorted.empty())
137 : return graph_const_iterator<T>(&(*node_list.cbegin())) + node_list.size();
138 : else
139 : return graph_const_iterator<T>(&(*Sorted.cbegin())) + Sorted.size();
140 : }
141 :
142 : /**
143 : * @brief get begin iterator for the backwarding
144 : * @retval const reverse iterator marking the begin of backwarding
145 : */
146 : template <
147 : typename T = GraphNode,
148 : std::enable_if_t<std::is_base_of<GraphNode, T>::value, T> * = nullptr>
149 : inline graph_const_reverse_iterator<T> crbegin() const {
150 : return graph_const_reverse_iterator<T>(cend<T>());
151 : }
152 :
153 : /**
154 : * @brief get end iterator for the backwarding
155 : * @retval const reverse iterator marking the end of backwarding
156 : */
157 : template <
158 : typename T = GraphNode,
159 : std::enable_if_t<std::is_base_of<GraphNode, T>::value, T> * = nullptr>
160 : inline graph_const_reverse_iterator<T> crend() const {
161 : return graph_const_reverse_iterator<T>(cbegin<T>());
162 : }
163 :
164 : /**
165 : * @brief Sorting and Define order to calculate : Depth First Search
166 : */
167 : void topologicalSort();
168 :
169 : /**
170 : * @brief Copy the graph
171 : * @param[in] from Graph Object to copy
172 : * @retval Graph Object copyed
173 : */
174 : GraphCore ©(GraphCore &from) {
175 0 : node_list.resize(from.node_list.size());
176 : if (this != &from) {
177 : // for (unsigned int i = 0; i < node_list.size(); ++i)
178 : // node_list[i]->copy(from.node_list[i]);
179 : }
180 : return *this;
181 : }
182 :
183 : /**
184 : * @brief Ensure that node has a name.
185 : * @param[in] node GraphNode whose name is to be ensured to be valid
186 : * @param[in] prefix Prefix to be attached to the node name
187 : * @param[in] postfix Postfix to be attached to the node name
188 : * @param[in] force_rename If the node must be forcefully rename
189 : * @details Ensures that the node has a unique and a valid name. A valid
190 : * name pre-assigned to the node can be changed if force_rename is enabled.
191 : */
192 : void ensureName(GraphNode &node, const std::string &prefix = "",
193 : const std::string &postfix = "", bool force_rename = false);
194 :
195 : /**
196 : * @brief Replace graph node in node_list
197 : * @param from Graph node to be replaced
198 : * @param to Graph node to replace
199 : */
200 : void replaceNode(std::shared_ptr<GraphNode> from,
201 : std::shared_ptr<GraphNode> to);
202 :
203 : /**
204 : * @brief getter of graph input nodes with index number
205 : * @param idx
206 : * @return graph node of input node
207 : */
208 : const std::shared_ptr<GraphNode> &getInputNode(unsigned int idx) const {
209 : return input_list[idx];
210 : }
211 :
212 : /**
213 : * @brief getter of number of input nodes
214 : * @return number of input nodes
215 : */
216 : unsigned int getNumInputNodes() const { return input_list.size(); }
217 :
218 : /**
219 : * @brief getter of graph output nodes with index number
220 : * @param idx
221 : * @return graph node of output node
222 : */
223 : const std::shared_ptr<GraphNode> &getOutputNode(unsigned int idx) const {
224 7505 : return output_list[idx];
225 : }
226 :
227 : /**
228 : * @brief getter of number of output nodes
229 : * @return number of output nodes
230 : */
231 14958 : unsigned int getNumOutputNodes() const { return output_list.size(); }
232 :
233 : /**
234 : * @brief replace output node
235 : * @param idx output node index to be replaced
236 : * @param node graph node shared pointer to replace
237 : */
238 : void replaceOutputNode(unsigned int idx, std::shared_ptr<GraphNode> node) {
239 : output_list[idx] = node;
240 : }
241 :
242 : /**
243 : * @brief find which node is a input or output node in graph
244 : */
245 : void realizeInputOutputNode();
246 :
247 : /**
248 : * @brief Verify if the node exists
249 : */
250 : inline bool verifyNode(const std::string &name) {
251 8477 : if (node_names.find(name) == node_names.end())
252 : return false;
253 : return true;
254 : }
255 :
256 : private:
257 : std::vector<std::shared_ptr<GraphNode>> input_list;
258 : std::vector<std::shared_ptr<GraphNode>> output_list;
259 : std::vector<std::shared_ptr<GraphNode>> node_list; /**< Unordered Node List */
260 : std::unordered_map<std::string, int> node_map; /**< Unordered Node map */
261 : std::unordered_map<std::string, int>
262 : sorted_node_map; /**< Unordered Node map */
263 : std::vector<std::shared_ptr<GraphNode>> Sorted; /**< Ordered Node List */
264 :
265 : std::unordered_set<std::string>
266 : node_names; /**< Set containing all the names of nodes in the model */
267 : int def_name_count; /**< Count assigned to node names declared by default */
268 :
269 : /**
270 : * @brief topological sort
271 : * @param[in] ith index of GraphNode
272 : * @param[in] visited temp list
273 : * @param[in] stack for Node list to visit.
274 : */
275 : void
276 : topologicalSortUtil(std::vector<std::list<std::shared_ptr<GraphNode>>> &adj,
277 : unsigned int ith, std::vector<bool> &visited,
278 : std::stack<std::shared_ptr<GraphNode>> &Stack);
279 :
280 : /**
281 : * @brief Add given GraphNode to the Graph
282 : * @param[in] node shared_ptr of GraphNode
283 : */
284 : void addGraphNode(std::shared_ptr<GraphNode> node);
285 :
286 : /**
287 : * @brief make adjancency list for the current graph
288 : */
289 : void
290 : makeAdjacencyList(std::vector<std::list<std::shared_ptr<GraphNode>>> &adj);
291 :
292 : /**
293 : * @brief Get index of the node with given name
294 : *
295 : * @param name Name of the node
296 : * @return internal index of the node
297 : */
298 : unsigned int getNodeIdx(const std::string &name);
299 : };
300 :
301 : } // namespace nntrainer
302 :
303 : #endif /* __cplusplus */
304 : #endif /* __NETWORK_GRAPH_H__ */
|