Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 hyeonseok lee <hs89.lee@samsung.com>
4 : *
5 : * @file lstmcell_core.cpp
6 : * @date 25 November 2021
7 : * @brief This is lstm core class.
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author hyeonseok lee <hs89.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #include <lstmcell_core.h>
15 : #include <nntrainer_error.h>
16 : #include <nntrainer_log.h>
17 :
18 : namespace nntrainer {
19 :
20 419 : LSTMCore::LSTMCore() :
21 : LayerImpl(),
22 1257 : lstmcore_props(props::Unit(), props::IntegrateBias(),
23 838 : props::HiddenStateActivation() = ActivationType::ACT_TANH,
24 838 : props::RecurrentActivation() = ActivationType::ACT_SIGMOID),
25 419 : acti_func(ActivationType::ACT_NONE, true),
26 419 : recurrent_acti_func(ActivationType::ACT_NONE, true),
27 838 : epsilon(1e-3f) {}
28 :
29 1244 : void LSTMCore::forwardLSTM(const unsigned int batch_size,
30 : const unsigned int unit, const bool disable_bias,
31 : const bool integrate_bias, ActiFunc &acti_func,
32 : ActiFunc &recurrent_acti_func, const Tensor &input,
33 : const Tensor &prev_hidden_state,
34 : const Tensor &prev_cell_state, Tensor &hidden_state,
35 : Tensor &cell_state, const Tensor &weight_ih,
36 : const Tensor &weight_hh, const Tensor &bias_h,
37 : const Tensor &bias_ih, const Tensor &bias_hh,
38 : Tensor &ifgo) {
39 1244 : input.dot(weight_ih, ifgo);
40 1244 : prev_hidden_state.dot(weight_hh, ifgo, false, false, 1.0);
41 1244 : if (!disable_bias) {
42 1244 : if (integrate_bias) {
43 560 : ifgo.add_i(bias_h);
44 : } else {
45 684 : ifgo.add_i(bias_ih);
46 684 : ifgo.add_i(bias_hh);
47 : }
48 : }
49 :
50 1244 : TensorDim::TensorType tensor_type = ifgo.getTensorType();
51 :
52 : Tensor input_forget_gate = ifgo.getSharedDataTensor(
53 1244 : {batch_size, 1, 1, unit * 2, tensor_type}, 0, false);
54 : Tensor input_gate =
55 1244 : ifgo.getSharedDataTensor({batch_size, 1, 1, unit, tensor_type}, 0, false);
56 : Tensor forget_gate = ifgo.getSharedDataTensor(
57 1244 : {batch_size, 1, 1, unit, tensor_type}, unit, false);
58 : Tensor memory_cell = ifgo.getSharedDataTensor(
59 1244 : {batch_size, 1, 1, unit, tensor_type}, unit * 2, false);
60 : Tensor output_gate = ifgo.getSharedDataTensor(
61 1244 : {batch_size, 1, 1, unit, tensor_type}, unit * 3, false);
62 :
63 : recurrent_acti_func.run_fn(input_forget_gate, input_forget_gate);
64 : recurrent_acti_func.run_fn(output_gate, output_gate);
65 : acti_func.run_fn(memory_cell, memory_cell);
66 :
67 1244 : prev_cell_state.multiply_strided(forget_gate, cell_state);
68 1244 : memory_cell.multiply_strided(input_gate, cell_state, 1.0f);
69 :
70 : acti_func.run_fn(cell_state, hidden_state);
71 1244 : hidden_state.multiply_i_strided(output_gate);
72 1244 : }
73 :
74 291 : void LSTMCore::calcDerivativeLSTM(Tensor &outgoing_derivative,
75 : const Tensor &weight_ih, const Tensor &d_ifgo,
76 : const float alpha) {
77 291 : d_ifgo.dot(weight_ih, outgoing_derivative, false, true, alpha);
78 291 : }
79 :
80 606 : void LSTMCore::calcGradientLSTM(
81 : const unsigned int batch_size, const unsigned int unit,
82 : const bool disable_bias, const bool integrate_bias, ActiFunc &acti_func,
83 : ActiFunc &recurrent_acti_func, const Tensor &input,
84 : const Tensor &prev_hidden_state, Tensor &d_prev_hidden_state,
85 : const Tensor &prev_cell_state, Tensor &d_prev_cell_state,
86 : const Tensor &d_hidden_state, const Tensor &cell_state,
87 : const Tensor &d_cell_state, Tensor &d_weight_ih, const Tensor &weight_hh,
88 : Tensor &d_weight_hh, Tensor &d_bias_h, Tensor &d_bias_ih, Tensor &d_bias_hh,
89 : const Tensor &ifgo, Tensor &d_ifgo) {
90 606 : TensorDim::TensorType tensor_type = ifgo.getTensorType();
91 : Tensor input_forget_gate = ifgo.getSharedDataTensor(
92 606 : {batch_size, 1, 1, unit * 2, tensor_type}, 0, false);
93 : Tensor input_gate =
94 606 : ifgo.getSharedDataTensor({batch_size, 1, 1, unit, tensor_type}, 0, false);
95 : Tensor forget_gate = ifgo.getSharedDataTensor(
96 606 : {batch_size, 1, 1, unit, tensor_type}, unit, false);
97 : Tensor memory_cell = ifgo.getSharedDataTensor(
98 606 : {batch_size, 1, 1, unit, tensor_type}, unit * 2, false);
99 : Tensor output_gate = ifgo.getSharedDataTensor(
100 606 : {batch_size, 1, 1, unit, tensor_type}, unit * 3, false);
101 :
102 : Tensor d_input_forget_gate = d_ifgo.getSharedDataTensor(
103 606 : {batch_size, 1, 1, unit * 2, tensor_type}, 0, false);
104 : Tensor d_input_gate =
105 606 : d_ifgo.getSharedDataTensor({batch_size, 1, 1, unit, tensor_type}, 0, false);
106 : Tensor d_forget_gate = d_ifgo.getSharedDataTensor(
107 606 : {batch_size, 1, 1, unit, tensor_type}, unit, false);
108 : Tensor d_memory_cell = d_ifgo.getSharedDataTensor(
109 606 : {batch_size, 1, 1, unit, tensor_type}, unit * 2, false);
110 : Tensor d_output_gate = d_ifgo.getSharedDataTensor(
111 606 : {batch_size, 1, 1, unit, tensor_type}, unit * 3, false);
112 :
113 : Tensor activated_cell_state = Tensor(
114 1212 : "activated_cell_state", cell_state.getFormat(), cell_state.getDataType());
115 :
116 : acti_func.run_fn(cell_state, activated_cell_state);
117 606 : d_hidden_state.multiply_strided(activated_cell_state, d_output_gate);
118 606 : acti_func.run_prime_fn(activated_cell_state, d_prev_cell_state,
119 : d_hidden_state);
120 606 : d_prev_cell_state.multiply_i_strided(output_gate);
121 606 : d_prev_cell_state.add_i(d_cell_state);
122 :
123 606 : d_prev_cell_state.multiply_strided(input_gate, d_memory_cell);
124 606 : d_prev_cell_state.multiply_strided(memory_cell, d_input_gate);
125 :
126 606 : d_prev_cell_state.multiply_strided(prev_cell_state, d_forget_gate);
127 606 : d_prev_cell_state.multiply_i_strided(forget_gate);
128 :
129 606 : recurrent_acti_func.run_prime_fn(output_gate, d_output_gate, d_output_gate);
130 606 : recurrent_acti_func.run_prime_fn(input_forget_gate, d_input_forget_gate,
131 : d_input_forget_gate);
132 606 : acti_func.run_prime_fn(memory_cell, d_memory_cell, d_memory_cell);
133 :
134 606 : if (!disable_bias) {
135 606 : if (integrate_bias) {
136 264 : d_ifgo.sum(0, d_bias_h, 1.0f, 1.0f);
137 : } else {
138 342 : d_ifgo.sum(0, d_bias_ih, 1.0f, 1.0f);
139 342 : d_ifgo.sum(0, d_bias_hh, 1.0f, 1.0f);
140 : }
141 : }
142 :
143 606 : if (input.batch() != 1) {
144 19 : input.dot(d_ifgo, d_weight_ih, true, false, 1.0f);
145 : } else {
146 :
147 1960 : for (unsigned int i = 0; i < d_weight_ih.height(); ++i) {
148 1373 : unsigned int out_width = d_weight_ih.width();
149 1373 : d_weight_ih.add_i_partial(out_width, i * out_width, d_ifgo, 1, 1, input,
150 : i);
151 : }
152 : }
153 :
154 606 : if (prev_hidden_state.batch() != 1) {
155 19 : prev_hidden_state.dot(d_ifgo, d_weight_hh, true, false, 1.0f);
156 : } else {
157 1912 : for (unsigned int i = 0; i < d_weight_hh.height(); ++i) {
158 1325 : unsigned int out_width = d_weight_hh.width();
159 1325 : d_weight_hh.add_i_partial(out_width, i * out_width, d_ifgo, 1, 1,
160 : prev_hidden_state, i);
161 : }
162 : }
163 606 : d_ifgo.dot(weight_hh, d_prev_hidden_state, false, true);
164 606 : }
165 :
166 2097 : void LSTMCore::setProperty(const std::vector<std::string> &values) {
167 : const std::vector<std::string> &remain_props =
168 2097 : loadProperties(values, lstmcore_props);
169 2097 : LayerImpl::setProperty(remain_props);
170 2097 : }
171 :
172 298 : void LSTMCore::exportTo(Exporter &exporter,
173 : const ml::train::ExportMethods &method) const {
174 298 : LayerImpl::exportTo(exporter, method);
175 298 : exporter.saveResult(lstmcore_props, method, this);
176 298 : }
177 : } // namespace nntrainer
|