LCOV - code coverage report
Current view: top level - nntrainer/layers - lstmcell.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 90.1 % 142 128
Test Date: 2025-12-14 20:38:17 Functions: 88.9 % 9 8

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
       4              :  *
       5              :  * @file   lstmcell.cpp
       6              :  * @date   17 March 2021
       7              :  * @brief  This is LSTMCell Layer Class of Neural Network
       8              :  * @see    https://github.com/nnstreamer/nntrainer
       9              :  * @author Parichay Kapoor <pk.kapoor@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  *
      12              :  */
      13              : 
      14              : #include <layer_context.h>
      15              : #include <lstmcell.h>
      16              : #include <nntrainer_error.h>
      17              : #include <nntrainer_log.h>
      18              : #include <node_exporter.h>
      19              : 
      20              : namespace nntrainer {
      21              : 
      22              : enum LSTMCellParams {
      23              :   weight_ih,
      24              :   weight_hh,
      25              :   bias_h,
      26              :   bias_ih,
      27              :   bias_hh,
      28              :   ifgo,
      29              :   dropout_mask
      30              : };
      31              : 
      32           73 : LSTMCellLayer::LSTMCellLayer() : lstmcell_props(props::DropOutRate()) {
      33              :   wt_idx.fill(std::numeric_limits<unsigned>::max());
      34           73 : }
      35              : 
      36           29 : void LSTMCellLayer::finalize(InitLayerContext &context) {
      37              :   const Initializer weight_initializer =
      38           29 :     std::get<props::WeightInitializer>(*layer_impl_props).get();
      39              :   const Initializer bias_initializer =
      40           29 :     std::get<props::BiasInitializer>(*layer_impl_props).get();
      41              :   const WeightRegularizer weight_regularizer =
      42           29 :     std::get<props::WeightRegularizer>(*layer_impl_props).get();
      43              :   const float weight_regularizer_constant =
      44           29 :     std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
      45              :   auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
      46              :   auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
      47              :   const bool disable_bias =
      48           29 :     std::get<props::DisableBias>(*layer_impl_props).get();
      49              : 
      50           29 :   NNTR_THROW_IF(std::get<props::Unit>(lstmcore_props).empty(),
      51              :                 std::invalid_argument)
      52              :     << "unit property missing for lstmcell layer";
      53           29 :   const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
      54              :   const bool integrate_bias =
      55           29 :     std::get<props::IntegrateBias>(lstmcore_props).get();
      56              :   const ActivationType hidden_state_activation_type =
      57           29 :     std::get<props::HiddenStateActivation>(lstmcore_props).get();
      58              :   const ActivationType recurrent_activation_type =
      59           29 :     std::get<props::RecurrentActivation>(lstmcore_props).get();
      60              : 
      61           29 :   const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props).get();
      62              : 
      63           29 :   NNTR_THROW_IF(context.getNumInputs() != 3, std::invalid_argument)
      64              :     << "LSTMCell layer expects 3 inputs(one for the input and two for the "
      65            0 :        "hidden/cell state) but got " +
      66            0 :          std::to_string(context.getNumInputs()) + " input(s)";
      67              : 
      68              :   // input_dim = [ batch_size, 1, 1, feature_size ]
      69              :   const TensorDim &input_dim = context.getInputDimensions()[INOUT_INDEX::INPUT];
      70           29 :   NNTR_THROW_IF(input_dim.channel() != 1 || input_dim.height() != 1,
      71              :                 std::invalid_argument)
      72              :     << "Input must be single time dimension for LSTMCell (shape should be "
      73              :        "[batch_size, 1, 1, feature_size])";
      74              :   // input_hidden_state_dim = [ batch, 1, 1, unit ]
      75              :   const TensorDim &input_hidden_state_dim =
      76              :     context.getInputDimensions()[INOUT_INDEX::INPUT_HIDDEN_STATE];
      77           29 :   NNTR_THROW_IF(input_hidden_state_dim.channel() != 1 ||
      78              :                   input_hidden_state_dim.height() != 1,
      79              :                 std::invalid_argument)
      80              :     << "Input hidden state's dimension should be [batch, 1, 1, unit] for "
      81              :        "LSTMCell";
      82              :   // input_cell_state_dim = [ batch, 1, 1, unit ]
      83              :   const TensorDim &input_cell_state_dim =
      84              :     context.getInputDimensions()[INOUT_INDEX::INPUT_CELL_STATE];
      85           29 :   NNTR_THROW_IF(input_cell_state_dim.channel() != 1 ||
      86              :                   input_cell_state_dim.height() != 1,
      87              :                 std::invalid_argument)
      88              :     << "Input cell state's dimension should be [batch, 1, 1, unit] for "
      89              :        "LSTMCell";
      90           29 :   const unsigned int batch_size = input_dim.batch();
      91           29 :   const unsigned int feature_size = input_dim.width();
      92              : 
      93              :   TensorDim::TensorType weight_tensor_type = {context.getFormat(),
      94              :                                               context.getWeightDataType()};
      95              : 
      96              :   // output_hidden_state_dim = [ batch_size, 1, 1, unit ]
      97           29 :   const TensorDim output_hidden_state_dim = input_hidden_state_dim;
      98              :   // output_cell_state_dim = [ batch_size, 1, 1, unit ]
      99           29 :   const TensorDim output_cell_state_dim = input_cell_state_dim;
     100              : 
     101              :   std::vector<VarGradSpecV2> out_specs;
     102              :   out_specs.push_back(
     103           29 :     InitLayerContext::outSpec(output_hidden_state_dim, "output_hidden_state",
     104              :                               TensorLifespan::FORWARD_FUNC_LIFESPAN));
     105              :   out_specs.push_back(
     106           29 :     InitLayerContext::outSpec(output_cell_state_dim, "output_cell_state",
     107              :                               TensorLifespan::FORWARD_GRAD_LIFESPAN));
     108           29 :   context.requestOutputs(std::move(out_specs));
     109              : 
     110              :   // weight_initializer can be set seperately. weight_ih initializer,
     111              :   // weight_hh initializer kernel initializer & recurrent_initializer in keras
     112              :   // for now, it is set same way.
     113              : 
     114              :   // - weight_ih ( input to hidden )
     115              :   //  : [ 1, 1, feature_size, NUM_GATE x unit ] -> i, f, g, o
     116           29 :   TensorDim weight_ih_dim({feature_size, NUM_GATE * unit}, weight_tensor_type);
     117           29 :   wt_idx[LSTMCellParams::weight_ih] = context.requestWeight(
     118              :     weight_ih_dim, weight_initializer, weight_regularizer,
     119              :     weight_regularizer_constant, weight_decay, "weight_ih", true);
     120              :   // - weight_hh ( hidden to hidden )
     121              :   //  : [ 1, 1, unit, NUM_GATE x unit ] -> i, f, g, o
     122           29 :   TensorDim weight_hh_dim({unit, NUM_GATE * unit}, weight_tensor_type);
     123           58 :   wt_idx[LSTMCellParams::weight_hh] = context.requestWeight(
     124              :     weight_hh_dim, weight_initializer, weight_regularizer,
     125              :     weight_regularizer_constant, weight_decay, "weight_hh", true);
     126           29 :   if (!disable_bias) {
     127           29 :     if (integrate_bias) {
     128              :       // - bias_h ( input bias, hidden bias are integrate to 1 bias )
     129              :       //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g, o
     130            1 :       TensorDim bias_h_dim({NUM_GATE * unit}, weight_tensor_type);
     131            1 :       wt_idx[LSTMCellParams::bias_h] = context.requestWeight(
     132              :         bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
     133              :         "bias_h", true);
     134              :     } else {
     135              :       // - bias_ih ( input bias )
     136              :       //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g, o
     137           28 :       TensorDim bias_ih_dim({NUM_GATE * unit}, weight_tensor_type);
     138           28 :       wt_idx[LSTMCellParams::bias_ih] = context.requestWeight(
     139              :         bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
     140              :         bias_decay, "bias_ih", true);
     141              :       // - bias_hh ( hidden bias )
     142              :       //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g, o
     143           28 :       TensorDim bias_hh_dim({NUM_GATE * unit}, weight_tensor_type);
     144           56 :       wt_idx[LSTMCellParams::bias_hh] = context.requestWeight(
     145              :         bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
     146              :         bias_decay, "bias_hh", true);
     147              :     }
     148              :   }
     149              : 
     150              :   /** ifgo_dim = [ batch_size, 1, 1, NUM_GATE * unit ] */
     151              :   const TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit,
     152           29 :                            weight_tensor_type);
     153           29 :   wt_idx[LSTMCellParams::ifgo] =
     154           29 :     context.requestTensor(ifgo_dim, "ifgo", Initializer::NONE, true,
     155              :                           TensorLifespan::ITERATION_LIFESPAN);
     156              : 
     157           29 :   if (dropout_rate > epsilon) {
     158              :     // dropout_mask_dim = [ batch_size, 1, 1, unit ]
     159              :     const TensorDim dropout_mask_dim(batch_size, 1, 1, unit,
     160            0 :                                      weight_tensor_type);
     161            0 :     wt_idx[LSTMCellParams::dropout_mask] =
     162            0 :       context.requestTensor(dropout_mask_dim, "dropout_mask", Initializer::NONE,
     163              :                             false, TensorLifespan::ITERATION_LIFESPAN);
     164              :   }
     165              : 
     166           29 :   if (context.getActivationDataType() == TensorDim::DataType::FP32) {
     167           29 :     acti_func.setActiFunc<float>(hidden_state_activation_type);
     168           29 :     recurrent_acti_func.setActiFunc<float>(recurrent_activation_type);
     169            0 :   } else if (context.getActivationDataType() == TensorDim::DataType::FP16) {
     170              : #ifdef ENABLE_FP16
     171              :     acti_func.setActiFunc<_FP16>(hidden_state_activation_type);
     172              :     recurrent_acti_func.setActiFunc<_FP16>(recurrent_activation_type);
     173              : #else
     174            0 :     throw std::invalid_argument("Error: enable-fp16 is not enabled");
     175              : #endif
     176              :   }
     177           29 : }
     178              : 
     179          250 : void LSTMCellLayer::setProperty(const std::vector<std::string> &values) {
     180              :   const std::vector<std::string> &remain_props =
     181          250 :     loadProperties(values, lstmcell_props);
     182          249 :   LSTMCore::setProperty(remain_props);
     183          249 : }
     184              : 
     185           56 : void LSTMCellLayer::exportTo(Exporter &exporter,
     186              :                              const ml::train::ExportMethods &method) const {
     187           56 :   LSTMCore::exportTo(exporter, method);
     188           56 :   exporter.saveResult(lstmcell_props, method, this);
     189           56 : }
     190              : 
     191           41 : void LSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
     192              :   const bool disable_bias =
     193           41 :     std::get<props::DisableBias>(*layer_impl_props).get();
     194              : 
     195           41 :   const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
     196              :   const bool integrate_bias =
     197           41 :     std::get<props::IntegrateBias>(lstmcore_props).get();
     198              : 
     199           41 :   const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props).get();
     200              : 
     201           41 :   const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
     202              :   const Tensor &prev_hidden_state =
     203           41 :     context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
     204              :   const Tensor &prev_cell_state =
     205           41 :     context.getInput(INOUT_INDEX::INPUT_CELL_STATE);
     206           41 :   Tensor &hidden_state = context.getOutput(INOUT_INDEX::OUTPUT_HIDDEN_STATE);
     207           41 :   Tensor &cell_state = context.getOutput(INOUT_INDEX::OUTPUT_CELL_STATE);
     208              : 
     209           41 :   const unsigned int batch_size = input.getDim().batch();
     210              : 
     211              :   const Tensor &weight_ih =
     212           41 :     context.getWeight(wt_idx[LSTMCellParams::weight_ih]);
     213              :   const Tensor &weight_hh =
     214           41 :     context.getWeight(wt_idx[LSTMCellParams::weight_hh]);
     215              : 
     216              :   Tensor empty =
     217           41 :     Tensor("empty", weight_ih.getFormat(), weight_ih.getDataType());
     218              : 
     219           41 :   const Tensor &bias_h = !disable_bias && integrate_bias
     220           41 :                            ? context.getWeight(wt_idx[LSTMCellParams::bias_h])
     221              :                            : empty;
     222              :   const Tensor &bias_ih = !disable_bias && !integrate_bias
     223           41 :                             ? context.getWeight(wt_idx[LSTMCellParams::bias_ih])
     224              :                             : empty;
     225              :   const Tensor &bias_hh = !disable_bias && !integrate_bias
     226           41 :                             ? context.getWeight(wt_idx[LSTMCellParams::bias_hh])
     227              :                             : empty;
     228              : 
     229           41 :   Tensor &ifgo = context.getTensor(wt_idx[LSTMCellParams::ifgo]);
     230              : 
     231           41 :   forwardLSTM(batch_size, unit, disable_bias, integrate_bias, acti_func,
     232           41 :               recurrent_acti_func, input, prev_hidden_state, prev_cell_state,
     233              :               hidden_state, cell_state, weight_ih, weight_hh, bias_h, bias_ih,
     234              :               bias_hh, ifgo);
     235              : 
     236           41 :   if (dropout_rate > epsilon && training) {
     237              :     Tensor &dropout_mask =
     238            0 :       context.getTensor(wt_idx[LSTMCellParams::dropout_mask]);
     239            0 :     dropout_mask.dropout_mask(dropout_rate);
     240            0 :     hidden_state.multiply_i(dropout_mask);
     241              :   }
     242           41 : }
     243              : 
     244           19 : void LSTMCellLayer::calcDerivative(RunLayerContext &context) {
     245           19 :   Tensor &d_ifgo = context.getTensorGrad(wt_idx[LSTMCellParams::ifgo]);
     246              :   const Tensor &weight_ih =
     247           19 :     context.getWeight(wt_idx[LSTMCellParams::weight_ih]);
     248              :   Tensor &outgoing_derivative =
     249           19 :     context.getOutgoingDerivative(INOUT_INDEX::INPUT);
     250              : 
     251           19 :   calcDerivativeLSTM(outgoing_derivative, weight_ih, d_ifgo);
     252           19 : }
     253              : 
     254           19 : void LSTMCellLayer::calcGradient(RunLayerContext &context) {
     255              :   const bool disable_bias =
     256           19 :     std::get<props::DisableBias>(*layer_impl_props).get();
     257              : 
     258           19 :   const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
     259              :   const bool integrate_bias =
     260           19 :     std::get<props::IntegrateBias>(lstmcore_props).get();
     261              : 
     262           19 :   const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
     263              : 
     264           19 :   const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
     265              :   const Tensor &prev_hidden_state =
     266           19 :     context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
     267              :   Tensor &d_prev_hidden_state =
     268           19 :     context.getOutgoingDerivative(INOUT_INDEX::INPUT_HIDDEN_STATE);
     269              :   const Tensor &prev_cell_state =
     270           19 :     context.getInput(INOUT_INDEX::INPUT_CELL_STATE);
     271              :   Tensor &d_prev_cell_state =
     272           19 :     context.getOutgoingDerivative(INOUT_INDEX::INPUT_CELL_STATE);
     273              :   const Tensor &d_hidden_state =
     274           19 :     context.getIncomingDerivative(INOUT_INDEX::OUTPUT_HIDDEN_STATE);
     275           19 :   const Tensor &cell_state = context.getOutput(INOUT_INDEX::OUTPUT_CELL_STATE);
     276              :   const Tensor &d_cell_state =
     277           19 :     context.getIncomingDerivative(INOUT_INDEX::OUTPUT_CELL_STATE);
     278              : 
     279           19 :   unsigned int batch_size = input.getDim().batch();
     280              : 
     281              :   Tensor &d_weight_ih =
     282           19 :     context.getWeightGrad(wt_idx[LSTMCellParams::weight_ih]);
     283              :   const Tensor &weight_hh =
     284           19 :     context.getWeight(wt_idx[LSTMCellParams::weight_hh]);
     285              :   Tensor &d_weight_hh =
     286           19 :     context.getWeightGrad(wt_idx[LSTMCellParams::weight_hh]);
     287              : 
     288              :   Tensor empty =
     289           19 :     Tensor("empty", weight_hh.getFormat(), weight_hh.getDataType());
     290              : 
     291           19 :   Tensor &d_bias_h = !disable_bias && integrate_bias
     292           19 :                        ? context.getWeightGrad(wt_idx[LSTMCellParams::bias_h])
     293              :                        : empty;
     294              :   Tensor &d_bias_ih = !disable_bias && !integrate_bias
     295           19 :                         ? context.getWeightGrad(wt_idx[LSTMCellParams::bias_ih])
     296              :                         : empty;
     297              :   Tensor &d_bias_hh = !disable_bias && !integrate_bias
     298           19 :                         ? context.getWeightGrad(wt_idx[LSTMCellParams::bias_hh])
     299              :                         : empty;
     300              : 
     301           19 :   const Tensor &ifgo = context.getTensor(wt_idx[LSTMCellParams::ifgo]);
     302           19 :   Tensor &d_ifgo = context.getTensorGrad(wt_idx[LSTMCellParams::ifgo]);
     303              : 
     304           19 :   if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::weight_ih])) {
     305            9 :     d_weight_ih.setZero();
     306              :   }
     307           19 :   if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::weight_hh])) {
     308            9 :     d_weight_hh.setZero();
     309              :   }
     310           19 :   if (!disable_bias) {
     311           19 :     if (integrate_bias) {
     312            1 :       if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::bias_h])) {
     313            0 :         d_bias_h.setZero();
     314              :       }
     315              :     } else {
     316           18 :       if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::bias_ih])) {
     317            9 :         d_bias_ih.setZero();
     318              :       }
     319           18 :       if (context.isGradientFirstAccess(wt_idx[LSTMCellParams::bias_hh])) {
     320            9 :         d_bias_hh.setZero();
     321              :       }
     322              :     }
     323              :   }
     324              : 
     325              :   Tensor d_hidden_state_masked = Tensor(
     326           19 :     "d_hidden_state_masked", weight_hh.getFormat(), weight_hh.getDataType());
     327              : 
     328           19 :   if (dropout_rate > epsilon) {
     329              :     Tensor &dropout_mask =
     330            0 :       context.getTensor(wt_idx[LSTMCellParams::dropout_mask]);
     331            0 :     d_hidden_state.multiply(dropout_mask, d_hidden_state_masked);
     332              :   }
     333              : 
     334           38 :   calcGradientLSTM(batch_size, unit, disable_bias, integrate_bias, acti_func,
     335           19 :                    recurrent_acti_func, input, prev_hidden_state,
     336              :                    d_prev_hidden_state, prev_cell_state, d_prev_cell_state,
     337           19 :                    dropout_rate > epsilon ? d_hidden_state_masked
     338              :                                           : d_hidden_state,
     339              :                    cell_state, d_cell_state, d_weight_ih, weight_hh,
     340              :                    d_weight_hh, d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo);
     341           19 : }
     342              : 
     343           24 : void LSTMCellLayer::setBatch(RunLayerContext &context, unsigned int batch) {
     344           24 :   const float dropout_rate = std::get<props::DropOutRate>(lstmcell_props);
     345           24 :   context.updateTensor(wt_idx[LSTMCellParams::ifgo], batch);
     346           24 :   if (dropout_rate > epsilon) {
     347            0 :     context.updateTensor(wt_idx[LSTMCellParams::dropout_mask], batch);
     348              :   }
     349           24 : }
     350              : 
     351              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1