LCOV - code coverage report
Current view: top level - nntrainer/compiler - recurrent_realizer.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 94.5 % 128 121
Test Date: 2025-12-14 20:38:17 Functions: 90.5 % 21 19

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
       4              :  *
       5              :  * @file recurrent_realizer.h
       6              :  * @date 12 October 2021
       7              :  * @brief NNTrainer graph realizer to create unrolled graph from a graph
       8              :  * realizer
       9              :  * @see https://github.com/nnstreamer/nntrainer
      10              :  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
      11              :  * @bug No known bugs except for NYI items
      12              :  */
      13              : #include <algorithm>
      14              : #include <iterator>
      15              : #include <stdexcept>
      16              : #include <string>
      17              : 
      18              : #include <base_properties.h>
      19              : #include <common_properties.h>
      20              : #include <connection.h>
      21              : #include <input_layer.h>
      22              : #include <layer_node.h>
      23              : #include <nntrainer_error.h>
      24              : #include <node_exporter.h>
      25              : #include <recurrent_realizer.h>
      26              : #include <remap_realizer.h>
      27              : #include <rnncell.h>
      28              : #include <util_func.h>
      29              : #include <zoneout_lstmcell.h>
      30              : 
      31              : namespace nntrainer {
      32              : 
      33              : namespace props {
      34              : 
      35              : /**
      36              :  * @brief Property check unroll_for
      37              :  *
      38              :  */
      39           57 : class UnrollFor final : public PositiveIntegerProperty {
      40              : public:
      41              :   UnrollFor(const unsigned &value = 1);
      42              :   static constexpr const char *key = "unroll_for";
      43              :   using prop_tag = uint_prop_tag;
      44              : };
      45              : 
      46           57 : UnrollFor::UnrollFor(const unsigned &value) { set(value); }
      47              : 
      48              : /**
      49              :  * @brief dynamic time sequence property, use this to set and check if dynamic
      50              :  * time sequence is enabled.
      51              :  *
      52              :  */
      53           57 : class DynamicTimeSequence final : public nntrainer::Property<bool> {
      54              : public:
      55              :   /**
      56              :    * @brief Construct a new DynamicTimeSequence object
      57              :    *
      58              :    */
      59           57 :   DynamicTimeSequence(bool val = true) : nntrainer::Property<bool>(val) {}
      60              :   static constexpr const char *key = "dynamic_time_seq";
      61              :   using prop_tag = bool_prop_tag;
      62              : };
      63              : 
      64              : /**
      65              :  * @brief Property for recurrent inputs
      66              :  *
      67              :  */
      68          756 : class RecurrentInput final : public Property<Connection> {
      69              : public:
      70              :   /**
      71              :    * @brief Construct a new Recurrent Input object
      72              :    *
      73              :    */
      74              :   RecurrentInput();
      75              : 
      76              :   /**
      77              :    * @brief Construct a new Recurrent Input object
      78              :    *
      79              :    * @param name name
      80              :    */
      81              :   RecurrentInput(const Connection &name);
      82              :   static constexpr const char *key = "recurrent_input";
      83              :   using prop_tag = connection_prop_tag;
      84              : };
      85              : 
      86          194 : RecurrentInput::RecurrentInput() {}
      87          112 : RecurrentInput::RecurrentInput(const Connection &con) { set(con); };
      88              : 
      89              : /**
      90              :  * @brief Property for recurrent outputs
      91              :  *
      92              :  */
      93          582 : class RecurrentOutput final : public Property<Connection> {
      94              : public:
      95              :   /**
      96              :    * @brief Construct a new Recurrent Output object
      97              :    *
      98              :    */
      99              :   RecurrentOutput();
     100              : 
     101              :   /**
     102              :    * @brief Construct a new Recurrent Output object
     103              :    *
     104              :    * @param name name
     105              :    */
     106              :   RecurrentOutput(const Connection &name);
     107              :   static constexpr const char *key = "recurrent_output";
     108              :   using prop_tag = connection_prop_tag;
     109              : };
     110              : 
     111          194 : RecurrentOutput::RecurrentOutput() {}
     112            0 : RecurrentOutput::RecurrentOutput(const Connection &con) { set(con); };
     113              : } // namespace props
     114              : 
     115           57 : RecurrentRealizer::RecurrentRealizer(const std::vector<std::string> &properties,
     116              :                                      const std::vector<Connection> &input_conns,
     117           57 :                                      const std::vector<Connection> &end_conns) :
     118           57 :   input_layers(),
     119              :   end_info(),
     120           57 :   sequenced_return_conns(),
     121              :   recurrent_props(new PropTypes(
     122           57 :     std::vector<props::RecurrentInput>(), std::vector<props::RecurrentOutput>(),
     123          114 :     std::vector<props::AsSequence>(), props::UnrollFor(1),
     124          171 :     std::vector<props::InputIsSequence>(), props::DynamicTimeSequence(false))) {
     125           57 :   auto left = loadProperties(properties, *recurrent_props);
     126              : 
     127           57 :   std::transform(
     128              :     input_conns.begin(), input_conns.end(),
     129           57 :     std::inserter(this->input_layers, this->input_layers.begin()),
     130              :     [](const Connection &c) -> const auto & { return c.getName(); });
     131              : 
     132              :   /// build end info.
     133              :   /// eg)
     134              :   /// end_layers: a(0), a(3), b(0) becomes
     135              :   /// end_info: {{a, 3}, {b, 0}}
     136              :   /// end_layers: a(1), b(3), c(0) becomes
     137              :   /// end_info: {{a, 1}, {b, 3}, {c, 0}}
     138          115 :   for (unsigned i = 0u, sz = end_conns.size(); i < sz; ++i) {
     139           58 :     const auto &name = end_conns[i].getName();
     140           58 :     const auto &idx = end_conns[i].getIndex();
     141              :     auto iter =
     142           58 :       std::find_if(end_info.begin(), end_info.end(),
     143            1 :                    [&name](auto &info) { return info.first == name; });
     144           58 :     if (iter == end_info.end()) {
     145           57 :       end_info.emplace_back(name, idx);
     146              :     } else {
     147            2 :       iter->second = std::max(iter->second, idx);
     148              :     }
     149              :   }
     150              : 
     151              :   auto &[inputs, outputs, as_sequence, unroll_for, input_is_seq,
     152              :          dynamic_time_seq] = *recurrent_props;
     153              : 
     154           57 :   NNTR_THROW_IF(inputs.empty() || inputs.size() != outputs.size(),
     155              :                 std::invalid_argument)
     156              :     << "recurrent inputs and outputs must not be empty and 1:1 map but given "
     157              :        "different size. input: "
     158              :     << inputs.size() << " output: " << outputs.size();
     159              : 
     160              :   /// @todo Deal as sequence as proper connection with identity layer
     161          111 :   NNTR_THROW_IF(!std::all_of(as_sequence.begin(), as_sequence.end(),
     162              :                              [&end_conns](const Connection &seq) {
     163              :                                return std::find(end_conns.begin(),
     164              :                                                 end_conns.end(),
     165              :                                                 seq) != end_conns.end();
     166              :                              }),
     167              :                 std::invalid_argument)
     168              :     << "as_sequence property must be subset of end_layers";
     169              : 
     170          111 :   for (auto &name : as_sequence) {
     171           54 :     sequenced_return_conns.emplace(name.get());
     172              :   };
     173              : 
     174              :   sequenced_input =
     175           57 :     std::unordered_set<std::string>(input_is_seq.begin(), input_is_seq.end());
     176              : 
     177           58 :   for (auto &seq_input : sequenced_input) {
     178            0 :     NNTR_THROW_IF(input_layers.count(seq_input) == 0, std::invalid_argument)
     179              :       << seq_input
     180              :       << " is not found inside input_layers, inputIsSequence argument must be "
     181              :          "subset of inputs";
     182              :   }
     183              : 
     184           57 :   NNTR_THROW_IF(!left.empty(), std::invalid_argument)
     185              :     << "There is unparsed properties";
     186              : 
     187          251 :   for (unsigned i = 0, sz = inputs.size(); i < sz; ++i) {
     188          388 :     recurrent_info.emplace(inputs.at(i).get(), outputs.at(i).get());
     189              :   }
     190           57 : }
     191              : 
     192              : /**
     193              :  * @brief if node is of recurrent type, set time step and max time step
     194              :  *
     195              :  * @param node node
     196              :  * @param time_step time step
     197              :  * @param max_time_step max time step
     198              :  */
     199          212 : static void propagateTimestep(LayerNode *node, unsigned int time_step,
     200              :                               unsigned int max_time_step) {
     201              : 
     202              :   /** @todo add an interface to check if a layer supports a property */
     203          212 :   auto is_recurrent_type = [](LayerNode *node) {
     204          212 :     return node->getType() == ZoneoutLSTMCellLayer::type;
     205              :   };
     206              : 
     207          212 :   if (is_recurrent_type(node)) {
     208          540 :     node->setProperty({"max_timestep=" + std::to_string(max_time_step),
     209          216 :                        "timestep=" + std::to_string(time_step)});
     210              :   }
     211              : 
     212          212 :   return;
     213            0 : }
     214              : 
     215            0 : RecurrentRealizer::RecurrentRealizer(
     216            0 :   const char *ini_path, const std::vector<std::string> &external_input_layers) {
     217              :   /// @todo delegate to RecurrentRealizer(
     218              :   // const std::vector<std::string> &properties,
     219              :   // const std::vector<std::string> &external_input_layers)
     220              :   /// NYI!
     221            0 : }
     222              : 
     223          165 : RecurrentRealizer::~RecurrentRealizer() {}
     224              : 
     225              : GraphRepresentation
     226           57 : RecurrentRealizer::realize(const GraphRepresentation &reference) {
     227              : 
     228              :   auto step0_verify_and_prepare = []() {
     229              :     /// empty intended
     230              :   };
     231              : 
     232              :   /**
     233              :    * @brief maps input place holder to given name otherwise scopped to suffix
     234              :    * "/0"
     235              :    *
     236              :    */
     237              :   auto step1_connect_external_input =
     238           57 :     [this](const GraphRepresentation &reference_, unsigned max_time_idx) {
     239           57 :       RemapRealizer input_mapper([this](std::string &id) {
     240          332 :         if (input_layers.count(id) == 0) {
     241              :           id += "/0";
     242          197 :         } else if (sequenced_input.count(id) != 0) {
     243              :           id += "/0";
     244              :         }
     245           57 :       });
     246              : 
     247           57 :       auto nodes = input_mapper.realize(reference_);
     248          153 :       for (auto &node : nodes) {
     249           96 :         propagateTimestep(node.get(), 0, max_time_idx);
     250              :         /// #1744, quick fix, add shared_from to every node
     251          384 :         node->setProperty({"shared_from=" + node->getName()});
     252              :       }
     253              : 
     254           57 :       return nodes;
     255          153 :     };
     256              : 
     257              :   /**
     258              :    * @brief Create a single time step. Used inside step2_unroll.
     259              :    *
     260              :    */
     261           64 :   auto create_step = [this](const GraphRepresentation &reference_,
     262              :                             unsigned time_idx, unsigned max_time_idx) {
     263              :     GraphRepresentation step;
     264           64 :     step.reserve(reference_.size());
     265              : 
     266          286 :     auto replace_time_idx = [](std::string &name, unsigned time_idx) {
     267              :       auto pos = name.find_last_of('/');
     268          286 :       if (pos != std::string::npos) {
     269          572 :         name.replace(pos + 1, std::string::npos, std::to_string(time_idx));
     270              :       }
     271          286 :     };
     272          180 :     for (auto &node : reference_) {
     273          116 :       auto new_node = node->cloneConfiguration();
     274              : 
     275              :       /// 1. remap identifiers to $name/$idx
     276          116 :       new_node->remapIdentifiers(
     277          116 :         [this, time_idx, replace_time_idx](std::string &id) {
     278          494 :           if (input_layers.count(id) == 0) {
     279          286 :             replace_time_idx(id, time_idx);
     280              :           }
     281          494 :         });
     282              : 
     283              :       /// 2. override first output name to $name/$idx - 1
     284          480 :       for (auto &[recurrent_input, recurrent_output] : recurrent_info) {
     285          728 :         if (node->getName() != recurrent_input.getName() + "/0") {
     286          160 :           continue;
     287              :         }
     288          204 :         new_node->setInputConnectionIndex(recurrent_input.getIndex(),
     289              :                                           recurrent_output.getIndex());
     290          204 :         new_node->setInputConnectionName(recurrent_input.getIndex(),
     291          408 :                                          recurrent_output.getName() + "/" +
     292          408 :                                            std::to_string(time_idx - 1));
     293              :       }
     294              :       /// 3. set shared_from
     295          464 :       new_node->setProperty({"shared_from=" + node->getName()});
     296              :       /// 4. if recurrent layer type set timestep property
     297          116 :       propagateTimestep(new_node.get(), time_idx, max_time_idx);
     298              : 
     299          116 :       step.push_back(std::move(new_node));
     300              :     }
     301           64 :     return step;
     302          116 :   };
     303              : 
     304              :   /**
     305              :    * @brief unroll the graph by calling create_step()
     306              :    *
     307              :    */
     308           57 :   auto step2_unroll = [create_step](const GraphRepresentation &reference_,
     309              :                                     unsigned unroll_for_) {
     310           57 :     GraphRepresentation processed(reference_.begin(), reference_.end());
     311           57 :     processed.reserve(reference_.size() * unroll_for_);
     312              : 
     313          121 :     for (unsigned int i = 1; i < unroll_for_; ++i) {
     314           64 :       auto step = create_step(reference_, i, unroll_for_);
     315           64 :       processed.insert(processed.end(), step.begin(), step.end());
     316           64 :     }
     317              : 
     318           57 :     return processed;
     319            0 :   };
     320              : 
     321              :   /**
     322              :    * @brief case when return sequence is true, concat layer is added to
     323              :    * aggregate all the output
     324              :    *
     325              :    */
     326           54 :   auto concat_output = [](const GraphRepresentation &reference_,
     327              :                           const Connection &con, unsigned unroll_for,
     328              :                           const std::string &new_layer_name) {
     329           54 :     GraphRepresentation processed(reference_.begin(), reference_.end());
     330              : 
     331              :     std::vector<props::RecurrentInput> conns;
     332          166 :     for (unsigned int i = 0; i < unroll_for; ++i) {
     333          112 :       conns.emplace_back(Connection{
     334          224 :         con.getName() + "/" + std::to_string(i),
     335              :         con.getIndex(),
     336              :       });
     337              :     }
     338              :     /// @todo have axis in concat layer
     339              :     /// @todo this has to be wrapped with identity layer as #1793
     340          108 :     auto node = createLayerNode(
     341          324 :       "concat", {"name=" + new_layer_name, "input_layers=" + to_string(conns)});
     342           54 :     processed.push_back(std::move(node));
     343              : 
     344           54 :     return processed;
     345          108 :   };
     346              : 
     347              :   /**
     348              :    * @brief create identity layer with output name by either creating concat
     349              :    * layer or directly using the connection, the number of inputs connection
     350              :    * have is depending on the end_conns max.
     351              :    *
     352              :    * eg)
     353              :    * layer A outputs a, b, c, d
     354              :    *
     355              :    * if end_layers=A(0),A(2)
     356              :    *    as_sequence=A(0)
     357              :    * realizer cannot know there is d so this is ignored. It is okay because user
     358              :    * didn't specify to use it anyway
     359              :    *
     360              :    * [A]
     361              :    * type=identity
     362              :    * input_layers=A_concat_0, A(1), A(2)
     363              :    *
     364              :    */
     365          114 :   auto step3_connect_output = [this, concat_output](
     366              :                                 const GraphRepresentation &reference_,
     367              :                                 unsigned unroll_for) {
     368              :     /// @note below is inefficient way of processing nodes. consider optimize
     369              :     /// below as needed by calling remap realizer only once
     370           57 :     auto processed = reference_;
     371          114 :     for (auto [name, max_idx] : end_info) {
     372              : 
     373              :       std::vector<props::InputConnection> out_node_inputs;
     374              : 
     375          115 :       for (auto i = 0u; i <= max_idx; ++i) {
     376              : 
     377           58 :         if (auto con = Connection(name, i); sequenced_return_conns.count(con)) {
     378          108 :           auto concat_name = name + "/concat_" + std::to_string(i);
     379           54 :           processed = concat_output(processed, con, unroll_for, concat_name);
     380              :           // create concat connection name,
     381          108 :           out_node_inputs.emplace_back(Connection(concat_name, 0));
     382              :         } else {
     383            8 :           auto last_layer_name = name + "/" + std::to_string(unroll_for - 1);
     384            8 :           out_node_inputs.emplace_back(Connection(last_layer_name, i));
     385              :         }
     386              :       }
     387              : 
     388          114 :       auto alias_layer = createLayerNode(
     389              :         "identity",
     390          342 :         {"name=" + name, "input_layers=" + to_string(out_node_inputs)});
     391           57 :       processed.push_back(std::move(alias_layer));
     392           57 :     }
     393              : 
     394           57 :     return processed;
     395          114 :   };
     396              : 
     397           57 :   auto unroll_for = std::get<props::UnrollFor>(*recurrent_props).get();
     398              :   step0_verify_and_prepare();
     399           57 :   auto processed = step1_connect_external_input(reference, unroll_for);
     400           57 :   processed = step2_unroll(processed, unroll_for);
     401          114 :   return step3_connect_output(processed, unroll_for);
     402           57 : }
     403              : 
     404              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1