Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file flatten_realizer.cpp
6 : * @date 09 October 2021
7 : * @brief NNTrainer graph realizer which realizes flatten=true to actual node
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : */
12 : #include <flatten_realizer.h>
13 : #include <remap_realizer.h>
14 : #include <unordered_map>
15 :
16 : #include <flatten_layer.h>
17 : #include <layer_node.h>
18 :
19 : namespace nntrainer {
20 :
21 1396 : FlattenRealizer::~FlattenRealizer() {}
22 :
23 : GraphRepresentation
24 697 : FlattenRealizer::realize(const GraphRepresentation &reference) {
25 : GraphRepresentation processed;
26 697 : processed.reserve(reference.size());
27 :
28 : std::unordered_map<std::string /**< layer_name */,
29 : std::string /**< flatten_layer_name */>
30 : remap_table;
31 : std::unordered_map<std::string /**< temp_layer_name */,
32 : std::string /**< layer_name */>
33 : recovery_table;
34 : std::vector<LayerNode *> flatten_nodes;
35 :
36 4757 : for (auto &node : reference) {
37 : /// @note: [node] type=flatten; flatten=true; is awkward but allowed.
38 : /// There is no reason to prohibit this.
39 4060 : processed.push_back(node);
40 4060 : if (node->getFlatten() && !node->getDistribute()) {
41 2 : node->setProperty({"flatten=false"});
42 :
43 2 : auto layer_name = node->getName();
44 :
45 2 : auto flatten_name = layer_name + "/flatten_realized";
46 2 : auto temp_name = flatten_name + "/temp";
47 :
48 2 : remap_table.insert({layer_name, flatten_name});
49 2 : recovery_table.insert({temp_name, layer_name});
50 :
51 : auto flatten_node =
52 6 : createLayerNode(FlattenLayer::type, {"name=" + flatten_name});
53 4 : flatten_node->setProperty({"input_layers=" + temp_name});
54 2 : processed.push_back(std::move(flatten_node));
55 : }
56 : }
57 : processed =
58 1394 : RemapRealizer([&remap_table](std::string &name, unsigned &idx) {
59 4193 : if (auto iter = remap_table.find(name); iter != remap_table.end()) {
60 1 : name = iter->second;
61 : }
62 4193 : })
63 1393 : .realize(processed);
64 : processed =
65 1392 : RemapRealizer([&recovery_table](std::string &name, unsigned &idx) {
66 4193 : if (auto iter = recovery_table.find(name); iter != recovery_table.end()) {
67 2 : name = iter->second;
68 : }
69 4193 : })
70 696 : .realize(processed);
71 :
72 696 : return processed;
73 702 : }
74 : } // namespace nntrainer
|