Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
4 : *
5 : * @file zoneout_lstmcell.h
6 : * @date 30 November 2021
7 : * @brief This is ZoneoutLSTMCell Layer Class of Neural Network
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * https://arxiv.org/pdf/1606.01305.pdf
10 : * https://github.com/teganmaharaj/zoneout
11 : * @author hyeonseok lee <hs89.lee@samsung.com>
12 : * @bug No known bugs except for NYI items
13 : *
14 : */
15 :
16 : #ifndef __ZONEOUTLSTMCELL_H__
17 : #define __ZONEOUTLSTMCELL_H__
18 : #ifdef __cplusplus
19 :
20 : #include <acti_func.h>
21 : #include <common_properties.h>
22 : #include <lstmcell_core.h>
23 :
24 : namespace nntrainer {
25 :
26 : /**
27 : * @class ZoneoutLSTMCellLayer
28 : * @brief ZoneoutLSTMCellLayer
29 : */
30 : class ZoneoutLSTMCellLayer : public LSTMCore {
31 : public:
32 : /**
33 : * @brief HiddenStateZoneOutRate property, this defines zone out rate for
34 : * hidden state
35 : *
36 : */
37 0 : class HiddenStateZoneOutRate : public nntrainer::Property<float> {
38 :
39 : public:
40 : /**
41 : * @brief Construct a new HiddenStateZoneOutRate object with a default value
42 : * 0.0
43 : *
44 : */
45 270 : HiddenStateZoneOutRate(float value = 0.0) :
46 270 : nntrainer::Property<float>(value) {}
47 : static constexpr const char *key =
48 : "hidden_state_zoneout_rate"; /**< unique key to access */
49 : using prop_tag = float_prop_tag; /**< property type */
50 :
51 : /**
52 : * @brief HiddenStateZoneOutRate validator
53 : *
54 : * @param v float to validate
55 : * @retval true if it is equal or greater than 0.0 and equal or smaller than
56 : * to 1.0
57 : * @retval false if it is samller than 0.0 or greater than 1.0
58 : */
59 : bool isValid(const float &value) const override;
60 : };
61 :
62 : /**
63 : * @brief CellStateZoneOutRate property, this defines zone out rate for cell
64 : * state
65 : *
66 : */
67 0 : class CellStateZoneOutRate : public nntrainer::Property<float> {
68 :
69 : public:
70 : /**
71 : * @brief Construct a new CellStateZoneOutRate object with a default value
72 : * 0.0
73 : *
74 : */
75 270 : CellStateZoneOutRate(float value = 0.0) :
76 270 : nntrainer::Property<float>(value) {}
77 : static constexpr const char *key =
78 : "cell_state_zoneout_rate"; /**< unique key to access */
79 : using prop_tag = float_prop_tag; /**< property type */
80 :
81 : /**
82 : * @brief CellStateZoneOutRate validator
83 : *
84 : * @param v float to validate
85 : * @retval true if it is equal or greater than 0.0 and equal or smaller than
86 : * to 1.0
87 : * @retval false if it is samller than 0.0 or greater than 1.0
88 : */
89 : bool isValid(const float &value) const override;
90 : };
91 :
92 : /**
93 : * @brief Test property, this property is set to true when test the zoneout
94 : * lstmcell in unittest
95 : *
96 : */
97 270 : class Test : public nntrainer::Property<bool> {
98 :
99 : public:
100 : /**
101 : * @brief Construct a new Test object with a default value false
102 : *
103 : */
104 270 : Test(bool value = false) : nntrainer::Property<bool>(value) {}
105 : static constexpr const char *key = "test"; /**< unique key to access */
106 : using prop_tag = bool_prop_tag; /**< property type */
107 : };
108 :
109 : /**
110 : * @brief Constructor of ZoneoutLSTMCellLayer
111 : */
112 : ZoneoutLSTMCellLayer();
113 :
114 : /**
115 : * @brief Destructor of ZoneoutLSTMCellLayer
116 : */
117 270 : ~ZoneoutLSTMCellLayer() = default;
118 :
119 : /**
120 : * @copydoc Layer::finalize(InitLayerContext &context)
121 : */
122 : void finalize(InitLayerContext &context) override;
123 :
124 : /**
125 : * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
126 : */
127 : void forwarding(RunLayerContext &context, bool training) override;
128 :
129 : /**
130 : * @copydoc Layer::calcDerivative(RunLayerContext &context)
131 : */
132 : void calcDerivative(RunLayerContext &context) override;
133 :
134 : /**
135 : * @copydoc Layer::calcGradient(RunLayerContext &context)
136 : */
137 : void calcGradient(RunLayerContext &context) override;
138 : /**
139 : * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
140 : * method)
141 : */
142 : void exportTo(Exporter &exporter,
143 : const ml::train::ExportMethods &method) const override;
144 :
145 : /**
146 : * @copydoc Layer::getType()
147 : */
148 5148 : const std::string getType() const override {
149 5148 : return ZoneoutLSTMCellLayer::type;
150 : };
151 :
152 : /**
153 : * @copydoc Layer::supportBackwarding()
154 : */
155 468 : bool supportBackwarding() const override { return true; }
156 :
157 : /**
158 : * @copydoc Layer::setProperty(const PropertyType type, const std::string
159 : * &value)
160 : */
161 : void setProperty(const std::vector<std::string> &values) override;
162 :
163 : /**
164 : * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch)
165 : */
166 : void setBatch(RunLayerContext &context, unsigned int batch) override;
167 :
168 : static constexpr const char *type = "zoneout_lstmcell";
169 :
170 : private:
171 : static constexpr unsigned int NUM_GATE = 4;
172 : enum INOUT_INDEX {
173 : INPUT = 0,
174 : INPUT_HIDDEN_STATE = 1,
175 : INPUT_CELL_STATE = 2,
176 : OUTPUT_HIDDEN_STATE = 0,
177 : OUTPUT_CELL_STATE = 1
178 : };
179 :
180 : /** common properties like Unit, IntegrateBias, HiddenStateActivation and
181 : * RecurrentActivation are in lstmcore_props */
182 :
183 : /**
184 : * HiddenStateZoneOutRate: zoneout rate for hidden_state
185 : * CellStateZoneOutRate: zoneout rate for cell_state
186 : * Test: property for test mode
187 : * MaxTimestep: maximum timestep for zoneout lstmcell
188 : * TimeStep: timestep for which lstm should operate
189 : *
190 : * */
191 : std::tuple<HiddenStateZoneOutRate, CellStateZoneOutRate, Test,
192 : props::MaxTimestep, props::Timestep>
193 : zoneout_lstmcell_props;
194 : std::array<unsigned int, 9> wt_idx; /**< indices of the weights */
195 : };
196 : } // namespace nntrainer
197 :
198 : #endif /* __cplusplus */
199 : #endif /* __ZONEOUTLSTMCELL_H__ */
|