LCOV - code coverage report
Current view: top level - nntrainer/compiler - tflite_interpreter.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 73.6 % 318 234
Test Date: 2025-12-14 20:38:17 Functions: 96.2 % 26 25

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
       4              :  *
       5              :  * @file tflite_interpreter.cpp
       6              :  * @date 12 April 2021
       7              :  * @brief NNTrainer *.tflite Interpreter
       8              :  * @see https://github.com/nnstreamer/nntrainer
       9              :  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
      10              :  * @author Donghak Park <donghak.park@samsung.com>
      11              :  * @bug No known bugs except for NYI items
      12              :  */
      13              : #include <tflite_interpreter.h>
      14              : 
      15              : #include <algorithm>
      16              : #include <fstream>
      17              : #include <map>
      18              : #include <memory>
      19              : #include <regex>
      20              : #include <set>
      21              : #include <string>
      22              : #include <tuple>
      23              : #include <type_traits>
      24              : #include <utility>
      25              : 
      26              : #include <bn_realizer.h>
      27              : #include <fc_layer.h>
      28              : #include <layer_node.h>
      29              : #include <tflite_export_realizer.h>
      30              : #include <nntrainer_error.h>
      31              : #include <node_exporter.h>
      32              : #include <tensor.h>
      33              : #include <tf_schema_generated.h>
      34              : #include <tflite_opnode.h>
      35              : 
      36              : static constexpr const char *FUNC_TAG = "[TFLITE INTERPRETER] ";
      37              : 
      38              : // This Variables need for create tflite nodes
      39              : nntrainer::TfOpNode::Variables new_variable;
      40              : nntrainer::Tensor new_weight_add[50];
      41              : unsigned int new_alloc_tensors_idx = 0;
      42              : 
      43              : namespace nntrainer {
      44              : 
      45              : namespace {
      46              : /**
      47              :  * @brief after finishing building, call this to safe to a file
      48              :  *
      49              :  * @param builder flatbuffer builder
      50              :  * @param out out
      51              :  */
      52            5 : void builder2file(const flatbuffers::FlatBufferBuilder &builder,
      53              :                   const std::string &out) {
      54            5 :   uint8_t *buf = builder.GetBufferPointer();
      55            5 :   size_t size = builder.GetSize();
      56              :   flatbuffers::Verifier v(buf, size);
      57            5 :   NNTR_THROW_IF(!tflite::VerifyModelBuffer(v), std::invalid_argument)
      58              :     << FUNC_TAG << "Verifying serialized model failed";
      59            5 :   std::ofstream os(out, std::ios_base::binary);
      60              : 
      61              :   const size_t error_buflen = 100;
      62              :   char error_buf[error_buflen];
      63            5 :   NNTR_THROW_IF(!os.good(), std::invalid_argument)
      64              :     << FUNC_TAG << "failed to open, reason: "
      65            0 :     << SAFE_STRERROR(errno, error_buf, error_buflen);
      66              : 
      67            5 :   std::streamsize sz = static_cast<std::streamsize>(builder.GetSize());
      68              :   NNTR_THROW_IF(sz < 0, std::invalid_argument)
      69              :     << FUNC_TAG << "builder size: " << builder.GetSize()
      70              :     << " is too big. It cannot be represented by std::streamsize";
      71              : 
      72            5 :   os.write((char *)builder.GetBufferPointer(), sz);
      73            5 :   os.close();
      74            5 : }
      75              : 
      76              : /**
      77              :  * @brief get predecessor nodes
      78              :  *
      79              :  * @param node the node from which to get predecessor nodes
      80              :  * @note virtual nodes are ignored
      81              :  */
      82           17 : std::vector<const TfOpNode *> getPredNodes(const TfOpNode &node) {
      83              :   std::vector<const TfOpNode *> predNodes;
      84              : 
      85           34 :   for (auto input : node.getInputNodes()) {
      86           17 :     const TfOpNode *pred = input;
      87           17 :     while (pred->isVirtualNode()) {
      88              :       /// Assume that virtual nodes have single input
      89            0 :       assert(pred->arity() == 1);
      90            0 :       pred = pred->arg(0);
      91              :     }
      92           17 :     predNodes.push_back(pred);
      93              :   }
      94           17 :   return predNodes;
      95            0 : }
      96              : 
      97              : using TfOpNodes = std::vector<std::unique_ptr<TfOpNode>>;
      98              : 
      99              : /**
     100              :  * @brief Bidirectional Index map
     101              :  *
     102              :  * @tparam Key type of a underlying hashable value, please note that T will be
     103              :  * copied, so please use this for pointers and primitive values that is okay to
     104              :  * copy
     105              :  * @tparam Data data type to be stored inside the vector, if not given, same as
     106              :  * KeyType
     107              :  */
     108              : template <typename KeyType, typename DataType = KeyType>
     109              : class BidirectionalIndexMap {
     110              : public:
     111              :   /**
     112              :    * @brief addDatapoint to the map
     113              :    *
     114              :    * @param key key to be added to search for the data
     115              :    * @param data data to be added if there is no occurrence, data will be
     116              :    * copied.
     117              :    */
     118           75 :   void addDataWhenNotFound(KeyType key, DataType data) {
     119              :     auto search = key2index.find(key);
     120              : 
     121           75 :     if (search == key2index.end()) {
     122           68 :       key2index[key] = index2data.size();
     123           68 :       index2data.push_back(data);
     124              :     }
     125           75 :   }
     126              : 
     127              :   /**
     128              :    * @brief addDatapoint to the map when key and datatype is same
     129              :    *
     130              :    * @param key key/data to add
     131              :    */
     132              :   void addDataWhenNotFound(KeyType key) {
     133              :     static_assert(std::is_same<KeyType, DataType>::value == true,
     134              :                   "key type and data type are different!");
     135           22 :     addDataWhenNotFound(key, key);
     136              :   }
     137              : 
     138              :   /**
     139              :    * @brief Get the Index of the data
     140              :    *
     141              :    * @param key data that will be the key
     142              :    * @return unsigned int index
     143              :    */
     144           65 :   unsigned int getIndex(const KeyType &key) const {
     145              :     auto search = key2index.find(key);
     146              : 
     147           65 :     NNTR_THROW_IF(search == key2index.end(), std::invalid_argument)
     148              :       << FUNC_TAG << "Cannot find index for key: " << key;
     149              : 
     150           65 :     return search->second;
     151              :   }
     152              : 
     153              :   /**
     154              :    * @brief Get the Data object
     155              :    *
     156              :    * @param idx index to be searched
     157              :    * @return T datapoint T
     158              :    */
     159              :   DataType getData(unsigned int index) const {
     160              :     NNTR_THROW_IF(index >= index2data.size(), std::invalid_argument)
     161              :       << FUNC_TAG << "Cannot find data for index: " << index;
     162              : 
     163              :     return index2data[index];
     164              :   }
     165              : 
     166              :   /**
     167              :    * @brief Get the Data object
     168              :    *
     169              :    * @return const std::vector<T>& underlying data
     170              :    */
     171              :   const std::vector<DataType> &getData() const { return index2data; }
     172              : 
     173              : private:
     174              :   std::unordered_map<KeyType, unsigned int> key2index; /**< key -> index map */
     175              :   std::vector<DataType> index2data;                    /**< index -> data map */
     176              : };
     177              : 
     178              : /**
     179              :  * @brief tensorflow operation index map, this class manages operation index
     180              :  * mapping
     181              :  *
     182              :  */
     183              : class TfOpIdxMap {
     184              : public:
     185              :   using Buffer = std::pair<size_t, const float *>;
     186              : 
     187            5 :   TfOpIdxMap(const TfOpNodes &nodes) {
     188              :     auto &opcode_map = getIndexMap<tflite::BuiltinOperator>();
     189              :     auto update_opcode = [&opcode_map](tflite::BuiltinOperator opcode) {
     190              :       opcode_map.addDataWhenNotFound(opcode);
     191           22 :     };
     192              : 
     193              :     auto &buffer_map = getIndexMap<const float *, Buffer>();
     194            5 :     buffer_map.addDataWhenNotFound(
     195            5 :       nullptr, {0, empty_buffer}); // this represents undefined buffer
     196            5 :     buffer_map.addDataWhenNotFound(
     197              :       empty_buffer, {0, empty_buffer}); // this represents empty buffer
     198              : 
     199           49 :     auto update_buffer_map = [&buffer_map](const TfOpNode::Variables &variables,
     200              :                                            bool dynamic) {
     201           92 :       for (auto &variable : variables) {
     202           43 :         const float *buf = variable->getData();
     203           43 :         assert(buf != nullptr);
     204           43 :         auto byte_size = dynamic ? 0 : variable->bytes();
     205           43 :         buffer_map.addDataWhenNotFound(buf, {byte_size, buf});
     206              :       }
     207           54 :     };
     208              : 
     209              :     auto register_tensors =
     210           49 :       [&tensors = this->tensors](const TfOpNode::Variables &variables) {
     211           97 :         for (auto &variable : variables) {
     212           48 :           auto tensor_it = std::find(tensors.begin(), tensors.end(), variable);
     213           48 :           if (tensor_it == tensors.end()) {
     214           48 :             tensors.push_back(variable);
     215              :           }
     216              :         }
     217            5 :       };
     218              : 
     219           27 :     for (auto &op_node : nodes) {
     220           22 :       if (op_node->isVirtualNode())
     221            0 :         continue;
     222              :       update_opcode(op_node->getOpType());
     223              : 
     224           22 :       if (op_node->isInputNode()) {
     225              :         /**
     226              :          * Q) Why only register graph input tensor?
     227              :          *
     228              :          * A) the tflite needs only one tensor between nodes. Therefore,
     229              :          *basically, no inputs are considered except graph input that doesn't
     230              :          *have FROM node.
     231              :          **/
     232            5 :         register_tensors(op_node->getInputs());
     233              :         /**
     234              :          * Q) Why only update second input of the input node?
     235              :          *
     236              :          * A) 1. graph input nodes should be Transpose operator to change data
     237              :          *format from NCHW to NHWC.
     238              :          *    2. Transpose operator has two inputs - input to be
     239              :          *transposed(input[0]), 1d permute vector(input[1])
     240              :          *    3. input[0] has nullptr data pointer, which can't be added to
     241              :          *buffer_map. But, input[0] should have its own buffer and it will be
     242              :          *considered when the tflite buffers are built.
     243              :          **/
     244            5 :         assert(op_node->getInputs()[0]->getData() == nullptr);
     245            5 :         update_buffer_map({op_node->getInputs()[1]}, false);
     246              :       }
     247           22 :       register_tensors(op_node->getWeights());
     248           22 :       update_buffer_map(op_node->getWeights(), false);
     249              : 
     250           22 :       register_tensors(op_node->getOutputs());
     251           22 :       update_buffer_map(op_node->getOutputs(), true);
     252              :     }
     253              : 
     254           10 :     auto update_model_io_to = [this](const TfOpNode::Variables &variables,
     255              :                                      std::vector<int> &v) {
     256           25 :       for (auto &variable : variables) {
     257           30 :         if (variable->getName().find("nntrainer_internal_perm") !=
     258              :             std::string::npos)
     259            5 :           continue;
     260           10 :         v.push_back(this->getTensorIndex(variable));
     261              :       }
     262           15 :     };
     263              : 
     264           27 :     for (auto &op_node : nodes) {
     265           22 :       if (op_node->isVirtualNode())
     266            0 :         continue;
     267           22 :       if (op_node->isInputNode()) {
     268            5 :         update_model_io_to(op_node->getInputs(), inputs);
     269              :       }
     270           22 :       if (op_node->isOutputNode()) {
     271            5 :         update_model_io_to(op_node->getOutputs(), outputs);
     272              :       }
     273              :     }
     274            5 :   }
     275              : 
     276              :   template <typename KeyType, typename DataType = KeyType>
     277              :   BidirectionalIndexMap<KeyType, DataType> &getIndexMap() {
     278              :     return std::get<BidirectionalIndexMap<KeyType, DataType>>(maps);
     279              :   }
     280              : 
     281              :   template <typename KeyType, typename DataType = KeyType>
     282              :   const BidirectionalIndexMap<KeyType, DataType> &getIndexMap() const {
     283              :     return std::get<BidirectionalIndexMap<KeyType, DataType>>(maps);
     284              :   }
     285              : 
     286              :   const float *get_empty_buffer() const { return empty_buffer; }
     287              : 
     288              :   const std::vector<int> &getInputs() const { return inputs; }
     289              : 
     290              :   const std::vector<int> &getOutputs() const { return outputs; }
     291              : 
     292              :   const std::vector<const Tensor *> &getTensors() const { return tensors; }
     293              : 
     294           75 :   std::ptrdiff_t getTensorIndex(const Tensor *tensor) const {
     295           75 :     auto tensor_it = std::find(tensors.begin(), tensors.end(), tensor);
     296           75 :     NNTR_THROW_IF(tensor_it == tensors.cend(), std::invalid_argument)
     297            0 :       << FUNC_TAG << "Cannot find index for tensor: " << tensor->getName();
     298           75 :     return std::distance(tensors.begin(), tensor_it);
     299              :   }
     300              : 
     301              : private:
     302              :   float empty_buffer[0]; /**< reserved uninitialized tensor points to this
     303              :                             buffer */
     304              : 
     305              :   std::tuple<BidirectionalIndexMap<const float *, Buffer>,   /**< buffer map
     306              :                                                               */
     307              :              BidirectionalIndexMap<tflite::BuiltinOperator>> /**< opcode map
     308              :                                                               */
     309              :     maps;
     310              : 
     311              :   std::vector<int> inputs;
     312              :   std::vector<int> outputs;
     313              :   /// since it is used as a tensor index, the order is important
     314              :   std::vector<const Tensor *> tensors;
     315              : };
     316              : 
     317            5 : TfOpNodes buildOpNodes(const GraphRepresentation &representation,
     318              :                        flatbuffers::FlatBufferBuilder &fbb) {
     319              :   TfOpNodes nodes;
     320              :   /// @todo TfOpNode needs to have LayerNode pointer
     321              :   std::map<TfOpNode *, const LayerNode *> tf_to_layer;
     322              :   std::map<const LayerNode *, TfOpNode *> layer_to_tf;
     323              : 
     324              :   /// @todo, look ahead of layers to get nodes that can be fused
     325              :   /// we will need to have a dedicated builder
     326           27 :   for (auto iter = representation.cbegin(); iter != representation.cend();
     327              :        iter++) {
     328              :     const auto &ln = *iter;
     329              : 
     330           22 :     Exporter e(&fbb);
     331           22 :     ln->exportTo(e, ml::train::ExportMethods::METHOD_TFLITE);
     332           22 :     auto export_output = e.getResult<ml::train::ExportMethods::METHOD_TFLITE>();
     333              : 
     334           22 :     if (export_output.get()->getWeights().size() == 0) {
     335              :       export_output.get()->setTrainable(false);
     336              :     }
     337              : 
     338           22 :     nodes.emplace_back(std::move(export_output));
     339           22 :     tf_to_layer.insert({nodes.back().get(), ln.get()});
     340           22 :     layer_to_tf.insert({ln.get(), nodes.back().get()});
     341           22 :   }
     342              : 
     343              :   int node_count = 0;
     344              :   bool is_local_first = true;
     345              :   /** is_local_first : first FC Layer after Channel related layer
     346              :    * For example
     347              :    * : Input -> Conv -> Conv -> Flatten -> [FC]:local_first
     348              :    * : Input -> Conv -> Flatten -> [FC]:local_first -> Conv -> Flatten ->
     349              :    * [FC]:local_first
     350              :    */
     351              : 
     352              :   // set reorder weight flag for FullyConnected layer
     353           27 :   for (auto &n : nodes) {
     354              :     auto tf_node = n.get();
     355              : 
     356              :     if (tf_node->getOptionType() ==
     357              :           tflite::BuiltinOptions::BuiltinOptions_FullyConnectedOptions &&
     358           22 :         node_count != 0 && is_local_first) {
     359              :       tf_node->setNeedReorderWeight();
     360              :       is_local_first = false;
     361              :     }
     362              : 
     363           22 :     if (is_local_first == false &&
     364              :         tf_node->getOptionType() !=
     365              :           tflite::BuiltinOptions::BuiltinOptions_FullyConnectedOptions) {
     366              :       is_local_first = true;
     367              :     }
     368              : 
     369           22 :     node_count++;
     370              :   }
     371              : 
     372              :   /// set arity of TfOpNodes
     373           27 :   for (auto &n : nodes) {
     374              :     auto tf_node = n.get();
     375              :     auto searched_layer = tf_to_layer.find(tf_node);
     376           22 :     if (searched_layer == tf_to_layer.end())
     377            0 :       throw std::runtime_error("Cannot find layer for TfOpNode");
     378           22 :     auto layer_node = searched_layer->second;
     379              :     auto layer_node_inputs = layer_node->getInputConnections();
     380              : 
     381              :     /// assume that the TfOpNode and the LayerNode have a one-to-one
     382              :     /// relationship
     383              :     tf_node->arity(layer_node_inputs.size());
     384           39 :     for (size_t index = 0; index < layer_node_inputs.size(); index++) {
     385              :       auto input_layer_name = layer_node_inputs[index];
     386           17 :       auto input_layer_node_iterator = std::find_if(
     387              :         representation.begin(), representation.end(),
     388           53 :         [&input_layer_name](std::shared_ptr<nntrainer::LayerNode> node) {
     389          106 :           return istrequal(node.get()->getName(), input_layer_name);
     390              :         });
     391              : 
     392           17 :       if (input_layer_node_iterator != representation.end()) {
     393              :         auto input_layer_node = input_layer_node_iterator->get();
     394           17 :         if (layer_to_tf.find(input_layer_node) != layer_to_tf.end()) {
     395           17 :           tf_node->setArg(index, layer_to_tf.find(input_layer_node)->second);
     396              :         }
     397              :       }
     398              :     }
     399           22 :   }
     400              : 
     401              :   node_count = 0;
     402           27 :   for (auto &n : nodes) {
     403              :     auto tf_node = n.get();
     404           22 :     if (tf_node->getOptionType() ==
     405              :         tflite::BuiltinOptions::BuiltinOptions_FullyConnectedOptions) {
     406            6 :       tf_node->weightReorder(node_count);
     407              :     }
     408              : 
     409              :     if (tf_node->getOpType() ==
     410            2 :           tflite::BuiltinOperator::BuiltinOperator_CONV_2D &&
     411            2 :         nodes.at(node_count + 1).get()->getOpType() ==
     412           22 :           tflite::BuiltinOperator::BuiltinOperator_MUL &&
     413            0 :         nodes.at(node_count + 2).get()->getOpType() ==
     414              :           tflite::BuiltinOperator::BuiltinOperator_RELU) {
     415              :       // Fuse Conv2D + Mul(Batch Norm) + ReLU to Conv2D
     416              : 
     417              :       auto props = tf_node->getProps();
     418              :       auto tf_padding = tflite::Padding_SAME;
     419              : 
     420            0 :       if (props[0] == 1) {
     421              :         tf_padding = tflite::Padding_VALID;
     422              :       }
     423              :       auto new_options =
     424            0 :         tflite::CreateConv2DOptions(fbb, tf_padding, props[1], props[2],
     425              :                                     tflite::ActivationFunctionType_RELU)
     426            0 :           .Union();
     427            0 :       tf_node->setBuiltinOptions(tflite::BuiltinOptions_Conv2DOptions,
     428              :                                  new_options);
     429              :       // After Fusing Mark ReLU Node to be removed
     430              :       nodes.at(node_count + 2).get()->setToBeRemoved(true);
     431            0 :     }
     432              : 
     433           22 :     if (node_count < 1) {
     434              :       node_count++;
     435            5 :       continue;
     436              :     } else {
     437           17 :       if (nodes.at(node_count - 1).get()->isTrainable() == true &&
     438              :           tf_node->getOpType() == tflite::BuiltinOperator_MUL) {
     439              : 
     440              :         // Fused weight(conv)
     441              :         // = weight(conv) * (weight(bn) / sqrt(var(bn) + eps))
     442              : 
     443            0 :         auto conv_weights = nodes.at(node_count - 1).get()->getWeights();
     444            0 :         Tensor conv_weight(conv_weights.at(0)->getDim());
     445            0 :         Tensor conv_bias(conv_weights.at(1)->getDim());
     446            0 :         conv_weight.copyData(conv_weights.at(0)->clone());
     447            0 :         conv_bias.copyData(conv_weights.at(1)->clone());
     448              : 
     449            0 :         auto mul_weights = tf_node->getWeights();
     450            0 :         auto mul_mean = mul_weights.at(0)->clone().transpose("1:2:0");
     451            0 :         auto mul_var = mul_weights.at(1)->clone().transpose("1:2:0");
     452            0 :         auto mul_weight = mul_weights.at(2)->clone().transpose("1:2:0");
     453            0 :         auto mul_bias = mul_weights.at(3)->clone().transpose("1:2:0");
     454            0 :         auto mul_epsilon = tf_node->getAdditionalProps().at(0);
     455              : 
     456              :         // run sqrt(var(bn) + eps)
     457            0 :         mul_var.add_i(mul_epsilon);
     458            0 :         mul_var.pow_i(-0.5f);
     459            0 :         mul_weight.multiply_i(mul_var);
     460              : 
     461            0 :         Tensor reshape_mul_weight(mul_weight.getDim());
     462            0 :         reshape_mul_weight.copy(mul_weight);
     463            0 :         reshape_mul_weight.reshape(
     464            0 :           TensorDim{mul_weight.getDim().width(), 1, 1, 1});
     465            0 :         conv_weight.multiply_i(reshape_mul_weight);
     466              : 
     467            0 :         conv_bias.subtract_i(mul_mean);
     468            0 :         conv_bias.multiply_i(mul_weight);
     469            0 :         conv_bias.add_i(mul_bias);
     470              : 
     471              :         TfOpNode::Variables conv_new_weights;
     472            0 :         conv_new_weights.push_back(&conv_weight);
     473            0 :         conv_new_weights.push_back(&conv_bias);
     474            0 :         nodes.at(node_count - 1).get()->setWeights(conv_new_weights);
     475              :         // set mul node to be removed (mul mean batch normalization)
     476              :         n->setToBeRemoved(true);
     477            0 :       }
     478              :     }
     479           17 :     node_count++;
     480              :   }
     481              : 
     482            5 :   return nodes;
     483            0 : }
     484              : 
     485              : flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::Buffer>>>
     486            5 : buildBuffers(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
     487              :   const auto &buffers =
     488              :     map.getIndexMap<const float *, TfOpIdxMap::Buffer>().getData();
     489              : 
     490              :   std::vector<flatbuffers::Offset<tflite::Buffer>> fb_buffers;
     491            5 :   fb_buffers.reserve(buffers.size());
     492              : 
     493           56 :   auto create_buffer_offset = [&fbb](const TfOpIdxMap::Buffer &buffer) {
     494           56 :     if (buffer.first == 0) {
     495           35 :       return tflite::CreateBuffer(fbb);
     496              :     }
     497              : 
     498           21 :     auto data = fbb.CreateVector(
     499           21 :       reinterpret_cast<const uint8_t *>(buffer.second), buffer.first);
     500              : 
     501           21 :     return tflite::CreateBuffer(fbb, data);
     502            5 :   };
     503              : 
     504            5 :   std::transform(buffers.begin(), buffers.end(), std::back_inserter(fb_buffers),
     505              :                  create_buffer_offset);
     506              : 
     507              :   // add input buffer
     508           10 :   for (unsigned index = 0; index < map.getInputs().size(); index++) {
     509           10 :     fb_buffers.push_back(create_buffer_offset({0, nullptr}));
     510              :   }
     511            5 :   return fbb.CreateVector(fb_buffers);
     512            5 : }
     513              : 
     514              : flatbuffers::Offset<
     515              :   flatbuffers::Vector<flatbuffers::Offset<tflite::OperatorCode>>>
     516            5 : buildOperatorCodes(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
     517              :   const auto &op_codes = map.getIndexMap<tflite::BuiltinOperator>().getData();
     518              : 
     519              :   std::vector<flatbuffers::Offset<tflite::OperatorCode>> fb_op_codes;
     520            5 :   fb_op_codes.reserve(op_codes.size());
     521              : 
     522           17 :   auto create_op_offset = [&fbb](const tflite::BuiltinOperator &op,
     523              :                                  int32_t version = 1) {
     524           17 :     tflite::OperatorCodeBuilder builder(fbb);
     525           17 :     builder.add_deprecated_builtin_code(static_cast<int8_t>(op));
     526              :     /// @todo find reason why version field is not shown
     527              :     /// on json when version is 1 (other versions are fine)
     528              :     builder.add_version(version);
     529           17 :     builder.add_builtin_code(op);
     530           17 :     return builder.Finish();
     531            5 :   };
     532              : 
     533            5 :   std::transform(op_codes.begin(), op_codes.end(),
     534              :                  std::back_inserter(fb_op_codes), create_op_offset);
     535              : 
     536            5 :   return fbb.CreateVector(fb_op_codes);
     537            5 : }
     538              : 
     539              : flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::Tensor>>>
     540            5 : buildTensors(const TfOpIdxMap &map, flatbuffers::FlatBufferBuilder &fbb) {
     541              :   /// @todo: the actual (squeezed) tensor dimension must be known before
     542              :   /// coming here. For now, it is directly guessed for the fc layer
     543              :   const auto &variables = map.getTensors();
     544              :   const auto &buffer_map = map.getIndexMap<const float *, TfOpIdxMap::Buffer>();
     545            5 :   auto graph_input_offset = map.getInputs().size() - 1;
     546              : 
     547              :   std::vector<flatbuffers::Offset<tflite::Tensor>> fb_tensors;
     548            5 :   fb_tensors.reserve(variables.size());
     549              : 
     550           48 :   auto create_tensor = [&fbb, &buffer_map,
     551              :                         &graph_input_offset](const Tensor *var) {
     552           48 :     auto dim = var->getDim();
     553           48 :     bool need_shape_signature = dim.is_dynamic();
     554           48 :     std::vector<int32_t> eff_dim = dim.getEffectiveDimension();
     555           48 :     auto shape = fbb.CreateVector(eff_dim);
     556              : 
     557              :     decltype(shape) shape_sig;
     558           48 :     if (need_shape_signature) {
     559           20 :       std::vector<int32_t> dyn_dim = dim.getEffectiveDimension(true);
     560           20 :       shape_sig = fbb.CreateVector(dyn_dim);
     561           20 :     }
     562              : 
     563              :     /// change this var->getName when tensor have it's own name
     564           96 :     auto name = fbb.CreateString("nntrainer_converted" + var->getName());
     565              : 
     566              :     /// only graph inputs have nullptr data pointer.
     567              :     unsigned int buffer_idx =
     568              :       var->getData() == nullptr
     569            5 :         ? buffer_map.getData().size() - graph_input_offset--
     570           48 :         : buffer_map.getIndex(var->getData());
     571              : 
     572           48 :     tflite::TensorBuilder builder(fbb);
     573           48 :     builder.add_name(name);
     574              :     builder.add_buffer(buffer_idx);
     575              :     /// @todo support more data types
     576              :     /// @note this is workaround because nntrainer tensor allows only float
     577              :     /// dtype
     578           96 :     if (var->getName().find("nntrainer_internal_perm") != std::string::npos) {
     579              :       builder.add_type(tflite::TensorType_INT32);
     580              :     } else
     581              :       builder.add_type(tflite::TensorType_FLOAT32);
     582           48 :     builder.add_shape(shape);
     583           48 :     if (need_shape_signature) {
     584           20 :       builder.add_shape_signature(shape_sig);
     585              :     }
     586           48 :     return builder.Finish();
     587           53 :   };
     588              : 
     589            5 :   std::transform(variables.begin(), variables.end(),
     590              :                  std::back_inserter(fb_tensors), create_tensor);
     591              : 
     592            5 :   return fbb.CreateVector(fb_tensors);
     593            5 : }
     594              : 
     595              : flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::Operator>>>
     596            5 : buildOperators(const TfOpNodes &nodes, const TfOpIdxMap &map,
     597              :                flatbuffers::FlatBufferBuilder &fbb) {
     598              : 
     599              :   /// this lambda maps variables to list of indexes in the map
     600           66 :   auto variables_to_idx_vector = [&map](const TfOpNode::Variables &v) {
     601              :     std::vector<int> idx_vector;
     602           66 :     idx_vector.reserve(v.size());
     603              : 
     604           66 :     std::transform(
     605              :       v.begin(), v.end(), std::back_inserter(idx_vector),
     606           65 :       [&map](const Tensor *variable) { return map.getTensorIndex(variable); });
     607           66 :     return idx_vector;
     608            0 :   };
     609              : 
     610           22 :   auto create_operator = [&fbb, &map,
     611              :                           &variables_to_idx_vector](const TfOpNode &node) {
     612           22 :     auto &index_map = map.getIndexMap<tflite::BuiltinOperator>();
     613              : 
     614           22 :     auto op_code = index_map.getIndex(node.getOpType());
     615              :     std::vector<int> inputs;
     616           22 :     if (node.isInputNode()) {
     617            5 :       inputs = variables_to_idx_vector(node.getInputs());
     618              :     } else {
     619              :       /**
     620              :        *  Q) Why find a tensor that shares a buffer with input tensor?
     621              :        *
     622              :        *  A) the tflite needs only one tensor between nodes. Therefore,
     623              :        *basically, output tensors are used for tflite tensor that shares its
     624              :        *buffer with input's
     625              :        **/
     626              :       TfOpNode::Variables input_tensors;
     627           34 :       for (auto parent_node : getPredNodes(node)) {
     628           34 :         for (auto parent_out : parent_node->getOutputs()) {
     629           34 :           for (auto in : node.getInputs()) {
     630              :             /// second condition is a workaround
     631              :             /// Transpose op output tensor originally had nullptr data pointer
     632              :             /// but it has been allocated (parent_out->getData()). But, the
     633              :             /// buffer that shared its buffer hasn't so it has still nullptr
     634              :             /// (in->getData()).
     635              :             /// @todo remove this workaround
     636           20 :             if (parent_out->getData() == in->getData() ||
     637            3 :                 (in->getData() == nullptr && parent_out->getData())) {
     638           17 :               if (std::find(input_tensors.begin(), input_tensors.end(),
     639              :                             parent_out) != input_tensors.end())
     640            0 :                 continue;
     641           17 :               input_tensors.push_back(parent_out);
     642              :             }
     643              :           }
     644              :         }
     645           17 :       }
     646           17 :       inputs = variables_to_idx_vector(input_tensors);
     647           17 :     }
     648           22 :     auto weights = variables_to_idx_vector(node.getWeights());
     649              : 
     650              :     /// weights are part of input in tflite
     651           22 :     inputs.insert(inputs.end(), weights.begin(), weights.end());
     652              : 
     653           22 :     auto outputs = variables_to_idx_vector(node.getOutputs());
     654              : 
     655           22 :     auto fb_inputs = fbb.CreateVector(inputs);
     656           22 :     auto fb_outputs = fbb.CreateVector(outputs);
     657           22 :     auto fb_options = node.getBuiltinOps();
     658              : 
     659           22 :     tflite::OperatorBuilder builder(fbb);
     660              :     builder.add_opcode_index(op_code);
     661              :     builder.add_builtin_options_type(node.getOptionType());
     662           22 :     builder.add_builtin_options(fb_options);
     663           22 :     builder.add_inputs(fb_inputs);
     664           22 :     builder.add_outputs(fb_outputs);
     665           22 :     return builder.Finish();
     666           27 :   };
     667              : 
     668              :   std::vector<flatbuffers::Offset<tflite::Operator>> v;
     669            5 :   v.reserve(nodes.size());
     670              : 
     671           27 :   for (auto &node : nodes) {
     672           22 :     if (node->isVirtualNode())
     673            0 :       continue;
     674           22 :     auto op = create_operator(*node);
     675           22 :     v.push_back(op);
     676              :   }
     677              : 
     678            5 :   return fbb.CreateVector(v);
     679            5 : }
     680              : 
     681              : flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<tflite::SubGraph>>>
     682            5 : buildSubGraphs(const TfOpNodes &nodes, const TfOpIdxMap &map,
     683              :                flatbuffers::FlatBufferBuilder &fbb) {
     684              : 
     685            5 :   auto tensors = buildTensors(map, fbb);
     686            5 :   auto ops = buildOperators(nodes, map, fbb);
     687              : 
     688              :   /// @todo extract this to buildSubgraph if there is one or more subgraph
     689            5 :   auto name = fbb.CreateString("main");
     690              :   auto inputs = fbb.CreateVector(map.getInputs());
     691              :   auto outputs = fbb.CreateVector(map.getOutputs());
     692              : 
     693              :   auto builder = tflite::SubGraphBuilder(fbb);
     694            5 :   builder.add_tensors(tensors);
     695            5 :   builder.add_inputs(inputs);
     696            5 :   builder.add_outputs(outputs);
     697            5 :   builder.add_name(name);
     698            5 :   builder.add_operators(ops);
     699            5 :   auto subgraph = builder.Finish();
     700              : 
     701              :   std::vector<flatbuffers::Offset<tflite::SubGraph>> subgraphs;
     702            5 :   subgraphs.reserve(1);
     703            5 :   subgraphs.push_back(subgraph);
     704              : 
     705            5 :   return fbb.CreateVector(subgraphs);
     706            5 : }
     707              : 
     708              : } // namespace
     709              : 
     710            5 : TfOpNodes buildRealizedOpNodes(TfOpNodes &nodes,
     711              :                                flatbuffers::FlatBufferBuilder &fbb) {
     712              :   TfOpNodes realized_nodes;
     713              : 
     714              :   bool set_input = false;
     715              :   unsigned int node_count = 0;
     716              : 
     717           27 :   for (auto &node : nodes) {
     718           22 :     if (set_input) { // if front node is new added node set input output
     719              :       node->setArg(0, realized_nodes.back().get());
     720              :       realized_nodes.back()->setOutputs(node->getInputs());
     721              :       set_input = false;
     722              :     }
     723              : 
     724           22 :     if (node->isToBeRemoved() == true) {
     725              :       // Remove node, Assume that Input Node is not removed
     726              :       realized_nodes.back().get()->setOutputs(
     727            0 :         nodes.at(node_count)->getOutputs());
     728            0 :       nodes.at(node_count + 1)->setArg(0, realized_nodes.back().get());
     729              :       nodes.at(node_count + 1)->setInputs(nodes.at(node_count)->getInputs());
     730              :     } else {
     731              :       realized_nodes.push_back(std::move(node));
     732              : 
     733           22 :       if (realized_nodes.back().get()->getOpType() ==
     734              :           tflite::BuiltinOperator_MUL) { // Fused MUL ADD (Non Trainable)
     735              :         /**
     736              :           y = x * (gamma / sqrt(variance + epsilon)) +
     737              :           (beta - mean * gamma / sqrt(variance + epsilon))
     738              :         */
     739            0 :         auto removed_weights = realized_nodes.back().get()->getWeights();
     740            0 :         auto mul_mean = removed_weights.at(0)->clone();
     741            0 :         auto mul_variance = removed_weights.at(1)->clone();
     742            0 :         auto mul_gamma = removed_weights.at(2)->clone();
     743            0 :         auto mul_beta = removed_weights.at(3)->clone();
     744              :         auto mul_epsilon =
     745            0 :           realized_nodes.back().get()->getAdditionalProps().at(0);
     746              : 
     747              :         std::unique_ptr<Tensor> new_mul_weight =
     748            0 :           std::make_unique<Tensor>(mul_gamma.getDim());
     749            0 :         new_mul_weight->allocate();
     750            0 :         new_mul_weight->copy(mul_gamma);
     751              : 
     752              :         // new_mul_weight = (gamma / sqrt(variance + epsilon))
     753            0 :         mul_variance.add_i(mul_epsilon);
     754            0 :         mul_variance.pow_i(-0.5f);
     755            0 :         new_mul_weight->multiply_i(mul_variance);
     756              : 
     757              :         // beta =  (beta - mean * gamma / sqrt(variance + epsilon))
     758            0 :         Tensor sub_result(new_mul_weight->getDim());
     759            0 :         sub_result.allocate();
     760            0 :         sub_result.copyData(*new_mul_weight);
     761              : 
     762            0 :         mul_mean.multiply_i(sub_result);
     763            0 :         mul_beta.subtract_i(mul_mean);
     764            0 :         new_mul_weight->setName("MUL");
     765            0 :         for (auto weight : removed_weights) {
     766            0 :           delete weight;
     767              :         }
     768              :         removed_weights.clear();
     769            0 :         removed_weights.push_back(new_mul_weight.release());
     770              : 
     771              :         realized_nodes.back().get()->replaceWeights(removed_weights);
     772            0 :         realized_nodes.back().get()->setWeights(removed_weights, true);
     773              : 
     774              :         // Insert Add layer into Graph
     775            0 :         std::unique_ptr<TfOpNode> tf_node = std::make_unique<TfOpNode>();
     776              :         tf_node->setInputs(realized_nodes.back()->getOutputs());
     777              :         tf_node->setOpType(tflite::BuiltinOperator_ADD);
     778              :         auto options =
     779            0 :           tflite::CreateAddOptions(fbb, tflite::ActivationFunctionType_RELU)
     780            0 :             .Union();
     781              : 
     782            0 :         new_weight_add[new_alloc_tensors_idx].allocate();
     783            0 :         new_weight_add[new_alloc_tensors_idx].copy(mul_beta);
     784            0 :         std::string name = "ADD_tensor";
     785            0 :         new_weight_add[new_alloc_tensors_idx].setName(name);
     786              : 
     787              :         new_variable.clear();
     788            0 :         new_variable.emplace_back(&new_weight_add[new_alloc_tensors_idx]);
     789            0 :         new_alloc_tensors_idx++;
     790              : 
     791              :         tf_node->replaceWeights(new_variable);
     792            0 :         tf_node->setWeights(new_variable, true);
     793            0 :         tf_node->setBuiltinOptions(tflite::BuiltinOptions_AddOptions, options);
     794              : 
     795            0 :         nodes.at(node_count + 1)
     796              :           .get()
     797              :           ->setToBeRemoved(true); // remove ReLU Layer and Fuse with Add
     798              : 
     799              :         auto mul_node = realized_nodes.back().get();
     800              :         tf_node->arity(1);
     801              :         tf_node->setArg(0, mul_node);
     802              : 
     803              :         realized_nodes.push_back(std::move(tf_node));
     804              :         set_input = true;
     805            0 :       }
     806              :     }
     807           22 :     node_count++;
     808              :   }
     809              : 
     810            5 :   return realized_nodes;
     811            0 : }
     812              : 
     813            5 : void TfliteInterpreter::serialize(const GraphRepresentation &representation,
     814              :                                   const std::string &out) {
     815              : 
     816              :   /// 1. remove loss layer in GraphRepresentation
     817              :   TfliteExportRealizer tflite_realizer({});
     818            5 :   GraphRepresentation graph_loss = tflite_realizer.realize(representation);
     819            5 :   GraphRepresentation graph = tflite_realizer.realize_dropout(graph_loss);
     820              : 
     821              :   /// 2. The graph must have weights, input dims, output dims set
     822              :   flatbuffers::FlatBufferBuilder fbb;
     823              : 
     824            5 :   auto opNodes = buildOpNodes(graph, fbb);
     825            5 :   auto converted_opNodes = buildRealizedOpNodes(opNodes, fbb);
     826              : 
     827            5 :   TfOpIdxMap map(converted_opNodes); /// build TfOpIdxMap from opNodes
     828            5 :   auto opcodes = buildOperatorCodes(map, fbb);
     829            5 :   auto subgraphs = buildSubGraphs(converted_opNodes, map, fbb);
     830            5 :   auto buffers = buildBuffers(map, fbb);
     831            5 :   auto desc = fbb.CreateString("This file is generated from NNTrainer");
     832              :   tflite::ModelBuilder model_builder(fbb);
     833              : 
     834            5 :   model_builder.add_operator_codes(opcodes);
     835            5 :   model_builder.add_subgraphs(subgraphs);
     836            5 :   model_builder.add_buffers(buffers);
     837              :   model_builder.add_version(3);
     838            5 :   model_builder.add_description(desc);
     839              :   auto model = model_builder.Finish();
     840              : 
     841              :   fbb.Finish(model, tflite::ModelIdentifier());
     842            5 :   builder2file(fbb, out);
     843            5 : }
     844              : 
     845            0 : GraphRepresentation TfliteInterpreter::deserialize(const std::string &in) {
     846              :   /// ======== list of things to consider ========
     847              :   /// we need to reconstruct some properties from the shape
     848              :   /// eg) units are not saved as a property
     849              : 
     850              :   /** NYI! */
     851            0 :   return {};
     852              : }
     853              : 
     854              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1