LCOV - code coverage report
Current view: top level - nntrainer/compiler - tflite_export_realizer.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 100.0 % 47 47
Test Date: 2025-12-14 20:38:17 Functions: 100.0 % 4 4

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2022 seongwoo <mhs4670go@naver.com>
       4              :  *
       5              :  * @file tflite_export_realizer.cpp
       6              :  * @date 18 July 2025
       7              :  * @brief NNTrainer graph realizer which remove loss layer for inference
       8              :  * @see https://github.com/nnstreamer/nntrainer
       9              :  * @author seongwoo <mhs4670go@naver.com>
      10              :  * @author donghak park <donghak.park@samsung.com>
      11              :  * @bug No known bugs except for NYI items
      12              :  */
      13              : 
      14              : #include <algorithm>
      15              : #include <cassert>
      16              : #include <connection.h>
      17              : #include <layer_node.h>
      18              : #include <nntrainer_error.h>
      19              : #include <nntrainer_log.h>
      20              : #include <set>
      21              : #include <stdexcept>
      22              : #include <string>
      23              : #include <tflite_export_realizer.h>
      24              : #include <unordered_map>
      25              : 
      26              : namespace nntrainer {
      27              : 
      28              : static constexpr size_t SINGLE_IN_IDX = 0;
      29              : 
      30              : GraphRepresentation
      31            6 : TfliteExportRealizer::realize(const GraphRepresentation &reference) {
      32              :   /// @todo support more loss layers
      33              :   /// @note Some layers need to consider not removing all semantics.
      34              :   /// For example, When CrossEntropySigmoidLossLayer needs to be removed,
      35              :   /// sigmoid computation shouldn't be removed.
      36            6 :   static const std::set<std::string> loss_type = {"mse"};
      37              :   std::unordered_map<std::string, LayerNode *> existing_nodes;
      38              :   std::vector<LayerNode *> loss_layers;
      39              : 
      40            6 :   std::transform(
      41              :     reference.begin(), reference.end(),
      42              :     std::inserter(existing_nodes, existing_nodes.end()),
      43           58 :     [](auto &node) { return std::pair(node->getName(), node.get()); });
      44              : 
      45           35 :   for (auto &node : reference) {
      46           58 :     if (loss_type.find(node->getType()) != loss_type.end()) {
      47            1 :       loss_layers.push_back(node.get());
      48              :     }
      49              :   }
      50              : 
      51            7 :   for (auto iter = loss_layers.begin(); iter != loss_layers.end(); ++iter) {
      52            1 :     auto loss_node = (*iter);
      53            1 :     assert(loss_node->getNumInputConnections() == 1);
      54            1 :     auto &input_name = loss_node->getInputConnectionName(SINGLE_IN_IDX);
      55            1 :     auto input_node = existing_nodes.at(input_name);
      56            2 :     for (unsigned int i = 0; i < input_node->getNumOutputConnections(); ++i) {
      57            2 :       if (istrequal(loss_node->getName(),
      58              :                     input_node->getOutputConnection(i)->getName())) {
      59              :         /// Assume that loss layers don't have output connections
      60            1 :         assert(loss_node->getOutputConnections().size() == 0);
      61            1 :         input_node->setOutputLayers({});
      62              :       }
      63              :     }
      64              :   }
      65              : 
      66              :   GraphRepresentation processed;
      67           35 :   for (auto &node : reference) {
      68           58 :     if (loss_type.find(node->getType()) == loss_type.end()) {
      69           28 :       processed.push_back(node);
      70              :     }
      71              :   }
      72              : 
      73            6 :   return processed;
      74            6 : }
      75              : 
      76              : GraphRepresentation
      77            5 : TfliteExportRealizer::realize_dropout(const GraphRepresentation &reference) {
      78            5 :   static const std::set<std::string> dropout_type = {"dropout"};
      79              :   std::unordered_map<std::string, LayerNode *> existing_nodes;
      80              :   std::vector<LayerNode *> dropout_layers;
      81              : 
      82            5 :   std::transform(
      83              :     reference.begin(), reference.end(),
      84              :     std::inserter(existing_nodes, existing_nodes.end()),
      85           46 :     [](auto &node) { return std::pair(node->getName(), node.get()); });
      86              : 
      87              :   // find dropout layer and push to vector
      88           28 :   for (auto &node : reference) {
      89           46 :     if (dropout_type.find(node->getType()) != dropout_type.end()) {
      90            1 :       dropout_layers.push_back(node.get());
      91              :     }
      92              :   }
      93              : 
      94            6 :   for (auto iter = dropout_layers.begin(); iter != dropout_layers.end();
      95              :        ++iter) {
      96            1 :     auto node = (*iter);
      97            1 :     auto &input_name = node->getInputConnectionName(SINGLE_IN_IDX);
      98            1 :     auto input_node = existing_nodes.at(input_name);
      99              : 
     100            2 :     for (unsigned int i = 0; i < input_node->getNumOutputConnections(); ++i) {
     101            2 :       if (istrequal(node->getName(),
     102              :                     input_node->getOutputConnection(i)->getName())) {
     103            1 :         input_node->setOutputConnection(
     104              :           i, node->getOutputConnection(i)->getName(), SINGLE_IN_IDX);
     105              :       }
     106            1 :       input_node->getOutput(SINGLE_IN_IDX)
     107            2 :         .setData(node->getOutput(SINGLE_IN_IDX).getMemoryData());
     108              :     }
     109              : 
     110            1 :     auto &output_name = node->getOutputConnection(SINGLE_IN_IDX)->getName();
     111            1 :     auto output_node = existing_nodes.at(output_name);
     112              : 
     113            2 :     for (unsigned int i = 0; i < output_node->getNumInputConnections(); ++i) {
     114            2 :       if (istrequal(node->getName(), output_node->getInputConnectionName(i))) {
     115            1 :         output_node->setInputConnectionName(i, node->getInputConnectionName(i));
     116              :       }
     117              :     }
     118              :   }
     119              : 
     120              :   GraphRepresentation processed;
     121           28 :   for (auto &node : reference) {
     122           46 :     if (dropout_type.find(node->getType()) == dropout_type.end()) {
     123           22 :       processed.push_back(node);
     124              :     }
     125              :   }
     126              : 
     127            5 :   return processed;
     128            5 : }
     129              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1