LCOV - code coverage report
Current view: top level - nntrainer/layers - gru.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 94.3 % 318 300
Test Date: 2025-12-14 20:38:17 Functions: 88.9 % 9 8

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2020 Jijoong Moon <jijoong.moon@samsung.com>
       4              :  *
       5              :  * @file   gru.cpp
       6              :  * @date   17 March 2021
       7              :  * @brief  This is Gated Recurrent Unit Layer Class of Neural Network
       8              :  * @see    https://github.com/nnstreamer/nntrainer
       9              :  * @author Jijoong Moon <jijoong.moon@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  *
      12              :  * h_prev --------d1------->[*]-------d0----->[+]---d0--> h
      13              :  * dh_nx    |  |             |                 | d0      dh
      14              :  *          | d14            | d2        d3    |
      15              :  *          |  |             +-----[1-]------>[*]
      16              :  *          | [*]<---+ d15   |d5               | d6
      17              :  *          |  |     |rt     | zt              |gt
      18              :  *          |  |    [sig]   [sig]            [tanh]
      19              :  *          |  |     |d16    | d7              |d8
      20              :  *          |  |    [+]      [+]              [+]
      21              :  *          |  |    / \d16   |  \ d7          / \ d8
      22              :  *          |  |  Whhr Wxhr Whhz Wxhz       Whhg Wxhg
      23              :  *          |  |  |d17  |d13 |d12 |d11       |d10 | d9
      24              :  *          +- |--+------|---+    |          |    |
      25              :  *             +---------|--------|----------+    |
      26              :  *   xs------------------+--------+---------------+
      27              :  */
      28              : 
      29              : #include <cmath>
      30              : #include <gru.h>
      31              : #include <layer_context.h>
      32              : #include <nntrainer_error.h>
      33              : #include <nntrainer_log.h>
      34              : #include <node_exporter.h>
      35              : #include <util_func.h>
      36              : 
      37              : namespace nntrainer {
      38              : 
      39              : static constexpr size_t SINGLE_INOUT_IDX = 0;
      40              : 
      41              : enum GRUParams {
      42              :   weight_ih,
      43              :   weight_hh,
      44              :   bias_h,
      45              :   bias_ih,
      46              :   bias_hh,
      47              :   hidden_state,
      48              :   zrg,
      49              :   h_prev,
      50              :   dropout_mask
      51              : };
      52              : 
      53           86 : GRULayer::GRULayer() :
      54              :   LayerImpl(),
      55          430 :   gru_props(props::Unit(),
      56          172 :             props::HiddenStateActivation() = ActivationType::ACT_TANH,
      57          172 :             props::RecurrentActivation() = ActivationType::ACT_SIGMOID,
      58          172 :             props::ReturnSequences(), props::DropOutRate(),
      59          172 :             props::IntegrateBias(), props::ResetAfter()),
      60           86 :   acti_func(ActivationType::ACT_NONE, true),
      61           86 :   recurrent_acti_func(ActivationType::ACT_NONE, true),
      62          172 :   epsilon(1e-3f) {
      63              :   wt_idx.fill(std::numeric_limits<unsigned>::max());
      64           86 : }
      65              : 
      66           72 : void GRULayer::finalize(InitLayerContext &context) {
      67              :   const Initializer weight_initializer =
      68           72 :     std::get<props::WeightInitializer>(*layer_impl_props).get();
      69              :   const Initializer bias_initializer =
      70           72 :     std::get<props::BiasInitializer>(*layer_impl_props).get();
      71              :   const WeightRegularizer weight_regularizer =
      72           72 :     std::get<props::WeightRegularizer>(*layer_impl_props).get();
      73              :   const float weight_regularizer_constant =
      74           72 :     std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
      75              :   auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
      76              :   auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
      77              :   const bool disable_bias =
      78           72 :     std::get<props::DisableBias>(*layer_impl_props).get();
      79              : 
      80           72 :   const unsigned int unit = std::get<props::Unit>(gru_props).get();
      81              :   ActivationType hidden_state_activation_type =
      82           72 :     std::get<props::HiddenStateActivation>(gru_props).get();
      83              :   ActivationType recurrent_activation_type =
      84           72 :     std::get<props::RecurrentActivation>(gru_props).get();
      85              :   const bool return_sequences =
      86           72 :     std::get<props::ReturnSequences>(gru_props).get();
      87           72 :   const float dropout_rate = std::get<props::DropOutRate>(gru_props).get();
      88           72 :   const bool integrate_bias = std::get<props::IntegrateBias>(gru_props).get();
      89              : 
      90           72 :   NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument)
      91              :     << "GRU layer takes only one input";
      92              : 
      93              :   // input_dim = [ batch, 1, time_iteration, feature_size ]
      94              :   const TensorDim &input_dim = context.getInputDimensions()[0];
      95           72 :   const unsigned int batch_size = input_dim.batch();
      96           72 :   const unsigned int max_timestep = input_dim.height();
      97           72 :   NNTR_THROW_IF(max_timestep < 1, std::runtime_error)
      98              :     << "max timestep must be greator than 0 in gru layer.";
      99           72 :   const unsigned int feature_size = input_dim.width();
     100              : 
     101              :   // if return_sequences == False :
     102              :   //      output_dim = [ batch, 1, 1, unit ]
     103              :   // else:
     104              :   //      output_dim = [ batch, 1, time_iteration, unit ]
     105              :   TensorDim output_dim(
     106          112 :     {batch_size, 1, return_sequences ? max_timestep : 1, unit});
     107           72 :   context.setOutputDimensions({output_dim});
     108              : 
     109              :   // weight_initializer can be set seperately. weight_ih initializer,
     110              :   // weight_hh initializer kernel initializer & recurrent_initializer in keras
     111              :   // for now, it is set same way.
     112              : 
     113              :   // - weight_ih ( input to hidden )
     114              :   // weight_ih_dim : [ 1, 1, feature_size, NUMGATE * unit ] -> z, r, g
     115           72 :   TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
     116           72 :   wt_idx[GRUParams::weight_ih] = context.requestWeight(
     117              :     weight_ih_dim, weight_initializer, weight_regularizer,
     118              :     weight_regularizer_constant, weight_decay, "weight_ih", true);
     119              :   // - weight_hh ( hidden to hidden )
     120              :   // weight_hh_dim : [ 1, 1, unit, NUM_GATE * unit ] -> z, r, g
     121           72 :   TensorDim weight_hh_dim({unit, NUM_GATE * unit});
     122          144 :   wt_idx[GRUParams::weight_hh] = context.requestWeight(
     123              :     weight_hh_dim, weight_initializer, weight_regularizer,
     124              :     weight_regularizer_constant, weight_decay, "weight_hh", true);
     125           72 :   if (!disable_bias) {
     126           72 :     if (integrate_bias) {
     127              :       // - bias_h ( input bias, hidden bias are integrate to 1 bias )
     128              :       // bias_h_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
     129           38 :       TensorDim bias_h_dim({NUM_GATE * unit});
     130           38 :       wt_idx[GRUParams::bias_h] = context.requestWeight(
     131              :         bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
     132              :         "bias_h", true);
     133              :     } else {
     134              :       // - bias_ih ( input bias )
     135              :       // bias_ih_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
     136           34 :       TensorDim bias_ih_dim({NUM_GATE * unit});
     137           34 :       wt_idx[GRUParams::bias_ih] = context.requestWeight(
     138              :         bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
     139              :         bias_decay, "bias_ih", true);
     140              :       // - bias_hh ( hidden bias )
     141              :       // bias_hh_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
     142           34 :       TensorDim bias_hh_dim({NUM_GATE * unit});
     143           68 :       wt_idx[GRUParams::bias_hh] = context.requestWeight(
     144              :         bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
     145              :         bias_decay, "bias_hh", true);
     146              :     }
     147              :   }
     148              : 
     149              :   // hidden_state_dim = [ batch, 1, max_timestep, unit ]
     150           72 :   TensorDim hidden_state_dim(batch_size, 1, max_timestep, unit);
     151           72 :   wt_idx[GRUParams::hidden_state] =
     152          144 :     context.requestTensor(hidden_state_dim, "hidden_state", Initializer::NONE,
     153              :                           true, TensorLifespan::ITERATION_LIFESPAN);
     154              : 
     155              :   // zrg_dim = [ batch, 1, max_timestep, NUM_GATE * unit ]
     156           72 :   TensorDim zrg_dim(batch_size, 1, max_timestep, NUM_GATE * unit);
     157           72 :   wt_idx[GRUParams::zrg] =
     158          144 :     context.requestTensor(zrg_dim, "zrg", Initializer::NONE, true,
     159              :                           TensorLifespan::ITERATION_LIFESPAN);
     160              : 
     161              :   // h_prev_dim = [ batch, 1, 1, unit ]
     162           72 :   TensorDim h_prev_dim = TensorDim({batch_size, 1, 1, unit});
     163           72 :   wt_idx[GRUParams::h_prev] =
     164           72 :     context.requestTensor(h_prev_dim, "h_prev", Initializer::NONE, false,
     165              :                           TensorLifespan::FORWARD_FUNC_LIFESPAN);
     166              : 
     167           72 :   if (dropout_rate > epsilon) {
     168            0 :     TensorDim dropout_mask_dim(batch_size, 1, max_timestep, unit);
     169            0 :     wt_idx[GRUParams::dropout_mask] =
     170            0 :       context.requestTensor(output_dim, "dropout_mask", Initializer::NONE,
     171              :                             false, TensorLifespan::ITERATION_LIFESPAN);
     172              :   }
     173              : 
     174           72 :   acti_func.setActiFunc(hidden_state_activation_type);
     175           72 :   recurrent_acti_func.setActiFunc(recurrent_activation_type);
     176           72 : }
     177              : 
     178          353 : void GRULayer::setProperty(const std::vector<std::string> &values) {
     179          353 :   auto remain_props = loadProperties(values, gru_props);
     180          352 :   LayerImpl::setProperty(remain_props);
     181          352 : }
     182              : 
     183           28 : void GRULayer::exportTo(Exporter &exporter,
     184              :                         const ml::train::ExportMethods &method) const {
     185           28 :   LayerImpl::exportTo(exporter, method);
     186           28 :   exporter.saveResult(gru_props, method, this);
     187           28 : }
     188              : 
     189          270 : void GRULayer::forwarding(RunLayerContext &context, bool training) {
     190              :   const bool disable_bias =
     191          270 :     std::get<props::DisableBias>(*layer_impl_props).get();
     192              : 
     193          270 :   const unsigned int unit = std::get<props::Unit>(gru_props).get();
     194              :   const bool return_sequences =
     195          270 :     std::get<props::ReturnSequences>(gru_props).get();
     196          270 :   const float dropout_rate = std::get<props::DropOutRate>(gru_props).get();
     197          270 :   const bool integrate_bias = std::get<props::IntegrateBias>(gru_props).get();
     198          270 :   const bool reset_after = std::get<props::ResetAfter>(gru_props).get();
     199              : 
     200          270 :   Tensor &input = context.getInput(SINGLE_INOUT_IDX);
     201          270 :   const TensorDim &input_dim = input.getDim();
     202          270 :   const unsigned int batch_size = input_dim.batch();
     203          270 :   const unsigned int max_timestep = input_dim.height();
     204          270 :   const unsigned int feature_size = input_dim.width();
     205          270 :   Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
     206              : 
     207          270 :   const Tensor &weight_ih = context.getWeight(wt_idx[GRUParams::weight_ih]);
     208          270 :   const Tensor &weight_hh = context.getWeight(wt_idx[GRUParams::weight_hh]);
     209          270 :   Tensor empty;
     210          270 :   Tensor &bias_h = !disable_bias && integrate_bias
     211          270 :                      ? context.getWeight(wt_idx[GRUParams::bias_h])
     212              :                      : empty;
     213              :   Tensor &bias_ih = !disable_bias && !integrate_bias
     214          270 :                       ? context.getWeight(wt_idx[GRUParams::bias_ih])
     215              :                       : empty;
     216              :   Tensor &bias_hh = !disable_bias && !integrate_bias
     217          270 :                       ? context.getWeight(wt_idx[GRUParams::bias_hh])
     218              :                       : empty;
     219              : 
     220          270 :   Tensor &hidden_state = context.getTensor(wt_idx[GRUParams::hidden_state]);
     221          270 :   Tensor &zrg = context.getTensor(wt_idx[GRUParams::zrg]);
     222          270 :   Tensor &h_prev = context.getTensor(wt_idx[GRUParams::h_prev]);
     223              : 
     224          270 :   hidden_state.setZero();
     225          270 :   zrg.setZero();
     226          270 :   h_prev.setZero();
     227              : 
     228          270 :   Tensor prev_hs;
     229          270 :   Tensor hs;
     230              : 
     231              :   // zt = sigma(W_hz.h_prev + W_xz.xs)
     232              :   // rt = sigma(W_hr.h_prev + W_xr.xs)
     233              :   // gt = tanh((h_prev*rt).W_hr + W_xg.xs)
     234              :   // h_nx = (1-zt)*gt + zt*h_prev
     235              : 
     236          750 :   for (unsigned int b = 0; b < batch_size; ++b) {
     237          480 :     Tensor islice = input.getBatchSlice(b, 1);
     238          480 :     Tensor oslice = hidden_state.getBatchSlice(b, 1);
     239          480 :     Tensor zrg_ = zrg.getBatchSlice(b, 1);
     240              : 
     241         1590 :     for (unsigned int t = 0; t < max_timestep; ++t) {
     242         1110 :       Tensor xs = islice.getSharedDataTensor({feature_size}, t * feature_size);
     243              : 
     244              :       /** @todo verify this dropout working */
     245              :       // if (dropout_rate > 0.0 && training) {
     246              :       //   xs.multiply_i(xs.dropout_mask(dropout_rate));
     247              :       // }
     248         2220 :       hs = oslice.getSharedDataTensor({unit}, t * unit);
     249              :       Tensor zrg_t =
     250         1110 :         zrg_.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE);
     251              : 
     252         1110 :       if (t > 0) {
     253         1260 :         prev_hs = oslice.getSharedDataTensor({unit}, (t - 1) * unit);
     254              :       } else {
     255          960 :         prev_hs = h_prev.getBatchSlice(b, 1);
     256              :       }
     257              : 
     258         1110 :       xs.dot(weight_ih, zrg_t); // x_z, x_r, x_g
     259              : 
     260         1110 :       Tensor ztrt = zrg_t.getSharedDataTensor({unit * 2}, 0);
     261              : 
     262         1110 :       Tensor w_hh;
     263         1110 :       w_hh.copy_with_stride(
     264         2220 :         weight_hh.getSharedDataTensor({1, 1, unit, unit * 2}, 0, false));
     265         1110 :       Tensor w_g;
     266         1110 :       w_g.copy_with_stride(
     267         2220 :         weight_hh.getSharedDataTensor({1, 1, unit, unit}, unit * 2, false));
     268              : 
     269         1110 :       Tensor gt = zrg_t.getSharedDataTensor({unit}, unit * 2);
     270              : 
     271         1110 :       ztrt.add_i(prev_hs.dot(w_hh));
     272         1110 :       if (!disable_bias) {
     273         1110 :         if (integrate_bias) {
     274          555 :           Tensor ztrt_bias_h = bias_h.getSharedDataTensor({unit * 2}, 0);
     275          555 :           ztrt.add_i(ztrt_bias_h);
     276          555 :         } else {
     277          555 :           Tensor ztrt_bias_ih = bias_ih.getSharedDataTensor({unit * 2}, 0);
     278          555 :           ztrt.add_i(ztrt_bias_ih);
     279          555 :           Tensor ztrt_bias_hh = bias_hh.getSharedDataTensor({unit * 2}, 0);
     280          555 :           ztrt.add_i(ztrt_bias_hh);
     281          555 :         }
     282              :       }
     283              : 
     284              :       recurrent_acti_func.run_fn(ztrt, ztrt);
     285              : 
     286         1110 :       Tensor zt = ztrt.getSharedDataTensor({unit}, 0);
     287         1110 :       Tensor rt = ztrt.getSharedDataTensor({unit}, unit);
     288              : 
     289         1110 :       Tensor temp;
     290         1110 :       if (reset_after) {
     291          555 :         prev_hs.dot(w_g, temp);
     292          555 :         if (!disable_bias && !integrate_bias) {
     293          555 :           Tensor bias_hh_g = bias_hh.getSharedDataTensor({unit}, 2 * unit);
     294          555 :           temp.add_i(bias_hh_g);
     295          555 :         }
     296          555 :         temp.multiply_i(rt);
     297          555 :         gt.add_i(temp);
     298              :       } else {
     299          555 :         rt.multiply(prev_hs, temp);
     300          555 :         temp.dot(w_g, gt, false, false, 1.0f);
     301          555 :         if (!disable_bias && !integrate_bias) {
     302            0 :           Tensor bias_hh_g = bias_hh.getSharedDataTensor({unit}, 2 * unit);
     303            0 :           gt.add_i(bias_hh_g);
     304            0 :         }
     305              :       }
     306         1110 :       if (!disable_bias) {
     307         1110 :         if (integrate_bias) {
     308          555 :           Tensor gt_bias_h = bias_h.getSharedDataTensor({unit}, unit * 2);
     309          555 :           gt.add_i(gt_bias_h);
     310          555 :         } else {
     311          555 :           Tensor gt_bias_ih = bias_ih.getSharedDataTensor({unit}, unit * 2);
     312          555 :           gt.add_i(gt_bias_ih);
     313          555 :         }
     314              :       }
     315              : 
     316              :       acti_func.run_fn(gt, gt);
     317              : 
     318         1110 :       zt.multiply(prev_hs, hs);
     319         2220 :       temp = zt.multiply(-1.0).add(1.0);
     320         1110 :       hs.add_i(gt.multiply(temp));
     321              : 
     322         1110 :       if (dropout_rate > epsilon && training) {
     323            0 :         Tensor mask_ = context.getTensor(wt_idx[GRUParams::dropout_mask])
     324            0 :                          .getBatchSlice(b, 1);
     325            0 :         Tensor msk = mask_.getSharedDataTensor({unit}, t * unit);
     326            0 :         msk.dropout_mask(dropout_rate);
     327            0 :         hs.multiply_i(msk);
     328            0 :       }
     329         1110 :     }
     330          480 :   }
     331              : 
     332          270 :   if (!return_sequences) {
     333          290 :     for (unsigned int batch = 0; batch < batch_size; ++batch) {
     334          180 :       Tensor dest = output.getSharedDataTensor({unit}, batch * unit);
     335              :       Tensor src = hidden_state.getSharedDataTensor(
     336          180 :         {unit}, batch * unit * max_timestep + (max_timestep - 1) * unit);
     337          180 :       dest.copy(src);
     338          180 :     }
     339              :   } else {
     340          160 :     output.copy(hidden_state);
     341              :   }
     342          270 : }
     343              : 
     344          166 : void GRULayer::calcDerivative(RunLayerContext &context) {
     345          166 :   Tensor &zrg_derivative = context.getTensorGrad(wt_idx[GRUParams::zrg]);
     346          166 :   Tensor &weight_ih = context.getWeight(wt_idx[GRUParams::weight_ih]);
     347          166 :   Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
     348              : 
     349          166 :   zrg_derivative.dot(weight_ih, outgoing_derivative, false, true);
     350          166 : }
     351              : 
     352          166 : void GRULayer::calcGradient(RunLayerContext &context) {
     353              :   const bool disable_bias =
     354          166 :     std::get<props::DisableBias>(*layer_impl_props).get();
     355              : 
     356          166 :   const unsigned int unit = std::get<props::Unit>(gru_props).get();
     357              :   const bool return_sequences =
     358          166 :     std::get<props::ReturnSequences>(gru_props).get();
     359          166 :   const float dropout_rate = std::get<props::DropOutRate>(gru_props).get();
     360          166 :   const bool integrate_bias = std::get<props::IntegrateBias>(gru_props).get();
     361          166 :   const bool reset_after = std::get<props::ResetAfter>(gru_props).get();
     362              : 
     363          166 :   Tensor &input = context.getInput(SINGLE_INOUT_IDX);
     364          166 :   const TensorDim &input_dim = input.getDim();
     365          166 :   const unsigned int batch_size = input_dim.batch();
     366          166 :   const unsigned int max_timestep = input_dim.height();
     367          166 :   const unsigned int feature_size = input_dim.width();
     368              :   const Tensor &incoming_derivative =
     369          166 :     context.getIncomingDerivative(SINGLE_INOUT_IDX);
     370              : 
     371          166 :   Tensor &djdweight_ih = context.getWeightGrad(wt_idx[GRUParams::weight_ih]);
     372          166 :   Tensor &weight_hh = context.getWeight(wt_idx[GRUParams::weight_hh]);
     373          166 :   Tensor &djdweight_hh = context.getWeightGrad(wt_idx[GRUParams::weight_hh]);
     374          166 :   Tensor empty;
     375          166 :   Tensor &djdbias_h = !disable_bias && integrate_bias
     376          166 :                         ? context.getWeightGrad(wt_idx[GRUParams::bias_h])
     377              :                         : empty;
     378              :   Tensor &djdbias_ih = !disable_bias && !integrate_bias
     379          166 :                          ? context.getWeightGrad(wt_idx[GRUParams::bias_ih])
     380              :                          : empty;
     381              :   Tensor &bias_hh = !disable_bias && !integrate_bias
     382          166 :                       ? context.getWeight(wt_idx[GRUParams::bias_hh])
     383              :                       : empty;
     384              :   Tensor &djdbias_hh = !disable_bias && !integrate_bias
     385          166 :                          ? context.getWeightGrad(wt_idx[GRUParams::bias_hh])
     386              :                          : empty;
     387              : 
     388          166 :   Tensor djdweight_hh_zr = Tensor({1, 1, unit, unit * 2}, true);
     389          166 :   Tensor djdweight_hh_g = Tensor({1, 1, unit, unit}, true);
     390              :   Tensor &hidden_state_derivative =
     391          166 :     context.getTensorGrad(wt_idx[GRUParams::hidden_state]);
     392          166 :   Tensor &hidden_state = context.getTensor(wt_idx[GRUParams::hidden_state]);
     393          166 :   Tensor &zrg = context.getTensor(wt_idx[GRUParams::zrg]);
     394          166 :   Tensor &d_zrg = context.getTensorGrad(wt_idx[GRUParams::zrg]);
     395              : 
     396          166 :   djdweight_ih.setZero();
     397          166 :   djdweight_hh_zr.setZero();
     398          166 :   djdweight_hh_g.setZero();
     399          166 :   if (!disable_bias) {
     400          166 :     if (integrate_bias) {
     401           83 :       djdbias_h.setZero();
     402              :     } else {
     403           83 :       djdbias_ih.setZero();
     404           83 :       djdbias_hh.setZero();
     405              :     }
     406              :   }
     407              : 
     408          166 :   hidden_state_derivative.setZero();
     409          166 :   d_zrg.setZero();
     410              : 
     411          166 :   if (!return_sequences) {
     412          170 :     for (unsigned int batch = 0; batch < batch_size; ++batch) {
     413              :       Tensor dest = hidden_state_derivative.getSharedDataTensor(
     414          100 :         {unit}, batch * unit * max_timestep + (max_timestep - 1) * unit);
     415              :       Tensor src =
     416          100 :         incoming_derivative.getSharedDataTensor({unit}, batch * unit);
     417          100 :       dest.copy(src);
     418          100 :     }
     419              :   } else {
     420           96 :     hidden_state_derivative.copy(incoming_derivative);
     421              :   }
     422              : 
     423          166 :   if (dropout_rate > epsilon) {
     424            0 :     hidden_state_derivative.multiply_i(
     425            0 :       context.getTensor(wt_idx[GRUParams::dropout_mask]));
     426              :   }
     427              : 
     428              :   Tensor dh_nx = Tensor(unit);
     429              : 
     430          422 :   for (unsigned int b = 0; b < batch_size; ++b) {
     431          256 :     Tensor deriv_t = hidden_state_derivative.getBatchSlice(b, 1);
     432          256 :     Tensor xs_t = input.getBatchSlice(b, 1);
     433          256 :     Tensor hs_t = hidden_state.getBatchSlice(b, 1);
     434              : 
     435          256 :     dh_nx.setZero();
     436              : 
     437          256 :     Tensor dh;
     438          256 :     Tensor prev_hs;
     439          256 :     Tensor xs;
     440          256 :     Tensor dzrg_ = d_zrg.getBatchSlice(b, 1);
     441          256 :     Tensor zrg_ = zrg.getBatchSlice(b, 1);
     442              : 
     443          782 :     for (unsigned int t = max_timestep; t-- > 0;) {
     444         1052 :       dh = deriv_t.getSharedDataTensor({unit}, t * unit);
     445         1052 :       xs = xs_t.getSharedDataTensor({feature_size}, t * feature_size);
     446              : 
     447              :       Tensor dzrg_t =
     448          526 :         dzrg_.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE);
     449              :       Tensor zrg_t =
     450          526 :         zrg_.getSharedDataTensor({unit * NUM_GATE}, unit * t * NUM_GATE);
     451              : 
     452          526 :       if (t == 0) {
     453          256 :         prev_hs = Tensor(unit);
     454          256 :         prev_hs.setZero();
     455              :       } else {
     456          540 :         prev_hs = hs_t.getSharedDataTensor({unit}, (t - 1) * unit);
     457              :       }
     458          526 :       if (t < max_timestep - 1) {
     459          270 :         dh.add_i(dh_nx);
     460              :       }
     461              : 
     462          526 :       Tensor dhz = dzrg_t.getSharedDataTensor({unit}, 0);
     463          526 :       Tensor dhr = dzrg_t.getSharedDataTensor({unit}, unit);
     464          526 :       Tensor dhg = dzrg_t.getSharedDataTensor({unit}, unit * 2);
     465              : 
     466          526 :       Tensor zt = zrg_t.getSharedDataTensor({unit}, 0);
     467          526 :       Tensor rt = zrg_t.getSharedDataTensor({unit}, unit);
     468          526 :       Tensor gt = zrg_t.getSharedDataTensor({unit}, unit * 2);
     469              : 
     470          526 :       zt.multiply(dh, dh_nx);          // dh_nx = d1
     471          526 :       dh.multiply(prev_hs, dhz);       // dhz = d2
     472          526 :       dhz.subtract_i(gt.multiply(dh)); // dhz = d5
     473          526 :       zt.multiply(-1.0, dhg);
     474          526 :       dhg.add_i(1.0);
     475          526 :       dhg.multiply_i(dh); // dhg = d6
     476              : 
     477          526 :       recurrent_acti_func.run_prime_fn(zt, dhz, dhz); // dhz = d7
     478          526 :       acti_func.run_prime_fn(gt, dhg, dhg);           // dhg = d8
     479              : 
     480          526 :       Tensor dhzr = dzrg_t.getSharedDataTensor({unit * 2}, 0); // dhz+dhr
     481              : 
     482          526 :       Tensor wg_hh;
     483          526 :       wg_hh.copy_with_stride(
     484         1052 :         weight_hh.getSharedDataTensor({1, 1, unit, unit}, unit * 2, false));
     485          526 :       Tensor wzr_hh;
     486          526 :       wzr_hh.copy_with_stride(
     487         1052 :         weight_hh.getSharedDataTensor({1, 1, unit, unit * 2}, 0, false));
     488              : 
     489              :       Tensor temp = Tensor(unit);
     490              : 
     491          526 :       if (reset_after) {
     492          263 :         prev_hs.dot(wg_hh, temp);
     493          263 :         if (!disable_bias && !integrate_bias) {
     494              :           const Tensor bias_hh_g =
     495          263 :             bias_hh.getSharedDataTensor({unit}, 2 * unit);
     496          263 :           temp.add_i(bias_hh_g);
     497          263 :         }
     498          263 :         dhg.multiply(temp, dhr);
     499              : 
     500              :         // reset temp: dhg * rt for djdbias_hh_g, dh_nx and djdweight_hh_g
     501          263 :         dhg.multiply(rt, temp);
     502          263 :         if (!disable_bias && !integrate_bias) {
     503              :           Tensor djdbias_hh_g =
     504          263 :             djdbias_hh.getSharedDataTensor({unit}, 2 * unit);
     505          263 :           djdbias_hh_g.add_i(temp);
     506          263 :         }
     507          263 :         temp.dot(wg_hh, dh_nx, false, true, 1.0f); // dh_nx = d1 + d14
     508          263 :         djdweight_hh_g.add_i(prev_hs.dot(temp, true, false));
     509              :       } else {
     510          263 :         if (!disable_bias && !integrate_bias) {
     511              :           Tensor djdbias_hh_g =
     512            0 :             djdbias_hh.getSharedDataTensor({unit}, 2 * unit);
     513            0 :           djdbias_hh_g.add_i(dhg);
     514            0 :         }
     515              : 
     516          263 :         dhg.dot(wg_hh, temp, false, true); // temp = d10
     517          263 :         temp.multiply(prev_hs, dhr);       // dhr = d15s
     518          263 :         temp.multiply_i(rt);               // temp=d14
     519          263 :         dh_nx.add_i(temp);                 //  dh_nx = d1 + d14
     520              : 
     521              :         // reset temp : prev_hs * rt for djdweight_hh_g
     522          263 :         rt.multiply(prev_hs, temp);
     523          263 :         temp.dot(dhg, djdweight_hh_g, true, false, 1.0f);
     524              :       }
     525              : 
     526          526 :       recurrent_acti_func.run_prime_fn(rt, dhr, dhr); // dhr = d16
     527              : 
     528          526 :       if (!disable_bias) {
     529          526 :         if (integrate_bias) {
     530          263 :           djdbias_h.add_i(dzrg_t); // dzrg_t = d7+d16+d8
     531              :         } else {
     532          263 :           djdbias_ih.add_i(dzrg_t); // dzrg_t = d7+d16+d8
     533          263 :           Tensor djdbias_hh_zr = djdbias_hh.getSharedDataTensor({2 * unit}, 0);
     534          263 :           djdbias_hh_zr.add_i(dzrg_t.getSharedDataTensor({2 * unit}, 0));
     535          263 :         }
     536              :       }
     537              : 
     538          526 :       djdweight_hh_zr.add_i(prev_hs.dot(dhzr, true, false));
     539          526 :       xs.dot(dzrg_t, djdweight_ih, true, false, 1.0f);
     540          526 :       dhzr.dot(wzr_hh, dh_nx, false, true, 1.0); // dh_nx = d1 + d14 + d12 + d17
     541          526 :     }
     542          256 :   }
     543          688 :   for (unsigned int h = 0; h < unit; ++h) {
     544          522 :     float *data = (float *)djdweight_hh_zr.getAddress(h * unit * 2);
     545          522 :     float *rdata = (float *)djdweight_hh.getAddress(h * unit * NUM_GATE);
     546          522 :     std::copy(data, data + unit * 2, rdata);
     547              :   }
     548              : 
     549          688 :   for (unsigned int h = 0; h < unit; ++h) {
     550          522 :     float *data = (float *)djdweight_hh_g.getAddress(h * unit);
     551              :     float *rdata =
     552          522 :       (float *)djdweight_hh.getAddress(h * unit * NUM_GATE + unit * 2);
     553          522 :     std::copy(data, data + unit, rdata);
     554              :   }
     555          166 : }
     556              : 
     557           24 : void GRULayer::setBatch(RunLayerContext &context, unsigned int batch) {
     558           24 :   context.updateTensor(wt_idx[GRUParams::hidden_state], batch);
     559           24 :   context.updateTensor(wt_idx[GRUParams::zrg], batch);
     560           24 :   context.updateTensor(wt_idx[GRUParams::h_prev], batch);
     561              : 
     562           24 :   if (std::get<props::DropOutRate>(gru_props).get() > epsilon) {
     563            0 :     context.updateTensor(wt_idx[GRUParams::dropout_mask], batch);
     564              :   }
     565           24 : }
     566              : 
     567              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1