Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Jijoong Moon <jijoong.moon@samsung.com>
4 : *
5 : * @file lstm.cpp
6 : * @date 17 March 2021
7 : * @brief This is Long Short-Term Memory Layer Class of Neural Network
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jijoong Moon <jijoong.moon@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #include <layer_context.h>
15 : #include <lstm.h>
16 : #include <nntr_threads.h>
17 : #include <nntrainer_error.h>
18 : #include <nntrainer_log.h>
19 : #include <node_exporter.h>
20 :
21 : namespace nntrainer {
22 :
23 : static constexpr size_t SINGLE_INOUT_IDX = 0;
24 :
25 : enum LSTMParams {
26 : weight_ih,
27 : weight_hh,
28 : bias_h,
29 : bias_ih,
30 : bias_hh,
31 : hidden_state,
32 : cell_state,
33 : ifgo,
34 : reverse_weight_ih,
35 : reverse_weight_hh,
36 : reverse_bias_h,
37 : reverse_bias_ih,
38 : reverse_bias_hh,
39 : reverse_hidden_state,
40 : reverse_cell_state,
41 : reverse_ifgo,
42 : dropout_mask
43 : };
44 :
45 189 : void LSTMLayer::forwardingBatchFirstLSTM(
46 : unsigned int NUM_GATE, const unsigned int batch_size,
47 : const unsigned int feature_size, const bool disable_bias,
48 : const unsigned int unit, const bool integrate_bias, ActiFunc &acti_func,
49 : ActiFunc &recurrent_acti_func, const bool enable_dropout,
50 : const float dropout_rate, const unsigned int max_timestep, const bool reverse,
51 : const Tensor &input_, const Tensor &weight_ih, const Tensor &weight_hh,
52 : const Tensor &bias_h, const Tensor &bias_ih, const Tensor &bias_hh,
53 : Tensor &hidden_state_, Tensor &cell_state_, Tensor &ifgo_,
54 : const Tensor &mask_) {
55 189 : hidden_state_.setZero();
56 189 : cell_state_.setZero();
57 189 : TensorDim::TensorType tensor_type = weight_ih.getTensorType();
58 189 : TensorDim input_tensor_dim({feature_size}, tensor_type);
59 189 : TensorDim unit_tensor_dim({unit}, tensor_type);
60 189 : TensorDim num_gate_unit_tensor_dim({NUM_GATE * unit}, tensor_type);
61 :
62 591 : for (unsigned int batch = 0; batch < batch_size; ++batch) {
63 402 : const Tensor input_sample = input_.getBatchSlice(batch, 1);
64 402 : Tensor hidden_state_sample = hidden_state_.getBatchSlice(batch, 1);
65 402 : Tensor cell_state_sample = cell_state_.getBatchSlice(batch, 1);
66 402 : Tensor ifgo_sample = ifgo_.getBatchSlice(batch, 1);
67 :
68 1281 : for (unsigned int t = 0; t < max_timestep; ++t) {
69 : Tensor input = input_sample.getSharedDataTensor(
70 879 : input_tensor_dim, (reverse ? max_timestep - 1 - t : t) * feature_size);
71 :
72 : Tensor prev_hidden_state = Tensor(
73 879 : "prev_hidden_state", weight_ih.getFormat(), weight_ih.getDataType());
74 :
75 879 : if (!t) {
76 804 : prev_hidden_state = Tensor(unit, tensor_type);
77 402 : prev_hidden_state.setZero();
78 : } else {
79 1431 : prev_hidden_state = hidden_state_sample.getSharedDataTensor(
80 477 : unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
81 : }
82 : Tensor hidden_state = hidden_state_sample.getSharedDataTensor(
83 879 : unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
84 879 : Tensor prev_cell_state;
85 879 : if (!t) {
86 804 : prev_cell_state = Tensor(unit, tensor_type);
87 402 : prev_cell_state.setZero();
88 : } else {
89 1431 : prev_cell_state = cell_state_sample.getSharedDataTensor(
90 477 : unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
91 : }
92 : Tensor cell_state = cell_state_sample.getSharedDataTensor(
93 879 : unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
94 : Tensor ifgo = ifgo_sample.getSharedDataTensor(
95 : num_gate_unit_tensor_dim,
96 879 : (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
97 :
98 879 : forwardLSTM(1, unit, disable_bias, integrate_bias, acti_func,
99 : recurrent_acti_func, input, prev_hidden_state,
100 : prev_cell_state, hidden_state, cell_state, weight_ih,
101 : weight_hh, bias_h, bias_ih, bias_hh, ifgo);
102 :
103 879 : if (enable_dropout) {
104 0 : Tensor mask_sample = mask_.getBatchSlice(batch, 1);
105 : Tensor mask =
106 0 : mask_sample.getSharedDataTensor(unit_tensor_dim, t * unit);
107 0 : mask.dropout_mask(dropout_rate);
108 0 : hidden_state.multiply_i(mask);
109 0 : }
110 879 : }
111 402 : }
112 189 : }
113 :
114 110 : void LSTMLayer::calcGradientBatchFirstLSTM(
115 : unsigned int NUM_GATE, const unsigned int batch_size,
116 : const unsigned int feature_size, const bool disable_bias,
117 : const unsigned int unit, const bool integrate_bias, ActiFunc &acti_func,
118 : ActiFunc &recurrent_acti_func, const bool return_sequences,
119 : const bool bidirectional, const bool enable_dropout, const float dropout_rate,
120 : const unsigned int max_timestep, const bool reverse, const Tensor &input_,
121 : const Tensor &incoming_derivative, Tensor &d_weight_ih,
122 : const Tensor &weight_hh, Tensor &d_weight_hh, Tensor &d_bias_h,
123 : Tensor &d_bias_ih, Tensor &d_bias_hh, const Tensor &hidden_state_,
124 : Tensor &d_hidden_state_, const Tensor &cell_state_, Tensor &d_cell_state_,
125 : const Tensor &ifgo_, Tensor &d_ifgo_, const Tensor &mask_) {
126 110 : const unsigned int bidirectional_constant = bidirectional ? 2 : 1;
127 :
128 110 : d_weight_ih.setZero();
129 110 : d_weight_hh.setZero();
130 110 : if (!disable_bias) {
131 110 : if (integrate_bias) {
132 83 : d_bias_h.setZero();
133 : } else {
134 27 : d_bias_ih.setZero();
135 27 : d_bias_hh.setZero();
136 : }
137 : }
138 :
139 110 : d_cell_state_.setZero();
140 110 : d_hidden_state_.setZero();
141 :
142 110 : TensorDim::TensorType tensor_type = weight_hh.getTensorType();
143 110 : TensorDim unit_tensor_dim({unit}, tensor_type);
144 110 : TensorDim feature_size_tensor_dim({feature_size}, tensor_type);
145 110 : TensorDim num_gate_tensor_dim({NUM_GATE * unit}, tensor_type);
146 :
147 110 : if (return_sequences && !bidirectional && !reverse) {
148 57 : if (incoming_derivative.getDataType() == TensorDim::DataType::FP32) {
149 57 : std::copy(incoming_derivative.getData<float>(),
150 57 : incoming_derivative.getData<float>() +
151 57 : incoming_derivative.size(),
152 : d_hidden_state_.getData<float>());
153 0 : } else if (incoming_derivative.getDataType() == TensorDim::DataType::FP16) {
154 : #ifdef ENABLE_FP16
155 : std::copy(incoming_derivative.getData<_FP16>(),
156 : incoming_derivative.getData<_FP16>() +
157 : incoming_derivative.size(),
158 : d_hidden_state_.getData<_FP16>());
159 : #else
160 0 : throw std::invalid_argument("Error: enable-fp16 is not enabled");
161 : #endif
162 : }
163 : } else {
164 53 : unsigned int end_timestep = return_sequences ? max_timestep : 1;
165 157 : for (unsigned int batch = 0; batch < batch_size; ++batch) {
166 262 : for (unsigned int timestep = 0; timestep < end_timestep; ++timestep) {
167 : Tensor d_hidden_state_sample = d_hidden_state_.getSharedDataTensor(
168 316 : unit_tensor_dim, batch * max_timestep * unit +
169 158 : (return_sequences ? 0 : max_timestep - 1) * unit +
170 316 : timestep * unit);
171 : Tensor incoming_derivative_sample =
172 : incoming_derivative.getSharedDataTensor(
173 158 : unit_tensor_dim, batch * (return_sequences ? max_timestep : 1) *
174 158 : bidirectional_constant * unit +
175 158 : timestep * bidirectional_constant * unit +
176 370 : (reverse ? unit : 0));
177 158 : d_hidden_state_sample.add_i(incoming_derivative_sample);
178 158 : }
179 : }
180 : }
181 :
182 110 : if (enable_dropout) {
183 0 : d_hidden_state_.multiply_i(mask_);
184 : }
185 :
186 110 : auto workers = ParallelBatch(batch_size);
187 :
188 110 : if (workers.getNumWorkers() > 1) {
189 :
190 0 : TensorDim weight_ih_d = d_weight_ih.getDim();
191 0 : TensorDim weight_hh_d = d_weight_hh.getDim();
192 :
193 0 : TensorDim bias_ih_d = d_bias_ih.getDim();
194 0 : TensorDim bias_hh_d = d_bias_hh.getDim();
195 0 : TensorDim bias_h_d = d_bias_h.getDim();
196 :
197 0 : weight_ih_d.batch(workers.getNumWorkers());
198 0 : weight_hh_d.batch(workers.getNumWorkers());
199 0 : bias_ih_d.batch(workers.getNumWorkers());
200 0 : bias_hh_d.batch(workers.getNumWorkers());
201 0 : bias_h_d.batch(workers.getNumWorkers());
202 :
203 0 : Tensor sub_d_weight_ih = Tensor(weight_ih_d);
204 0 : Tensor sub_d_weight_hh = Tensor(weight_hh_d);
205 0 : Tensor sub_d_bias_ih = Tensor(bias_ih_d);
206 0 : Tensor sub_d_bias_hh = Tensor(bias_hh_d);
207 0 : Tensor sub_d_bias_h = Tensor(bias_h_d);
208 :
209 0 : sub_d_weight_ih.setZero();
210 0 : sub_d_weight_hh.setZero();
211 0 : sub_d_bias_ih.setZero();
212 0 : sub_d_bias_hh.setZero();
213 0 : sub_d_bias_h.setZero();
214 :
215 0 : auto batch_job = [&](unsigned int s, unsigned int e, unsigned int pid,
216 : void *user_data) {
217 0 : for (unsigned int batch = s; batch < e; ++batch) {
218 0 : const Tensor input_sample = input_.getBatchSlice(batch, 1);
219 :
220 : const Tensor hidden_state_sample =
221 0 : hidden_state_.getBatchSlice(batch, 1);
222 0 : Tensor d_hidden_state_sample = d_hidden_state_.getBatchSlice(batch, 1);
223 0 : const Tensor cell_state_sample = cell_state_.getBatchSlice(batch, 1);
224 0 : Tensor d_cell_state_sample = d_cell_state_.getBatchSlice(batch, 1);
225 :
226 0 : const Tensor ifgo_sample = ifgo_.getBatchSlice(batch, 1);
227 0 : Tensor d_ifgo_sample = d_ifgo_.getBatchSlice(batch, 1);
228 :
229 0 : Tensor input;
230 0 : Tensor prev_hidden_state;
231 0 : Tensor d_prev_hidden_state;
232 0 : Tensor prev_cell_state;
233 0 : Tensor d_prev_cell_state;
234 0 : Tensor d_hidden_state;
235 0 : Tensor cell_state;
236 0 : Tensor d_cell_state;
237 :
238 0 : Tensor p_d_weight_ih = sub_d_weight_ih.getBatchSlice(pid, 1);
239 0 : Tensor p_d_weight_hh = sub_d_weight_hh.getBatchSlice(pid, 1);
240 0 : Tensor p_d_bias_ih = sub_d_bias_ih.getBatchSlice(pid, 1);
241 0 : Tensor p_d_bias_hh = sub_d_bias_hh.getBatchSlice(pid, 1);
242 0 : Tensor p_d_bias_h = sub_d_bias_h.getBatchSlice(pid, 1);
243 :
244 0 : for (int t = max_timestep - 1; t > -1; t--) {
245 0 : input = input_sample.getSharedDataTensor(
246 : feature_size_tensor_dim,
247 0 : (reverse ? max_timestep - 1 - t : t) * feature_size);
248 :
249 0 : if (!t) {
250 0 : prev_hidden_state = Tensor(unit, tensor_type);
251 0 : prev_hidden_state.setZero();
252 0 : d_prev_hidden_state = Tensor(unit, tensor_type);
253 0 : d_prev_hidden_state.setZero();
254 : } else {
255 0 : prev_hidden_state = hidden_state_sample.getSharedDataTensor(
256 0 : unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
257 0 : d_prev_hidden_state = d_hidden_state_sample.getSharedDataTensor(
258 0 : unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
259 : }
260 0 : d_hidden_state = d_hidden_state_sample.getSharedDataTensor(
261 0 : unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
262 :
263 0 : if (!t) {
264 0 : prev_cell_state = Tensor(unit, tensor_type);
265 0 : prev_cell_state.setZero();
266 0 : d_prev_cell_state = Tensor(unit, tensor_type);
267 0 : d_prev_cell_state.setZero();
268 : } else {
269 0 : prev_cell_state = cell_state_sample.getSharedDataTensor(
270 0 : unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
271 0 : d_prev_cell_state = d_cell_state_sample.getSharedDataTensor(
272 0 : unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
273 : }
274 0 : cell_state = cell_state_sample.getSharedDataTensor(
275 0 : unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
276 0 : d_cell_state = d_cell_state_sample.getSharedDataTensor(
277 0 : unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
278 :
279 : Tensor ifgo = ifgo_sample.getSharedDataTensor(
280 : num_gate_tensor_dim,
281 0 : (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
282 : Tensor d_ifgo = d_ifgo_sample.getSharedDataTensor(
283 : num_gate_tensor_dim,
284 0 : (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
285 :
286 : // Temporary variable for d_prev_hidden_state. d_prev_hidden_state
287 : // already have precalculated values from incomming derivatives
288 : Tensor d_prev_hidden_state_temp =
289 : Tensor("d_prev_hidden_state_temp", tensor_type.format,
290 0 : tensor_type.data_type);
291 :
292 0 : calcGradientLSTM(
293 0 : 1, unit, disable_bias, integrate_bias, acti_func,
294 : recurrent_acti_func, input, prev_hidden_state,
295 : d_prev_hidden_state_temp, prev_cell_state, d_prev_cell_state,
296 : d_hidden_state, cell_state, d_cell_state, p_d_weight_ih, weight_hh,
297 : p_d_weight_hh, p_d_bias_h, p_d_bias_ih, p_d_bias_hh, ifgo, d_ifgo);
298 :
299 0 : d_prev_hidden_state.add_i(d_prev_hidden_state_temp);
300 0 : }
301 0 : }
302 0 : };
303 :
304 0 : workers.setCallback(batch_job, nullptr);
305 0 : workers.run();
306 :
307 0 : for (unsigned int b = 0; b < workers.getNumWorkers(); ++b) {
308 :
309 0 : Tensor p_d_weight_ih = sub_d_weight_ih.getBatchSlice(b, 1);
310 0 : Tensor p_d_weight_hh = sub_d_weight_hh.getBatchSlice(b, 1);
311 0 : Tensor p_d_bias_ih = sub_d_bias_ih.getBatchSlice(b, 1);
312 0 : Tensor p_d_bias_hh = sub_d_bias_hh.getBatchSlice(b, 1);
313 0 : Tensor p_d_bias_h = sub_d_bias_h.getBatchSlice(b, 1);
314 :
315 0 : d_weight_ih.add_i(p_d_weight_ih);
316 0 : d_weight_hh.add_i(p_d_weight_hh);
317 :
318 0 : if (!disable_bias) {
319 0 : if (integrate_bias) {
320 0 : d_bias_h.add_i(p_d_bias_h);
321 : } else {
322 0 : d_bias_ih.add_i(p_d_bias_ih);
323 0 : d_bias_hh.add_i(p_d_bias_hh);
324 : }
325 : }
326 0 : }
327 :
328 0 : } else {
329 319 : for (unsigned int batch = 0; batch < batch_size; ++batch) {
330 209 : const Tensor input_sample = input_.getBatchSlice(batch, 1);
331 :
332 209 : const Tensor hidden_state_sample = hidden_state_.getBatchSlice(batch, 1);
333 209 : Tensor d_hidden_state_sample = d_hidden_state_.getBatchSlice(batch, 1);
334 209 : const Tensor cell_state_sample = cell_state_.getBatchSlice(batch, 1);
335 209 : Tensor d_cell_state_sample = d_cell_state_.getBatchSlice(batch, 1);
336 :
337 209 : const Tensor ifgo_sample = ifgo_.getBatchSlice(batch, 1);
338 209 : Tensor d_ifgo_sample = d_ifgo_.getBatchSlice(batch, 1);
339 :
340 209 : Tensor input;
341 209 : Tensor prev_hidden_state;
342 209 : Tensor d_prev_hidden_state;
343 209 : Tensor prev_cell_state;
344 209 : Tensor d_prev_cell_state;
345 209 : Tensor d_hidden_state;
346 209 : Tensor cell_state;
347 209 : Tensor d_cell_state;
348 :
349 634 : for (int t = max_timestep - 1; t > -1; t--) {
350 850 : input = input_sample.getSharedDataTensor(
351 : feature_size_tensor_dim,
352 425 : (reverse ? max_timestep - 1 - t : t) * feature_size);
353 :
354 425 : if (!t) {
355 627 : prev_hidden_state = Tensor(unit, tensor_type);
356 209 : prev_hidden_state.setZero();
357 627 : d_prev_hidden_state = Tensor(unit, tensor_type);
358 209 : d_prev_hidden_state.setZero();
359 : } else {
360 432 : prev_hidden_state = hidden_state_sample.getSharedDataTensor(
361 216 : unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
362 648 : d_prev_hidden_state = d_hidden_state_sample.getSharedDataTensor(
363 216 : unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
364 : }
365 850 : d_hidden_state = d_hidden_state_sample.getSharedDataTensor(
366 425 : unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
367 :
368 425 : if (!t) {
369 627 : prev_cell_state = Tensor(unit, tensor_type);
370 209 : prev_cell_state.setZero();
371 627 : d_prev_cell_state = Tensor(unit, tensor_type);
372 209 : d_prev_cell_state.setZero();
373 : } else {
374 432 : prev_cell_state = cell_state_sample.getSharedDataTensor(
375 216 : unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
376 648 : d_prev_cell_state = d_cell_state_sample.getSharedDataTensor(
377 216 : unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
378 : }
379 850 : cell_state = cell_state_sample.getSharedDataTensor(
380 425 : unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
381 850 : d_cell_state = d_cell_state_sample.getSharedDataTensor(
382 425 : unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
383 :
384 : Tensor ifgo = ifgo_sample.getSharedDataTensor(
385 : num_gate_tensor_dim,
386 425 : (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
387 : Tensor d_ifgo = d_ifgo_sample.getSharedDataTensor(
388 : num_gate_tensor_dim,
389 425 : (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
390 :
391 : // Temporary variable for d_prev_hidden_state. d_prev_hidden_state
392 : // already have precalculated values from incomming derivatives
393 : Tensor d_prev_hidden_state_temp =
394 : Tensor("d_prev_hidden_state_temp", tensor_type.format,
395 425 : tensor_type.data_type);
396 :
397 425 : calcGradientLSTM(1, unit, disable_bias, integrate_bias, acti_func,
398 : recurrent_acti_func, input, prev_hidden_state,
399 : d_prev_hidden_state_temp, prev_cell_state,
400 : d_prev_cell_state, d_hidden_state, cell_state,
401 : d_cell_state, d_weight_ih, weight_hh, d_weight_hh,
402 : d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo);
403 425 : d_prev_hidden_state.add_i(d_prev_hidden_state_temp);
404 425 : }
405 209 : }
406 : }
407 110 : }
408 :
409 76 : LSTMLayer::LSTMLayer() :
410 : LSTMCore(),
411 76 : lstm_props(props::ReturnSequences(), props::Bidirectional(),
412 152 : props::DropOutRate(), props::MaxTimestep()) {
413 : wt_idx.fill(std::numeric_limits<unsigned>::max());
414 76 : }
415 :
416 62 : void LSTMLayer::finalize(InitLayerContext &context) {
417 : const Initializer weight_initializer =
418 62 : std::get<props::WeightInitializer>(*layer_impl_props).get();
419 : const Initializer bias_initializer =
420 62 : std::get<props::BiasInitializer>(*layer_impl_props).get();
421 : const nntrainer::WeightRegularizer weight_regularizer =
422 62 : std::get<props::WeightRegularizer>(*layer_impl_props).get();
423 : const float weight_regularizer_constant =
424 62 : std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
425 : auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
426 : auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
427 : const bool disable_bias =
428 62 : std::get<props::DisableBias>(*layer_impl_props).get();
429 :
430 62 : NNTR_THROW_IF(std::get<props::Unit>(lstmcore_props).empty(),
431 : std::invalid_argument)
432 : << "unit property missing for lstm layer";
433 62 : const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
434 : const bool integrate_bias =
435 62 : std::get<props::IntegrateBias>(lstmcore_props).get();
436 : const ActivationType hidden_state_activation_type =
437 62 : std::get<props::HiddenStateActivation>(lstmcore_props).get();
438 : const ActivationType recurrent_activation_type =
439 62 : std::get<props::RecurrentActivation>(lstmcore_props).get();
440 :
441 : const bool return_sequences =
442 62 : std::get<props::ReturnSequences>(lstm_props).get();
443 62 : const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
444 62 : const float dropout_rate = std::get<props::DropOutRate>(lstm_props).get();
445 :
446 62 : if (context.getNumInputs() != 1) {
447 0 : throw std::invalid_argument("LSTM layer takes only one input");
448 : }
449 :
450 : // input_dim = [ batch_size, 1, time_iteration, feature_size ]
451 : const TensorDim &input_dim = context.getInputDimensions()[SINGLE_INOUT_IDX];
452 62 : if (input_dim.channel() != 1) {
453 : throw std::invalid_argument(
454 : "Input must be single channel dimension for LSTM (shape should be "
455 0 : "[batch_size, 1, time_iteration, feature_size])");
456 : }
457 62 : const unsigned int batch_size = input_dim.batch();
458 62 : unsigned int max_timestep = input_dim.height();
459 62 : if (!std::get<props::MaxTimestep>(lstm_props).empty())
460 26 : max_timestep =
461 52 : std::max(max_timestep, std::get<props::MaxTimestep>(lstm_props).get());
462 62 : NNTR_THROW_IF(max_timestep < 1, std::runtime_error)
463 : << "max timestep must be greator than 0 in lstm layer.";
464 62 : std::get<props::MaxTimestep>(lstm_props).set(max_timestep);
465 62 : const unsigned int feature_size = input_dim.width();
466 :
467 : // output_dim = [ batch_size, 1, return_sequences ? time_iteration : 1,
468 : // bidirectional ? 2 * unit : unit ]
469 : TensorDim::TensorType activation_tensor_type = {
470 : context.getFormat(), context.getActivationDataType()};
471 :
472 : TensorDim::TensorType weight_tensor_type = {context.getFormat(),
473 : context.getWeightDataType()};
474 44 : const TensorDim output_dim(batch_size, 1, return_sequences ? max_timestep : 1,
475 12 : bidirectional ? 2 * unit : unit,
476 62 : activation_tensor_type);
477 62 : context.setOutputDimensions({output_dim});
478 :
479 : // weight_initializer can be set seperately. weight_ih initializer,
480 : // weight_hh initializer kernel initializer & recurrent_initializer in
481 : // keras for now, it is set same way.
482 :
483 : // weight_ih ( input to hidden ) : [ 1, 1, feature_size, NUM_GATE * unit ]
484 : // -> i, f, g, o
485 62 : const TensorDim weight_ih_dim({feature_size, NUM_GATE * unit},
486 62 : weight_tensor_type);
487 62 : wt_idx[LSTMParams::weight_ih] = context.requestWeight(
488 : weight_ih_dim, weight_initializer, weight_regularizer,
489 : weight_regularizer_constant, weight_decay, "weight_ih", true);
490 : // weight_hh ( hidden to hidden ) : [ 1, 1, unit, NUM_GATE * unit ] -> i,
491 : // f, g, o
492 62 : const TensorDim weight_hh_dim({unit, NUM_GATE * unit}, weight_tensor_type);
493 124 : wt_idx[LSTMParams::weight_hh] = context.requestWeight(
494 : weight_hh_dim, weight_initializer, weight_regularizer,
495 : weight_regularizer_constant, weight_decay, "weight_hh", true);
496 62 : if (!disable_bias) {
497 62 : if (integrate_bias) {
498 : // bias_h ( input bias, hidden bias are integrate to 1 bias ) : [ 1,
499 : // 1, 1, NUM_GATE * unit ] -> i, f, g, o
500 34 : const TensorDim bias_h_dim({NUM_GATE * unit}, weight_tensor_type);
501 34 : wt_idx[LSTMParams::bias_h] = context.requestWeight(
502 : bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
503 : "bias_h", true);
504 : } else {
505 : // bias_ih ( input bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o
506 28 : const TensorDim bias_ih_dim({NUM_GATE * unit}, weight_tensor_type);
507 28 : wt_idx[LSTMParams::bias_ih] = context.requestWeight(
508 : bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
509 : bias_decay, "bias_ih", true);
510 : // bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o
511 56 : wt_idx[LSTMParams::bias_hh] = context.requestWeight(
512 : bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
513 : bias_decay, "bias_hh", true);
514 : }
515 : }
516 :
517 : // hidden_state_dim : [ batch_size, 1, max_timestep, unit ]
518 : const TensorDim hidden_state_dim(batch_size, 1, max_timestep, unit,
519 62 : activation_tensor_type);
520 :
521 62 : wt_idx[LSTMParams::hidden_state] =
522 124 : context.requestTensor(hidden_state_dim, "hidden_state", Initializer::NONE,
523 : true, TensorLifespan::ITERATION_LIFESPAN);
524 : // cell_state_dim : [ batch_size, 1, max_timestep, unit ]
525 : const TensorDim cell_state_dim(batch_size, 1, max_timestep, unit,
526 62 : activation_tensor_type);
527 :
528 62 : wt_idx[LSTMParams::cell_state] =
529 124 : context.requestTensor(cell_state_dim, "cell_state", Initializer::NONE, true,
530 : TensorLifespan::ITERATION_LIFESPAN);
531 :
532 : // ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ]
533 : const TensorDim ifgo_dim(batch_size, 1, max_timestep, NUM_GATE * unit,
534 62 : activation_tensor_type);
535 :
536 62 : wt_idx[LSTMParams::ifgo] =
537 62 : context.requestTensor(ifgo_dim, "ifgo", Initializer::NONE, true,
538 : TensorLifespan::ITERATION_LIFESPAN);
539 :
540 62 : if (bidirectional) {
541 : // weight_initializer can be set seperately. weight_ih initializer,
542 : // weight_hh initializer kernel initializer & recurrent_initializer in
543 : // keras for now, it is set same way.
544 :
545 : // reverse_weight_ih ( input to hidden ) : [ 1, 1, feature_size,
546 : // NUM_GATE * unit ] -> i, f, g, o
547 : const TensorDim reverse_weight_ih_dim({feature_size, NUM_GATE * unit},
548 12 : weight_tensor_type);
549 24 : wt_idx[LSTMParams::reverse_weight_ih] = context.requestWeight(
550 : reverse_weight_ih_dim, weight_initializer, weight_regularizer,
551 : weight_regularizer_constant, weight_decay, "reverse_weight_ih", true);
552 : // reverse_weight_hh ( hidden to hidden ) : [ 1, 1, unit, NUM_GATE *
553 : // unit ]
554 : // -> i, f, g, o
555 : const TensorDim reverse_weight_hh_dim({unit, NUM_GATE * unit},
556 12 : weight_tensor_type);
557 24 : wt_idx[LSTMParams::reverse_weight_hh] = context.requestWeight(
558 : reverse_weight_hh_dim, weight_initializer, weight_regularizer,
559 : weight_regularizer_constant, weight_decay, "reverse_weight_hh", true);
560 12 : if (!disable_bias) {
561 12 : if (integrate_bias) {
562 : // reverse_bias_h ( input bias, hidden bias are integrate to 1 bias
563 : // ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o
564 : const TensorDim reverse_bias_h_dim({NUM_GATE * unit},
565 0 : weight_tensor_type);
566 0 : wt_idx[LSTMParams::reverse_bias_h] = context.requestWeight(
567 : reverse_bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
568 : bias_decay, "reverse_bias_h", true);
569 : } else {
570 : // reverse_bias_ih ( input bias ) : [ 1, 1, 1, NUM_GATE * unit ] ->
571 : // i, f, g, o
572 : const TensorDim reverse_bias_ih_dim({NUM_GATE * unit},
573 12 : weight_tensor_type);
574 12 : wt_idx[LSTMParams::reverse_bias_ih] = context.requestWeight(
575 : reverse_bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
576 : bias_decay, "reverse_bias_ih", true);
577 : // reverse_bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] ->
578 : // i, f, g, o
579 : const TensorDim reverse_bias_hh_dim({NUM_GATE * unit},
580 12 : weight_tensor_type);
581 24 : wt_idx[LSTMParams::reverse_bias_hh] = context.requestWeight(
582 : reverse_bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
583 : bias_decay, "reverse_bias_hh", true);
584 : }
585 : }
586 :
587 : // reverse_hidden_state_dim : [ batch_size, 1, max_timestep, unit ]
588 : const TensorDim reverse_hidden_state_dim(batch_size, 1, max_timestep, unit,
589 12 : activation_tensor_type);
590 12 : wt_idx[LSTMParams::reverse_hidden_state] = context.requestTensor(
591 : reverse_hidden_state_dim, "reverse_hidden_state", Initializer::NONE, true,
592 : TensorLifespan::ITERATION_LIFESPAN);
593 : // reverse_cell_state_dim : [ batch_size, 1, max_timestep, unit ]
594 : const TensorDim reverse_cell_state_dim(batch_size, 1, max_timestep, unit,
595 12 : activation_tensor_type);
596 12 : wt_idx[LSTMParams::reverse_cell_state] = context.requestTensor(
597 : reverse_cell_state_dim, "reverse_cell_state", Initializer::NONE, true,
598 : TensorLifespan::ITERATION_LIFESPAN);
599 :
600 : // reverse_ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ]
601 : const TensorDim reverse_ifgo_dim(batch_size, 1, max_timestep,
602 12 : NUM_GATE * unit, activation_tensor_type);
603 12 : wt_idx[LSTMParams::reverse_ifgo] =
604 24 : context.requestTensor(reverse_ifgo_dim, "reverse_ifgo", Initializer::NONE,
605 : true, TensorLifespan::ITERATION_LIFESPAN);
606 : }
607 :
608 62 : if (dropout_rate > epsilon) {
609 : // dropout_mask_dim = [ batch, 1, time_iteration, unit ]
610 : const TensorDim dropout_mask_dim(batch_size, 1, max_timestep, unit,
611 0 : activation_tensor_type);
612 0 : wt_idx[LSTMParams::dropout_mask] =
613 0 : context.requestTensor(dropout_mask_dim, "dropout_mask", Initializer::NONE,
614 : false, TensorLifespan::ITERATION_LIFESPAN);
615 : }
616 :
617 62 : if (context.getActivationDataType() == TensorDim::DataType::FP32) {
618 62 : acti_func.setActiFunc<float>(hidden_state_activation_type);
619 62 : recurrent_acti_func.setActiFunc<float>(recurrent_activation_type);
620 0 : } else if (context.getActivationDataType() == TensorDim::DataType::FP16) {
621 : #ifdef ENABLE_FP16
622 : acti_func.setActiFunc<_FP16>(hidden_state_activation_type);
623 : recurrent_acti_func.setActiFunc<_FP16>(recurrent_activation_type);
624 : #else
625 0 : throw std::invalid_argument("Error: enable-fp16 is not enabled");
626 : #endif
627 : }
628 62 : }
629 :
630 337 : void LSTMLayer::setProperty(const std::vector<std::string> &values) {
631 : const std::vector<std::string> &remain_props =
632 337 : loadProperties(values, lstm_props);
633 336 : LSTMCore::setProperty(remain_props);
634 336 : }
635 :
636 26 : void LSTMLayer::exportTo(Exporter &exporter,
637 : const ml::train::ExportMethods &method) const {
638 26 : LSTMCore::exportTo(exporter, method);
639 26 : exporter.saveResult(lstm_props, method, this);
640 26 : }
641 :
642 171 : void LSTMLayer::forwarding(RunLayerContext &context, bool training) {
643 : const bool disable_bias =
644 171 : std::get<props::DisableBias>(*layer_impl_props).get();
645 :
646 171 : const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
647 : const bool integrate_bias =
648 171 : std::get<props::IntegrateBias>(lstmcore_props).get();
649 :
650 : const bool return_sequences =
651 171 : std::get<props::ReturnSequences>(lstm_props).get();
652 171 : const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
653 171 : const float dropout_rate = std::get<props::DropOutRate>(lstm_props).get();
654 : const unsigned int max_timestep =
655 171 : std::get<props::MaxTimestep>(lstm_props).get();
656 :
657 171 : const unsigned int bidirectional_constant = bidirectional ? 2 : 1;
658 171 : bool enable_dropout = dropout_rate > epsilon && training;
659 :
660 171 : const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
661 171 : const TensorDim input_dim = input.getDim();
662 171 : const unsigned int batch_size = input_dim.batch();
663 171 : const unsigned int feature_size = input_dim.width();
664 171 : Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
665 :
666 171 : const Tensor &weight_ih = context.getWeight(wt_idx[LSTMParams::weight_ih]);
667 171 : const Tensor &weight_hh = context.getWeight(wt_idx[LSTMParams::weight_hh]);
668 :
669 : Tensor empty =
670 171 : Tensor("empty", weight_ih.getFormat(), weight_ih.getDataType());
671 :
672 171 : const Tensor &bias_h = !disable_bias && integrate_bias
673 171 : ? context.getWeight(wt_idx[LSTMParams::bias_h])
674 : : empty;
675 : const Tensor &bias_ih = !disable_bias && !integrate_bias
676 171 : ? context.getWeight(wt_idx[LSTMParams::bias_ih])
677 : : empty;
678 : const Tensor &bias_hh = !disable_bias && !integrate_bias
679 171 : ? context.getWeight(wt_idx[LSTMParams::bias_hh])
680 : : empty;
681 :
682 171 : Tensor &hidden_state = context.getTensor(wt_idx[LSTMParams::hidden_state]);
683 171 : Tensor &cell_state = context.getTensor(wt_idx[LSTMParams::cell_state]);
684 171 : Tensor &ifgo = context.getTensor(wt_idx[LSTMParams::ifgo]);
685 :
686 : Tensor &mask = enable_dropout
687 171 : ? context.getTensor(wt_idx[LSTMParams::dropout_mask])
688 : : empty;
689 171 : forwardingBatchFirstLSTM(NUM_GATE, batch_size, feature_size, disable_bias,
690 171 : unit, integrate_bias, acti_func, recurrent_acti_func,
691 : enable_dropout, dropout_rate, max_timestep, false,
692 : input, weight_ih, weight_hh, bias_h, bias_ih,
693 : bias_hh, hidden_state, cell_state, ifgo, mask);
694 171 : if (bidirectional) {
695 : const Tensor &reverse_weight_ih =
696 18 : context.getWeight(wt_idx[LSTMParams::reverse_weight_ih]);
697 : const Tensor &reverse_weight_hh =
698 18 : context.getWeight(wt_idx[LSTMParams::reverse_weight_hh]);
699 : const Tensor &reverse_bias_h =
700 : !disable_bias && integrate_bias
701 18 : ? context.getWeight(wt_idx[LSTMParams::reverse_bias_h])
702 : : empty;
703 : const Tensor &reverse_bias_ih =
704 : !disable_bias && !integrate_bias
705 18 : ? context.getWeight(wt_idx[LSTMParams::reverse_bias_ih])
706 : : empty;
707 : const Tensor &reverse_bias_hh =
708 : !disable_bias && !integrate_bias
709 18 : ? context.getWeight(wt_idx[LSTMParams::reverse_bias_hh])
710 : : empty;
711 :
712 : Tensor &reverse_hidden_state =
713 18 : context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]);
714 : Tensor &reverse_cell_state =
715 18 : context.getTensor(wt_idx[LSTMParams::reverse_cell_state]);
716 18 : Tensor &reverse_ifgo = context.getTensor(wt_idx[LSTMParams::reverse_ifgo]);
717 :
718 18 : forwardingBatchFirstLSTM(
719 : NUM_GATE, batch_size, feature_size, disable_bias, unit, integrate_bias,
720 : acti_func, recurrent_acti_func, enable_dropout, dropout_rate,
721 : max_timestep, true, input, reverse_weight_ih, reverse_weight_hh,
722 : reverse_bias_h, reverse_bias_ih, reverse_bias_hh, reverse_hidden_state,
723 : reverse_cell_state, reverse_ifgo, mask);
724 : }
725 :
726 171 : if (return_sequences && !bidirectional) {
727 98 : if (hidden_state.getDataType() == TensorDim::DataType::FP32) {
728 98 : std::copy(hidden_state.getData<float>(),
729 98 : hidden_state.getData<float>() + hidden_state.size(),
730 : output.getData<float>());
731 0 : } else if (hidden_state.getDataType() == TensorDim::DataType::FP16) {
732 : #ifdef ENABLE_FP16
733 : std::copy(hidden_state.getData<_FP16>(),
734 : hidden_state.getData<_FP16>() + hidden_state.size(),
735 : output.getData<_FP16>());
736 : #else
737 0 : throw std::invalid_argument("Error: enable-fp16 is not enabled");
738 : #endif
739 : }
740 : } else {
741 73 : unsigned int end_timestep = return_sequences ? max_timestep : 1;
742 73 : if (hidden_state.getDataType() == TensorDim::DataType::FP32) {
743 217 : for (unsigned int batch = 0; batch < batch_size; ++batch) {
744 342 : for (unsigned int timestep = 0; timestep < end_timestep; ++timestep) {
745 198 : float *hidden_state_data = hidden_state.getAddress<float>(
746 198 : batch * max_timestep * unit +
747 198 : (return_sequences ? 0 : (max_timestep - 1) * unit) +
748 : timestep * unit);
749 198 : float *output_data = output.getAddress<float>(
750 198 : batch * (return_sequences ? max_timestep : 1) *
751 198 : bidirectional_constant * unit +
752 : timestep * bidirectional_constant * unit);
753 198 : std::copy(hidden_state_data, hidden_state_data + unit, output_data);
754 :
755 198 : if (bidirectional) {
756 : Tensor &reverse_hidden_state =
757 108 : context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]);
758 : float *reverse_hidden_state_data =
759 : reverse_hidden_state.getAddress<float>(
760 : batch * max_timestep * unit +
761 : (return_sequences ? 0 : (max_timestep - 1) * unit) +
762 : timestep * unit);
763 108 : std::copy(reverse_hidden_state_data,
764 : reverse_hidden_state_data + unit, output_data + unit);
765 : }
766 : }
767 : }
768 0 : } else if (hidden_state.getDataType() == TensorDim::DataType::FP16) {
769 : #ifdef ENABLE_FP16
770 : for (unsigned int batch = 0; batch < batch_size; ++batch) {
771 : for (unsigned int timestep = 0; timestep < end_timestep; ++timestep) {
772 : _FP16 *hidden_state_data = hidden_state.getAddress<_FP16>(
773 : batch * max_timestep * unit +
774 : (return_sequences ? 0 : (max_timestep - 1) * unit) +
775 : timestep * unit);
776 : _FP16 *output_data = output.getAddress<_FP16>(
777 : batch * (return_sequences ? max_timestep : 1) *
778 : bidirectional_constant * unit +
779 : timestep * bidirectional_constant * unit);
780 : std::copy(hidden_state_data, hidden_state_data + unit, output_data);
781 :
782 : if (bidirectional) {
783 : Tensor &reverse_hidden_state =
784 : context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]);
785 : _FP16 *reverse_hidden_state_data =
786 : reverse_hidden_state.getAddress<_FP16>(
787 : batch * max_timestep * unit +
788 : (return_sequences ? 0 : (max_timestep - 1) * unit) +
789 : timestep * unit);
790 : std::copy(reverse_hidden_state_data,
791 : reverse_hidden_state_data + unit, output_data + unit);
792 : }
793 : }
794 : }
795 : #else
796 0 : throw std::invalid_argument("Error: enable-fp16 is not enabled");
797 : #endif
798 : }
799 : }
800 171 : }
801 :
802 101 : void LSTMLayer::calcDerivative(RunLayerContext &context) {
803 101 : const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
804 :
805 101 : Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
806 101 : const Tensor &weight_ih = context.getWeight(wt_idx[LSTMParams::weight_ih]);
807 101 : const Tensor &d_ifgos = context.getTensorGrad(wt_idx[LSTMParams::ifgo]);
808 :
809 101 : calcDerivativeLSTM(outgoing_derivative, weight_ih, d_ifgos);
810 :
811 101 : if (bidirectional) {
812 : const Tensor &reverse_weight_ih =
813 9 : context.getWeight(wt_idx[LSTMParams::reverse_weight_ih]);
814 : const Tensor &reverse_d_ifgos =
815 9 : context.getTensorGrad(wt_idx[LSTMParams::reverse_ifgo]);
816 :
817 9 : calcDerivativeLSTM(outgoing_derivative, reverse_weight_ih, reverse_d_ifgos,
818 : 1.0f);
819 : }
820 101 : }
821 :
822 101 : void LSTMLayer::calcGradient(RunLayerContext &context) {
823 : const bool disable_bias =
824 101 : std::get<props::DisableBias>(*layer_impl_props).get();
825 :
826 101 : const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
827 : const bool integrate_bias =
828 101 : std::get<props::IntegrateBias>(lstmcore_props).get();
829 :
830 : const bool return_sequences =
831 101 : std::get<props::ReturnSequences>(lstm_props).get();
832 101 : const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
833 101 : const float dropout_rate = std::get<props::DropOutRate>(lstm_props).get();
834 : const unsigned int max_timestep =
835 101 : std::get<props::MaxTimestep>(lstm_props).get();
836 :
837 101 : bool enable_dropout = dropout_rate > epsilon;
838 :
839 101 : const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
840 : const Tensor &incoming_derivative =
841 101 : context.getIncomingDerivative(SINGLE_INOUT_IDX);
842 101 : const TensorDim input_dim = input.getDim();
843 101 : const unsigned int batch_size = input_dim.batch();
844 101 : const unsigned int feature_size = input_dim.width();
845 :
846 101 : Tensor &d_weight_ih = context.getWeightGrad(wt_idx[LSTMParams::weight_ih]);
847 101 : const Tensor &weight_hh = context.getWeight(wt_idx[LSTMParams::weight_hh]);
848 101 : Tensor &d_weight_hh = context.getWeightGrad(wt_idx[LSTMParams::weight_hh]);
849 :
850 : Tensor empty =
851 101 : Tensor("empty", weight_hh.getFormat(), weight_hh.getDataType());
852 :
853 101 : Tensor &d_bias_h = !disable_bias && integrate_bias
854 101 : ? context.getWeightGrad(wt_idx[LSTMParams::bias_h])
855 : : empty;
856 : Tensor &d_bias_ih = !disable_bias && !integrate_bias
857 101 : ? context.getWeightGrad(wt_idx[LSTMParams::bias_ih])
858 : : empty;
859 : Tensor &d_bias_hh = !disable_bias && !integrate_bias
860 101 : ? context.getWeightGrad(wt_idx[LSTMParams::bias_hh])
861 : : empty;
862 :
863 : const Tensor &hidden_state =
864 101 : context.getTensor(wt_idx[LSTMParams::hidden_state]);
865 : Tensor &d_hidden_state =
866 101 : context.getTensorGrad(wt_idx[LSTMParams::hidden_state]);
867 101 : const Tensor &cell_state = context.getTensor(wt_idx[LSTMParams::cell_state]);
868 101 : Tensor &d_cell_state = context.getTensorGrad(wt_idx[LSTMParams::cell_state]);
869 :
870 101 : const Tensor &ifgo = context.getTensor(wt_idx[LSTMParams::ifgo]);
871 101 : Tensor &d_ifgo = context.getTensorGrad(wt_idx[LSTMParams::ifgo]);
872 :
873 : const Tensor &mask = enable_dropout
874 101 : ? context.getTensor(wt_idx[LSTMParams::dropout_mask])
875 : : empty;
876 :
877 101 : calcGradientBatchFirstLSTM(
878 : NUM_GATE, batch_size, feature_size, disable_bias, unit, integrate_bias,
879 101 : acti_func, recurrent_acti_func, return_sequences, bidirectional,
880 : enable_dropout, dropout_rate, max_timestep, false, input,
881 : incoming_derivative, d_weight_ih, weight_hh, d_weight_hh, d_bias_h,
882 : d_bias_ih, d_bias_hh, hidden_state, d_hidden_state, cell_state,
883 : d_cell_state, ifgo, d_ifgo, mask);
884 :
885 101 : if (bidirectional) {
886 : Tensor &reverse_d_weight_ih =
887 9 : context.getWeightGrad(wt_idx[LSTMParams::reverse_weight_ih]);
888 : const Tensor &reverse_weight_hh =
889 9 : context.getWeight(wt_idx[LSTMParams::reverse_weight_hh]);
890 : Tensor &reverse_d_weight_hh =
891 9 : context.getWeightGrad(wt_idx[LSTMParams::reverse_weight_hh]);
892 : Tensor &reverse_d_bias_h =
893 : !disable_bias && integrate_bias
894 9 : ? context.getWeightGrad(wt_idx[LSTMParams::reverse_bias_h])
895 : : empty;
896 : Tensor &reverse_d_bias_ih =
897 : !disable_bias && !integrate_bias
898 9 : ? context.getWeightGrad(wt_idx[LSTMParams::reverse_bias_ih])
899 : : empty;
900 : Tensor &reverse_d_bias_hh =
901 : !disable_bias && !integrate_bias
902 9 : ? context.getWeightGrad(wt_idx[LSTMParams::reverse_bias_hh])
903 : : empty;
904 :
905 : const Tensor &reverse_hidden_state =
906 9 : context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]);
907 : Tensor &reverse_d_hidden_state =
908 9 : context.getTensorGrad(wt_idx[LSTMParams::reverse_hidden_state]);
909 : const Tensor &reverse_cell_state =
910 9 : context.getTensor(wt_idx[LSTMParams::reverse_cell_state]);
911 : Tensor &reverse_d_cell_state =
912 9 : context.getTensorGrad(wt_idx[LSTMParams::reverse_cell_state]);
913 :
914 : const Tensor &reverse_ifgo =
915 9 : context.getTensor(wt_idx[LSTMParams::reverse_ifgo]);
916 : Tensor &reverse_d_ifgo =
917 9 : context.getTensorGrad(wt_idx[LSTMParams::reverse_ifgo]);
918 :
919 9 : calcGradientBatchFirstLSTM(
920 : NUM_GATE, batch_size, feature_size, disable_bias, unit, integrate_bias,
921 : acti_func, recurrent_acti_func, return_sequences, bidirectional,
922 : enable_dropout, dropout_rate, max_timestep, true, input,
923 : incoming_derivative, reverse_d_weight_ih, reverse_weight_hh,
924 : reverse_d_weight_hh, reverse_d_bias_h, reverse_d_bias_ih,
925 : reverse_d_bias_hh, reverse_hidden_state, reverse_d_hidden_state,
926 : reverse_cell_state, reverse_d_cell_state, reverse_ifgo, reverse_d_ifgo,
927 : mask);
928 : }
929 101 : }
930 :
931 36 : void LSTMLayer::setBatch(RunLayerContext &context, unsigned int batch) {
932 36 : const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
933 36 : const float dropout_rate = std::get<props::DropOutRate>(lstm_props).get();
934 :
935 36 : context.updateTensor(wt_idx[LSTMParams::hidden_state], batch);
936 36 : context.updateTensor(wt_idx[LSTMParams::cell_state], batch);
937 36 : context.updateTensor(wt_idx[LSTMParams::ifgo], batch);
938 :
939 36 : if (bidirectional) {
940 12 : context.updateTensor(wt_idx[LSTMParams::reverse_hidden_state], batch);
941 12 : context.updateTensor(wt_idx[LSTMParams::reverse_cell_state], batch);
942 12 : context.updateTensor(wt_idx[LSTMParams::reverse_ifgo], batch);
943 : }
944 :
945 36 : if (dropout_rate > epsilon) {
946 0 : context.updateTensor(wt_idx[LSTMParams::dropout_mask], batch);
947 : }
948 36 : }
949 :
950 : } // namespace nntrainer
|