Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file attention_layer.cpp
6 : * @date 1 October 2021
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
9 : * @bug No known bugs except for NYI items
10 : * @brief This is Attention Layer Class for Neural Network
11 : *
12 : */
13 :
14 : #include <cmath>
15 :
16 : #include <attention_layer.h>
17 : #include <layer_context.h>
18 : #include <nntrainer_error.h>
19 : #include <nntrainer_log.h>
20 : #include <node_exporter.h>
21 :
22 : namespace nntrainer {
23 :
24 21 : AttentionLayer::AttentionLayer() {
25 : wt_idx.fill(std::numeric_limits<unsigned>::max());
26 21 : }
27 :
28 42 : AttentionLayer::~AttentionLayer() {}
29 :
30 : static constexpr size_t SINGLE_INOUT_IDX = 0;
31 :
32 : enum AttentionParams { query = 0, value = 1, key = 2, weights };
33 :
34 7 : void AttentionLayer::finalizeCommon(InitLayerContext &context) {
35 7 : if (context.getNumInputs() < 2 || context.getNumInputs() > 3)
36 0 : throw std::runtime_error("Attention layer needs 2-3 inputs.");
37 :
38 : auto const &all_dims = context.getInputDimensions();
39 : auto const &query_dim = all_dims[AttentionParams::query];
40 : auto const &value_dim = all_dims[AttentionParams::value];
41 :
42 7 : NNTR_THROW_IF(query_dim.width() != value_dim.width(), std::invalid_argument)
43 : << "Query and Value dimension mismatch for layer " << context.getName();
44 :
45 7 : wt_idx[AttentionParams::query] = AttentionParams::query;
46 7 : wt_idx[AttentionParams::value] = AttentionParams::value;
47 7 : wt_idx[AttentionParams::key] = AttentionParams::value;
48 :
49 7 : if (context.getNumInputs() == 3) {
50 : auto const &key_dim = all_dims[AttentionParams::key];
51 1 : if (key_dim != value_dim)
52 0 : throw std::invalid_argument("Key and value must have same shape");
53 :
54 1 : wt_idx[AttentionParams::key] = AttentionParams::key;
55 : }
56 7 : }
57 :
58 7 : void AttentionLayer::finalize(InitLayerContext &context) {
59 7 : finalizeCommon(context);
60 :
61 : auto const &all_dims = context.getInputDimensions();
62 : auto const &query_dim = all_dims[AttentionParams::query];
63 : auto const &value_dim = all_dims[AttentionParams::value];
64 :
65 7 : auto weights_dim = query_dim;
66 7 : weights_dim.width(value_dim.height());
67 7 : wt_idx[AttentionParams::weights] =
68 7 : context.requestTensor(weights_dim, "weights", Initializer::NONE, false,
69 : TensorLifespan::ITERATION_LIFESPAN);
70 :
71 7 : context.setOutputDimensions({query_dim});
72 :
73 : auto data_type = context.getActivationDataType();
74 7 : if (data_type == ml::train::TensorDim::DataType::FP32) {
75 7 : sm.setActiFunc<float>(ActivationType::ACT_SOFTMAX);
76 0 : } else if (data_type == ml::train::TensorDim::DataType::FP16) {
77 : #ifdef ENABLE_FP16
78 : sm.setActiFunc<_FP16>(ActivationType::ACT_SOFTMAX);
79 : #else
80 0 : throw std::runtime_error("enable-fp16 is not enabled");
81 : #endif
82 : }
83 7 : }
84 :
85 15 : void AttentionLayer::forwarding(RunLayerContext &context, bool training) {
86 15 : Tensor &query = context.getInput(wt_idx[AttentionParams::query]);
87 15 : Tensor &value = context.getInput(wt_idx[AttentionParams::value]);
88 15 : Tensor &key = context.getInput(wt_idx[AttentionParams::key]);
89 :
90 15 : Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
91 15 : Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]);
92 :
93 15 : query.dotBatched(key, weights, false, true); /** dot 1 */
94 15 : if (std::get<props::ScaledDotProduct>(attention_props).get()) {
95 0 : weights.multiply_i(1 / sqrt((float)key.getDim().width()));
96 : }
97 15 : if (std::get<props::CausalMask>(attention_props).get()) {
98 0 : unsigned int mask_size = weights.getDim().width();
99 : unsigned int mask_dim_height = mask_size;
100 : unsigned int mask_dim_width = mask_size;
101 :
102 0 : Tensor causal_mask(TensorDim{mask_size, mask_size});
103 :
104 0 : causal_mask.setZero();
105 0 : for (unsigned int i = 0; i < mask_dim_height; ++i) {
106 0 : for (unsigned int j = i + 1; j < mask_dim_width; ++j) {
107 0 : causal_mask.setValue(0, 0, i, j, -1e10);
108 : }
109 : }
110 :
111 0 : weights.add_i(causal_mask);
112 0 : }
113 :
114 : sm.run_fn(weights, weights); /** softmax */
115 15 : weights.dotBatched(value, output); /** dot 2 */
116 15 : }
117 :
118 0 : void AttentionLayer::incremental_forwarding(RunLayerContext &context,
119 : unsigned int from, unsigned int to,
120 : bool training) {
121 0 : Tensor &query = context.getInput(wt_idx[AttentionParams::query]);
122 0 : Tensor &value = context.getInput(wt_idx[AttentionParams::value]);
123 0 : Tensor &key = context.getInput(wt_idx[AttentionParams::key]);
124 :
125 0 : TensorDim query_dim = query.getDim();
126 0 : TensorDim value_dim = value.getDim();
127 0 : TensorDim key_dim = key.getDim();
128 0 : TensorDim query_step_dim = {query_dim.batch(), query_dim.channel(), to - from,
129 0 : query_dim.width()};
130 0 : TensorDim value_step_dim = {value_dim.batch(), value_dim.channel(), to,
131 0 : value_dim.width()};
132 0 : TensorDim key_step_dim = {key_dim.batch(), key_dim.channel(), to,
133 0 : key_dim.width()};
134 : Tensor query_step =
135 0 : query.getSharedDataTensor(query_step_dim, from * query_dim.width(), true);
136 0 : Tensor value_step = value.getSharedDataTensor(value_step_dim, 0, true);
137 0 : Tensor key_step = key.getSharedDataTensor(key_step_dim, 0, true);
138 :
139 0 : Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
140 0 : TensorDim output_dim = output.getDim();
141 0 : TensorDim output_step_dim = {output_dim.batch(), output_dim.channel(),
142 0 : to - from, output_dim.width()};
143 : Tensor output_step = output.getSharedDataTensor(
144 0 : output_step_dim, from * output_dim.width(), true);
145 :
146 0 : Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]);
147 0 : TensorDim weights_dim = weights.getDim();
148 : TensorDim weights_step_dim = {
149 0 : query_step_dim.batch(), query_step_dim.channel(), query_step_dim.height(),
150 0 : value_step_dim.height()};
151 : Tensor weights_step = weights.getSharedDataTensor(
152 0 : weights_step_dim, from * weights_dim.width(), true);
153 :
154 0 : query_step.dotBatched(key_step, weights_step, false, true); /** dot 1 */
155 0 : if (std::get<props::ScaledDotProduct>(attention_props).get()) {
156 0 : weights_step.multiply_i(1 / sqrt((float)key.getDim().width()));
157 : }
158 :
159 0 : if (std::get<props::CausalMask>(attention_props).get() && !from) {
160 0 : unsigned int mask_size = weights_step.getDim().width();
161 : unsigned int mask_dim_height = mask_size;
162 : unsigned int mask_dim_width = mask_size;
163 :
164 0 : Tensor causal_mask(TensorDim{mask_size, mask_size});
165 :
166 0 : causal_mask.setZero();
167 0 : for (unsigned int i = 0; i < mask_dim_height; ++i) {
168 0 : for (unsigned int j = i + 1; j < mask_dim_width; ++j) {
169 0 : causal_mask.setValue(0, 0, i, j, -1e10);
170 : }
171 : }
172 :
173 0 : weights_step.add_i(causal_mask);
174 0 : }
175 :
176 : sm.run_fn(weights_step, weights_step); /** softmax */
177 0 : weights_step.dotBatched(value_step, output_step); /** dot 2 */
178 0 : }
179 :
180 3 : void AttentionLayer::calcDerivative(RunLayerContext &context) {
181 3 : const Tensor &derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX);
182 :
183 3 : Tensor &query = context.getInput(wt_idx[AttentionParams::query]);
184 3 : Tensor &value = context.getInput(wt_idx[AttentionParams::value]);
185 3 : Tensor &key = context.getInput(wt_idx[AttentionParams::key]);
186 :
187 : Tensor &dquery =
188 3 : context.getOutgoingDerivative(wt_idx[AttentionParams::query]);
189 : Tensor &dvalue =
190 3 : context.getOutgoingDerivative(wt_idx[AttentionParams::value]);
191 3 : Tensor &dkey = context.getOutgoingDerivative(wt_idx[AttentionParams::key]);
192 :
193 3 : Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]);
194 :
195 : Tensor dweight = Tensor(
196 3 : TensorDim({derivative.batch(), 1, derivative.height(), value.height()},
197 3 : weights.getTensorType()));
198 :
199 : /** derivative for dot 2 */
200 3 : dweight.dot_batched_deriv_wrt_1(value, derivative);
201 3 : weights.dot_batched_deriv_wrt_2(dvalue, derivative);
202 :
203 : /** derivative for softmax */
204 3 : sm.run_prime_fn(weights, dweight, dweight);
205 :
206 3 : if (std::get<props::ScaledDotProduct>(attention_props).get()) {
207 0 : dweight.multiply_i(1 / sqrt((float)key.getDim().width()));
208 : }
209 :
210 : /** derivative for dot 1 */
211 3 : dquery.dot_batched_deriv_wrt_1(key, dweight, false, true);
212 4 : query.dot_batched_deriv_wrt_2(dkey, dweight, false, true,
213 3 : context.getNumInputs() == 2);
214 3 : }
215 :
216 36 : void AttentionLayer::setProperty(const std::vector<std::string> &values) {
217 36 : auto remain_props = loadProperties(values, attention_props);
218 35 : if (!remain_props.empty()) {
219 : std::string msg = "[AttentionLayer] Unknown Layer Properties count " +
220 2 : std::to_string(values.size());
221 4 : throw exception::not_supported(msg);
222 : }
223 35 : }
224 :
225 0 : void AttentionLayer::setBatch(RunLayerContext &context, unsigned int batch) {
226 0 : context.updateTensor(wt_idx[AttentionParams::weights], batch);
227 0 : }
228 :
229 : } /* namespace nntrainer */
|