Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Jijoong Moon <jijoong.moon@samsung.com>
4 : *
5 : * @file time_dist.h
6 : * @date 01 April 2021
7 : * @brief This is Time Distributed Layer Class of Neural Network
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jijoong Moon <jijoong.moon@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #ifndef __TIME_DIST_H__
15 : #define __TIME_DIST_H__
16 : #ifdef __cplusplus
17 :
18 : #include <layer_devel.h>
19 : #include <weight.h>
20 :
21 : namespace nntrainer {
22 :
23 : /**
24 : * @class TimeDistLayer
25 : * @brief Time Distribution Layer
26 : */
27 : class TimeDistLayer : public Layer {
28 : public:
29 : /**
30 : * @brief Constructor of Time Distribution Layer
31 : */
32 2 : TimeDistLayer() : Layer() {
33 10 : for (unsigned int i = 0; i < 4; ++i) {
34 8 : positions[i] = nullptr;
35 : }
36 : }
37 :
38 : /**
39 : * @brief Destructor of Time Distributed Layer
40 : */
41 4 : ~TimeDistLayer() = default;
42 :
43 : /**
44 : * @brief Move constructor.
45 : * @param[in] TimeDistLayer &&
46 : */
47 : TimeDistLayer(TimeDistLayer &&rhs) noexcept = default;
48 :
49 : /**
50 : * @brief Move assignment operator.
51 : * @parma[in] rhs TimeDistLayer to be moved.
52 : */
53 : TimeDistLayer &operator=(TimeDistLayer &&rhs) = default;
54 :
55 : /**
56 : * @copydoc Layer::finalize(InitLayerContext &context)
57 : */
58 : void finalize(InitLayerContext &context) override;
59 :
60 : /**
61 : * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
62 : */
63 : void forwarding(RunLayerContext &context, bool training) override;
64 :
65 : /**
66 : * @copydoc Layer::calcDerivative(RunLayerContext &context)
67 : */
68 : void calcDerivative(RunLayerContext &context) override;
69 :
70 : /**
71 : * @copydoc Layer::calcGradient(RunLayerContext &context)
72 : */
73 : void calcGradient(RunLayerContext &context) override;
74 :
75 : /**
76 : * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
77 : * method)
78 : */
79 0 : void exportTo(Exporter &exporter,
80 : const ml::train::ExportMethods &method) const override {
81 0 : dist_layer->exportTo(exporter, method);
82 0 : }
83 :
84 : /**
85 : * @copydoc Layer::getType()
86 : */
87 22 : const std::string getType() const override { return TimeDistLayer::type; };
88 :
89 : /**
90 : * @copydoc Layer::supportBackwarding()
91 : */
92 0 : bool supportBackwarding() const override {
93 0 : return dist_layer->supportBackwarding();
94 : }
95 :
96 : /**
97 : * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch)
98 : */
99 : void setBatch(RunLayerContext &context, unsigned int batch) override;
100 :
101 : /**
102 : * @copydoc Layer::setProperty(const PropertyType type, const std::string
103 : * &value)
104 : */
105 0 : void setProperty(const std::vector<std::string> &values) override {
106 : /**
107 : * @note assumption: name of the dist_layer is set via setName() and not
108 : * with setProperty()
109 : */
110 0 : if (!values.empty())
111 0 : dist_layer->setProperty(values);
112 0 : }
113 :
114 : /**
115 : * @copydoc Layer::requireLabel()
116 : */
117 4 : virtual bool requireLabel() const override {
118 4 : return dist_layer->requireLabel();
119 : }
120 :
121 : /**
122 : * @brief set distribute layer
123 : * @param[in] l layer to distribute along time
124 : */
125 : void setDistLayer(std::unique_ptr<Layer> &&l) { dist_layer = std::move(l); }
126 :
127 : /**
128 : * @brief get distribute layer
129 : * @retval dist_layer std::shared_ptr<Layer>
130 : */
131 : Layer *getDistLayer() { return dist_layer.get(); };
132 :
133 : /**
134 : * @brief get distribute layer
135 : * @retval dist_layer std::shared_ptr<Layer>
136 : */
137 : const Layer *getDistLayer() const { return dist_layer.get(); };
138 :
139 : static constexpr const char *type = "time_dist";
140 :
141 : private:
142 : /**
143 : * @brief Layer to be distributed through time
144 : */
145 : std::unique_ptr<Layer> dist_layer;
146 : std::vector<Weight> weights_wrapper;
147 : std::vector<Var_Grad> tensors_wrapper;
148 :
149 : /**
150 : * @brief pointer value of each input/output tensors to compare position
151 : */
152 : float *positions[4];
153 :
154 : /**
155 : * @brief Transpose Input and Output Tensors to avoid duplicatation becuase
156 : * of memory optimization
157 : * It transpose the net_input.getVariableRef, net_input.getGradientRef,
158 : * net_hidden.getVariableRef and net_hidden.getGradientRef.
159 : *
160 : * @param context Run layer context
161 : */
162 : void transposeInOut(RunLayerContext &context);
163 :
164 : /**
165 : * @brief get transposed Tensor according to time iteration axis
166 : * [b, 1, h, w] to [h, 1, b, w]
167 : * @param[in] m Tensor
168 : * @retval Tensor transposed Tensor
169 : */
170 : static Tensor transposeTensor(Tensor &m);
171 :
172 : /**
173 : * @brief calculate the pointer of each input and output tensors
174 : *
175 : * @param context Run layer context
176 : */
177 : void setPosition(RunLayerContext &context);
178 :
179 : /**
180 : * @brief Fill weights from the given context
181 : *
182 : * @param context The given context
183 : */
184 : void fillWeightsFromContext(RunLayerContext &context);
185 :
186 : /**
187 : * @brief Get the Weights for Context object
188 : *
189 : * @return std::vector<Weight *> The list of weights
190 : */
191 : std::vector<Weight *> getWeightsForContext();
192 :
193 : /**
194 : * @brief Fill tensors from the given context
195 : *
196 : * @param context The given context
197 : */
198 : void fillTensorsFromContext(RunLayerContext &context);
199 :
200 : /**
201 : * @brief Get the Tensors for Context object
202 : *
203 : * @return std::vector<Var_Grad *> The list of tensors
204 : */
205 : std::vector<Var_Grad *> getTensorsForContext();
206 :
207 : /**
208 : * @brief Clean the values filled from context
209 : *
210 : * @note This is necessary to ensure that all the references to the stored
211 : * tensors are cleared for the memory to be released after run is complete.
212 : *
213 : */
214 0 : void clearFromContext() {
215 : weights_wrapper.clear();
216 : tensors_wrapper.clear();
217 0 : }
218 :
219 : /**
220 : * @brief Fill init context from the given dist context
221 : *
222 : * @param context context to be set/filled
223 : * @param dist_context context from which to be filled
224 : */
225 : void fillLayerInitContext(InitLayerContext &context,
226 : const InitLayerContext &dist_context);
227 : };
228 : } // namespace nntrainer
229 :
230 : #endif /* __cplusplus */
231 : #endif /* __TIME_DIST_H__ */
|