Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Jijoong Moon <jijoong.moon@samsung.com>
4 : *
5 : * @file gru.cpp
6 : * @date 17 March 2021
7 : * @brief This is Gated Recurrent Unit Layer Class of Neural Network
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jijoong Moon <jijoong.moon@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : * h_prev --------d1------->[*]-------d0----->[+]---d0--> h
13 : * dh_nx | | | | d0 dh
14 : * | d14 | d2 d3 |
15 : * | | +-----[1-]------>[*]
16 : * | [*]<---+ d15 |d5 | d6
17 : * | | |rt | zt |gt
18 : * | | [sig] [sig] [tanh]
19 : * | | |d16 | d7 |d8
20 : * | | [+] [+] [+]
21 : * | | / \d16 | \ d7 / \ d8
22 : * | | Whhr Wxhr Whhz Wxhz Whhg Wxhg
23 : * | | |d17 |d13 |d12 |d11 |d10 | d9
24 : * +- |--+------|---+ | | |
25 : * +---------|--------|----------+ |
26 : * xs------------------+--------+---------------+
27 : */
28 :
29 : #include <cmath>
30 : #include <gru.h>
31 : #include <layer_context.h>
32 : #include <nntrainer_error.h>
33 : #include <nntrainer_log.h>
34 : #include <node_exporter.h>
35 : #include <util_func.h>
36 :
37 : namespace nntrainer {
38 :
39 : static constexpr size_t SINGLE_INOUT_IDX = 0;
40 :
41 : enum GRUParams {
42 : weight_ih,
43 : weight_hh,
44 : bias_h,
45 : bias_ih,
46 : bias_hh,
47 : hidden_state,
48 : zrg,
49 : h_prev,
50 : dropout_mask
51 : };
52 :
53 86 : GRULayer::GRULayer() :
54 : LayerImpl(),
55 430 : gru_props(props::Unit(),
56 172 : props::HiddenStateActivation() = ActivationType::ACT_TANH,
57 172 : props::RecurrentActivation() = ActivationType::ACT_SIGMOID,
58 172 : props::ReturnSequences(), props::DropOutRate(),
59 172 : props::IntegrateBias(), props::ResetAfter()),
60 86 : acti_func(ActivationType::ACT_NONE, true),
61 86 : recurrent_acti_func(ActivationType::ACT_NONE, true),
62 172 : epsilon(1e-3f) {
63 : wt_idx.fill(std::numeric_limits<unsigned>::max());
64 86 : }
65 :
66 72 : void GRULayer::finalize(InitLayerContext &context) {
67 : const Initializer weight_initializer =
68 72 : std::get<props::WeightInitializer>(*layer_impl_props).get();
69 : const Initializer bias_initializer =
70 72 : std::get<props::BiasInitializer>(*layer_impl_props).get();
71 : const WeightRegularizer weight_regularizer =
72 72 : std::get<props::WeightRegularizer>(*layer_impl_props).get();
73 : const float weight_regularizer_constant =
74 72 : std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
75 : auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
76 : auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
77 : const bool disable_bias =
78 72 : std::get<props::DisableBias>(*layer_impl_props).get();
79 :
80 72 : const unsigned int unit = std::get<props::Unit>(gru_props).get();
81 : ActivationType hidden_state_activation_type =
82 72 : std::get<props::HiddenStateActivation>(gru_props).get();
83 : ActivationType recurrent_activation_type =
84 72 : std::get<props::RecurrentActivation>(gru_props).get();
85 : const bool return_sequences =
86 72 : std::get<props::ReturnSequences>(gru_props).get();
87 72 : const float dropout_rate = std::get<props::DropOutRate>(gru_props).get();
88 72 : const bool integrate_bias = std::get<props::IntegrateBias>(gru_props).get();
89 :
90 72 : NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument)
91 : << "GRU layer takes only one input";
92 :
93 : // input_dim = [ batch, 1, time_iteration, feature_size ]
94 : const TensorDim &input_dim = context.getInputDimensions()[0];
95 72 : const unsigned int batch_size = input_dim.batch();
96 72 : const unsigned int max_timestep = input_dim.height();
97 72 : NNTR_THROW_IF(max_timestep < 1, std::runtime_error)
98 : << "max timestep must be greator than 0 in gru layer.";
99 72 : const unsigned int feature_size = input_dim.width();
100 :
101 : // if return_sequences == False :
102 : // output_dim = [ batch, 1, 1, unit ]
103 : // else:
104 : // output_dim = [ batch, 1, time_iteration, unit ]
105 : TensorDim output_dim(
106 112 : {batch_size, 1, return_sequences ? max_timestep : 1, unit});
107 72 : context.setOutputDimensions({output_dim});
108 :
109 : // weight_initializer can be set seperately. weight_ih initializer,
110 : // weight_hh initializer kernel initializer & recurrent_initializer in keras
111 : // for now, it is set same way.
112 :
113 : // - weight_ih ( input to hidden )
114 : // weight_ih_dim : [ 1, 1, feature_size, NUMGATE * unit ] -> z, r, g
115 72 : TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
116 72 : wt_idx[GRUParams::weight_ih] = context.requestWeight(
117 : weight_ih_dim, weight_initializer, weight_regularizer,
118 : weight_regularizer_constant, weight_decay, "weight_ih", true);
119 : // - weight_hh ( hidden to hidden )
120 : // weight_hh_dim : [ 1, 1, unit, NUM_GATE * unit ] -> z, r, g
121 72 : TensorDim weight_hh_dim({unit, NUM_GATE * unit});
122 144 : wt_idx[GRUParams::weight_hh] = context.requestWeight(
123 : weight_hh_dim, weight_initializer, weight_regularizer,
124 : weight_regularizer_constant, weight_decay, "weight_hh", true);
125 72 : if (!disable_bias) {
126 72 : if (integrate_bias) {
127 : // - bias_h ( input bias, hidden bias are integrate to 1 bias )
128 : // bias_h_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
129 38 : TensorDim bias_h_dim({NUM_GATE * unit});
130 38 : wt_idx[GRUParams::bias_h] = context.requestWeight(
131 : bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
132 : "bias_h", true);
133 : } else {
134 : // - bias_ih ( input bias )
135 : // bias_ih_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
136 34 : TensorDim bias_ih_dim({NUM_GATE * unit});
137 34 : wt_idx[GRUParams::bias_ih] = context.requestWeight(
138 : bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
139 : bias_decay, "bias_ih", true);
140 : // - bias_hh ( hidden bias )
141 : // bias_hh_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
142 34 : TensorDim bias_hh_dim({NUM_GATE * unit});
143 68 : wt_idx[GRUParams::bias_hh] = context.requestWeight(
144 : bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
145 : bias_decay, "bias_hh", true);
146 : }
147 : }
148 :
149 : // hidden_state_dim = [ batch, 1, max_timestep, unit ]
150 72 : TensorDim hidden_state_dim(batch_size, 1, max_timestep, unit);
151 72 : wt_idx[GRUParams::hidden_state] =
152 144 : context.requestTensor(hidden_state_dim, "hidden_state", Initializer::NONE,
153 : true, TensorLifespan::ITERATION_LIFESPAN);
154 :
155 : // zrg_dim = [ batch, 1, max_timestep, NUM_GATE * unit ]
156 72 : TensorDim zrg_dim(batch_size, 1, max_timestep, NUM_GATE * unit);
157 72 : wt_idx[GRUParams::zrg] =
158 144 : context.requestTensor(zrg_dim, "zrg", Initializer::NONE, true,
159 : TensorLifespan::ITERATION_LIFESPAN);
160 :
161 : // h_prev_dim = [ batch, 1, 1, unit ]
162 72 : TensorDim h_prev_dim = TensorDim({batch_size, 1, 1, unit});
163 72 : wt_idx[GRUParams::h_prev] =
164 72 : context.requestTensor(h_prev_dim, "h_prev", Initializer::NONE, false,
165 : TensorLifespan::FORWARD_FUNC_LIFESPAN);
166 :
167 72 : if (dropout_rate > epsilon) {
168 0 : TensorDim dropout_mask_dim(batch_size, 1, max_timestep, unit);
169 0 : wt_idx[GRUParams::dropout_mask] =
170 0 : context.requestTensor(output_dim, "dropout_mask", Initializer::NONE,
171 : false, TensorLifespan::ITERATION_LIFESPAN);
172 : }
173 :
174 72 : acti_func.setActiFunc(hidden_state_activation_type);
175 72 : recurrent_acti_func.setActiFunc(recurrent_activation_type);
176 72 : }
177 :
178 353 : void GRULayer::setProperty(const std::vector<std::string> &values) {
179 353 : auto remain_props = loadProperties(values, gru_props);
180 352 : LayerImpl::setProperty(remain_props);
181 352 : }
182 :
183 28 : void GRULayer::exportTo(Exporter &exporter,
184 : const ml::train::ExportMethods &method) const {
185 28 : LayerImpl::exportTo(exporter, method);
186 28 : exporter.saveResult(gru_props, method, this);
187 28 : }
188 :
189 270 : void GRULayer::forwarding(RunLayerContext &context, bool training) {
190 : const bool disable_bias =
191 270 : std::get<props::DisableBias>(*layer_impl_props).get();
192 :
193 270 : const unsigned int unit = std::get<props::Unit>(gru_props).get();
194 : const bool return_sequences =
195 270 : std::get<props::ReturnSequences>(gru_props).get();
196 270 : const float dropout_rate = std::get<props::DropOutRate>(gru_props).get();
197 270 : const bool integrate_bias = std::get<props::IntegrateBias>(gru_props).get();
198 270 : const bool reset_after = std::get<props::ResetAfter>(gru_props).get();
199 :
200 270 : Tensor &input = context.getInput(SINGLE_INOUT_IDX);
201 270 : const TensorDim &input_dim = input.getDim();
202 270 : const unsigned int batch_size = input_dim.batch();
203 270 : const unsigned int max_timestep = input_dim.height();
204 270 : const unsigned int feature_size = input_dim.width();
205 270 : Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
206 :
207 270 : const Tensor &weight_ih = context.getWeight(wt_idx[GRUParams::weight_ih]);
208 270 : const Tensor &weight_hh = context.getWeight(wt_idx[GRUParams::weight_hh]);
209 270 : Tensor empty;
210 270 : Tensor &bias_h = !disable_bias && integrate_bias
211 270 : ? context.getWeight(wt_idx[GRUParams::bias_h])
212 : : empty;
213 : Tensor &bias_ih = !disable_bias && !integrate_bias
214 270 : ? context.getWeight(wt_idx[GRUParams::bias_ih])
215 : : empty;
216 : Tensor &bias_hh = !disable_bias && !integrate_bias
217 270 : ? context.getWeight(wt_idx[GRUParams::bias_hh])
218 : : empty;
219 :
220 270 : Tensor &hidden_state = context.getTensor(wt_idx[GRUParams::hidden_state]);
221 270 : Tensor &zrg = context.getTensor(wt_idx[GRUParams::zrg]);
222 270 : Tensor &h_prev = context.getTensor(wt_idx[GRUParams::h_prev]);
223 :
224 270 : hidden_state.setZero();
225 270 : zrg.setZero();
226 270 : h_prev.setZero();
227 :
228 270 : Tensor prev_hs;
229 270 : Tensor hs;
230 :
231 : // zt = sigma(W_hz.h_prev + W_xz.xs)
232 : // rt = sigma(W_hr.h_prev + W_xr.xs)
233 : // gt = tanh((h_prev*rt).W_hr + W_xg.xs)
234 : // h_nx = (1-zt)*gt + zt*h_prev
235 :
236 750 : for (unsigned int b = 0; b < batch_size; ++b) {
237 480 : Tensor islice = input.getBatchSlice(b, 1);
238 480 : Tensor oslice = hidden_state.getBatchSlice(b, 1);
239 480 : Tensor zrg_ = zrg.getBatchSlice(b, 1);
240 :
241 1590 : for (unsigned int t = 0; t < max_timestep; ++t) {
242 1110 : Tensor xs = islice.getSharedDataTensor({feature_size}, t * feature_size);
243 :
244 : /** @todo verify this dropout working */
245 : // if (dropout_rate > 0.0 && training) {
246 : // xs.multiply_i(xs.dropout_mask(dropout_rate));
247 : // }
248 2220 : hs = oslice.getSharedDataTensor({unit}, t * unit);
249 : Tensor zrg_t =
250 1110 : zrg_.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE);
251 :
252 1110 : if (t > 0) {
253 1260 : prev_hs = oslice.getSharedDataTensor({unit}, (t - 1) * unit);
254 : } else {
255 960 : prev_hs = h_prev.getBatchSlice(b, 1);
256 : }
257 :
258 1110 : xs.dot(weight_ih, zrg_t); // x_z, x_r, x_g
259 :
260 1110 : Tensor ztrt = zrg_t.getSharedDataTensor({unit * 2}, 0);
261 :
262 1110 : Tensor w_hh;
263 1110 : w_hh.copy_with_stride(
264 2220 : weight_hh.getSharedDataTensor({1, 1, unit, unit * 2}, 0, false));
265 1110 : Tensor w_g;
266 1110 : w_g.copy_with_stride(
267 2220 : weight_hh.getSharedDataTensor({1, 1, unit, unit}, unit * 2, false));
268 :
269 1110 : Tensor gt = zrg_t.getSharedDataTensor({unit}, unit * 2);
270 :
271 1110 : ztrt.add_i(prev_hs.dot(w_hh));
272 1110 : if (!disable_bias) {
273 1110 : if (integrate_bias) {
274 555 : Tensor ztrt_bias_h = bias_h.getSharedDataTensor({unit * 2}, 0);
275 555 : ztrt.add_i(ztrt_bias_h);
276 555 : } else {
277 555 : Tensor ztrt_bias_ih = bias_ih.getSharedDataTensor({unit * 2}, 0);
278 555 : ztrt.add_i(ztrt_bias_ih);
279 555 : Tensor ztrt_bias_hh = bias_hh.getSharedDataTensor({unit * 2}, 0);
280 555 : ztrt.add_i(ztrt_bias_hh);
281 555 : }
282 : }
283 :
284 : recurrent_acti_func.run_fn(ztrt, ztrt);
285 :
286 1110 : Tensor zt = ztrt.getSharedDataTensor({unit}, 0);
287 1110 : Tensor rt = ztrt.getSharedDataTensor({unit}, unit);
288 :
289 1110 : Tensor temp;
290 1110 : if (reset_after) {
291 555 : prev_hs.dot(w_g, temp);
292 555 : if (!disable_bias && !integrate_bias) {
293 555 : Tensor bias_hh_g = bias_hh.getSharedDataTensor({unit}, 2 * unit);
294 555 : temp.add_i(bias_hh_g);
295 555 : }
296 555 : temp.multiply_i(rt);
297 555 : gt.add_i(temp);
298 : } else {
299 555 : rt.multiply(prev_hs, temp);
300 555 : temp.dot(w_g, gt, false, false, 1.0f);
301 555 : if (!disable_bias && !integrate_bias) {
302 0 : Tensor bias_hh_g = bias_hh.getSharedDataTensor({unit}, 2 * unit);
303 0 : gt.add_i(bias_hh_g);
304 0 : }
305 : }
306 1110 : if (!disable_bias) {
307 1110 : if (integrate_bias) {
308 555 : Tensor gt_bias_h = bias_h.getSharedDataTensor({unit}, unit * 2);
309 555 : gt.add_i(gt_bias_h);
310 555 : } else {
311 555 : Tensor gt_bias_ih = bias_ih.getSharedDataTensor({unit}, unit * 2);
312 555 : gt.add_i(gt_bias_ih);
313 555 : }
314 : }
315 :
316 : acti_func.run_fn(gt, gt);
317 :
318 1110 : zt.multiply(prev_hs, hs);
319 2220 : temp = zt.multiply(-1.0).add(1.0);
320 1110 : hs.add_i(gt.multiply(temp));
321 :
322 1110 : if (dropout_rate > epsilon && training) {
323 0 : Tensor mask_ = context.getTensor(wt_idx[GRUParams::dropout_mask])
324 0 : .getBatchSlice(b, 1);
325 0 : Tensor msk = mask_.getSharedDataTensor({unit}, t * unit);
326 0 : msk.dropout_mask(dropout_rate);
327 0 : hs.multiply_i(msk);
328 0 : }
329 1110 : }
330 480 : }
331 :
332 270 : if (!return_sequences) {
333 290 : for (unsigned int batch = 0; batch < batch_size; ++batch) {
334 180 : Tensor dest = output.getSharedDataTensor({unit}, batch * unit);
335 : Tensor src = hidden_state.getSharedDataTensor(
336 180 : {unit}, batch * unit * max_timestep + (max_timestep - 1) * unit);
337 180 : dest.copy(src);
338 180 : }
339 : } else {
340 160 : output.copy(hidden_state);
341 : }
342 270 : }
343 :
344 166 : void GRULayer::calcDerivative(RunLayerContext &context) {
345 166 : Tensor &zrg_derivative = context.getTensorGrad(wt_idx[GRUParams::zrg]);
346 166 : Tensor &weight_ih = context.getWeight(wt_idx[GRUParams::weight_ih]);
347 166 : Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
348 :
349 166 : zrg_derivative.dot(weight_ih, outgoing_derivative, false, true);
350 166 : }
351 :
352 166 : void GRULayer::calcGradient(RunLayerContext &context) {
353 : const bool disable_bias =
354 166 : std::get<props::DisableBias>(*layer_impl_props).get();
355 :
356 166 : const unsigned int unit = std::get<props::Unit>(gru_props).get();
357 : const bool return_sequences =
358 166 : std::get<props::ReturnSequences>(gru_props).get();
359 166 : const float dropout_rate = std::get<props::DropOutRate>(gru_props).get();
360 166 : const bool integrate_bias = std::get<props::IntegrateBias>(gru_props).get();
361 166 : const bool reset_after = std::get<props::ResetAfter>(gru_props).get();
362 :
363 166 : Tensor &input = context.getInput(SINGLE_INOUT_IDX);
364 166 : const TensorDim &input_dim = input.getDim();
365 166 : const unsigned int batch_size = input_dim.batch();
366 166 : const unsigned int max_timestep = input_dim.height();
367 166 : const unsigned int feature_size = input_dim.width();
368 : const Tensor &incoming_derivative =
369 166 : context.getIncomingDerivative(SINGLE_INOUT_IDX);
370 :
371 166 : Tensor &djdweight_ih = context.getWeightGrad(wt_idx[GRUParams::weight_ih]);
372 166 : Tensor &weight_hh = context.getWeight(wt_idx[GRUParams::weight_hh]);
373 166 : Tensor &djdweight_hh = context.getWeightGrad(wt_idx[GRUParams::weight_hh]);
374 166 : Tensor empty;
375 166 : Tensor &djdbias_h = !disable_bias && integrate_bias
376 166 : ? context.getWeightGrad(wt_idx[GRUParams::bias_h])
377 : : empty;
378 : Tensor &djdbias_ih = !disable_bias && !integrate_bias
379 166 : ? context.getWeightGrad(wt_idx[GRUParams::bias_ih])
380 : : empty;
381 : Tensor &bias_hh = !disable_bias && !integrate_bias
382 166 : ? context.getWeight(wt_idx[GRUParams::bias_hh])
383 : : empty;
384 : Tensor &djdbias_hh = !disable_bias && !integrate_bias
385 166 : ? context.getWeightGrad(wt_idx[GRUParams::bias_hh])
386 : : empty;
387 :
388 166 : Tensor djdweight_hh_zr = Tensor({1, 1, unit, unit * 2}, true);
389 166 : Tensor djdweight_hh_g = Tensor({1, 1, unit, unit}, true);
390 : Tensor &hidden_state_derivative =
391 166 : context.getTensorGrad(wt_idx[GRUParams::hidden_state]);
392 166 : Tensor &hidden_state = context.getTensor(wt_idx[GRUParams::hidden_state]);
393 166 : Tensor &zrg = context.getTensor(wt_idx[GRUParams::zrg]);
394 166 : Tensor &d_zrg = context.getTensorGrad(wt_idx[GRUParams::zrg]);
395 :
396 166 : djdweight_ih.setZero();
397 166 : djdweight_hh_zr.setZero();
398 166 : djdweight_hh_g.setZero();
399 166 : if (!disable_bias) {
400 166 : if (integrate_bias) {
401 83 : djdbias_h.setZero();
402 : } else {
403 83 : djdbias_ih.setZero();
404 83 : djdbias_hh.setZero();
405 : }
406 : }
407 :
408 166 : hidden_state_derivative.setZero();
409 166 : d_zrg.setZero();
410 :
411 166 : if (!return_sequences) {
412 170 : for (unsigned int batch = 0; batch < batch_size; ++batch) {
413 : Tensor dest = hidden_state_derivative.getSharedDataTensor(
414 100 : {unit}, batch * unit * max_timestep + (max_timestep - 1) * unit);
415 : Tensor src =
416 100 : incoming_derivative.getSharedDataTensor({unit}, batch * unit);
417 100 : dest.copy(src);
418 100 : }
419 : } else {
420 96 : hidden_state_derivative.copy(incoming_derivative);
421 : }
422 :
423 166 : if (dropout_rate > epsilon) {
424 0 : hidden_state_derivative.multiply_i(
425 0 : context.getTensor(wt_idx[GRUParams::dropout_mask]));
426 : }
427 :
428 : Tensor dh_nx = Tensor(unit);
429 :
430 422 : for (unsigned int b = 0; b < batch_size; ++b) {
431 256 : Tensor deriv_t = hidden_state_derivative.getBatchSlice(b, 1);
432 256 : Tensor xs_t = input.getBatchSlice(b, 1);
433 256 : Tensor hs_t = hidden_state.getBatchSlice(b, 1);
434 :
435 256 : dh_nx.setZero();
436 :
437 256 : Tensor dh;
438 256 : Tensor prev_hs;
439 256 : Tensor xs;
440 256 : Tensor dzrg_ = d_zrg.getBatchSlice(b, 1);
441 256 : Tensor zrg_ = zrg.getBatchSlice(b, 1);
442 :
443 782 : for (unsigned int t = max_timestep; t-- > 0;) {
444 1052 : dh = deriv_t.getSharedDataTensor({unit}, t * unit);
445 1052 : xs = xs_t.getSharedDataTensor({feature_size}, t * feature_size);
446 :
447 : Tensor dzrg_t =
448 526 : dzrg_.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE);
449 : Tensor zrg_t =
450 526 : zrg_.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE);
451 :
452 526 : if (t == 0) {
453 256 : prev_hs = Tensor(unit);
454 256 : prev_hs.setZero();
455 : } else {
456 540 : prev_hs = hs_t.getSharedDataTensor({unit}, (t - 1) * unit);
457 : }
458 526 : if (t < max_timestep - 1) {
459 270 : dh.add_i(dh_nx);
460 : }
461 :
462 526 : Tensor dhz = dzrg_t.getSharedDataTensor({unit}, 0);
463 526 : Tensor dhr = dzrg_t.getSharedDataTensor({unit}, unit);
464 526 : Tensor dhg = dzrg_t.getSharedDataTensor({unit}, unit * 2);
465 :
466 526 : Tensor zt = zrg_t.getSharedDataTensor({unit}, 0);
467 526 : Tensor rt = zrg_t.getSharedDataTensor({unit}, unit);
468 526 : Tensor gt = zrg_t.getSharedDataTensor({unit}, unit * 2);
469 :
470 526 : zt.multiply(dh, dh_nx); // dh_nx = d1
471 526 : dh.multiply(prev_hs, dhz); // dhz = d2
472 526 : dhz.subtract_i(gt.multiply(dh)); // dhz = d5
473 526 : zt.multiply(-1.0, dhg);
474 526 : dhg.add_i(1.0);
475 526 : dhg.multiply_i(dh); // dhg = d6
476 :
477 526 : recurrent_acti_func.run_prime_fn(zt, dhz, dhz); // dhz = d7
478 526 : acti_func.run_prime_fn(gt, dhg, dhg); // dhg = d8
479 :
480 526 : Tensor dhzr = dzrg_t.getSharedDataTensor({unit * 2}, 0); // dhz+dhr
481 :
482 526 : Tensor wg_hh;
483 526 : wg_hh.copy_with_stride(
484 1052 : weight_hh.getSharedDataTensor({1, 1, unit, unit}, unit * 2, false));
485 526 : Tensor wzr_hh;
486 526 : wzr_hh.copy_with_stride(
487 1052 : weight_hh.getSharedDataTensor({1, 1, unit, unit * 2}, 0, false));
488 :
489 : Tensor temp = Tensor(unit);
490 :
491 526 : if (reset_after) {
492 263 : prev_hs.dot(wg_hh, temp);
493 263 : if (!disable_bias && !integrate_bias) {
494 : const Tensor bias_hh_g =
495 263 : bias_hh.getSharedDataTensor({unit}, 2 * unit);
496 263 : temp.add_i(bias_hh_g);
497 263 : }
498 263 : dhg.multiply(temp, dhr);
499 :
500 : // reset temp: dhg * rt for djdbias_hh_g, dh_nx and djdweight_hh_g
501 263 : dhg.multiply(rt, temp);
502 263 : if (!disable_bias && !integrate_bias) {
503 : Tensor djdbias_hh_g =
504 263 : djdbias_hh.getSharedDataTensor({unit}, 2 * unit);
505 263 : djdbias_hh_g.add_i(temp);
506 263 : }
507 263 : temp.dot(wg_hh, dh_nx, false, true, 1.0f); // dh_nx = d1 + d14
508 263 : djdweight_hh_g.add_i(prev_hs.dot(temp, true, false));
509 : } else {
510 263 : if (!disable_bias && !integrate_bias) {
511 : Tensor djdbias_hh_g =
512 0 : djdbias_hh.getSharedDataTensor({unit}, 2 * unit);
513 0 : djdbias_hh_g.add_i(dhg);
514 0 : }
515 :
516 263 : dhg.dot(wg_hh, temp, false, true); // temp = d10
517 263 : temp.multiply(prev_hs, dhr); // dhr = d15s
518 263 : temp.multiply_i(rt); // temp=d14
519 263 : dh_nx.add_i(temp); // dh_nx = d1 + d14
520 :
521 : // reset temp : prev_hs * rt for djdweight_hh_g
522 263 : rt.multiply(prev_hs, temp);
523 263 : temp.dot(dhg, djdweight_hh_g, true, false, 1.0f);
524 : }
525 :
526 526 : recurrent_acti_func.run_prime_fn(rt, dhr, dhr); // dhr = d16
527 :
528 526 : if (!disable_bias) {
529 526 : if (integrate_bias) {
530 263 : djdbias_h.add_i(dzrg_t); // dzrg_t = d7+d16+d8
531 : } else {
532 263 : djdbias_ih.add_i(dzrg_t); // dzrg_t = d7+d16+d8
533 263 : Tensor djdbias_hh_zr = djdbias_hh.getSharedDataTensor({2 * unit}, 0);
534 263 : djdbias_hh_zr.add_i(dzrg_t.getSharedDataTensor({2 * unit}, 0));
535 263 : }
536 : }
537 :
538 526 : djdweight_hh_zr.add_i(prev_hs.dot(dhzr, true, false));
539 526 : xs.dot(dzrg_t, djdweight_ih, true, false, 1.0f);
540 526 : dhzr.dot(wzr_hh, dh_nx, false, true, 1.0); // dh_nx = d1 + d14 + d12 + d17
541 526 : }
542 256 : }
543 688 : for (unsigned int h = 0; h < unit; ++h) {
544 522 : float *data = (float *)djdweight_hh_zr.getAddress(h * unit * 2);
545 522 : float *rdata = (float *)djdweight_hh.getAddress(h * unit * NUM_GATE);
546 522 : std::copy(data, data + unit * 2, rdata);
547 : }
548 :
549 688 : for (unsigned int h = 0; h < unit; ++h) {
550 522 : float *data = (float *)djdweight_hh_g.getAddress(h * unit);
551 : float *rdata =
552 522 : (float *)djdweight_hh.getAddress(h * unit * NUM_GATE + unit * 2);
553 522 : std::copy(data, data + unit, rdata);
554 : }
555 166 : }
556 :
557 24 : void GRULayer::setBatch(RunLayerContext &context, unsigned int batch) {
558 24 : context.updateTensor(wt_idx[GRUParams::hidden_state], batch);
559 24 : context.updateTensor(wt_idx[GRUParams::zrg], batch);
560 24 : context.updateTensor(wt_idx[GRUParams::h_prev], batch);
561 :
562 24 : if (std::get<props::DropOutRate>(gru_props).get() > epsilon) {
563 0 : context.updateTensor(wt_idx[GRUParams::dropout_mask], batch);
564 : }
565 24 : }
566 :
567 : } // namespace nntrainer
|