LCOV - code coverage report
Current view: top level - nntrainer/layers - zoneout_lstmcell.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 89.1 % 211 188
Test Date: 2025-12-14 20:38:17 Functions: 81.8 % 11 9

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
       4              :  *
       5              :  * @file   zoneout_lstmcell.cpp
       6              :  * @date   30 November 2021
       7              :  * @brief  This is ZoneoutLSTMCell Layer Class of Neural Network
       8              :  * @see    https://github.com/nnstreamer/nntrainer
       9              :  *         https://arxiv.org/pdf/1606.01305.pdf
      10              :  *         https://github.com/teganmaharaj/zoneout
      11              :  * @author hyeonseok lee <hs89.lee@samsung.com>
      12              :  * @bug    No known bugs except for NYI items
      13              :  *
      14              :  */
      15              : 
      16              : #include <layer_context.h>
      17              : #include <nntrainer_error.h>
      18              : #include <nntrainer_log.h>
      19              : #include <node_exporter.h>
      20              : #include <zoneout_lstmcell.h>
      21              : 
      22              : namespace nntrainer {
      23              : 
      24              : enum ZoneoutLSTMParams {
      25              :   weight_ih,
      26              :   weight_hh,
      27              :   bias_h,
      28              :   bias_ih,
      29              :   bias_hh,
      30              :   ifgo,
      31              :   hidden_state_zoneout_mask,
      32              :   cell_state_zoneout_mask,
      33              :   lstm_cell_state,
      34              : };
      35              : 
      36          270 : ZoneoutLSTMCellLayer::ZoneoutLSTMCellLayer() :
      37          270 :   zoneout_lstmcell_props(HiddenStateZoneOutRate(), CellStateZoneOutRate(),
      38          540 :                          Test(), props::MaxTimestep(), props::Timestep()) {
      39              :   wt_idx.fill(std::numeric_limits<unsigned>::max());
      40          270 : }
      41              : 
      42          270 : bool ZoneoutLSTMCellLayer::HiddenStateZoneOutRate::isValid(
      43              :   const float &value) const {
      44          270 :   if (value < 0.0f || value > 1.0f) {
      45              :     return false;
      46              :   } else {
      47          270 :     return true;
      48              :   }
      49              : }
      50              : 
      51          270 : bool ZoneoutLSTMCellLayer::CellStateZoneOutRate::isValid(
      52              :   const float &value) const {
      53          270 :   if (value < 0.0f || value > 1.0f) {
      54              :     return false;
      55              :   } else {
      56          270 :     return true;
      57              :   }
      58              : }
      59              : 
      60          216 : void ZoneoutLSTMCellLayer::finalize(InitLayerContext &context) {
      61              :   const Initializer weight_initializer =
      62          216 :     std::get<props::WeightInitializer>(*layer_impl_props).get();
      63              :   const Initializer bias_initializer =
      64          216 :     std::get<props::BiasInitializer>(*layer_impl_props).get();
      65              :   const WeightRegularizer weight_regularizer =
      66          216 :     std::get<props::WeightRegularizer>(*layer_impl_props).get();
      67              :   const float weight_regularizer_constant =
      68          216 :     std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
      69              :   auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
      70              :   auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
      71              :   const bool disable_bias =
      72          216 :     std::get<props::DisableBias>(*layer_impl_props).get();
      73              : 
      74          216 :   NNTR_THROW_IF(std::get<props::Unit>(lstmcore_props).empty(),
      75              :                 std::invalid_argument)
      76              :     << "unit property missing for zoneout_lstmcell layer";
      77          216 :   const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
      78              :   const bool integrate_bias =
      79          216 :     std::get<props::IntegrateBias>(lstmcore_props).get();
      80              :   const ActivationType hidden_state_activation_type =
      81          216 :     std::get<props::HiddenStateActivation>(lstmcore_props).get();
      82              :   const ActivationType recurrent_activation_type =
      83          216 :     std::get<props::RecurrentActivation>(lstmcore_props).get();
      84              : 
      85          216 :   const bool test = std::get<Test>(zoneout_lstmcell_props).get();
      86              :   const unsigned int max_timestep =
      87          216 :     std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
      88              : 
      89          216 :   NNTR_THROW_IF(context.getNumInputs() != 3, std::invalid_argument)
      90              :     << "Number of input is not 3. ZoneoutLSTMCellLayer should takes 3 inputs";
      91              : 
      92              :   // input_dim = [ batch_size, 1, 1, feature_size ]
      93              :   const TensorDim &input_dim = context.getInputDimensions()[INOUT_INDEX::INPUT];
      94          216 :   if (input_dim.channel() != 1 || input_dim.height() != 1)
      95              :     throw std::invalid_argument("Input must be single time dimension for "
      96              :                                 "ZoneoutLSTMCell (shape should be "
      97            0 :                                 "[batch_size, 1, 1, feature_size])");
      98              :   // input_hidden_state_dim = [ batch_size, 1, 1, unit ]
      99              :   const TensorDim &input_hidden_state_dim =
     100              :     context.getInputDimensions()[INOUT_INDEX::INPUT_HIDDEN_STATE];
     101          432 :   if (input_hidden_state_dim.channel() != 1 ||
     102          216 :       input_hidden_state_dim.height() != 1) {
     103              :     throw std::invalid_argument(
     104              :       "Input hidden state's dimension should be"
     105            0 :       "[batch_size, 1, 1, unit] for zoneout LSTMcell");
     106              :   }
     107              :   // input_cell_state_dim = [ batch_size, 1, 1, unit ]
     108              :   const TensorDim &input_cell_state_dim =
     109              :     context.getInputDimensions()[INOUT_INDEX::INPUT_CELL_STATE];
     110          432 :   if (input_cell_state_dim.channel() != 1 ||
     111          216 :       input_cell_state_dim.height() != 1) {
     112              :     throw std::invalid_argument(
     113              :       "Input cell state's dimension should be"
     114            0 :       "[batch_size, 1, 1, unit] for zoneout LSTMcell");
     115              :   }
     116          216 :   const unsigned int batch_size = input_dim.batch();
     117          216 :   const unsigned int feature_size = input_dim.width();
     118              : 
     119              :   // output_hidden_state_dim = [ batch_size, 1, 1, unit ]
     120          216 :   const TensorDim output_hidden_state_dim = input_hidden_state_dim;
     121              :   // output_cell_state_dim = [ batch_size, 1, 1, unit ]
     122          216 :   const TensorDim output_cell_state_dim = input_cell_state_dim;
     123              : 
     124              :   std::vector<VarGradSpecV2> out_specs;
     125              :   /// note: those out spec can be forward func, but for the test, it is being
     126              :   /// kept to forward deriv lifespan
     127              :   out_specs.push_back(
     128          216 :     InitLayerContext::outSpec(output_hidden_state_dim, "output_hidden_state",
     129              :                               TensorLifespan::FORWARD_DERIV_LIFESPAN));
     130              :   //////////////////////////  TensorLifespan::FORWARD_FUNC_LIFESPAN));
     131              :   out_specs.push_back(
     132          216 :     InitLayerContext::outSpec(output_cell_state_dim, "output_cell_state",
     133              :                               TensorLifespan::FORWARD_DERIV_LIFESPAN));
     134              :   // cell_state_zoneout_rate > epsilon ? TensorLifespan::FORWARD_FUNC_LIFESPAN
     135              :   //                                   :
     136              :   //                                   TensorLifespan::FORWARD_GRAD_LIFESPAN));
     137          216 :   context.requestOutputs(std::move(out_specs));
     138              : 
     139              :   // weight_initializer can be set seperately.
     140              :   // weight_ih initializer, weight_hh initializer
     141              :   // kernel initializer & recurrent_initializer in
     142              :   // keras for now, it is set same way.
     143              : 
     144              :   // - weight_ih ( input to hidden )
     145              :   //  : [ 1, 1, feature_size, NUM_GATE x unit ] ->
     146              :   //  i, f, g, o
     147          216 :   TensorDim weight_ih_dim({feature_size, NUM_GATE * unit});
     148          216 :   wt_idx[ZoneoutLSTMParams::weight_ih] = context.requestWeight(
     149              :     weight_ih_dim, weight_initializer, weight_regularizer,
     150              :     weight_regularizer_constant, weight_decay, "weight_ih", true);
     151              :   // - weight_hh ( hidden to hidden )
     152              :   //  : [ 1, 1, unit, NUM_GATE x unit ] -> i, f, g,
     153              :   //  o
     154          216 :   TensorDim weight_hh_dim({unit, NUM_GATE * unit});
     155          432 :   wt_idx[ZoneoutLSTMParams::weight_hh] = context.requestWeight(
     156              :     weight_hh_dim, weight_initializer, weight_regularizer,
     157              :     weight_regularizer_constant, weight_decay, "weight_hh", true);
     158          216 :   if (!disable_bias) {
     159          216 :     if (integrate_bias) {
     160              :       // - bias_h ( input bias, hidden bias are
     161              :       // integrate to 1 bias )
     162              :       //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g,
     163              :       //  o
     164            0 :       TensorDim bias_h_dim({NUM_GATE * unit});
     165            0 :       wt_idx[ZoneoutLSTMParams::bias_h] = context.requestWeight(
     166              :         bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
     167              :         "bias_h", true);
     168              :     } else {
     169              :       // - bias_ih ( input bias )
     170              :       //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g,
     171              :       //  o
     172          216 :       TensorDim bias_ih_dim({NUM_GATE * unit});
     173          216 :       wt_idx[ZoneoutLSTMParams::bias_ih] = context.requestWeight(
     174              :         bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
     175              :         bias_decay, "bias_ih", true);
     176              :       // - bias_hh ( hidden bias )
     177              :       //  : [ 1, 1, 1, NUM_GATE x unit ] -> i, f, g,
     178              :       //  o
     179          216 :       TensorDim bias_hh_dim({NUM_GATE * unit});
     180          432 :       wt_idx[ZoneoutLSTMParams::bias_hh] = context.requestWeight(
     181              :         bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
     182              :         bias_decay, "bias_hh", true);
     183              :     }
     184              :   }
     185              : 
     186              :   /** ifgo_dim = [ batch_size, 1, 1, NUM_GATE * unit
     187              :    * ] */
     188          216 :   const TensorDim ifgo_dim(batch_size, 1, 1, NUM_GATE * unit);
     189          216 :   wt_idx[ZoneoutLSTMParams::ifgo] =
     190          432 :     context.requestTensor(ifgo_dim, "ifgo", Initializer::NONE, true,
     191              :                           TensorLifespan::ITERATION_LIFESPAN);
     192              : 
     193              :   // hidden_state_zoneout_mask_dim = [ max_timestep
     194              :   // * batch_size, 1, 1, unit ]
     195          216 :   const TensorDim hidden_state_zoneout_mask_dim(max_timestep * batch_size, 1, 1,
     196          216 :                                                 unit);
     197          216 :   if (test) {
     198          216 :     wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask] =
     199          432 :       context.requestWeight(hidden_state_zoneout_mask_dim, Initializer::NONE,
     200              :                             WeightRegularizer::NONE, 1.0f, 0.0f,
     201              :                             "hidden_state_zoneout_mask", false);
     202              :   } else {
     203            0 :     wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask] =
     204            0 :       context.requestTensor(hidden_state_zoneout_mask_dim,
     205              :                             "hidden_state_zoneout_mask", Initializer::NONE,
     206              :                             false, TensorLifespan::ITERATION_LIFESPAN, false);
     207              :   }
     208              : 
     209              :   // cell_state_zoneout_mask_dim = [ max_timestep * batch_size, 1, 1, unit ]
     210              :   const TensorDim cell_state_zoneout_mask_dim(max_timestep * batch_size, 1, 1,
     211          216 :                                               unit);
     212          216 :   if (test) {
     213          216 :     wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestWeight(
     214              :       cell_state_zoneout_mask_dim, Initializer::NONE, WeightRegularizer::NONE,
     215              :       1.0f, 0.0f, "cell_state_zoneout_mask", false);
     216              :   } else {
     217            0 :     wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask] = context.requestTensor(
     218              :       cell_state_zoneout_mask_dim, "cell_state_zoneout_mask", Initializer::NONE,
     219              :       false, TensorLifespan::ITERATION_LIFESPAN, false);
     220              :   }
     221              : 
     222              :   // lstm_cell_state_dim = [ batch_size, 1, 1, unit ]
     223          216 :   const TensorDim lstm_cell_state_dim(batch_size, 1, 1, unit);
     224          216 :   wt_idx[ZoneoutLSTMParams::lstm_cell_state] = context.requestTensor(
     225              :     lstm_cell_state_dim, "lstm_cell_state", Initializer::NONE, true,
     226              :     TensorLifespan::ITERATION_LIFESPAN);
     227              : 
     228          216 :   acti_func.setActiFunc(hidden_state_activation_type);
     229          216 :   recurrent_acti_func.setActiFunc(recurrent_activation_type);
     230          216 : }
     231              : 
     232         1512 : void ZoneoutLSTMCellLayer::setProperty(const std::vector<std::string> &values) {
     233              :   const std::vector<std::string> &remain_props =
     234         1512 :     loadProperties(values, zoneout_lstmcell_props);
     235         1512 :   LSTMCore::setProperty(remain_props);
     236         1512 : }
     237              : 
     238          216 : void ZoneoutLSTMCellLayer::exportTo(
     239              :   Exporter &exporter, const ml::train::ExportMethods &method) const {
     240          216 :   LSTMCore::exportTo(exporter, method);
     241          216 :   exporter.saveResult(zoneout_lstmcell_props, method, this);
     242          216 : }
     243              : 
     244          324 : void ZoneoutLSTMCellLayer::forwarding(RunLayerContext &context, bool training) {
     245              :   const bool disable_bias =
     246          324 :     std::get<props::DisableBias>(*layer_impl_props).get();
     247              : 
     248          324 :   const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
     249              :   const bool integrate_bias =
     250          324 :     std::get<props::IntegrateBias>(lstmcore_props).get();
     251              : 
     252              :   const float hidden_state_zoneout_rate =
     253          324 :     std::get<HiddenStateZoneOutRate>(zoneout_lstmcell_props).get();
     254              :   const float cell_state_zoneout_rate =
     255          324 :     std::get<CellStateZoneOutRate>(zoneout_lstmcell_props).get();
     256          324 :   const bool test = std::get<Test>(zoneout_lstmcell_props).get();
     257              :   const unsigned int max_timestep =
     258          324 :     std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
     259              :   const unsigned int timestep =
     260          324 :     std::get<props::Timestep>(zoneout_lstmcell_props).get();
     261              : 
     262          324 :   const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
     263              :   const Tensor &prev_hidden_state =
     264          324 :     context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
     265              :   const Tensor &prev_cell_state =
     266          324 :     context.getInput(INOUT_INDEX::INPUT_CELL_STATE);
     267          324 :   Tensor &hidden_state = context.getOutput(INOUT_INDEX::OUTPUT_HIDDEN_STATE);
     268          324 :   Tensor &cell_state = context.getOutput(INOUT_INDEX::OUTPUT_CELL_STATE);
     269              : 
     270          324 :   const unsigned int batch_size = input.getDim().batch();
     271              : 
     272              :   const Tensor &weight_ih =
     273          324 :     context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]);
     274              :   const Tensor &weight_hh =
     275          324 :     context.getWeight(wt_idx[ZoneoutLSTMParams::weight_hh]);
     276          324 :   Tensor empty;
     277              :   const Tensor &bias_h =
     278          324 :     !disable_bias && integrate_bias
     279          324 :       ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_h])
     280              :       : empty;
     281              :   const Tensor &bias_ih =
     282              :     !disable_bias && !integrate_bias
     283          324 :       ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_ih])
     284              :       : empty;
     285              :   const Tensor &bias_hh =
     286              :     !disable_bias && !integrate_bias
     287          324 :       ? context.getWeight(wt_idx[ZoneoutLSTMParams::bias_hh])
     288              :       : empty;
     289              : 
     290          324 :   Tensor &ifgo = context.getTensor(wt_idx[ZoneoutLSTMParams::ifgo]);
     291              :   Tensor &lstm_cell_state =
     292          324 :     context.getTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state]);
     293              : 
     294          324 :   forwardLSTM(batch_size, unit, disable_bias, integrate_bias, acti_func,
     295          324 :               recurrent_acti_func, input, prev_hidden_state, prev_cell_state,
     296              :               hidden_state, lstm_cell_state, weight_ih, weight_hh, bias_h,
     297              :               bias_ih, bias_hh, ifgo);
     298              : 
     299          324 :   if (training) {
     300              :     Tensor &hs_zoneout_mask =
     301          162 :       test ? context.getWeight(
     302              :                wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
     303            0 :            : context.getTensor(
     304              :                wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
     305          162 :     hs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
     306              :     Tensor hidden_state_zoneout_mask =
     307          162 :       hs_zoneout_mask.getBatchSlice(timestep, 1);
     308          162 :     hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
     309          162 :     Tensor prev_hidden_state_zoneout_mask;
     310          162 :     if (!test) {
     311              :       prev_hidden_state_zoneout_mask =
     312            0 :         hidden_state_zoneout_mask.zoneout_mask(hidden_state_zoneout_rate);
     313              :     } else {
     314          162 :       hidden_state_zoneout_mask.multiply(-1.0f, prev_hidden_state_zoneout_mask);
     315          162 :       prev_hidden_state_zoneout_mask.add_i(1.0f);
     316              :     }
     317              : 
     318          162 :     hidden_state.multiply_i(hidden_state_zoneout_mask);
     319          162 :     prev_hidden_state.multiply(prev_hidden_state_zoneout_mask, hidden_state,
     320              :                                1.0f);
     321              :     Tensor &cs_zoneout_mask =
     322              :       test
     323          162 :         ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
     324            0 :         : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
     325          162 :     cs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
     326          162 :     Tensor cell_state_zoneout_mask = cs_zoneout_mask.getBatchSlice(timestep, 1);
     327          162 :     cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
     328          162 :     Tensor prev_cell_state_zoneout_mask;
     329          162 :     if (!test) {
     330              :       prev_cell_state_zoneout_mask =
     331            0 :         cell_state_zoneout_mask.zoneout_mask(cell_state_zoneout_rate);
     332              :     } else {
     333          162 :       cell_state_zoneout_mask.multiply(-1.0f, prev_cell_state_zoneout_mask);
     334          162 :       prev_cell_state_zoneout_mask.add_i(1.0f);
     335              :     }
     336              : 
     337          162 :     lstm_cell_state.multiply(cell_state_zoneout_mask, cell_state);
     338          162 :     prev_cell_state.multiply(prev_cell_state_zoneout_mask, cell_state, 1.0f);
     339          162 :   }
     340              :   // Todo: zoneout at inference
     341          324 : }
     342              : 
     343          162 : void ZoneoutLSTMCellLayer::calcDerivative(RunLayerContext &context) {
     344              :   Tensor &outgoing_derivative =
     345          162 :     context.getOutgoingDerivative(INOUT_INDEX::INPUT);
     346              :   const Tensor &weight_ih =
     347          162 :     context.getWeight(wt_idx[ZoneoutLSTMParams::weight_ih]);
     348          162 :   const Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]);
     349              : 
     350          162 :   calcDerivativeLSTM(outgoing_derivative, weight_ih, d_ifgo);
     351          162 : }
     352              : 
     353          162 : void ZoneoutLSTMCellLayer::calcGradient(RunLayerContext &context) {
     354              :   const bool disable_bias =
     355          162 :     std::get<props::DisableBias>(*layer_impl_props).get();
     356              : 
     357          162 :   const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
     358              :   const bool integrate_bias =
     359          162 :     std::get<props::IntegrateBias>(lstmcore_props).get();
     360              : 
     361          162 :   const bool test = std::get<Test>(zoneout_lstmcell_props).get();
     362              :   const unsigned int max_timestep =
     363          162 :     std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
     364              :   const unsigned int timestep =
     365          162 :     std::get<props::Timestep>(zoneout_lstmcell_props).get();
     366              : 
     367          162 :   const Tensor &input = context.getInput(INOUT_INDEX::INPUT);
     368              :   const Tensor &prev_hidden_state =
     369          162 :     context.getInput(INOUT_INDEX::INPUT_HIDDEN_STATE);
     370              :   Tensor &d_prev_hidden_state =
     371          162 :     context.getOutgoingDerivative(INOUT_INDEX::INPUT_HIDDEN_STATE);
     372              :   const Tensor &prev_cell_state =
     373          162 :     context.getInput(INOUT_INDEX::INPUT_CELL_STATE);
     374              :   Tensor &d_prev_cell_state =
     375          162 :     context.getOutgoingDerivative(INOUT_INDEX::INPUT_CELL_STATE);
     376              :   const Tensor &d_hidden_state =
     377          162 :     context.getIncomingDerivative(INOUT_INDEX::OUTPUT_HIDDEN_STATE);
     378              :   const Tensor &d_cell_state =
     379          162 :     context.getIncomingDerivative(INOUT_INDEX::OUTPUT_CELL_STATE);
     380              : 
     381          162 :   unsigned int batch_size = input.getDim().batch();
     382              : 
     383              :   Tensor &d_weight_ih =
     384          162 :     context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_ih]);
     385              :   const Tensor &weight_hh =
     386          162 :     context.getWeight(wt_idx[ZoneoutLSTMParams::weight_hh]);
     387              :   Tensor &d_weight_hh =
     388          162 :     context.getWeightGrad(wt_idx[ZoneoutLSTMParams::weight_hh]);
     389          162 :   Tensor empty;
     390              :   Tensor &d_bias_h =
     391          162 :     !disable_bias && integrate_bias
     392          162 :       ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_h])
     393              :       : empty;
     394              :   Tensor &d_bias_ih =
     395              :     !disable_bias && !integrate_bias
     396          162 :       ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_ih])
     397              :       : empty;
     398              :   Tensor &d_bias_hh =
     399              :     !disable_bias && !integrate_bias
     400          162 :       ? context.getWeightGrad(wt_idx[ZoneoutLSTMParams::bias_hh])
     401              :       : empty;
     402              : 
     403          162 :   Tensor &ifgo = context.getTensor(wt_idx[ZoneoutLSTMParams::ifgo]);
     404          162 :   Tensor &d_ifgo = context.getTensorGrad(wt_idx[ZoneoutLSTMParams::ifgo]);
     405              : 
     406              :   const Tensor &lstm_cell_state =
     407          162 :     context.getTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state]);
     408              :   Tensor &d_lstm_cell_state =
     409          162 :     context.getTensorGrad(wt_idx[ZoneoutLSTMParams::lstm_cell_state]);
     410              : 
     411          162 :   if (context.isGradientFirstAccess(wt_idx[ZoneoutLSTMParams::weight_ih])) {
     412           81 :     d_weight_ih.setZero();
     413              :   }
     414          162 :   if (context.isGradientFirstAccess(wt_idx[ZoneoutLSTMParams::weight_hh])) {
     415           81 :     d_weight_hh.setZero();
     416              :   }
     417          162 :   if (!disable_bias) {
     418          162 :     if (integrate_bias) {
     419            0 :       if (context.isGradientFirstAccess(wt_idx[ZoneoutLSTMParams::bias_h])) {
     420            0 :         d_bias_h.setZero();
     421              :       }
     422              :     } else {
     423          162 :       if (context.isGradientFirstAccess(wt_idx[ZoneoutLSTMParams::bias_ih])) {
     424           81 :         d_bias_ih.setZero();
     425              :       }
     426          162 :       if (context.isGradientFirstAccess(wt_idx[ZoneoutLSTMParams::bias_hh])) {
     427           81 :         d_bias_hh.setZero();
     428              :       }
     429              :     }
     430              :   }
     431              : 
     432          162 :   Tensor d_prev_hidden_state_residual;
     433          162 :   Tensor d_hidden_state_masked;
     434              :   Tensor &hs_zoneout_mask =
     435              :     test
     436          162 :       ? context.getWeight(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask])
     437            0 :       : context.getTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask]);
     438          162 :   hs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
     439          162 :   Tensor hidden_state_zoneout_mask = hs_zoneout_mask.getBatchSlice(timestep, 1);
     440          162 :   hidden_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
     441              :   Tensor prev_hidden_state_zoneout_mask = hidden_state_zoneout_mask.apply(
     442          162 :     (std::function<float(float)>)[epsilon = epsilon](float x) {
     443          324 :       return x < epsilon;
     444          162 :     });
     445              : 
     446          162 :   d_hidden_state.multiply(prev_hidden_state_zoneout_mask,
     447              :                           d_prev_hidden_state_residual);
     448          162 :   d_hidden_state.multiply(hidden_state_zoneout_mask, d_hidden_state_masked);
     449              : 
     450          162 :   Tensor d_prev_cell_state_residual;
     451              :   Tensor &cs_zoneout_mask =
     452              :     test
     453          162 :       ? context.getWeight(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask])
     454            0 :       : context.getTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask]);
     455          162 :   cs_zoneout_mask.reshape({max_timestep, 1, batch_size, unit});
     456          162 :   Tensor cell_state_zoneout_mask = cs_zoneout_mask.getBatchSlice(timestep, 1);
     457          162 :   cell_state_zoneout_mask.reshape({batch_size, 1, 1, unit});
     458              :   Tensor prev_cell_state_zoneout_mask = cell_state_zoneout_mask.apply(
     459          162 :     (std::function<float(float)>)[epsilon = epsilon](float x) {
     460          324 :       return x < epsilon;
     461          162 :     });
     462              : 
     463          162 :   d_cell_state.multiply(prev_cell_state_zoneout_mask,
     464              :                         d_prev_cell_state_residual);
     465          162 :   d_cell_state.multiply(cell_state_zoneout_mask, d_lstm_cell_state);
     466              : 
     467          162 :   calcGradientLSTM(batch_size, unit, disable_bias, integrate_bias, acti_func,
     468          162 :                    recurrent_acti_func, input, prev_hidden_state,
     469              :                    d_prev_hidden_state, prev_cell_state, d_prev_cell_state,
     470              :                    d_hidden_state_masked, lstm_cell_state, d_lstm_cell_state,
     471              :                    d_weight_ih, weight_hh, d_weight_hh, d_bias_h, d_bias_ih,
     472              :                    d_bias_hh, ifgo, d_ifgo);
     473              : 
     474          162 :   d_prev_hidden_state.add_i(d_prev_hidden_state_residual);
     475          162 :   d_prev_cell_state.add_i(d_prev_cell_state_residual);
     476          162 : }
     477              : 
     478            0 : void ZoneoutLSTMCellLayer::setBatch(RunLayerContext &context,
     479              :                                     unsigned int batch) {
     480              :   const unsigned int max_timestep =
     481            0 :     std::get<props::MaxTimestep>(zoneout_lstmcell_props).get();
     482              : 
     483            0 :   context.updateTensor(wt_idx[ZoneoutLSTMParams::ifgo], batch);
     484              : 
     485            0 :   context.updateTensor(wt_idx[ZoneoutLSTMParams::hidden_state_zoneout_mask],
     486              :                        max_timestep * batch);
     487            0 :   context.updateTensor(wt_idx[ZoneoutLSTMParams::cell_state_zoneout_mask],
     488              :                        max_timestep * batch);
     489            0 :   context.updateTensor(wt_idx[ZoneoutLSTMParams::lstm_cell_state], batch);
     490            0 : }
     491              : 
     492              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1