Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file multiout_realizer.h
6 : * @date 17 November 2021
7 : * @brief NNTrainer graph realizer which realizes multiout to actual node
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : */
12 : #include <stdexcept>
13 : #include <string>
14 : #include <unordered_map>
15 : #include <unordered_set>
16 : #include <utility>
17 :
18 : #include <common_properties.h>
19 : #include <compiler_fwd.h>
20 : #include <connection.h>
21 : #include <layer_node.h>
22 : #include <multiout_realizer.h>
23 : #include <remap_realizer.h>
24 :
25 : namespace nntrainer {
26 1400 : MultioutRealizer::~MultioutRealizer() {}
27 :
28 : GraphRepresentation
29 701 : MultioutRealizer::realize(const GraphRepresentation &reference) {
30 701 : GraphRepresentation processed(reference.begin(), reference.end());
31 :
32 : std::unordered_map<Connection, unsigned> freq_map;
33 : std::unordered_set<std::string> node_names;
34 : std::vector<Connection> connections;
35 :
36 : /// 1. build frequency map and connection names
37 4661 : for (auto &node : reference) {
38 7922 : NNTR_THROW_IF(node_names.count(node->getName()), std::invalid_argument)
39 2 : << "node name clashes: " << node->getName();
40 7920 : node_names.emplace(node->getName());
41 :
42 8050 : for (unsigned int i = 0, num_nodes = node->getNumInputConnections();
43 8050 : i < num_nodes; ++i) {
44 : Connection c(node->getInputConnectionName(i),
45 4090 : node->getInputConnectionIndex(i));
46 4090 : [[maybe_unused]] auto [iter, result] = freq_map.try_emplace(c, 0);
47 4090 : if (result)
48 3881 : connections.push_back(c);
49 4090 : iter->second++;
50 : }
51 : }
52 :
53 : /// 2. for each connection names, if a connection is referenced multiple
54 : /// times, create multioutput node and remap to multi output node index
55 : std::unordered_map<
56 : std::string /**< original id */,
57 : std::vector<std::shared_ptr<LayerNode>> /**< created node */>
58 : multiout_nodes;
59 :
60 4581 : for (auto &con : connections) {
61 3881 : unsigned freq = freq_map[con];
62 : /// @note freq < 1 should never happen as the map entry is not created.
63 : /// but if it happens multiout realizer is not interested in checking if it
64 : /// is a dangled or actually an output. So there is no assurance done at
65 : /// this point. Some other class must check if the given graph is formed in
66 : /// a correct way.
67 3881 : if (freq <= 1) {
68 3741 : continue;
69 : }
70 :
71 : std::string id = con.getName();
72 140 : auto idx = con.getIndex();
73 :
74 140 : std::stringstream ss;
75 : /// {connection_name}/generated_out_{index}
76 : ss << id << "/generated_out_" << idx;
77 140 : while (node_names.count(ss.str()) != 0) {
78 0 : ss << "_";
79 : }
80 : auto multiout_name = ss.str();
81 :
82 560 : multiout_nodes[id].push_back(createLayerNode(
83 280 : "multiout", {"name=" + multiout_name, "input_layers=" + con.toString()}));
84 : node_names.emplace(multiout_name);
85 :
86 140 : unsigned input_count = 0;
87 2901 : RemapRealizer remapper([&id, &multiout_name, idx,
88 : &input_count](std::string &id_, unsigned &idx_) {
89 2621 : if (id_ == id && idx_ == idx) {
90 349 : id_ = multiout_name;
91 349 : idx_ = input_count++;
92 : }
93 140 : });
94 :
95 140 : processed = remapper.realize(processed);
96 280 : }
97 :
98 : /// 3. insert multiout_nodes close to the original node to make the
99 : /// realization more sensible
100 : GraphRepresentation ret;
101 700 : ret.reserve(processed.size());
102 4659 : for (auto &node : processed) {
103 3959 : ret.push_back(node);
104 11877 : auto ranges = multiout_nodes[node->getName()];
105 4099 : for (auto &it : ranges) {
106 140 : ret.push_back(it);
107 : }
108 3959 : }
109 :
110 700 : return ret;
111 1542 : }
112 :
113 : } // namespace nntrainer
|