Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file input_realizer.cpp
6 : * @date 14 October 2021
7 : * @brief NNTrainer graph realizer which remaps input to the external graph
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : */
12 : #include <connection.h>
13 : #include <input_realizer.h>
14 : #include <layer_node.h>
15 : #include <nntrainer_error.h>
16 : #include <nntrainer_log.h>
17 :
18 : #include <algorithm>
19 : #include <stdexcept>
20 : #include <unordered_map>
21 :
22 : namespace nntrainer {
23 55 : InputRealizer::InputRealizer(const std::vector<Connection> &start_conns,
24 55 : const std::vector<Connection> &input_conns) :
25 55 : start_conns(start_conns),
26 55 : input_conns(input_conns) {
27 56 : NNTR_THROW_IF(start_conns.size() != input_conns.size(), std::invalid_argument)
28 : << "start connection size is not same input_conns size";
29 56 : }
30 :
31 105 : InputRealizer::~InputRealizer() {}
32 :
33 : GraphRepresentation
34 54 : InputRealizer::realize(const GraphRepresentation &reference) {
35 : std::unordered_map<std::string, LayerNode *> existing_nodes;
36 :
37 54 : std::transform(
38 : reference.begin(), reference.end(),
39 : std::inserter(existing_nodes, existing_nodes.end()),
40 174 : [](auto &node) { return std::pair(node->getName(), node.get()); });
41 :
42 244 : for (unsigned i = 0u, sz = start_conns.size(); i < sz; ++i) {
43 192 : const auto &sc = start_conns[i];
44 : const auto &ic = input_conns[i];
45 192 : auto node = existing_nodes.at(sc.getName());
46 :
47 192 : auto num_connection = node->getNumInputConnections();
48 192 : if (num_connection == 0) {
49 4 : NNTR_THROW_IF(sc.getIndex() != 0, std::invalid_argument)
50 2 : << "start connection: " << sc.toString()
51 : << " not defined and num connection of that node is empty, although "
52 : "start connection of index zero is allowed";
53 12 : node->setProperty({"input_layers=" + ic.toString()});
54 : } else {
55 190 : NNTR_THROW_IF(sc.getIndex() >= num_connection, std::invalid_argument)
56 2 : << "start connection: " << sc.toString()
57 : << " not defined, num connection: " << num_connection;
58 187 : node->setInputConnectionName(sc.getIndex(), ic.getName());
59 187 : node->setInputConnectionIndex(sc.getIndex(), ic.getIndex());
60 : }
61 : }
62 :
63 104 : return reference;
64 : }
65 :
66 : } // namespace nntrainer
|