Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2022 Jijoong Moon <jijoong.moon@samsung.com>
4 : *
5 : * @file bn_realizer.cpp
6 : * @date 13 April 2022
7 : * @brief NNTrainer graph realizer which remove batch normalization layer for
8 : * inference
9 : * @see https://github.com/nnstreamer/nntrainer
10 : * @author Jijoong Moon <jijoong.moon@samsung.com>
11 : * @bug No known bugs except for NYI items
12 : */
13 : #include <bn_realizer.h>
14 : #include <connection.h>
15 : #include <layer_node.h>
16 : #include <nntrainer_error.h>
17 : #include <nntrainer_log.h>
18 :
19 : #include <algorithm>
20 : #include <stdexcept>
21 : #include <unordered_map>
22 :
23 : namespace nntrainer {
24 :
25 : static constexpr size_t SINGLE_INOUT_IDX = 0;
26 :
27 2 : GraphRepresentation BnRealizer::realize(const GraphRepresentation &reference) {
28 : std::unordered_map<std::string, LayerNode *> existing_nodes;
29 : std::vector<LayerNode *> bn_layers;
30 :
31 2 : std::transform(
32 : reference.begin(), reference.end(),
33 : std::inserter(existing_nodes, existing_nodes.end()),
34 42 : [](auto &node) { return std::pair(node->getName(), node.get()); });
35 :
36 23 : for (auto &node : reference) {
37 42 : if (istrequal(node->getType(), "batch_normalization")) {
38 5 : bn_layers.push_back(node.get());
39 : }
40 : }
41 :
42 7 : for (auto iter = bn_layers.begin(); iter != bn_layers.end(); ++iter) {
43 5 : auto node = (*iter);
44 5 : auto &input_name = node->getInputConnectionName(SINGLE_INOUT_IDX);
45 5 : auto input_node = existing_nodes.at(input_name);
46 :
47 10 : for (unsigned int i = 0; i < input_node->getNumOutputConnections(); ++i) {
48 10 : if (istrequal(node->getName(),
49 : input_node->getOutputConnection(i)->getName())) {
50 5 : input_node->setOutputConnection(
51 : i, node->getOutputConnection(SINGLE_INOUT_IDX)->getName(),
52 : SINGLE_INOUT_IDX);
53 : }
54 : }
55 :
56 5 : auto &output_name = node->getOutputConnection(SINGLE_INOUT_IDX)->getName();
57 5 : auto output_node = existing_nodes.at(output_name);
58 :
59 10 : for (unsigned int i = 0; i < output_node->getNumInputConnections(); ++i) {
60 10 : if (istrequal(node->getName(), output_node->getInputConnectionName(i))) {
61 5 : output_node->setInputConnectionName(
62 : i, node->getInputConnectionName(SINGLE_INOUT_IDX));
63 : }
64 : }
65 : }
66 :
67 : GraphRepresentation processed;
68 23 : for (auto &node : reference) {
69 42 : if (!istrequal(node->getType(), "batch_normalization")) {
70 16 : processed.push_back(node);
71 : }
72 : }
73 :
74 2 : return processed;
75 2 : }
76 :
77 : } // namespace nntrainer
|