LCOV - code coverage report
Current view: top level - nntrainer/layers - grucell.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 92.7 % 259 240
Test Date: 2025-12-14 20:38:17 Functions: 90.9 % 11 10

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
       4              :  *
       5              :  * @file   grucell.cpp
       6              :  * @date   28 Oct 2021
       7              :  * @brief  This is Gated Recurrent Unit Cell Layer Class of Neural Network
       8              :  * @see    https://github.com/nnstreamer/nntrainer
       9              :  * @author hyeonseok lee <hs89.lee@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  *
      12              :  * h_prev --------d1------->[*]-------d0----->[+]---d0--> h
      13              :  * d_h_prev |  |             |                 | d0      dh
      14              :  *          | d14            | d2        d3    |
      15              :  *          |  |             +-----[1-]------>[*]
      16              :  *          | [*]<---+ d15   |d5               | d6
      17              :  *          |  |     |reset_g| update_gate     | memory_cell
      18              :  *          |  |    [sig]   [sig]            [tanh]
      19              :  *          |  |     |d16    | d7              |d8
      20              :  *          |  |    [+]      [+]              [+]
      21              :  *          |  |    / \d16   |  \ d7          / \ d8
      22              :  *          |  |  Whhr Wxhr Whhz Wxhz       Whhg Wxhg
      23              :  *          |  |  |d17  |d13 |d12 |d11       |d10 | d9
      24              :  *          +- |--+------|---+    |          |    |
      25              :  *             +---------|--------|----------+    |
      26              :  *   xs------------------+--------+---------------+
      27              :  */
      28              : 
      29              : #include <cmath>
      30              : 
      31              : #include <grucell.h>
      32              : #include <lazy_tensor.h>
      33              : #include <nntrainer_error.h>
      34              : #include <nntrainer_log.h>
      35              : #include <node_exporter.h>
      36              : #include <util_func.h>
      37              : 
      38              : #include <layer_context.h>
      39              : 
      40              : namespace nntrainer {
      41              : 
      42              : /**
      43              :  * @brief  gru forwarding
      44              :  *
      45              :  */
      46           63 : static void grucell_forwarding(
      47              :   const unsigned int unit, const unsigned int batch_size,
      48              :   const bool disable_bias, const bool integrate_bias, const bool reset_after,
      49              :   ActiFunc &acti_func, ActiFunc &recurrent_acti_func, const Tensor &input,
      50              :   const Tensor &prev_hidden_state, Tensor &hidden_state,
      51              :   const Tensor &weight_ih, const Tensor &weight_hh, const Tensor &bias_h,
      52              :   const Tensor &bias_ih, const Tensor &bias_hh, Tensor &zrg) {
      53           63 :   input.dot(weight_ih, zrg);
      54              : 
      55              :   Tensor update_reset_gate =
      56           63 :     zrg.getSharedDataTensor({batch_size, 1, 1, 2 * unit}, 0, false);
      57              :   Tensor memory_cell =
      58           63 :     zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 2 * unit, false);
      59              : 
      60           63 :   Tensor weight_hh_update_reset_gate;
      61           63 :   Tensor weight_hh_memory_cell;
      62           63 :   weight_hh_update_reset_gate.copy_with_stride(
      63          126 :     weight_hh.getSharedDataTensor({unit, 2 * unit}, 0, false));
      64           63 :   weight_hh_memory_cell.copy_with_stride(
      65          126 :     weight_hh.getSharedDataTensor({unit, unit}, 2 * unit, false));
      66              : 
      67           63 :   update_reset_gate.add_i_strided(
      68          126 :     prev_hidden_state.dot(weight_hh_update_reset_gate));
      69           63 :   if (!disable_bias) {
      70           63 :     if (integrate_bias) {
      71              :       const Tensor bias_h_update_reset_gate =
      72           10 :         bias_h.getSharedDataTensor({2 * unit}, 0);
      73           10 :       update_reset_gate.add_i(bias_h_update_reset_gate);
      74           10 :     } else {
      75              :       const Tensor bias_ih_update_reset_gate =
      76           53 :         bias_ih.getSharedDataTensor({2 * unit}, 0);
      77           53 :       update_reset_gate.add_i(bias_ih_update_reset_gate);
      78              :       const Tensor bias_hh_update_reset_gate =
      79           53 :         bias_hh.getSharedDataTensor({2 * unit}, 0);
      80           53 :       update_reset_gate.add_i(bias_hh_update_reset_gate);
      81           53 :     }
      82              :   }
      83              : 
      84              :   recurrent_acti_func.run_fn(update_reset_gate, update_reset_gate);
      85              : 
      86              :   Tensor update_gate =
      87           63 :     update_reset_gate.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
      88              :   Tensor reset_gate = update_reset_gate.getSharedDataTensor(
      89           63 :     {batch_size, 1, 1, unit}, unit, false);
      90              : 
      91           63 :   Tensor temp;
      92           63 :   if (reset_after) {
      93           53 :     prev_hidden_state.dot(weight_hh_memory_cell, temp);
      94           53 :     if (!disable_bias && !integrate_bias) {
      95              :       const Tensor bias_hh_memory_cell =
      96           53 :         bias_hh.getSharedDataTensor({unit}, 2 * unit);
      97           53 :       temp.add_i(bias_hh_memory_cell);
      98           53 :     }
      99           53 :     temp.multiply_i_strided(reset_gate);
     100           53 :     memory_cell.add_i_strided(temp);
     101              :   } else {
     102           10 :     reset_gate.multiply_strided(prev_hidden_state, temp);
     103           10 :     memory_cell.add_i_strided(temp.dot(weight_hh_memory_cell));
     104           10 :     if (!disable_bias && !integrate_bias) {
     105              :       const Tensor bias_hh_memory_cell =
     106            0 :         bias_hh.getSharedDataTensor({unit}, 2 * unit);
     107            0 :       memory_cell.add_i(bias_hh_memory_cell);
     108            0 :     }
     109              :   }
     110           63 :   if (!disable_bias) {
     111           63 :     if (integrate_bias) {
     112              :       const Tensor bias_h_memory_cell =
     113           10 :         bias_h.getSharedDataTensor({unit}, 2 * unit);
     114           10 :       memory_cell.add_i(bias_h_memory_cell);
     115           10 :     } else {
     116              :       const Tensor bias_ih_memory_cell =
     117           53 :         bias_ih.getSharedDataTensor({unit}, 2 * unit);
     118           53 :       memory_cell.add_i(bias_ih_memory_cell);
     119           53 :     }
     120              :   }
     121              : 
     122              :   acti_func.run_fn(memory_cell, memory_cell);
     123              : 
     124           63 :   update_gate.multiply_strided(prev_hidden_state, hidden_state);
     125          126 :   temp = update_gate.multiply(-1.0).add(1.0);
     126           63 :   memory_cell.multiply_strided(temp, hidden_state, 1.0f);
     127           63 : }
     128              : 
     129              : /**
     130              :  * @brief  gru calcGradient
     131              :  *
     132              :  */
     133           27 : static void grucell_calcGradient(
     134              :   const unsigned int unit, const unsigned int batch_size,
     135              :   const bool disable_bias, const bool integrate_bias, const bool reset_after,
     136              :   ActiFunc &acti_func, ActiFunc &recurrent_acti_func, const Tensor &input,
     137              :   const Tensor &prev_hidden_state, Tensor &d_prev_hidden_state,
     138              :   const Tensor &d_hidden_state, Tensor &d_weight_ih, const Tensor &weight_hh,
     139              :   Tensor &d_weight_hh, Tensor &d_bias_h, Tensor &d_bias_ih,
     140              :   const Tensor &bias_hh, Tensor &d_bias_hh, const Tensor &zrg, Tensor &d_zrg) {
     141              :   Tensor d_weight_hh_update_reset_gate =
     142           27 :     d_weight_hh.getSharedDataTensor({unit, 2 * unit}, 0, false);
     143              :   Tensor d_weight_hh_memory_cell =
     144           27 :     d_weight_hh.getSharedDataTensor({unit, unit}, 2 * unit, false);
     145              : 
     146              :   Tensor update_gate =
     147           27 :     zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
     148              :   Tensor reset_gate =
     149           27 :     zrg.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
     150              :   Tensor memory_cell =
     151           27 :     zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 2 * unit, false);
     152              : 
     153              :   Tensor d_update_gate =
     154           27 :     d_zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 0, false);
     155              :   Tensor d_reset_gate =
     156           27 :     d_zrg.getSharedDataTensor({batch_size, 1, 1, unit}, unit, false);
     157              :   Tensor d_memory_cell =
     158           27 :     d_zrg.getSharedDataTensor({batch_size, 1, 1, unit}, 2 * unit, false);
     159              : 
     160           27 :   d_hidden_state.multiply_strided(
     161              :     update_gate, d_prev_hidden_state); // d_prev_hidden_state = d1
     162           27 :   d_hidden_state.multiply_strided(prev_hidden_state,
     163              :                                   d_update_gate); // d_update_gate = d2
     164           27 :   d_update_gate.add_i_strided(d_hidden_state.multiply_strided(memory_cell),
     165              :                               -1.0f); // d_update_gate = d5
     166           27 :   update_gate.multiply(-1.0, d_memory_cell);
     167           27 :   d_memory_cell.add_i(1.0);
     168           27 :   d_memory_cell.multiply_i_strided(d_hidden_state); // d_memory_cell = d6
     169              : 
     170           27 :   recurrent_acti_func.run_prime_fn(update_gate, d_update_gate,
     171              :                                    d_update_gate); // d_update_gate = d7
     172           27 :   acti_func.run_prime_fn(memory_cell, d_memory_cell,
     173              :                          d_memory_cell); // d_memory_cell = d8
     174              : 
     175              :   Tensor d_update_reset_gate = d_zrg.getSharedDataTensor(
     176           27 :     {batch_size, 1, 1, 2 * unit}, 0, false); // d_update_gate+d_reset_gate
     177              : 
     178           27 :   Tensor weight_hh_memory_cell;
     179           27 :   weight_hh_memory_cell.copy_with_stride(
     180           54 :     weight_hh.getSharedDataTensor({unit, unit}, 2 * unit, false));
     181           27 :   Tensor weight_hh_update_reset_gate;
     182           27 :   weight_hh_update_reset_gate.copy_with_stride(
     183           54 :     weight_hh.getSharedDataTensor({unit, 2 * unit}, 0, false));
     184              : 
     185           27 :   Tensor temp = Tensor(batch_size, 1, 1, unit);
     186           27 :   Tensor d_memory_cell_contiguous;
     187           27 :   d_memory_cell_contiguous.copy_with_stride(d_memory_cell);
     188              : 
     189           27 :   if (reset_after) {
     190           25 :     prev_hidden_state.dot(weight_hh_memory_cell, temp);
     191           25 :     if (!disable_bias && !integrate_bias) {
     192              :       const Tensor bias_hh_memory_cell =
     193           25 :         bias_hh.getSharedDataTensor({unit}, 2 * unit);
     194           25 :       temp.add_i(bias_hh_memory_cell);
     195           25 :     }
     196           25 :     d_memory_cell_contiguous.multiply_strided(
     197              :       temp, d_reset_gate); // d_reset_gate = d15
     198              : 
     199              :     // reset temp: d_memory_cell_contiguous * reset_gate for
     200              :     // d_bias_hh_memory_cell, d_prev_hidden_state and d_weight_hh_memory_cell
     201           25 :     d_memory_cell_contiguous.multiply_strided(reset_gate, temp);
     202           25 :     if (!disable_bias && !integrate_bias) {
     203              :       Tensor d_bias_hh_memory_cell =
     204           25 :         d_bias_hh.getSharedDataTensor({unit}, 2 * unit);
     205           25 :       temp.sum(0, d_bias_hh_memory_cell, 1.0, 1.0);
     206           25 :     }
     207           25 :     temp.dot(weight_hh_memory_cell, d_prev_hidden_state, false, true,
     208              :              1.0); // d_prev_hidden_state = d1 + d14
     209           25 :     d_weight_hh_memory_cell.add_i_strided(
     210           50 :       prev_hidden_state.dot(temp, true, false));
     211              :   } else {
     212            2 :     if (!disable_bias && !integrate_bias) {
     213              :       Tensor d_bias_hh_memory_cell =
     214            0 :         d_bias_hh.getSharedDataTensor({unit}, 2 * unit);
     215            0 :       d_memory_cell.sum(0, d_bias_hh_memory_cell, 1.0, 1.0);
     216            0 :     }
     217              : 
     218            2 :     d_memory_cell_contiguous.dot(weight_hh_memory_cell, temp, false, true);
     219            2 :     temp.multiply_strided(prev_hidden_state, d_reset_gate);
     220            2 :     temp.multiply_strided(reset_gate, d_prev_hidden_state, 1.0f);
     221              : 
     222              :     // reset temp: reset_gate * prev_hidden_state for and
     223              :     // d_weight_hh_memory_cell
     224            2 :     reset_gate.multiply_strided(prev_hidden_state, temp);
     225            2 :     d_weight_hh_memory_cell.add_i_strided(
     226            4 :       temp.dot(d_memory_cell_contiguous, true, false));
     227              :   }
     228              : 
     229           27 :   recurrent_acti_func.run_prime_fn(reset_gate, d_reset_gate,
     230              :                                    d_reset_gate); // d_reset_gate = d16
     231              : 
     232           27 :   if (!disable_bias) {
     233           27 :     if (integrate_bias) {
     234            2 :       d_zrg.sum(0, d_bias_h, 1.0, 1.0);
     235              :     } else {
     236           25 :       d_zrg.sum(0, d_bias_ih, 1.0, 1.0);
     237              :       Tensor d_bias_hh_update_reset_gate =
     238           25 :         d_bias_hh.getSharedDataTensor({2 * unit}, 0);
     239           25 :       d_bias_hh_update_reset_gate.add_i(
     240           50 :         d_zrg.sum(0).getSharedDataTensor({2 * unit}, 0));
     241           25 :     }
     242              :   }
     243              : 
     244           27 :   Tensor d_update_reset_gate_contiguous;
     245           27 :   d_update_reset_gate_contiguous.copy_with_stride(d_update_reset_gate);
     246           27 :   d_weight_hh_update_reset_gate.add_i_strided(
     247           54 :     prev_hidden_state.dot(d_update_reset_gate_contiguous, true, false));
     248           27 :   input.dot(d_zrg, d_weight_ih, true, false, 1.0f);
     249           27 :   d_update_reset_gate_contiguous.dot(
     250              :     weight_hh_update_reset_gate, d_prev_hidden_state, false, true,
     251              :     1.0); // d_prev_hidden_state = d1 + d14 + d12 + d17
     252           27 : }
     253              : 
     254              : enum GRUCellParams {
     255              :   weight_ih,
     256              :   weight_hh,
     257              :   bias_h,
     258              :   bias_ih,
     259              :   bias_hh,
     260              :   zrg,
     261              :   dropout_mask
     262              : };
     263              : 
     264              : // Todo: handle with strided tensor more efficiently and reduce temporary
     265              : // tensors
     266           61 : GRUCellLayer::GRUCellLayer() :
     267              :   LayerImpl(),
     268          244 :   grucell_props(props::Unit(), props::IntegrateBias(), props::ResetAfter(),
     269          122 :                 props::HiddenStateActivation() = ActivationType::ACT_TANH,
     270          122 :                 props::RecurrentActivation() = ActivationType::ACT_SIGMOID,
     271           61 :                 props::DropOutRate()),
     272           61 :   acti_func(ActivationType::ACT_NONE, true),
     273           61 :   recurrent_acti_func(ActivationType::ACT_NONE, true),
     274          122 :   epsilon(1e-3f) {
     275              :   wt_idx.fill(std::numeric_limits<unsigned>::max());
     276           61 : }
     277              : 
     278           39 : void GRUCellLayer::finalize(InitLayerContext &context) {
     279              :   const Initializer weight_initializer =
     280           39 :     std::get<props::WeightInitializer>(*layer_impl_props).get();
     281              :   const Initializer bias_initializer =
     282           39 :     std::get<props::BiasInitializer>(*layer_impl_props).get();
     283              :   const WeightRegularizer weight_regularizer =
     284           39 :     std::get<props::WeightRegularizer>(*layer_impl_props).get();
     285              :   const float weight_regularizer_constant =
     286           39 :     std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
     287              :   auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
     288              :   auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
     289              :   const bool disable_bias =
     290           39 :     std::get<props::DisableBias>(*layer_impl_props).get();
     291              : 
     292           39 :   const unsigned int unit = std::get<props::Unit>(grucell_props).get();
     293              :   const bool integrate_bias =
     294           39 :     std::get<props::IntegrateBias>(grucell_props).get();
     295              :   const ActivationType hidden_state_activation_type =
     296           39 :     std::get<props::HiddenStateActivation>(grucell_props).get();
     297              :   const ActivationType recurrent_activation_type =
     298           39 :     std::get<props::RecurrentActivation>(grucell_props).get();
     299           39 :   const float dropout_rate = std::get<props::DropOutRate>(grucell_props).get();
     300              : 
     301           39 :   NNTR_THROW_IF(context.getNumInputs() != 2, std::invalid_argument)
     302              :     << "GRUCell layer expects 2 inputs(one for the input and hidden state for "
     303            0 :        "the other) but got " +
     304            0 :          std::to_string(context.getNumInputs()) + " input(s)";
     305              : 
     306              :   // input_dim = [ batch_size, 1, 1, feature_size ]
     307              :   const TensorDim &input_dim = context.getInputDimensions()[0];
     308           39 :   NNTR_THROW_IF(input_dim.channel() != 1 && input_dim.height() != 1,
     309              :                 std::invalid_argument)
     310              :     << "Input must be single time dimension for GRUCell(shape should be "
     311              :        "[batch_size, 1, 1, feature_size]";
     312              :   // input_hidden_state_dim = [ batch_size, 1, 1, unit ]
     313              :   const TensorDim &input_hidden_state_dim =
     314              :     context.getInputDimensions()[INOUT_INDEX::INPUT_HIDDEN_STATE];
     315           39 :   NNTR_THROW_IF(input_hidden_state_dim.channel() != 1 ||
     316              :                   input_hidden_state_dim.height() != 1,
     317              :                 std::invalid_argument)
     318              :     << "Input hidden state's dimension should be [batch, 1, 1, unit] for "
     319              :        "GRUCell";
     320              : 
     321           39 :   const unsigned int batch_size = input_dim.batch();
     322           39 :   const unsigned int feature_size = input_dim.width();
     323              : 
     324              :   // output_dim = [ batch_size, 1, 1, unit ]
     325           39 :   TensorDim output_dim(batch_size, 1, 1, unit);
     326           39 :   context.setOutputDimensions({output_dim});
     327              : 
     328              :   // weight_initializer can be set seperately. weight_ih initializer,
     329              :   // weight_hh initializer kernel initializer & recurrent_initializer in keras
     330              :   // for now, it is set same way.
     331              : 
     332              :   // - weight_ih ( input to hidden )
     333              :   // weight_ih_dim : [ 1, 1, feature_size, NUMGATE * unit ] -> z, r, g
     334           39 :   TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
     335           39 :   wt_idx[GRUCellParams::weight_ih] = context.requestWeight(
     336              :     weight_ih_dim, weight_initializer, weight_regularizer,
     337              :     weight_regularizer_constant, weight_decay, "weight_ih", true);
     338              :   // - weight_hh ( hidden to hidden )
     339              :   // weight_hh_dim : [ 1, 1, unit, NUM_GATE * unit ] -> z, r, g
     340           39 :   TensorDim weight_hh_dim({unit, NUM_GATE * unit});
     341           78 :   wt_idx[GRUCellParams::weight_hh] = context.requestWeight(
     342              :     weight_hh_dim, weight_initializer, weight_regularizer,
     343              :     weight_regularizer_constant, weight_decay, "weight_hh", true);
     344           39 :   if (!disable_bias) {
     345           39 :     if (integrate_bias) {
     346              :       // - bias_h ( input bias, hidden bias are integrate to 1 bias )
     347              :       // bias_h_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
     348            2 :       TensorDim bias_h_dim({NUM_GATE * unit});
     349            2 :       wt_idx[GRUCellParams::bias_h] = context.requestWeight(
     350              :         bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
     351              :         "bias_h", true);
     352              :     } else {
     353              :       // - bias_ih ( input bias )
     354              :       // bias_ih_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
     355           37 :       TensorDim bias_ih_dim({NUM_GATE * unit});
     356           37 :       wt_idx[GRUCellParams::bias_ih] = context.requestWeight(
     357              :         bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
     358              :         bias_decay, "bias_ih", true);
     359              :       // - bias_hh ( hidden bias )
     360              :       // bias_hh_dim : [ 1, 1, 1, NUM_GATE * unit ] -> z, r, g
     361           37 :       TensorDim bias_hh_dim({NUM_GATE * unit});
     362           74 :       wt_idx[GRUCellParams::bias_hh] = context.requestWeight(
     363              :         bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
     364              :         bias_decay, "bias_hh", true);
     365              :     }
     366              :   }
     367              : 
     368              :   // zrg_dim = [ batch_size, 1, 1, NUM_GATE * unit ]
     369           39 :   TensorDim zrg_dim(batch_size, 1, 1, NUM_GATE * unit);
     370           39 :   wt_idx[GRUCellParams::zrg] =
     371           39 :     context.requestTensor(zrg_dim, "zrg", Initializer::NONE, true,
     372              :                           TensorLifespan::ITERATION_LIFESPAN);
     373              : 
     374           39 :   if (dropout_rate > epsilon) {
     375              :     // dropout_mask_dim = [ batch_size, 1, 1, unit ]
     376            0 :     TensorDim dropout_mask_dim(batch_size, 1, 1, unit);
     377            0 :     wt_idx[GRUCellParams::dropout_mask] =
     378            0 :       context.requestTensor(dropout_mask_dim, "dropout_mask", Initializer::NONE,
     379              :                             false, TensorLifespan::ITERATION_LIFESPAN);
     380              :   }
     381              : 
     382           39 :   acti_func.setActiFunc(hidden_state_activation_type);
     383           39 :   recurrent_acti_func.setActiFunc(recurrent_activation_type);
     384           39 : }
     385              : 
     386          244 : void GRUCellLayer::setProperty(const std::vector<std::string> &values) {
     387          244 :   auto remain_props = loadProperties(values, grucell_props);
     388          243 :   LayerImpl::setProperty(remain_props);
     389          243 : }
     390              : 
     391           32 : void GRUCellLayer::exportTo(Exporter &exporter,
     392              :                             const ml::train::ExportMethods &method) const {
     393           32 :   LayerImpl::exportTo(exporter, method);
     394           32 :   exporter.saveResult(grucell_props, method, this);
     395           32 : }
     396              : 
     397           63 : void GRUCellLayer::forwarding(RunLayerContext &context, bool training) {
     398              :   const bool disable_bias =
     399           63 :     std::get<props::DisableBias>(*layer_impl_props).get();
     400              : 
     401           63 :   const unsigned int unit = std::get<props::Unit>(grucell_props).get();
     402              :   const bool integrate_bias =
     403           63 :     std::get<props::IntegrateBias>(grucell_props).get();
     404           63 :   const bool reset_after = std::get<props::ResetAfter>(grucell_props).get();
     405           63 :   const float dropout_rate = std::get<props::DropOutRate>(grucell_props).get();
     406              : 
     407           63 :   const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
     408              :   const Tensor &prev_hidden_state =
     409           63 :     context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
     410              :   // hidden_state == output in grucell
     411           63 :   Tensor &hidden_state = context.getOutput(INOUT_INDEX::OUTPUT);
     412              : 
     413           63 :   const unsigned int batch_size = input.getDim().batch();
     414              : 
     415           63 :   const Tensor &weight_ih = context.getWeight(wt_idx[GRUCellParams::weight_ih]);
     416           63 :   const Tensor &weight_hh = context.getWeight(wt_idx[GRUCellParams::weight_hh]);
     417           63 :   Tensor empty;
     418           63 :   const Tensor &bias_h = !disable_bias && integrate_bias
     419           63 :                            ? context.getWeight(wt_idx[GRUCellParams::bias_h])
     420              :                            : empty;
     421              :   const Tensor &bias_ih = !disable_bias && !integrate_bias
     422           63 :                             ? context.getWeight(wt_idx[GRUCellParams::bias_ih])
     423              :                             : empty;
     424              :   const Tensor &bias_hh = !disable_bias && !integrate_bias
     425           63 :                             ? context.getWeight(wt_idx[GRUCellParams::bias_hh])
     426              :                             : empty;
     427              : 
     428           63 :   Tensor &zrg = context.getTensor(wt_idx[GRUCellParams::zrg]);
     429              : 
     430           63 :   grucell_forwarding(unit, batch_size, disable_bias, integrate_bias,
     431           63 :                      reset_after, acti_func, recurrent_acti_func, input,
     432              :                      prev_hidden_state, hidden_state, weight_ih, weight_hh,
     433              :                      bias_h, bias_ih, bias_hh, zrg);
     434              : 
     435           63 :   if (dropout_rate > epsilon && training) {
     436            0 :     Tensor mask = context.getTensor(wt_idx[GRUCellParams::dropout_mask]);
     437            0 :     mask.dropout_mask(dropout_rate);
     438            0 :     hidden_state.multiply_i(mask);
     439            0 :   }
     440           63 : }
     441              : 
     442           27 : void GRUCellLayer::calcDerivative(RunLayerContext &context) {
     443              :   Tensor &outgoing_derivative =
     444           27 :     context.getOutgoingDerivative(INOUT_INDEX::INPUT);
     445           27 :   const Tensor &weight_ih = context.getWeight(wt_idx[GRUCellParams::weight_ih]);
     446           27 :   const Tensor &d_zrg = context.getTensorGrad(wt_idx[GRUCellParams::zrg]);
     447              : 
     448           27 :   d_zrg.dot(weight_ih, outgoing_derivative, false, true);
     449           27 : }
     450              : 
     451           27 : void GRUCellLayer::calcGradient(RunLayerContext &context) {
     452              :   const bool disable_bias =
     453           27 :     std::get<props::DisableBias>(*layer_impl_props).get();
     454              : 
     455           27 :   const unsigned int unit = std::get<props::Unit>(grucell_props).get();
     456              :   const bool integrate_bias =
     457           27 :     std::get<props::IntegrateBias>(grucell_props).get();
     458           27 :   const bool reset_after = std::get<props::ResetAfter>(grucell_props).get();
     459           27 :   const float dropout_rate = std::get<props::DropOutRate>(grucell_props).get();
     460              : 
     461           27 :   const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
     462              :   const Tensor &prev_hidden_state =
     463           27 :     context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
     464              :   Tensor &d_prev_hidden_state =
     465           27 :     context.getOutgoingDerivative(INOUT_INDEX::INPUT_HIDDEN_STATE);
     466              :   const Tensor &incoming_derivative =
     467           27 :     context.getIncomingDerivative(INOUT_INDEX::OUTPUT);
     468              : 
     469           27 :   const unsigned int batch_size = input.getDim().batch();
     470              : 
     471           27 :   Tensor &d_weight_ih = context.getWeightGrad(wt_idx[GRUCellParams::weight_ih]);
     472           27 :   const Tensor &weight_hh = context.getWeight(wt_idx[GRUCellParams::weight_hh]);
     473           27 :   Tensor &d_weight_hh = context.getWeightGrad(wt_idx[GRUCellParams::weight_hh]);
     474              : 
     475           27 :   Tensor empty;
     476           27 :   Tensor &d_bias_h = !disable_bias && integrate_bias
     477           27 :                        ? context.getWeightGrad(wt_idx[GRUCellParams::bias_h])
     478              :                        : empty;
     479              :   Tensor &d_bias_ih = !disable_bias && !integrate_bias
     480           27 :                         ? context.getWeightGrad(wt_idx[GRUCellParams::bias_ih])
     481              :                         : empty;
     482              :   const Tensor &bias_hh = !disable_bias && !integrate_bias
     483           27 :                             ? context.getWeight(wt_idx[GRUCellParams::bias_hh])
     484              :                             : empty;
     485              :   Tensor &d_bias_hh = !disable_bias && !integrate_bias
     486           27 :                         ? context.getWeightGrad(wt_idx[GRUCellParams::bias_hh])
     487              :                         : empty;
     488              : 
     489           27 :   const Tensor &zrg = context.getTensor(wt_idx[GRUCellParams::zrg]);
     490           27 :   Tensor &d_zrg = context.getTensorGrad(wt_idx[GRUCellParams::zrg]);
     491              : 
     492           27 :   if (context.isGradientFirstAccess(wt_idx[GRUCellParams::weight_ih])) {
     493           12 :     d_weight_ih.setZero();
     494              :   }
     495           27 :   if (context.isGradientFirstAccess(wt_idx[GRUCellParams::weight_hh])) {
     496           12 :     d_weight_hh.setZero();
     497              :   }
     498           27 :   if (!disable_bias) {
     499           27 :     if (integrate_bias) {
     500            2 :       if (context.isGradientFirstAccess(wt_idx[GRUCellParams::bias_h])) {
     501            0 :         d_bias_h.setZero();
     502              :       }
     503              :     } else {
     504           25 :       if (context.isGradientFirstAccess(wt_idx[GRUCellParams::bias_ih])) {
     505           12 :         d_bias_ih.setZero();
     506              :       }
     507           25 :       if (context.isGradientFirstAccess(wt_idx[GRUCellParams::bias_hh])) {
     508           12 :         d_bias_hh.setZero();
     509              :       }
     510              :     }
     511              :   }
     512              : 
     513           27 :   Tensor incoming_derivative_masked(batch_size, 1, 1, unit);
     514           27 :   if (dropout_rate > epsilon) {
     515            0 :     incoming_derivative.multiply_strided(
     516            0 :       context.getTensor(wt_idx[GRUCellParams::dropout_mask]),
     517              :       incoming_derivative_masked);
     518              :   }
     519              : 
     520           54 :   grucell_calcGradient(
     521           27 :     unit, batch_size, disable_bias, integrate_bias, reset_after, acti_func,
     522           27 :     recurrent_acti_func, input, prev_hidden_state, d_prev_hidden_state,
     523           27 :     dropout_rate > epsilon ? incoming_derivative_masked : incoming_derivative,
     524              :     d_weight_ih, weight_hh, d_weight_hh, d_bias_h, d_bias_ih, bias_hh,
     525              :     d_bias_hh, zrg, d_zrg);
     526           27 : }
     527              : 
     528           32 : void GRUCellLayer::setBatch(RunLayerContext &context, unsigned int batch) {
     529           32 :   const float dropout_rate = std::get<props::DropOutRate>(grucell_props);
     530              : 
     531           32 :   context.updateTensor(wt_idx[GRUCellParams::zrg], batch);
     532              : 
     533           32 :   if (dropout_rate > epsilon) {
     534            0 :     context.updateTensor(wt_idx[GRUCellParams::dropout_mask], batch);
     535              :   }
     536           32 : }
     537              : 
     538              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1