Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
4 : *
5 : * @file grucell.cpp
6 : * @date 28 Oct 2021
7 : * @brief This is Gated Recurrent Unit Cell Layer Class of Neural Network
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author hyeonseok lee <hs89.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : * h_prev --------d1------->[*]-------d0----->[+]---d0--> h
13 : * d_h_prev | | | | d0 dh
14 : * | d14 | d2 d3 |
15 : * | | +-----[1-]------>[*]
16 : * | [*]<---+ d15 |d5 | d6
17 : * | | |reset_g| update_gate | memory_cell
18 : * | | [sig] [sig] [tanh]
19 : * | | |d16 | d7 |d8
20 : * | | [+] [+] [+]
21 : * | | / \d16 | \ d7 / \ d8
22 : * | | Whhr Wxhr Whhz Wxhz Whhg Wxhg
23 : * | | |d17 |d13 |d12 |d11 |d10 | d9
24 : * +- |--+------|---+ | | |
25 : * +---------|--------|----------+ |
26 : * xs------------------+--------+---------------+
27 : */
28 :
29 : #include <cmath>
30 :
31 : #include <grucell.h>
32 : #include <lazy_tensor.h>
33 : #include <nntrainer_error.h>
34 : #include <nntrainer_log.h>
35 : #include <node_exporter.h>
36 : #include <util_func.h>
37 :
38 : #include <layer_context.h>
39 :
40 : namespace nntrainer {
41 :
42 : /**
43 : * @brief gru forwarding
44 : *
45 : */
46 63 : static void grucell_forwarding(
47 : const unsigned int unit, const unsigned int batch_size,
48 : const bool disable_bias, const bool integrate_bias, const bool reset_after,
49 : ActiFunc &acti_func, ActiFunc &recurrent_acti_func, const Tensor &input,
50 : const Tensor &prev_hidden_state, Tensor &hidden_state,
51 : const Tensor &weight_ih, const Tensor &weight_hh, const Tensor &bias_h,
52 : const Tensor &bias_ih, const Tensor &bias_hh, Tensor &zrg) {
53 63 : input.dot(weight_ih, zrg);
54 :
55 : Tensor update_reset_gate =
56 63 : zrg.getSharedDataTensor({batch_size, 1, 1, 2 * unit}, 0, false);
57 : Tensor memory_cell =
58 63 : zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 2 * unit, false);
59 :
60 63 : Tensor weight_hh_update_reset_gate;
61 63 : Tensor weight_hh_memory_cell;
62 63 : weight_hh_update_reset_gate.copy_with_stride(
63 126 : weight_hh.getSharedDataTensor({unit, 2 * unit}, 0, false));
64 63 : weight_hh_memory_cell.copy_with_stride(
65 126 : weight_hh.getSharedDataTensor({unit, unit}, 2 * unit, false));
66 :
67 63 : update_reset_gate.add_i_strided(
68 126 : prev_hidden_state.dot(weight_hh_update_reset_gate));
69 63 : if (!disable_bias) {
70 63 : if (integrate_bias) {
71 : const Tensor bias_h_update_reset_gate =
72 10 : bias_h.getSharedDataTensor({2 * unit}, 0);
73 10 : update_reset_gate.add_i(bias_h_update_reset_gate);
74 10 : } else {
75 : const Tensor bias_ih_update_reset_gate =
76 53 : bias_ih.getSharedDataTensor({2 * unit}, 0);
77 53 : update_reset_gate.add_i(bias_ih_update_reset_gate);
78 : const Tensor bias_hh_update_reset_gate =
79 53 : bias_hh.getSharedDataTensor({2 * unit}, 0);
80 53 : update_reset_gate.add_i(bias_hh_update_reset_gate);
81 53 : }
82 : }
83 :
84 : recurrent_acti_func.run_fn(update_reset_gate, update_reset_gate);
85 :
86 : Tensor update_gate =
87 63 : update_reset_gate.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
88 : Tensor reset_gate = update_reset_gate.getSharedDataTensor(
89 63 : {batch_size, 1, 1, unit}, unit, false);
90 :
91 63 : Tensor temp;
92 63 : if (reset_after) {
93 53 : prev_hidden_state.dot(weight_hh_memory_cell, temp);
94 53 : if (!disable_bias && !integrate_bias) {
95 : const Tensor bias_hh_memory_cell =
96 53 : bias_hh.getSharedDataTensor({unit}, 2 * unit);
97 53 : temp.add_i(bias_hh_memory_cell);
98 53 : }
99 53 : temp.multiply_i_strided(reset_gate);
100 53 : memory_cell.add_i_strided(temp);
101 : } else {
102 10 : reset_gate.multiply_strided(prev_hidden_state, temp);
103 10 : memory_cell.add_i_strided(temp.dot(weight_hh_memory_cell));
104 10 : if (!disable_bias && !integrate_bias) {
105 : const Tensor bias_hh_memory_cell =
106 0 : bias_hh.getSharedDataTensor({unit}, 2 * unit);
107 0 : memory_cell.add_i(bias_hh_memory_cell);
108 0 : }
109 : }
110 63 : if (!disable_bias) {
111 63 : if (integrate_bias) {
112 : const Tensor bias_h_memory_cell =
113 10 : bias_h.getSharedDataTensor({unit}, 2 * unit);
114 10 : memory_cell.add_i(bias_h_memory_cell);
115 10 : } else {
116 : const Tensor bias_ih_memory_cell =
117 53 : bias_ih.getSharedDataTensor({unit}, 2 * unit);
118 53 : memory_cell.add_i(bias_ih_memory_cell);
119 53 : }
120 : }
121 :
122 : acti_func.run_fn(memory_cell, memory_cell);
123 :
124 63 : update_gate.multiply_strided(prev_hidden_state, hidden_state);
125 126 : temp = update_gate.multiply(-1.0).add(1.0);
126 63 : memory_cell.multiply_strided(temp, hidden_state, 1.0f);
127 63 : }
128 :
129 : /**
130 : * @brief gru calcGradient
131 : *
132 : */
133 27 : static void grucell_calcGradient(
134 : const unsigned int unit, const unsigned int batch_size,
135 : const bool disable_bias, const bool integrate_bias, const bool reset_after,
136 : ActiFunc &acti_func, ActiFunc &recurrent_acti_func, const Tensor &input,
137 : const Tensor &prev_hidden_state, Tensor &d_prev_hidden_state,
138 : const Tensor &d_hidden_state, Tensor &d_weight_ih, const Tensor &weight_hh,
139 : Tensor &d_weight_hh, Tensor &d_bias_h, Tensor &d_bias_ih,
140 : const Tensor &bias_hh, Tensor &d_bias_hh, const Tensor &zrg, Tensor &d_zrg) {
141 : Tensor d_weight_hh_update_reset_gate =
142 27 : d_weight_hh.getSharedDataTensor({unit, 2 * unit}, 0, false);
143 : Tensor d_weight_hh_memory_cell =
144 27 : d_weight_hh.getSharedDataTensor({unit, unit}, 2 * unit, false);
145 :
146 : Tensor update_gate =
147 27 : zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
148 : Tensor reset_gate =
149 27 : zrg.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
150 : Tensor memory_cell =
151 27 : zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 2 * unit, false);
152 :
153 : Tensor d_update_gate =
154 27 : d_zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
155 : Tensor d_reset_gate =
156 27 : d_zrg.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
157 : Tensor d_memory_cell =
158 27 : d_zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 2 * unit, false);
159 :
160 27 : d_hidden_state.multiply_strided(
161 : update_gate, d_prev_hidden_state); // d_prev_hidden_state = d1
162 27 : d_hidden_state.multiply_strided(prev_hidden_state,
163 : d_update_gate); // d_update_gate = d2
164 27 : d_update_gate.add_i_strided(d_hidden_state.multiply_strided(memory_cell),
165 : -1.0f); // d_update_gate = d5
166 27 : update_gate.multiply(-1.0, d_memory_cell);
167 27 : d_memory_cell.add_i(1.0);
168 27 : d_memory_cell.multiply_i_strided(d_hidden_state); // d_memory_cell = d6
169 :
170 27 : recurrent_acti_func.run_prime_fn(update_gate, d_update_gate,
171 : d_update_gate); // d_update_gate = d7
172 27 : acti_func.run_prime_fn(memory_cell, d_memory_cell,
173 : d_memory_cell); // d_memory_cell = d8
174 :
175 : Tensor d_update_reset_gate = d_zrg.getSharedDataTensor(
176 27 : {batch_size, 1, 1, 2 * unit}, 0, false); // d_update_gate+d_reset_gate
177 :
178 27 : Tensor weight_hh_memory_cell;
179 27 : weight_hh_memory_cell.copy_with_stride(
180 54 : weight_hh.getSharedDataTensor({unit, unit}, 2 * unit, false));
181 27 : Tensor weight_hh_update_reset_gate;
182 27 : weight_hh_update_reset_gate.copy_with_stride(
183 54 : weight_hh.getSharedDataTensor({unit, 2 * unit}, 0, false));
184 :
185 27 : Tensor temp = Tensor(batch_size, 1, 1, unit);
186 27 : Tensor d_memory_cell_contiguous;
187 27 : d_memory_cell_contiguous.copy_with_stride(d_memory_cell);
188 :
189 27 : if (reset_after) {
190 25 : prev_hidden_state.dot(weight_hh_memory_cell, temp);
191 25 : if (!disable_bias && !integrate_bias) {
192 : const Tensor bias_hh_memory_cell =
193 25 : bias_hh.getSharedDataTensor({unit}, 2 * unit);
194 25 : temp.add_i(bias_hh_memory_cell);
195 25 : }
196 25 : d_memory_cell_contiguous.multiply_strided(
197 : temp, d_reset_gate); // d_reset_gate = d15
198 :
199 : // reset temp: d_memory_cell_contiguous * reset_gate for
200 : // d_bias_hh_memory_cell, d_prev_hidden_state and d_weight_hh_memory_cell
201 25 : d_memory_cell_contiguous.multiply_strided(reset_gate, temp);
202 25 : if (!disable_bias && !integrate_bias) {
203 : Tensor d_bias_hh_memory_cell =
204 25 : d_bias_hh.getSharedDataTensor({unit}, 2 * unit);
205 25 : temp.sum(0, d_bias_hh_memory_cell, 1.0, 1.0);
206 25 : }
207 25 : temp.dot(weight_hh_memory_cell, d_prev_hidden_state, false, true,
208 : 1.0); // d_prev_hidden_state = d1 + d14
209 25 : d_weight_hh_memory_cell.add_i_strided(
210 50 : prev_hidden_state.dot(temp, true, false));
211 : } else {
212 2 : if (!disable_bias && !integrate_bias) {
213 : Tensor d_bias_hh_memory_cell =
214 0 : d_bias_hh.getSharedDataTensor({unit}, 2 * unit);
215 0 : d_memory_cell.sum(0, d_bias_hh_memory_cell, 1.0, 1.0);
216 0 : }
217 :
218 2 : d_memory_cell_contiguous.dot(weight_hh_memory_cell, temp, false, true);
219 2 : temp.multiply_strided(prev_hidden_state, d_reset_gate);
220 2 : temp.multiply_strided(reset_gate, d_prev_hidden_state, 1.0f);
221 :
222 : // reset temp: reset_gate * prev_hidden_state for and
223 : // d_weight_hh_memory_cell
224 2 : reset_gate.multiply_strided(prev_hidden_state, temp);
225 2 : d_weight_hh_memory_cell.add_i_strided(
226 4 : temp.dot(d_memory_cell_contiguous, true, false));
227 : }
228 :
229 27 : recurrent_acti_func.run_prime_fn(reset_gate, d_reset_gate,
230 : d_reset_gate); // d_reset_gate = d16
231 :
232 27 : if (!disable_bias) {
233 27 : if (integrate_bias) {
234 2 : d_zrg.sum(0, d_bias_h, 1.0, 1.0);
235 : } else {
236 25 : d_zrg.sum(0, d_bias_ih, 1.0, 1.0);
237 : Tensor d_bias_hh_update_reset_gate =
238 25 : d_bias_hh.getSharedDataTensor({2 * unit}, 0);
239 25 : d_bias_hh_update_reset_gate.add_i(
240 50 : d_zrg.sum(0).getSharedDataTensor({2 * unit}, 0));
241 25 : }
242 : }
243 :
244 27 : Tensor d_update_reset_gate_contiguous;
245 27 : d_update_reset_gate_contiguous.copy_with_stride(d_update_reset_gate);
246 27 : d_weight_hh_update_reset_gate.add_i_strided(
247 54 : prev_hidden_state.dot(d_update_reset_gate_contiguous, true, false));
248 27 : input.dot(d_zrg, d_weight_ih, true, false, 1.0f);
249 27 : d_update_reset_gate_contiguous.dot(
250 : weight_hh_update_reset_gate, d_prev_hidden_state, false, true,
251 : 1.0); // d_prev_hidden_state = d1 + d14 + d12 + d17
252 27 : }
253 :
254 : enum GRUCellParams {
255 : weight_ih,
256 : weight_hh,
257 : bias_h,
258 : bias_ih,
259 : bias_hh,
260 : zrg,
261 : dropout_mask
262 : };
263 :
264 : // Todo: handle with strided tensor more efficiently and reduce temporary
265 : // tensors
266 61 : GRUCellLayer::GRUCellLayer() :
267 : LayerImpl(),
268 244 : grucell_props(props::Unit(), props::IntegrateBias(), props::ResetAfter(),
269 122 : props::HiddenStateActivation() = ActivationType::ACT_TANH,
270 122 : props::RecurrentActivation() = ActivationType::ACT_SIGMOID,
271 61 : props::DropOutRate()),
272 61 : acti_func(ActivationType::ACT_NONE, true),
273 61 : recurrent_acti_func(ActivationType::ACT_NONE, true),
274 122 : epsilon(1e-3f) {
275 : wt_idx.fill(std::numeric_limits<unsigned>::max());
276 61 : }
277 :
278 39 : void GRUCellLayer::finalize(InitLayerContext &context) {
279 : const Initializer weight_initializer =
280 39 : std::get<props::WeightInitializer>(*layer_impl_props).get();
281 : const Initializer bias_initializer =
282 39 : std::get<props::BiasInitializer>(*layer_impl_props).get();
283 : const WeightRegularizer weight_regularizer =
284 39 : std::get<props::WeightRegularizer>(*layer_impl_props).get();
285 : const float weight_regularizer_constant =
286 39 : std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
287 : auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
288 : auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
289 : const bool disable_bias =
290 39 : std::get<props::DisableBias>(*layer_impl_props).get();
291 :
292 39 : const unsigned int unit = std::get<props::Unit>(grucell_props).get();
293 : const bool integrate_bias =
294 39 : std::get<props::IntegrateBias>(grucell_props).get();
295 : const ActivationType hidden_state_activation_type =
296 39 : std::get<props::HiddenStateActivation>(grucell_props).get();
297 : const ActivationType recurrent_activation_type =
298 39 : std::get<props::RecurrentActivation>(grucell_props).get();
299 39 : const float dropout_rate = std::get<props::DropOutRate>(grucell_props).get();
300 :
301 39 : NNTR_THROW_IF(context.getNumInputs() != 2, std::invalid_argument)
302 : << "GRUCell layer expects 2 inputs(one for the input and hidden state for "
303 0 : "the other) but got " +
304 0 : std::to_string(context.getNumInputs()) + " input(s)";
305 :
306 : // input_dim = [ batch_size, 1, 1, feature_size ]
307 : const TensorDim &input_dim = context.getInputDimensions()[0];
308 39 : NNTR_THROW_IF(input_dim.channel() != 1 && input_dim.height() != 1,
309 : std::invalid_argument)
310 : << "Input must be single time dimension for GRUCell(shape should be "
311 : "[batch_size, 1, 1, feature_size]";
312 : // input_hidden_state_dim = [ batch_size, 1, 1, unit ]
313 : const TensorDim &input_hidden_state_dim =
314 : context.getInputDimensions()[INOUT_INDEX::INPUT_HIDDEN_STATE];
315 39 : NNTR_THROW_IF(input_hidden_state_dim.channel() != 1 ||
316 : input_hidden_state_dim.height() != 1,
317 : std::invalid_argument)
318 : << "Input hidden state's dimension should be [batch, 1, 1, unit] for "
319 : "GRUCell";
320 :
321 39 : const unsigned int batch_size = input_dim.batch();
322 39 : const unsigned int feature_size = input_dim.width();
323 :
324 : // output_dim = [ batch_size, 1, 1, unit ]
325 39 : TensorDim output_dim(batch_size, 1, 1, unit);
326 39 : context.setOutputDimensions({output_dim});
327 :
328 : // weight_initializer can be set seperately. weight_ih initializer,
329 : // weight_hh initializer kernel initializer & recurrent_initializer in keras
330 : // for now, it is set same way.
331 :
332 : // - weight_ih ( input to hidden )
333 : // weight_ih_dim : [ 1, 1, feature_size, NUMGATE * unit ] -> z, r, g
334 39 : TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
335 39 : wt_idx[GRUCellParams::weight_ih] = context.requestWeight(
336 : weight_ih_dim, weight_initializer, weight_regularizer,
337 : weight_regularizer_constant, weight_decay, "weight_ih", true);
338 : // - weight_hh ( hidden to hidden )
339 : // weight_hh_dim : [ 1, 1, unit, NUM_GATE * unit ] -> z, r, g
340 39 : TensorDim weight_hh_dim({unit, NUM_GATE * unit});
341 78 : wt_idx[GRUCellParams::weight_hh] = context.requestWeight(
342 : weight_hh_dim, weight_initializer, weight_regularizer,
343 : weight_regularizer_constant, weight_decay, "weight_hh", true);
344 39 : if (!disable_bias) {
345 39 : if (integrate_bias) {
346 : // - bias_h ( input bias, hidden bias are integrate to 1 bias )
347 : // bias_h_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
348 2 : TensorDim bias_h_dim({NUM_GATE * unit});
349 2 : wt_idx[GRUCellParams::bias_h] = context.requestWeight(
350 : bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
351 : "bias_h", true);
352 : } else {
353 : // - bias_ih ( input bias )
354 : // bias_ih_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
355 37 : TensorDim bias_ih_dim({NUM_GATE * unit});
356 37 : wt_idx[GRUCellParams::bias_ih] = context.requestWeight(
357 : bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
358 : bias_decay, "bias_ih", true);
359 : // - bias_hh ( hidden bias )
360 : // bias_hh_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
361 37 : TensorDim bias_hh_dim({NUM_GATE * unit});
362 74 : wt_idx[GRUCellParams::bias_hh] = context.requestWeight(
363 : bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
364 : bias_decay, "bias_hh", true);
365 : }
366 : }
367 :
368 : // zrg_dim = [ batch_size, 1, 1, NUM_GATE * unit ]
369 39 : TensorDim zrg_dim(batch_size, 1, 1, NUM_GATE * unit);
370 39 : wt_idx[GRUCellParams::zrg] =
371 39 : context.requestTensor(zrg_dim, "zrg", Initializer::NONE, true,
372 : TensorLifespan::ITERATION_LIFESPAN);
373 :
374 39 : if (dropout_rate > epsilon) {
375 : // dropout_mask_dim = [ batch_size, 1, 1, unit ]
376 0 : TensorDim dropout_mask_dim(batch_size, 1, 1, unit);
377 0 : wt_idx[GRUCellParams::dropout_mask] =
378 0 : context.requestTensor(dropout_mask_dim, "dropout_mask", Initializer::NONE,
379 : false, TensorLifespan::ITERATION_LIFESPAN);
380 : }
381 :
382 39 : acti_func.setActiFunc(hidden_state_activation_type);
383 39 : recurrent_acti_func.setActiFunc(recurrent_activation_type);
384 39 : }
385 :
386 244 : void GRUCellLayer::setProperty(const std::vector<std::string> &values) {
387 244 : auto remain_props = loadProperties(values, grucell_props);
388 243 : LayerImpl::setProperty(remain_props);
389 243 : }
390 :
391 32 : void GRUCellLayer::exportTo(Exporter &exporter,
392 : const ml::train::ExportMethods &method) const {
393 32 : LayerImpl::exportTo(exporter, method);
394 32 : exporter.saveResult(grucell_props, method, this);
395 32 : }
396 :
397 63 : void GRUCellLayer::forwarding(RunLayerContext &context, bool training) {
398 : const bool disable_bias =
399 63 : std::get<props::DisableBias>(*layer_impl_props).get();
400 :
401 63 : const unsigned int unit = std::get<props::Unit>(grucell_props).get();
402 : const bool integrate_bias =
403 63 : std::get<props::IntegrateBias>(grucell_props).get();
404 63 : const bool reset_after = std::get<props::ResetAfter>(grucell_props).get();
405 63 : const float dropout_rate = std::get<props::DropOutRate>(grucell_props).get();
406 :
407 63 : const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
408 : const Tensor &prev_hidden_state =
409 63 : context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
410 : // hidden_state == output in grucell
411 63 : Tensor &hidden_state = context.getOutput(INOUT_INDEX::OUTPUT);
412 :
413 63 : const unsigned int batch_size = input.getDim().batch();
414 :
415 63 : const Tensor &weight_ih = context.getWeight(wt_idx[GRUCellParams::weight_ih]);
416 63 : const Tensor &weight_hh = context.getWeight(wt_idx[GRUCellParams::weight_hh]);
417 63 : Tensor empty;
418 63 : const Tensor &bias_h = !disable_bias && integrate_bias
419 63 : ? context.getWeight(wt_idx[GRUCellParams::bias_h])
420 : : empty;
421 : const Tensor &bias_ih = !disable_bias && !integrate_bias
422 63 : ? context.getWeight(wt_idx[GRUCellParams::bias_ih])
423 : : empty;
424 : const Tensor &bias_hh = !disable_bias && !integrate_bias
425 63 : ? context.getWeight(wt_idx[GRUCellParams::bias_hh])
426 : : empty;
427 :
428 63 : Tensor &zrg = context.getTensor(wt_idx[GRUCellParams::zrg]);
429 :
430 63 : grucell_forwarding(unit, batch_size, disable_bias, integrate_bias,
431 63 : reset_after, acti_func, recurrent_acti_func, input,
432 : prev_hidden_state, hidden_state, weight_ih, weight_hh,
433 : bias_h, bias_ih, bias_hh, zrg);
434 :
435 63 : if (dropout_rate > epsilon && training) {
436 0 : Tensor mask = context.getTensor(wt_idx[GRUCellParams::dropout_mask]);
437 0 : mask.dropout_mask(dropout_rate);
438 0 : hidden_state.multiply_i(mask);
439 0 : }
440 63 : }
441 :
442 27 : void GRUCellLayer::calcDerivative(RunLayerContext &context) {
443 : Tensor &outgoing_derivative =
444 27 : context.getOutgoingDerivative(INOUT_INDEX::INPUT);
445 27 : const Tensor &weight_ih = context.getWeight(wt_idx[GRUCellParams::weight_ih]);
446 27 : const Tensor &d_zrg = context.getTensorGrad(wt_idx[GRUCellParams::zrg]);
447 :
448 27 : d_zrg.dot(weight_ih, outgoing_derivative, false, true);
449 27 : }
450 :
451 27 : void GRUCellLayer::calcGradient(RunLayerContext &context) {
452 : const bool disable_bias =
453 27 : std::get<props::DisableBias>(*layer_impl_props).get();
454 :
455 27 : const unsigned int unit = std::get<props::Unit>(grucell_props).get();
456 : const bool integrate_bias =
457 27 : std::get<props::IntegrateBias>(grucell_props).get();
458 27 : const bool reset_after = std::get<props::ResetAfter>(grucell_props).get();
459 27 : const float dropout_rate = std::get<props::DropOutRate>(grucell_props).get();
460 :
461 27 : const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
462 : const Tensor &prev_hidden_state =
463 27 : context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
464 : Tensor &d_prev_hidden_state =
465 27 : context.getOutgoingDerivative(INOUT_INDEX::INPUT_HIDDEN_STATE);
466 : const Tensor &incoming_derivative =
467 27 : context.getIncomingDerivative(INOUT_INDEX::OUTPUT);
468 :
469 27 : const unsigned int batch_size = input.getDim().batch();
470 :
471 27 : Tensor &d_weight_ih = context.getWeightGrad(wt_idx[GRUCellParams::weight_ih]);
472 27 : const Tensor &weight_hh = context.getWeight(wt_idx[GRUCellParams::weight_hh]);
473 27 : Tensor &d_weight_hh = context.getWeightGrad(wt_idx[GRUCellParams::weight_hh]);
474 :
475 27 : Tensor empty;
476 27 : Tensor &d_bias_h = !disable_bias && integrate_bias
477 27 : ? context.getWeightGrad(wt_idx[GRUCellParams::bias_h])
478 : : empty;
479 : Tensor &d_bias_ih = !disable_bias && !integrate_bias
480 27 : ? context.getWeightGrad(wt_idx[GRUCellParams::bias_ih])
481 : : empty;
482 : const Tensor &bias_hh = !disable_bias && !integrate_bias
483 27 : ? context.getWeight(wt_idx[GRUCellParams::bias_hh])
484 : : empty;
485 : Tensor &d_bias_hh = !disable_bias && !integrate_bias
486 27 : ? context.getWeightGrad(wt_idx[GRUCellParams::bias_hh])
487 : : empty;
488 :
489 27 : const Tensor &zrg = context.getTensor(wt_idx[GRUCellParams::zrg]);
490 27 : Tensor &d_zrg = context.getTensorGrad(wt_idx[GRUCellParams::zrg]);
491 :
492 27 : if (context.isGradientFirstAccess(wt_idx[GRUCellParams::weight_ih])) {
493 12 : d_weight_ih.setZero();
494 : }
495 27 : if (context.isGradientFirstAccess(wt_idx[GRUCellParams::weight_hh])) {
496 12 : d_weight_hh.setZero();
497 : }
498 27 : if (!disable_bias) {
499 27 : if (integrate_bias) {
500 2 : if (context.isGradientFirstAccess(wt_idx[GRUCellParams::bias_h])) {
501 0 : d_bias_h.setZero();
502 : }
503 : } else {
504 25 : if (context.isGradientFirstAccess(wt_idx[GRUCellParams::bias_ih])) {
505 12 : d_bias_ih.setZero();
506 : }
507 25 : if (context.isGradientFirstAccess(wt_idx[GRUCellParams::bias_hh])) {
508 12 : d_bias_hh.setZero();
509 : }
510 : }
511 : }
512 :
513 27 : Tensor incoming_derivative_masked(batch_size, 1, 1, unit);
514 27 : if (dropout_rate > epsilon) {
515 0 : incoming_derivative.multiply_strided(
516 0 : context.getTensor(wt_idx[GRUCellParams::dropout_mask]),
517 : incoming_derivative_masked);
518 : }
519 :
520 54 : grucell_calcGradient(
521 27 : unit, batch_size, disable_bias, integrate_bias, reset_after, acti_func,
522 27 : recurrent_acti_func, input, prev_hidden_state, d_prev_hidden_state,
523 27 : dropout_rate > epsilon ? incoming_derivative_masked : incoming_derivative,
524 : d_weight_ih, weight_hh, d_weight_hh, d_bias_h, d_bias_ih, bias_hh,
525 : d_bias_hh, zrg, d_zrg);
526 27 : }
527 :
528 32 : void GRUCellLayer::setBatch(RunLayerContext &context, unsigned int batch) {
529 32 : const float dropout_rate = std::get<props::DropOutRate>(grucell_props);
530 :
531 32 : context.updateTensor(wt_idx[GRUCellParams::zrg], batch);
532 :
533 32 : if (dropout_rate > epsilon) {
534 0 : context.updateTensor(wt_idx[GRUCellParams::dropout_mask], batch);
535 : }
536 32 : }
537 :
538 : } // namespace nntrainer
|