LCOV - code coverage report
Current view: top level - nntrainer/layers - lstmcell_core.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 100.0 % 1 1
Test Date: 2025-12-14 20:38:17 Functions: 100.0 % 1 1

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
       4              :  *
       5              :  * @file   lstmcell_core.h
       6              :  * @date   25 November 2021
       7              :  * @brief  This is lstm core class.
       8              :  * @see    https://github.com/nnstreamer/nntrainer
       9              :  * @author hyeonseok lee <hs89.lee@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  *
      12              :  */
      13              : 
      14              : #ifndef __LSTMCELLCORE_H__
      15              : #define __LSTMCELLCORE_H__
      16              : #ifdef __cplusplus
      17              : 
      18              : #include <acti_func.h>
      19              : #include <common.h>
      20              : #include <layer_impl.h>
      21              : #include <node_exporter.h>
      22              : 
      23              : namespace nntrainer {
      24              : 
      25              : /**
      26              :  * @class   LSTMCore
      27              :  * @brief   LSTMCore
      28              :  */
      29              : class LSTMCore : public LayerImpl {
      30              : public:
      31              :   /**
      32              :    * @brief     Constructor of LSTMCore
      33              :    */
      34              :   LSTMCore();
      35              : 
      36              :   /**
      37              :    * @brief     Destructor of LSTMCore
      38              :    */
      39          838 :   ~LSTMCore() = default;
      40              : 
      41              :   /**
      42              :    * @brief lstm cell forwarding implementation
      43              :    *
      44              :    * @param batch_size batch size
      45              :    * @param unit number of output neurons
      46              :    * @param disable_bias whether to disable bias or not
      47              :    * @param integrate_bias integrate bias_ih, bias_hh to bias_h
      48              :    * @param acti_func activation function for memory cell, cell state
      49              :    * @param recurrent_acti_func activation function for input/output/forget
      50              :    * gate
      51              :    * @param input input
      52              :    * @param prev_hidden_state previous hidden state
      53              :    * @param prev_cell_state previous cell state
      54              :    * @param hidden_state hidden state
      55              :    * @param cell_state cell state
      56              :    * @param weight_ih weight for input to hidden
      57              :    * @param weight_hh weight for hidden to hidden
      58              :    * @param bias_h bias for input and hidden.
      59              :    * @param bias_ih bias for input
      60              :    * @param bias_hh bias for hidden
      61              :    * @param ifgo input gate, forget gate, memory cell, output gate
      62              :    */
      63              :   void forwardLSTM(const unsigned int batch_size, const unsigned int unit,
      64              :                    const bool disable_bias, const bool integrate_bias,
      65              :                    ActiFunc &acti_func, ActiFunc &recurrent_acti_func,
      66              :                    const Tensor &input, const Tensor &prev_hidden_state,
      67              :                    const Tensor &prev_cell_state, Tensor &hidden_state,
      68              :                    Tensor &cell_state, const Tensor &weight_ih,
      69              :                    const Tensor &weight_hh, const Tensor &bias_h,
      70              :                    const Tensor &bias_ih, const Tensor &bias_hh, Tensor &ifgo);
      71              : 
      72              :   /**
      73              :    * @brief lstm cell calculate derivative implementation
      74              :    *
      75              :    * @param outgoing_derivative derivative for input
      76              :    * @param weight_ih weight for input to hidden
      77              :    * @param d_ifgo gradient for input gate, forget gate, memory cell, output
      78              :    * gate
      79              :    * @param alpha value to be scale outgoing_derivative
      80              :    */
      81              :   void calcDerivativeLSTM(Tensor &outgoing_derivative, const Tensor &weight_ih,
      82              :                           const Tensor &d_ifgo, const float alpha = 0.0f);
      83              : 
      84              :   /**
      85              :    * @brief lstm cell calculate gradient implementation
      86              :    *
      87              :    * @param batch_size batch size
      88              :    * @param unit number of output neurons
      89              :    * @param disable_bias whether to disable bias or not
      90              :    * @param integrate_bias integrate bias_ih, bias_hh to bias_h
      91              :    * @param acti_func activation function for memory cell, cell state
      92              :    * @param recurrent_acti_func activation function for input/output/forget
      93              :    * gate
      94              :    * @param input input
      95              :    * @param prev_hidden_state previous hidden state
      96              :    * @param d_prev_hidden_state previous hidden state gradient
      97              :    * @param prev_cell_state previous cell state
      98              :    * @param d_prev_cell_state previous cell state gradient
      99              :    * @param d_hidden_state hidden state gradient
     100              :    * @param cell_state cell state
     101              :    * @param d_cell_state cell state gradient
     102              :    * @param d_weight_ih weight_ih(weight for input to hidden) gradient
     103              :    * @param weight_hh weight for hidden to hidden
     104              :    * @param d_weight_hh weight_hh(weight for hidden to hidden) gradient
     105              :    * @param d_bias_h bias_h(bias for input and hidden) gradient
     106              :    * @param d_bias_ih bias_ih(bias for input) gradient
     107              :    * @param d_bias_hh bias_hh(bias for hidden) gradient
     108              :    * @param ifgo input gate, forget gate, memory cell, output gate
     109              :    * @param d_ifgo gradient for input gate, forget gate, memory cell, output
     110              :    * gate
     111              :    */
     112              :   void calcGradientLSTM(const unsigned int batch_size, const unsigned int unit,
     113              :                         const bool disable_bias, const bool integrate_bias,
     114              :                         ActiFunc &acti_func, ActiFunc &recurrent_acti_func,
     115              :                         const Tensor &input, const Tensor &prev_hidden_state,
     116              :                         Tensor &d_prev_hidden_state,
     117              :                         const Tensor &prev_cell_state,
     118              :                         Tensor &d_prev_cell_state, const Tensor &d_hidden_state,
     119              :                         const Tensor &cell_state, const Tensor &d_cell_state,
     120              :                         Tensor &d_weight_ih, const Tensor &weight_hh,
     121              :                         Tensor &d_weight_hh, Tensor &d_bias_h,
     122              :                         Tensor &d_bias_ih, Tensor &d_bias_hh,
     123              :                         const Tensor &ifgo, Tensor &d_ifgo);
     124              : 
     125              :   /**
     126              :    * @copydoc Layer::setProperty(const PropertyType type, const std::string
     127              :    * &value)
     128              :    */
     129              :   void setProperty(const std::vector<std::string> &values) override;
     130              : 
     131              :   /**
     132              :    * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
     133              :    * method)
     134              :    */
     135              :   void exportTo(Exporter &exporter,
     136              :                 const ml::train::ExportMethods &method) const override;
     137              : 
     138              : protected:
     139              :   /**
     140              :    * Unit: number of output neurons
     141              :    * IntegrateBias: integrate bias_ih, bias_hh to bias_h
     142              :    * HiddenStateActivation: activation type for hidden state. default is tanh
     143              :    * RecurrentActivation: activation type for recurrent. default is sigmoid
     144              :    *
     145              :    * */
     146              :   std::tuple<props::Unit, props::IntegrateBias, props::HiddenStateActivation,
     147              :              props::RecurrentActivation>
     148              :     lstmcore_props;
     149              : 
     150              :   /**
     151              :    * @brief     activation function: default is tanh
     152              :    */
     153              :   ActiFunc acti_func;
     154              : 
     155              :   /**
     156              :    * @brief     activation function for recurrent: default is sigmoid
     157              :    */
     158              :   ActiFunc recurrent_acti_func;
     159              : 
     160              :   /**
     161              :    * @brief     to protect overflow
     162              :    */
     163              :   float epsilon;
     164              : };
     165              : } // namespace nntrainer
     166              : 
     167              : #endif /* __cplusplus */
     168              : #endif /* __LSTMCELLCORE_H__ */
        

Generated by: LCOV version 2.0-1