Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file tflite_opnode.h
6 : * @date 28 April 2021
7 : * @brief contains tflite opnode which has information to convert to tflite file
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 : * @author Donghak Park <donghak.park@samsung.com>
11 : * @bug No known bugs except for NYI items
12 : */
13 :
14 : #ifndef __TFLITE_OPNODE_H__
15 : #define __TFLITE_OPNODE_H__
16 :
17 : #include <functional>
18 : #include <utility>
19 : #include <vector>
20 :
21 : #include <tensor.h>
22 : #include <tf_schema_generated.h>
23 :
24 : namespace nntrainer {
25 :
26 : class LayerNode;
27 : class RunLayerContext;
28 : /**
29 : * @brief tensorflow operational node representation. This class contains,
30 : * information to build operation flatbuffer
31 : *
32 : */
33 : class TfOpNode {
34 : public:
35 : using Variables = std::vector<const Tensor *>;
36 :
37 : using TransformFn =
38 : std::function<std::vector<Tensor>(std::vector<const Tensor *> &)>;
39 :
40 : /**
41 : * @brief Construct a new Tf object
42 : *
43 : */
44 : TfOpNode();
45 :
46 : /**
47 : * @brief finalize tf op node will be transformed to required variables
48 : * in this phase, weights are merged into inputs
49 : *
50 : */
51 : void finalize();
52 :
53 : /**
54 : * @brief Set common informations from layer node
55 : *
56 : * @param layer node layer node
57 : */
58 : void setLayerNode(const LayerNode &layer);
59 :
60 : /**
61 : * @brief Set the Weight Transform Fn object
62 : *
63 : * @param fn fn will be called before get
64 : */
65 : void setWeightTransformFn(TransformFn fn);
66 :
67 : /**
68 : * @brief Set the Input Transform Fn object
69 : *
70 : * @param fn fn will be called before get
71 : */
72 : void setInputTransformFn(TransformFn fn);
73 :
74 : /**
75 : * @brief Set the Op Type object
76 : *
77 : * @param op_type_ operation type
78 : */
79 22 : void setOpType(tflite::BuiltinOperator op_type_) { op_type = op_type_; }
80 :
81 : /**
82 : * @brief Set the Builtin Options object,
83 : * @note this can go private, export from a layer and fill this out
84 : *
85 : * @param builtin_option_type_ builtin option type
86 : * @param builtin_ops_ flatbuffer offset of builtin_ops
87 : */
88 : void setBuiltinOptions(tflite::BuiltinOptions builtin_option_type_,
89 : const flatbuffers::Offset<void> &builtin_ops_);
90 :
91 : /**
92 : * @brief Set the Need Reorder Weight object
93 : *
94 : */
95 4 : void setNeedReorderWeight() { need_reorder_weight = true; }
96 :
97 : /**
98 : * @brief Set the To Be Removed object
99 : *
100 : */
101 0 : void setToBeRemoved(bool to_be_removed) { is_to_be_removed = to_be_removed; }
102 :
103 : /**
104 : * @brief Set the Trainable object
105 : *
106 : */
107 14 : void setTrainable(bool trainable) { is_trainable = trainable; }
108 :
109 : /**
110 : * @brief Set the Inputs object
111 : *
112 : * @param inputs_
113 : */
114 0 : void setInputs(const Variables &inputs_) { inputs = inputs_; }
115 :
116 : /**
117 : * @brief Set the Outputs object
118 : *
119 : * @param outputs_
120 : */
121 0 : void setOutputs(const Variables &outputs_) { outputs = outputs_; }
122 :
123 : /**
124 : * @brief Set the Weights object
125 : *
126 : * @param weights_
127 : */
128 : void setWeights(Variables weights_, bool weight_transpose = false);
129 : /**
130 : * @brief Replace the Weights object
131 : *
132 : * @param weights_
133 : */
134 0 : void replaceWeights(const Variables &weights_) { weights = weights_; }
135 : /**
136 : * @brief Set(Append) the Props object
137 : *
138 : * @param value
139 : */
140 6 : void AppendProps(const int &value) { props_vector.push_back(value); }
141 :
142 : /**
143 : * @brief Set(Append) the Additional Props object
144 : *
145 : * @param value
146 : */
147 : void AppendAdditionalProps(const float &value) {
148 0 : additional_props.push_back(value);
149 : }
150 :
151 : /**
152 : * @brief Reorder Weight in case of NCHW --> NHWC
153 : *
154 : * @param node_count
155 : */
156 : void weightReorder(unsigned int node_count);
157 :
158 : /**
159 : * @brief Get the Inputs object
160 : *
161 : * @return Variables& inputs
162 : */
163 10 : Variables &getInputs() { return inputs; }
164 :
165 : /**
166 : * @brief Get the weights object
167 : *
168 : * @return const Variables& weights
169 : */
170 22 : const Variables &getWeights() const { return weights; }
171 :
172 : /**
173 : * @brief Get the weights object
174 : *
175 : * @return Variables& weights
176 : */
177 22 : Variables &getWeights() { return weights; }
178 :
179 : /**
180 : * @brief Get the Inputs object
181 : *
182 : * @return const Variables& inputs
183 : */
184 5 : const Variables &getInputs() const { return inputs; }
185 :
186 : /**
187 : * @brief Get the Outputs object
188 : *
189 : * @return Variables&
190 : */
191 27 : Variables &getOutputs() { return outputs; }
192 :
193 : /**
194 : * @brief Get the Outputs object
195 : *
196 : * @return const Variables& outputs
197 : */
198 22 : const Variables &getOutputs() const { return outputs; }
199 :
200 : /**
201 : * @brief check if this op node is model input
202 : *
203 : * @retval true if op node is model input
204 : * @retval false if op node is not model input
205 : */
206 66 : bool isInputNode() const { return is_input; }
207 :
208 : /**
209 : * @brief check if this op node is model output
210 : *
211 : * @retval true if op node is model output
212 : * @retval false if op node is not model output
213 : */
214 22 : bool isOutputNode() const { return is_output; }
215 :
216 : /**
217 : * @brief check if this op node is virtual node
218 : *
219 : * virtual node is a node that will not be exported
220 : */
221 83 : bool isVirtualNode() const { return is_virtual; }
222 :
223 : /**
224 : * @brief check if this layer need to reorder
225 : *
226 : * @return true if weight need to reorder
227 : * @return false if reordering is not required
228 : */
229 : bool isNeedReorder() const { return need_reorder_weight; }
230 :
231 : /**
232 : * @brief check if this layer is trainable
233 : *
234 : * @return true if layer(OpNode) trainable
235 : * @return false if layer(OpNode) non-trainable
236 : */
237 17 : bool isTrainable() const { return is_trainable; }
238 :
239 : /**
240 : * @brief check if this layer is to be removed
241 : *
242 : * @return true
243 : * @return false
244 : */
245 22 : bool isToBeRemoved() const { return is_to_be_removed; }
246 :
247 : /**
248 : * @brief Get the Props Vector
249 : *
250 : * @return const std::vector<int> props_vector
251 : */
252 0 : std::vector<int> getProps() const { return props_vector; }
253 :
254 : /**
255 : * @brief Get the Additional Props Vector
256 : *
257 : * @return const std::vector<float> additional_props
258 : */
259 0 : std::vector<float> getAdditionalProps() const { return additional_props; }
260 :
261 : /**
262 : * @brief Get the Op Type object
263 : *
264 : * @return const tflite::BuiltinOperator
265 : */
266 95 : const tflite::BuiltinOperator getOpType() const { return op_type; }
267 :
268 : /**
269 : * @brief Get the Op Type object
270 : *
271 : * @return const tflite::BuiltinOperator
272 : */
273 : const tflite::BuiltinOptions getOptionType() const {
274 66 : return builtin_option_type;
275 : }
276 :
277 : /**
278 : * @brief Get the Op Options object
279 : * @param f Flatbuffer Builder
280 : * @retval const tflite::Offset<void>
281 : */
282 : flatbuffers::Offset<void> getBuiltinOps() const;
283 :
284 : /**
285 : * @brief Get input nodes
286 : *
287 : * @return const std::vector<TfOpNode *> &input_nodes
288 : */
289 : const std::vector<TfOpNode *> &getInputNodes() const { return input_nodes; }
290 :
291 : /**
292 : * @brief Set arity
293 : *
294 : * @param value value to set
295 : */
296 22 : void arity(size_t value) { input_nodes.resize(value); }
297 :
298 : /**
299 : * @brief Get arity
300 : *
301 : * @return const unsigned input_nodes.size()
302 : */
303 0 : const unsigned arity() const { return input_nodes.size(); }
304 :
305 : /**
306 : * @brief Set n-th argument of the node
307 : *
308 : * @param index argument index to set
309 : * @param node the node to be argument
310 : */
311 17 : void setArg(size_t index, TfOpNode *node) { input_nodes.at(index) = node; }
312 :
313 : /**
314 : * @brief Get n-th argument of the node
315 : *
316 : * @return TfOpNode *input_nodes.at(index)
317 : */
318 0 : TfOpNode *arg(size_t index) const { return input_nodes.at(index); }
319 :
320 : private:
321 : Variables inputs; /**< input variables */
322 : Variables outputs; /**< output variables */
323 : Variables weights; /**< weight variables */
324 : std::vector<TfOpNode *> input_nodes; /**< input nodes */
325 : std::vector<int> props_vector; /**< props vector */
326 : std::vector<float> additional_props; /**< additional props vector */
327 :
328 : /**
329 : * Q) Why do we need input transform?
330 : * A) To transform the nntrainer input data format(NCHW) to tflite
331 : *format(NHWC)
332 : **/
333 : TransformFn weight_transform; /**< weight transforms */
334 : TransformFn input_transform; /**< input transforms */
335 :
336 : bool is_input; /**< true if given input is input; */
337 : bool is_output; /**< true if given output is output; */
338 : bool is_virtual; /**< true if given node is virtual; */
339 : bool is_trainable; /**< true if given node has weight and trainable */
340 : bool is_to_be_removed; /**< true if given node is to be removed */
341 : bool need_reorder_weight; /**< true if given node need to reorder weight; */
342 :
343 : /// @todo change to shared_ptr or unique_ptr
344 : /// why? the addresses of existing tensors in the vector could become invalid
345 : /// due to memory reallocation
346 : std::vector<Tensor>
347 : node_owned_variable; /**< when node should be transformed it's own type, it
348 : * needs to be owned by someone, so @a TfOpNode owns
349 : * those orphaned tensors until the instance is
350 : * destroyed */
351 :
352 : tflite::BuiltinOperator op_type;
353 :
354 : /// retrieve this from export_to
355 : flatbuffers::Offset<void> builtin_ops;
356 : tflite::BuiltinOptions builtin_option_type;
357 : };
358 :
359 : } // namespace nntrainer
360 :
361 : #endif // __TFLITE_OPNODE_H__
|