Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2022 seongwoo <mhs4670go@naver.com>
4 : *
5 : * @file tflite_export_realizer.cpp
6 : * @date 18 July 2025
7 : * @brief NNTrainer graph realizer which remove loss layer for inference
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author seongwoo <mhs4670go@naver.com>
10 : * @author donghak park <donghak.park@samsung.com>
11 : * @bug No known bugs except for NYI items
12 : */
13 :
14 : #include <algorithm>
15 : #include <cassert>
16 : #include <connection.h>
17 : #include <layer_node.h>
18 : #include <nntrainer_error.h>
19 : #include <nntrainer_log.h>
20 : #include <set>
21 : #include <stdexcept>
22 : #include <string>
23 : #include <tflite_export_realizer.h>
24 : #include <unordered_map>
25 :
26 : namespace nntrainer {
27 :
28 : static constexpr size_t SINGLE_IN_IDX = 0;
29 :
30 : GraphRepresentation
31 6 : TfliteExportRealizer::realize(const GraphRepresentation &reference) {
32 : /// @todo support more loss layers
33 : /// @note Some layers need to consider not removing all semantics.
34 : /// For example, When CrossEntropySigmoidLossLayer needs to be removed,
35 : /// sigmoid computation shouldn't be removed.
36 6 : static const std::set<std::string> loss_type = {"mse"};
37 : std::unordered_map<std::string, LayerNode *> existing_nodes;
38 : std::vector<LayerNode *> loss_layers;
39 :
40 6 : std::transform(
41 : reference.begin(), reference.end(),
42 : std::inserter(existing_nodes, existing_nodes.end()),
43 58 : [](auto &node) { return std::pair(node->getName(), node.get()); });
44 :
45 35 : for (auto &node : reference) {
46 58 : if (loss_type.find(node->getType()) != loss_type.end()) {
47 1 : loss_layers.push_back(node.get());
48 : }
49 : }
50 :
51 7 : for (auto iter = loss_layers.begin(); iter != loss_layers.end(); ++iter) {
52 1 : auto loss_node = (*iter);
53 1 : assert(loss_node->getNumInputConnections() == 1);
54 1 : auto &input_name = loss_node->getInputConnectionName(SINGLE_IN_IDX);
55 1 : auto input_node = existing_nodes.at(input_name);
56 2 : for (unsigned int i = 0; i < input_node->getNumOutputConnections(); ++i) {
57 2 : if (istrequal(loss_node->getName(),
58 : input_node->getOutputConnection(i)->getName())) {
59 : /// Assume that loss layers don't have output connections
60 1 : assert(loss_node->getOutputConnections().size() == 0);
61 1 : input_node->setOutputLayers({});
62 : }
63 : }
64 : }
65 :
66 : GraphRepresentation processed;
67 35 : for (auto &node : reference) {
68 58 : if (loss_type.find(node->getType()) == loss_type.end()) {
69 28 : processed.push_back(node);
70 : }
71 : }
72 :
73 6 : return processed;
74 6 : }
75 :
76 : GraphRepresentation
77 5 : TfliteExportRealizer::realize_dropout(const GraphRepresentation &reference) {
78 5 : static const std::set<std::string> dropout_type = {"dropout"};
79 : std::unordered_map<std::string, LayerNode *> existing_nodes;
80 : std::vector<LayerNode *> dropout_layers;
81 :
82 5 : std::transform(
83 : reference.begin(), reference.end(),
84 : std::inserter(existing_nodes, existing_nodes.end()),
85 46 : [](auto &node) { return std::pair(node->getName(), node.get()); });
86 :
87 : // find dropout layer and push to vector
88 28 : for (auto &node : reference) {
89 46 : if (dropout_type.find(node->getType()) != dropout_type.end()) {
90 1 : dropout_layers.push_back(node.get());
91 : }
92 : }
93 :
94 6 : for (auto iter = dropout_layers.begin(); iter != dropout_layers.end();
95 : ++iter) {
96 1 : auto node = (*iter);
97 1 : auto &input_name = node->getInputConnectionName(SINGLE_IN_IDX);
98 1 : auto input_node = existing_nodes.at(input_name);
99 :
100 2 : for (unsigned int i = 0; i < input_node->getNumOutputConnections(); ++i) {
101 2 : if (istrequal(node->getName(),
102 : input_node->getOutputConnection(i)->getName())) {
103 1 : input_node->setOutputConnection(
104 : i, node->getOutputConnection(i)->getName(), SINGLE_IN_IDX);
105 : }
106 1 : input_node->getOutput(SINGLE_IN_IDX)
107 2 : .setData(node->getOutput(SINGLE_IN_IDX).getMemoryData());
108 : }
109 :
110 1 : auto &output_name = node->getOutputConnection(SINGLE_IN_IDX)->getName();
111 1 : auto output_node = existing_nodes.at(output_name);
112 :
113 2 : for (unsigned int i = 0; i < output_node->getNumInputConnections(); ++i) {
114 2 : if (istrequal(node->getName(), output_node->getInputConnectionName(i))) {
115 1 : output_node->setInputConnectionName(i, node->getInputConnectionName(i));
116 : }
117 : }
118 : }
119 :
120 : GraphRepresentation processed;
121 28 : for (auto &node : reference) {
122 46 : if (dropout_type.find(node->getType()) == dropout_type.end()) {
123 22 : processed.push_back(node);
124 : }
125 : }
126 :
127 5 : return processed;
128 5 : }
129 : } // namespace nntrainer
|