LCOV - code coverage report
Current view: top level - nntrainer/layers - rnn.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 89.5 % 181 162
Test Date: 2025-12-14 20:38:17 Functions: 88.9 % 9 8

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Jijoong Moon <jijoong.moon@samsung.com>
       4              :  *
       5              :  * @file   rnn.cpp
       6              :  * @date   17 March 2021
       7              :  * @brief  This is Recurrent Layer Class of Neural Network
       8              :  * @see    https://github.com/nnstreamer/nntrainer
       9              :  * @author Jijoong Moon <jijoong.moon@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  *
      12              :  */
      13              : 
      14              : #include <cmath>
      15              : #include <layer_context.h>
      16              : #include <nntrainer_error.h>
      17              : #include <nntrainer_log.h>
      18              : #include <node_exporter.h>
      19              : #include <rnn.h>
      20              : #include <util_func.h>
      21              : 
      22              : namespace nntrainer {
      23              : 
      24              : static constexpr size_t SINGLE_INOUT_IDX = 0;
      25              : 
      26              : // - weight_ih ( input to hidden )
      27              : // - weight_hh ( hidden to hidden )
      28              : // - bias_h ( input bias, hidden bias )
      29              : // - bias_ih ( input bias )
      30              : // - bias_hh ( hidden bias )
      31              : enum RNNParams {
      32              :   weight_ih,
      33              :   weight_hh,
      34              :   bias_h,
      35              :   bias_ih,
      36              :   bias_hh,
      37              :   hidden_state,
      38              :   dropout_mask
      39              : };
      40              : 
      41           47 : RNNLayer::RNNLayer() :
      42              :   LayerImpl(),
      43           94 :   rnn_props(
      44          141 :     props::Unit(), props::HiddenStateActivation() = ActivationType::ACT_TANH,
      45           94 :     props::ReturnSequences(), props::DropOutRate(), props::IntegrateBias()),
      46           47 :   acti_func(ActivationType::ACT_NONE, true),
      47           94 :   epsilon(1e-3f) {
      48              :   wt_idx.fill(std::numeric_limits<unsigned>::max());
      49           47 : }
      50              : 
      51           33 : void RNNLayer::finalize(InitLayerContext &context) {
      52              :   const nntrainer::WeightRegularizer weight_regularizer =
      53           33 :     std::get<props::WeightRegularizer>(*layer_impl_props);
      54              :   const float weight_regularizer_constant =
      55           33 :     std::get<props::WeightRegularizerConstant>(*layer_impl_props);
      56              :   const Initializer weight_initializer =
      57           33 :     std::get<props::WeightInitializer>(*layer_impl_props);
      58              :   const Initializer bias_initializer =
      59           33 :     std::get<props::BiasInitializer>(*layer_impl_props);
      60              :   auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
      61              :   auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
      62              :   const bool disable_bias =
      63           33 :     std::get<props::DisableBias>(*layer_impl_props).get();
      64              : 
      65           33 :   const unsigned int unit = std::get<props::Unit>(rnn_props).get();
      66              :   const nntrainer::ActivationType hidden_state_activation_type =
      67           33 :     std::get<props::HiddenStateActivation>(rnn_props).get();
      68              :   const bool return_sequences =
      69           33 :     std::get<props::ReturnSequences>(rnn_props).get();
      70           33 :   const float dropout_rate = std::get<props::DropOutRate>(rnn_props).get();
      71           33 :   const bool integrate_bias = std::get<props::IntegrateBias>(rnn_props).get();
      72              : 
      73           33 :   NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument)
      74              :     << "RNN layer takes only one input";
      75              : 
      76              :   // input_dim = [ batch, 1, time_iteration, feature_size ]
      77              :   const TensorDim &input_dim = context.getInputDimensions()[SINGLE_INOUT_IDX];
      78           33 :   const unsigned int batch_size = input_dim.batch();
      79           33 :   const unsigned int max_timestep = input_dim.height();
      80           33 :   NNTR_THROW_IF(max_timestep < 1, std::runtime_error)
      81              :     << "max timestep must be greator than 0 in rnn layer.";
      82           33 :   const unsigned int feature_size = input_dim.width();
      83              : 
      84              :   // output_dim = [ batch, 1, (return_sequences ? time_iteration : 1), unit ]
      85              :   const TensorDim output_dim(batch_size, 1, return_sequences ? max_timestep : 1,
      86           49 :                              unit);
      87              : 
      88           33 :   context.setOutputDimensions({output_dim});
      89              : 
      90              :   // weight_initializer can be set seperately. weight_ih initializer,
      91              :   // weight_hh initializer kernel initializer & recurrent_initializer in keras
      92              :   // for now, it is set same way.
      93              : 
      94              :   // weight_ih_dim : [ 1, 1, feature_size, unit ]
      95           33 :   const TensorDim weight_ih_dim({feature_size, unit});
      96           33 :   wt_idx[RNNParams::weight_ih] = context.requestWeight(
      97              :     weight_ih_dim, weight_initializer, weight_regularizer,
      98              :     weight_regularizer_constant, weight_decay, "weight_ih", true);
      99              :   // weight_hh_dim : [ 1, 1, unit, unit ]
     100           33 :   const TensorDim weight_hh_dim({unit, unit});
     101           66 :   wt_idx[RNNParams::weight_hh] = context.requestWeight(
     102              :     weight_hh_dim, weight_initializer, weight_regularizer,
     103              :     weight_regularizer_constant, weight_decay, "weight_hh", true);
     104           33 :   if (!disable_bias) {
     105           33 :     if (integrate_bias) {
     106              :       // bias_h_dim : [ 1, 1, 1, unit ]
     107           29 :       const TensorDim bias_h_dim({unit});
     108           29 :       wt_idx[RNNParams::bias_h] = context.requestWeight(
     109              :         bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
     110              :         "bias_h", true);
     111              :     } else {
     112              :       // bias_ih_dim : [ 1, 1, 1, unit ]
     113            4 :       const TensorDim bias_ih_dim({unit});
     114            4 :       wt_idx[RNNParams::bias_ih] = context.requestWeight(
     115              :         bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
     116              :         bias_decay, "bias_ih", true);
     117              :       // bias_hh_dim : [ 1, 1, 1, unit ]
     118            4 :       const TensorDim bias_hh_dim({unit});
     119            8 :       wt_idx[RNNParams::bias_hh] = context.requestWeight(
     120              :         bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
     121              :         bias_decay, "bias_hh", true);
     122              :     }
     123              :   }
     124              : 
     125              :   // We do not need this if we reuse net_hidden[0]. But if we do, then the unit
     126              :   // test will fail. Becuase it modifies the data during gradient calculation
     127              :   // TODO : We could control with something like #define test to save memory
     128              : 
     129              :   // hidden_state_dim : [ batch_size, 1, max_timestep, unit ]
     130           33 :   const TensorDim hidden_state_dim(batch_size, 1, max_timestep, unit);
     131           33 :   wt_idx[RNNParams::hidden_state] =
     132           33 :     context.requestTensor(hidden_state_dim, "hidden_state", Initializer::NONE,
     133              :                           true, TensorLifespan::ITERATION_LIFESPAN);
     134              : 
     135           33 :   if (dropout_rate > epsilon) {
     136              :     // dropout_mask_dim = [ batch, 1, (return_sequences ? time_iteration : 1),
     137              :     // unit ]
     138              :     const TensorDim dropout_mask_dim(batch_size, 1,
     139            0 :                                      return_sequences ? max_timestep : 1, unit);
     140            0 :     wt_idx[RNNParams::dropout_mask] =
     141            0 :       context.requestTensor(dropout_mask_dim, "dropout_mask", Initializer::NONE,
     142              :                             false, TensorLifespan::ITERATION_LIFESPAN);
     143              :   }
     144              : 
     145           33 :   acti_func.setActiFunc(hidden_state_activation_type);
     146              : 
     147           33 :   if (!acti_func.supportInPlace())
     148              :     throw exception::not_supported(
     149            0 :       "Out of place activation functions not supported");
     150           33 : }
     151              : 
     152          188 : void RNNLayer::setProperty(const std::vector<std::string> &values) {
     153              :   const std::vector<std::string> &remain_props =
     154          188 :     loadProperties(values, rnn_props);
     155          187 :   LayerImpl::setProperty(remain_props);
     156          187 : }
     157              : 
     158           14 : void RNNLayer::exportTo(Exporter &exporter,
     159              :                         const ml::train::ExportMethods &method) const {
     160           14 :   LayerImpl::exportTo(exporter, method);
     161           14 :   exporter.saveResult(rnn_props, method, this);
     162           14 : }
     163              : 
     164          110 : void RNNLayer::forwarding(RunLayerContext &context, bool training) {
     165              :   const bool disable_bias =
     166          110 :     std::get<props::DisableBias>(*layer_impl_props).get();
     167              : 
     168          110 :   const unsigned int unit = std::get<props::Unit>(rnn_props).get();
     169              :   const bool return_sequences =
     170          110 :     std::get<props::ReturnSequences>(rnn_props).get();
     171          110 :   const float dropout_rate = std::get<props::DropOutRate>(rnn_props).get();
     172          110 :   const bool integrate_bias = std::get<props::IntegrateBias>(rnn_props).get();
     173              : 
     174          110 :   const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
     175          110 :   const TensorDim &input_dim = input.getDim();
     176          110 :   const unsigned int batch_size = input_dim.batch();
     177          110 :   const unsigned int max_timestep = input_dim.height();
     178          110 :   const unsigned int feature_size = input_dim.width();
     179          110 :   Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
     180              : 
     181          110 :   const Tensor &weight_ih = context.getWeight(wt_idx[RNNParams::weight_ih]);
     182          110 :   const Tensor &weight_hh = context.getWeight(wt_idx[RNNParams::weight_hh]);
     183          110 :   Tensor empty;
     184          110 :   Tensor &bias_h = !disable_bias && integrate_bias
     185          110 :                      ? context.getWeight(wt_idx[RNNParams::bias_h])
     186              :                      : empty;
     187              :   Tensor &bias_ih = !disable_bias && !integrate_bias
     188          110 :                       ? context.getWeight(wt_idx[RNNParams::bias_ih])
     189              :                       : empty;
     190              :   Tensor &bias_hh = !disable_bias && !integrate_bias
     191          110 :                       ? context.getWeight(wt_idx[RNNParams::bias_hh])
     192              :                       : empty;
     193              : 
     194          110 :   Tensor &hidden_state = context.getTensor(wt_idx[RNNParams::hidden_state]);
     195              : 
     196              :   // TODO: swap batch and timestep index with transpose
     197          275 :   for (unsigned int batch = 0; batch < batch_size; ++batch) {
     198          165 :     Tensor input_slice = input.getBatchSlice(batch, 1);
     199          165 :     Tensor hidden_state_slice = hidden_state.getBatchSlice(batch, 1);
     200              : 
     201          465 :     for (unsigned int timestep = 0; timestep < max_timestep; ++timestep) {
     202              :       Tensor in = input_slice.getSharedDataTensor({feature_size},
     203          300 :                                                   timestep * feature_size);
     204              :       Tensor hs =
     205          300 :         hidden_state_slice.getSharedDataTensor({unit}, timestep * unit);
     206              : 
     207          300 :       in.dot(weight_ih, hs);
     208          300 :       if (!disable_bias) {
     209          300 :         if (integrate_bias) {
     210          300 :           hs.add_i(bias_h);
     211              :         } else {
     212            0 :           hs.add_i(bias_ih);
     213            0 :           hs.add_i(bias_hh);
     214              :         }
     215              :       }
     216              : 
     217          300 :       if (timestep) {
     218              :         Tensor prev_hs =
     219          135 :           hidden_state_slice.getSharedDataTensor({unit}, (timestep - 1) * unit);
     220          135 :         prev_hs.dot(weight_hh, hs, false, false, 1.0);
     221          135 :       }
     222              : 
     223              :       // In-place calculation for activation
     224              :       acti_func.run_fn(hs, hs);
     225              : 
     226          300 :       if (dropout_rate > epsilon && training) {
     227            0 :         Tensor dropout_mask = context.getTensor(wt_idx[RNNParams::dropout_mask])
     228            0 :                                 .getBatchSlice(batch, 1);
     229              :         Tensor dropout_mask_t =
     230            0 :           dropout_mask.getSharedDataTensor({unit}, timestep * unit);
     231            0 :         dropout_mask_t.dropout_mask(dropout_rate);
     232            0 :         hs.multiply_i(dropout_mask_t);
     233            0 :       }
     234          300 :     }
     235          165 :   }
     236              : 
     237          110 :   if (!return_sequences) {
     238          125 :     for (unsigned int batch = 0; batch < input_dim.batch(); ++batch) {
     239           75 :       float *hidden_state_data = hidden_state.getAddress<float>(
     240           75 :         batch * unit * max_timestep + (max_timestep - 1) * unit);
     241              :       float *output_data = output.getAddress<float>(batch * unit);
     242           75 :       std::copy(hidden_state_data, hidden_state_data + unit, output_data);
     243              :     }
     244              :   } else {
     245           60 :     output.copy(hidden_state);
     246              :   }
     247          110 : }
     248              : 
     249           78 : void RNNLayer::calcDerivative(RunLayerContext &context) {
     250              :   const Tensor &hidden_state_derivative =
     251           78 :     context.getTensorGrad(wt_idx[RNNParams::hidden_state]);
     252           78 :   const Tensor &weight = context.getWeight(wt_idx[RNNParams::weight_ih]);
     253           78 :   Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
     254              : 
     255           78 :   hidden_state_derivative.dot(weight, outgoing_derivative, false, true);
     256           78 : }
     257              : 
     258           78 : void RNNLayer::calcGradient(RunLayerContext &context) {
     259              :   const bool disable_bias =
     260           78 :     std::get<props::DisableBias>(*layer_impl_props).get();
     261              : 
     262           78 :   const unsigned int unit = std::get<props::Unit>(rnn_props).get();
     263              :   const bool return_sequences =
     264           78 :     std::get<props::ReturnSequences>(rnn_props).get();
     265           78 :   const float dropout_rate = std::get<props::DropOutRate>(rnn_props).get();
     266           78 :   const bool integrate_bias = std::get<props::IntegrateBias>(rnn_props).get();
     267              : 
     268           78 :   Tensor &input = context.getInput(SINGLE_INOUT_IDX);
     269           78 :   const TensorDim &input_dim = input.getDim();
     270           78 :   const unsigned int batch_size = input_dim.batch();
     271           78 :   const unsigned int max_timestep = input_dim.height();
     272              :   const Tensor &incoming_derivative =
     273           78 :     context.getIncomingDerivative(SINGLE_INOUT_IDX);
     274              : 
     275           78 :   Tensor &djdweight_ih = context.getWeightGrad(wt_idx[RNNParams::weight_ih]);
     276           78 :   Tensor &weight_hh = context.getWeight(wt_idx[RNNParams::weight_hh]);
     277           78 :   Tensor &djdweight_hh = context.getWeightGrad(wt_idx[RNNParams::weight_hh]);
     278           78 :   Tensor empty;
     279           78 :   Tensor &djdbias_h = !disable_bias && integrate_bias
     280           78 :                         ? context.getWeightGrad(wt_idx[RNNParams::bias_h])
     281              :                         : empty;
     282              :   Tensor &djdbias_ih = !disable_bias && !integrate_bias
     283           78 :                          ? context.getWeightGrad(wt_idx[RNNParams::bias_ih])
     284              :                          : empty;
     285              :   Tensor &djdbias_hh = !disable_bias && !integrate_bias
     286           78 :                          ? context.getWeightGrad(wt_idx[RNNParams::bias_hh])
     287              :                          : empty;
     288              : 
     289              :   Tensor &hidden_state_derivative =
     290           78 :     context.getTensorGrad(wt_idx[RNNParams::hidden_state]);
     291              : 
     292           78 :   djdweight_ih.setZero();
     293           78 :   djdweight_hh.setZero();
     294           78 :   if (!disable_bias) {
     295           78 :     if (integrate_bias) {
     296           78 :       djdbias_h.setZero();
     297              :     } else {
     298            0 :       djdbias_ih.setZero();
     299            0 :       djdbias_hh.setZero();
     300              :     }
     301              :   }
     302           78 :   hidden_state_derivative.setZero();
     303              : 
     304           78 :   if (!return_sequences) {
     305           81 :     for (unsigned int batch = 0; batch < batch_size; ++batch) {
     306              :       float *hidden_state_derivative_data =
     307           47 :         hidden_state_derivative.getAddress<float>(batch * unit * max_timestep +
     308           47 :                                                   (max_timestep - 1) * unit);
     309              :       const float *incoming_derivative_data =
     310              :         (float *)incoming_derivative.getAddress<float>(batch * unit);
     311           47 :       std::copy(incoming_derivative_data, incoming_derivative_data + unit,
     312              :                 hidden_state_derivative_data);
     313              :     }
     314              :   } else {
     315           44 :     hidden_state_derivative.copy(incoming_derivative);
     316              :   }
     317              : 
     318           78 :   if (dropout_rate > epsilon) {
     319            0 :     hidden_state_derivative.multiply_i(
     320            0 :       context.getTensor(wt_idx[RNNParams::dropout_mask]));
     321              :   }
     322              : 
     323           78 :   Tensor &hidden_state = context.getTensor(wt_idx[RNNParams::hidden_state]);
     324              : 
     325          191 :   for (unsigned int batch = 0; batch < batch_size; ++batch) {
     326          113 :     Tensor deriv_t = hidden_state_derivative.getBatchSlice(batch, 1);
     327          113 :     Tensor input_t = input.getBatchSlice(batch, 1);
     328          113 :     Tensor hidden_state_t = hidden_state.getBatchSlice(batch, 1);
     329              : 
     330          325 :     for (unsigned int timestep = max_timestep; timestep-- > 0;) {
     331              :       Tensor dh = deriv_t.getSharedDataTensor(
     332          212 :         TensorDim(1, 1, 1, deriv_t.width()), timestep * deriv_t.width());
     333              :       Tensor xs = input_t.getSharedDataTensor(
     334          212 :         TensorDim(1, 1, 1, input_t.width()), timestep * input_t.width());
     335              :       Tensor hs = hidden_state_t.getSharedDataTensor(
     336              :         TensorDim(1, 1, 1, hidden_state_t.width()),
     337          212 :         timestep * hidden_state_t.width());
     338              : 
     339          212 :       acti_func.run_prime_fn(hs, dh, dh);
     340          212 :       if (!disable_bias) {
     341          212 :         if (integrate_bias) {
     342          212 :           djdbias_h.add_i(dh);
     343              :         } else {
     344            0 :           djdbias_ih.add_i(dh);
     345            0 :           djdbias_hh.add_i(dh);
     346              :         }
     347              :       }
     348          212 :       xs.dot(dh, djdweight_ih, true, false, 1.0);
     349              : 
     350          212 :       if (timestep) {
     351              :         Tensor prev_hs = hidden_state_t.getSharedDataTensor(
     352              :           TensorDim(1, 1, 1, hidden_state_t.width()),
     353           99 :           (timestep - 1) * hidden_state_t.width());
     354              :         Tensor dh_t_1 =
     355              :           deriv_t.getSharedDataTensor(TensorDim(1, 1, 1, deriv_t.width()),
     356           99 :                                       (timestep - 1) * deriv_t.width());
     357           99 :         prev_hs.dot(dh, djdweight_hh, true, false, 1.0);
     358           99 :         dh.dot(weight_hh, dh_t_1, false, true, 1.0);
     359           99 :       }
     360          212 :     }
     361          113 :   }
     362           78 : }
     363              : 
     364           12 : void RNNLayer::setBatch(RunLayerContext &context, unsigned int batch) {
     365           12 :   context.updateTensor(wt_idx[RNNParams::hidden_state], batch);
     366              : 
     367           12 :   if (std::get<props::DropOutRate>(rnn_props).get() > epsilon) {
     368            0 :     context.updateTensor(wt_idx[RNNParams::dropout_mask], batch);
     369              :   }
     370           12 : }
     371              : 
     372              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1