Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file lstmcell.cpp
6 : * @date 17 March 2021
7 : * @brief This is LSTMCell Layer Class of Neural Network
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #include <layer_context.h>
15 : #include <lstmcell.h>
16 : #include <nntrainer_error.h>
17 : #include <nntrainer_log.h>
18 : #include <node_exporter.h>
19 :
20 : namespace nntrainer {
21 :
22 : enum LSTMCellParams {
23 : weight_ih,
24 : weight_hh,
25 : bias_h,
26 : bias_ih,
27 : bias_hh,
28 : ifgo,
29 : dropout_mask
30 : };
31 :
32 73 : LSTMCellLayer::LSTMCellLayer() : lstmcell_props(props::DropOutRate()) {
33 : wt_idx.fill(std::numeric_limits<unsigned>::max());
34 73 : }
35 :
36 29 : void LSTMCellLayer::finalize(InitLayerContext &context) {
37 : const Initializer weight_initializer =
38 29 : std::get<props::WeightInitializer>(*layer_impl_props).get();
39 : const Initializer bias_initializer =
40 29 : std::get<props::BiasInitializer>(*layer_impl_props).get();
41 : const WeightRegularizer weight_regularizer =
42 29 : std::get<props::WeightRegularizer>(*layer_impl_props).get();
43 : const float weight_regularizer_constant =
44 29 : std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
45 : auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
46 : auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
47 : const bool disable_bias =
48 29 : std::get<props::DisableBias>(*layer_impl_props).get();
49 :
50 29 : NNTR_THROW_IF(std::get<props::Unit>(lstmcore_props).empty(),
51 : std::invalid_argument)
52 : << "unit property missing for lstmcell layer";
53 29 : const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
54 : const bool integrate_bias =
55 29 : std::get<props::IntegrateBias>(lstmcore_props).get();
56 : const ActivationType hidden_state_activation_type =
57 29 : std::get<props::HiddenStateActivation>(lstmcore_props).get();
58 : const ActivationType recurrent_activation_type =
59 29 : std::get<props::RecurrentActivation>(lstmcore_props).get();
60 :
61 29 : const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props).get();
62 :
63 29 : NNTR_THROW_IF(context.getNumInputs() != 3, std::invalid_argument)
64 : << "LSTMCell layer expects 3 inputs(one for the input and two for the "
65 0 : "hidden/cell state) but got " +
66 0 : std::to_string(context.getNumInputs()) + " input(s)";
67 :
68 : // input_dim = [ batch_size, 1, 1, feature_size ]
69 : const TensorDim &input_dim = context.getInputDimensions()[INOUT_INDEX::INPUT];
70 29 : NNTR_THROW_IF(input_dim.channel() != 1 || input_dim.height() != 1,
71 : std::invalid_argument)
72 : << "Input must be single time dimension for LSTMCell (shape should be "
73 : "[batch_size, 1, 1, feature_size])";
74 : // input_hidden_state_dim = [ batch, 1, 1, unit ]
75 : const TensorDim &input_hidden_state_dim =
76 : context.getInputDimensions()[INOUT_INDEX::INPUT_HIDDEN_STATE];
77 29 : NNTR_THROW_IF(input_hidden_state_dim.channel() != 1 ||
78 : input_hidden_state_dim.height() != 1,
79 : std::invalid_argument)
80 : << "Input hidden state's dimension should be [batch, 1, 1, unit] for "
81 : "LSTMCell";
82 : // input_cell_state_dim = [ batch, 1, 1, unit ]
83 : const TensorDim &input_cell_state_dim =
84 : context.getInputDimensions()[INOUT_INDEX::INPUT_CELL_STATE];
85 29 : NNTR_THROW_IF(input_cell_state_dim.channel() != 1 ||
86 : input_cell_state_dim.height() != 1,
87 : std::invalid_argument)
88 : << "Input cell state's dimension should be [batch, 1, 1, unit] for "
89 : "LSTMCell";
90 29 : const unsigned int batch_size = input_dim.batch();
91 29 : const unsigned int feature_size = input_dim.width();
92 :
93 : TensorDim::TensorType weight_tensor_type = {context.getFormat(),
94 : context.getWeightDataType()};
95 :
96 : // output_hidden_state_dim = [ batch_size, 1, 1, unit ]
97 29 : const TensorDim output_hidden_state_dim = input_hidden_state_dim;
98 : // output_cell_state_dim = [ batch_size, 1, 1, unit ]
99 29 : const TensorDim output_cell_state_dim = input_cell_state_dim;
100 :
101 : std::vector<VarGradSpecV2> out_specs;
102 : out_specs.push_back(
103 29 : InitLayerContext::outSpec(output_hidden_state_dim, "output_hidden_state",
104 : TensorLifespan::FORWARD_FUNC_LIFESPAN));
105 : out_specs.push_back(
106 29 : InitLayerContext::outSpec(output_cell_state_dim, "output_cell_state",
107 : TensorLifespan::FORWARD_GRAD_LIFESPAN));
108 29 : context.requestOutputs(std::move(out_specs));
109 :
110 : // weight_initializer can be set seperately. weight_ih initializer,
111 : // weight_hh initializer kernel initializer & recurrent_initializer in keras
112 : // for now, it is set same way.
113 :
114 : // - weight_ih ( input to hidden )
115 : // : [ 1, 1, feature_size, NUM_GATE x unit ] -> i, f, g, o
116 29 : TensorDim weight_ih_dim({feature_size, NUM_GATE * unit}, weight_tensor_type);
117 29 : wt_idx[LSTMCellParams::weight_ih] = context.requestWeight(
118 : weight_ih_dim, weight_initializer, weight_regularizer,
119 : weight_regularizer_constant, weight_decay, "weight_ih", true);
120 : // - weight_hh ( hidden to hidden )
121 : // : [ 1, 1, unit, NUM_GATE x unit ] -> i, f, g, o
122 29 : TensorDim weight_hh_dim({unit, NUM_GATE * unit}, weight_tensor_type);
123 58 : wt_idx[LSTMCellParams::weight_hh] = context.requestWeight(
124 : weight_hh_dim, weight_initializer, weight_regularizer,
125 : weight_regularizer_constant, weight_decay, "weight_hh", true);
126 29 : if (!disable_bias) {
127 29 : if (integrate_bias) {
128 : // - bias_h ( input bias, hidden bias are integrate to 1 bias )
129 : // : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g, o
130 1 : TensorDim bias_h_dim({NUM_GATE * unit}, weight_tensor_type);
131 1 : wt_idx[LSTMCellParams::bias_h] = context.requestWeight(
132 : bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
133 : "bias_h", true);
134 : } else {
135 : // - bias_ih ( input bias )
136 : // : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g, o
137 28 : TensorDim bias_ih_dim({NUM_GATE * unit}, weight_tensor_type);
138 28 : wt_idx[LSTMCellParams::bias_ih] = context.requestWeight(
139 : bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
140 : bias_decay, "bias_ih", true);
141 : // - bias_hh ( hidden bias )
142 : // : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g, o
143 28 : TensorDim bias_hh_dim({NUM_GATE * unit}, weight_tensor_type);
144 56 : wt_idx[LSTMCellParams::bias_hh] = context.requestWeight(
145 : bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
146 : bias_decay, "bias_hh", true);
147 : }
148 : }
149 :
150 : /** ifgo_dim = [ batch_size, 1, 1, NUM_GATE * unit ] */
151 : const TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit,
152 29 : weight_tensor_type);
153 29 : wt_idx[LSTMCellParams::ifgo] =
154 29 : context.requestTensor(ifgo_dim, "ifgo", Initializer::NONE, true,
155 : TensorLifespan::ITERATION_LIFESPAN);
156 :
157 29 : if (dropout_rate > epsilon) {
158 : // dropout_mask_dim = [ batch_size, 1, 1, unit ]
159 : const TensorDim dropout_mask_dim(batch_size, 1, 1, unit,
160 0 : weight_tensor_type);
161 0 : wt_idx[LSTMCellParams::dropout_mask] =
162 0 : context.requestTensor(dropout_mask_dim, "dropout_mask", Initializer::NONE,
163 : false, TensorLifespan::ITERATION_LIFESPAN);
164 : }
165 :
166 29 : if (context.getActivationDataType() == TensorDim::DataType::FP32) {
167 29 : acti_func.setActiFunc<float>(hidden_state_activation_type);
168 29 : recurrent_acti_func.setActiFunc<float>(recurrent_activation_type);
169 0 : } else if (context.getActivationDataType() == TensorDim::DataType::FP16) {
170 : #ifdef ENABLE_FP16
171 : acti_func.setActiFunc<_FP16>(hidden_state_activation_type);
172 : recurrent_acti_func.setActiFunc<_FP16>(recurrent_activation_type);
173 : #else
174 0 : throw std::invalid_argument("Error: enable-fp16 is not enabled");
175 : #endif
176 : }
177 29 : }
178 :
179 250 : void LSTMCellLayer::setProperty(const std::vector<std::string> &values) {
180 : const std::vector<std::string> &remain_props =
181 250 : loadProperties(values, lstmcell_props);
182 249 : LSTMCore::setProperty(remain_props);
183 249 : }
184 :
185 56 : void LSTMCellLayer::exportTo(Exporter &exporter,
186 : const ml::train::ExportMethods &method) const {
187 56 : LSTMCore::exportTo(exporter, method);
188 56 : exporter.saveResult(lstmcell_props, method, this);
189 56 : }
190 :
191 41 : void LSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
192 : const bool disable_bias =
193 41 : std::get<props::DisableBias>(*layer_impl_props).get();
194 :
195 41 : const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
196 : const bool integrate_bias =
197 41 : std::get<props::IntegrateBias>(lstmcore_props).get();
198 :
199 41 : const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props).get();
200 :
201 41 : const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
202 : const Tensor &prev_hidden_state =
203 41 : context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
204 : const Tensor &prev_cell_state =
205 41 : context.getInput(INOUT_INDEX::INPUT_CELL_STATE);
206 41 : Tensor &hidden_state = context.getOutput(INOUT_INDEX::OUTPUT_HIDDEN_STATE);
207 41 : Tensor &cell_state = context.getOutput(INOUT_INDEX::OUTPUT_CELL_STATE);
208 :
209 41 : const unsigned int batch_size = input.getDim().batch();
210 :
211 : const Tensor &weight_ih =
212 41 : context.getWeight(wt_idx[LSTMCellParams::weight_ih]);
213 : const Tensor &weight_hh =
214 41 : context.getWeight(wt_idx[LSTMCellParams::weight_hh]);
215 :
216 : Tensor empty =
217 41 : Tensor("empty", weight_ih.getFormat(), weight_ih.getDataType());
218 :
219 41 : const Tensor &bias_h = !disable_bias && integrate_bias
220 41 : ? context.getWeight(wt_idx[LSTMCellParams::bias_h])
221 : : empty;
222 : const Tensor &bias_ih = !disable_bias && !integrate_bias
223 41 : ? context.getWeight(wt_idx[LSTMCellParams::bias_ih])
224 : : empty;
225 : const Tensor &bias_hh = !disable_bias && !integrate_bias
226 41 : ? context.getWeight(wt_idx[LSTMCellParams::bias_hh])
227 : : empty;
228 :
229 41 : Tensor &ifgo = context.getTensor(wt_idx[LSTMCellParams::ifgo]);
230 :
231 41 : forwardLSTM(batch_size, unit, disable_bias, integrate_bias, acti_func,
232 41 : recurrent_acti_func, input, prev_hidden_state, prev_cell_state,
233 : hidden_state, cell_state, weight_ih, weight_hh, bias_h, bias_ih,
234 : bias_hh, ifgo);
235 :
236 41 : if (dropout_rate > epsilon && training) {
237 : Tensor &dropout_mask =
238 0 : context.getTensor(wt_idx[LSTMCellParams::dropout_mask]);
239 0 : dropout_mask.dropout_mask(dropout_rate);
240 0 : hidden_state.multiply_i(dropout_mask);
241 : }
242 41 : }
243 :
244 19 : void LSTMCellLayer::calcDerivative(RunLayerContext &context) {
245 19 : Tensor &d_ifgo = context.getTensorGrad(wt_idx[LSTMCellParams::ifgo]);
246 : const Tensor &weight_ih =
247 19 : context.getWeight(wt_idx[LSTMCellParams::weight_ih]);
248 : Tensor &outgoing_derivative =
249 19 : context.getOutgoingDerivative(INOUT_INDEX::INPUT);
250 :
251 19 : calcDerivativeLSTM(outgoing_derivative, weight_ih, d_ifgo);
252 19 : }
253 :
254 19 : void LSTMCellLayer::calcGradient(RunLayerContext &context) {
255 : const bool disable_bias =
256 19 : std::get<props::DisableBias>(*layer_impl_props).get();
257 :
258 19 : const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
259 : const bool integrate_bias =
260 19 : std::get<props::IntegrateBias>(lstmcore_props).get();
261 :
262 19 : const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
263 :
264 19 : const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
265 : const Tensor &prev_hidden_state =
266 19 : context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
267 : Tensor &d_prev_hidden_state =
268 19 : context.getOutgoingDerivative(INOUT_INDEX::INPUT_HIDDEN_STATE);
269 : const Tensor &prev_cell_state =
270 19 : context.getInput(INOUT_INDEX::INPUT_CELL_STATE);
271 : Tensor &d_prev_cell_state =
272 19 : context.getOutgoingDerivative(INOUT_INDEX::INPUT_CELL_STATE);
273 : const Tensor &d_hidden_state =
274 19 : context.getIncomingDerivative(INOUT_INDEX::OUTPUT_HIDDEN_STATE);
275 19 : const Tensor &cell_state = context.getOutput(INOUT_INDEX::OUTPUT_CELL_STATE);
276 : const Tensor &d_cell_state =
277 19 : context.getIncomingDerivative(INOUT_INDEX::OUTPUT_CELL_STATE);
278 :
279 19 : unsigned int batch_size = input.getDim().batch();
280 :
281 : Tensor &d_weight_ih =
282 19 : context.getWeightGrad(wt_idx[LSTMCellParams::weight_ih]);
283 : const Tensor &weight_hh =
284 19 : context.getWeight(wt_idx[LSTMCellParams::weight_hh]);
285 : Tensor &d_weight_hh =
286 19 : context.getWeightGrad(wt_idx[LSTMCellParams::weight_hh]);
287 :
288 : Tensor empty =
289 19 : Tensor("empty", weight_hh.getFormat(), weight_hh.getDataType());
290 :
291 19 : Tensor &d_bias_h = !disable_bias && integrate_bias
292 19 : ? context.getWeightGrad(wt_idx[LSTMCellParams::bias_h])
293 : : empty;
294 : Tensor &d_bias_ih = !disable_bias && !integrate_bias
295 19 : ? context.getWeightGrad(wt_idx[LSTMCellParams::bias_ih])
296 : : empty;
297 : Tensor &d_bias_hh = !disable_bias && !integrate_bias
298 19 : ? context.getWeightGrad(wt_idx[LSTMCellParams::bias_hh])
299 : : empty;
300 :
301 19 : const Tensor &ifgo = context.getTensor(wt_idx[LSTMCellParams::ifgo]);
302 19 : Tensor &d_ifgo = context.getTensorGrad(wt_idx[LSTMCellParams::ifgo]);
303 :
304 19 : if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::weight_ih])) {
305 9 : d_weight_ih.setZero();
306 : }
307 19 : if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::weight_hh])) {
308 9 : d_weight_hh.setZero();
309 : }
310 19 : if (!disable_bias) {
311 19 : if (integrate_bias) {
312 1 : if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::bias_h])) {
313 0 : d_bias_h.setZero();
314 : }
315 : } else {
316 18 : if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::bias_ih])) {
317 9 : d_bias_ih.setZero();
318 : }
319 18 : if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::bias_hh])) {
320 9 : d_bias_hh.setZero();
321 : }
322 : }
323 : }
324 :
325 : Tensor d_hidden_state_masked = Tensor(
326 19 : "d_hidden_state_masked", weight_hh.getFormat(), weight_hh.getDataType());
327 :
328 19 : if (dropout_rate > epsilon) {
329 : Tensor &dropout_mask =
330 0 : context.getTensor(wt_idx[LSTMCellParams::dropout_mask]);
331 0 : d_hidden_state.multiply(dropout_mask, d_hidden_state_masked);
332 : }
333 :
334 38 : calcGradientLSTM(batch_size, unit, disable_bias, integrate_bias, acti_func,
335 19 : recurrent_acti_func, input, prev_hidden_state,
336 : d_prev_hidden_state, prev_cell_state, d_prev_cell_state,
337 19 : dropout_rate > epsilon ? d_hidden_state_masked
338 : : d_hidden_state,
339 : cell_state, d_cell_state, d_weight_ih, weight_hh,
340 : d_weight_hh, d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo);
341 19 : }
342 :
343 24 : void LSTMCellLayer::setBatch(RunLayerContext &context, unsigned int batch) {
344 24 : const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
345 24 : context.updateTensor(wt_idx[LSTMCellParams::ifgo], batch);
346 24 : if (dropout_rate > epsilon) {
347 0 : context.updateTensor(wt_idx[LSTMCellParams::dropout_mask], batch);
348 : }
349 24 : }
350 :
351 : } // namespace nntrainer
|