LCOV - code coverage report
Current view: top level - nntrainer/compiler - bn_realizer.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 100.0 % 23 23
Test Date: 2025-12-14 20:38:17 Functions: 100.0 % 2 2

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2022 Jijoong Moon <jijoong.moon@samsung.com>
       4              :  *
       5              :  * @file bn_realizer.cpp
       6              :  * @date 13 April 2022
       7              :  * @brief NNTrainer graph realizer which remove batch normalization layer for
       8              :  * inference
       9              :  * @see https://github.com/nnstreamer/nntrainer
      10              :  * @author Jijoong Moon <jijoong.moon@samsung.com>
      11              :  * @bug No known bugs except for NYI items
      12              :  */
      13              : #include <bn_realizer.h>
      14              : #include <connection.h>
      15              : #include <layer_node.h>
      16              : #include <nntrainer_error.h>
      17              : #include <nntrainer_log.h>
      18              : 
      19              : #include <algorithm>
      20              : #include <stdexcept>
      21              : #include <unordered_map>
      22              : 
      23              : namespace nntrainer {
      24              : 
      25              : static constexpr size_t SINGLE_INOUT_IDX = 0;
      26              : 
      27            2 : GraphRepresentation BnRealizer::realize(const GraphRepresentation &reference) {
      28              :   std::unordered_map<std::string, LayerNode *> existing_nodes;
      29              :   std::vector<LayerNode *> bn_layers;
      30              : 
      31            2 :   std::transform(
      32              :     reference.begin(), reference.end(),
      33              :     std::inserter(existing_nodes, existing_nodes.end()),
      34           42 :     [](auto &node) { return std::pair(node->getName(), node.get()); });
      35              : 
      36           23 :   for (auto &node : reference) {
      37           42 :     if (istrequal(node->getType(), "batch_normalization")) {
      38            5 :       bn_layers.push_back(node.get());
      39              :     }
      40              :   }
      41              : 
      42            7 :   for (auto iter = bn_layers.begin(); iter != bn_layers.end(); ++iter) {
      43            5 :     auto node = (*iter);
      44            5 :     auto &input_name = node->getInputConnectionName(SINGLE_INOUT_IDX);
      45            5 :     auto input_node = existing_nodes.at(input_name);
      46              : 
      47           10 :     for (unsigned int i = 0; i < input_node->getNumOutputConnections(); ++i) {
      48           10 :       if (istrequal(node->getName(),
      49              :                     input_node->getOutputConnection(i)->getName())) {
      50            5 :         input_node->setOutputConnection(
      51              :           i, node->getOutputConnection(SINGLE_INOUT_IDX)->getName(),
      52              :           SINGLE_INOUT_IDX);
      53              :       }
      54              :     }
      55              : 
      56            5 :     auto &output_name = node->getOutputConnection(SINGLE_INOUT_IDX)->getName();
      57            5 :     auto output_node = existing_nodes.at(output_name);
      58              : 
      59           10 :     for (unsigned int i = 0; i < output_node->getNumInputConnections(); ++i) {
      60           10 :       if (istrequal(node->getName(), output_node->getInputConnectionName(i))) {
      61            5 :         output_node->setInputConnectionName(
      62              :           i, node->getInputConnectionName(SINGLE_INOUT_IDX));
      63              :       }
      64              :     }
      65              :   }
      66              : 
      67              :   GraphRepresentation processed;
      68           23 :   for (auto &node : reference) {
      69           42 :     if (!istrequal(node->getType(), "batch_normalization")) {
      70           16 :       processed.push_back(node);
      71              :     }
      72              :   }
      73              : 
      74            2 :   return processed;
      75            2 : }
      76              : 
      77              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1