LCOV - code coverage report
Current view: top level - nntrainer/layers - zoneout_lstmcell.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 83.3 % 12 10
Test Date: 2025-12-14 20:38:17 Functions: 75.0 % 4 3

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
       4              :  *
       5              :  * @file   zoneout_lstmcell.h
       6              :  * @date   30 November 2021
       7              :  * @brief  This is ZoneoutLSTMCell Layer Class of Neural Network
       8              :  * @see    https://github.com/nnstreamer/nntrainer
       9              :  *         https://arxiv.org/pdf/1606.01305.pdf
      10              :  *         https://github.com/teganmaharaj/zoneout
      11              :  * @author hyeonseok lee <hs89.lee@samsung.com>
      12              :  * @bug    No known bugs except for NYI items
      13              :  *
      14              :  */
      15              : 
      16              : #ifndef __ZONEOUTLSTMCELL_H__
      17              : #define __ZONEOUTLSTMCELL_H__
      18              : #ifdef __cplusplus
      19              : 
      20              : #include <acti_func.h>
      21              : #include <common_properties.h>
      22              : #include <lstmcell_core.h>
      23              : 
      24              : namespace nntrainer {
      25              : 
      26              : /**
      27              :  * @class   ZoneoutLSTMCellLayer
      28              :  * @brief   ZoneoutLSTMCellLayer
      29              :  */
      30              : class ZoneoutLSTMCellLayer : public LSTMCore {
      31              : public:
      32              :   /**
      33              :    * @brief HiddenStateZoneOutRate property, this defines zone out rate for
      34              :    * hidden state
      35              :    *
      36              :    */
      37            0 :   class HiddenStateZoneOutRate : public nntrainer::Property<float> {
      38              : 
      39              :   public:
      40              :     /**
      41              :      * @brief Construct a new HiddenStateZoneOutRate object with a default value
      42              :      * 0.0
      43              :      *
      44              :      */
      45          270 :     HiddenStateZoneOutRate(float value = 0.0) :
      46          270 :       nntrainer::Property<float>(value) {}
      47              :     static constexpr const char *key =
      48              :       "hidden_state_zoneout_rate";   /**< unique key to access */
      49              :     using prop_tag = float_prop_tag; /**< property type */
      50              : 
      51              :     /**
      52              :      * @brief HiddenStateZoneOutRate validator
      53              :      *
      54              :      * @param v float to validate
      55              :      * @retval true if it is equal or greater than 0.0 and equal or smaller than
      56              :      * to 1.0
      57              :      * @retval false if it is samller than 0.0 or greater than 1.0
      58              :      */
      59              :     bool isValid(const float &value) const override;
      60              :   };
      61              : 
      62              :   /**
      63              :    * @brief CellStateZoneOutRate property, this defines zone out rate for cell
      64              :    * state
      65              :    *
      66              :    */
      67            0 :   class CellStateZoneOutRate : public nntrainer::Property<float> {
      68              : 
      69              :   public:
      70              :     /**
      71              :      * @brief Construct a new CellStateZoneOutRate object with a default value
      72              :      * 0.0
      73              :      *
      74              :      */
      75          270 :     CellStateZoneOutRate(float value = 0.0) :
      76          270 :       nntrainer::Property<float>(value) {}
      77              :     static constexpr const char *key =
      78              :       "cell_state_zoneout_rate";     /**< unique key to access */
      79              :     using prop_tag = float_prop_tag; /**< property type */
      80              : 
      81              :     /**
      82              :      * @brief CellStateZoneOutRate validator
      83              :      *
      84              :      * @param v float to validate
      85              :      * @retval true if it is equal or greater than 0.0 and equal or smaller than
      86              :      * to 1.0
      87              :      * @retval false if it is samller than 0.0 or greater than 1.0
      88              :      */
      89              :     bool isValid(const float &value) const override;
      90              :   };
      91              : 
      92              :   /**
      93              :    * @brief Test property, this property is set to true when test the zoneout
      94              :    * lstmcell in unittest
      95              :    *
      96              :    */
      97          270 :   class Test : public nntrainer::Property<bool> {
      98              : 
      99              :   public:
     100              :     /**
     101              :      * @brief Construct a new Test object with a default value false
     102              :      *
     103              :      */
     104          270 :     Test(bool value = false) : nntrainer::Property<bool>(value) {}
     105              :     static constexpr const char *key = "test"; /**< unique key to access */
     106              :     using prop_tag = bool_prop_tag;            /**< property type */
     107              :   };
     108              : 
     109              :   /**
     110              :    * @brief     Constructor of ZoneoutLSTMCellLayer
     111              :    */
     112              :   ZoneoutLSTMCellLayer();
     113              : 
     114              :   /**
     115              :    * @brief     Destructor of ZoneoutLSTMCellLayer
     116              :    */
     117          270 :   ~ZoneoutLSTMCellLayer() = default;
     118              : 
     119              :   /**
     120              :    * @copydoc Layer::finalize(InitLayerContext &context)
     121              :    */
     122              :   void finalize(InitLayerContext &context) override;
     123              : 
     124              :   /**
     125              :    * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
     126              :    */
     127              :   void forwarding(RunLayerContext &context, bool training) override;
     128              : 
     129              :   /**
     130              :    * @copydoc Layer::calcDerivative(RunLayerContext &context)
     131              :    */
     132              :   void calcDerivative(RunLayerContext &context) override;
     133              : 
     134              :   /**
     135              :    * @copydoc Layer::calcGradient(RunLayerContext &context)
     136              :    */
     137              :   void calcGradient(RunLayerContext &context) override;
     138              :   /**
     139              :    * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
     140              :    * method)
     141              :    */
     142              :   void exportTo(Exporter &exporter,
     143              :                 const ml::train::ExportMethods &method) const override;
     144              : 
     145              :   /**
     146              :    * @copydoc Layer::getType()
     147              :    */
     148         5148 :   const std::string getType() const override {
     149         5148 :     return ZoneoutLSTMCellLayer::type;
     150              :   };
     151              : 
     152              :   /**
     153              :    * @copydoc Layer::supportBackwarding()
     154              :    */
     155          468 :   bool supportBackwarding() const override { return true; }
     156              : 
     157              :   /**
     158              :    * @copydoc Layer::setProperty(const PropertyType type, const std::string
     159              :    * &value)
     160              :    */
     161              :   void setProperty(const std::vector<std::string> &values) override;
     162              : 
     163              :   /**
     164              :    * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch)
     165              :    */
     166              :   void setBatch(RunLayerContext &context, unsigned int batch) override;
     167              : 
     168              :   static constexpr const char *type = "zoneout_lstmcell";
     169              : 
     170              : private:
     171              :   static constexpr unsigned int NUM_GATE = 4;
     172              :   enum INOUT_INDEX {
     173              :     INPUT = 0,
     174              :     INPUT_HIDDEN_STATE = 1,
     175              :     INPUT_CELL_STATE = 2,
     176              :     OUTPUT_HIDDEN_STATE = 0,
     177              :     OUTPUT_CELL_STATE = 1
     178              :   };
     179              : 
     180              :   /** common properties like Unit, IntegrateBias, HiddenStateActivation and
     181              :    * RecurrentActivation are in lstmcore_props */
     182              : 
     183              :   /**
     184              :    * HiddenStateZoneOutRate: zoneout rate for hidden_state
     185              :    * CellStateZoneOutRate: zoneout rate for cell_state
     186              :    * Test: property for test mode
     187              :    * MaxTimestep: maximum timestep for zoneout lstmcell
     188              :    * TimeStep: timestep for which lstm should operate
     189              :    *
     190              :    * */
     191              :   std::tuple<HiddenStateZoneOutRate, CellStateZoneOutRate, Test,
     192              :              props::MaxTimestep, props::Timestep>
     193              :     zoneout_lstmcell_props;
     194              :   std::array<unsigned int, 9> wt_idx; /**< indices of the weights */
     195              : };
     196              : } // namespace nntrainer
     197              : 
     198              : #endif /* __cplusplus */
     199              : #endif /* __ZONEOUTLSTMCELL_H__ */
        

Generated by: LCOV version 2.0-1