Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file layer_node.h
6 : * @date 1 April 2021
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
9 : * @author Debadri Samaddar <s.debadri@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : * @brief This is the layer node for network graph
12 : *
13 : * @todo Add printPreset support
14 : *
15 : * @details LayerNode provides a node wrapper around the Layer class to form a
16 : * GraphNode. Each layer is wrapped with LayerNode in order to add it to a
17 : * graph. Each LayerNode contains only 1 layer inside. LayerNode also intercepts
18 : * certain properties of the layer which are either related to graph related
19 : * connections (input_connections, output_connections, activation, flatten,
20 : * distribute, name) or essential for the description of the layer (trainable,
21 : * input_dims) itself. These properties, if needed by the layer object, are
22 : * provided access to via LayerContext.
23 : */
24 :
25 : #ifndef __LAYER_NODE_H__
26 : #define __LAYER_NODE_H__
27 :
28 : #include <memory>
29 : #include <tuple>
30 : #include <vector>
31 :
32 : #include <graph_node.h>
33 : #include <layer.h>
34 : #include <layer_context.h>
35 : #include <layer_devel.h>
36 : #include <weight.h>
37 :
38 : namespace nntrainer {
39 :
40 : class Layer;
41 : class Connection;
42 : class Exporter;
43 : class ContextData;
44 :
45 : namespace props {
46 : class Name;
47 : class Distribute;
48 : class Flatten;
49 : class Loss;
50 : class InputShape;
51 : class Activation;
52 : class SharedFrom;
53 : class InputConnection;
54 : class ClipGradByGlobalNorm;
55 : class Packed;
56 : class LossScaleForMixed;
57 : class ComputeEngine;
58 : } // namespace props
59 :
60 : /**
61 : * @class LayerNode class
62 : * @brief layer node class for the graph
63 : */
64 : class LayerNode final : public ml::train::Layer, public GraphNode {
65 : public:
66 : /**
67 : * @brief Constructor of LayerNode class for v2
68 : * @param l layer to wrap with, the ownership is transferred to layer node
69 : *
70 : */
71 : LayerNode(std::unique_ptr<nntrainer::Layer> &&l);
72 :
73 : /**
74 : * @brief Destructor of LayerNode Class
75 : *
76 : */
77 : ~LayerNode();
78 :
79 : /**
80 : * Support all the interface requirements by ml::train::Layer
81 : */
82 :
83 : /**
84 : * @brief Get the layer type
85 : *
86 : * @return const std::string type representation
87 : */
88 : const std::string getType() const override;
89 :
90 : /**
91 : * @brief set Property of layer
92 : * @param[in] properties values of property
93 : * @retval #ML_ERROR_NONE Successful.
94 : * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
95 : * @details This function accepts vector of properties in the format -
96 : * { std::string property_name=property_val, ...}
97 : *
98 : */
99 : void setProperty(const std::vector<std::string> &properties) override;
100 :
101 : std::string getProperty(const std::string &key) override;
102 :
103 : /**
104 : * @brief Get name of the layer
105 : *
106 : * @retval name of the layer
107 : * @note This name is unique to this layer in a model
108 : * @note This name might be changed once this layer is added to the model
109 : * to keep the name unique to the model
110 : */
111 : const std::string getName() const override;
112 :
113 : /**
114 : * Support all the interface requirements by nntrainer::GraphNode
115 : */
116 :
117 : /**
118 : * @brief set name of layer
119 : *
120 : * @param[in] name Name of the layer
121 : */
122 8796 : void setName(const std::string &name) override {
123 17592 : setProperty({"name=" + name});
124 17592 : }
125 :
126 : /**
127 : * @brief set weight and activation data type of layer
128 : *
129 : * @param[in] weight data type, activation data type
130 : */
131 : void setDataType(const TensorDim::DataType w_type,
132 : const TensorDim::DataType a_type) {
133 4407 : data_type = {w_type, a_type};
134 : }
135 :
136 : /**
137 : * @brief Get the Weight Data Type
138 : *
139 : * @return TensorDim::DataType weight data type
140 : */
141 3134 : const TensorDim::DataType getWeightDataType() const { return data_type[0]; }
142 :
143 : /**
144 : * @brief Get the Activation Data Type
145 : *
146 : * @return TensorDim::DataType activation data type
147 : */
148 : const TensorDim::DataType getActivationDataType() const {
149 : return data_type[1];
150 : }
151 :
152 : /**
153 : * @brief Get the Input Connection Index object
154 : *
155 : * @param nth nth input
156 : * @throws if nth is out of range of getNumInputConnection()
157 : * @return const unsigned index
158 : */
159 : const unsigned getInputConnectionIndex(unsigned nth) const;
160 :
161 : /**
162 : * @brief Get the Input Connection Name object
163 : *
164 : * @param nth nth input
165 : * @throws if nth is out of range of getNumInputConnection()
166 : * @return const std::string& name
167 : */
168 : const std::string &getInputConnectionName(unsigned nth) const;
169 :
170 : /**
171 : * @brief Set the Input Connection Index object
172 : *
173 : * @param nth nth input
174 : * @param index index to set
175 : * @throws if nth is out of range of getNumInputConnection()
176 : */
177 : void setInputConnectionIndex(unsigned nth, unsigned index);
178 :
179 : /**
180 : * @brief Get the Input Connection Name object
181 : *
182 : * @param nth input
183 : * @param index index to set
184 : * @throws if nth is out of range of getNumInputConnection()
185 : * @throws if new identifier is invalid
186 : */
187 : void setInputConnectionName(unsigned nth, const std::string &name);
188 :
189 : /**
190 : * @brief Get the output connection object
191 : *
192 : * @param nth nth input
193 : * @throws if nth is out of range of getNumOutputConnection()
194 : * @return Connection * view of a connection, null means this does not exist
195 : */
196 : const Connection *getOutputConnection(unsigned nth) const;
197 :
198 : /**
199 : * @brief Set the Output Connection
200 : * @note Each output must be identified only ONCE.
201 : * @note when nth comes, getNumOutput() expends to nth + 1 as resize occurs.
202 : * Please also notice none identified intermediate output (or mismatch between
203 : * actual number of out tensors and output) is allowed but will produce
204 : * warning, this implies that the output is not used else where.
205 : * @throw std::invalid_argument when trying to identify output
206 : * more then once
207 : *
208 : * @param nth nth output
209 : * @param name name of the output bound connection
210 : * @param index index of the output bound connection
211 : */
212 : void setOutputConnection(unsigned nth, const std::string &name,
213 : unsigned index);
214 :
215 : /**
216 : * @brief set the compute engine for this node
217 : * @param compute engine (CPU/GPU)
218 : */
219 : void setComputeEngine(const ml::train::LayerComputeEngine &compute_engine =
220 : ml::train::LayerComputeEngine::CPU);
221 :
222 : /**
223 : * @brief Get the input connections for this node
224 : *
225 : * @return list of name of the nodes which form input connections
226 : */
227 9827 : const std::vector<std::string> getInputConnections() const override {
228 19254 : return getInputLayers();
229 : }
230 :
231 : /**
232 : * @brief Get the output connections for this node
233 : *
234 : * @return list of name of the nodes which form output connections
235 : */
236 4385 : const std::vector<std::string> getOutputConnections() const override {
237 9494 : return getOutputLayers();
238 : }
239 :
240 : /**
241 : * @brief get the execution order/location of this node
242 : *
243 : * @retval the execution order/location of this node
244 : */
245 128597 : ExecutionOrder getExecutionOrder() const override { return exec_order; }
246 :
247 : /**
248 : * @brief set the execution order/location of this node
249 : *
250 : * @param exec_order the execution order/location of this node
251 : */
252 0 : void setExecutionOrder(ExecutionOrder exec_order_) override {
253 : exec_order = exec_order_;
254 0 : }
255 :
256 : /**
257 : * Support all the interface requirements by nntrainer::Layer
258 : */
259 :
260 : /**
261 : * @brief Finalize creating the layer node
262 : *
263 : * @param input_dims input dimension provided to be used to set output
264 : * dimensions. if empty function This function must set output dimensions in
265 : * the given context. Further, context can be used to request weights for the
266 : * layer, and any extra tensor required for the operation of the layer.
267 : * @note After calling this it is not allowed to
268 : * change properties.
269 : * @note No memory allocation must be performed in the initialization
270 : * step. Any tensor memory required must be requested to the context which
271 : * will be made available during execution of the layer with the context.
272 : * @note configureRunContext() is expected to called right after this.
273 : */
274 : InitLayerContext
275 : finalize(const std::vector<TensorDim> &input_dims = {},
276 : std::array<std::string, 3> tensor_type = {"NCHW", "FP32", "FP32"},
277 : ml::train::ExecutionMode mode = ml::train::ExecutionMode::TRAIN);
278 :
279 : /**
280 : * @brief Refinalize creating the layer node
281 : *
282 : * @param input_dims input dimension provided to be used to set output
283 : * dimensions. if empty function This function must set output dimensions in
284 : * the given context. Further, context can be used to request weights for the
285 : * layer, and any extra tensor required for the operation of the layer.
286 : * @note After calling this it is not allowed to
287 : * change properties.
288 : * @note No memory allocation must be performed in the reinitialization
289 : * step. Any tensor memory required must be requested to the context which
290 : * will be made available during execution of the layer with the context.
291 : * @note configureRunContext() is expected to called right after this.
292 : */
293 : InitLayerContext refinalize(const std::vector<TensorDim> &input_dims = {});
294 :
295 0 : void initialize() override { layer->initialize(*run_context); }
296 :
297 : /**
298 : * @brief Forward Propagation of a layer
299 : * @param training true if training, false if inference
300 : *
301 : * @details context provides access to the weights (if any), inputs,
302 : * outputs, and tensors (if any) for the layer. Input and output dimensions
303 : * can be access from the inputs/outputs tensors themselves.
304 : */
305 : void forwarding(bool training = true);
306 :
307 : /**
308 : * @brief Incremental forward Propagation of a layer
309 : * @param from start step
310 : * @param to end step
311 : * @param training true if training, false if inference
312 : *
313 : * @details context provides access to the weights (if any), inputs,
314 : * outputs, and tensors (if any) for the layer. Input and output dimensions
315 : * can be access from the inputs/outputs tensors themselves.
316 : */
317 : void incremental_forwarding(unsigned int from, unsigned int to,
318 : bool training = true);
319 :
320 : /**
321 : * @brief calc the derivative to be passed to the previous layer
322 : *
323 : * @details context provides access to the weights (if any), inputs,
324 : * outputs, and tensors (if any) for the layer. Input and output dimensions
325 : * can be access from the inputs/outputs tensors themselves.
326 : */
327 : void calcDerivative();
328 :
329 : /**
330 : * @brief Calculate the derivative of a layer
331 : * @details context provides access to the weights (if any), inputs,
332 : * outputs, and tensors (if any) for the layer. Input and output dimensions
333 : * can be access from the inputs/outputs tensors themselves.
334 : */
335 : void calcGradient();
336 :
337 : /**
338 : * @brief this function helps exporting the layer in a predefined format,
339 : * while workarounding issue caused by templated function type eraser
340 : *
341 : * @param exporter exporter that contains exporting logic
342 : * @param method enum value to identify how it should be exported to
343 : */
344 : void exportTo(Exporter &exporter,
345 : const ml::train::ExportMethods &method) const;
346 :
347 : /**
348 : * @brief Set the batch for the layer
349 : * @param batch Batch value to be set
350 : * @details Update the run context based on the updated batch size if required
351 : */
352 : void setBatch(unsigned int batch);
353 :
354 : /**
355 : * @brief Update the tensors dimensions of the layer by input dimensions
356 : * @param input_dimensions input dimensions of the layer
357 : * @details Update the dimensions of inputs, outputs, weights, tensors based
358 : * on the input dimensions
359 : */
360 : void updateTensorsByInputDimensions(std::vector<TensorDim> input_dimensions);
361 :
362 : /**
363 : * @brief If the current layer can support in-place
364 : * @return true if inplace, else false
365 : */
366 : bool supportInPlace() const;
367 :
368 : /**
369 : * @brief Initialize the in-place settings of the layer
370 : * @details If it is a layer that supports in-place, the default in-place type
371 : * is NONE_RESTRICTING, but if there is a RESTRICTING type among the input
372 : * layers, it is set to NONE in the network_graph.cpp.
373 : * Layers with exceptional behavior such as No-Operation layers should
374 : * override this function.
375 : * @return InPlaceType
376 : */
377 : InPlaceType initializeInPlace();
378 : /**
379 : * @brief Notify that this layer will execute in-place
380 : *
381 : * @param val in place state for the layer
382 : */
383 2565 : void setInPlaceType(InPlaceType val) {
384 2565 : if (val != InPlaceType::NONE && !supportInPlace())
385 0 : throw std::runtime_error("Error setting layer to work in-place");
386 :
387 2565 : inplace_type = val;
388 2565 : }
389 :
390 : /**
391 : * @brief Get if the layer is going to execute in-place
392 : *
393 : * @return Inplace type for the layer
394 : */
395 13953 : InPlaceType getInPlaceType() const { return inplace_type; }
396 :
397 : /**
398 : * @brief Get the inplace direction for the layer
399 : *
400 : * @return InPlaceDirection
401 : */
402 : InPlaceDirection getInPlaceDirection() const;
403 :
404 : /**
405 : * @brief check if this layer requires label to be passed
406 : * @return true if requires a label when training, else false
407 : * @note if requireLabel() == true means, for now, that it is endpoint of
408 : * a graph(numOutlayers == 0). label will be fed to the gradient of hidden
409 : * if requireLabel is true
410 : */
411 : bool requireLabel() const;
412 :
413 : /**
414 : * Add rest of the helper interfaces required by other internal classes
415 : */
416 :
417 : /**
418 : * @brief Get the trainable property of the underlying object
419 : *
420 : * @return boolean true if trainable, else false
421 : */
422 7233 : bool supportBackwarding() const { return getLayer()->supportBackwarding(); }
423 :
424 : /**
425 : * Support interfaces for the properties intercepted from layer
426 : */
427 :
428 : /**
429 : * @brief Get the trainable property of the underlying object
430 : *
431 : * @return boolean true if trainable, else false
432 : */
433 : bool getTrainable() const override;
434 :
435 : /**
436 : * @brief get if the output of this layer must be flatten
437 : * @retval flatten value
438 : */
439 : bool getFlatten() const;
440 :
441 : /**
442 : * @brief Get the Shared From property of the layer node
443 : *
444 : * @return std::string node name where the weights are borrowed
445 : */
446 : std::string getSharedFrom() const;
447 :
448 : /**
449 : * @brief get distribute for this layer
450 : * @retval dist to enable/disable distribute
451 : */
452 : bool getDistribute() const;
453 :
454 : /**
455 : * @brief get activation for this layer
456 : * @retval dist to enable/disable distribute
457 : */
458 : ActivationType getActivationToBeRealized() const;
459 :
460 : /**
461 : * @brief Activation Type Getter
462 : * @retval Activation Type.
463 : */
464 : ActivationType getActivationType() const;
465 :
466 : /**
467 : * @brief Get number of input connections
468 : * @retval number of inputs
469 : */
470 : unsigned int getNumInputConnections() const;
471 :
472 : /**
473 : * @brief Get number of output connections
474 : * @retval number of outputs
475 : */
476 : unsigned int getNumOutputConnections() const;
477 :
478 : /**
479 : * @brief Get number of inputs
480 : * @retval number of inputs
481 : */
482 3702 : unsigned int getNumInputs() const {
483 3703 : NNTR_THROW_IF(!run_context, std::runtime_error)
484 : << __func__ << " layer needs to be finalized first!";
485 3701 : return run_context->getNumInputs();
486 : }
487 :
488 : /**
489 : * @brief Get number of outputs
490 : * @retval number of outputs
491 : */
492 18478 : unsigned int getNumOutputs() const {
493 18479 : NNTR_THROW_IF(!run_context, std::runtime_error)
494 : << __func__ << " layer needs to be finalized first!";
495 18477 : return run_context->getNumOutputs();
496 : }
497 :
498 : /**
499 : * @brief Get the number of weights
500 : *
501 : * @return unsigned int number of weights
502 : */
503 2396 : unsigned int getNumWeights() const {
504 2397 : NNTR_THROW_IF(!run_context, std::runtime_error)
505 : << __func__ << " layer needs to be finalized first!";
506 2395 : return run_context->getNumWeights();
507 : }
508 :
509 : /**
510 : * @brief Set the Output Layers object
511 : *
512 : * @param layers Name of the layers
513 : */
514 : void setOutputLayers(const std::vector<std::string> &layers);
515 :
516 : /**
517 : * @brief check if input shape property is set
518 : *
519 : * @return bool true if input shape property has set
520 : */
521 : bool hasInputShapeProperty() const;
522 :
523 : /**
524 : * @brief Get the input dimension
525 : * @return TensorDim dimension of the input
526 : */
527 : const std::vector<TensorDim> getInputDimensions() const;
528 :
529 : /**
530 : * @brief Get the output dimension
531 : * @return TensorDim dimension of the output
532 : */
533 : const std::vector<TensorDim> getOutputDimensions() const;
534 : /**
535 : * @brief Get the Weight object
536 : * currently, only unittest uses this func.
537 : *
538 : * @param idx Identifier of the weight
539 : * @return Weight& Reference to the weight
540 : */
541 4185 : Weight getWeightWrapper(unsigned int idx) {
542 4186 : NNTR_THROW_IF(!run_context, std::runtime_error)
543 : << __func__ << " layer needs to be finalized first!";
544 4184 : if (run_context->weightHasGradient(idx)) {
545 : return Weight(
546 3992 : run_context->getWeight(idx), run_context->getWeightGrad(idx),
547 7984 : run_context->getWeightFP32(idx), run_context->getWeightName(idx));
548 : } else {
549 768 : return Weight(run_context->getWeight(idx), Tensor(), Tensor(),
550 384 : run_context->getWeightName(idx));
551 : }
552 : }
553 :
554 : /**
555 : * @brief Get the Weight object
556 : *
557 : * @param idx Identifier of the weight
558 : * @return Tensor& Reference to the weight tensor
559 : */
560 1 : Weight &getWeightObject(unsigned int idx) {
561 2 : NNTR_THROW_IF(!run_context, std::runtime_error)
562 : << __func__ << " layer needs to be finalized first!";
563 0 : return run_context->getWeightObject(idx);
564 : }
565 :
566 : /**
567 : * @brief Get the Weight tensor object
568 : *
569 : * @param idx Identifier of the weight
570 : * @return Tensor& Reference to the weight tensor
571 : */
572 2634 : Tensor &getWeight(unsigned int idx) {
573 2635 : NNTR_THROW_IF(!run_context, std::runtime_error)
574 : << __func__ << " layer needs to be finalized first!";
575 2633 : return run_context->getWeight(idx);
576 : }
577 :
578 : /**
579 : * @brief Get the Weight Gradient tensor object
580 : *
581 : * @param idx Identifier of the weight
582 : * @return Tensor& Reference to the weight grad tensor
583 : */
584 2303 : Tensor &getWeightGrad(unsigned int idx) {
585 2304 : NNTR_THROW_IF(!run_context, std::runtime_error)
586 : << __func__ << " layer needs to be finalized first!";
587 2302 : return run_context->getWeightGrad(idx);
588 : }
589 :
590 : /**
591 : * @brief Get the Weight object name
592 : *
593 : * @param idx Identifier of the weight
594 : * @return const std::string &Name of the weight
595 : */
596 3 : const std::string &getWeightName(unsigned int idx) override {
597 4 : NNTR_THROW_IF(!run_context, std::runtime_error)
598 : << __func__ << " layer needs to be finalized first!";
599 2 : return run_context->getWeightName(idx);
600 : }
601 :
602 : /**
603 : * @brief Get weight data of the layer
604 : * @retval weight data of the layer
605 : * @note nntrainer assign the vector and if there is no weights, the size
606 : * of vector is zero
607 : * @note layer needs to be finalized before called.
608 : */
609 5 : const std::vector<float *> getWeights() override {
610 6 : NNTR_THROW_IF(!run_context, std::runtime_error)
611 : << __func__ << " layer needs to be finalized first!";
612 :
613 : std::vector<float *> weights;
614 12 : for (unsigned int idx = 0; idx < getNumWeights(); ++idx) {
615 :
616 8 : if (getWeight(idx).getDataType() ==
617 : ml::train::TensorDim::DataType::FP16) {
618 : #ifdef ENABLE_FP16
619 : _FP16 *data = getWeight(idx).getData<_FP16>();
620 : float *d = new float[getWeight(idx).size()]();
621 : weights.emplace_back(d);
622 : for (unsigned int i = 0; i < getWeight(idx).size(); ++i) {
623 : weights[idx][i] = static_cast<float>(data[i]);
624 : }
625 : #else
626 0 : throw std::runtime_error("enable-fp16 is not set");
627 : #endif
628 : } else {
629 16 : weights.emplace_back(getWeight(idx).getData());
630 : }
631 : }
632 4 : return weights;
633 0 : }
634 :
635 : /**
636 : * @brief Get weight data of the layer
637 : * @param[out] weights : float * arrary to store weight data
638 : * @param[out] weights_dim : TensorDim for each weights
639 : * @note nntrainer assign the vector and if there is no weights, the size
640 : * of vector is zero
641 : * @note layer needs to be finalized before called.
642 : */
643 1 : void getWeights(std::vector<float *> &weights,
644 : std::vector<TensorDim> &weight_dim) override {
645 1 : NNTR_THROW_IF(!run_context, std::runtime_error)
646 : << __func__ << " layer needs to be finalized first!";
647 :
648 : std::vector<int *> weights_dim;
649 3 : for (unsigned int idx = 0; idx < getNumWeights(); ++idx) {
650 2 : TensorDim d = getWeight(idx).getDim();
651 4 : weights.emplace_back(getWeight(idx).getData());
652 2 : weight_dim.emplace_back(d);
653 : }
654 1 : return;
655 1 : }
656 : #ifdef ENABLE_FP16
657 : /**
658 : * @brief Get weight data of the layer
659 : * @retval weight data of the layer
660 : * @note nntrainer assign the vector and if there is no weights, the size
661 : * of vector is zero
662 : * @note layer needs to be finalized before called.
663 : */
664 : const std::vector<_FP16 *> getFP16Weights() override {
665 : NNTR_THROW_IF(!run_context, std::runtime_error)
666 : << __func__ << " layer needs to be finalized first!";
667 :
668 : std::vector<_FP16 *> weights;
669 : for (unsigned int idx = 0; idx < getNumWeights(); ++idx) {
670 : weights.emplace_back(getWeight(idx).getData<_FP16>());
671 : }
672 : return weights;
673 : }
674 :
675 : /**
676 : * @brief Get weight data of the layer
677 : * @param[out] weights : float * arrary to store weight data
678 : * @param[out] weights_dim : TensorDim for each weights
679 : * @note nntrainer assign the vector and if there is no weights, the size
680 : * of vector is zero
681 : * @note layer needs to be finalized before called.
682 : */
683 : void getFP16Weights(std::vector<_FP16 *> &weights,
684 : std::vector<TensorDim> &weight_dim) override {
685 : NNTR_THROW_IF(!run_context, std::runtime_error)
686 : << __func__ << " layer needs to be finalized first!";
687 :
688 : std::vector<int *> weights_dim;
689 : for (unsigned int idx = 0; idx < getNumWeights(); ++idx) {
690 : TensorDim d = getWeight(idx).getDim();
691 : weights.emplace_back(getWeight(idx).getData<_FP16>());
692 : weight_dim.emplace_back(d);
693 : }
694 : return;
695 : }
696 : #endif
697 :
698 : /**
699 : * @brief Set weight data of the layer
700 : * @note Size of vector must be the same with number of weights.
701 : * @note layer needs to be finalized before called.
702 : */
703 : void setWeights(const std::vector<float *> weights) override;
704 :
705 : /**
706 : * @brief Get the Input tensor object
707 : *
708 : * @param idx Identifier of the input
709 : * @return Tensor& Reference to the input grad tensor
710 : */
711 971 : Tensor &getInput(unsigned int idx) {
712 972 : NNTR_THROW_IF(!run_context, std::runtime_error)
713 : << __func__ << " layer needs to be finalized first!";
714 970 : return run_context->getInput(idx);
715 : }
716 :
717 : /**
718 : * @brief Get the Input Grad tensor object
719 : *
720 : * @param idx Identifier of the input
721 : * @return Tensor& Reference to the input grad tensor
722 : */
723 1362 : Tensor &getInputGrad(unsigned int idx) {
724 1363 : NNTR_THROW_IF(!run_context, std::runtime_error)
725 : << __func__ << " layer needs to be finalized first!";
726 1361 : return run_context->getInputGrad(idx);
727 : }
728 :
729 : /**
730 : * @brief Get the Output tensor object
731 : *
732 : * @param idx Identifier of the output
733 : * @return Tensor& Reference to the output tensor
734 : */
735 9552 : Tensor &getOutput(unsigned int idx) {
736 9553 : NNTR_THROW_IF(!run_context, std::runtime_error)
737 : << __func__ << " layer needs to be finalized first!";
738 9551 : return run_context->getOutput(idx);
739 : }
740 :
741 : /**
742 : * @brief Get the Output Grad tensor object
743 : *
744 : * @param idx Identifier of the output
745 : * @return Tensor& Reference to the output grad tensor
746 : */
747 631 : const Tensor getOutputGrad(unsigned int idx) const {
748 632 : NNTR_THROW_IF(!run_context, std::runtime_error)
749 : << __func__ << " layer needs to be finalized first!";
750 630 : return run_context->getOutputGrad(idx);
751 : }
752 :
753 : /**
754 : * @brief Get the Output Grad unsafe
755 : *
756 : * @param idx Identifier of the output
757 : * @return Tensor& Reference to the output grad tensor
758 : */
759 : const Tensor &getOutputGradUnsafe(unsigned int idx) const {
760 : return run_context->getOutputGradUnsafe(idx);
761 : }
762 :
763 : /**
764 : * @brief read layer Weight & Bias data from file
765 : * @param file input file stream
766 : * @param fsu fsu type
767 : * @param mode Execution mode
768 : * @param opt_var read optimizer variables
769 : */
770 : void read(std::ifstream &file, bool opt_var = false,
771 : ml::train::ExecutionMode mode = ml::train::ExecutionMode::TRAIN,
772 : bool fsu = false, size_t start_offset = 0,
773 : bool read_from_offset = false, int file_fd = -1);
774 :
775 : /**
776 : * @brief read layer Weight & Bias data from file
777 : * @param src input file/mmaped stream
778 : * @param fsu fsu type
779 : * @param mode Execution mode
780 : * @param opt_var read optimizer variables
781 : */
782 : void read(ReadSource src, bool opt_var = false,
783 : ml::train::ExecutionMode mode = ml::train::ExecutionMode::TRAIN,
784 : bool fsu = false, size_t start_offset = 0,
785 : bool read_from_offset = false);
786 :
787 : /**
788 : * @brief save layer Weight & Bias data from file
789 : * @param file output file stream
790 : * @param bool save optimizer variables
791 : */
792 : void
793 : save(std::ofstream &file, bool opt_var = false,
794 : ml::train::ExecutionMode mode = ml::train::ExecutionMode::TRAIN) const;
795 :
796 : /**
797 : * @brief clear optimizer variable to initial state
798 : *
799 : */
800 : void clearOptVar();
801 :
802 : /**
803 : * @brief get loss for the layer
804 : * @return loss of the layer
805 : */
806 : float getLoss() const;
807 :
808 : #ifdef PROFILE
809 : int forward_event_key;
810 : int calc_deriv_event_key;
811 : int calc_grad_event_key;
812 : #endif
813 :
814 : /**
815 : * @brief Overriding output stream for layers and it's derived class
816 : */
817 : friend std::ostream &operator<<(std::ostream &out, const LayerNode &l);
818 :
819 : /**
820 : * @brief Get run layer context
821 : *
822 : * @retval run layer context
823 : */
824 25245 : RunLayerContext &getRunContext() {
825 25248 : NNTR_THROW_IF(!run_context, std::runtime_error)
826 : << __func__ << " layer needs to be configured first!";
827 25242 : return *run_context;
828 : }
829 :
830 : /**
831 : * @brief Get run layer context
832 : *
833 : * @retval run layer context
834 : */
835 22 : const RunLayerContext &getRunContext() const {
836 22 : NNTR_THROW_IF(!run_context, std::runtime_error)
837 : << __func__ << " layer needs to be configured first!";
838 22 : return *run_context;
839 : }
840 :
841 : #ifdef ENABLE_TEST
842 : /**
843 : * @brief Get init layer context
844 : *
845 : * @retval init layer context
846 : */
847 32 : InitLayerContext &getInitContext() {
848 32 : NNTR_THROW_IF(!init_context, std::runtime_error)
849 : << __func__ << " layer needs to be finalized first!";
850 32 : return *init_context;
851 : }
852 :
853 : /**
854 : * @brief Get init layer context
855 : *
856 : * @retval init layer context
857 : */
858 : const InitLayerContext &getInitContext() const {
859 : NNTR_THROW_IF(!init_context, std::runtime_error)
860 : << __func__ << " layer needs to be finalized first!";
861 : return *init_context;
862 : }
863 : #endif // ENABLE_TEST
864 :
865 : /**
866 : * @brief check if layer is finalized
867 : *
868 : * @retval bool true if the layer is finalized else false
869 : */
870 : bool isFinalized() const {
871 23164 : if (!run_context)
872 0 : return false;
873 :
874 : return true;
875 : }
876 :
877 : /**
878 : * @brief Set the Run Context object with given tensor packs
879 : *
880 : * @param weights weights
881 : * @param inputs inputs
882 : * @param outputs outputs
883 : * @param tensors tensors
884 : */
885 : void configureRunContext(const std::vector<Weight *> &weights,
886 : const std::vector<Var_Grad *> &inputs,
887 : const std::vector<Var_Grad *> &outputs,
888 : const std::vector<Var_Grad *> &tensors,
889 : float loss_scale,
890 : std::shared_ptr<ContextData> ct_data = nullptr);
891 :
892 : /**
893 : * @brief Preset modes for printing summary for the layer
894 : */
895 : enum class PrintPreset {
896 : PRINT_NONE = 0, /**< Print nothing */
897 : PRINT_SUMMARY, /**< Print preset including summary information */
898 : PRINT_SUMMARY_META, /**< Print summary preset that includes meta information
899 : */
900 : PRINT_ALL /**< Print everything possible */
901 : };
902 :
903 : /**
904 : * @brief print using PrintPreset
905 : *
906 : * @param out oustream
907 : * @param preset preset to be used
908 : */
909 : void printPreset(std::ostream &out,
910 : PrintPreset preset = PrintPreset::PRINT_SUMMARY);
911 :
912 : /**
913 : * @brief remap identifier inside layer node
914 : *
915 : * @param remap_fn function to remap
916 : */
917 : void remapIdentifiers(std::function<void(std::string &)> remap_fn);
918 :
919 : /**
920 : * @brief remap connections(input, output layers ) inside layer node
921 : *
922 : * @param remap_fn function to remap
923 : */
924 : void
925 : remapConnections(std::function<void(std::string &, unsigned &)> remap_fn);
926 :
927 : /**
928 : * @brief create the same node with same properties and types
929 : *
930 : * @note this must be done before finalize() as finalize has some potential to
931 : * change some properties
932 : * @return LayerNode newly created node
933 : */
934 : std::unique_ptr<LayerNode> cloneConfiguration();
935 :
936 : /**
937 : * @brief Set if the layer needs to do derivative calculation
938 : *
939 : * @param nb true if the layer needs to do backwarding, else false
940 : */
941 3972 : void needsCalcDerivative(bool nb) {
942 7939 : NNTR_THROW_IF(nb && !supportBackwarding(), std::invalid_argument)
943 0 : << "[Layer] " << getName()
944 : << " does not support backwarding but is needed";
945 3972 : needs_calc_derivative = nb;
946 3972 : }
947 :
948 : /**
949 : * @brief Set if the layer output needs reinitialization @mixed precsion
950 : *
951 : * @param nb true if the layer needs to do reinitialization, eles false
952 : */
953 : void reStoreData(bool nb) {
954 27303 : needs_restore_data = nb;
955 : run_context->reStoreData(nb);
956 : }
957 :
958 : /**
959 : * @brief Set if the layer needs to do calculation of gradients
960 : *
961 : * @param nb true if the layer needs to do backwarding, else false
962 : */
963 1442 : void needsCalcGradient(bool nb) { needs_calc_gradient = nb; }
964 :
965 : /**
966 : * @brief Get the layer needs to do calculation of derivatives
967 : *
968 : * @return true if the layer needs to do backwarding, else false
969 : */
970 26867 : bool needsCalcDerivative() { return needs_calc_derivative; }
971 :
972 : /**
973 : * @brief Set if the layer needs to do calculation of gradient
974 : *
975 : * @param nb true if the layer needs to do backwarding, else false
976 : */
977 1100 : bool needsCalcGradient() { return needs_calc_gradient; }
978 :
979 : /**
980 : * @brief Set if the layer needs to reinitialization @mixed precsion
981 : *
982 : * @param nb true if the layer needs reinitialization, eles false
983 : */
984 27303 : bool reStoreData() { return needs_restore_data; }
985 :
986 4435 : std::string getComputeEngineType() {
987 : auto size = props::ComputeEngineTypeInfo::EnumList.size();
988 : auto data = std::data(props::ComputeEngineTypeInfo::EnumList);
989 4435 : for (unsigned i = 0; i < size; ++i) {
990 4435 : if (data[i] == compute_engine) {
991 4435 : return props::ComputeEngineTypeInfo::EnumStr[i];
992 : }
993 : }
994 0 : return "cpu";
995 : }
996 :
997 : private:
998 : /**
999 : * @brief Get the Input Layers object
1000 : *
1001 : * @return const std::vector<std::string>
1002 : */
1003 : const std::vector<std::string> getInputLayers() const;
1004 :
1005 : /**
1006 : * @brief Get the Output Layers object
1007 : *
1008 : * @return const std::vector<std::string>
1009 : */
1010 : const std::vector<std::string> getOutputLayers() const;
1011 :
1012 : std::unique_ptr<nntrainer::Layer>
1013 : layer; /**< The actual object in the graph node */
1014 :
1015 : InPlaceType inplace_type; /**< store if the current layer is going to operate
1016 : in-place */
1017 : bool needs_calc_derivative; /**< cache if this layer needs to do
1018 : calcDerivative */
1019 : bool needs_calc_gradient; /**< cache if this layer needs to do calcGradient */
1020 :
1021 : std::vector<std::unique_ptr<Connection>>
1022 : output_connections; /**< output layer names */
1023 :
1024 : /**
1025 : * @brief compute_engine Information about the compute backend being used
1026 : *
1027 : */
1028 : ml::train::LayerComputeEngine compute_engine =
1029 : ml::train::LayerComputeEngine::CPU;
1030 :
1031 : #ifdef ENABLE_TEST
1032 : /**
1033 : * @brief Init context which is stored for debugging issue
1034 : *
1035 : * @note init context is stored only for testing purpose
1036 : */
1037 : std::unique_ptr<InitLayerContext> init_context;
1038 : #endif // ENABLE_TEST
1039 :
1040 : std::unique_ptr<RunLayerContext>
1041 : run_context; /**< context required for running/execution of the layer. This
1042 : will also contain the properties of the layer. The properties will be copied
1043 : upon final creation. Editing properties of the layer after init will not the
1044 : properties in the context/graph unless intended. */
1045 :
1046 : using PropsType =
1047 : std::tuple<props::Name, props::Distribute, props::Trainable,
1048 : std::vector<props::InputConnection>,
1049 : std::vector<props::InputShape>, props::SharedFrom,
1050 : props::ClipGradByGlobalNorm, props::Packed, props::WeightDtype,
1051 : props::LossScaleForMixed, props::ComputeEngine>;
1052 :
1053 : using RealizationPropsType = std::tuple<props::Flatten, props::Activation>;
1054 : /** these realization properties results in addition of new layers, hence
1055 : * skipped in generation of model architecture as the correspondingly layer
1056 : * itself is added. Distribute is also a property which is realized, but as it
1057 : * doesn't add new layer, it is saved. */
1058 :
1059 : /**
1060 : * These properties are set for the layer by the user but are intercepted
1061 : * and used in the node which forms the basic element of the graph.
1062 : */
1063 : std::unique_ptr<PropsType> layer_node_props; /**< properties for the node */
1064 : std::unique_ptr<RealizationPropsType>
1065 : layer_node_props_realization; /**< properties for the node */
1066 : std::unique_ptr<props::Loss> loss; /**< loss */
1067 : ExecutionOrder exec_order; /**< order/location of execution for this node
1068 : in forward and backwarding operations */
1069 :
1070 : bool needs_restore_data; /**< cache if this layer needs reinitialization
1071 : output */
1072 :
1073 : std::array<TensorDim::DataType, 2> data_type;
1074 :
1075 : /**
1076 : * @brief Get the effective layer managed by this layer node
1077 : *
1078 : * @details this is layer inside the distribution layer if this layer node
1079 : * is distributed.
1080 : */
1081 : const nntrainer::Layer *getLayer() const;
1082 :
1083 : /**
1084 : * @brief Get the effective layer managed by this layer node
1085 : *
1086 : * @details this is layer inside the distribution layer if this layer node
1087 : * is distributed.
1088 : */
1089 : nntrainer::Layer *getLayer();
1090 :
1091 : /**
1092 : * @brief anchor point to override if PRINT_SHAPE_INFO is enabled for
1093 : * Layer::print()
1094 : */
1095 : void printShapeInfo(std::ostream &out);
1096 :
1097 : /**
1098 : * @brief anchor point to override if PRINT_METRIC is enabled for
1099 : * Layer::print()
1100 : */
1101 : void printMetric(std::ostream &out);
1102 :
1103 : /**
1104 : * @brief Print layer related information. Do not override without clear
1105 : * reason. It is recommended to override printShapeInfo, printPropertiesMeta,
1106 : * printProperties, printMetric instead
1107 : * @param[in] out outstream
1108 : * @param[in] flags combination of LayerPrintOption
1109 : */
1110 : void print(std::ostream &out, unsigned int flags = 0);
1111 : };
1112 :
1113 : /**
1114 : * @brief LayerNode creator with constructor
1115 : *
1116 : * @params[in] type Type of the layer to be constructed
1117 : * @params[in] properties Properties of the layer
1118 : * @params[in] compute engine for the layer to run on
1119 : */
1120 : std::unique_ptr<LayerNode>
1121 : createLayerNode(const ml::train::LayerType &type,
1122 : const std::vector<std::string> &properties = {});
1123 :
1124 : /**
1125 : * @brief LayerNode creator with constructor
1126 : *
1127 : * @params[in] type Type of the layer to be constructed
1128 : * @params[in] properties Properties of the layer
1129 : */
1130 : std::unique_ptr<LayerNode>
1131 : createLayerNode(const std::string &type,
1132 : const std::vector<std::string> &properties = {});
1133 :
1134 : /**
1135 : * @brief LayerNode creator with constructor
1136 : *
1137 : * @params[in] layer Already constructed layer
1138 : * @params[in] properties Properties of the layer
1139 : * @params[in] compute engine for the layer to run on
1140 : */
1141 : std::unique_ptr<LayerNode>
1142 : createLayerNode(std::unique_ptr<nntrainer::Layer> &&layer,
1143 : const std::vector<std::string> &properties);
1144 :
1145 : } // namespace nntrainer
1146 : #endif // __LAYER_NODE_H__
|