Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file tflite_interpreter.cpp
6 : * @date 12 April 2021
7 : * @brief NNTrainer *.tflite Interpreter
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 : * @author Donghak Park <donghak.park@samsung.com>
11 : * @bug No known bugs except for NYI items
12 : */
13 : #include <tflite_interpreter.h>
14 :
15 : #include <algorithm>
16 : #include <fstream>
17 : #include <map>
18 : #include <memory>
19 : #include <regex>
20 : #include <set>
21 : #include <string>
22 : #include <tuple>
23 : #include <type_traits>
24 : #include <utility>
25 :
26 : #include <bn_realizer.h>
27 : #include <fc_layer.h>
28 : #include <layer_node.h>
29 : #include <tflite_export_realizer.h>
30 : #include <nntrainer_error.h>
31 : #include <node_exporter.h>
32 : #include <tensor.h>
33 : #include <tf_schema_generated.h>
34 : #include <tflite_opnode.h>
35 :
36 : static constexpr const char *FUNC_TAG = "[TFLITE INTERPRETER] ";
37 :
38 : // This Variables need for create tflite nodes
39 : nntrainer::TfOpNode::Variables new_variable;
40 : nntrainer::Tensor new_weight_add[50];
41 : unsigned int new_alloc_tensors_idx = 0;
42 :
43 : namespace nntrainer {
44 :
45 : namespace {
46 : /**
47 : * @brief after finishing building, call this to safe to a file
48 : *
49 : * @param builder flatbuffer builder
50 : * @param out out
51 : */
52 5 : void builder2file(const flatbuffers::FlatBufferBuilder &builder,
53 : const std::string &out) {
54 5 : uint8_t *buf = builder.GetBufferPointer();
55 5 : size_t size = builder.GetSize();
56 : flatbuffers::Verifier v(buf, size);
57 5 : NNTR_THROW_IF(!tflite::VerifyModelBuffer(v), std::invalid_argument)
58 : << FUNC_TAG << "Verifying serialized model failed";
59 5 : std::ofstream os(out, std::ios_base::binary);
60 :
61 : const size_t error_buflen = 100;
62 : char error_buf[error_buflen];
63 5 : NNTR_THROW_IF(!os.good(), std::invalid_argument)
64 : << FUNC_TAG << "failed to open, reason: "
65 0 : << SAFE_STRERROR(errno, error_buf, error_buflen);
66 :
67 5 : std::streamsize sz = static_cast<std::streamsize>(builder.GetSize());
68 : NNTR_THROW_IF(sz < 0, std::invalid_argument)
69 : << FUNC_TAG << "builder size: " << builder.GetSize()
70 : << " is too big. It cannot be represented by std::streamsize";
71 :
72 5 : os.write((char *)builder.GetBufferPointer(), sz);
73 5 : os.close();
74 5 : }
75 :
76 : /**
77 : * @brief get predecessor nodes
78 : *
79 : * @param node the node from which to get predecessor nodes
80 : * @note virtual nodes are ignored
81 : */
82 17 : std::vector<const TfOpNode *> getPredNodes(const TfOpNode &node) {
83 : std::vector<const TfOpNode *> predNodes;
84 :
85 34 : for (auto input : node.getInputNodes()) {
86 17 : const TfOpNode *pred = input;
87 17 : while (pred->isVirtualNode()) {
88 : /// Assume that virtual nodes have single input
89 0 : assert(pred->arity() == 1);
90 0 : pred = pred->arg(0);
91 : }
92 17 : predNodes.push_back(pred);
93 : }
94 17 : return predNodes;
95 0 : }
96 :
97 : using TfOpNodes = std::vector<std::unique_ptr<TfOpNode>>;
98 :
99 : /**
100 : * @brief Bidirectional Index map
101 : *
102 : * @tparam Key type of a underlying hashable value, please note that T will be
103 : * copied, so please use this for pointers and primitive values that is okay to
104 : * copy
105 : * @tparam Data data type to be stored inside the vector, if not given, same as
106 : * KeyType
107 : */
108 : template <typename KeyType, typename DataType = KeyType>
109 : class BidirectionalIndexMap {
110 : public:
111 : /**
112 : * @brief addDatapoint to the map
113 : *
114 : * @param key key to be added to search for the data
115 : * @param data data to be added if there is no occurrence, data will be
116 : * copied.
117 : */
118 75 : void addDataWhenNotFound(KeyType key, DataType data) {
119 : auto search = key2index.find(key);
120 :
121 75 : if (search == key2index.end()) {
122 68 : key2index[key] = index2data.size();
123 68 : index2data.push_back(data);
124 : }
125 75 : }
126 :
127 : /**
128 : * @brief addDatapoint to the map when key and datatype is same
129 : *
130 : * @param key key/data to add
131 : */
132 : void addDataWhenNotFound(KeyType key) {
133 : static_assert(std::is_same<KeyType, DataType>::value == true,
134 : "key type and data type are different!");
135 22 : addDataWhenNotFound(key, key);
136 : }
137 :
138 : /**
139 : * @brief Get the Index of the data
140 : *
141 : * @param key data that will be the key
142 : * @return unsigned int index
143 : */
144 65 : unsigned int getIndex(const KeyType &key) const {
145 : auto search = key2index.find(key);
146 :
147 65 : NNTR_THROW_IF(search == key2index.end(), std::invalid_argument)
148 : << FUNC_TAG << "Cannot find index for key: " << key;
149 :
150 65 : return search->second;
151 : }
152 :
153 : /**
154 : * @brief Get the Data object
155 : *
156 : * @param idx index to be searched
157 : * @return T datapoint T
158 : */
159 : DataType getData(unsigned int index) const {
160 : NNTR_THROW_IF(index >= index2data.size(), std::invalid_argument)
161 : << FUNC_TAG << "Cannot find data for index: " << index;
162 :
163 : return index2data[index];
164 : }
165 :
166 : /**
167 : * @brief Get the Data object
168 : *
169 : * @return const std::vector<T>& underlying data
170 : */
171 : const std::vector<DataType> &getData() const { return index2data; }
172 :
173 : private:
174 : std::unordered_map<KeyType, unsigned int> key2index; /**< key -> index map */
175 : std::vector<DataType> index2data; /**< index -> data map */
176 : };
177 :
178 : /**
179 : * @brief tensorflow operation index map, this class manages operation index
180 : * mapping
181 : *
182 : */
183 : class TfOpIdxMap {
184 : public:
185 : using Buffer = std::pair<size_t, const float *>;
186 :
187 5 : TfOpIdxMap(const TfOpNodes &nodes) {
188 : auto &opcode_map = getIndexMap<tflite::BuiltinOperator>();
189 : auto update_opcode = [&opcode_map](tflite::BuiltinOperator opcode) {
190 : opcode_map.addDataWhenNotFound(opcode);
191 22 : };
192 :
193 : auto &buffer_map = getIndexMap<const float *, Buffer>();
194 5 : buffer_map.addDataWhenNotFound(
195 5 : nullptr, {0, empty_buffer}); // this represents undefined buffer
196 5 : buffer_map.addDataWhenNotFound(
197 : empty_buffer, {0, empty_buffer}); // this represents empty buffer
198 :
199 49 : auto update_buffer_map = [&buffer_map](const TfOpNode::Variables &variables,
200 : bool dynamic) {
201 92 : for (auto &variable : variables) {
202 43 : const float *buf = variable->getData();
203 43 : assert(buf != nullptr);
204 43 : auto byte_size = dynamic ? 0 : variable->bytes();
205 43 : buffer_map.addDataWhenNotFound(buf, {byte_size, buf});
206 : }
207 54 : };
208 :
209 : auto register_tensors =
210 49 : [&tensors = this->tensors](const TfOpNode::Variables &variables) {
211 97 : for (auto &variable : variables) {
212 48 : auto tensor_it = std::find(tensors.begin(), tensors.end(), variable);
213 48 : if (tensor_it == tensors.end()) {
214 48 : tensors.push_back(variable);
215 : }
216 : }
217 5 : };
218 :
219 27 : for (auto &op_node : nodes) {
220 22 : if (op_node->isVirtualNode())
221 0 : continue;
222 : update_opcode(op_node->getOpType());
223 :
224 22 : if (op_node->isInputNode()) {
225 : /**
226 : * Q) Why only register graph input tensor?
227 : *
228 : * A) the tflite needs only one tensor between nodes. Therefore,
229 : *basically, no inputs are considered except graph input that doesn't
230 : *have FROM node.
231 : **/
232 5 : register_tensors(op_node->getInputs());
233 : /**
234 : * Q) Why only update second input of the input node?
235 : *
236 : * A) 1. graph input nodes should be Transpose operator to change data
237 : *format from NCHW to NHWC.
238 : * 2. Transpose operator has two inputs - input to be
239 : *transposed(input[0]), 1d permute vector(input[1])
240 : * 3. input[0] has nullptr data pointer, which can't be added to
241 : *buffer_map. But, input[0] should have its own buffer and it will be
242 : *considered when the tflite buffers are built.
243 : **/
244 5 : assert(op_node->getInputs()[0]->getData() == nullptr);
245 5 : update_buffer_map({op_node->getInputs()[1]}, false);
246 : }
247 22 : register_tensors(op_node->getWeights());
248 22 : update_buffer_map(op_node->getWeights(), false);
249 :
250 22 : register_tensors(op_node->getOutputs());
251 22 : update_buffer_map(op_node->getOutputs(), true);
252 : }
253 :
254 10 : auto update_model_io_to = [this](const TfOpNode::Variables &variables,
255 : std::vector<int> &v) {
256 25 : for (auto &variable : variables) {
257 30 : if (variable->getName().find("nntrainer_internal_perm") !=
258 : std::string::npos)
259 5 : continue;
260 10 : v.push_back(this->getTensorIndex(variable));
261 : }
262 15 : };
263 :
264 27 : for (auto &op_node : nodes) {
265 22 : if (op_node->isVirtualNode())
266 0 : continue;
267 22 : if (op_node->isInputNode()) {
268 5 : update_model_io_to(op_node->getInputs(), inputs);
269 : }
270 22 : if (op_node->isOutputNode()) {
271 5 : update_model_io_to(op_node->getOutputs(), outputs);
272 : }
273 : }
274 5 : }
275 :
276 : template <typename KeyType, typename DataType = KeyType>
277 : BidirectionalIndexMap<KeyType, DataType> &getIndexMap() {
278 : return std::get<BidirectionalIndexMap<KeyType, DataType>>(maps);
279 : }
280 :
281 : template <typename KeyType, typename DataType = KeyType>
282 : const BidirectionalIndexMap<KeyType, DataType> &getIndexMap() const {
283 : return std::get<BidirectionalIndexMap<KeyType, DataType>>(maps);
284 : }
285 :
286 : const float *get_empty_buffer() const { return empty_buffer; }
287 :
288 : const std::vector<int> &getInputs() const { return inputs; }
289 :
290 : const std::vector<int> &getOutputs() const { return outputs; }
291 :
292 : const std::vector<const Tensor *> &getTensors() const { return tensors; }
293 :
294 75 : std::ptrdiff_t getTensorIndex(const Tensor *tensor) const {
295 75 : auto tensor_it = std::find(tensors.begin(), tensors.end(), tensor);
296 75 : NNTR_THROW_IF(tensor_it == tensors.cend(), std::invalid_argument)
297 0 : << FUNC_TAG << "Cannot find index for tensor: " << tensor->getName();
298 75 : return std::distance(tensors.begin(), tensor_it);
299 : }
300 :
301 : private:
302 : float empty_buffer[0]; /**< reserved uninitialized tensor points to this
303 : buffer */
304 :
305 : std::tuple<BidirectionalIndexMap<const float *, Buffer>, /**< buffer map
306 : */
307 : BidirectionalIndexMap<tflite::BuiltinOperator>> /**< opcode map
308 : */
309 : maps;
310 :
311 : std::vector<int> inputs;
312 : std::vector<int> outputs;
313 : /// since it is used as a tensor index, the order is important
314 : std::vector<const Tensor *> tensors;
315 : };
316 :
317 5 : TfOpNodes buildOpNodes(const GraphRepresentation &representation,
318 : flatbuffers::FlatBufferBuilder &fbb) {
319 : TfOpNodes nodes;
320 : /// @todo TfOpNode needs to have LayerNode pointer
321 : std::map<TfOpNode *, const LayerNode *> tf_to_layer;
322 : std::map<const LayerNode *, TfOpNode *> layer_to_tf;
323 :
324 : /// @todo, look ahead of layers to get nodes that can be fused
325 : /// we will need to have a dedicated builder
326 27 : for (auto iter = representation.cbegin(); iter != representation.cend();
327 : iter++) {
328 : const auto &ln = *iter;
329 :
330 22 : Exporter e(&fbb);
331 22 : ln->exportTo(e, ml::train::ExportMethods::METHOD_TFLITE);
332 22 : auto export_output = e.getResult<ml::train::ExportMethods::METHOD_TFLITE>();
333 :
334 22 : if (export_output.get()->getWeights().size() == 0) {
335 : export_output.get()->setTrainable(false);
336 : }
337 :
338 22 : nodes.emplace_back(std::move(export_output));
339 22 : tf_to_layer.insert({nodes.back().get(), ln.get()});
340 22 : layer_to_tf.insert({ln.get(), nodes.back().get()});
341 22 : }
342 :
343 : int node_count = 0;
344 : bool is_local_first = true;
345 : /** is_local_first : first FC Layer after Channel related layer
346 : * For example
347 : * : Input -> Conv -> Conv -> Flatten -> [FC]:local_first
348 : * : Input -> Conv -> Flatten -> [FC]:local_first -> Conv -> Flatten ->
349 : * [FC]:local_first
350 : */
351 :
352 : // set reorder weight flag for FullyConnected layer
353 27 : for (auto &n : nodes) {
354 : auto tf_node = n.get();
355 :
356 : if (tf_node->getOptionType() ==
357 : tflite::BuiltinOptions::BuiltinOptions_FullyConnectedOptions &&
358 22 : node_count != 0 && is_local_first) {
359 : tf_node->setNeedReorderWeight();
360 : is_local_first = false;
361 : }
362 :
363 22 : if (is_local_first == false &&
364 : tf_node->getOptionType() !=
365 : tflite::BuiltinOptions::BuiltinOptions_FullyConnectedOptions) {
366 : is_local_first = true;
367 : }
368 :
369 22 : node_count++;
370 : }
371 :
372 : /// set arity of TfOpNodes
373 27 : for (auto &n : nodes) {
374 : auto tf_node = n.get();
375 : auto searched_layer = tf_to_layer.find(tf_node);
376 22 : if (searched_layer == tf_to_layer.end())
377 0 : throw std::runtime_error("Cannot find layer for TfOpNode");
378 22 : auto layer_node = searched_layer->second;
379 : auto layer_node_inputs = layer_node->getInputConnections();
380 :
381 : /// assume that the TfOpNode and the LayerNode have a one-to-one
382 : /// relationship
383 : tf_node->arity(layer_node_inputs.size());
384 39 : for (size_t index = 0; index < layer_node_inputs.size(); index++) {
385 : auto input_layer_name = layer_node_inputs[index];
386 17 : auto input_layer_node_iterator = std::find_if(
387 : representation.begin(), representation.end(),
388 53 : [&input_layer_name](std::shared_ptr<nntrainer::LayerNode> node) {
389 106 : return istrequal(node.get()->getName(), input_layer_name);
390 : });
391 :
392 17 : if (input_layer_node_iterator != representation.end()) {
393 : auto input_layer_node = input_layer_node_iterator->get();
394 17 : if (layer_to_tf.find(input_layer_node) != layer_to_tf.end()) {
395 17 : tf_node->setArg(index, layer_to_tf.find(input_layer_node)->second);
396 : }
397 : }
398 : }
399 22 : }
400 :
401 : node_count = 0;
402 27 : for (auto &n : nodes) {
403 : auto tf_node = n.get();
404 22 : if (tf_node->getOptionType() ==
405 : tflite::BuiltinOptions::BuiltinOptions_FullyConnectedOptions) {
406 6 : tf_node->weightReorder(node_count);
407 : }
408 :
409 : if (tf_node->getOpType() ==
410 2 : tflite::BuiltinOperator::BuiltinOperator_CONV_2D &&
411 2 : nodes.at(node_count + 1).get()->getOpType() ==
412 22 : tflite::BuiltinOperator::BuiltinOperator_MUL &&
413 0 : nodes.at(node_count + 2).get()->getOpType() ==
414 : tflite::BuiltinOperator::BuiltinOperator_RELU) {
415 : // Fuse Conv2D + Mul(Batch Norm) + ReLU to Conv2D
416 :
417 : auto props = tf_node->getProps();
418 : auto tf_padding = tflite::Padding_SAME;
419 :
420 0 : if (props[0] == 1) {
421 : tf_padding = tflite::Padding_VALID;
422 : }
423 : auto new_options =
424 0 : tflite::CreateConv2DOptions(fbb, tf_padding, props[1], props[2],
425 : tflite::ActivationFunctionType_RELU)
426 0 : .Union();
427 0 : tf_node->setBuiltinOptions(tflite::BuiltinOptions_Conv2DOptions,
428 : new_options);
429 : // After Fusing Mark ReLU Node to be removed
430 : nodes.at(node_count + 2).get()->setToBeRemoved(true);
431 0 : }
432 :
433 22 : if (node_count < 1) {
434 : node_count++;
435 5 : continue;
436 : } else {
437 17 : if (nodes.at(node_count - 1).get()->isTrainable() == true &&
438 : tf_node->getOpType() == tflite::BuiltinOperator_MUL) {
439 :
440 : // Fused weight(conv)
441 : // = weight(conv) * (weight(bn) / sqrt(var(bn) + eps))
442 :
443 0 : auto conv_weights = nodes.at(node_count - 1).get()->getWeights();
444 0 : Tensor conv_weight(conv_weights.at(0)->getDim());
445 0 : Tensor conv_bias(conv_weights.at(1)->getDim());
446 0 : conv_weight.copyData(conv_weights.at(0)->clone());
447 0 : conv_bias.copyData(conv_weights.at(1)->clone());
448 :
449 0 : auto mul_weights = tf_node->getWeights();
450 0 : auto mul_mean = mul_weights.at(0)->clone().transpose("1:2:0");
451 0 : auto mul_var = mul_weights.at(1)->clone().transpose("1:2:0");
452 0 : auto mul_weight = mul_weights.at(2)->clone().transpose("1:2:0");
453 0 : auto mul_bias = mul_weights.at(3)->clone().transpose("1:2:0");
454 0 : auto mul_epsilon = tf_node->getAdditionalProps().at(0);
455 :
456 : // run sqrt(var(bn) + eps)
457 0 : mul_var.add_i(mul_epsilon);
458 0 : mul_var.pow_i(-0.5f);
459 0 : mul_weight.multiply_i(mul_var);
460 :
461 0 : Tensor reshape_mul_weight(mul_weight.getDim());
462 0 : reshape_mul_weight.copy(mul_weight);
463 0 : reshape_mul_weight.reshape(
464 0 : TensorDim{mul_weight.getDim().width(), 1, 1, 1});
465 0 : conv_weight.multiply_i(reshape_mul_weight);
466 :
467 0 : conv_bias.subtract_i(mul_mean);
468 0 : conv_bias.multiply_i(mul_weight);
469 0 : conv_bias.add_i(mul_bias);
470 :
471 : TfOpNode::Variables conv_new_weights;
472 0 : conv_new_weights.push_back(&conv_weight);
473 0 : conv_new_weights.push_back(&conv_bias);
474 0 : nodes.at(node_count - 1).get()->setWeights(conv_new_weights);
475 : // set mul node to be removed (mul mean batch normalization)
476 : n->setToBeRemoved(true);
477 0 : }
478 : }
479 17 : node_count++;
480 : }
481 :
482 5 : return nodes;
483 0 : }
484 :
485 : flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::Buffer>>>
486 5 : buildBuffers(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
487 : const auto &buffers =
488 : map.getIndexMap<const float *, TfOpIdxMap::Buffer>().getData();
489 :
490 : std::vector<flatbuffers::Offset<tflite::Buffer>> fb_buffers;
491 5 : fb_buffers.reserve(buffers.size());
492 :
493 56 : auto create_buffer_offset = [&fbb](const TfOpIdxMap::Buffer &buffer) {
494 56 : if (buffer.first == 0) {
495 35 : return tflite::CreateBuffer(fbb);
496 : }
497 :
498 21 : auto data = fbb.CreateVector(
499 21 : reinterpret_cast<const uint8_t *>(buffer.second), buffer.first);
500 :
501 21 : return tflite::CreateBuffer(fbb, data);
502 5 : };
503 :
504 5 : std::transform(buffers.begin(), buffers.end(), std::back_inserter(fb_buffers),
505 : create_buffer_offset);
506 :
507 : // add input buffer
508 10 : for (unsigned index = 0; index < map.getInputs().size(); index++) {
509 10 : fb_buffers.push_back(create_buffer_offset({0, nullptr}));
510 : }
511 5 : return fbb.CreateVector(fb_buffers);
512 5 : }
513 :
514 : flatbuffers::Offset<
515 : flatbuffers::Vector<flatbuffers::Offset<tflite::OperatorCode>>>
516 5 : buildOperatorCodes(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
517 : const auto &op_codes = map.getIndexMap<tflite::BuiltinOperator>().getData();
518 :
519 : std::vector<flatbuffers::Offset<tflite::OperatorCode>> fb_op_codes;
520 5 : fb_op_codes.reserve(op_codes.size());
521 :
522 17 : auto create_op_offset = [&fbb](const tflite::BuiltinOperator &op,
523 : int32_t version = 1) {
524 17 : tflite::OperatorCodeBuilder builder(fbb);
525 17 : builder.add_deprecated_builtin_code(static_cast<int8_t>(op));
526 : /// @todo find reason why version field is not shown
527 : /// on json when version is 1 (other versions are fine)
528 : builder.add_version(version);
529 17 : builder.add_builtin_code(op);
530 17 : return builder.Finish();
531 5 : };
532 :
533 5 : std::transform(op_codes.begin(), op_codes.end(),
534 : std::back_inserter(fb_op_codes), create_op_offset);
535 :
536 5 : return fbb.CreateVector(fb_op_codes);
537 5 : }
538 :
539 : flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::Tensor>>>
540 5 : buildTensors(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
541 : /// @todo: the actual (squeezed) tensor dimension must be known before
542 : /// coming here. For now, it is directly guessed for the fc layer
543 : const auto &variables = map.getTensors();
544 : const auto &buffer_map = map.getIndexMap<const float *, TfOpIdxMap::Buffer>();
545 5 : auto graph_input_offset = map.getInputs().size() - 1;
546 :
547 : std::vector<flatbuffers::Offset<tflite::Tensor>> fb_tensors;
548 5 : fb_tensors.reserve(variables.size());
549 :
550 48 : auto create_tensor = [&fbb, &buffer_map,
551 : &graph_input_offset](const Tensor *var) {
552 48 : auto dim = var->getDim();
553 48 : bool need_shape_signature = dim.is_dynamic();
554 48 : std::vector<int32_t> eff_dim = dim.getEffectiveDimension();
555 48 : auto shape = fbb.CreateVector(eff_dim);
556 :
557 : decltype(shape) shape_sig;
558 48 : if (need_shape_signature) {
559 20 : std::vector<int32_t> dyn_dim = dim.getEffectiveDimension(true);
560 20 : shape_sig = fbb.CreateVector(dyn_dim);
561 20 : }
562 :
563 : /// change this var->getName when tensor have it's own name
564 96 : auto name = fbb.CreateString("nntrainer_converted" + var->getName());
565 :
566 : /// only graph inputs have nullptr data pointer.
567 : unsigned int buffer_idx =
568 : var->getData() == nullptr
569 5 : ? buffer_map.getData().size() - graph_input_offset--
570 48 : : buffer_map.getIndex(var->getData());
571 :
572 48 : tflite::TensorBuilder builder(fbb);
573 48 : builder.add_name(name);
574 : builder.add_buffer(buffer_idx);
575 : /// @todo support more data types
576 : /// @note this is workaround because nntrainer tensor allows only float
577 : /// dtype
578 96 : if (var->getName().find("nntrainer_internal_perm") != std::string::npos) {
579 : builder.add_type(tflite::TensorType_INT32);
580 : } else
581 : builder.add_type(tflite::TensorType_FLOAT32);
582 48 : builder.add_shape(shape);
583 48 : if (need_shape_signature) {
584 20 : builder.add_shape_signature(shape_sig);
585 : }
586 48 : return builder.Finish();
587 53 : };
588 :
589 5 : std::transform(variables.begin(), variables.end(),
590 : std::back_inserter(fb_tensors), create_tensor);
591 :
592 5 : return fbb.CreateVector(fb_tensors);
593 5 : }
594 :
595 : flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::Operator>>>
596 5 : buildOperators(const TfOpNodes &nodes, const TfOpIdxMap &map,
597 : flatbuffers::FlatBufferBuilder &fbb) {
598 :
599 : /// this lambda maps variables to list of indexes in the map
600 66 : auto variables_to_idx_vector = [&map](const TfOpNode::Variables &v) {
601 : std::vector<int> idx_vector;
602 66 : idx_vector.reserve(v.size());
603 :
604 66 : std::transform(
605 : v.begin(), v.end(), std::back_inserter(idx_vector),
606 65 : [&map](const Tensor *variable) { return map.getTensorIndex(variable); });
607 66 : return idx_vector;
608 0 : };
609 :
610 22 : auto create_operator = [&fbb, &map,
611 : &variables_to_idx_vector](const TfOpNode &node) {
612 22 : auto &index_map = map.getIndexMap<tflite::BuiltinOperator>();
613 :
614 22 : auto op_code = index_map.getIndex(node.getOpType());
615 : std::vector<int> inputs;
616 22 : if (node.isInputNode()) {
617 5 : inputs = variables_to_idx_vector(node.getInputs());
618 : } else {
619 : /**
620 : * Q) Why find a tensor that shares a buffer with input tensor?
621 : *
622 : * A) the tflite needs only one tensor between nodes. Therefore,
623 : *basically, output tensors are used for tflite tensor that shares its
624 : *buffer with input's
625 : **/
626 : TfOpNode::Variables input_tensors;
627 34 : for (auto parent_node : getPredNodes(node)) {
628 34 : for (auto parent_out : parent_node->getOutputs()) {
629 34 : for (auto in : node.getInputs()) {
630 : /// second condition is a workaround
631 : /// Transpose op output tensor originally had nullptr data pointer
632 : /// but it has been allocated (parent_out->getData()). But, the
633 : /// buffer that shared its buffer hasn't so it has still nullptr
634 : /// (in->getData()).
635 : /// @todo remove this workaround
636 20 : if (parent_out->getData() == in->getData() ||
637 3 : (in->getData() == nullptr && parent_out->getData())) {
638 17 : if (std::find(input_tensors.begin(), input_tensors.end(),
639 : parent_out) != input_tensors.end())
640 0 : continue;
641 17 : input_tensors.push_back(parent_out);
642 : }
643 : }
644 : }
645 17 : }
646 17 : inputs = variables_to_idx_vector(input_tensors);
647 17 : }
648 22 : auto weights = variables_to_idx_vector(node.getWeights());
649 :
650 : /// weights are part of input in tflite
651 22 : inputs.insert(inputs.end(), weights.begin(), weights.end());
652 :
653 22 : auto outputs = variables_to_idx_vector(node.getOutputs());
654 :
655 22 : auto fb_inputs = fbb.CreateVector(inputs);
656 22 : auto fb_outputs = fbb.CreateVector(outputs);
657 22 : auto fb_options = node.getBuiltinOps();
658 :
659 22 : tflite::OperatorBuilder builder(fbb);
660 : builder.add_opcode_index(op_code);
661 : builder.add_builtin_options_type(node.getOptionType());
662 22 : builder.add_builtin_options(fb_options);
663 22 : builder.add_inputs(fb_inputs);
664 22 : builder.add_outputs(fb_outputs);
665 22 : return builder.Finish();
666 27 : };
667 :
668 : std::vector<flatbuffers::Offset<tflite::Operator>> v;
669 5 : v.reserve(nodes.size());
670 :
671 27 : for (auto &node : nodes) {
672 22 : if (node->isVirtualNode())
673 0 : continue;
674 22 : auto op = create_operator(*node);
675 22 : v.push_back(op);
676 : }
677 :
678 5 : return fbb.CreateVector(v);
679 5 : }
680 :
681 : flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::SubGraph>>>
682 5 : buildSubGraphs(const TfOpNodes &nodes, const TfOpIdxMap &map,
683 : flatbuffers::FlatBufferBuilder &fbb) {
684 :
685 5 : auto tensors = buildTensors(map, fbb);
686 5 : auto ops = buildOperators(nodes, map, fbb);
687 :
688 : /// @todo extract this to buildSubgraph if there is one or more subgraph
689 5 : auto name = fbb.CreateString("main");
690 : auto inputs = fbb.CreateVector(map.getInputs());
691 : auto outputs = fbb.CreateVector(map.getOutputs());
692 :
693 : auto builder = tflite::SubGraphBuilder(fbb);
694 5 : builder.add_tensors(tensors);
695 5 : builder.add_inputs(inputs);
696 5 : builder.add_outputs(outputs);
697 5 : builder.add_name(name);
698 5 : builder.add_operators(ops);
699 5 : auto subgraph = builder.Finish();
700 :
701 : std::vector<flatbuffers::Offset<tflite::SubGraph>> subgraphs;
702 5 : subgraphs.reserve(1);
703 5 : subgraphs.push_back(subgraph);
704 :
705 5 : return fbb.CreateVector(subgraphs);
706 5 : }
707 :
708 : } // namespace
709 :
710 5 : TfOpNodes buildRealizedOpNodes(TfOpNodes &nodes,
711 : flatbuffers::FlatBufferBuilder &fbb) {
712 : TfOpNodes realized_nodes;
713 :
714 : bool set_input = false;
715 : unsigned int node_count = 0;
716 :
717 27 : for (auto &node : nodes) {
718 22 : if (set_input) { // if front node is new added node set input output
719 : node->setArg(0, realized_nodes.back().get());
720 : realized_nodes.back()->setOutputs(node->getInputs());
721 : set_input = false;
722 : }
723 :
724 22 : if (node->isToBeRemoved() == true) {
725 : // Remove node, Assume that Input Node is not removed
726 : realized_nodes.back().get()->setOutputs(
727 0 : nodes.at(node_count)->getOutputs());
728 0 : nodes.at(node_count + 1)->setArg(0, realized_nodes.back().get());
729 : nodes.at(node_count + 1)->setInputs(nodes.at(node_count)->getInputs());
730 : } else {
731 : realized_nodes.push_back(std::move(node));
732 :
733 22 : if (realized_nodes.back().get()->getOpType() ==
734 : tflite::BuiltinOperator_MUL) { // Fused MUL ADD (Non Trainable)
735 : /**
736 : y = x * (gamma / sqrt(variance + epsilon)) +
737 : (beta - mean * gamma / sqrt(variance + epsilon))
738 : */
739 0 : auto removed_weights = realized_nodes.back().get()->getWeights();
740 0 : auto mul_mean = removed_weights.at(0)->clone();
741 0 : auto mul_variance = removed_weights.at(1)->clone();
742 0 : auto mul_gamma = removed_weights.at(2)->clone();
743 0 : auto mul_beta = removed_weights.at(3)->clone();
744 : auto mul_epsilon =
745 0 : realized_nodes.back().get()->getAdditionalProps().at(0);
746 :
747 : std::unique_ptr<Tensor> new_mul_weight =
748 0 : std::make_unique<Tensor>(mul_gamma.getDim());
749 0 : new_mul_weight->allocate();
750 0 : new_mul_weight->copy(mul_gamma);
751 :
752 : // new_mul_weight = (gamma / sqrt(variance + epsilon))
753 0 : mul_variance.add_i(mul_epsilon);
754 0 : mul_variance.pow_i(-0.5f);
755 0 : new_mul_weight->multiply_i(mul_variance);
756 :
757 : // beta = (beta - mean * gamma / sqrt(variance + epsilon))
758 0 : Tensor sub_result(new_mul_weight->getDim());
759 0 : sub_result.allocate();
760 0 : sub_result.copyData(*new_mul_weight);
761 :
762 0 : mul_mean.multiply_i(sub_result);
763 0 : mul_beta.subtract_i(mul_mean);
764 0 : new_mul_weight->setName("MUL");
765 0 : for (auto weight : removed_weights) {
766 0 : delete weight;
767 : }
768 : removed_weights.clear();
769 0 : removed_weights.push_back(new_mul_weight.release());
770 :
771 : realized_nodes.back().get()->replaceWeights(removed_weights);
772 0 : realized_nodes.back().get()->setWeights(removed_weights, true);
773 :
774 : // Insert Add layer into Graph
775 0 : std::unique_ptr<TfOpNode> tf_node = std::make_unique<TfOpNode>();
776 : tf_node->setInputs(realized_nodes.back()->getOutputs());
777 : tf_node->setOpType(tflite::BuiltinOperator_ADD);
778 : auto options =
779 0 : tflite::CreateAddOptions(fbb, tflite::ActivationFunctionType_RELU)
780 0 : .Union();
781 :
782 0 : new_weight_add[new_alloc_tensors_idx].allocate();
783 0 : new_weight_add[new_alloc_tensors_idx].copy(mul_beta);
784 0 : std::string name = "ADD_tensor";
785 0 : new_weight_add[new_alloc_tensors_idx].setName(name);
786 :
787 : new_variable.clear();
788 0 : new_variable.emplace_back(&new_weight_add[new_alloc_tensors_idx]);
789 0 : new_alloc_tensors_idx++;
790 :
791 : tf_node->replaceWeights(new_variable);
792 0 : tf_node->setWeights(new_variable, true);
793 0 : tf_node->setBuiltinOptions(tflite::BuiltinOptions_AddOptions, options);
794 :
795 0 : nodes.at(node_count + 1)
796 : .get()
797 : ->setToBeRemoved(true); // remove ReLU Layer and Fuse with Add
798 :
799 : auto mul_node = realized_nodes.back().get();
800 : tf_node->arity(1);
801 : tf_node->setArg(0, mul_node);
802 :
803 : realized_nodes.push_back(std::move(tf_node));
804 : set_input = true;
805 0 : }
806 : }
807 22 : node_count++;
808 : }
809 :
810 5 : return realized_nodes;
811 0 : }
812 :
813 5 : void TfliteInterpreter::serialize(const GraphRepresentation &representation,
814 : const std::string &out) {
815 :
816 : /// 1. remove loss layer in GraphRepresentation
817 : TfliteExportRealizer tflite_realizer({});
818 5 : GraphRepresentation graph_loss = tflite_realizer.realize(representation);
819 5 : GraphRepresentation graph = tflite_realizer.realize_dropout(graph_loss);
820 :
821 : /// 2. The graph must have weights, input dims, output dims set
822 : flatbuffers::FlatBufferBuilder fbb;
823 :
824 5 : auto opNodes = buildOpNodes(graph, fbb);
825 5 : auto converted_opNodes = buildRealizedOpNodes(opNodes, fbb);
826 :
827 5 : TfOpIdxMap map(converted_opNodes); /// build TfOpIdxMap from opNodes
828 5 : auto opcodes = buildOperatorCodes(map, fbb);
829 5 : auto subgraphs = buildSubGraphs(converted_opNodes, map, fbb);
830 5 : auto buffers = buildBuffers(map, fbb);
831 5 : auto desc = fbb.CreateString("This file is generated from NNTrainer");
832 : tflite::ModelBuilder model_builder(fbb);
833 :
834 5 : model_builder.add_operator_codes(opcodes);
835 5 : model_builder.add_subgraphs(subgraphs);
836 5 : model_builder.add_buffers(buffers);
837 : model_builder.add_version(3);
838 5 : model_builder.add_description(desc);
839 : auto model = model_builder.Finish();
840 :
841 : fbb.Finish(model, tflite::ModelIdentifier());
842 5 : builder2file(fbb, out);
843 5 : }
844 :
845 0 : GraphRepresentation TfliteInterpreter::deserialize(const std::string &in) {
846 : /// ======== list of things to consider ========
847 : /// we need to reconstruct some properties from the shape
848 : /// eg) units are not saved as a property
849 :
850 : /** NYI! */
851 0 : return {};
852 : }
853 :
854 : } // namespace nntrainer
|