LCOV - code coverage report
Current view: top level - nntrainer/layers - lstm.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 73.8 % 447 330
Test Date: 2025-12-14 20:38:17 Functions: 83.3 % 12 10

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2020 Jijoong Moon <jijoong.moon@samsung.com>
       4              :  *
       5              :  * @file   lstm.cpp
       6              :  * @date   17 March 2021
       7              :  * @brief  This is Long Short-Term Memory Layer Class of Neural Network
       8              :  * @see    https://github.com/nnstreamer/nntrainer
       9              :  * @author Jijoong Moon <jijoong.moon@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  *
      12              :  */
      13              : 
      14              : #include <layer_context.h>
      15              : #include <lstm.h>
      16              : #include <nntr_threads.h>
      17              : #include <nntrainer_error.h>
      18              : #include <nntrainer_log.h>
      19              : #include <node_exporter.h>
      20              : 
      21              : namespace nntrainer {
      22              : 
      23              : static constexpr size_t SINGLE_INOUT_IDX = 0;
      24              : 
      25              : enum LSTMParams {
      26              :   weight_ih,
      27              :   weight_hh,
      28              :   bias_h,
      29              :   bias_ih,
      30              :   bias_hh,
      31              :   hidden_state,
      32              :   cell_state,
      33              :   ifgo,
      34              :   reverse_weight_ih,
      35              :   reverse_weight_hh,
      36              :   reverse_bias_h,
      37              :   reverse_bias_ih,
      38              :   reverse_bias_hh,
      39              :   reverse_hidden_state,
      40              :   reverse_cell_state,
      41              :   reverse_ifgo,
      42              :   dropout_mask
      43              : };
      44              : 
      45          189 : void LSTMLayer::forwardingBatchFirstLSTM(
      46              :   unsigned int NUM_GATE, const unsigned int batch_size,
      47              :   const unsigned int feature_size, const bool disable_bias,
      48              :   const unsigned int unit, const bool integrate_bias, ActiFunc &acti_func,
      49              :   ActiFunc &recurrent_acti_func, const bool enable_dropout,
      50              :   const float dropout_rate, const unsigned int max_timestep, const bool reverse,
      51              :   const Tensor &input_, const Tensor &weight_ih, const Tensor &weight_hh,
      52              :   const Tensor &bias_h, const Tensor &bias_ih, const Tensor &bias_hh,
      53              :   Tensor &hidden_state_, Tensor &cell_state_, Tensor &ifgo_,
      54              :   const Tensor &mask_) {
      55          189 :   hidden_state_.setZero();
      56          189 :   cell_state_.setZero();
      57          189 :   TensorDim::TensorType tensor_type = weight_ih.getTensorType();
      58          189 :   TensorDim input_tensor_dim({feature_size}, tensor_type);
      59          189 :   TensorDim unit_tensor_dim({unit}, tensor_type);
      60          189 :   TensorDim num_gate_unit_tensor_dim({NUM_GATE * unit}, tensor_type);
      61              : 
      62          591 :   for (unsigned int batch = 0; batch < batch_size; ++batch) {
      63          402 :     const Tensor input_sample = input_.getBatchSlice(batch, 1);
      64          402 :     Tensor hidden_state_sample = hidden_state_.getBatchSlice(batch, 1);
      65          402 :     Tensor cell_state_sample = cell_state_.getBatchSlice(batch, 1);
      66          402 :     Tensor ifgo_sample = ifgo_.getBatchSlice(batch, 1);
      67              : 
      68         1281 :     for (unsigned int t = 0; t < max_timestep; ++t) {
      69              :       Tensor input = input_sample.getSharedDataTensor(
      70          879 :         input_tensor_dim, (reverse ? max_timestep - 1 - t : t) * feature_size);
      71              : 
      72              :       Tensor prev_hidden_state = Tensor(
      73          879 :         "prev_hidden_state", weight_ih.getFormat(), weight_ih.getDataType());
      74              : 
      75          879 :       if (!t) {
      76          804 :         prev_hidden_state = Tensor(unit, tensor_type);
      77          402 :         prev_hidden_state.setZero();
      78              :       } else {
      79         1431 :         prev_hidden_state = hidden_state_sample.getSharedDataTensor(
      80          477 :           unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
      81              :       }
      82              :       Tensor hidden_state = hidden_state_sample.getSharedDataTensor(
      83          879 :         unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
      84          879 :       Tensor prev_cell_state;
      85          879 :       if (!t) {
      86          804 :         prev_cell_state = Tensor(unit, tensor_type);
      87          402 :         prev_cell_state.setZero();
      88              :       } else {
      89         1431 :         prev_cell_state = cell_state_sample.getSharedDataTensor(
      90          477 :           unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
      91              :       }
      92              :       Tensor cell_state = cell_state_sample.getSharedDataTensor(
      93          879 :         unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
      94              :       Tensor ifgo = ifgo_sample.getSharedDataTensor(
      95              :         num_gate_unit_tensor_dim,
      96          879 :         (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
      97              : 
      98          879 :       forwardLSTM(1, unit, disable_bias, integrate_bias, acti_func,
      99              :                   recurrent_acti_func, input, prev_hidden_state,
     100              :                   prev_cell_state, hidden_state, cell_state, weight_ih,
     101              :                   weight_hh, bias_h, bias_ih, bias_hh, ifgo);
     102              : 
     103          879 :       if (enable_dropout) {
     104            0 :         Tensor mask_sample = mask_.getBatchSlice(batch, 1);
     105              :         Tensor mask =
     106            0 :           mask_sample.getSharedDataTensor(unit_tensor_dim, t * unit);
     107            0 :         mask.dropout_mask(dropout_rate);
     108            0 :         hidden_state.multiply_i(mask);
     109            0 :       }
     110          879 :     }
     111          402 :   }
     112          189 : }
     113              : 
     114          110 : void LSTMLayer::calcGradientBatchFirstLSTM(
     115              :   unsigned int NUM_GATE, const unsigned int batch_size,
     116              :   const unsigned int feature_size, const bool disable_bias,
     117              :   const unsigned int unit, const bool integrate_bias, ActiFunc &acti_func,
     118              :   ActiFunc &recurrent_acti_func, const bool return_sequences,
     119              :   const bool bidirectional, const bool enable_dropout, const float dropout_rate,
     120              :   const unsigned int max_timestep, const bool reverse, const Tensor &input_,
     121              :   const Tensor &incoming_derivative, Tensor &d_weight_ih,
     122              :   const Tensor &weight_hh, Tensor &d_weight_hh, Tensor &d_bias_h,
     123              :   Tensor &d_bias_ih, Tensor &d_bias_hh, const Tensor &hidden_state_,
     124              :   Tensor &d_hidden_state_, const Tensor &cell_state_, Tensor &d_cell_state_,
     125              :   const Tensor &ifgo_, Tensor &d_ifgo_, const Tensor &mask_) {
     126          110 :   const unsigned int bidirectional_constant = bidirectional ? 2 : 1;
     127              : 
     128          110 :   d_weight_ih.setZero();
     129          110 :   d_weight_hh.setZero();
     130          110 :   if (!disable_bias) {
     131          110 :     if (integrate_bias) {
     132           83 :       d_bias_h.setZero();
     133              :     } else {
     134           27 :       d_bias_ih.setZero();
     135           27 :       d_bias_hh.setZero();
     136              :     }
     137              :   }
     138              : 
     139          110 :   d_cell_state_.setZero();
     140          110 :   d_hidden_state_.setZero();
     141              : 
     142          110 :   TensorDim::TensorType tensor_type = weight_hh.getTensorType();
     143          110 :   TensorDim unit_tensor_dim({unit}, tensor_type);
     144          110 :   TensorDim feature_size_tensor_dim({feature_size}, tensor_type);
     145          110 :   TensorDim num_gate_tensor_dim({NUM_GATE * unit}, tensor_type);
     146              : 
     147          110 :   if (return_sequences && !bidirectional && !reverse) {
     148           57 :     if (incoming_derivative.getDataType() == TensorDim::DataType::FP32) {
     149           57 :       std::copy(incoming_derivative.getData<float>(),
     150           57 :                 incoming_derivative.getData<float>() +
     151           57 :                   incoming_derivative.size(),
     152              :                 d_hidden_state_.getData<float>());
     153            0 :     } else if (incoming_derivative.getDataType() == TensorDim::DataType::FP16) {
     154              : #ifdef ENABLE_FP16
     155              :       std::copy(incoming_derivative.getData<_FP16>(),
     156              :                 incoming_derivative.getData<_FP16>() +
     157              :                   incoming_derivative.size(),
     158              :                 d_hidden_state_.getData<_FP16>());
     159              : #else
     160            0 :       throw std::invalid_argument("Error: enable-fp16 is not enabled");
     161              : #endif
     162              :     }
     163              :   } else {
     164           53 :     unsigned int end_timestep = return_sequences ? max_timestep : 1;
     165          157 :     for (unsigned int batch = 0; batch < batch_size; ++batch) {
     166          262 :       for (unsigned int timestep = 0; timestep < end_timestep; ++timestep) {
     167              :         Tensor d_hidden_state_sample = d_hidden_state_.getSharedDataTensor(
     168          316 :           unit_tensor_dim, batch * max_timestep * unit +
     169          158 :                              (return_sequences ? 0 : max_timestep - 1) * unit +
     170          316 :                              timestep * unit);
     171              :         Tensor incoming_derivative_sample =
     172              :           incoming_derivative.getSharedDataTensor(
     173          158 :             unit_tensor_dim, batch * (return_sequences ? max_timestep : 1) *
     174          158 :                                  bidirectional_constant * unit +
     175          158 :                                timestep * bidirectional_constant * unit +
     176          370 :                                (reverse ? unit : 0));
     177          158 :         d_hidden_state_sample.add_i(incoming_derivative_sample);
     178          158 :       }
     179              :     }
     180              :   }
     181              : 
     182          110 :   if (enable_dropout) {
     183            0 :     d_hidden_state_.multiply_i(mask_);
     184              :   }
     185              : 
     186          110 :   auto workers = ParallelBatch(batch_size);
     187              : 
     188          110 :   if (workers.getNumWorkers() > 1) {
     189              : 
     190            0 :     TensorDim weight_ih_d = d_weight_ih.getDim();
     191            0 :     TensorDim weight_hh_d = d_weight_hh.getDim();
     192              : 
     193            0 :     TensorDim bias_ih_d = d_bias_ih.getDim();
     194            0 :     TensorDim bias_hh_d = d_bias_hh.getDim();
     195            0 :     TensorDim bias_h_d = d_bias_h.getDim();
     196              : 
     197            0 :     weight_ih_d.batch(workers.getNumWorkers());
     198            0 :     weight_hh_d.batch(workers.getNumWorkers());
     199            0 :     bias_ih_d.batch(workers.getNumWorkers());
     200            0 :     bias_hh_d.batch(workers.getNumWorkers());
     201            0 :     bias_h_d.batch(workers.getNumWorkers());
     202              : 
     203            0 :     Tensor sub_d_weight_ih = Tensor(weight_ih_d);
     204            0 :     Tensor sub_d_weight_hh = Tensor(weight_hh_d);
     205            0 :     Tensor sub_d_bias_ih = Tensor(bias_ih_d);
     206            0 :     Tensor sub_d_bias_hh = Tensor(bias_hh_d);
     207            0 :     Tensor sub_d_bias_h = Tensor(bias_h_d);
     208              : 
     209            0 :     sub_d_weight_ih.setZero();
     210            0 :     sub_d_weight_hh.setZero();
     211            0 :     sub_d_bias_ih.setZero();
     212            0 :     sub_d_bias_hh.setZero();
     213            0 :     sub_d_bias_h.setZero();
     214              : 
     215            0 :     auto batch_job = [&](unsigned int s, unsigned int e, unsigned int pid,
     216              :                          void *user_data) {
     217            0 :       for (unsigned int batch = s; batch < e; ++batch) {
     218            0 :         const Tensor input_sample = input_.getBatchSlice(batch, 1);
     219              : 
     220              :         const Tensor hidden_state_sample =
     221            0 :           hidden_state_.getBatchSlice(batch, 1);
     222            0 :         Tensor d_hidden_state_sample = d_hidden_state_.getBatchSlice(batch, 1);
     223            0 :         const Tensor cell_state_sample = cell_state_.getBatchSlice(batch, 1);
     224            0 :         Tensor d_cell_state_sample = d_cell_state_.getBatchSlice(batch, 1);
     225              : 
     226            0 :         const Tensor ifgo_sample = ifgo_.getBatchSlice(batch, 1);
     227            0 :         Tensor d_ifgo_sample = d_ifgo_.getBatchSlice(batch, 1);
     228              : 
     229            0 :         Tensor input;
     230            0 :         Tensor prev_hidden_state;
     231            0 :         Tensor d_prev_hidden_state;
     232            0 :         Tensor prev_cell_state;
     233            0 :         Tensor d_prev_cell_state;
     234            0 :         Tensor d_hidden_state;
     235            0 :         Tensor cell_state;
     236            0 :         Tensor d_cell_state;
     237              : 
     238            0 :         Tensor p_d_weight_ih = sub_d_weight_ih.getBatchSlice(pid, 1);
     239            0 :         Tensor p_d_weight_hh = sub_d_weight_hh.getBatchSlice(pid, 1);
     240            0 :         Tensor p_d_bias_ih = sub_d_bias_ih.getBatchSlice(pid, 1);
     241            0 :         Tensor p_d_bias_hh = sub_d_bias_hh.getBatchSlice(pid, 1);
     242            0 :         Tensor p_d_bias_h = sub_d_bias_h.getBatchSlice(pid, 1);
     243              : 
     244            0 :         for (int t = max_timestep - 1; t > -1; t--) {
     245            0 :           input = input_sample.getSharedDataTensor(
     246              :             feature_size_tensor_dim,
     247            0 :             (reverse ? max_timestep - 1 - t : t) * feature_size);
     248              : 
     249            0 :           if (!t) {
     250            0 :             prev_hidden_state = Tensor(unit, tensor_type);
     251            0 :             prev_hidden_state.setZero();
     252            0 :             d_prev_hidden_state = Tensor(unit, tensor_type);
     253            0 :             d_prev_hidden_state.setZero();
     254              :           } else {
     255            0 :             prev_hidden_state = hidden_state_sample.getSharedDataTensor(
     256            0 :               unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
     257            0 :             d_prev_hidden_state = d_hidden_state_sample.getSharedDataTensor(
     258            0 :               unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
     259              :           }
     260            0 :           d_hidden_state = d_hidden_state_sample.getSharedDataTensor(
     261            0 :             unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
     262              : 
     263            0 :           if (!t) {
     264            0 :             prev_cell_state = Tensor(unit, tensor_type);
     265            0 :             prev_cell_state.setZero();
     266            0 :             d_prev_cell_state = Tensor(unit, tensor_type);
     267            0 :             d_prev_cell_state.setZero();
     268              :           } else {
     269            0 :             prev_cell_state = cell_state_sample.getSharedDataTensor(
     270            0 :               unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
     271            0 :             d_prev_cell_state = d_cell_state_sample.getSharedDataTensor(
     272            0 :               unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
     273              :           }
     274            0 :           cell_state = cell_state_sample.getSharedDataTensor(
     275            0 :             unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
     276            0 :           d_cell_state = d_cell_state_sample.getSharedDataTensor(
     277            0 :             unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
     278              : 
     279              :           Tensor ifgo = ifgo_sample.getSharedDataTensor(
     280              :             num_gate_tensor_dim,
     281            0 :             (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
     282              :           Tensor d_ifgo = d_ifgo_sample.getSharedDataTensor(
     283              :             num_gate_tensor_dim,
     284            0 :             (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
     285              : 
     286              :           // Temporary variable for d_prev_hidden_state. d_prev_hidden_state
     287              :           // already have precalculated values from incomming derivatives
     288              :           Tensor d_prev_hidden_state_temp =
     289              :             Tensor("d_prev_hidden_state_temp", tensor_type.format,
     290            0 :                    tensor_type.data_type);
     291              : 
     292            0 :           calcGradientLSTM(
     293            0 :             1, unit, disable_bias, integrate_bias, acti_func,
     294              :             recurrent_acti_func, input, prev_hidden_state,
     295              :             d_prev_hidden_state_temp, prev_cell_state, d_prev_cell_state,
     296              :             d_hidden_state, cell_state, d_cell_state, p_d_weight_ih, weight_hh,
     297              :             p_d_weight_hh, p_d_bias_h, p_d_bias_ih, p_d_bias_hh, ifgo, d_ifgo);
     298              : 
     299            0 :           d_prev_hidden_state.add_i(d_prev_hidden_state_temp);
     300            0 :         }
     301            0 :       }
     302            0 :     };
     303              : 
     304            0 :     workers.setCallback(batch_job, nullptr);
     305            0 :     workers.run();
     306              : 
     307            0 :     for (unsigned int b = 0; b < workers.getNumWorkers(); ++b) {
     308              : 
     309            0 :       Tensor p_d_weight_ih = sub_d_weight_ih.getBatchSlice(b, 1);
     310            0 :       Tensor p_d_weight_hh = sub_d_weight_hh.getBatchSlice(b, 1);
     311            0 :       Tensor p_d_bias_ih = sub_d_bias_ih.getBatchSlice(b, 1);
     312            0 :       Tensor p_d_bias_hh = sub_d_bias_hh.getBatchSlice(b, 1);
     313            0 :       Tensor p_d_bias_h = sub_d_bias_h.getBatchSlice(b, 1);
     314              : 
     315            0 :       d_weight_ih.add_i(p_d_weight_ih);
     316            0 :       d_weight_hh.add_i(p_d_weight_hh);
     317              : 
     318            0 :       if (!disable_bias) {
     319            0 :         if (integrate_bias) {
     320            0 :           d_bias_h.add_i(p_d_bias_h);
     321              :         } else {
     322            0 :           d_bias_ih.add_i(p_d_bias_ih);
     323            0 :           d_bias_hh.add_i(p_d_bias_hh);
     324              :         }
     325              :       }
     326            0 :     }
     327              : 
     328            0 :   } else {
     329          319 :     for (unsigned int batch = 0; batch < batch_size; ++batch) {
     330          209 :       const Tensor input_sample = input_.getBatchSlice(batch, 1);
     331              : 
     332          209 :       const Tensor hidden_state_sample = hidden_state_.getBatchSlice(batch, 1);
     333          209 :       Tensor d_hidden_state_sample = d_hidden_state_.getBatchSlice(batch, 1);
     334          209 :       const Tensor cell_state_sample = cell_state_.getBatchSlice(batch, 1);
     335          209 :       Tensor d_cell_state_sample = d_cell_state_.getBatchSlice(batch, 1);
     336              : 
     337          209 :       const Tensor ifgo_sample = ifgo_.getBatchSlice(batch, 1);
     338          209 :       Tensor d_ifgo_sample = d_ifgo_.getBatchSlice(batch, 1);
     339              : 
     340          209 :       Tensor input;
     341          209 :       Tensor prev_hidden_state;
     342          209 :       Tensor d_prev_hidden_state;
     343          209 :       Tensor prev_cell_state;
     344          209 :       Tensor d_prev_cell_state;
     345          209 :       Tensor d_hidden_state;
     346          209 :       Tensor cell_state;
     347          209 :       Tensor d_cell_state;
     348              : 
     349          634 :       for (int t = max_timestep - 1; t > -1; t--) {
     350          850 :         input = input_sample.getSharedDataTensor(
     351              :           feature_size_tensor_dim,
     352          425 :           (reverse ? max_timestep - 1 - t : t) * feature_size);
     353              : 
     354          425 :         if (!t) {
     355          627 :           prev_hidden_state = Tensor(unit, tensor_type);
     356          209 :           prev_hidden_state.setZero();
     357          627 :           d_prev_hidden_state = Tensor(unit, tensor_type);
     358          209 :           d_prev_hidden_state.setZero();
     359              :         } else {
     360          432 :           prev_hidden_state = hidden_state_sample.getSharedDataTensor(
     361          216 :             unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
     362          648 :           d_prev_hidden_state = d_hidden_state_sample.getSharedDataTensor(
     363          216 :             unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
     364              :         }
     365          850 :         d_hidden_state = d_hidden_state_sample.getSharedDataTensor(
     366          425 :           unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
     367              : 
     368          425 :         if (!t) {
     369          627 :           prev_cell_state = Tensor(unit, tensor_type);
     370          209 :           prev_cell_state.setZero();
     371          627 :           d_prev_cell_state = Tensor(unit, tensor_type);
     372          209 :           d_prev_cell_state.setZero();
     373              :         } else {
     374          432 :           prev_cell_state = cell_state_sample.getSharedDataTensor(
     375          216 :             unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
     376          648 :           d_prev_cell_state = d_cell_state_sample.getSharedDataTensor(
     377          216 :             unit_tensor_dim, (reverse ? (max_timestep - t) : (t - 1)) * unit);
     378              :         }
     379          850 :         cell_state = cell_state_sample.getSharedDataTensor(
     380          425 :           unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
     381          850 :         d_cell_state = d_cell_state_sample.getSharedDataTensor(
     382          425 :           unit_tensor_dim, (reverse ? max_timestep - 1 - t : t) * unit);
     383              : 
     384              :         Tensor ifgo = ifgo_sample.getSharedDataTensor(
     385              :           num_gate_tensor_dim,
     386          425 :           (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
     387              :         Tensor d_ifgo = d_ifgo_sample.getSharedDataTensor(
     388              :           num_gate_tensor_dim,
     389          425 :           (reverse ? max_timestep - 1 - t : t) * NUM_GATE * unit);
     390              : 
     391              :         // Temporary variable for d_prev_hidden_state. d_prev_hidden_state
     392              :         // already have precalculated values from incomming derivatives
     393              :         Tensor d_prev_hidden_state_temp =
     394              :           Tensor("d_prev_hidden_state_temp", tensor_type.format,
     395          425 :                  tensor_type.data_type);
     396              : 
     397          425 :         calcGradientLSTM(1, unit, disable_bias, integrate_bias, acti_func,
     398              :                          recurrent_acti_func, input, prev_hidden_state,
     399              :                          d_prev_hidden_state_temp, prev_cell_state,
     400              :                          d_prev_cell_state, d_hidden_state, cell_state,
     401              :                          d_cell_state, d_weight_ih, weight_hh, d_weight_hh,
     402              :                          d_bias_h, d_bias_ih, d_bias_hh, ifgo, d_ifgo);
     403          425 :         d_prev_hidden_state.add_i(d_prev_hidden_state_temp);
     404          425 :       }
     405          209 :     }
     406              :   }
     407          110 : }
     408              : 
     409           76 : LSTMLayer::LSTMLayer() :
     410              :   LSTMCore(),
     411           76 :   lstm_props(props::ReturnSequences(), props::Bidirectional(),
     412          152 :              props::DropOutRate(), props::MaxTimestep()) {
     413              :   wt_idx.fill(std::numeric_limits<unsigned>::max());
     414           76 : }
     415              : 
     416           62 : void LSTMLayer::finalize(InitLayerContext &context) {
     417              :   const Initializer weight_initializer =
     418           62 :     std::get<props::WeightInitializer>(*layer_impl_props).get();
     419              :   const Initializer bias_initializer =
     420           62 :     std::get<props::BiasInitializer>(*layer_impl_props).get();
     421              :   const nntrainer::WeightRegularizer weight_regularizer =
     422           62 :     std::get<props::WeightRegularizer>(*layer_impl_props).get();
     423              :   const float weight_regularizer_constant =
     424           62 :     std::get<props::WeightRegularizerConstant>(*layer_impl_props).get();
     425              :   auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
     426              :   auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
     427              :   const bool disable_bias =
     428           62 :     std::get<props::DisableBias>(*layer_impl_props).get();
     429              : 
     430           62 :   NNTR_THROW_IF(std::get<props::Unit>(lstmcore_props).empty(),
     431              :                 std::invalid_argument)
     432              :     << "unit property missing for lstm layer";
     433           62 :   const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
     434              :   const bool integrate_bias =
     435           62 :     std::get<props::IntegrateBias>(lstmcore_props).get();
     436              :   const ActivationType hidden_state_activation_type =
     437           62 :     std::get<props::HiddenStateActivation>(lstmcore_props).get();
     438              :   const ActivationType recurrent_activation_type =
     439           62 :     std::get<props::RecurrentActivation>(lstmcore_props).get();
     440              : 
     441              :   const bool return_sequences =
     442           62 :     std::get<props::ReturnSequences>(lstm_props).get();
     443           62 :   const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
     444           62 :   const float dropout_rate = std::get<props::DropOutRate>(lstm_props).get();
     445              : 
     446           62 :   if (context.getNumInputs() != 1) {
     447            0 :     throw std::invalid_argument("LSTM layer takes only one input");
     448              :   }
     449              : 
     450              :   // input_dim = [ batch_size, 1, time_iteration, feature_size ]
     451              :   const TensorDim &input_dim = context.getInputDimensions()[SINGLE_INOUT_IDX];
     452           62 :   if (input_dim.channel() != 1) {
     453              :     throw std::invalid_argument(
     454              :       "Input must be single channel dimension for LSTM (shape should be "
     455            0 :       "[batch_size, 1, time_iteration, feature_size])");
     456              :   }
     457           62 :   const unsigned int batch_size = input_dim.batch();
     458           62 :   unsigned int max_timestep = input_dim.height();
     459           62 :   if (!std::get<props::MaxTimestep>(lstm_props).empty())
     460           26 :     max_timestep =
     461           52 :       std::max(max_timestep, std::get<props::MaxTimestep>(lstm_props).get());
     462           62 :   NNTR_THROW_IF(max_timestep < 1, std::runtime_error)
     463              :     << "max timestep must be greator than 0 in lstm layer.";
     464           62 :   std::get<props::MaxTimestep>(lstm_props).set(max_timestep);
     465           62 :   const unsigned int feature_size = input_dim.width();
     466              : 
     467              :   // output_dim = [ batch_size, 1, return_sequences ? time_iteration : 1,
     468              :   // bidirectional ? 2 * unit : unit ]
     469              :   TensorDim::TensorType activation_tensor_type = {
     470              :     context.getFormat(), context.getActivationDataType()};
     471              : 
     472              :   TensorDim::TensorType weight_tensor_type = {context.getFormat(),
     473              :                                               context.getWeightDataType()};
     474           44 :   const TensorDim output_dim(batch_size, 1, return_sequences ? max_timestep : 1,
     475           12 :                              bidirectional ? 2 * unit : unit,
     476           62 :                              activation_tensor_type);
     477           62 :   context.setOutputDimensions({output_dim});
     478              : 
     479              :   // weight_initializer can be set seperately. weight_ih initializer,
     480              :   // weight_hh initializer kernel initializer & recurrent_initializer in
     481              :   // keras for now, it is set same way.
     482              : 
     483              :   // weight_ih ( input to hidden ) : [ 1, 1, feature_size, NUM_GATE * unit ]
     484              :   // -> i, f, g, o
     485           62 :   const TensorDim weight_ih_dim({feature_size, NUM_GATE * unit},
     486           62 :                                 weight_tensor_type);
     487           62 :   wt_idx[LSTMParams::weight_ih] = context.requestWeight(
     488              :     weight_ih_dim, weight_initializer, weight_regularizer,
     489              :     weight_regularizer_constant, weight_decay, "weight_ih", true);
     490              :   // weight_hh ( hidden to hidden ) : [ 1, 1, unit, NUM_GATE * unit ] -> i,
     491              :   // f, g, o
     492           62 :   const TensorDim weight_hh_dim({unit, NUM_GATE * unit}, weight_tensor_type);
     493          124 :   wt_idx[LSTMParams::weight_hh] = context.requestWeight(
     494              :     weight_hh_dim, weight_initializer, weight_regularizer,
     495              :     weight_regularizer_constant, weight_decay, "weight_hh", true);
     496           62 :   if (!disable_bias) {
     497           62 :     if (integrate_bias) {
     498              :       // bias_h ( input bias, hidden bias are integrate to 1 bias ) : [ 1,
     499              :       // 1, 1, NUM_GATE * unit ] -> i, f, g, o
     500           34 :       const TensorDim bias_h_dim({NUM_GATE * unit}, weight_tensor_type);
     501           34 :       wt_idx[LSTMParams::bias_h] = context.requestWeight(
     502              :         bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
     503              :         "bias_h", true);
     504              :     } else {
     505              :       // bias_ih ( input bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o
     506           28 :       const TensorDim bias_ih_dim({NUM_GATE * unit}, weight_tensor_type);
     507           28 :       wt_idx[LSTMParams::bias_ih] = context.requestWeight(
     508              :         bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
     509              :         bias_decay, "bias_ih", true);
     510              :       // bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o
     511           56 :       wt_idx[LSTMParams::bias_hh] = context.requestWeight(
     512              :         bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
     513              :         bias_decay, "bias_hh", true);
     514              :     }
     515              :   }
     516              : 
     517              :   // hidden_state_dim : [ batch_size, 1, max_timestep, unit ]
     518              :   const TensorDim hidden_state_dim(batch_size, 1, max_timestep, unit,
     519           62 :                                    activation_tensor_type);
     520              : 
     521           62 :   wt_idx[LSTMParams::hidden_state] =
     522          124 :     context.requestTensor(hidden_state_dim, "hidden_state", Initializer::NONE,
     523              :                           true, TensorLifespan::ITERATION_LIFESPAN);
     524              :   // cell_state_dim : [ batch_size, 1, max_timestep, unit ]
     525              :   const TensorDim cell_state_dim(batch_size, 1, max_timestep, unit,
     526           62 :                                  activation_tensor_type);
     527              : 
     528           62 :   wt_idx[LSTMParams::cell_state] =
     529          124 :     context.requestTensor(cell_state_dim, "cell_state", Initializer::NONE, true,
     530              :                           TensorLifespan::ITERATION_LIFESPAN);
     531              : 
     532              :   // ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ]
     533              :   const TensorDim ifgo_dim(batch_size, 1, max_timestep, NUM_GATE * unit,
     534           62 :                            activation_tensor_type);
     535              : 
     536           62 :   wt_idx[LSTMParams::ifgo] =
     537           62 :     context.requestTensor(ifgo_dim, "ifgo", Initializer::NONE, true,
     538              :                           TensorLifespan::ITERATION_LIFESPAN);
     539              : 
     540           62 :   if (bidirectional) {
     541              :     // weight_initializer can be set seperately. weight_ih initializer,
     542              :     // weight_hh initializer kernel initializer & recurrent_initializer in
     543              :     // keras for now, it is set same way.
     544              : 
     545              :     // reverse_weight_ih ( input to hidden ) : [ 1, 1, feature_size,
     546              :     // NUM_GATE * unit ] -> i, f, g, o
     547              :     const TensorDim reverse_weight_ih_dim({feature_size, NUM_GATE * unit},
     548           12 :                                           weight_tensor_type);
     549           24 :     wt_idx[LSTMParams::reverse_weight_ih] = context.requestWeight(
     550              :       reverse_weight_ih_dim, weight_initializer, weight_regularizer,
     551              :       weight_regularizer_constant, weight_decay, "reverse_weight_ih", true);
     552              :     // reverse_weight_hh ( hidden to hidden ) : [ 1, 1, unit, NUM_GATE *
     553              :     // unit ]
     554              :     // -> i, f, g, o
     555              :     const TensorDim reverse_weight_hh_dim({unit, NUM_GATE * unit},
     556           12 :                                           weight_tensor_type);
     557           24 :     wt_idx[LSTMParams::reverse_weight_hh] = context.requestWeight(
     558              :       reverse_weight_hh_dim, weight_initializer, weight_regularizer,
     559              :       weight_regularizer_constant, weight_decay, "reverse_weight_hh", true);
     560           12 :     if (!disable_bias) {
     561           12 :       if (integrate_bias) {
     562              :         // reverse_bias_h ( input bias, hidden bias are integrate to 1 bias
     563              :         // ) : [ 1, 1, 1, NUM_GATE * unit ] -> i, f, g, o
     564              :         const TensorDim reverse_bias_h_dim({NUM_GATE * unit},
     565            0 :                                            weight_tensor_type);
     566            0 :         wt_idx[LSTMParams::reverse_bias_h] = context.requestWeight(
     567              :           reverse_bias_h_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
     568              :           bias_decay, "reverse_bias_h", true);
     569              :       } else {
     570              :         // reverse_bias_ih ( input bias ) : [ 1, 1, 1, NUM_GATE * unit ] ->
     571              :         // i, f, g, o
     572              :         const TensorDim reverse_bias_ih_dim({NUM_GATE * unit},
     573           12 :                                             weight_tensor_type);
     574           12 :         wt_idx[LSTMParams::reverse_bias_ih] = context.requestWeight(
     575              :           reverse_bias_ih_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
     576              :           bias_decay, "reverse_bias_ih", true);
     577              :         // reverse_bias_hh ( hidden bias ) : [ 1, 1, 1, NUM_GATE * unit ] ->
     578              :         // i, f, g, o
     579              :         const TensorDim reverse_bias_hh_dim({NUM_GATE * unit},
     580           12 :                                             weight_tensor_type);
     581           24 :         wt_idx[LSTMParams::reverse_bias_hh] = context.requestWeight(
     582              :           reverse_bias_hh_dim, bias_initializer, WeightRegularizer::NONE, 1.0f,
     583              :           bias_decay, "reverse_bias_hh", true);
     584              :       }
     585              :     }
     586              : 
     587              :     // reverse_hidden_state_dim : [ batch_size, 1, max_timestep, unit ]
     588              :     const TensorDim reverse_hidden_state_dim(batch_size, 1, max_timestep, unit,
     589           12 :                                              activation_tensor_type);
     590           12 :     wt_idx[LSTMParams::reverse_hidden_state] = context.requestTensor(
     591              :       reverse_hidden_state_dim, "reverse_hidden_state", Initializer::NONE, true,
     592              :       TensorLifespan::ITERATION_LIFESPAN);
     593              :     // reverse_cell_state_dim : [ batch_size, 1, max_timestep, unit ]
     594              :     const TensorDim reverse_cell_state_dim(batch_size, 1, max_timestep, unit,
     595           12 :                                            activation_tensor_type);
     596           12 :     wt_idx[LSTMParams::reverse_cell_state] = context.requestTensor(
     597              :       reverse_cell_state_dim, "reverse_cell_state", Initializer::NONE, true,
     598              :       TensorLifespan::ITERATION_LIFESPAN);
     599              : 
     600              :     // reverse_ifgo_dim : [ batch_size, 1, max_timestep, NUM_GATE * unit ]
     601              :     const TensorDim reverse_ifgo_dim(batch_size, 1, max_timestep,
     602           12 :                                      NUM_GATE * unit, activation_tensor_type);
     603           12 :     wt_idx[LSTMParams::reverse_ifgo] =
     604           24 :       context.requestTensor(reverse_ifgo_dim, "reverse_ifgo", Initializer::NONE,
     605              :                             true, TensorLifespan::ITERATION_LIFESPAN);
     606              :   }
     607              : 
     608           62 :   if (dropout_rate > epsilon) {
     609              :     // dropout_mask_dim = [ batch, 1, time_iteration, unit ]
     610              :     const TensorDim dropout_mask_dim(batch_size, 1, max_timestep, unit,
     611            0 :                                      activation_tensor_type);
     612            0 :     wt_idx[LSTMParams::dropout_mask] =
     613            0 :       context.requestTensor(dropout_mask_dim, "dropout_mask", Initializer::NONE,
     614              :                             false, TensorLifespan::ITERATION_LIFESPAN);
     615              :   }
     616              : 
     617           62 :   if (context.getActivationDataType() == TensorDim::DataType::FP32) {
     618           62 :     acti_func.setActiFunc<float>(hidden_state_activation_type);
     619           62 :     recurrent_acti_func.setActiFunc<float>(recurrent_activation_type);
     620            0 :   } else if (context.getActivationDataType() == TensorDim::DataType::FP16) {
     621              : #ifdef ENABLE_FP16
     622              :     acti_func.setActiFunc<_FP16>(hidden_state_activation_type);
     623              :     recurrent_acti_func.setActiFunc<_FP16>(recurrent_activation_type);
     624              : #else
     625            0 :     throw std::invalid_argument("Error: enable-fp16 is not enabled");
     626              : #endif
     627              :   }
     628           62 : }
     629              : 
     630          337 : void LSTMLayer::setProperty(const std::vector<std::string> &values) {
     631              :   const std::vector<std::string> &remain_props =
     632          337 :     loadProperties(values, lstm_props);
     633          336 :   LSTMCore::setProperty(remain_props);
     634          336 : }
     635              : 
     636           26 : void LSTMLayer::exportTo(Exporter &exporter,
     637              :                          const ml::train::ExportMethods &method) const {
     638           26 :   LSTMCore::exportTo(exporter, method);
     639           26 :   exporter.saveResult(lstm_props, method, this);
     640           26 : }
     641              : 
     642          171 : void LSTMLayer::forwarding(RunLayerContext &context, bool training) {
     643              :   const bool disable_bias =
     644          171 :     std::get<props::DisableBias>(*layer_impl_props).get();
     645              : 
     646          171 :   const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
     647              :   const bool integrate_bias =
     648          171 :     std::get<props::IntegrateBias>(lstmcore_props).get();
     649              : 
     650              :   const bool return_sequences =
     651          171 :     std::get<props::ReturnSequences>(lstm_props).get();
     652          171 :   const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
     653          171 :   const float dropout_rate = std::get<props::DropOutRate>(lstm_props).get();
     654              :   const unsigned int max_timestep =
     655          171 :     std::get<props::MaxTimestep>(lstm_props).get();
     656              : 
     657          171 :   const unsigned int bidirectional_constant = bidirectional ? 2 : 1;
     658          171 :   bool enable_dropout = dropout_rate > epsilon && training;
     659              : 
     660          171 :   const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
     661          171 :   const TensorDim input_dim = input.getDim();
     662          171 :   const unsigned int batch_size = input_dim.batch();
     663          171 :   const unsigned int feature_size = input_dim.width();
     664          171 :   Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
     665              : 
     666          171 :   const Tensor &weight_ih = context.getWeight(wt_idx[LSTMParams::weight_ih]);
     667          171 :   const Tensor &weight_hh = context.getWeight(wt_idx[LSTMParams::weight_hh]);
     668              : 
     669              :   Tensor empty =
     670          171 :     Tensor("empty", weight_ih.getFormat(), weight_ih.getDataType());
     671              : 
     672          171 :   const Tensor &bias_h = !disable_bias && integrate_bias
     673          171 :                            ? context.getWeight(wt_idx[LSTMParams::bias_h])
     674              :                            : empty;
     675              :   const Tensor &bias_ih = !disable_bias && !integrate_bias
     676          171 :                             ? context.getWeight(wt_idx[LSTMParams::bias_ih])
     677              :                             : empty;
     678              :   const Tensor &bias_hh = !disable_bias && !integrate_bias
     679          171 :                             ? context.getWeight(wt_idx[LSTMParams::bias_hh])
     680              :                             : empty;
     681              : 
     682          171 :   Tensor &hidden_state = context.getTensor(wt_idx[LSTMParams::hidden_state]);
     683          171 :   Tensor &cell_state = context.getTensor(wt_idx[LSTMParams::cell_state]);
     684          171 :   Tensor &ifgo = context.getTensor(wt_idx[LSTMParams::ifgo]);
     685              : 
     686              :   Tensor &mask = enable_dropout
     687          171 :                    ? context.getTensor(wt_idx[LSTMParams::dropout_mask])
     688              :                    : empty;
     689          171 :   forwardingBatchFirstLSTM(NUM_GATE, batch_size, feature_size, disable_bias,
     690          171 :                            unit, integrate_bias, acti_func, recurrent_acti_func,
     691              :                            enable_dropout, dropout_rate, max_timestep, false,
     692              :                            input, weight_ih, weight_hh, bias_h, bias_ih,
     693              :                            bias_hh, hidden_state, cell_state, ifgo, mask);
     694          171 :   if (bidirectional) {
     695              :     const Tensor &reverse_weight_ih =
     696           18 :       context.getWeight(wt_idx[LSTMParams::reverse_weight_ih]);
     697              :     const Tensor &reverse_weight_hh =
     698           18 :       context.getWeight(wt_idx[LSTMParams::reverse_weight_hh]);
     699              :     const Tensor &reverse_bias_h =
     700              :       !disable_bias && integrate_bias
     701           18 :         ? context.getWeight(wt_idx[LSTMParams::reverse_bias_h])
     702              :         : empty;
     703              :     const Tensor &reverse_bias_ih =
     704              :       !disable_bias && !integrate_bias
     705           18 :         ? context.getWeight(wt_idx[LSTMParams::reverse_bias_ih])
     706              :         : empty;
     707              :     const Tensor &reverse_bias_hh =
     708              :       !disable_bias && !integrate_bias
     709           18 :         ? context.getWeight(wt_idx[LSTMParams::reverse_bias_hh])
     710              :         : empty;
     711              : 
     712              :     Tensor &reverse_hidden_state =
     713           18 :       context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]);
     714              :     Tensor &reverse_cell_state =
     715           18 :       context.getTensor(wt_idx[LSTMParams::reverse_cell_state]);
     716           18 :     Tensor &reverse_ifgo = context.getTensor(wt_idx[LSTMParams::reverse_ifgo]);
     717              : 
     718           18 :     forwardingBatchFirstLSTM(
     719              :       NUM_GATE, batch_size, feature_size, disable_bias, unit, integrate_bias,
     720              :       acti_func, recurrent_acti_func, enable_dropout, dropout_rate,
     721              :       max_timestep, true, input, reverse_weight_ih, reverse_weight_hh,
     722              :       reverse_bias_h, reverse_bias_ih, reverse_bias_hh, reverse_hidden_state,
     723              :       reverse_cell_state, reverse_ifgo, mask);
     724              :   }
     725              : 
     726          171 :   if (return_sequences && !bidirectional) {
     727           98 :     if (hidden_state.getDataType() == TensorDim::DataType::FP32) {
     728           98 :       std::copy(hidden_state.getData<float>(),
     729           98 :                 hidden_state.getData<float>() + hidden_state.size(),
     730              :                 output.getData<float>());
     731            0 :     } else if (hidden_state.getDataType() == TensorDim::DataType::FP16) {
     732              : #ifdef ENABLE_FP16
     733              :       std::copy(hidden_state.getData<_FP16>(),
     734              :                 hidden_state.getData<_FP16>() + hidden_state.size(),
     735              :                 output.getData<_FP16>());
     736              : #else
     737            0 :       throw std::invalid_argument("Error: enable-fp16 is not enabled");
     738              : #endif
     739              :     }
     740              :   } else {
     741           73 :     unsigned int end_timestep = return_sequences ? max_timestep : 1;
     742           73 :     if (hidden_state.getDataType() == TensorDim::DataType::FP32) {
     743          217 :       for (unsigned int batch = 0; batch < batch_size; ++batch) {
     744          342 :         for (unsigned int timestep = 0; timestep < end_timestep; ++timestep) {
     745          198 :           float *hidden_state_data = hidden_state.getAddress<float>(
     746          198 :             batch * max_timestep * unit +
     747          198 :             (return_sequences ? 0 : (max_timestep - 1) * unit) +
     748              :             timestep * unit);
     749          198 :           float *output_data = output.getAddress<float>(
     750          198 :             batch * (return_sequences ? max_timestep : 1) *
     751          198 :               bidirectional_constant * unit +
     752              :             timestep * bidirectional_constant * unit);
     753          198 :           std::copy(hidden_state_data, hidden_state_data + unit, output_data);
     754              : 
     755          198 :           if (bidirectional) {
     756              :             Tensor &reverse_hidden_state =
     757          108 :               context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]);
     758              :             float *reverse_hidden_state_data =
     759              :               reverse_hidden_state.getAddress<float>(
     760              :                 batch * max_timestep * unit +
     761              :                 (return_sequences ? 0 : (max_timestep - 1) * unit) +
     762              :                 timestep * unit);
     763          108 :             std::copy(reverse_hidden_state_data,
     764              :                       reverse_hidden_state_data + unit, output_data + unit);
     765              :           }
     766              :         }
     767              :       }
     768            0 :     } else if (hidden_state.getDataType() == TensorDim::DataType::FP16) {
     769              : #ifdef ENABLE_FP16
     770              :       for (unsigned int batch = 0; batch < batch_size; ++batch) {
     771              :         for (unsigned int timestep = 0; timestep < end_timestep; ++timestep) {
     772              :           _FP16 *hidden_state_data = hidden_state.getAddress<_FP16>(
     773              :             batch * max_timestep * unit +
     774              :             (return_sequences ? 0 : (max_timestep - 1) * unit) +
     775              :             timestep * unit);
     776              :           _FP16 *output_data = output.getAddress<_FP16>(
     777              :             batch * (return_sequences ? max_timestep : 1) *
     778              :               bidirectional_constant * unit +
     779              :             timestep * bidirectional_constant * unit);
     780              :           std::copy(hidden_state_data, hidden_state_data + unit, output_data);
     781              : 
     782              :           if (bidirectional) {
     783              :             Tensor &reverse_hidden_state =
     784              :               context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]);
     785              :             _FP16 *reverse_hidden_state_data =
     786              :               reverse_hidden_state.getAddress<_FP16>(
     787              :                 batch * max_timestep * unit +
     788              :                 (return_sequences ? 0 : (max_timestep - 1) * unit) +
     789              :                 timestep * unit);
     790              :             std::copy(reverse_hidden_state_data,
     791              :                       reverse_hidden_state_data + unit, output_data + unit);
     792              :           }
     793              :         }
     794              :       }
     795              : #else
     796            0 :       throw std::invalid_argument("Error: enable-fp16 is not enabled");
     797              : #endif
     798              :     }
     799              :   }
     800          171 : }
     801              : 
     802          101 : void LSTMLayer::calcDerivative(RunLayerContext &context) {
     803          101 :   const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
     804              : 
     805          101 :   Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
     806          101 :   const Tensor &weight_ih = context.getWeight(wt_idx[LSTMParams::weight_ih]);
     807          101 :   const Tensor &d_ifgos = context.getTensorGrad(wt_idx[LSTMParams::ifgo]);
     808              : 
     809          101 :   calcDerivativeLSTM(outgoing_derivative, weight_ih, d_ifgos);
     810              : 
     811          101 :   if (bidirectional) {
     812              :     const Tensor &reverse_weight_ih =
     813            9 :       context.getWeight(wt_idx[LSTMParams::reverse_weight_ih]);
     814              :     const Tensor &reverse_d_ifgos =
     815            9 :       context.getTensorGrad(wt_idx[LSTMParams::reverse_ifgo]);
     816              : 
     817            9 :     calcDerivativeLSTM(outgoing_derivative, reverse_weight_ih, reverse_d_ifgos,
     818              :                        1.0f);
     819              :   }
     820          101 : }
     821              : 
     822          101 : void LSTMLayer::calcGradient(RunLayerContext &context) {
     823              :   const bool disable_bias =
     824          101 :     std::get<props::DisableBias>(*layer_impl_props).get();
     825              : 
     826          101 :   const unsigned int unit = std::get<props::Unit>(lstmcore_props).get();
     827              :   const bool integrate_bias =
     828          101 :     std::get<props::IntegrateBias>(lstmcore_props).get();
     829              : 
     830              :   const bool return_sequences =
     831          101 :     std::get<props::ReturnSequences>(lstm_props).get();
     832          101 :   const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
     833          101 :   const float dropout_rate = std::get<props::DropOutRate>(lstm_props).get();
     834              :   const unsigned int max_timestep =
     835          101 :     std::get<props::MaxTimestep>(lstm_props).get();
     836              : 
     837          101 :   bool enable_dropout = dropout_rate > epsilon;
     838              : 
     839          101 :   const Tensor &input = context.getInput(SINGLE_INOUT_IDX);
     840              :   const Tensor &incoming_derivative =
     841          101 :     context.getIncomingDerivative(SINGLE_INOUT_IDX);
     842          101 :   const TensorDim input_dim = input.getDim();
     843          101 :   const unsigned int batch_size = input_dim.batch();
     844          101 :   const unsigned int feature_size = input_dim.width();
     845              : 
     846          101 :   Tensor &d_weight_ih = context.getWeightGrad(wt_idx[LSTMParams::weight_ih]);
     847          101 :   const Tensor &weight_hh = context.getWeight(wt_idx[LSTMParams::weight_hh]);
     848          101 :   Tensor &d_weight_hh = context.getWeightGrad(wt_idx[LSTMParams::weight_hh]);
     849              : 
     850              :   Tensor empty =
     851          101 :     Tensor("empty", weight_hh.getFormat(), weight_hh.getDataType());
     852              : 
     853          101 :   Tensor &d_bias_h = !disable_bias && integrate_bias
     854          101 :                        ? context.getWeightGrad(wt_idx[LSTMParams::bias_h])
     855              :                        : empty;
     856              :   Tensor &d_bias_ih = !disable_bias && !integrate_bias
     857          101 :                         ? context.getWeightGrad(wt_idx[LSTMParams::bias_ih])
     858              :                         : empty;
     859              :   Tensor &d_bias_hh = !disable_bias && !integrate_bias
     860          101 :                         ? context.getWeightGrad(wt_idx[LSTMParams::bias_hh])
     861              :                         : empty;
     862              : 
     863              :   const Tensor &hidden_state =
     864          101 :     context.getTensor(wt_idx[LSTMParams::hidden_state]);
     865              :   Tensor &d_hidden_state =
     866          101 :     context.getTensorGrad(wt_idx[LSTMParams::hidden_state]);
     867          101 :   const Tensor &cell_state = context.getTensor(wt_idx[LSTMParams::cell_state]);
     868          101 :   Tensor &d_cell_state = context.getTensorGrad(wt_idx[LSTMParams::cell_state]);
     869              : 
     870          101 :   const Tensor &ifgo = context.getTensor(wt_idx[LSTMParams::ifgo]);
     871          101 :   Tensor &d_ifgo = context.getTensorGrad(wt_idx[LSTMParams::ifgo]);
     872              : 
     873              :   const Tensor &mask = enable_dropout
     874          101 :                          ? context.getTensor(wt_idx[LSTMParams::dropout_mask])
     875              :                          : empty;
     876              : 
     877          101 :   calcGradientBatchFirstLSTM(
     878              :     NUM_GATE, batch_size, feature_size, disable_bias, unit, integrate_bias,
     879          101 :     acti_func, recurrent_acti_func, return_sequences, bidirectional,
     880              :     enable_dropout, dropout_rate, max_timestep, false, input,
     881              :     incoming_derivative, d_weight_ih, weight_hh, d_weight_hh, d_bias_h,
     882              :     d_bias_ih, d_bias_hh, hidden_state, d_hidden_state, cell_state,
     883              :     d_cell_state, ifgo, d_ifgo, mask);
     884              : 
     885          101 :   if (bidirectional) {
     886              :     Tensor &reverse_d_weight_ih =
     887            9 :       context.getWeightGrad(wt_idx[LSTMParams::reverse_weight_ih]);
     888              :     const Tensor &reverse_weight_hh =
     889            9 :       context.getWeight(wt_idx[LSTMParams::reverse_weight_hh]);
     890              :     Tensor &reverse_d_weight_hh =
     891            9 :       context.getWeightGrad(wt_idx[LSTMParams::reverse_weight_hh]);
     892              :     Tensor &reverse_d_bias_h =
     893              :       !disable_bias && integrate_bias
     894            9 :         ? context.getWeightGrad(wt_idx[LSTMParams::reverse_bias_h])
     895              :         : empty;
     896              :     Tensor &reverse_d_bias_ih =
     897              :       !disable_bias && !integrate_bias
     898            9 :         ? context.getWeightGrad(wt_idx[LSTMParams::reverse_bias_ih])
     899              :         : empty;
     900              :     Tensor &reverse_d_bias_hh =
     901              :       !disable_bias && !integrate_bias
     902            9 :         ? context.getWeightGrad(wt_idx[LSTMParams::reverse_bias_hh])
     903              :         : empty;
     904              : 
     905              :     const Tensor &reverse_hidden_state =
     906            9 :       context.getTensor(wt_idx[LSTMParams::reverse_hidden_state]);
     907              :     Tensor &reverse_d_hidden_state =
     908            9 :       context.getTensorGrad(wt_idx[LSTMParams::reverse_hidden_state]);
     909              :     const Tensor &reverse_cell_state =
     910            9 :       context.getTensor(wt_idx[LSTMParams::reverse_cell_state]);
     911              :     Tensor &reverse_d_cell_state =
     912            9 :       context.getTensorGrad(wt_idx[LSTMParams::reverse_cell_state]);
     913              : 
     914              :     const Tensor &reverse_ifgo =
     915            9 :       context.getTensor(wt_idx[LSTMParams::reverse_ifgo]);
     916              :     Tensor &reverse_d_ifgo =
     917            9 :       context.getTensorGrad(wt_idx[LSTMParams::reverse_ifgo]);
     918              : 
     919            9 :     calcGradientBatchFirstLSTM(
     920              :       NUM_GATE, batch_size, feature_size, disable_bias, unit, integrate_bias,
     921              :       acti_func, recurrent_acti_func, return_sequences, bidirectional,
     922              :       enable_dropout, dropout_rate, max_timestep, true, input,
     923              :       incoming_derivative, reverse_d_weight_ih, reverse_weight_hh,
     924              :       reverse_d_weight_hh, reverse_d_bias_h, reverse_d_bias_ih,
     925              :       reverse_d_bias_hh, reverse_hidden_state, reverse_d_hidden_state,
     926              :       reverse_cell_state, reverse_d_cell_state, reverse_ifgo, reverse_d_ifgo,
     927              :       mask);
     928              :   }
     929          101 : }
     930              : 
     931           36 : void LSTMLayer::setBatch(RunLayerContext &context, unsigned int batch) {
     932           36 :   const bool bidirectional = std::get<props::Bidirectional>(lstm_props).get();
     933           36 :   const float dropout_rate = std::get<props::DropOutRate>(lstm_props).get();
     934              : 
     935           36 :   context.updateTensor(wt_idx[LSTMParams::hidden_state], batch);
     936           36 :   context.updateTensor(wt_idx[LSTMParams::cell_state], batch);
     937           36 :   context.updateTensor(wt_idx[LSTMParams::ifgo], batch);
     938              : 
     939           36 :   if (bidirectional) {
     940           12 :     context.updateTensor(wt_idx[LSTMParams::reverse_hidden_state], batch);
     941           12 :     context.updateTensor(wt_idx[LSTMParams::reverse_cell_state], batch);
     942           12 :     context.updateTensor(wt_idx[LSTMParams::reverse_ifgo], batch);
     943              :   }
     944              : 
     945           36 :   if (dropout_rate > epsilon) {
     946            0 :     context.updateTensor(wt_idx[LSTMParams::dropout_mask], batch);
     947              :   }
     948           36 : }
     949              : 
     950              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1