Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
4 : *
5 : * @file zoneout_lstmcell.cpp
6 : * @date 30 November 2021
7 : * @brief This is ZoneoutLSTMCell Layer Class of Neural Network
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * https://arxiv.org/pdf/1606.01305.pdf
10 : * https://github.com/teganmaharaj/zoneout
11 : * @author hyeonseok lee <hs89.lee@samsung.com>
12 : * @bug No known bugs except for NYI items
13 : *
14 : */
15 :
16 : #include <layer_context.h>
17 : #include <nntrainer_error.h>
18 : #include <nntrainer_log.h>
19 : #include <node_exporter.h>
20 : #include <zoneout_lstmcell.h>
21 :
22 : namespace nntrainer {
23 :
24 : enum ZoneoutLSTMParams {
25 : weight_ih,
26 : weight_hh,
27 : bias_h,
28 : bias_ih,
29 : bias_hh,
30 : ifgo,
31 : hidden_state_zoneout_mask,
32 : cell_state_zoneout_mask,
33 : lstm_cell_state,
34 : };
35 :
36 270 : ZoneoutLSTMCellLayer::ZoneoutLSTMCellLayer() :
37 270 : zoneout_lstmcell_props(HiddenStateZoneOutRate(), CellStateZoneOutRate(),
38 540 : Test(), props::MaxTimestep(), props::Timestep()) {
39 : wt_idx.fill(std::numeric_limits<unsigned>::max());
40 270 : }
41 :
42 270 : bool ZoneoutLSTMCellLayer::HiddenStateZoneOutRate::isValid(
43 : const float &value) const {
44 270 : if (value < 0.0f || value > 1.0f) {
45 : return false;
46 : } else {
47 270 : return true;
48 : }
49 : }
50 :
51 270 : bool ZoneoutLSTMCellLayer::CellStateZoneOutRate::isValid(
52 : const float &value) const {
53 270 : if (value < 0.0f || value > 1.0f) {
54 : return false;
55 : } else {
56 270 : return true;
57 : }
58 : }
59 :
60 216 : void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
61 : const Initializer weight_initializer =
62 216 : std::get<props::WeightInitializer>(*layer_impl_props).get();
63 : const Initializer bias_initializer =
64 216 : std::get<props::BiasInitializer>(*layer_impl_props).get();
65 : const WeightRegularizer weight_regularizer =
66 216 : std::get<props::WeightRegularizer>(*layer_impl_props).get();
67 : const float weight_regularizer_constant =
68 216 : std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
69 : auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
70 : auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
71 : const bool disable_bias =
72 216 : std::get<props::DisableBias>(*layer_impl_props).get();
73 :
74 216 : NNTR_THROW_IF(std::get<props::Unit>(lstmcore_props).empty(),
75 : std::invalid_argument)
76 : << "unit property missing for zoneout_lstmcell layer";
77 216 : const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
78 : const bool integrate_bias =
79 216 : std::get<props::IntegrateBias>(lstmcore_props).get();
80 : const ActivationType hidden_state_activation_type =
81 216 : std::get<props::HiddenStateActivation>(lstmcore_props).get();
82 : const ActivationType recurrent_activation_type =
83 216 : std::get<props::RecurrentActivation>(lstmcore_props).get();
84 :
85 216 : const bool test = std::get<Test>(zoneout_lstmcell_props).get();
86 : const unsigned int max_timestep =
87 216 : std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
88 :
89 216 : NNTR_THROW_IF(context.getNumInputs() != 3, std::invalid_argument)
90 : << "Number of input is not 3. ZoneoutLSTMCellLayer should takes 3 inputs";
91 :
92 : // input_dim = [ batch_size, 1, 1, feature_size ]
93 : const TensorDim &input_dim = context.getInputDimensions()[INOUT_INDEX::INPUT];
94 216 : if (input_dim.channel() != 1 || input_dim.height() != 1)
95 : throw std::invalid_argument("Input must be single time dimension for "
96 : "ZoneoutLSTMCell (shape should be "
97 0 : "[batch_size, 1, 1, feature_size])");
98 : // input_hidden_state_dim = [ batch_size, 1, 1, unit ]
99 : const TensorDim &input_hidden_state_dim =
100 : context.getInputDimensions()[INOUT_INDEX::INPUT_HIDDEN_STATE];
101 432 : if (input_hidden_state_dim.channel() != 1 ||
102 216 : input_hidden_state_dim.height() != 1) {
103 : throw std::invalid_argument(
104 : "Input hidden state's dimension should be"
105 0 : "[batch_size, 1, 1, unit] for zoneout LSTMcell");
106 : }
107 : // input_cell_state_dim = [ batch_size, 1, 1, unit ]
108 : const TensorDim &input_cell_state_dim =
109 : context.getInputDimensions()[INOUT_INDEX::INPUT_CELL_STATE];
110 432 : if (input_cell_state_dim.channel() != 1 ||
111 216 : input_cell_state_dim.height() != 1) {
112 : throw std::invalid_argument(
113 : "Input cell state's dimension should be"
114 0 : "[batch_size, 1, 1, unit] for zoneout LSTMcell");
115 : }
116 216 : const unsigned int batch_size = input_dim.batch();
117 216 : const unsigned int feature_size = input_dim.width();
118 :
119 : // output_hidden_state_dim = [ batch_size, 1, 1, unit ]
120 216 : const TensorDim output_hidden_state_dim = input_hidden_state_dim;
121 : // output_cell_state_dim = [ batch_size, 1, 1, unit ]
122 216 : const TensorDim output_cell_state_dim = input_cell_state_dim;
123 :
124 : std::vector<VarGradSpecV2> out_specs;
125 : /// note: those out spec can be forward func, but for the test, it is being
126 : /// kept to forward deriv lifespan
127 : out_specs.push_back(
128 216 : InitLayerContext::outSpec(output_hidden_state_dim, "output_hidden_state",
129 : TensorLifespan::FORWARD_DERIV_LIFESPAN));
130 : ////////////////////////// TensorLifespan::FORWARD_FUNC_LIFESPAN));
131 : out_specs.push_back(
132 216 : InitLayerContext::outSpec(output_cell_state_dim, "output_cell_state",
133 : TensorLifespan::FORWARD_DERIV_LIFESPAN));
134 : // cell_state_zoneout_rate > epsilon ? TensorLifespan::FORWARD_FUNC_LIFESPAN
135 : // :
136 : // TensorLifespan::FORWARD_GRAD_LIFESPAN));
137 216 : context.requestOutputs(std::move(out_specs));
138 :
139 : // weight_initializer can be set seperately.
140 : // weight_ih initializer, weight_hh initializer
141 : // kernel initializer & recurrent_initializer in
142 : // keras for now, it is set same way.
143 :
144 : // - weight_ih ( input to hidden )
145 : // : [ 1, 1, feature_size, NUM_GATE x unit ] ->
146 : // i, f, g, o
147 216 : TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
148 216 : wt_idx[ZoneoutLSTMParams::weight_ih] = context.requestWeight(
149 : weight_ih_dim, weight_initializer, weight_regularizer,
150 : weight_regularizer_constant, weight_decay, "weight_ih", true);
151 : // - weight_hh ( hidden to hidden )
152 : // : [ 1, 1, unit, NUM_GATE x unit ] -> i, f, g,
153 : // o
154 216 : TensorDim weight_hh_dim({unit, NUM_GATE * unit});
155 432 : wt_idx[ZoneoutLSTMParams::weight_hh] = context.requestWeight(
156 : weight_hh_dim, weight_initializer, weight_regularizer,
157 : weight_regularizer_constant, weight_decay, "weight_hh", true);
158 216 : if (!disable_bias) {
159 216 : if (integrate_bias) {
160 : // - bias_h ( input bias, hidden bias are
161 : // integrate to 1 bias )
162 : // : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g,
163 : // o
164 0 : TensorDim bias_h_dim({NUM_GATE * unit});
165 0 : wt_idx[ZoneoutLSTMParams::bias_h] = context.requestWeight(
166 : bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
167 : "bias_h", true);
168 : } else {
169 : // - bias_ih ( input bias )
170 : // : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g,
171 : // o
172 216 : TensorDim bias_ih_dim({NUM_GATE * unit});
173 216 : wt_idx[ZoneoutLSTMParams::bias_ih] = context.requestWeight(
174 : bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
175 : bias_decay, "bias_ih", true);
176 : // - bias_hh ( hidden bias )
177 : // : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g,
178 : // o
179 216 : TensorDim bias_hh_dim({NUM_GATE * unit});
180 432 : wt_idx[ZoneoutLSTMParams::bias_hh] = context.requestWeight(
181 : bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
182 : bias_decay, "bias_hh", true);
183 : }
184 : }
185 :
186 : /** ifgo_dim = [ batch_size, 1, 1, NUM_GATE * unit
187 : * ] */
188 216 : const TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit);
189 216 : wt_idx[ZoneoutLSTMParams::ifgo] =
190 432 : context.requestTensor(ifgo_dim, "ifgo", Initializer::NONE, true,
191 : TensorLifespan::ITERATION_LIFESPAN);
192 :
193 : // hidden_state_zoneout_mask_dim = [ max_timestep
194 : // * batch_size, 1, 1, unit ]
195 216 : const TensorDim hidden_state_zoneout_mask_dim(max_timestep * batch_size, 1, 1,
196 216 : unit);
197 216 : if (test) {
198 216 : wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask] =
199 432 : context.requestWeight(hidden_state_zoneout_mask_dim, Initializer::NONE,
200 : WeightRegularizer::NONE, 1.0f, 0.0f,
201 : "hidden_state_zoneout_mask", false);
202 : } else {
203 0 : wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask] =
204 0 : context.requestTensor(hidden_state_zoneout_mask_dim,
205 : "hidden_state_zoneout_mask", Initializer::NONE,
206 : false, TensorLifespan::ITERATION_LIFESPAN, false);
207 : }
208 :
209 : // cell_state_zoneout_mask_dim = [ max_timestep * batch_size, 1, 1, unit ]
210 : const TensorDim cell_state_zoneout_mask_dim(max_timestep * batch_size, 1, 1,
211 216 : unit);
212 216 : if (test) {
213 216 : wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestWeight(
214 : cell_state_zoneout_mask_dim, Initializer::NONE, WeightRegularizer::NONE,
215 : 1.0f, 0.0f, "cell_state_zoneout_mask", false);
216 : } else {
217 0 : wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestTensor(
218 : cell_state_zoneout_mask_dim, "cell_state_zoneout_mask", Initializer::NONE,
219 : false, TensorLifespan::ITERATION_LIFESPAN, false);
220 : }
221 :
222 : // lstm_cell_state_dim = [ batch_size, 1, 1, unit ]
223 216 : const TensorDim lstm_cell_state_dim(batch_size, 1, 1, unit);
224 216 : wt_idx[ZoneoutLSTMParams::lstm_cell_state] = context.requestTensor(
225 : lstm_cell_state_dim, "lstm_cell_state", Initializer::NONE, true,
226 : TensorLifespan::ITERATION_LIFESPAN);
227 :
228 216 : acti_func.setActiFunc(hidden_state_activation_type);
229 216 : recurrent_acti_func.setActiFunc(recurrent_activation_type);
230 216 : }
231 :
232 1512 : void ZoneoutLSTMCellLayer::setProperty(const std::vector<std::string> &values) {
233 : const std::vector<std::string> &remain_props =
234 1512 : loadProperties(values, zoneout_lstmcell_props);
235 1512 : LSTMCore::setProperty(remain_props);
236 1512 : }
237 :
238 216 : void ZoneoutLSTMCellLayer::exportTo(
239 : Exporter &exporter, const ml::train::ExportMethods &method) const {
240 216 : LSTMCore::exportTo(exporter, method);
241 216 : exporter.saveResult(zoneout_lstmcell_props, method, this);
242 216 : }
243 :
244 324 : void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
245 : const bool disable_bias =
246 324 : std::get<props::DisableBias>(*layer_impl_props).get();
247 :
248 324 : const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
249 : const bool integrate_bias =
250 324 : std::get<props::IntegrateBias>(lstmcore_props).get();
251 :
252 : const float hidden_state_zoneout_rate =
253 324 : std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props).get();
254 : const float cell_state_zoneout_rate =
255 324 : std::get<CellStateZoneOutRate>(zoneout_lstmcell_props).get();
256 324 : const bool test = std::get<Test>(zoneout_lstmcell_props).get();
257 : const unsigned int max_timestep =
258 324 : std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
259 : const unsigned int timestep =
260 324 : std::get<props::Timestep>(zoneout_lstmcell_props).get();
261 :
262 324 : const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
263 : const Tensor &prev_hidden_state =
264 324 : context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
265 : const Tensor &prev_cell_state =
266 324 : context.getInput(INOUT_INDEX::INPUT_CELL_STATE);
267 324 : Tensor &hidden_state = context.getOutput(INOUT_INDEX::OUTPUT_HIDDEN_STATE);
268 324 : Tensor &cell_state = context.getOutput(INOUT_INDEX::OUTPUT_CELL_STATE);
269 :
270 324 : const unsigned int batch_size = input.getDim().batch();
271 :
272 : const Tensor &weight_ih =
273 324 : context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]);
274 : const Tensor &weight_hh =
275 324 : context.getWeight(wt_idx[ZoneoutLSTMParams::weight_hh]);
276 324 : Tensor empty;
277 : const Tensor &bias_h =
278 324 : !disable_bias && integrate_bias
279 324 : ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_h])
280 : : empty;
281 : const Tensor &bias_ih =
282 : !disable_bias && !integrate_bias
283 324 : ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_ih])
284 : : empty;
285 : const Tensor &bias_hh =
286 : !disable_bias && !integrate_bias
287 324 : ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_hh])
288 : : empty;
289 :
290 324 : Tensor &ifgo = context.getTensor(wt_idx[ZoneoutLSTMParams::ifgo]);
291 : Tensor &lstm_cell_state =
292 324 : context.getTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state]);
293 :
294 324 : forwardLSTM(batch_size, unit, disable_bias, integrate_bias, acti_func,
295 324 : recurrent_acti_func, input, prev_hidden_state, prev_cell_state,
296 : hidden_state, lstm_cell_state, weight_ih, weight_hh, bias_h,
297 : bias_ih, bias_hh, ifgo);
298 :
299 324 : if (training) {
300 : Tensor &hs_zoneout_mask =
301 162 : test ? context.getWeight(
302 : wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
303 0 : : context.getTensor(
304 : wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
305 162 : hs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
306 : Tensor hidden_state_zoneout_mask =
307 162 : hs_zoneout_mask.getBatchSlice(timestep, 1);
308 162 : hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
309 162 : Tensor prev_hidden_state_zoneout_mask;
310 162 : if (!test) {
311 : prev_hidden_state_zoneout_mask =
312 0 : hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
313 : } else {
314 162 : hidden_state_zoneout_mask.multiply(-1.0f, prev_hidden_state_zoneout_mask);
315 162 : prev_hidden_state_zoneout_mask.add_i(1.0f);
316 : }
317 :
318 162 : hidden_state.multiply_i(hidden_state_zoneout_mask);
319 162 : prev_hidden_state.multiply(prev_hidden_state_zoneout_mask, hidden_state,
320 : 1.0f);
321 : Tensor &cs_zoneout_mask =
322 : test
323 162 : ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
324 0 : : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
325 162 : cs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
326 162 : Tensor cell_state_zoneout_mask = cs_zoneout_mask.getBatchSlice(timestep, 1);
327 162 : cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
328 162 : Tensor prev_cell_state_zoneout_mask;
329 162 : if (!test) {
330 : prev_cell_state_zoneout_mask =
331 0 : cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
332 : } else {
333 162 : cell_state_zoneout_mask.multiply(-1.0f, prev_cell_state_zoneout_mask);
334 162 : prev_cell_state_zoneout_mask.add_i(1.0f);
335 : }
336 :
337 162 : lstm_cell_state.multiply(cell_state_zoneout_mask, cell_state);
338 162 : prev_cell_state.multiply(prev_cell_state_zoneout_mask, cell_state, 1.0f);
339 162 : }
340 : // Todo: zoneout at inference
341 324 : }
342 :
343 162 : void ZoneoutLSTMCellLayer::calcDerivative(RunLayerContext &context) {
344 : Tensor &outgoing_derivative =
345 162 : context.getOutgoingDerivative(INOUT_INDEX::INPUT);
346 : const Tensor &weight_ih =
347 162 : context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]);
348 162 : const Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]);
349 :
350 162 : calcDerivativeLSTM(outgoing_derivative, weight_ih, d_ifgo);
351 162 : }
352 :
353 162 : void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
354 : const bool disable_bias =
355 162 : std::get<props::DisableBias>(*layer_impl_props).get();
356 :
357 162 : const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
358 : const bool integrate_bias =
359 162 : std::get<props::IntegrateBias>(lstmcore_props).get();
360 :
361 162 : const bool test = std::get<Test>(zoneout_lstmcell_props).get();
362 : const unsigned int max_timestep =
363 162 : std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
364 : const unsigned int timestep =
365 162 : std::get<props::Timestep>(zoneout_lstmcell_props).get();
366 :
367 162 : const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
368 : const Tensor &prev_hidden_state =
369 162 : context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
370 : Tensor &d_prev_hidden_state =
371 162 : context.getOutgoingDerivative(INOUT_INDEX::INPUT_HIDDEN_STATE);
372 : const Tensor &prev_cell_state =
373 162 : context.getInput(INOUT_INDEX::INPUT_CELL_STATE);
374 : Tensor &d_prev_cell_state =
375 162 : context.getOutgoingDerivative(INOUT_INDEX::INPUT_CELL_STATE);
376 : const Tensor &d_hidden_state =
377 162 : context.getIncomingDerivative(INOUT_INDEX::OUTPUT_HIDDEN_STATE);
378 : const Tensor &d_cell_state =
379 162 : context.getIncomingDerivative(INOUT_INDEX::OUTPUT_CELL_STATE);
380 :
381 162 : unsigned int batch_size = input.getDim().batch();
382 :
383 : Tensor &d_weight_ih =
384 162 : context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_ih]);
385 : const Tensor &weight_hh =
386 162 : context.getWeight(wt_idx[ZoneoutLSTMParams::weight_hh]);
387 : Tensor &d_weight_hh =
388 162 : context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_hh]);
389 162 : Tensor empty;
390 : Tensor &d_bias_h =
391 162 : !disable_bias && integrate_bias
392 162 : ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_h])
393 : : empty;
394 : Tensor &d_bias_ih =
395 : !disable_bias && !integrate_bias
396 162 : ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_ih])
397 : : empty;
398 : Tensor &d_bias_hh =
399 : !disable_bias && !integrate_bias
400 162 : ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_hh])
401 : : empty;
402 :
403 162 : Tensor &ifgo = context.getTensor(wt_idx[ZoneoutLSTMParams::ifgo]);
404 162 : Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]);
405 :
406 : const Tensor &lstm_cell_state =
407 162 : context.getTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state]);
408 : Tensor &d_lstm_cell_state =
409 162 : context.getTensorGrad(wt_idx[ZoneoutLSTMParams::lstm_cell_state]);
410 :
411 162 : if (context.isGradientFirstAccess(wt_idx[ZoneoutLSTMParams::weight_ih])) {
412 81 : d_weight_ih.setZero();
413 : }
414 162 : if (context.isGradientFirstAccess(wt_idx[ZoneoutLSTMParams::weight_hh])) {
415 81 : d_weight_hh.setZero();
416 : }
417 162 : if (!disable_bias) {
418 162 : if (integrate_bias) {
419 0 : if (context.isGradientFirstAccess(wt_idx[ZoneoutLSTMParams::bias_h])) {
420 0 : d_bias_h.setZero();
421 : }
422 : } else {
423 162 : if (context.isGradientFirstAccess(wt_idx[ZoneoutLSTMParams::bias_ih])) {
424 81 : d_bias_ih.setZero();
425 : }
426 162 : if (context.isGradientFirstAccess(wt_idx[ZoneoutLSTMParams::bias_hh])) {
427 81 : d_bias_hh.setZero();
428 : }
429 : }
430 : }
431 :
432 162 : Tensor d_prev_hidden_state_residual;
433 162 : Tensor d_hidden_state_masked;
434 : Tensor &hs_zoneout_mask =
435 : test
436 162 : ? context.getWeight(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
437 0 : : context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
438 162 : hs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
439 162 : Tensor hidden_state_zoneout_mask = hs_zoneout_mask.getBatchSlice(timestep, 1);
440 162 : hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
441 : Tensor prev_hidden_state_zoneout_mask = hidden_state_zoneout_mask.apply(
442 162 : (std::function<float(float)>)[epsilon = epsilon](float x) {
443 324 : return x < epsilon;
444 162 : });
445 :
446 162 : d_hidden_state.multiply(prev_hidden_state_zoneout_mask,
447 : d_prev_hidden_state_residual);
448 162 : d_hidden_state.multiply(hidden_state_zoneout_mask, d_hidden_state_masked);
449 :
450 162 : Tensor d_prev_cell_state_residual;
451 : Tensor &cs_zoneout_mask =
452 : test
453 162 : ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
454 0 : : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
455 162 : cs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
456 162 : Tensor cell_state_zoneout_mask = cs_zoneout_mask.getBatchSlice(timestep, 1);
457 162 : cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
458 : Tensor prev_cell_state_zoneout_mask = cell_state_zoneout_mask.apply(
459 162 : (std::function<float(float)>)[epsilon = epsilon](float x) {
460 324 : return x < epsilon;
461 162 : });
462 :
463 162 : d_cell_state.multiply(prev_cell_state_zoneout_mask,
464 : d_prev_cell_state_residual);
465 162 : d_cell_state.multiply(cell_state_zoneout_mask, d_lstm_cell_state);
466 :
467 162 : calcGradientLSTM(batch_size, unit, disable_bias, integrate_bias, acti_func,
468 162 : recurrent_acti_func, input, prev_hidden_state,
469 : d_prev_hidden_state, prev_cell_state, d_prev_cell_state,
470 : d_hidden_state_masked, lstm_cell_state, d_lstm_cell_state,
471 : d_weight_ih, weight_hh, d_weight_hh, d_bias_h, d_bias_ih,
472 : d_bias_hh, ifgo, d_ifgo);
473 :
474 162 : d_prev_hidden_state.add_i(d_prev_hidden_state_residual);
475 162 : d_prev_cell_state.add_i(d_prev_cell_state_residual);
476 162 : }
477 :
478 0 : void ZoneoutLSTMCellLayer::setBatch(RunLayerContext &context,
479 : unsigned int batch) {
480 : const unsigned int max_timestep =
481 0 : std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
482 :
483 0 : context.updateTensor(wt_idx[ZoneoutLSTMParams::ifgo], batch);
484 :
485 0 : context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask],
486 : max_timestep * batch);
487 0 : context.updateTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask],
488 : max_timestep * batch);
489 0 : context.updateTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state], batch);
490 0 : }
491 :
492 : } // namespace nntrainer
|