Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Jijoong Moon <jijoong.moon@samsung.com>
4 : *
5 : * @file gru.h
6 : * @date 31 March 2021
7 : * @brief This is Gated Recurrent Unit Layer Class of Neural Network
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jijoong Moon <jijoong.moon@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #ifndef __GRU_H__
15 : #define __GRU_H__
16 : #ifdef __cplusplus
17 :
18 : #include <acti_func.h>
19 : #include <common_properties.h>
20 : #include <layer_impl.h>
21 :
22 : namespace nntrainer {
23 :
24 : /**
25 : * @class GRULayer
26 : * @brief GRULayer
27 : */
28 : class GRULayer : public LayerImpl {
29 : public:
30 : /**
31 : * @brief Constructor of GRULayer
32 : */
33 : GRULayer();
34 :
35 : /**
36 : * @brief Destructor of GRULayer
37 : */
38 172 : ~GRULayer() = default;
39 :
40 : /**
41 : * @brief Move constructor.
42 : * @param[in] GRULayer &&
43 : */
44 : GRULayer(GRULayer &&rhs) noexcept = default;
45 :
46 : /**
47 : * @brief Move assignment operator.
48 : * @parma[in] rhs GRULayer to be moved.
49 : */
50 : GRULayer &operator=(GRULayer &&rhs) = default;
51 :
52 : /**
53 : * @copydoc Layer::finalize(InitLayerContext &context)
54 : */
55 : void finalize(InitLayerContext &context) override;
56 :
57 : /**
58 : * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
59 : */
60 : void forwarding(RunLayerContext &context, bool training) override;
61 :
62 : /**
63 : * @copydoc Layer::calcDerivative(RunLayerContext &context)
64 : */
65 : void calcDerivative(RunLayerContext &context) override;
66 :
67 : /**
68 : * @copydoc Layer::calcGradient(RunLayerContext &context)
69 : */
70 : void calcGradient(RunLayerContext &context) override;
71 :
72 : /**
73 : * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
74 : * method)
75 : */
76 : void exportTo(Exporter &exporter,
77 : const ml::train::ExportMethods &method) const override;
78 :
79 : /**
80 : * @copydoc Layer::getType()
81 : */
82 2381 : const std::string getType() const override { return GRULayer::type; };
83 :
84 : /**
85 : * @copydoc Layer::supportBackwarding()
86 : */
87 102 : bool supportBackwarding() const override { return true; }
88 :
89 : /**
90 : * @copydoc Layer::setProperty(const PropertyType type, const std::string
91 : * &value)
92 : */
93 : void setProperty(const std::vector<std::string> &values) override;
94 :
95 : /**
96 : * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch)
97 : */
98 : void setBatch(RunLayerContext &context, unsigned int batch) override;
99 :
100 : static constexpr const char *type = "gru";
101 :
102 : private:
103 : static constexpr unsigned int NUM_GATE = 3;
104 :
105 : /**
106 : * Unit: number of output neurons
107 : * HiddenStateActivation: activation type for hidden state. default is tanh
108 : * RecurrentActivation: activation type for recurrent. default is sigmoid
109 : * ReturnSequence: option for return sequence
110 : * DropOutRate: dropout rate
111 : * IntegrateBias: integrate bias_ih, bias_hh to bias_h
112 : * ResetAfter: Whether apply reset gate before/after the matrix
113 : *
114 : * */
115 : std::tuple<props::Unit, props::HiddenStateActivation,
116 : props::RecurrentActivation, props::ReturnSequences,
117 : props::DropOutRate, props::IntegrateBias, props::ResetAfter>
118 : gru_props;
119 : std::array<unsigned int, 9> wt_idx; /**< indices of the weights */
120 :
121 : /**
122 : * @brief activation function for h_t : default is sigmoid
123 : */
124 : ActiFunc acti_func;
125 :
126 : /**
127 : * @brief activation function for recurrent : default is tanh
128 : */
129 : ActiFunc recurrent_acti_func;
130 :
131 : /**
132 : * @brief to protect overflow
133 : */
134 : float epsilon;
135 : };
136 : } // namespace nntrainer
137 :
138 : #endif /* __cplusplus */
139 : #endif /* __GRU_H__ */
|