Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2022 hyeonseok Lee <hs89.lee@samsung.com>
4 : *
5 : * @file multi_head_attention_layer.cpp
6 : * @date 08 July 2022
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * https://arxiv.org/abs/1706.03762
9 : * @author hyeonseok Lee <hs89.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : * @brief This is MultiHeadAttention Layer Class for Neural Network
12 : *
13 : */
14 :
15 : #include <cmath>
16 :
17 : #include <layer_context.h>
18 : #include <multi_head_attention_layer.h>
19 : #include <nntrainer_error.h>
20 : #include <nntrainer_log.h>
21 : #include <node_exporter.h>
22 :
23 : namespace nntrainer {
24 :
25 149 : MultiHeadAttentionLayer::MultiHeadAttentionLayer() :
26 : multi_head_attention_props(
27 298 : props::NumHeads(), props::ProjectedKeyDim(), props::ProjectedValueDim(),
28 447 : props::OutputShape(), props::DropOutRate(), props::ReturnAttentionWeight(),
29 0 : props::AverageAttentionWeight()),
30 149 : sm(ActivationType::ACT_SOFTMAX),
31 298 : epsilon(1e-3f) {
32 : weight_idx.fill(std::numeric_limits<unsigned>::max());
33 149 : }
34 :
35 298 : MultiHeadAttentionLayer::~MultiHeadAttentionLayer() {}
36 :
37 : enum INOUT_INDEX {
38 : /** input index */
39 : QUERY = 0,
40 : KEY = 1,
41 : VALUE = 2,
42 : MASK = 3,
43 : /** output index */
44 : OUTPUT = 0,
45 : RETURN_ATTENTION_WEIGHT = 1,
46 : };
47 :
48 : enum AttentionParams {
49 : query_fc_weight,
50 : query_fc_bias,
51 : key_fc_weight,
52 : key_fc_bias,
53 : value_fc_weight,
54 : value_fc_bias,
55 : fc_weight,
56 : fc_bias,
57 : projected_query,
58 : projected_key,
59 : projected_value,
60 : cache_key,
61 : cache_value,
62 : /** intended comment for later use of attention_mask */
63 : // attention_mask,
64 : attention_weight,
65 : dropout_mask,
66 : attention_output,
67 : };
68 :
69 121 : void MultiHeadAttentionLayer::finalize(InitLayerContext &context) {
70 121 : NNTR_THROW_IF(context.getNumInputs() < 3 || context.getNumInputs() > 4,
71 : std::invalid_argument)
72 : << "Multi head Attention layer needs 3 or 4 inputs. (query, key, value and "
73 : "mask is optional";
74 : const bool provide_attention_mask = context.getNumInputs() == 4;
75 :
76 : TensorDim::TensorType weight_type = {context.getFormat(),
77 : context.getWeightDataType()};
78 :
79 : TensorDim::TensorType activation_type = {context.getFormat(),
80 : context.getActivationDataType()};
81 :
82 121 : TensorDim empty_dim(activation_type);
83 :
84 : const std::vector<TensorDim> &input_dims = context.getInputDimensions();
85 : const TensorDim &query_dim = input_dims[INOUT_INDEX::QUERY];
86 : const TensorDim &key_dim = input_dims[INOUT_INDEX::KEY];
87 : const TensorDim &value_dim = input_dims[INOUT_INDEX::VALUE];
88 : const TensorDim &mask_dim =
89 121 : provide_attention_mask ? input_dims[INOUT_INDEX::MASK] : empty_dim;
90 :
91 121 : const unsigned int batch_size = query_dim.batch();
92 121 : const unsigned int query_height = query_dim.height();
93 121 : const unsigned int query_width = query_dim.width();
94 121 : const unsigned int key_height = key_dim.height();
95 121 : const unsigned int key_width = key_dim.width();
96 121 : const unsigned int value_height = value_dim.height();
97 121 : const unsigned int value_width = value_dim.width();
98 :
99 : const bool disable_bias =
100 121 : std::get<props::DisableBias>(*layer_impl_props).get();
101 : auto &weight_initializer =
102 121 : std::get<props::WeightInitializer>(*layer_impl_props).get();
103 : auto &weight_regularizer =
104 : std::get<props::WeightRegularizer>(*layer_impl_props);
105 : auto &weight_regularizer_constant =
106 : std::get<props::WeightRegularizerConstant>(*layer_impl_props);
107 : const float &weight_decay =
108 121 : std::get<props::WeightDecay>(*layer_impl_props).get();
109 :
110 121 : NNTR_THROW_IF(std::get<props::NumHeads>(multi_head_attention_props).empty(),
111 : std::invalid_argument)
112 : << "num_heads property is not provided for layer " << context.getName();
113 : const unsigned int num_heads =
114 121 : std::get<props::NumHeads>(multi_head_attention_props).get();
115 :
116 121 : if (std::get<props::ProjectedKeyDim>(multi_head_attention_props).empty()) {
117 54 : NNTR_THROW_IF(query_width % num_heads, std::invalid_argument)
118 : << "query_width: " << query_width
119 : << " is not divisible by num_heads: " << num_heads << " for layer "
120 : << context.getName();
121 :
122 54 : ml_logw("[multi head attention] ProjectedKeyDim property is not given. "
123 : "Default value(query_width / num_heads) is set");
124 :
125 : std::get<props::ProjectedKeyDim>(multi_head_attention_props)
126 54 : .set(query_width / num_heads);
127 : }
128 : const unsigned int projected_key_dim_prop =
129 121 : std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
130 :
131 121 : if (std::get<props::ProjectedValueDim>(multi_head_attention_props).empty()) {
132 : std::get<props::ProjectedValueDim>(multi_head_attention_props)
133 66 : .set(projected_key_dim_prop);
134 : }
135 : const unsigned int projected_value_dim_prop =
136 121 : std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
137 :
138 121 : if (std::get<props::OutputShape>(multi_head_attention_props).empty()) {
139 66 : std::get<props::OutputShape>(multi_head_attention_props).set(query_width);
140 : }
141 : const unsigned int output_shape =
142 121 : std::get<props::OutputShape>(multi_head_attention_props).get();
143 :
144 : const float dropout_rate =
145 121 : std::get<props::DropOutRate>(multi_head_attention_props).get();
146 :
147 121 : if (std::get<props::AverageAttentionWeight>(multi_head_attention_props)
148 : .empty()) {
149 : std::get<props::AverageAttentionWeight>(multi_head_attention_props)
150 66 : .set(true);
151 : }
152 : const bool average_attention_weight =
153 121 : std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
154 :
155 : const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
156 121 : std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
157 :
158 121 : const unsigned int projected_query_dim_prop = projected_key_dim_prop;
159 :
160 121 : if (activation_type.data_type == TensorDim::DataType::FP32) {
161 121 : sm.setActiFunc(ActivationType::ACT_SOFTMAX);
162 0 : } else if (activation_type.data_type == TensorDim::DataType::FP16) {
163 : #ifdef ENABLE_FP16
164 : sm.setActiFunc<_FP16>(ActivationType::ACT_SOFTMAX);
165 : #else
166 0 : throw std::invalid_argument("Error: enable-fp16 is not enabled");
167 : #endif
168 : }
169 :
170 : // sm.setActiFunc(ActivationType::ACT_SOFTMAX);
171 :
172 121 : NNTR_THROW_IF(query_dim.channel() != 1, std::invalid_argument)
173 0 : << "Dimension of input query channel: " << query_dim.channel()
174 : << " is not 1 for layer " << context.getName();
175 121 : NNTR_THROW_IF(key_dim.channel() != 1, std::invalid_argument)
176 0 : << "Dimension of input key channel: " << key_dim.channel()
177 : << " is not 1 for layer " << context.getName();
178 121 : NNTR_THROW_IF(value_dim.channel() != 1, std::invalid_argument)
179 0 : << "Dimension of input value channel: " << value_dim.channel()
180 : << " is not 1 for layer " << context.getName();
181 121 : NNTR_THROW_IF(provide_attention_mask && mask_dim.channel() != num_heads,
182 : std::invalid_argument)
183 0 : << "Dimension of input mask channel: " << mask_dim.channel()
184 : << " is not matched with num_heads property: " << num_heads << " for layer "
185 : << context.getName();
186 :
187 121 : NNTR_THROW_IF(key_height != value_height, std::invalid_argument)
188 : << "Dimension of input key height: " << key_height
189 : << " is not matched with Dimension of input value height: " << value_height
190 : << " for layer " << context.getName();
191 121 : NNTR_THROW_IF(provide_attention_mask && mask_dim.height() != query_height,
192 : std::invalid_argument)
193 0 : << "Dimension of input mask height: " << mask_dim.height()
194 : << " is not matched with Dimension of input query height: " << query_height
195 : << " for layer " << context.getName();
196 :
197 121 : NNTR_THROW_IF(provide_attention_mask && mask_dim.width() != key_height,
198 : std::invalid_argument)
199 0 : << "Dimension of input mask width: " << mask_dim.width()
200 : << " is not matched with Dimension of input key height: " << key_height
201 : << " for layer " << context.getName();
202 :
203 : /** weight/bias for query fc */
204 : TensorDim query_fc_weight_dim(
205 121 : {1, 1, query_width, num_heads * projected_query_dim_prop}, weight_type);
206 :
207 121 : weight_idx[AttentionParams::query_fc_weight] = context.requestWeight(
208 : query_fc_weight_dim, weight_initializer, weight_regularizer,
209 : weight_regularizer_constant, weight_decay, "query_fc_weight", true);
210 121 : if (!disable_bias) {
211 : TensorDim query_fc_bias_dim({1, 1, 1, num_heads * projected_query_dim_prop},
212 117 : weight_type);
213 117 : weight_idx[AttentionParams::query_fc_bias] = context.requestWeight(
214 : query_fc_bias_dim, weight_initializer, weight_regularizer,
215 : weight_regularizer_constant, weight_decay, "query_fc_bias", true);
216 : }
217 :
218 : /** weight/bias for key fc */
219 : TensorDim key_fc_weight_dim(
220 121 : {1, 1, key_width, num_heads * projected_key_dim_prop}, weight_type);
221 121 : weight_idx[AttentionParams::key_fc_weight] = context.requestWeight(
222 : key_fc_weight_dim, weight_initializer, weight_regularizer,
223 : weight_regularizer_constant, weight_decay, "key_fc_weight", true);
224 121 : if (!disable_bias) {
225 117 : TensorDim key_fc_bias_dim({1, 1, 1, num_heads * projected_key_dim_prop},
226 117 : weight_type);
227 117 : weight_idx[AttentionParams::key_fc_bias] = context.requestWeight(
228 : key_fc_bias_dim, weight_initializer, weight_regularizer,
229 : weight_regularizer_constant, weight_decay, "key_fc_bias", true);
230 : }
231 :
232 : /** weight/bias for value fc */
233 : TensorDim value_fc_weight_dim(
234 121 : {1, 1, value_width, num_heads * projected_value_dim_prop}, weight_type);
235 121 : weight_idx[AttentionParams::value_fc_weight] = context.requestWeight(
236 : value_fc_weight_dim, weight_initializer, weight_regularizer,
237 : weight_regularizer_constant, weight_decay, "value_fc_weight", true);
238 121 : if (!disable_bias) {
239 : TensorDim value_fc_bias_dim({1, 1, 1, num_heads * projected_value_dim_prop},
240 117 : weight_type);
241 117 : weight_idx[AttentionParams::value_fc_bias] = context.requestWeight(
242 : value_fc_bias_dim, weight_initializer, weight_regularizer,
243 : weight_regularizer_constant, weight_decay, "value_fc_bias", true);
244 : }
245 :
246 : /** weight/bias for out fc */
247 : TensorDim fc_weight_dim(
248 121 : {1, 1, num_heads * projected_value_dim_prop, output_shape}, weight_type);
249 121 : weight_idx[AttentionParams::fc_weight] = context.requestWeight(
250 : fc_weight_dim, weight_initializer, weight_regularizer,
251 : weight_regularizer_constant, weight_decay, "fc_weight", true);
252 121 : if (!disable_bias) {
253 117 : TensorDim fc_bias_dim({1, 1, 1, output_shape}, weight_type);
254 117 : weight_idx[AttentionParams::fc_bias] = context.requestWeight(
255 : fc_bias_dim, weight_initializer, weight_regularizer,
256 : weight_regularizer_constant, weight_decay, "fc_bias", true);
257 : }
258 :
259 : /** tensor for output of query fc */
260 : TensorDim projected_query_dim(
261 : {batch_size, 1, query_height, num_heads * projected_query_dim_prop},
262 121 : activation_type);
263 121 : weight_idx[AttentionParams::projected_query] = context.requestTensor(
264 : projected_query_dim, "projected_query", Initializer::NONE, true,
265 : TensorLifespan::ITERATION_LIFESPAN);
266 : /** tensor for output of key fc */
267 : TensorDim projected_key_dim(
268 121 : {batch_size, 1, key_height, num_heads * projected_key_dim_prop},
269 121 : activation_type);
270 121 : weight_idx[AttentionParams::projected_key] =
271 121 : context.requestTensor(projected_key_dim, "projected_key", Initializer::NONE,
272 : true, TensorLifespan::ITERATION_LIFESPAN);
273 : /** tensor for output of value fc */
274 : TensorDim projected_value_dim(
275 : {batch_size, 1, value_height, num_heads * projected_value_dim_prop},
276 121 : activation_type);
277 121 : weight_idx[AttentionParams::projected_value] = context.requestTensor(
278 : projected_value_dim, "projected_value", Initializer::NONE, true,
279 : TensorLifespan::ITERATION_LIFESPAN);
280 :
281 121 : weight_idx[AttentionParams::cache_key] =
282 121 : context.requestTensor(projected_key_dim, "cache_key", Initializer::NONE,
283 : true, TensorLifespan::MAX_LIFESPAN);
284 :
285 121 : weight_idx[AttentionParams::cache_value] =
286 121 : context.requestTensor(projected_value_dim, "cache_value", Initializer::NONE,
287 : true, TensorLifespan::MAX_LIFESPAN);
288 :
289 : if (provide_attention_mask) {
290 : /** Intended comment for bool type mask */
291 : // TensorDim attention_mask_dim(
292 : // {batch_size, num_heads, query_height, key_height});
293 : // weight_idx[AttentionParams::attention_mask] = context.requestTensor(
294 : // attention_mask_dim, "attention_mask", Initializer::NONE, false,
295 : // TensorLifespan::FORWARD_FUNC_LIFESPAN);
296 : }
297 : /** tensor for attention weight */
298 : TensorDim attention_weight_dim(
299 121 : {batch_size, num_heads, query_height, key_height}, activation_type);
300 121 : weight_idx[AttentionParams::attention_weight] = context.requestTensor(
301 : attention_weight_dim, "attention_weight", Initializer::NONE, true,
302 : TensorLifespan::ITERATION_LIFESPAN);
303 121 : if (dropout_rate > epsilon) {
304 : /** tensor for dropout mask */
305 : TensorDim dropout_mask_dim(
306 0 : {batch_size, num_heads, query_height, key_height}, activation_type);
307 0 : weight_idx[AttentionParams::dropout_mask] =
308 0 : context.requestTensor(dropout_mask_dim, "dropout_mask", Initializer::NONE,
309 : false, TensorLifespan::ITERATION_LIFESPAN);
310 : }
311 :
312 : /** tensor for attention output */
313 : TensorDim attention_output_dim(
314 : {batch_size, 1, query_height, num_heads * projected_value_dim_prop},
315 121 : activation_type);
316 121 : weight_idx[AttentionParams::attention_output] = context.requestTensor(
317 : attention_output_dim, "attention_output", Initializer::NONE, true,
318 : TensorLifespan::ITERATION_LIFESPAN);
319 :
320 : TensorDim output_dim({batch_size, 1, query_height, output_shape},
321 121 : activation_type);
322 121 : if (return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none) {
323 : TensorDim return_attention_weight_dim(
324 17 : {batch_size, average_attention_weight ? 1 : num_heads, query_height,
325 : key_height},
326 18 : activation_type);
327 17 : context.setOutputDimensions({output_dim, return_attention_weight_dim});
328 : } else {
329 104 : context.setOutputDimensions({output_dim});
330 : }
331 121 : }
332 :
333 186 : void MultiHeadAttentionLayer::forwarding(RunLayerContext &context,
334 : bool training) {
335 : const bool disable_bias =
336 186 : std::get<props::DisableBias>(*layer_impl_props).get();
337 :
338 : const unsigned int num_heads =
339 186 : std::get<props::NumHeads>(multi_head_attention_props).get();
340 : const unsigned int projected_key_dim_prop =
341 186 : std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
342 : const unsigned int projected_value_dim_prop =
343 186 : std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
344 : const float dropout_rate =
345 186 : std::get<props::DropOutRate>(multi_head_attention_props).get();
346 : const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
347 186 : std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
348 : const bool average_attention_weight =
349 186 : std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
350 :
351 186 : const bool provide_attention_mask = context.getNumInputs() == 4;
352 : const unsigned int projected_query_dim_prop = projected_key_dim_prop;
353 186 : const bool enable_dropout = dropout_rate > epsilon;
354 :
355 186 : Tensor empty_tensor;
356 :
357 : /** get inputs/outputs */
358 186 : Tensor &query = context.getInput(INOUT_INDEX::QUERY);
359 186 : Tensor &key = context.getInput(INOUT_INDEX::KEY);
360 186 : Tensor &value = context.getInput(INOUT_INDEX::VALUE);
361 : Tensor &mask =
362 186 : provide_attention_mask ? context.getInput(INOUT_INDEX::MASK) : empty_tensor;
363 :
364 186 : Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
365 : Tensor &ret_attention_weight =
366 : return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
367 186 : ? context.getOutput(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
368 : : empty_tensor;
369 :
370 : /** get weights */
371 : Tensor &query_fc_weight =
372 186 : context.getWeight(weight_idx[AttentionParams::query_fc_weight]);
373 : Tensor &query_fc_bias =
374 : disable_bias
375 186 : ? empty_tensor
376 180 : : context.getWeight(weight_idx[AttentionParams::query_fc_bias]);
377 : Tensor &key_fc_weight =
378 186 : context.getWeight(weight_idx[AttentionParams::key_fc_weight]);
379 : Tensor &key_fc_bias =
380 186 : disable_bias ? empty_tensor
381 180 : : context.getWeight(weight_idx[AttentionParams::key_fc_bias]);
382 : Tensor &value_fc_weight =
383 186 : context.getWeight(weight_idx[AttentionParams::value_fc_weight]);
384 : Tensor &value_fc_bias =
385 : disable_bias
386 186 : ? empty_tensor
387 180 : : context.getWeight(weight_idx[AttentionParams::value_fc_bias]);
388 186 : Tensor &fc_weight = context.getWeight(weight_idx[AttentionParams::fc_weight]);
389 : Tensor &fc_bias = disable_bias
390 186 : ? empty_tensor
391 180 : : context.getWeight(weight_idx[AttentionParams::fc_bias]);
392 :
393 : /** get tensors */
394 : Tensor &projected_query =
395 186 : context.getTensor(weight_idx[AttentionParams::projected_query]);
396 : Tensor &projected_key =
397 186 : context.getTensor(weight_idx[AttentionParams::projected_key]);
398 : Tensor &projected_value =
399 186 : context.getTensor(weight_idx[AttentionParams::projected_value]);
400 :
401 : Tensor &attention_weight =
402 186 : context.getTensor(weight_idx[AttentionParams::attention_weight]);
403 : Tensor &attention_output =
404 186 : context.getTensor(weight_idx[AttentionParams::attention_output]);
405 :
406 186 : const TensorDim query_dim = query.getDim();
407 186 : const unsigned int batch_size = query_dim.batch();
408 186 : const unsigned int query_height = query_dim.height();
409 186 : const TensorDim key_dim = key.getDim();
410 186 : const unsigned int key_height = key_dim.height();
411 186 : const TensorDim value_dim = value.getDim();
412 186 : const unsigned int value_height = value_dim.height();
413 :
414 186 : query.dot(query_fc_weight, projected_query);
415 186 : if (!disable_bias) {
416 180 : projected_query.add_i(query_fc_bias);
417 : }
418 186 : key.dot(key_fc_weight, projected_key);
419 186 : if (!disable_bias) {
420 180 : projected_key.add_i(key_fc_bias);
421 : }
422 186 : value.dot(value_fc_weight, projected_value);
423 186 : if (!disable_bias) {
424 180 : projected_value.add_i(value_fc_bias);
425 : }
426 :
427 186 : projected_query.reshape(
428 372 : TensorDim({batch_size, query_height, num_heads, projected_query_dim_prop}));
429 186 : projected_key.reshape(
430 372 : TensorDim({batch_size, key_height, num_heads, projected_key_dim_prop}));
431 186 : projected_value.reshape(
432 186 : TensorDim({batch_size, value_height, num_heads, projected_value_dim_prop}));
433 :
434 372 : projected_query = projected_query.transpose("1:0:2");
435 372 : projected_key = projected_key.transpose("1:0:2");
436 372 : projected_value = projected_value.transpose("1:0:2");
437 :
438 : /** set tensor name to restore origin name cause origin name was remove during
439 : * transpose */
440 186 : projected_query.setName("multi_head_attention:projected_query");
441 186 : projected_key.setName("multi_head_attention:projected_key");
442 372 : projected_value.setName("multi_head_attention:projected_value");
443 :
444 186 : projected_query.reshape(TensorDim(
445 186 : {batch_size * num_heads, 1, query_height, projected_query_dim_prop}));
446 186 : projected_key.reshape(
447 372 : TensorDim({batch_size * num_heads, 1, key_height, projected_key_dim_prop}));
448 186 : projected_value.reshape(TensorDim(
449 : {batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
450 :
451 186 : attention_weight.reshape(
452 372 : TensorDim({batch_size * num_heads, 1, query_height, key_height}));
453 186 : attention_output.reshape(TensorDim(
454 : {batch_size * num_heads, 1, query_height, projected_value_dim_prop}));
455 :
456 : /** scaled dot product attention */
457 186 : projected_query.dotBatched(projected_key, attention_weight, false, true);
458 186 : attention_weight.multiply_i(1.0f /
459 186 : std::sqrt((float)projected_query_dim_prop));
460 :
461 186 : if (provide_attention_mask) {
462 : // Tensor &attention_mask =
463 : // context.getTensor(weight_idx[AttentionParams::attention_mask]);
464 : /** @todo: enable bool type tensor */
465 : // if (torch_ref) {
466 : // attention_mask.setValue(-1e9);
467 : // } else {
468 : // // flip
469 : // attention_mask.setValue(1);
470 : // attention_mask.subtract_i(mask);
471 :
472 : // attention_mask.multiply_i(-1e9);
473 : // }
474 : // attention_mask.multiply_i(mask);
475 : // attention_weight.add_i(attention_mask);
476 :
477 60 : attention_weight.reshape(
478 60 : TensorDim({batch_size, num_heads, query_height, key_height}));
479 60 : attention_weight.add_i(mask);
480 60 : attention_weight.reshape(
481 120 : TensorDim({batch_size * num_heads, 1, query_height, key_height}));
482 : }
483 :
484 : sm.run_fn(attention_weight, attention_weight);
485 :
486 186 : if (return_attention_weight ==
487 : props::ReturnAttentionWeightInfo::Enum::before) {
488 5 : ret_attention_weight.copyData(attention_weight);
489 : }
490 :
491 186 : if (enable_dropout) {
492 : Tensor &dropout_mask =
493 0 : context.getTensor(weight_idx[AttentionParams::dropout_mask]);
494 0 : dropout_mask.dropout_mask(dropout_rate);
495 0 : attention_weight.multiply_i(dropout_mask);
496 : }
497 :
498 186 : if (return_attention_weight ==
499 : props::ReturnAttentionWeightInfo::Enum::after) {
500 24 : if (average_attention_weight) {
501 24 : attention_weight.reshape(
502 24 : TensorDim({batch_size, num_heads, query_height, key_height}));
503 24 : attention_weight.sum(1, ret_attention_weight, 1, 0);
504 24 : ret_attention_weight.divide_i(static_cast<float>(num_heads));
505 24 : attention_weight.reshape(
506 48 : TensorDim({batch_size * num_heads, 1, query_height, key_height}));
507 : } else {
508 0 : ret_attention_weight.copyData(attention_weight);
509 : }
510 : }
511 :
512 186 : attention_weight.dotBatched(projected_value, attention_output);
513 :
514 186 : attention_output.reshape(
515 186 : TensorDim({batch_size, num_heads, query_height, projected_value_dim_prop}));
516 :
517 372 : attention_output = attention_output.transpose("1:0:2");
518 :
519 : /** set tensor name to restore origin name cause origin name was remove during
520 : * transpose */
521 372 : attention_output.setName("multi_head_attention:attention_output");
522 :
523 186 : attention_output.reshape(TensorDim(
524 186 : {batch_size * query_height, 1, 1, num_heads * projected_value_dim_prop}));
525 :
526 186 : attention_output.dot(fc_weight, output);
527 186 : if (!disable_bias) {
528 180 : output.add_i(fc_bias);
529 : }
530 :
531 : /** restore shape */
532 186 : projected_query.reshape(TensorDim(
533 186 : {batch_size, 1, query_height, num_heads * projected_query_dim_prop}));
534 186 : projected_key.reshape(
535 372 : TensorDim({batch_size, 1, key_height, num_heads * projected_key_dim_prop}));
536 186 : projected_value.reshape(TensorDim(
537 : {batch_size, 1, value_height, num_heads * projected_value_dim_prop}));
538 :
539 186 : attention_weight.reshape(
540 372 : TensorDim({batch_size, num_heads, query_height, key_height}));
541 186 : attention_output.reshape(TensorDim(
542 : {batch_size, 1, query_height, num_heads * projected_value_dim_prop}));
543 186 : }
544 :
545 0 : void MultiHeadAttentionLayer::incremental_forwarding(RunLayerContext &context,
546 : unsigned int from,
547 : unsigned int to,
548 : bool training) {
549 : const bool disable_bias =
550 0 : std::get<props::DisableBias>(*layer_impl_props).get();
551 :
552 : const unsigned int num_heads =
553 0 : std::get<props::NumHeads>(multi_head_attention_props).get();
554 : const unsigned int projected_key_dim_prop =
555 0 : std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
556 : const unsigned int projected_value_dim_prop =
557 0 : std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
558 : const float dropout_rate =
559 0 : std::get<props::DropOutRate>(multi_head_attention_props).get();
560 : const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
561 0 : std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
562 : const bool average_attention_weight =
563 0 : std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
564 :
565 0 : const bool provide_attention_mask = context.getNumInputs() == 4;
566 : const unsigned int projected_query_dim_prop = projected_key_dim_prop;
567 : const bool enable_dropout = dropout_rate > epsilon;
568 :
569 : /** get inputs/outputs */
570 0 : Tensor &query = context.getInput(INOUT_INDEX::QUERY);
571 0 : Tensor &key = context.getInput(INOUT_INDEX::KEY);
572 0 : Tensor &value = context.getInput(INOUT_INDEX::VALUE);
573 :
574 0 : Tensor empty_tensor("empty", value.getFormat(), value.getDataType());
575 :
576 : Tensor &mask =
577 0 : provide_attention_mask ? context.getInput(INOUT_INDEX::MASK) : empty_tensor;
578 :
579 0 : TensorDim query_dim = query.getDim();
580 0 : TensorDim key_dim = key.getDim();
581 0 : TensorDim value_dim = value.getDim();
582 :
583 0 : Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
584 :
585 0 : TensorDim output_dim = output.getDim();
586 : Tensor &ret_attention_weight =
587 : return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
588 0 : ? context.getOutput(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
589 : : empty_tensor;
590 :
591 : /** get weights */
592 : Tensor &query_fc_weight =
593 0 : context.getWeight(weight_idx[AttentionParams::query_fc_weight]);
594 : Tensor &query_fc_bias =
595 : disable_bias
596 0 : ? empty_tensor
597 0 : : context.getWeight(weight_idx[AttentionParams::query_fc_bias]);
598 : Tensor &key_fc_weight =
599 0 : context.getWeight(weight_idx[AttentionParams::key_fc_weight]);
600 : Tensor &key_fc_bias =
601 0 : disable_bias ? empty_tensor
602 0 : : context.getWeight(weight_idx[AttentionParams::key_fc_bias]);
603 : Tensor &value_fc_weight =
604 0 : context.getWeight(weight_idx[AttentionParams::value_fc_weight]);
605 : Tensor &value_fc_bias =
606 : disable_bias
607 0 : ? empty_tensor
608 0 : : context.getWeight(weight_idx[AttentionParams::value_fc_bias]);
609 0 : Tensor &fc_weight = context.getWeight(weight_idx[AttentionParams::fc_weight]);
610 : Tensor &fc_bias = disable_bias
611 0 : ? empty_tensor
612 0 : : context.getWeight(weight_idx[AttentionParams::fc_bias]);
613 :
614 : /** get tensors */
615 : Tensor &projected_query =
616 0 : context.getTensor(weight_idx[AttentionParams::projected_query]);
617 : Tensor &projected_key =
618 0 : context.getTensor(weight_idx[AttentionParams::projected_key]);
619 : Tensor &projected_value =
620 0 : context.getTensor(weight_idx[AttentionParams::projected_value]);
621 0 : Tensor &cache_key = context.getTensor(weight_idx[AttentionParams::cache_key]);
622 : Tensor &cache_value =
623 0 : context.getTensor(weight_idx[AttentionParams::cache_value]);
624 :
625 0 : TensorDim projected_query_dim = projected_query.getDim();
626 0 : TensorDim projected_key_dim = projected_key.getDim();
627 0 : TensorDim projected_value_dim = projected_value.getDim();
628 0 : TensorDim cache_key_dim = cache_key.getDim();
629 0 : TensorDim cache_value_dim = cache_value.getDim();
630 :
631 0 : TensorDim projected_query_step_dim = projected_query_dim;
632 :
633 0 : TensorDim projected_key_step_dim = projected_key_dim;
634 0 : TensorDim projected_value_step_dim = projected_value_dim;
635 0 : TensorDim cache_key_step_dim = cache_key_dim;
636 0 : TensorDim cache_value_step_dim = cache_value_dim;
637 0 : projected_query_step_dim.height(to - from);
638 :
639 0 : projected_key_step_dim.height(to);
640 0 : projected_value_step_dim.height(to);
641 0 : cache_key_step_dim.height(to - from);
642 0 : cache_value_step_dim.height(to - from);
643 :
644 : Tensor projected_query_step =
645 0 : projected_query.getSharedDataTensor(projected_query_step_dim, 0, true);
646 : Tensor projected_key_step =
647 0 : projected_key.getSharedDataTensor(projected_key_step_dim, 0, true);
648 : Tensor projected_value_step =
649 0 : projected_value.getSharedDataTensor(projected_value_step_dim, 0, true);
650 :
651 : Tensor cache_key_step = cache_key.getSharedDataTensor(
652 0 : cache_key_step_dim, from * cache_key_dim.width(), true);
653 : Tensor cache_value_step = cache_value.getSharedDataTensor(
654 0 : cache_value_step_dim, from * cache_value_dim.width(), true);
655 :
656 : TensorDim cached_key_dim = {cache_key_dim.batch(), cache_key_dim.channel(),
657 : to, cache_key_dim.width(),
658 0 : cache_key.getTensorType()};
659 : TensorDim cached_value_dim = {
660 : cache_value_dim.batch(), cache_value_dim.channel(), to,
661 0 : cache_value_dim.width(), cache_value.getTensorType()};
662 0 : Tensor cached_key = cache_key.getSharedDataTensor(cached_key_dim, 0, true);
663 : Tensor cached_value =
664 0 : cache_value.getSharedDataTensor(cached_value_dim, 0, true);
665 :
666 : Tensor &attention_weight =
667 0 : context.getTensor(weight_idx[AttentionParams::attention_weight]);
668 : Tensor &attention_output =
669 0 : context.getTensor(weight_idx[AttentionParams::attention_output]);
670 0 : TensorDim attention_weight_dim = attention_weight.getDim();
671 :
672 0 : TensorDim attention_weight_step_dim = attention_weight_dim;
673 0 : attention_weight_step_dim.height(to - from);
674 0 : attention_weight_step_dim.width(to);
675 :
676 : Tensor attention_weight_step =
677 0 : attention_weight.getSharedDataTensor(attention_weight_step_dim, 0, true);
678 :
679 0 : TensorDim attention_output_dim = attention_output.getDim();
680 0 : TensorDim attention_output_step_dim = attention_output_dim;
681 0 : attention_output_step_dim.height(to - from);
682 :
683 : Tensor attention_output_step =
684 0 : attention_output.getSharedDataTensor(attention_output_step_dim, 0, true);
685 :
686 0 : const unsigned int batch_size = query_dim.batch();
687 0 : const unsigned int query_height = query_dim.height();
688 0 : const unsigned int key_height = key_dim.height();
689 0 : const unsigned int value_height = value_dim.height();
690 :
691 0 : query.dot(query_fc_weight, projected_query_step);
692 0 : if (!disable_bias) {
693 0 : projected_query_step.add_i(query_fc_bias);
694 : }
695 0 : key.dot(key_fc_weight, cache_key_step);
696 0 : if (!disable_bias) {
697 0 : cache_key_step.add_i(key_fc_bias);
698 : }
699 0 : value.dot(value_fc_weight, cache_value_step);
700 0 : if (!disable_bias) {
701 0 : cache_value_step.add_i(value_fc_bias);
702 : }
703 :
704 0 : projected_query_step.reshape(
705 0 : TensorDim({batch_size, 1, num_heads, projected_query_dim_prop}));
706 0 : cached_key.reshape(
707 0 : TensorDim({batch_size, to, num_heads, projected_key_dim_prop}));
708 0 : cached_value.reshape(
709 0 : TensorDim({batch_size, to, num_heads, projected_value_dim_prop}));
710 :
711 0 : projected_query_step.transpose("1:0:2", projected_query_step);
712 0 : cached_key.transpose("1:0:2", projected_key_step);
713 0 : cached_value.transpose("1:0:2", projected_value_step);
714 :
715 0 : projected_query_step.reshape(
716 0 : TensorDim({batch_size * num_heads, 1, 1, projected_query_dim_prop}));
717 0 : projected_key_step.reshape(
718 0 : TensorDim({batch_size * num_heads, 1, to, projected_key_dim_prop}));
719 0 : projected_value_step.reshape(
720 0 : TensorDim({batch_size * num_heads, 1, to, projected_value_dim_prop}));
721 :
722 0 : attention_weight_step.reshape(TensorDim({batch_size * num_heads, 1, 1, to}));
723 0 : attention_output_step.reshape(
724 0 : TensorDim({batch_size * num_heads, 1, 1, projected_value_dim_prop}));
725 :
726 : /** scaled dot product attention */
727 0 : projected_query_step.dotBatched(projected_key_step, attention_weight_step,
728 : false, true);
729 0 : attention_weight_step.multiply_i(1 / sqrt((float)projected_query_dim_prop));
730 :
731 0 : if (!from) {
732 0 : unsigned int mask_size = attention_weight_step.getDim().width();
733 : unsigned int mask_dim_height = mask_size;
734 : unsigned int mask_dim_width = mask_size;
735 :
736 0 : Tensor causal_mask(TensorDim{1, 1, mask_size, mask_size,
737 0 : attention_weight_step.getTensorType()});
738 :
739 0 : causal_mask.setZero();
740 :
741 : #ifdef ENABLE_FP16
742 : #define _MASK_NUM -1e4
743 : #else
744 : #define _MASK_NUM -1e10
745 : #endif
746 :
747 0 : for (unsigned int i = 0; i < mask_dim_height; ++i) {
748 0 : for (unsigned int j = i + 1; j < mask_dim_width; ++j) {
749 0 : causal_mask.setValue(0, 0, i, j, _MASK_NUM);
750 : }
751 : }
752 :
753 0 : attention_weight_step.add_i(causal_mask);
754 0 : }
755 :
756 : sm.run_fn(attention_weight_step, attention_weight_step);
757 :
758 0 : attention_weight_step.dotBatched(projected_value_step, attention_output_step);
759 :
760 0 : attention_output_step.reshape(
761 0 : TensorDim({batch_size, num_heads, to - from, projected_value_dim_prop}));
762 :
763 0 : attention_output_step = attention_output_step.transpose("1:0:2");
764 :
765 0 : attention_output_step.reshape(TensorDim(
766 0 : {batch_size * (to - from), 1, 1, num_heads * projected_value_dim_prop}));
767 :
768 0 : attention_output_step.dot(fc_weight, output);
769 0 : if (!disable_bias) {
770 0 : output.add_i(fc_bias);
771 : }
772 0 : }
773 :
774 85 : void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) {
775 : const unsigned int num_heads =
776 85 : std::get<props::NumHeads>(multi_head_attention_props).get();
777 : const unsigned int projected_key_dim_prop =
778 85 : std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
779 : const unsigned int projected_value_dim_prop =
780 85 : std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
781 : const float dropout_rate =
782 85 : std::get<props::DropOutRate>(multi_head_attention_props).get();
783 : const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
784 85 : std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
785 : const bool average_attention_weight =
786 85 : std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
787 :
788 85 : const bool provide_attention_mask = context.getNumInputs() == 4;
789 : const unsigned int projected_query_dim_prop = projected_key_dim_prop;
790 :
791 85 : Tensor empty_tensor;
792 :
793 85 : Tensor &query = context.getInput(INOUT_INDEX::QUERY);
794 85 : Tensor &key = context.getInput(INOUT_INDEX::KEY);
795 85 : Tensor &value = context.getInput(INOUT_INDEX::VALUE);
796 : const Tensor &incoming_derivative =
797 85 : context.getIncomingDerivative(INOUT_INDEX::OUTPUT);
798 : const Tensor &d_ret_attention_weight =
799 : return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
800 85 : ? context.getIncomingDerivative(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
801 85 : : empty_tensor;
802 :
803 85 : Tensor &fc_weight = context.getWeight(weight_idx[AttentionParams::fc_weight]);
804 :
805 : Tensor &projected_query =
806 85 : context.getTensor(weight_idx[AttentionParams::projected_query]);
807 : Tensor &d_projected_query =
808 85 : context.getTensorGrad(weight_idx[AttentionParams::projected_query]);
809 : Tensor &projected_key =
810 85 : context.getTensor(weight_idx[AttentionParams::projected_key]);
811 : Tensor &d_projected_key =
812 85 : context.getTensorGrad(weight_idx[AttentionParams::projected_key]);
813 : Tensor &projected_value =
814 85 : context.getTensor(weight_idx[AttentionParams::projected_value]);
815 : Tensor &d_projected_value =
816 85 : context.getTensorGrad(weight_idx[AttentionParams::projected_value]);
817 :
818 : Tensor &attention_weight =
819 85 : context.getTensor(weight_idx[AttentionParams::attention_weight]);
820 : Tensor &d_attention_weight =
821 85 : context.getTensorGrad(weight_idx[AttentionParams::attention_weight]);
822 : Tensor &d_attention_output =
823 85 : context.getTensorGrad(weight_idx[AttentionParams::attention_output]);
824 :
825 85 : const TensorDim query_dim = query.getDim();
826 85 : const unsigned int batch_size = query_dim.batch();
827 85 : const unsigned int query_height = query_dim.height();
828 85 : const TensorDim key_dim = key.getDim();
829 85 : const unsigned int key_height = key_dim.height();
830 85 : const TensorDim value_dim = value.getDim();
831 85 : const unsigned int value_height = value_dim.height();
832 :
833 85 : d_attention_output.dot_deriv_wrt_1(fc_weight, incoming_derivative);
834 :
835 85 : d_attention_output.reshape(
836 85 : TensorDim({batch_size, query_height, num_heads, projected_value_dim_prop}));
837 :
838 170 : d_attention_output = d_attention_output.transpose("1:0:2");
839 :
840 : /** set tensor name to restore origin name cause origin name was remove
841 : * during transpose */
842 170 : d_attention_output.setName("multi_head_attention:attention_output:grad");
843 :
844 85 : projected_query.reshape(TensorDim(
845 85 : {batch_size * num_heads, 1, query_height, projected_query_dim_prop}));
846 85 : d_projected_query.reshape(TensorDim(
847 : {batch_size * num_heads, 1, query_height, projected_query_dim_prop}));
848 85 : projected_key.reshape(
849 170 : TensorDim({batch_size * num_heads, 1, key_height, projected_key_dim_prop}));
850 85 : d_projected_key.reshape(
851 170 : TensorDim({batch_size * num_heads, 1, key_height, projected_key_dim_prop}));
852 85 : projected_value.reshape(TensorDim(
853 : {batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
854 85 : d_projected_value.reshape(TensorDim(
855 : {batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
856 :
857 85 : attention_weight.reshape(
858 170 : TensorDim({batch_size * num_heads, 1, query_height, key_height}));
859 85 : d_attention_weight.reshape(
860 170 : TensorDim({batch_size * num_heads, 1, query_height, key_height}));
861 85 : d_attention_output.reshape(TensorDim(
862 : {batch_size * num_heads, 1, query_height, projected_value_dim_prop}));
863 :
864 85 : d_attention_weight.dot_batched_deriv_wrt_1(projected_value,
865 : d_attention_output);
866 85 : attention_weight.dot_batched_deriv_wrt_2(d_projected_value,
867 : d_attention_output);
868 :
869 85 : if (return_attention_weight ==
870 : props::ReturnAttentionWeightInfo::Enum::after) {
871 12 : const float scale = average_attention_weight ? 1 / (float)num_heads : 1;
872 12 : d_attention_weight.add_i(d_ret_attention_weight, scale);
873 : }
874 :
875 85 : if (dropout_rate > epsilon) {
876 : Tensor &dropout_mask =
877 0 : context.getTensor(weight_idx[AttentionParams::dropout_mask]);
878 0 : d_attention_weight.multiply_i(dropout_mask);
879 : }
880 :
881 85 : if (return_attention_weight ==
882 : props::ReturnAttentionWeightInfo::Enum::before) {
883 1 : d_attention_weight.add_i(d_ret_attention_weight);
884 : }
885 :
886 85 : sm.run_prime_fn(attention_weight, d_attention_weight, d_attention_weight);
887 85 : if (provide_attention_mask) {
888 30 : Tensor &d_mask = context.getOutgoingDerivative(INOUT_INDEX::MASK);
889 30 : d_mask.copyData(d_attention_weight);
890 : }
891 85 : d_attention_weight.multiply_i(
892 85 : 1 / sqrt((float)projected_query_dim_prop)); /** scale */
893 :
894 85 : d_projected_query.dot_batched_deriv_wrt_1(projected_key, d_attention_weight,
895 : false, true);
896 85 : projected_query.dot_batched_deriv_wrt_2(d_projected_key, d_attention_weight,
897 : false, true);
898 :
899 85 : d_projected_query.reshape(
900 170 : TensorDim({batch_size, num_heads, query_height, projected_query_dim_prop}));
901 85 : d_projected_key.reshape(
902 170 : TensorDim({batch_size, num_heads, key_height, projected_key_dim_prop}));
903 85 : d_projected_value.reshape(
904 85 : TensorDim({batch_size, num_heads, value_height, projected_value_dim_prop}));
905 :
906 170 : d_projected_query = d_projected_query.transpose("1:0:2");
907 170 : d_projected_key = d_projected_key.transpose("1:0:2");
908 170 : d_projected_value = d_projected_value.transpose("1:0:2");
909 :
910 : /** set tensor name to restore origin name cause origin name was remove
911 : * during transpose */
912 85 : d_projected_query.setName("multi_head_attention:projected_query:grad");
913 85 : d_projected_key.setName("multi_head_attention:projected_key:grad");
914 170 : d_projected_value.setName("multi_head_attention:projected_value:grad");
915 :
916 : /** restore shape */
917 85 : projected_query.reshape(TensorDim(
918 85 : {batch_size, 1, query_height, num_heads * projected_query_dim_prop}));
919 85 : d_projected_query.reshape(TensorDim(
920 85 : {batch_size * query_height, 1, 1, num_heads * projected_query_dim_prop}));
921 85 : projected_key.reshape(
922 170 : TensorDim({batch_size, 1, key_height, num_heads * projected_key_dim_prop}));
923 85 : d_projected_key.reshape(TensorDim(
924 85 : {batch_size * key_height, 1, 1, num_heads * projected_key_dim_prop}));
925 85 : projected_value.reshape(TensorDim(
926 85 : {batch_size, 1, value_height, num_heads * projected_value_dim_prop}));
927 85 : d_projected_value.reshape(TensorDim(
928 85 : {batch_size * value_height, 1, 1, num_heads * projected_value_dim_prop}));
929 :
930 85 : attention_weight.reshape(
931 170 : TensorDim({batch_size, num_heads, query_height, key_height}));
932 85 : d_attention_weight.reshape(
933 170 : TensorDim({batch_size, num_heads, query_height, key_height}));
934 85 : d_attention_output.reshape(TensorDim(
935 : {batch_size, 1, query_height, num_heads * projected_value_dim_prop}));
936 85 : }
937 :
938 85 : void MultiHeadAttentionLayer::calcDerivative(RunLayerContext &context) {
939 85 : if (!context.getTrainable()) {
940 0 : calcCommonDerivative(context);
941 : }
942 :
943 85 : Tensor &query = context.getInput(INOUT_INDEX::QUERY);
944 85 : Tensor &d_query = context.getOutgoingDerivative(INOUT_INDEX::QUERY);
945 85 : Tensor &key = context.getInput(INOUT_INDEX::KEY);
946 85 : Tensor &d_key = context.getOutgoingDerivative(INOUT_INDEX::KEY);
947 85 : Tensor &value = context.getInput(INOUT_INDEX::VALUE);
948 85 : Tensor &d_value = context.getOutgoingDerivative(INOUT_INDEX::VALUE);
949 : /** d_mask will be calculated in calcCommonDerivative */
950 :
951 : Tensor &query_fc_weight =
952 85 : context.getWeight(weight_idx[AttentionParams::query_fc_weight]);
953 : Tensor &key_fc_weight =
954 85 : context.getWeight(weight_idx[AttentionParams::key_fc_weight]);
955 : Tensor &value_fc_weight =
956 85 : context.getWeight(weight_idx[AttentionParams::value_fc_weight]);
957 :
958 : Tensor &d_projected_query =
959 85 : context.getTensorGrad(weight_idx[AttentionParams::projected_query]);
960 : Tensor &d_projected_key =
961 85 : context.getTensorGrad(weight_idx[AttentionParams::projected_key]);
962 : Tensor &d_projected_value =
963 85 : context.getTensorGrad(weight_idx[AttentionParams::projected_value]);
964 :
965 85 : const TensorDim query_dim = query.getDim();
966 85 : const TensorDim key_dim = key.getDim();
967 85 : const TensorDim value_dim = value.getDim();
968 :
969 85 : d_query.dot_deriv_wrt_1(query_fc_weight, d_projected_query);
970 85 : d_key.dot_deriv_wrt_1(key_fc_weight, d_projected_key);
971 85 : d_value.dot_deriv_wrt_1(value_fc_weight, d_projected_value, false, false);
972 85 : }
973 :
974 85 : void MultiHeadAttentionLayer::calcGradient(RunLayerContext &context) {
975 85 : calcCommonDerivative(context);
976 :
977 : const bool disable_bias =
978 85 : std::get<props::DisableBias>(*layer_impl_props).get();
979 :
980 : const unsigned int num_heads =
981 85 : std::get<props::NumHeads>(multi_head_attention_props).get();
982 : const unsigned int projected_key_dim_prop =
983 85 : std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
984 : const unsigned int projected_value_dim_prop =
985 85 : std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
986 : const unsigned int output_shape =
987 85 : std::get<props::OutputShape>(multi_head_attention_props).get();
988 :
989 : const unsigned int projected_query_dim_prop = projected_key_dim_prop;
990 :
991 85 : Tensor &query = context.getInput(INOUT_INDEX::QUERY);
992 85 : Tensor &key = context.getInput(INOUT_INDEX::KEY);
993 85 : Tensor &value = context.getInput(INOUT_INDEX::VALUE);
994 : const Tensor &incoming_derivative =
995 85 : context.getIncomingDerivative(INOUT_INDEX::OUTPUT);
996 :
997 : Tensor &d_query_fc_weight =
998 85 : context.getWeightGrad(weight_idx[AttentionParams::query_fc_weight]);
999 : Tensor &d_key_fc_weight =
1000 85 : context.getWeightGrad(weight_idx[AttentionParams::key_fc_weight]);
1001 : Tensor &d_value_fc_weight =
1002 85 : context.getWeightGrad(weight_idx[AttentionParams::value_fc_weight]);
1003 : Tensor &d_fc_weight =
1004 85 : context.getWeightGrad(weight_idx[AttentionParams::fc_weight]);
1005 :
1006 85 : Tensor empty_tensor;
1007 : Tensor &d_query_fc_bias =
1008 : disable_bias
1009 85 : ? empty_tensor
1010 82 : : context.getWeightGrad(weight_idx[AttentionParams::query_fc_bias]);
1011 : Tensor &d_key_fc_bias =
1012 : disable_bias
1013 : ? empty_tensor
1014 82 : : context.getWeightGrad(weight_idx[AttentionParams::key_fc_bias]);
1015 : Tensor &d_value_fc_bias =
1016 : disable_bias
1017 85 : ? empty_tensor
1018 82 : : context.getWeightGrad(weight_idx[AttentionParams::value_fc_bias]);
1019 : Tensor &d_fc_bias =
1020 : disable_bias ? empty_tensor
1021 82 : : context.getWeightGrad(weight_idx[AttentionParams::fc_bias]);
1022 :
1023 : Tensor &d_projected_query =
1024 85 : context.getTensorGrad(weight_idx[AttentionParams::projected_query]);
1025 : Tensor &d_projected_key =
1026 85 : context.getTensorGrad(weight_idx[AttentionParams::projected_key]);
1027 : Tensor &d_projected_value =
1028 85 : context.getTensorGrad(weight_idx[AttentionParams::projected_value]);
1029 :
1030 : Tensor &attention_output =
1031 85 : context.getTensor(weight_idx[AttentionParams::attention_output]);
1032 :
1033 85 : const TensorDim query_dim = query.getDim();
1034 85 : const unsigned int batch_size = query_dim.batch();
1035 85 : const unsigned int query_height = query_dim.height();
1036 85 : const TensorDim key_dim = key.getDim();
1037 85 : const unsigned int key_height = key_dim.height();
1038 85 : const TensorDim value_dim = value.getDim();
1039 85 : const unsigned int value_height = value_dim.height();
1040 :
1041 85 : attention_output.dot_deriv_wrt_2(
1042 : d_fc_weight, incoming_derivative, false, false,
1043 85 : !context.isGradientFirstAccess(weight_idx[AttentionParams::fc_weight]));
1044 :
1045 85 : if (!disable_bias) {
1046 82 : Tensor incoming_derivative_ = incoming_derivative;
1047 82 : incoming_derivative_.reshape(
1048 82 : TensorDim({batch_size * query_height, 1, 1, output_shape}));
1049 82 : incoming_derivative_.sum(
1050 : 0, d_fc_bias, 1,
1051 82 : !context.isGradientFirstAccess(weight_idx[AttentionParams::fc_bias]));
1052 82 : }
1053 :
1054 85 : query.dot_deriv_wrt_2(d_query_fc_weight, d_projected_query, false, false,
1055 85 : !context.isGradientFirstAccess(
1056 : weight_idx[AttentionParams::query_fc_weight]));
1057 85 : if (!disable_bias) {
1058 82 : d_projected_query.reshape(TensorDim(
1059 82 : {batch_size * query_height, 1, 1, num_heads * projected_query_dim_prop}));
1060 82 : d_projected_query.sum(0, d_query_fc_bias, 1,
1061 82 : !context.isGradientFirstAccess(
1062 : weight_idx[AttentionParams::query_fc_bias]));
1063 82 : d_projected_query.reshape(TensorDim(
1064 : {batch_size, 1, query_height, num_heads * projected_query_dim_prop}));
1065 : }
1066 :
1067 85 : key.dot_deriv_wrt_2(
1068 : d_key_fc_weight, d_projected_key, false, false,
1069 85 : !context.isGradientFirstAccess(weight_idx[AttentionParams::key_fc_weight]));
1070 85 : if (!disable_bias) {
1071 82 : d_projected_key.reshape(TensorDim(
1072 82 : {batch_size * key_height, 1, 1, num_heads * projected_key_dim_prop}));
1073 82 : d_projected_key.sum(
1074 : 0, d_key_fc_bias, 1,
1075 82 : !context.isGradientFirstAccess(weight_idx[AttentionParams::key_fc_bias]));
1076 82 : d_projected_key.reshape(TensorDim(
1077 : {batch_size, 1, key_height, num_heads * projected_key_dim_prop}));
1078 : }
1079 :
1080 85 : value.dot_deriv_wrt_2(d_value_fc_weight, d_projected_value, false, false,
1081 85 : !context.isGradientFirstAccess(
1082 : weight_idx[AttentionParams::value_fc_weight]));
1083 85 : if (!disable_bias) {
1084 82 : d_projected_value.reshape(TensorDim(
1085 82 : {batch_size * value_height, 1, 1, num_heads * projected_value_dim_prop}));
1086 82 : d_projected_value.sum(0, d_value_fc_bias, 1,
1087 82 : !context.isGradientFirstAccess(
1088 : weight_idx[AttentionParams::value_fc_bias]));
1089 82 : d_projected_value.reshape(TensorDim(
1090 : {batch_size, 1, value_height, num_heads * projected_value_dim_prop}));
1091 : }
1092 85 : }
1093 :
1094 665 : void MultiHeadAttentionLayer::setProperty(
1095 : const std::vector<std::string> &values) {
1096 665 : auto remain_props = loadProperties(values, multi_head_attention_props);
1097 663 : LayerImpl::setProperty(remain_props);
1098 663 : }
1099 :
1100 108 : void MultiHeadAttentionLayer::setBatch(RunLayerContext &context,
1101 : unsigned int batch) {
1102 : const float dropout_rate =
1103 108 : std::get<props::DropOutRate>(multi_head_attention_props).get();
1104 :
1105 108 : context.updateTensor(weight_idx[AttentionParams::projected_query], batch);
1106 108 : context.updateTensor(weight_idx[AttentionParams::projected_key], batch);
1107 108 : context.updateTensor(weight_idx[AttentionParams::projected_value], batch);
1108 108 : context.updateTensor(weight_idx[AttentionParams::cache_key], batch);
1109 108 : context.updateTensor(weight_idx[AttentionParams::cache_value], batch);
1110 : // context.updateTensor(weight_idx[AttentionParams::cache_value], batch);
1111 108 : context.updateTensor(weight_idx[AttentionParams::attention_weight], batch);
1112 108 : if (dropout_rate > epsilon) {
1113 0 : context.updateTensor(weight_idx[AttentionParams::dropout_mask], batch);
1114 : }
1115 108 : context.updateTensor(weight_idx[AttentionParams::attention_output], batch);
1116 108 : }
1117 :
1118 54 : void MultiHeadAttentionLayer::exportTo(
1119 : Exporter &exporter, const ml::train::ExportMethods &method) const {
1120 54 : LayerImpl::exportTo(exporter, method);
1121 54 : exporter.saveResult(multi_head_attention_props, method, this);
1122 54 : }
1123 :
1124 : } /* namespace nntrainer */
|