LCOV - code coverage report
Current view: top level - nntrainer/compiler - tflite_opnode.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 67.9 % 28 19
Test Date: 2025-12-14 20:38:17 Functions: 100.0 % 1 1

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
       4              :  *
       5              :  * @file tflite_opnode.h
       6              :  * @date 28 April 2021
       7              :  * @brief contains tflite opnode which has information to convert to tflite file
       8              :  * @see https://github.com/nnstreamer/nntrainer
       9              :  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
      10              :  * @author Donghak Park <donghak.park@samsung.com>
      11              :  * @bug No known bugs except for NYI items
      12              :  */
      13              : 
      14              : #ifndef __TFLITE_OPNODE_H__
      15              : #define __TFLITE_OPNODE_H__
      16              : 
      17              : #include <functional>
      18              : #include <utility>
      19              : #include <vector>
      20              : 
      21              : #include <tensor.h>
      22              : #include <tf_schema_generated.h>
      23              : 
      24              : namespace nntrainer {
      25              : 
      26              : class LayerNode;
      27              : class RunLayerContext;
      28              : /**
      29              :  * @brief tensorflow operational node representation. This class contains,
      30              :  * information to build operation flatbuffer
      31              :  *
      32              :  */
      33              : class TfOpNode {
      34              : public:
      35              :   using Variables = std::vector<const Tensor *>;
      36              : 
      37              :   using TransformFn =
      38              :     std::function<std::vector<Tensor>(std::vector<const Tensor *> &)>;
      39              : 
      40              :   /**
      41              :    * @brief Construct a new Tf object
      42              :    *
      43              :    */
      44              :   TfOpNode();
      45              : 
      46              :   /**
      47              :    * @brief finalize tf op node will be transformed to required variables
      48              :    * in this phase, weights are merged into inputs
      49              :    *
      50              :    */
      51              :   void finalize();
      52              : 
      53              :   /**
      54              :    * @brief Set common informations from layer node
      55              :    *
      56              :    * @param layer node layer node
      57              :    */
      58              :   void setLayerNode(const LayerNode &layer);
      59              : 
      60              :   /**
      61              :    * @brief Set the Weight Transform Fn object
      62              :    *
      63              :    * @param fn fn will be called before get
      64              :    */
      65              :   void setWeightTransformFn(TransformFn fn);
      66              : 
      67              :   /**
      68              :    * @brief Set the Input Transform Fn object
      69              :    *
      70              :    * @param fn fn will be called before get
      71              :    */
      72              :   void setInputTransformFn(TransformFn fn);
      73              : 
      74              :   /**
      75              :    * @brief Set the Op Type object
      76              :    *
      77              :    * @param op_type_ operation type
      78              :    */
      79           22 :   void setOpType(tflite::BuiltinOperator op_type_) { op_type = op_type_; }
      80              : 
      81              :   /**
      82              :    * @brief Set the Builtin Options object,
      83              :    * @note this can go private, export from a layer and fill this out
      84              :    *
      85              :    * @param builtin_option_type_ builtin option type
      86              :    * @param builtin_ops_ flatbuffer offset of builtin_ops
      87              :    */
      88              :   void setBuiltinOptions(tflite::BuiltinOptions builtin_option_type_,
      89              :                          const flatbuffers::Offset<void> &builtin_ops_);
      90              : 
      91              :   /**
      92              :    * @brief Set the Need Reorder Weight object
      93              :    *
      94              :    */
      95            4 :   void setNeedReorderWeight() { need_reorder_weight = true; }
      96              : 
      97              :   /**
      98              :    * @brief Set the To Be Removed object
      99              :    *
     100              :    */
     101            0 :   void setToBeRemoved(bool to_be_removed) { is_to_be_removed = to_be_removed; }
     102              : 
     103              :   /**
     104              :    * @brief Set the Trainable object
     105              :    *
     106              :    */
     107           14 :   void setTrainable(bool trainable) { is_trainable = trainable; }
     108              : 
     109              :   /**
     110              :    * @brief Set the Inputs object
     111              :    *
     112              :    * @param inputs_
     113              :    */
     114            0 :   void setInputs(const Variables &inputs_) { inputs = inputs_; }
     115              : 
     116              :   /**
     117              :    * @brief Set the Outputs object
     118              :    *
     119              :    * @param outputs_
     120              :    */
     121            0 :   void setOutputs(const Variables &outputs_) { outputs = outputs_; }
     122              : 
     123              :   /**
     124              :    * @brief Set the Weights object
     125              :    *
     126              :    * @param weights_
     127              :    */
     128              :   void setWeights(Variables weights_, bool weight_transpose = false);
     129              :   /**
     130              :    * @brief Replace the Weights object
     131              :    *
     132              :    * @param weights_
     133              :    */
     134            0 :   void replaceWeights(const Variables &weights_) { weights = weights_; }
     135              :   /**
     136              :    * @brief Set(Append) the Props object
     137              :    *
     138              :    * @param value
     139              :    */
     140            6 :   void AppendProps(const int &value) { props_vector.push_back(value); }
     141              : 
     142              :   /**
     143              :    * @brief Set(Append) the Additional Props object
     144              :    *
     145              :    * @param value
     146              :    */
     147              :   void AppendAdditionalProps(const float &value) {
     148            0 :     additional_props.push_back(value);
     149              :   }
     150              : 
     151              :   /**
     152              :    * @brief Reorder Weight in case of NCHW --> NHWC
     153              :    *
     154              :    * @param node_count
     155              :    */
     156              :   void weightReorder(unsigned int node_count);
     157              : 
     158              :   /**
     159              :    * @brief Get the Inputs object
     160              :    *
     161              :    * @return Variables& inputs
     162              :    */
     163           10 :   Variables &getInputs() { return inputs; }
     164              : 
     165              :   /**
     166              :    * @brief Get the weights object
     167              :    *
     168              :    * @return const Variables& weights
     169              :    */
     170           22 :   const Variables &getWeights() const { return weights; }
     171              : 
     172              :   /**
     173              :    * @brief Get the weights object
     174              :    *
     175              :    * @return Variables& weights
     176              :    */
     177           22 :   Variables &getWeights() { return weights; }
     178              : 
     179              :   /**
     180              :    * @brief Get the Inputs object
     181              :    *
     182              :    * @return const Variables& inputs
     183              :    */
     184            5 :   const Variables &getInputs() const { return inputs; }
     185              : 
     186              :   /**
     187              :    * @brief Get the Outputs object
     188              :    *
     189              :    * @return Variables&
     190              :    */
     191           27 :   Variables &getOutputs() { return outputs; }
     192              : 
     193              :   /**
     194              :    * @brief Get the Outputs object
     195              :    *
     196              :    * @return const Variables& outputs
     197              :    */
     198           22 :   const Variables &getOutputs() const { return outputs; }
     199              : 
     200              :   /**
     201              :    * @brief check if this op node is model input
     202              :    *
     203              :    * @retval true if op node is model input
     204              :    * @retval false if op node is not model input
     205              :    */
     206           66 :   bool isInputNode() const { return is_input; }
     207              : 
     208              :   /**
     209              :    * @brief check if this op node is model output
     210              :    *
     211              :    * @retval true if op node is model output
     212              :    * @retval false if op node is not model output
     213              :    */
     214           22 :   bool isOutputNode() const { return is_output; }
     215              : 
     216              :   /**
     217              :    * @brief check if this op node is virtual node
     218              :    *
     219              :    * virtual node is a node that will not be exported
     220              :    */
     221           83 :   bool isVirtualNode() const { return is_virtual; }
     222              : 
     223              :   /**
     224              :    * @brief check if this layer need to reorder
     225              :    *
     226              :    * @return true if weight need to reorder
     227              :    * @return false if reordering is not required
     228              :    */
     229              :   bool isNeedReorder() const { return need_reorder_weight; }
     230              : 
     231              :   /**
     232              :    * @brief check if this layer is trainable
     233              :    *
     234              :    * @return true if layer(OpNode) trainable
     235              :    * @return false if layer(OpNode) non-trainable
     236              :    */
     237           17 :   bool isTrainable() const { return is_trainable; }
     238              : 
     239              :   /**
     240              :    * @brief check if this layer is to be removed
     241              :    *
     242              :    * @return true
     243              :    * @return false
     244              :    */
     245           22 :   bool isToBeRemoved() const { return is_to_be_removed; }
     246              : 
     247              :   /**
     248              :    * @brief Get the Props Vector
     249              :    *
     250              :    * @return const std::vector<int> props_vector
     251              :    */
     252            0 :   std::vector<int> getProps() const { return props_vector; }
     253              : 
     254              :   /**
     255              :    * @brief Get the Additional Props Vector
     256              :    *
     257              :    * @return const std::vector<float> additional_props
     258              :    */
     259            0 :   std::vector<float> getAdditionalProps() const { return additional_props; }
     260              : 
     261              :   /**
     262              :    * @brief Get the Op Type object
     263              :    *
     264              :    * @return const tflite::BuiltinOperator
     265              :    */
     266           95 :   const tflite::BuiltinOperator getOpType() const { return op_type; }
     267              : 
     268              :   /**
     269              :    * @brief Get the Op Type object
     270              :    *
     271              :    * @return const tflite::BuiltinOperator
     272              :    */
     273              :   const tflite::BuiltinOptions getOptionType() const {
     274           66 :     return builtin_option_type;
     275              :   }
     276              : 
     277              :   /**
     278              :    * @brief Get the Op Options object
     279              :    * @param f Flatbuffer Builder
     280              :    * @retval const tflite::Offset<void>
     281              :    */
     282              :   flatbuffers::Offset<void> getBuiltinOps() const;
     283              : 
     284              :   /**
     285              :    * @brief Get input nodes
     286              :    *
     287              :    * @return const std::vector<TfOpNode *> &input_nodes
     288              :    */
     289              :   const std::vector<TfOpNode *> &getInputNodes() const { return input_nodes; }
     290              : 
     291              :   /**
     292              :    * @brief Set arity
     293              :    *
     294              :    * @param value value to set
     295              :    */
     296           22 :   void arity(size_t value) { input_nodes.resize(value); }
     297              : 
     298              :   /**
     299              :    * @brief Get arity
     300              :    *
     301              :    * @return const unsigned input_nodes.size()
     302              :    */
     303            0 :   const unsigned arity() const { return input_nodes.size(); }
     304              : 
     305              :   /**
     306              :    * @brief Set n-th argument of the node
     307              :    *
     308              :    * @param index argument index to set
     309              :    * @param node the node to be argument
     310              :    */
     311           17 :   void setArg(size_t index, TfOpNode *node) { input_nodes.at(index) = node; }
     312              : 
     313              :   /**
     314              :    * @brief Get n-th argument of the node
     315              :    *
     316              :    * @return TfOpNode *input_nodes.at(index)
     317              :    */
     318            0 :   TfOpNode *arg(size_t index) const { return input_nodes.at(index); }
     319              : 
     320              : private:
     321              :   Variables inputs;                    /**< input variables */
     322              :   Variables outputs;                   /**< output variables */
     323              :   Variables weights;                   /**< weight variables */
     324              :   std::vector<TfOpNode *> input_nodes; /**< input nodes */
     325              :   std::vector<int> props_vector;       /**< props vector */
     326              :   std::vector<float> additional_props; /**< additional props vector */
     327              : 
     328              :   /**
     329              :    * Q) Why do we need input transform?
     330              :    * A) To transform the nntrainer input data format(NCHW) to tflite
     331              :    *format(NHWC)
     332              :    **/
     333              :   TransformFn weight_transform; /**< weight transforms */
     334              :   TransformFn input_transform;  /**< input transforms */
     335              : 
     336              :   bool is_input;            /**< true if given input is input; */
     337              :   bool is_output;           /**< true if given output is output; */
     338              :   bool is_virtual;          /**< true if given node is virtual; */
     339              :   bool is_trainable;        /**< true if given node has weight and trainable */
     340              :   bool is_to_be_removed;    /**< true if given node is to be removed */
     341              :   bool need_reorder_weight; /**< true if given node need to reorder weight; */
     342              : 
     343              :   /// @todo change to shared_ptr or unique_ptr
     344              :   /// why? the addresses of existing tensors in the vector could become invalid
     345              :   /// due to memory reallocation
     346              :   std::vector<Tensor>
     347              :     node_owned_variable; /**< when node should be transformed it's own type, it
     348              :                           * needs to be owned by someone, so @a TfOpNode owns
     349              :                           * those orphaned tensors until the instance is
     350              :                           * destroyed */
     351              : 
     352              :   tflite::BuiltinOperator op_type;
     353              : 
     354              :   /// retrieve this from export_to
     355              :   flatbuffers::Offset<void> builtin_ops;
     356              :   tflite::BuiltinOptions builtin_option_type;
     357              : };
     358              : 
     359              : } // namespace nntrainer
     360              : 
     361              : #endif // __TFLITE_OPNODE_H__
        

Generated by: LCOV version 2.0-1