LCOV - code coverage report
Current view: top level - nntrainer/layers - lstmcell_core.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 100.0 % 77 77
Test Date: 2025-12-14 20:38:17 Functions: 85.7 % 7 6

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
       4              :  *
       5              :  * @file   lstmcell_core.cpp
       6              :  * @date   25 November 2021
       7              :  * @brief  This is lstm core class.
       8              :  * @see    https://github.com/nnstreamer/nntrainer
       9              :  * @author hyeonseok lee <hs89.lee@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  *
      12              :  */
      13              : 
      14              : #include <lstmcell_core.h>
      15              : #include <nntrainer_error.h>
      16              : #include <nntrainer_log.h>
      17              : 
      18              : namespace nntrainer {
      19              : 
      20          419 : LSTMCore::LSTMCore() :
      21              :   LayerImpl(),
      22         1257 :   lstmcore_props(props::Unit(), props::IntegrateBias(),
      23          838 :                  props::HiddenStateActivation() = ActivationType::ACT_TANH,
      24          838 :                  props::RecurrentActivation() = ActivationType::ACT_SIGMOID),
      25          419 :   acti_func(ActivationType::ACT_NONE, true),
      26          419 :   recurrent_acti_func(ActivationType::ACT_NONE, true),
      27          838 :   epsilon(1e-3f) {}
      28              : 
      29         1244 : void LSTMCore::forwardLSTM(const unsigned int batch_size,
      30              :                            const unsigned int unit, const bool disable_bias,
      31              :                            const bool integrate_bias, ActiFunc &acti_func,
      32              :                            ActiFunc &recurrent_acti_func, const Tensor &input,
      33              :                            const Tensor &prev_hidden_state,
      34              :                            const Tensor &prev_cell_state, Tensor &hidden_state,
      35              :                            Tensor &cell_state, const Tensor &weight_ih,
      36              :                            const Tensor &weight_hh, const Tensor &bias_h,
      37              :                            const Tensor &bias_ih, const Tensor &bias_hh,
      38              :                            Tensor &ifgo) {
      39         1244 :   input.dot(weight_ih, ifgo);
      40         1244 :   prev_hidden_state.dot(weight_hh, ifgo, false, false, 1.0);
      41         1244 :   if (!disable_bias) {
      42         1244 :     if (integrate_bias) {
      43          560 :       ifgo.add_i(bias_h);
      44              :     } else {
      45          684 :       ifgo.add_i(bias_ih);
      46          684 :       ifgo.add_i(bias_hh);
      47              :     }
      48              :   }
      49              : 
      50         1244 :   TensorDim::TensorType tensor_type = ifgo.getTensorType();
      51              : 
      52              :   Tensor input_forget_gate = ifgo.getSharedDataTensor(
      53         1244 :     {batch_size, 1, 1, unit * 2, tensor_type}, 0, false);
      54              :   Tensor input_gate =
      55         1244 :     ifgo.getSharedDataTensor({batch_size, 1, 1, unit, tensor_type}, 0, false);
      56              :   Tensor forget_gate = ifgo.getSharedDataTensor(
      57         1244 :     {batch_size, 1, 1, unit, tensor_type}, unit, false);
      58              :   Tensor memory_cell = ifgo.getSharedDataTensor(
      59         1244 :     {batch_size, 1, 1, unit, tensor_type}, unit * 2, false);
      60              :   Tensor output_gate = ifgo.getSharedDataTensor(
      61         1244 :     {batch_size, 1, 1, unit, tensor_type}, unit * 3, false);
      62              : 
      63              :   recurrent_acti_func.run_fn(input_forget_gate, input_forget_gate);
      64              :   recurrent_acti_func.run_fn(output_gate, output_gate);
      65              :   acti_func.run_fn(memory_cell, memory_cell);
      66              : 
      67         1244 :   prev_cell_state.multiply_strided(forget_gate, cell_state);
      68         1244 :   memory_cell.multiply_strided(input_gate, cell_state, 1.0f);
      69              : 
      70              :   acti_func.run_fn(cell_state, hidden_state);
      71         1244 :   hidden_state.multiply_i_strided(output_gate);
      72         1244 : }
      73              : 
      74          291 : void LSTMCore::calcDerivativeLSTM(Tensor &outgoing_derivative,
      75              :                                   const Tensor &weight_ih, const Tensor &d_ifgo,
      76              :                                   const float alpha) {
      77          291 :   d_ifgo.dot(weight_ih, outgoing_derivative, false, true, alpha);
      78          291 : }
      79              : 
      80          606 : void LSTMCore::calcGradientLSTM(
      81              :   const unsigned int batch_size, const unsigned int unit,
      82              :   const bool disable_bias, const bool integrate_bias, ActiFunc &acti_func,
      83              :   ActiFunc &recurrent_acti_func, const Tensor &input,
      84              :   const Tensor &prev_hidden_state, Tensor &d_prev_hidden_state,
      85              :   const Tensor &prev_cell_state, Tensor &d_prev_cell_state,
      86              :   const Tensor &d_hidden_state, const Tensor &cell_state,
      87              :   const Tensor &d_cell_state, Tensor &d_weight_ih, const Tensor &weight_hh,
      88              :   Tensor &d_weight_hh, Tensor &d_bias_h, Tensor &d_bias_ih, Tensor &d_bias_hh,
      89              :   const Tensor &ifgo, Tensor &d_ifgo) {
      90          606 :   TensorDim::TensorType tensor_type = ifgo.getTensorType();
      91              :   Tensor input_forget_gate = ifgo.getSharedDataTensor(
      92          606 :     {batch_size, 1, 1, unit * 2, tensor_type}, 0, false);
      93              :   Tensor input_gate =
      94          606 :     ifgo.getSharedDataTensor({batch_size, 1, 1, unit, tensor_type}, 0, false);
      95              :   Tensor forget_gate = ifgo.getSharedDataTensor(
      96          606 :     {batch_size, 1, 1, unit, tensor_type}, unit, false);
      97              :   Tensor memory_cell = ifgo.getSharedDataTensor(
      98          606 :     {batch_size, 1, 1, unit, tensor_type}, unit * 2, false);
      99              :   Tensor output_gate = ifgo.getSharedDataTensor(
     100          606 :     {batch_size, 1, 1, unit, tensor_type}, unit * 3, false);
     101              : 
     102              :   Tensor d_input_forget_gate = d_ifgo.getSharedDataTensor(
     103          606 :     {batch_size, 1, 1, unit * 2, tensor_type}, 0, false);
     104              :   Tensor d_input_gate =
     105          606 :     d_ifgo.getSharedDataTensor({batch_size, 1, 1, unit, tensor_type}, 0, false);
     106              :   Tensor d_forget_gate = d_ifgo.getSharedDataTensor(
     107          606 :     {batch_size, 1, 1, unit, tensor_type}, unit, false);
     108              :   Tensor d_memory_cell = d_ifgo.getSharedDataTensor(
     109          606 :     {batch_size, 1, 1, unit, tensor_type}, unit * 2, false);
     110              :   Tensor d_output_gate = d_ifgo.getSharedDataTensor(
     111          606 :     {batch_size, 1, 1, unit, tensor_type}, unit * 3, false);
     112              : 
     113              :   Tensor activated_cell_state = Tensor(
     114         1212 :     "activated_cell_state", cell_state.getFormat(), cell_state.getDataType());
     115              : 
     116              :   acti_func.run_fn(cell_state, activated_cell_state);
     117          606 :   d_hidden_state.multiply_strided(activated_cell_state, d_output_gate);
     118          606 :   acti_func.run_prime_fn(activated_cell_state, d_prev_cell_state,
     119              :                          d_hidden_state);
     120          606 :   d_prev_cell_state.multiply_i_strided(output_gate);
     121          606 :   d_prev_cell_state.add_i(d_cell_state);
     122              : 
     123          606 :   d_prev_cell_state.multiply_strided(input_gate, d_memory_cell);
     124          606 :   d_prev_cell_state.multiply_strided(memory_cell, d_input_gate);
     125              : 
     126          606 :   d_prev_cell_state.multiply_strided(prev_cell_state, d_forget_gate);
     127          606 :   d_prev_cell_state.multiply_i_strided(forget_gate);
     128              : 
     129          606 :   recurrent_acti_func.run_prime_fn(output_gate, d_output_gate, d_output_gate);
     130          606 :   recurrent_acti_func.run_prime_fn(input_forget_gate, d_input_forget_gate,
     131              :                                    d_input_forget_gate);
     132          606 :   acti_func.run_prime_fn(memory_cell, d_memory_cell, d_memory_cell);
     133              : 
     134          606 :   if (!disable_bias) {
     135          606 :     if (integrate_bias) {
     136          264 :       d_ifgo.sum(0, d_bias_h, 1.0f, 1.0f);
     137              :     } else {
     138          342 :       d_ifgo.sum(0, d_bias_ih, 1.0f, 1.0f);
     139          342 :       d_ifgo.sum(0, d_bias_hh, 1.0f, 1.0f);
     140              :     }
     141              :   }
     142              : 
     143          606 :   if (input.batch() != 1) {
     144           19 :     input.dot(d_ifgo, d_weight_ih, true, false, 1.0f);
     145              :   } else {
     146              : 
     147         1960 :     for (unsigned int i = 0; i < d_weight_ih.height(); ++i) {
     148         1373 :       unsigned int out_width = d_weight_ih.width();
     149         1373 :       d_weight_ih.add_i_partial(out_width, i * out_width, d_ifgo, 1, 1, input,
     150              :                                 i);
     151              :     }
     152              :   }
     153              : 
     154          606 :   if (prev_hidden_state.batch() != 1) {
     155           19 :     prev_hidden_state.dot(d_ifgo, d_weight_hh, true, false, 1.0f);
     156              :   } else {
     157         1912 :     for (unsigned int i = 0; i < d_weight_hh.height(); ++i) {
     158         1325 :       unsigned int out_width = d_weight_hh.width();
     159         1325 :       d_weight_hh.add_i_partial(out_width, i * out_width, d_ifgo, 1, 1,
     160              :                                 prev_hidden_state, i);
     161              :     }
     162              :   }
     163          606 :   d_ifgo.dot(weight_hh, d_prev_hidden_state, false, true);
     164          606 : }
     165              : 
     166         2097 : void LSTMCore::setProperty(const std::vector<std::string> &values) {
     167              :   const std::vector<std::string> &remain_props =
     168         2097 :     loadProperties(values, lstmcore_props);
     169         2097 :   LayerImpl::setProperty(remain_props);
     170         2097 : }
     171              : 
     172          298 : void LSTMCore::exportTo(Exporter &exporter,
     173              :                         const ml::train::ExportMethods &method) const {
     174          298 :   LayerImpl::exportTo(exporter, method);
     175          298 :   exporter.saveResult(lstmcore_props, method, this);
     176          298 : }
     177              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1