LCOV - code coverage report
Current view: top level - nntrainer/compiler - multiout_realizer.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 97.6 % 41 40
Test Date: 2025-12-14 20:38:17 Functions: 100.0 % 4 4

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
       4              :  *
       5              :  * @file multiout_realizer.h
       6              :  * @date 17 November 2021
       7              :  * @brief NNTrainer graph realizer which realizes multiout to actual node
       8              :  * @see https://github.com/nnstreamer/nntrainer
       9              :  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
      10              :  * @bug No known bugs except for NYI items
      11              :  */
      12              : #include <stdexcept>
      13              : #include <string>
      14              : #include <unordered_map>
      15              : #include <unordered_set>
      16              : #include <utility>
      17              : 
      18              : #include <common_properties.h>
      19              : #include <compiler_fwd.h>
      20              : #include <connection.h>
      21              : #include <layer_node.h>
      22              : #include <multiout_realizer.h>
      23              : #include <remap_realizer.h>
      24              : 
      25              : namespace nntrainer {
      26         1400 : MultioutRealizer::~MultioutRealizer() {}
      27              : 
      28              : GraphRepresentation
      29          701 : MultioutRealizer::realize(const GraphRepresentation &reference) {
      30          701 :   GraphRepresentation processed(reference.begin(), reference.end());
      31              : 
      32              :   std::unordered_map<Connection, unsigned> freq_map;
      33              :   std::unordered_set<std::string> node_names;
      34              :   std::vector<Connection> connections;
      35              : 
      36              :   /// 1. build frequency map and connection names
      37         4661 :   for (auto &node : reference) {
      38         7922 :     NNTR_THROW_IF(node_names.count(node->getName()), std::invalid_argument)
      39            2 :       << "node name clashes: " << node->getName();
      40         7920 :     node_names.emplace(node->getName());
      41              : 
      42         8050 :     for (unsigned int i = 0, num_nodes = node->getNumInputConnections();
      43         8050 :          i < num_nodes; ++i) {
      44              :       Connection c(node->getInputConnectionName(i),
      45         4090 :                    node->getInputConnectionIndex(i));
      46         4090 :       [[maybe_unused]] auto [iter, result] = freq_map.try_emplace(c, 0);
      47         4090 :       if (result)
      48         3881 :         connections.push_back(c);
      49         4090 :       iter->second++;
      50              :     }
      51              :   }
      52              : 
      53              :   /// 2. for each connection names, if a connection is referenced multiple
      54              :   /// times, create multioutput node and remap to multi output node index
      55              :   std::unordered_map<
      56              :     std::string /**< original id */,
      57              :     std::vector<std::shared_ptr<LayerNode>> /**< created node */>
      58              :     multiout_nodes;
      59              : 
      60         4581 :   for (auto &con : connections) {
      61         3881 :     unsigned freq = freq_map[con];
      62              :     /// @note freq < 1 should never happen as the map entry is not created.
      63              :     /// but if it happens multiout realizer is not interested in checking if it
      64              :     /// is a dangled or actually an output. So there is no assurance done at
      65              :     /// this point. Some other class must check if the given graph is formed in
      66              :     /// a correct way.
      67         3881 :     if (freq <= 1) {
      68         3741 :       continue;
      69              :     }
      70              : 
      71              :     std::string id = con.getName();
      72          140 :     auto idx = con.getIndex();
      73              : 
      74          140 :     std::stringstream ss;
      75              :     /// {connection_name}/generated_out_{index}
      76              :     ss << id << "/generated_out_" << idx;
      77          140 :     while (node_names.count(ss.str()) != 0) {
      78            0 :       ss << "_";
      79              :     }
      80              :     auto multiout_name = ss.str();
      81              : 
      82          560 :     multiout_nodes[id].push_back(createLayerNode(
      83          280 :       "multiout", {"name=" + multiout_name, "input_layers=" + con.toString()}));
      84              :     node_names.emplace(multiout_name);
      85              : 
      86          140 :     unsigned input_count = 0;
      87         2901 :     RemapRealizer remapper([&id, &multiout_name, idx,
      88              :                             &input_count](std::string &id_, unsigned &idx_) {
      89         2621 :       if (id_ == id && idx_ == idx) {
      90          349 :         id_ = multiout_name;
      91          349 :         idx_ = input_count++;
      92              :       }
      93          140 :     });
      94              : 
      95          140 :     processed = remapper.realize(processed);
      96          280 :   }
      97              : 
      98              :   /// 3. insert multiout_nodes close to the original node to make the
      99              :   /// realization more sensible
     100              :   GraphRepresentation ret;
     101          700 :   ret.reserve(processed.size());
     102         4659 :   for (auto &node : processed) {
     103         3959 :     ret.push_back(node);
     104        11877 :     auto ranges = multiout_nodes[node->getName()];
     105         4099 :     for (auto &it : ranges) {
     106          140 :       ret.push_back(it);
     107              :     }
     108         3959 :   }
     109              : 
     110          700 :   return ret;
     111         1542 : }
     112              : 
     113              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1