Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file graph_node.h
6 : * @date 1 April 2021
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
9 : * @bug No known bugs except for NYI items
10 : * @brief This is the graph node interface for c++ API
11 : */
12 :
13 : #ifndef __GRAPH_NODE_H__
14 : #define __GRAPH_NODE_H__
15 :
16 : #include <iterator>
17 : #include <memory>
18 : #include <string>
19 : #include <vector>
20 :
21 : namespace nntrainer {
22 :
23 : /**
24 : * @class Layer Base class for the graph node
25 : * @brief Base class for all layers
26 : */
27 : class GraphNode {
28 : public:
29 : /**
30 : * @brief Provides the time/order at which the node will be executed.
31 : * @details This time will be finalized once the graph has been calculated.
32 : * Each element indicates the orders with which the below operations
33 : * for each node are executed:
34 : * 1. Forwarding
35 : * 2. calcGradient
36 : * 3. calcDerivative
37 : * 4. ApplyGradient
38 : * One constraint is that they must be sorted in ascending order.
39 : * This ensures that the operations are executed in the order of their
40 : * listing.
41 : */
42 : typedef std::tuple<unsigned int, unsigned int, unsigned int, unsigned int>
43 : ExecutionOrder;
44 :
45 : /**
46 : * @brief Destructor of Layer Class
47 : */
48 : virtual ~GraphNode() = default;
49 :
50 : /**
51 : * @brief Get the Name of the underlying object
52 : *
53 : * @return std::string Name of the underlying object
54 : * @note name of each node in the graph must be unique
55 : */
56 : virtual const std::string getName() const = 0;
57 :
58 : /**
59 : * @brief Set the Name of the underlying object
60 : *
61 : * @param[in] std::string Name for the underlying object
62 : * @note name of each node in the graph must be unique, and caller must ensure
63 : * that
64 : */
65 : virtual void setName(const std::string &name) = 0;
66 :
67 : /**
68 : * @brief Get the Type of the underlying object
69 : *
70 : * @return const std::string type representation
71 : */
72 : virtual const std::string getType() const = 0;
73 :
74 : /**
75 : * @brief Get the trainable parameter
76 : *
77 : * @return bool true / false
78 : */
79 : virtual bool getTrainable() const = 0;
80 :
81 : /**
82 : * @brief Get the input connections for this node
83 : *
84 : * @return list of name of the nodes which form input connections
85 : */
86 : virtual const std::vector<std::string> getInputConnections() const = 0;
87 :
88 : /**
89 : * @brief Get the output connections for this node
90 : *
91 : * @return list of name of the nodes which form output connections
92 : */
93 : virtual const std::vector<std::string> getOutputConnections() const = 0;
94 :
95 : /**
96 : * @brief get the execution order/location of this node
97 : *
98 : * @retval the execution order/location of this node
99 : * @details The two values represents the value for forward and backward
100 : * respectively
101 : */
102 : virtual ExecutionOrder getExecutionOrder() const = 0;
103 :
104 : /**
105 : * @brief set the execution order/location of this node
106 : *
107 : * @param exec_order the execution order/location of this node
108 : * @details The two values represents the value for forward and backward
109 : * respectively
110 : */
111 : virtual void setExecutionOrder(ExecutionOrder exec_order_) = 0;
112 : };
113 :
114 : /**
115 : * @brief Iterator for GraphNode which return const
116 : * std::shared_ptr<LayerNodeType> object upon realize
117 : *
118 : * @note This does not include the complete list of required functions. Add
119 : * them as per need.
120 : *
121 : * @note GraphNodeType is to enable for both GraphNode and const GraphNode
122 : */
123 : template <typename LayerNodeType, typename GraphNodeType>
124 : class GraphNodeIterator {
125 : GraphNodeType *p; /** underlying object of GraphNode */
126 :
127 : public:
128 : /**
129 : * @brief iterator_traits types definition
130 : *
131 : * @note these are not required to be explicitly defined now, but maintains
132 : * forward compatibility for c++17 and later
133 : *
134 : * @note value_type, pointer and reference are different from standard
135 : * iterator
136 : */
137 : typedef const std::shared_ptr<LayerNodeType> value_type;
138 : typedef std::random_access_iterator_tag iterator_category;
139 : typedef std::ptrdiff_t difference_type;
140 : typedef const std::shared_ptr<LayerNodeType> *pointer;
141 : typedef const std::shared_ptr<LayerNodeType> &reference;
142 :
143 : /**
144 : * @brief Construct a new Graph Node Iterator object
145 : *
146 : * @param x underlying object of GraphNode
147 : */
148 : GraphNodeIterator(GraphNodeType *x) : p(x) {}
149 :
150 : /**
151 : * @brief reference operator
152 : *
153 : * @return value_type
154 : * @note this is different from standard iterator
155 : */
156 : value_type operator*() const {
157 : return std::static_pointer_cast<LayerNodeType>(*p);
158 : }
159 :
160 : /**
161 : * @brief pointer operator
162 : *
163 : * @return value_type
164 : * @note this is different from standard iterator
165 : */
166 : value_type operator->() const {
167 : return std::static_pointer_cast<LayerNodeType>(*p);
168 : }
169 :
170 : /**
171 : * @brief == comparison operator override
172 : *
173 : * @param lhs iterator lhs
174 : * @param rhs iterator rhs
175 : * @retval true if match
176 : * @retval false if mismatch
177 : */
178 : friend bool operator==(GraphNodeIterator const &lhs,
179 : GraphNodeIterator const &rhs) {
180 : return lhs.p == rhs.p;
181 : }
182 :
183 : /**
184 : * @brief != comparison operator override
185 : *
186 : * @param lhs iterator lhs
187 : * @param rhs iterator rhs
188 : * @retval true if mismatch
189 : * @retval false if match
190 : */
191 : friend bool operator!=(GraphNodeIterator const &lhs,
192 : GraphNodeIterator const &rhs) {
193 103 : return lhs.p != rhs.p;
194 : }
195 :
196 : /**
197 : * @brief <= comparison operator override
198 : *
199 : * @param lhs iterator lhs
200 : * @param rhs iterator rhs
201 : * @retval true if left is less than or equal to the right value
202 : * @retval false if left is greater than the right value
203 : */
204 : friend bool operator<=(GraphNodeIterator const &lhs,
205 : GraphNodeIterator const &rhs) {
206 : return lhs.p <= rhs.p;
207 : }
208 :
209 : /**
210 : * @brief override for ++ operator
211 : *
212 : * @return GraphNodeIterator&
213 : */
214 : GraphNodeIterator &operator++() {
215 7038 : p += 1;
216 6950 : return *this;
217 : }
218 :
219 : /**
220 : * @brief override for operator++
221 : *
222 : * @return GraphNodeIterator
223 : */
224 : GraphNodeIterator operator++(int) {
225 : GraphNodeIterator temp(p);
226 83532 : p += 1;
227 83532 : return temp;
228 : }
229 :
230 : /**
231 : * @brief override for -- operator
232 : *
233 : * @return GraphNodeIterator&
234 : */
235 : GraphNodeIterator &operator--() {
236 29262 : p -= 1;
237 : return *this;
238 : }
239 :
240 : /**
241 : * @brief override for operator--
242 : *
243 : * @return GraphNodeIterator
244 : */
245 : GraphNodeIterator operator--(int) {
246 : GraphNodeIterator temp(p);
247 : p -= 1;
248 : return temp;
249 : }
250 :
251 : /**
252 : * @brief override for subtract operator
253 : *
254 : * @param offset offset to subtract
255 : * @return GraphNodeIterator
256 : */
257 : GraphNodeIterator operator-(const difference_type offset) const {
258 1081 : return GraphNodeIterator(p - offset);
259 : }
260 :
261 : /**
262 : * @brief override for subtract operator
263 : *
264 : * @param other iterator to subtract
265 : * @return difference_type
266 : */
267 : difference_type operator-(const GraphNodeIterator &other) const {
268 4484 : return p - other.p;
269 : }
270 :
271 : /**
272 : * @brief override for subtract and return result operator
273 : *
274 : * @param offset offset to subtract
275 : * @return GraphNodeIterator&
276 : */
277 : GraphNodeIterator &operator-=(const difference_type offset) {
278 : p -= offset;
279 : return *this;
280 : }
281 :
282 : /**
283 : * @brief override for add operator
284 : *
285 : * @param offset offset to add
286 : * @return GraphNodeIterator
287 : */
288 : GraphNodeIterator operator+(const difference_type offset) const {
289 0 : return GraphNodeIterator(p + offset);
290 : }
291 :
292 : /**
293 : * @brief override for add and return result operator
294 : *
295 : * @param offset offset to add
296 : * @return GraphNodeIterator&
297 : */
298 : GraphNodeIterator &operator+=(const difference_type offset) {
299 : p += offset;
300 : return *this;
301 : }
302 : };
303 :
304 : /**
305 : * @brief Reverse Iterator for GraphNode which return LayerNode object upon
306 : * realize
307 : *
308 : * @note This just extends GraphNodeIterator and is limited by its
309 : * functionality.
310 : */
311 : template <typename T_iterator>
312 : class GraphNodeReverseIterator : public std::reverse_iterator<T_iterator> {
313 : public:
314 : /**
315 : * @brief Construct a new Graph Node Reverse Iterator object
316 : *
317 : * @param iter Iterator
318 : */
319 : explicit GraphNodeReverseIterator(T_iterator iter) :
320 : std::reverse_iterator<T_iterator>(iter) {}
321 :
322 : /**
323 : * @brief reference operator
324 : *
325 : * @return T_iterator::value_type
326 : * @note this is different from standard iterator
327 : */
328 : typename T_iterator::value_type operator*() const {
329 : auto temp = std::reverse_iterator<T_iterator>::current - 1;
330 : return *temp;
331 : }
332 :
333 : /**
334 : * @brief pointer operator
335 : *
336 : * @return T_iterator::value_type
337 : * @note this is different from standard iterator
338 : */
339 : typename T_iterator::value_type operator->() const {
340 : auto temp = std::reverse_iterator<T_iterator>::current - 1;
341 : return *temp;
342 : }
343 : };
344 :
345 : /**
346 : * @brief Iterators to traverse the graph
347 : */
348 : template <class LayerNodeType>
349 : using graph_const_iterator =
350 : GraphNodeIterator<LayerNodeType, const std::shared_ptr<GraphNode>>;
351 :
352 : /**
353 : * @brief Iterators to traverse the graph
354 : */
355 : template <class LayerNodeType>
356 : using graph_const_reverse_iterator = GraphNodeReverseIterator<
357 : GraphNodeIterator<LayerNodeType, const std::shared_ptr<GraphNode>>>;
358 :
359 : } // namespace nntrainer
360 : #endif // __GRAPH_NODE_H__
|