Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file recurrent_realizer.h
6 : * @date 12 October 2021
7 : * @brief NNTrainer graph realizer to create unrolled graph from a graph
8 : * realizer
9 : * @see https://github.com/nnstreamer/nntrainer
10 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
11 : * @bug No known bugs except for NYI items
12 : */
13 : #include <algorithm>
14 : #include <iterator>
15 : #include <stdexcept>
16 : #include <string>
17 :
18 : #include <base_properties.h>
19 : #include <common_properties.h>
20 : #include <connection.h>
21 : #include <input_layer.h>
22 : #include <layer_node.h>
23 : #include <nntrainer_error.h>
24 : #include <node_exporter.h>
25 : #include <recurrent_realizer.h>
26 : #include <remap_realizer.h>
27 : #include <rnncell.h>
28 : #include <util_func.h>
29 : #include <zoneout_lstmcell.h>
30 :
31 : namespace nntrainer {
32 :
33 : namespace props {
34 :
35 : /**
36 : * @brief Property check unroll_for
37 : *
38 : */
39 57 : class UnrollFor final : public PositiveIntegerProperty {
40 : public:
41 : UnrollFor(const unsigned &value = 1);
42 : static constexpr const char *key = "unroll_for";
43 : using prop_tag = uint_prop_tag;
44 : };
45 :
46 57 : UnrollFor::UnrollFor(const unsigned &value) { set(value); }
47 :
48 : /**
49 : * @brief dynamic time sequence property, use this to set and check if dynamic
50 : * time sequence is enabled.
51 : *
52 : */
53 57 : class DynamicTimeSequence final : public nntrainer::Property<bool> {
54 : public:
55 : /**
56 : * @brief Construct a new DynamicTimeSequence object
57 : *
58 : */
59 57 : DynamicTimeSequence(bool val = true) : nntrainer::Property<bool>(val) {}
60 : static constexpr const char *key = "dynamic_time_seq";
61 : using prop_tag = bool_prop_tag;
62 : };
63 :
64 : /**
65 : * @brief Property for recurrent inputs
66 : *
67 : */
68 756 : class RecurrentInput final : public Property<Connection> {
69 : public:
70 : /**
71 : * @brief Construct a new Recurrent Input object
72 : *
73 : */
74 : RecurrentInput();
75 :
76 : /**
77 : * @brief Construct a new Recurrent Input object
78 : *
79 : * @param name name
80 : */
81 : RecurrentInput(const Connection &name);
82 : static constexpr const char *key = "recurrent_input";
83 : using prop_tag = connection_prop_tag;
84 : };
85 :
86 194 : RecurrentInput::RecurrentInput() {}
87 112 : RecurrentInput::RecurrentInput(const Connection &con) { set(con); };
88 :
89 : /**
90 : * @brief Property for recurrent outputs
91 : *
92 : */
93 582 : class RecurrentOutput final : public Property<Connection> {
94 : public:
95 : /**
96 : * @brief Construct a new Recurrent Output object
97 : *
98 : */
99 : RecurrentOutput();
100 :
101 : /**
102 : * @brief Construct a new Recurrent Output object
103 : *
104 : * @param name name
105 : */
106 : RecurrentOutput(const Connection &name);
107 : static constexpr const char *key = "recurrent_output";
108 : using prop_tag = connection_prop_tag;
109 : };
110 :
111 194 : RecurrentOutput::RecurrentOutput() {}
112 0 : RecurrentOutput::RecurrentOutput(const Connection &con) { set(con); };
113 : } // namespace props
114 :
115 57 : RecurrentRealizer::RecurrentRealizer(const std::vector<std::string> &properties,
116 : const std::vector<Connection> &input_conns,
117 57 : const std::vector<Connection> &end_conns) :
118 57 : input_layers(),
119 : end_info(),
120 57 : sequenced_return_conns(),
121 : recurrent_props(new PropTypes(
122 57 : std::vector<props::RecurrentInput>(), std::vector<props::RecurrentOutput>(),
123 114 : std::vector<props::AsSequence>(), props::UnrollFor(1),
124 171 : std::vector<props::InputIsSequence>(), props::DynamicTimeSequence(false))) {
125 57 : auto left = loadProperties(properties, *recurrent_props);
126 :
127 57 : std::transform(
128 : input_conns.begin(), input_conns.end(),
129 57 : std::inserter(this->input_layers, this->input_layers.begin()),
130 : [](const Connection &c) -> const auto & { return c.getName(); });
131 :
132 : /// build end info.
133 : /// eg)
134 : /// end_layers: a(0), a(3), b(0) becomes
135 : /// end_info: {{a, 3}, {b, 0}}
136 : /// end_layers: a(1), b(3), c(0) becomes
137 : /// end_info: {{a, 1}, {b, 3}, {c, 0}}
138 115 : for (unsigned i = 0u, sz = end_conns.size(); i < sz; ++i) {
139 58 : const auto &name = end_conns[i].getName();
140 58 : const auto &idx = end_conns[i].getIndex();
141 : auto iter =
142 58 : std::find_if(end_info.begin(), end_info.end(),
143 1 : [&name](auto &info) { return info.first == name; });
144 58 : if (iter == end_info.end()) {
145 57 : end_info.emplace_back(name, idx);
146 : } else {
147 2 : iter->second = std::max(iter->second, idx);
148 : }
149 : }
150 :
151 : auto &[inputs, outputs, as_sequence, unroll_for, input_is_seq,
152 : dynamic_time_seq] = *recurrent_props;
153 :
154 57 : NNTR_THROW_IF(inputs.empty() || inputs.size() != outputs.size(),
155 : std::invalid_argument)
156 : << "recurrent inputs and outputs must not be empty and 1:1 map but given "
157 : "different size. input: "
158 : << inputs.size() << " output: " << outputs.size();
159 :
160 : /// @todo Deal as sequence as proper connection with identity layer
161 111 : NNTR_THROW_IF(!std::all_of(as_sequence.begin(), as_sequence.end(),
162 : [&end_conns](const Connection &seq) {
163 : return std::find(end_conns.begin(),
164 : end_conns.end(),
165 : seq) != end_conns.end();
166 : }),
167 : std::invalid_argument)
168 : << "as_sequence property must be subset of end_layers";
169 :
170 111 : for (auto &name : as_sequence) {
171 54 : sequenced_return_conns.emplace(name.get());
172 : };
173 :
174 : sequenced_input =
175 57 : std::unordered_set<std::string>(input_is_seq.begin(), input_is_seq.end());
176 :
177 58 : for (auto &seq_input : sequenced_input) {
178 0 : NNTR_THROW_IF(input_layers.count(seq_input) == 0, std::invalid_argument)
179 : << seq_input
180 : << " is not found inside input_layers, inputIsSequence argument must be "
181 : "subset of inputs";
182 : }
183 :
184 57 : NNTR_THROW_IF(!left.empty(), std::invalid_argument)
185 : << "There is unparsed properties";
186 :
187 251 : for (unsigned i = 0, sz = inputs.size(); i < sz; ++i) {
188 388 : recurrent_info.emplace(inputs.at(i).get(), outputs.at(i).get());
189 : }
190 57 : }
191 :
192 : /**
193 : * @brief if node is of recurrent type, set time step and max time step
194 : *
195 : * @param node node
196 : * @param time_step time step
197 : * @param max_time_step max time step
198 : */
199 212 : static void propagateTimestep(LayerNode *node, unsigned int time_step,
200 : unsigned int max_time_step) {
201 :
202 : /** @todo add an interface to check if a layer supports a property */
203 212 : auto is_recurrent_type = [](LayerNode *node) {
204 212 : return node->getType() == ZoneoutLSTMCellLayer::type;
205 : };
206 :
207 212 : if (is_recurrent_type(node)) {
208 540 : node->setProperty({"max_timestep=" + std::to_string(max_time_step),
209 216 : "timestep=" + std::to_string(time_step)});
210 : }
211 :
212 212 : return;
213 0 : }
214 :
215 0 : RecurrentRealizer::RecurrentRealizer(
216 0 : const char *ini_path, const std::vector<std::string> &external_input_layers) {
217 : /// @todo delegate to RecurrentRealizer(
218 : // const std::vector<std::string> &properties,
219 : // const std::vector<std::string> &external_input_layers)
220 : /// NYI!
221 0 : }
222 :
223 165 : RecurrentRealizer::~RecurrentRealizer() {}
224 :
225 : GraphRepresentation
226 57 : RecurrentRealizer::realize(const GraphRepresentation &reference) {
227 :
228 : auto step0_verify_and_prepare = []() {
229 : /// empty intended
230 : };
231 :
232 : /**
233 : * @brief maps input place holder to given name otherwise scopped to suffix
234 : * "/0"
235 : *
236 : */
237 : auto step1_connect_external_input =
238 57 : [this](const GraphRepresentation &reference_, unsigned max_time_idx) {
239 57 : RemapRealizer input_mapper([this](std::string &id) {
240 332 : if (input_layers.count(id) == 0) {
241 : id += "/0";
242 197 : } else if (sequenced_input.count(id) != 0) {
243 : id += "/0";
244 : }
245 57 : });
246 :
247 57 : auto nodes = input_mapper.realize(reference_);
248 153 : for (auto &node : nodes) {
249 96 : propagateTimestep(node.get(), 0, max_time_idx);
250 : /// #1744, quick fix, add shared_from to every node
251 384 : node->setProperty({"shared_from=" + node->getName()});
252 : }
253 :
254 57 : return nodes;
255 153 : };
256 :
257 : /**
258 : * @brief Create a single time step. Used inside step2_unroll.
259 : *
260 : */
261 64 : auto create_step = [this](const GraphRepresentation &reference_,
262 : unsigned time_idx, unsigned max_time_idx) {
263 : GraphRepresentation step;
264 64 : step.reserve(reference_.size());
265 :
266 286 : auto replace_time_idx = [](std::string &name, unsigned time_idx) {
267 : auto pos = name.find_last_of('/');
268 286 : if (pos != std::string::npos) {
269 572 : name.replace(pos + 1, std::string::npos, std::to_string(time_idx));
270 : }
271 286 : };
272 180 : for (auto &node : reference_) {
273 116 : auto new_node = node->cloneConfiguration();
274 :
275 : /// 1. remap identifiers to $name/$idx
276 116 : new_node->remapIdentifiers(
277 116 : [this, time_idx, replace_time_idx](std::string &id) {
278 494 : if (input_layers.count(id) == 0) {
279 286 : replace_time_idx(id, time_idx);
280 : }
281 494 : });
282 :
283 : /// 2. override first output name to $name/$idx - 1
284 480 : for (auto &[recurrent_input, recurrent_output] : recurrent_info) {
285 728 : if (node->getName() != recurrent_input.getName() + "/0") {
286 160 : continue;
287 : }
288 204 : new_node->setInputConnectionIndex(recurrent_input.getIndex(),
289 : recurrent_output.getIndex());
290 204 : new_node->setInputConnectionName(recurrent_input.getIndex(),
291 408 : recurrent_output.getName() + "/" +
292 408 : std::to_string(time_idx - 1));
293 : }
294 : /// 3. set shared_from
295 464 : new_node->setProperty({"shared_from=" + node->getName()});
296 : /// 4. if recurrent layer type set timestep property
297 116 : propagateTimestep(new_node.get(), time_idx, max_time_idx);
298 :
299 116 : step.push_back(std::move(new_node));
300 : }
301 64 : return step;
302 116 : };
303 :
304 : /**
305 : * @brief unroll the graph by calling create_step()
306 : *
307 : */
308 57 : auto step2_unroll = [create_step](const GraphRepresentation &reference_,
309 : unsigned unroll_for_) {
310 57 : GraphRepresentation processed(reference_.begin(), reference_.end());
311 57 : processed.reserve(reference_.size() * unroll_for_);
312 :
313 121 : for (unsigned int i = 1; i < unroll_for_; ++i) {
314 64 : auto step = create_step(reference_, i, unroll_for_);
315 64 : processed.insert(processed.end(), step.begin(), step.end());
316 64 : }
317 :
318 57 : return processed;
319 0 : };
320 :
321 : /**
322 : * @brief case when return sequence is true, concat layer is added to
323 : * aggregate all the output
324 : *
325 : */
326 54 : auto concat_output = [](const GraphRepresentation &reference_,
327 : const Connection &con, unsigned unroll_for,
328 : const std::string &new_layer_name) {
329 54 : GraphRepresentation processed(reference_.begin(), reference_.end());
330 :
331 : std::vector<props::RecurrentInput> conns;
332 166 : for (unsigned int i = 0; i < unroll_for; ++i) {
333 112 : conns.emplace_back(Connection{
334 224 : con.getName() + "/" + std::to_string(i),
335 : con.getIndex(),
336 : });
337 : }
338 : /// @todo have axis in concat layer
339 : /// @todo this has to be wrapped with identity layer as #1793
340 108 : auto node = createLayerNode(
341 324 : "concat", {"name=" + new_layer_name, "input_layers=" + to_string(conns)});
342 54 : processed.push_back(std::move(node));
343 :
344 54 : return processed;
345 108 : };
346 :
347 : /**
348 : * @brief create identity layer with output name by either creating concat
349 : * layer or directly using the connection, the number of inputs connection
350 : * have is depending on the end_conns max.
351 : *
352 : * eg)
353 : * layer A outputs a, b, c, d
354 : *
355 : * if end_layers=A(0),A(2)
356 : * as_sequence=A(0)
357 : * realizer cannot know there is d so this is ignored. It is okay because user
358 : * didn't specify to use it anyway
359 : *
360 : * [A]
361 : * type=identity
362 : * input_layers=A_concat_0, A(1), A(2)
363 : *
364 : */
365 114 : auto step3_connect_output = [this, concat_output](
366 : const GraphRepresentation &reference_,
367 : unsigned unroll_for) {
368 : /// @note below is inefficient way of processing nodes. consider optimize
369 : /// below as needed by calling remap realizer only once
370 57 : auto processed = reference_;
371 114 : for (auto [name, max_idx] : end_info) {
372 :
373 : std::vector<props::InputConnection> out_node_inputs;
374 :
375 115 : for (auto i = 0u; i <= max_idx; ++i) {
376 :
377 58 : if (auto con = Connection(name, i); sequenced_return_conns.count(con)) {
378 108 : auto concat_name = name + "/concat_" + std::to_string(i);
379 54 : processed = concat_output(processed, con, unroll_for, concat_name);
380 : // create concat connection name,
381 108 : out_node_inputs.emplace_back(Connection(concat_name, 0));
382 : } else {
383 8 : auto last_layer_name = name + "/" + std::to_string(unroll_for - 1);
384 8 : out_node_inputs.emplace_back(Connection(last_layer_name, i));
385 : }
386 : }
387 :
388 114 : auto alias_layer = createLayerNode(
389 : "identity",
390 342 : {"name=" + name, "input_layers=" + to_string(out_node_inputs)});
391 57 : processed.push_back(std::move(alias_layer));
392 57 : }
393 :
394 57 : return processed;
395 114 : };
396 :
397 57 : auto unroll_for = std::get<props::UnrollFor>(*recurrent_props).get();
398 : step0_verify_and_prepare();
399 57 : auto processed = step1_connect_external_input(reference, unroll_for);
400 57 : processed = step2_unroll(processed, unroll_for);
401 114 : return step3_connect_output(processed, unroll_for);
402 57 : }
403 :
404 : } // namespace nntrainer
|