Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
4 : *
5 : * @file lstmcell_core.h
6 : * @date 25 November 2021
7 : * @brief This is lstm core class.
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author hyeonseok lee <hs89.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #ifndef __LSTMCELLCORE_H__
15 : #define __LSTMCELLCORE_H__
16 : #ifdef __cplusplus
17 :
18 : #include <acti_func.h>
19 : #include <common.h>
20 : #include <layer_impl.h>
21 : #include <node_exporter.h>
22 :
23 : namespace nntrainer {
24 :
25 : /**
26 : * @class LSTMCore
27 : * @brief LSTMCore
28 : */
29 : class LSTMCore : public LayerImpl {
30 : public:
31 : /**
32 : * @brief Constructor of LSTMCore
33 : */
34 : LSTMCore();
35 :
36 : /**
37 : * @brief Destructor of LSTMCore
38 : */
39 838 : ~LSTMCore() = default;
40 :
41 : /**
42 : * @brief lstm cell forwarding implementation
43 : *
44 : * @param batch_size batch size
45 : * @param unit number of output neurons
46 : * @param disable_bias whether to disable bias or not
47 : * @param integrate_bias integrate bias_ih, bias_hh to bias_h
48 : * @param acti_func activation function for memory cell, cell state
49 : * @param recurrent_acti_func activation function for input/output/forget
50 : * gate
51 : * @param input input
52 : * @param prev_hidden_state previous hidden state
53 : * @param prev_cell_state previous cell state
54 : * @param hidden_state hidden state
55 : * @param cell_state cell state
56 : * @param weight_ih weight for input to hidden
57 : * @param weight_hh weight for hidden to hidden
58 : * @param bias_h bias for input and hidden.
59 : * @param bias_ih bias for input
60 : * @param bias_hh bias for hidden
61 : * @param ifgo input gate, forget gate, memory cell, output gate
62 : */
63 : void forwardLSTM(const unsigned int batch_size, const unsigned int unit,
64 : const bool disable_bias, const bool integrate_bias,
65 : ActiFunc &acti_func, ActiFunc &recurrent_acti_func,
66 : const Tensor &input, const Tensor &prev_hidden_state,
67 : const Tensor &prev_cell_state, Tensor &hidden_state,
68 : Tensor &cell_state, const Tensor &weight_ih,
69 : const Tensor &weight_hh, const Tensor &bias_h,
70 : const Tensor &bias_ih, const Tensor &bias_hh, Tensor &ifgo);
71 :
72 : /**
73 : * @brief lstm cell calculate derivative implementation
74 : *
75 : * @param outgoing_derivative derivative for input
76 : * @param weight_ih weight for input to hidden
77 : * @param d_ifgo gradient for input gate, forget gate, memory cell, output
78 : * gate
79 : * @param alpha value to be scale outgoing_derivative
80 : */
81 : void calcDerivativeLSTM(Tensor &outgoing_derivative, const Tensor &weight_ih,
82 : const Tensor &d_ifgo, const float alpha = 0.0f);
83 :
84 : /**
85 : * @brief lstm cell calculate gradient implementation
86 : *
87 : * @param batch_size batch size
88 : * @param unit number of output neurons
89 : * @param disable_bias whether to disable bias or not
90 : * @param integrate_bias integrate bias_ih, bias_hh to bias_h
91 : * @param acti_func activation function for memory cell, cell state
92 : * @param recurrent_acti_func activation function for input/output/forget
93 : * gate
94 : * @param input input
95 : * @param prev_hidden_state previous hidden state
96 : * @param d_prev_hidden_state previous hidden state gradient
97 : * @param prev_cell_state previous cell state
98 : * @param d_prev_cell_state previous cell state gradient
99 : * @param d_hidden_state hidden state gradient
100 : * @param cell_state cell state
101 : * @param d_cell_state cell state gradient
102 : * @param d_weight_ih weight_ih(weight for input to hidden) gradient
103 : * @param weight_hh weight for hidden to hidden
104 : * @param d_weight_hh weight_hh(weight for hidden to hidden) gradient
105 : * @param d_bias_h bias_h(bias for input and hidden) gradient
106 : * @param d_bias_ih bias_ih(bias for input) gradient
107 : * @param d_bias_hh bias_hh(bias for hidden) gradient
108 : * @param ifgo input gate, forget gate, memory cell, output gate
109 : * @param d_ifgo gradient for input gate, forget gate, memory cell, output
110 : * gate
111 : */
112 : void calcGradientLSTM(const unsigned int batch_size, const unsigned int unit,
113 : const bool disable_bias, const bool integrate_bias,
114 : ActiFunc &acti_func, ActiFunc &recurrent_acti_func,
115 : const Tensor &input, const Tensor &prev_hidden_state,
116 : Tensor &d_prev_hidden_state,
117 : const Tensor &prev_cell_state,
118 : Tensor &d_prev_cell_state, const Tensor &d_hidden_state,
119 : const Tensor &cell_state, const Tensor &d_cell_state,
120 : Tensor &d_weight_ih, const Tensor &weight_hh,
121 : Tensor &d_weight_hh, Tensor &d_bias_h,
122 : Tensor &d_bias_ih, Tensor &d_bias_hh,
123 : const Tensor &ifgo, Tensor &d_ifgo);
124 :
125 : /**
126 : * @copydoc Layer::setProperty(const PropertyType type, const std::string
127 : * &value)
128 : */
129 : void setProperty(const std::vector<std::string> &values) override;
130 :
131 : /**
132 : * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
133 : * method)
134 : */
135 : void exportTo(Exporter &exporter,
136 : const ml::train::ExportMethods &method) const override;
137 :
138 : protected:
139 : /**
140 : * Unit: number of output neurons
141 : * IntegrateBias: integrate bias_ih, bias_hh to bias_h
142 : * HiddenStateActivation: activation type for hidden state. default is tanh
143 : * RecurrentActivation: activation type for recurrent. default is sigmoid
144 : *
145 : * */
146 : std::tuple<props::Unit, props::IntegrateBias, props::HiddenStateActivation,
147 : props::RecurrentActivation>
148 : lstmcore_props;
149 :
150 : /**
151 : * @brief activation function: default is tanh
152 : */
153 : ActiFunc acti_func;
154 :
155 : /**
156 : * @brief activation function for recurrent: default is sigmoid
157 : */
158 : ActiFunc recurrent_acti_func;
159 :
160 : /**
161 : * @brief to protect overflow
162 : */
163 : float epsilon;
164 : };
165 : } // namespace nntrainer
166 :
167 : #endif /* __cplusplus */
168 : #endif /* __LSTMCELLCORE_H__ */
|