Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file mol_attention_layer.h
6 : * @date 11 November 2021
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
9 : * @bug No known bugs except for NYI items
10 : * @brief This is MoL Attention Layer Class for Neural Network
11 : *
12 : */
13 :
14 : #ifndef __MOL_ATTENTION_LAYER_H__
15 : #define __MOL_ATTENTION_LAYER_H__
16 : #ifdef __cplusplus
17 :
18 : #include <attention_layer.h>
19 : #include <layer_impl.h>
20 :
21 : namespace nntrainer {
22 :
23 : /**
24 : * @class MoL Attention Layer
25 : * @brief Mixture of Logistics Attention Layer
26 : */
27 : class MoLAttentionLayer : public LayerImpl {
28 : public:
29 : /**
30 : * @brief Constructor of MoL Attention Layer
31 : */
32 : MoLAttentionLayer();
33 :
34 : /**
35 : * @brief Destructor of MoL Attention Layer
36 : */
37 : ~MoLAttentionLayer();
38 :
39 : /**
40 : * @brief Move constructor of MoLAttentionLayer.
41 : * @param[in] MoLAttentionLayer &&
42 : */
43 : MoLAttentionLayer(MoLAttentionLayer &&rhs) noexcept = default;
44 :
45 : /**
46 : * @brief Move assignment operator.
47 : * @parma[in] rhs MoLAttentionLayer to be moved.
48 : */
49 : MoLAttentionLayer &operator=(MoLAttentionLayer &&rhs) = default;
50 :
51 : /**
52 : * @copydoc Layer::finalize(InitLayerContext &context)
53 : */
54 : void finalize(InitLayerContext &context) override;
55 :
56 : /**
57 : * @copydoc Layer::forwarding(RunLayerContext &context, bool training)
58 : */
59 : void forwarding(RunLayerContext &context, bool training) override;
60 :
61 : /**
62 : * @copydoc Layer::calcDerivative(RunLayerContext &context)
63 : */
64 : void calcDerivative(RunLayerContext &context) override;
65 :
66 : /**
67 : * @copydoc Layer::calcGradient(RunLayerContext &context)
68 : */
69 : void calcGradient(RunLayerContext &context) override;
70 :
71 : /**
72 : * @copydoc bool supportBackwarding() const
73 : */
74 2 : bool supportBackwarding() const override { return true; };
75 :
76 : /**
77 : * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
78 : * method)
79 : */
80 : void exportTo(Exporter &exporter,
81 : const ml::train::ExportMethods &method) const override;
82 :
83 : /**
84 : * @copydoc Layer::setProperty(const std::vector<std::string> &values)
85 : */
86 : void setProperty(const std::vector<std::string> &values) override;
87 :
88 : /**
89 : * @copydoc Layer::getType()
90 : */
91 24 : const std::string getType() const override {
92 24 : return MoLAttentionLayer::type;
93 : };
94 :
95 : /**
96 : * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch)
97 : */
98 : void setBatch(RunLayerContext &context, unsigned int batch) override;
99 :
100 : static constexpr const char *type = "mol_attention";
101 :
102 : private:
103 : std::tuple<props::Unit, props::MoL_K>
104 : mol_props; /**< mol attention layer properties : unit - number of output
105 : neurons */
106 :
107 : bool helper_exec; /** check if the helper function has already ran */
108 : ActiFunc softmax; /** softmax activation operation */
109 : ActiFunc tanh; /** softmax activation operation */
110 : ActiFunc sigmoid; /** softmax activation operation */
111 : std::array<unsigned int, 17>
112 : wt_idx; /**< indices of the weights and tensors */
113 :
114 : /**
115 : * @brief Helper function for calculation of the derivative
116 : *
117 : * @param context layer context
118 : * @param dstate to store the derivative of the state
119 : */
120 : void calcDerivativeHelper(RunLayerContext &context, Tensor &dstate);
121 : };
122 :
123 : } // namespace nntrainer
124 :
125 : #endif /* __cplusplus */
126 : #endif /* __MOL_ATTENTION_LAYER_H__ */
|