LCOV - code coverage report
Current view: top level - nntrainer/layers - multi_head_attention_layer.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 73.7 % 552 407
Test Date: 2025-12-14 20:38:17 Functions: 78.6 % 14 11

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2022 hyeonseok Lee <hs89.lee@samsung.com>
       4              :  *
       5              :  * @file   multi_head_attention_layer.cpp
       6              :  * @date   08 July 2022
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  *         https://arxiv.org/abs/1706.03762
       9              :  * @author hyeonseok Lee <hs89.lee@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  * @brief  This is MultiHeadAttention Layer Class for Neural Network
      12              :  *
      13              :  */
      14              : 
      15              : #include <cmath>
      16              : 
      17              : #include <layer_context.h>
      18              : #include <multi_head_attention_layer.h>
      19              : #include <nntrainer_error.h>
      20              : #include <nntrainer_log.h>
      21              : #include <node_exporter.h>
      22              : 
      23              : namespace nntrainer {
      24              : 
      25          149 : MultiHeadAttentionLayer::MultiHeadAttentionLayer() :
      26              :   multi_head_attention_props(
      27          298 :     props::NumHeads(), props::ProjectedKeyDim(), props::ProjectedValueDim(),
      28          447 :     props::OutputShape(), props::DropOutRate(), props::ReturnAttentionWeight(),
      29            0 :     props::AverageAttentionWeight()),
      30          149 :   sm(ActivationType::ACT_SOFTMAX),
      31          298 :   epsilon(1e-3f) {
      32              :   weight_idx.fill(std::numeric_limits<unsigned>::max());
      33          149 : }
      34              : 
      35          298 : MultiHeadAttentionLayer::~MultiHeadAttentionLayer() {}
      36              : 
      37              : enum INOUT_INDEX {
      38              :   /** input index */
      39              :   QUERY = 0,
      40              :   KEY = 1,
      41              :   VALUE = 2,
      42              :   MASK = 3,
      43              :   /** output index */
      44              :   OUTPUT = 0,
      45              :   RETURN_ATTENTION_WEIGHT = 1,
      46              : };
      47              : 
      48              : enum AttentionParams {
      49              :   query_fc_weight,
      50              :   query_fc_bias,
      51              :   key_fc_weight,
      52              :   key_fc_bias,
      53              :   value_fc_weight,
      54              :   value_fc_bias,
      55              :   fc_weight,
      56              :   fc_bias,
      57              :   projected_query,
      58              :   projected_key,
      59              :   projected_value,
      60              :   cache_key,
      61              :   cache_value,
      62              :   /** intended comment for later use of attention_mask */
      63              :   // attention_mask,
      64              :   attention_weight,
      65              :   dropout_mask,
      66              :   attention_output,
      67              : };
      68              : 
      69          121 : void MultiHeadAttentionLayer::finalize(InitLayerContext &context) {
      70          121 :   NNTR_THROW_IF(context.getNumInputs() < 3 || context.getNumInputs() > 4,
      71              :                 std::invalid_argument)
      72              :     << "Multi head Attention layer needs 3 or 4 inputs. (query, key, value and "
      73              :        "mask is optional";
      74              :   const bool provide_attention_mask = context.getNumInputs() == 4;
      75              : 
      76              :   TensorDim::TensorType weight_type = {context.getFormat(),
      77              :                                        context.getWeightDataType()};
      78              : 
      79              :   TensorDim::TensorType activation_type = {context.getFormat(),
      80              :                                            context.getActivationDataType()};
      81              : 
      82          121 :   TensorDim empty_dim(activation_type);
      83              : 
      84              :   const std::vector<TensorDim> &input_dims = context.getInputDimensions();
      85              :   const TensorDim &query_dim = input_dims[INOUT_INDEX::QUERY];
      86              :   const TensorDim &key_dim = input_dims[INOUT_INDEX::KEY];
      87              :   const TensorDim &value_dim = input_dims[INOUT_INDEX::VALUE];
      88              :   const TensorDim &mask_dim =
      89          121 :     provide_attention_mask ? input_dims[INOUT_INDEX::MASK] : empty_dim;
      90              : 
      91          121 :   const unsigned int batch_size = query_dim.batch();
      92          121 :   const unsigned int query_height = query_dim.height();
      93          121 :   const unsigned int query_width = query_dim.width();
      94          121 :   const unsigned int key_height = key_dim.height();
      95          121 :   const unsigned int key_width = key_dim.width();
      96          121 :   const unsigned int value_height = value_dim.height();
      97          121 :   const unsigned int value_width = value_dim.width();
      98              : 
      99              :   const bool disable_bias =
     100          121 :     std::get<props::DisableBias>(*layer_impl_props).get();
     101              :   auto &weight_initializer =
     102          121 :     std::get<props::WeightInitializer>(*layer_impl_props).get();
     103              :   auto &weight_regularizer =
     104              :     std::get<props::WeightRegularizer>(*layer_impl_props);
     105              :   auto &weight_regularizer_constant =
     106              :     std::get<props::WeightRegularizerConstant>(*layer_impl_props);
     107              :   const float &weight_decay =
     108          121 :     std::get<props::WeightDecay>(*layer_impl_props).get();
     109              : 
     110          121 :   NNTR_THROW_IF(std::get<props::NumHeads>(multi_head_attention_props).empty(),
     111              :                 std::invalid_argument)
     112              :     << "num_heads property is not provided for layer " << context.getName();
     113              :   const unsigned int num_heads =
     114          121 :     std::get<props::NumHeads>(multi_head_attention_props).get();
     115              : 
     116          121 :   if (std::get<props::ProjectedKeyDim>(multi_head_attention_props).empty()) {
     117           54 :     NNTR_THROW_IF(query_width % num_heads, std::invalid_argument)
     118              :       << "query_width: " << query_width
     119              :       << " is not divisible by num_heads: " << num_heads << " for layer "
     120              :       << context.getName();
     121              : 
     122           54 :     ml_logw("[multi head attention] ProjectedKeyDim property is not given. "
     123              :             "Default value(query_width / num_heads) is set");
     124              : 
     125              :     std::get<props::ProjectedKeyDim>(multi_head_attention_props)
     126           54 :       .set(query_width / num_heads);
     127              :   }
     128              :   const unsigned int projected_key_dim_prop =
     129          121 :     std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
     130              : 
     131          121 :   if (std::get<props::ProjectedValueDim>(multi_head_attention_props).empty()) {
     132              :     std::get<props::ProjectedValueDim>(multi_head_attention_props)
     133           66 :       .set(projected_key_dim_prop);
     134              :   }
     135              :   const unsigned int projected_value_dim_prop =
     136          121 :     std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
     137              : 
     138          121 :   if (std::get<props::OutputShape>(multi_head_attention_props).empty()) {
     139           66 :     std::get<props::OutputShape>(multi_head_attention_props).set(query_width);
     140              :   }
     141              :   const unsigned int output_shape =
     142          121 :     std::get<props::OutputShape>(multi_head_attention_props).get();
     143              : 
     144              :   const float dropout_rate =
     145          121 :     std::get<props::DropOutRate>(multi_head_attention_props).get();
     146              : 
     147          121 :   if (std::get<props::AverageAttentionWeight>(multi_head_attention_props)
     148              :         .empty()) {
     149              :     std::get<props::AverageAttentionWeight>(multi_head_attention_props)
     150           66 :       .set(true);
     151              :   }
     152              :   const bool average_attention_weight =
     153          121 :     std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
     154              : 
     155              :   const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
     156          121 :     std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
     157              : 
     158          121 :   const unsigned int projected_query_dim_prop = projected_key_dim_prop;
     159              : 
     160          121 :   if (activation_type.data_type == TensorDim::DataType::FP32) {
     161          121 :     sm.setActiFunc(ActivationType::ACT_SOFTMAX);
     162            0 :   } else if (activation_type.data_type == TensorDim::DataType::FP16) {
     163              : #ifdef ENABLE_FP16
     164              :     sm.setActiFunc<_FP16>(ActivationType::ACT_SOFTMAX);
     165              : #else
     166            0 :     throw std::invalid_argument("Error: enable-fp16 is not enabled");
     167              : #endif
     168              :   }
     169              : 
     170              :   // sm.setActiFunc(ActivationType::ACT_SOFTMAX);
     171              : 
     172          121 :   NNTR_THROW_IF(query_dim.channel() != 1, std::invalid_argument)
     173            0 :     << "Dimension of input query channel: " << query_dim.channel()
     174              :     << " is not 1 for layer " << context.getName();
     175          121 :   NNTR_THROW_IF(key_dim.channel() != 1, std::invalid_argument)
     176            0 :     << "Dimension of input key channel: " << key_dim.channel()
     177              :     << " is not 1 for layer " << context.getName();
     178          121 :   NNTR_THROW_IF(value_dim.channel() != 1, std::invalid_argument)
     179            0 :     << "Dimension of input value channel: " << value_dim.channel()
     180              :     << " is not 1 for layer " << context.getName();
     181          121 :   NNTR_THROW_IF(provide_attention_mask && mask_dim.channel() != num_heads,
     182              :                 std::invalid_argument)
     183            0 :     << "Dimension of input mask channel: " << mask_dim.channel()
     184              :     << " is not matched with num_heads property: " << num_heads << " for layer "
     185              :     << context.getName();
     186              : 
     187          121 :   NNTR_THROW_IF(key_height != value_height, std::invalid_argument)
     188              :     << "Dimension of input key height: " << key_height
     189              :     << " is not matched with Dimension of input value height: " << value_height
     190              :     << " for layer " << context.getName();
     191          121 :   NNTR_THROW_IF(provide_attention_mask && mask_dim.height() != query_height,
     192              :                 std::invalid_argument)
     193            0 :     << "Dimension of input mask height: " << mask_dim.height()
     194              :     << " is not matched with Dimension of input query height: " << query_height
     195              :     << " for layer " << context.getName();
     196              : 
     197          121 :   NNTR_THROW_IF(provide_attention_mask && mask_dim.width() != key_height,
     198              :                 std::invalid_argument)
     199            0 :     << "Dimension of input mask width: " << mask_dim.width()
     200              :     << " is not matched with Dimension of input key height: " << key_height
     201              :     << " for layer " << context.getName();
     202              : 
     203              :   /** weight/bias for query fc */
     204              :   TensorDim query_fc_weight_dim(
     205          121 :     {1, 1, query_width, num_heads * projected_query_dim_prop}, weight_type);
     206              : 
     207          121 :   weight_idx[AttentionParams::query_fc_weight] = context.requestWeight(
     208              :     query_fc_weight_dim, weight_initializer, weight_regularizer,
     209              :     weight_regularizer_constant, weight_decay, "query_fc_weight", true);
     210          121 :   if (!disable_bias) {
     211              :     TensorDim query_fc_bias_dim({1, 1, 1, num_heads * projected_query_dim_prop},
     212          117 :                                 weight_type);
     213          117 :     weight_idx[AttentionParams::query_fc_bias] = context.requestWeight(
     214              :       query_fc_bias_dim, weight_initializer, weight_regularizer,
     215              :       weight_regularizer_constant, weight_decay, "query_fc_bias", true);
     216              :   }
     217              : 
     218              :   /** weight/bias for key fc */
     219              :   TensorDim key_fc_weight_dim(
     220          121 :     {1, 1, key_width, num_heads * projected_key_dim_prop}, weight_type);
     221          121 :   weight_idx[AttentionParams::key_fc_weight] = context.requestWeight(
     222              :     key_fc_weight_dim, weight_initializer, weight_regularizer,
     223              :     weight_regularizer_constant, weight_decay, "key_fc_weight", true);
     224          121 :   if (!disable_bias) {
     225          117 :     TensorDim key_fc_bias_dim({1, 1, 1, num_heads * projected_key_dim_prop},
     226          117 :                               weight_type);
     227          117 :     weight_idx[AttentionParams::key_fc_bias] = context.requestWeight(
     228              :       key_fc_bias_dim, weight_initializer, weight_regularizer,
     229              :       weight_regularizer_constant, weight_decay, "key_fc_bias", true);
     230              :   }
     231              : 
     232              :   /** weight/bias for value fc */
     233              :   TensorDim value_fc_weight_dim(
     234          121 :     {1, 1, value_width, num_heads * projected_value_dim_prop}, weight_type);
     235          121 :   weight_idx[AttentionParams::value_fc_weight] = context.requestWeight(
     236              :     value_fc_weight_dim, weight_initializer, weight_regularizer,
     237              :     weight_regularizer_constant, weight_decay, "value_fc_weight", true);
     238          121 :   if (!disable_bias) {
     239              :     TensorDim value_fc_bias_dim({1, 1, 1, num_heads * projected_value_dim_prop},
     240          117 :                                 weight_type);
     241          117 :     weight_idx[AttentionParams::value_fc_bias] = context.requestWeight(
     242              :       value_fc_bias_dim, weight_initializer, weight_regularizer,
     243              :       weight_regularizer_constant, weight_decay, "value_fc_bias", true);
     244              :   }
     245              : 
     246              :   /** weight/bias for out fc */
     247              :   TensorDim fc_weight_dim(
     248          121 :     {1, 1, num_heads * projected_value_dim_prop, output_shape}, weight_type);
     249          121 :   weight_idx[AttentionParams::fc_weight] = context.requestWeight(
     250              :     fc_weight_dim, weight_initializer, weight_regularizer,
     251              :     weight_regularizer_constant, weight_decay, "fc_weight", true);
     252          121 :   if (!disable_bias) {
     253          117 :     TensorDim fc_bias_dim({1, 1, 1, output_shape}, weight_type);
     254          117 :     weight_idx[AttentionParams::fc_bias] = context.requestWeight(
     255              :       fc_bias_dim, weight_initializer, weight_regularizer,
     256              :       weight_regularizer_constant, weight_decay, "fc_bias", true);
     257              :   }
     258              : 
     259              :   /** tensor for output of query fc */
     260              :   TensorDim projected_query_dim(
     261              :     {batch_size, 1, query_height, num_heads * projected_query_dim_prop},
     262          121 :     activation_type);
     263          121 :   weight_idx[AttentionParams::projected_query] = context.requestTensor(
     264              :     projected_query_dim, "projected_query", Initializer::NONE, true,
     265              :     TensorLifespan::ITERATION_LIFESPAN);
     266              :   /** tensor for output of key fc */
     267              :   TensorDim projected_key_dim(
     268          121 :     {batch_size, 1, key_height, num_heads * projected_key_dim_prop},
     269          121 :     activation_type);
     270          121 :   weight_idx[AttentionParams::projected_key] =
     271          121 :     context.requestTensor(projected_key_dim, "projected_key", Initializer::NONE,
     272              :                           true, TensorLifespan::ITERATION_LIFESPAN);
     273              :   /** tensor for output of value fc */
     274              :   TensorDim projected_value_dim(
     275              :     {batch_size, 1, value_height, num_heads * projected_value_dim_prop},
     276          121 :     activation_type);
     277          121 :   weight_idx[AttentionParams::projected_value] = context.requestTensor(
     278              :     projected_value_dim, "projected_value", Initializer::NONE, true,
     279              :     TensorLifespan::ITERATION_LIFESPAN);
     280              : 
     281          121 :   weight_idx[AttentionParams::cache_key] =
     282          121 :     context.requestTensor(projected_key_dim, "cache_key", Initializer::NONE,
     283              :                           true, TensorLifespan::MAX_LIFESPAN);
     284              : 
     285          121 :   weight_idx[AttentionParams::cache_value] =
     286          121 :     context.requestTensor(projected_value_dim, "cache_value", Initializer::NONE,
     287              :                           true, TensorLifespan::MAX_LIFESPAN);
     288              : 
     289              :   if (provide_attention_mask) {
     290              :     /** Intended comment for bool type mask */
     291              :     // TensorDim attention_mask_dim(
     292              :     //   {batch_size, num_heads, query_height, key_height});
     293              :     // weight_idx[AttentionParams::attention_mask] = context.requestTensor(
     294              :     //   attention_mask_dim, "attention_mask", Initializer::NONE, false,
     295              :     //   TensorLifespan::FORWARD_FUNC_LIFESPAN);
     296              :   }
     297              :   /** tensor for attention weight */
     298              :   TensorDim attention_weight_dim(
     299          121 :     {batch_size, num_heads, query_height, key_height}, activation_type);
     300          121 :   weight_idx[AttentionParams::attention_weight] = context.requestTensor(
     301              :     attention_weight_dim, "attention_weight", Initializer::NONE, true,
     302              :     TensorLifespan::ITERATION_LIFESPAN);
     303          121 :   if (dropout_rate > epsilon) {
     304              :     /** tensor for dropout mask */
     305              :     TensorDim dropout_mask_dim(
     306            0 :       {batch_size, num_heads, query_height, key_height}, activation_type);
     307            0 :     weight_idx[AttentionParams::dropout_mask] =
     308            0 :       context.requestTensor(dropout_mask_dim, "dropout_mask", Initializer::NONE,
     309              :                             false, TensorLifespan::ITERATION_LIFESPAN);
     310              :   }
     311              : 
     312              :   /** tensor for attention output */
     313              :   TensorDim attention_output_dim(
     314              :     {batch_size, 1, query_height, num_heads * projected_value_dim_prop},
     315          121 :     activation_type);
     316          121 :   weight_idx[AttentionParams::attention_output] = context.requestTensor(
     317              :     attention_output_dim, "attention_output", Initializer::NONE, true,
     318              :     TensorLifespan::ITERATION_LIFESPAN);
     319              : 
     320              :   TensorDim output_dim({batch_size, 1, query_height, output_shape},
     321          121 :                        activation_type);
     322          121 :   if (return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none) {
     323              :     TensorDim return_attention_weight_dim(
     324           17 :       {batch_size, average_attention_weight ? 1 : num_heads, query_height,
     325              :        key_height},
     326           18 :       activation_type);
     327           17 :     context.setOutputDimensions({output_dim, return_attention_weight_dim});
     328              :   } else {
     329          104 :     context.setOutputDimensions({output_dim});
     330              :   }
     331          121 : }
     332              : 
     333          186 : void MultiHeadAttentionLayer::forwarding(RunLayerContext &context,
     334              :                                          bool training) {
     335              :   const bool disable_bias =
     336          186 :     std::get<props::DisableBias>(*layer_impl_props).get();
     337              : 
     338              :   const unsigned int num_heads =
     339          186 :     std::get<props::NumHeads>(multi_head_attention_props).get();
     340              :   const unsigned int projected_key_dim_prop =
     341          186 :     std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
     342              :   const unsigned int projected_value_dim_prop =
     343          186 :     std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
     344              :   const float dropout_rate =
     345          186 :     std::get<props::DropOutRate>(multi_head_attention_props).get();
     346              :   const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
     347          186 :     std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
     348              :   const bool average_attention_weight =
     349          186 :     std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
     350              : 
     351          186 :   const bool provide_attention_mask = context.getNumInputs() == 4;
     352              :   const unsigned int projected_query_dim_prop = projected_key_dim_prop;
     353          186 :   const bool enable_dropout = dropout_rate > epsilon;
     354              : 
     355          186 :   Tensor empty_tensor;
     356              : 
     357              :   /** get inputs/outputs */
     358          186 :   Tensor &query = context.getInput(INOUT_INDEX::QUERY);
     359          186 :   Tensor &key = context.getInput(INOUT_INDEX::KEY);
     360          186 :   Tensor &value = context.getInput(INOUT_INDEX::VALUE);
     361              :   Tensor &mask =
     362          186 :     provide_attention_mask ? context.getInput(INOUT_INDEX::MASK) : empty_tensor;
     363              : 
     364          186 :   Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
     365              :   Tensor &ret_attention_weight =
     366              :     return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
     367          186 :       ? context.getOutput(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
     368              :       : empty_tensor;
     369              : 
     370              :   /** get weights */
     371              :   Tensor &query_fc_weight =
     372          186 :     context.getWeight(weight_idx[AttentionParams::query_fc_weight]);
     373              :   Tensor &query_fc_bias =
     374              :     disable_bias
     375          186 :       ? empty_tensor
     376          180 :       : context.getWeight(weight_idx[AttentionParams::query_fc_bias]);
     377              :   Tensor &key_fc_weight =
     378          186 :     context.getWeight(weight_idx[AttentionParams::key_fc_weight]);
     379              :   Tensor &key_fc_bias =
     380          186 :     disable_bias ? empty_tensor
     381          180 :                  : context.getWeight(weight_idx[AttentionParams::key_fc_bias]);
     382              :   Tensor &value_fc_weight =
     383          186 :     context.getWeight(weight_idx[AttentionParams::value_fc_weight]);
     384              :   Tensor &value_fc_bias =
     385              :     disable_bias
     386          186 :       ? empty_tensor
     387          180 :       : context.getWeight(weight_idx[AttentionParams::value_fc_bias]);
     388          186 :   Tensor &fc_weight = context.getWeight(weight_idx[AttentionParams::fc_weight]);
     389              :   Tensor &fc_bias = disable_bias
     390          186 :                       ? empty_tensor
     391          180 :                       : context.getWeight(weight_idx[AttentionParams::fc_bias]);
     392              : 
     393              :   /** get tensors */
     394              :   Tensor &projected_query =
     395          186 :     context.getTensor(weight_idx[AttentionParams::projected_query]);
     396              :   Tensor &projected_key =
     397          186 :     context.getTensor(weight_idx[AttentionParams::projected_key]);
     398              :   Tensor &projected_value =
     399          186 :     context.getTensor(weight_idx[AttentionParams::projected_value]);
     400              : 
     401              :   Tensor &attention_weight =
     402          186 :     context.getTensor(weight_idx[AttentionParams::attention_weight]);
     403              :   Tensor &attention_output =
     404          186 :     context.getTensor(weight_idx[AttentionParams::attention_output]);
     405              : 
     406          186 :   const TensorDim query_dim = query.getDim();
     407          186 :   const unsigned int batch_size = query_dim.batch();
     408          186 :   const unsigned int query_height = query_dim.height();
     409          186 :   const TensorDim key_dim = key.getDim();
     410          186 :   const unsigned int key_height = key_dim.height();
     411          186 :   const TensorDim value_dim = value.getDim();
     412          186 :   const unsigned int value_height = value_dim.height();
     413              : 
     414          186 :   query.dot(query_fc_weight, projected_query);
     415          186 :   if (!disable_bias) {
     416          180 :     projected_query.add_i(query_fc_bias);
     417              :   }
     418          186 :   key.dot(key_fc_weight, projected_key);
     419          186 :   if (!disable_bias) {
     420          180 :     projected_key.add_i(key_fc_bias);
     421              :   }
     422          186 :   value.dot(value_fc_weight, projected_value);
     423          186 :   if (!disable_bias) {
     424          180 :     projected_value.add_i(value_fc_bias);
     425              :   }
     426              : 
     427          186 :   projected_query.reshape(
     428          372 :     TensorDim({batch_size, query_height, num_heads, projected_query_dim_prop}));
     429          186 :   projected_key.reshape(
     430          372 :     TensorDim({batch_size, key_height, num_heads, projected_key_dim_prop}));
     431          186 :   projected_value.reshape(
     432          186 :     TensorDim({batch_size, value_height, num_heads, projected_value_dim_prop}));
     433              : 
     434          372 :   projected_query = projected_query.transpose("1:0:2");
     435          372 :   projected_key = projected_key.transpose("1:0:2");
     436          372 :   projected_value = projected_value.transpose("1:0:2");
     437              : 
     438              :   /** set tensor name to restore origin name cause origin name was remove during
     439              :    * transpose */
     440          186 :   projected_query.setName("multi_head_attention:projected_query");
     441          186 :   projected_key.setName("multi_head_attention:projected_key");
     442          372 :   projected_value.setName("multi_head_attention:projected_value");
     443              : 
     444          186 :   projected_query.reshape(TensorDim(
     445          186 :     {batch_size * num_heads, 1, query_height, projected_query_dim_prop}));
     446          186 :   projected_key.reshape(
     447          372 :     TensorDim({batch_size * num_heads, 1, key_height, projected_key_dim_prop}));
     448          186 :   projected_value.reshape(TensorDim(
     449              :     {batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
     450              : 
     451          186 :   attention_weight.reshape(
     452          372 :     TensorDim({batch_size * num_heads, 1, query_height, key_height}));
     453          186 :   attention_output.reshape(TensorDim(
     454              :     {batch_size * num_heads, 1, query_height, projected_value_dim_prop}));
     455              : 
     456              :   /** scaled dot product attention */
     457          186 :   projected_query.dotBatched(projected_key, attention_weight, false, true);
     458          186 :   attention_weight.multiply_i(1.0f /
     459          186 :                               std::sqrt((float)projected_query_dim_prop));
     460              : 
     461          186 :   if (provide_attention_mask) {
     462              :     // Tensor &attention_mask =
     463              :     //   context.getTensor(weight_idx[AttentionParams::attention_mask]);
     464              :     /** @todo: enable bool type tensor */
     465              :     // if (torch_ref) {
     466              :     //   attention_mask.setValue(-1e9);
     467              :     // } else {
     468              :     //   // flip
     469              :     //   attention_mask.setValue(1);
     470              :     //   attention_mask.subtract_i(mask);
     471              : 
     472              :     //   attention_mask.multiply_i(-1e9);
     473              :     // }
     474              :     // attention_mask.multiply_i(mask);
     475              :     // attention_weight.add_i(attention_mask);
     476              : 
     477           60 :     attention_weight.reshape(
     478           60 :       TensorDim({batch_size, num_heads, query_height, key_height}));
     479           60 :     attention_weight.add_i(mask);
     480           60 :     attention_weight.reshape(
     481          120 :       TensorDim({batch_size * num_heads, 1, query_height, key_height}));
     482              :   }
     483              : 
     484              :   sm.run_fn(attention_weight, attention_weight);
     485              : 
     486          186 :   if (return_attention_weight ==
     487              :       props::ReturnAttentionWeightInfo::Enum::before) {
     488            5 :     ret_attention_weight.copyData(attention_weight);
     489              :   }
     490              : 
     491          186 :   if (enable_dropout) {
     492              :     Tensor &dropout_mask =
     493            0 :       context.getTensor(weight_idx[AttentionParams::dropout_mask]);
     494            0 :     dropout_mask.dropout_mask(dropout_rate);
     495            0 :     attention_weight.multiply_i(dropout_mask);
     496              :   }
     497              : 
     498          186 :   if (return_attention_weight ==
     499              :       props::ReturnAttentionWeightInfo::Enum::after) {
     500           24 :     if (average_attention_weight) {
     501           24 :       attention_weight.reshape(
     502           24 :         TensorDim({batch_size, num_heads, query_height, key_height}));
     503           24 :       attention_weight.sum(1, ret_attention_weight, 1, 0);
     504           24 :       ret_attention_weight.divide_i(static_cast<float>(num_heads));
     505           24 :       attention_weight.reshape(
     506           48 :         TensorDim({batch_size * num_heads, 1, query_height, key_height}));
     507              :     } else {
     508            0 :       ret_attention_weight.copyData(attention_weight);
     509              :     }
     510              :   }
     511              : 
     512          186 :   attention_weight.dotBatched(projected_value, attention_output);
     513              : 
     514          186 :   attention_output.reshape(
     515          186 :     TensorDim({batch_size, num_heads, query_height, projected_value_dim_prop}));
     516              : 
     517          372 :   attention_output = attention_output.transpose("1:0:2");
     518              : 
     519              :   /** set tensor name to restore origin name cause origin name was remove during
     520              :    * transpose */
     521          372 :   attention_output.setName("multi_head_attention:attention_output");
     522              : 
     523          186 :   attention_output.reshape(TensorDim(
     524          186 :     {batch_size * query_height, 1, 1, num_heads * projected_value_dim_prop}));
     525              : 
     526          186 :   attention_output.dot(fc_weight, output);
     527          186 :   if (!disable_bias) {
     528          180 :     output.add_i(fc_bias);
     529              :   }
     530              : 
     531              :   /** restore shape */
     532          186 :   projected_query.reshape(TensorDim(
     533          186 :     {batch_size, 1, query_height, num_heads * projected_query_dim_prop}));
     534          186 :   projected_key.reshape(
     535          372 :     TensorDim({batch_size, 1, key_height, num_heads * projected_key_dim_prop}));
     536          186 :   projected_value.reshape(TensorDim(
     537              :     {batch_size, 1, value_height, num_heads * projected_value_dim_prop}));
     538              : 
     539          186 :   attention_weight.reshape(
     540          372 :     TensorDim({batch_size, num_heads, query_height, key_height}));
     541          186 :   attention_output.reshape(TensorDim(
     542              :     {batch_size, 1, query_height, num_heads * projected_value_dim_prop}));
     543          186 : }
     544              : 
     545            0 : void MultiHeadAttentionLayer::incremental_forwarding(RunLayerContext &context,
     546              :                                                      unsigned int from,
     547              :                                                      unsigned int to,
     548              :                                                      bool training) {
     549              :   const bool disable_bias =
     550            0 :     std::get<props::DisableBias>(*layer_impl_props).get();
     551              : 
     552              :   const unsigned int num_heads =
     553            0 :     std::get<props::NumHeads>(multi_head_attention_props).get();
     554              :   const unsigned int projected_key_dim_prop =
     555            0 :     std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
     556              :   const unsigned int projected_value_dim_prop =
     557            0 :     std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
     558              :   const float dropout_rate =
     559            0 :     std::get<props::DropOutRate>(multi_head_attention_props).get();
     560              :   const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
     561            0 :     std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
     562              :   const bool average_attention_weight =
     563            0 :     std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
     564              : 
     565            0 :   const bool provide_attention_mask = context.getNumInputs() == 4;
     566              :   const unsigned int projected_query_dim_prop = projected_key_dim_prop;
     567              :   const bool enable_dropout = dropout_rate > epsilon;
     568              : 
     569              :   /** get inputs/outputs */
     570            0 :   Tensor &query = context.getInput(INOUT_INDEX::QUERY);
     571            0 :   Tensor &key = context.getInput(INOUT_INDEX::KEY);
     572            0 :   Tensor &value = context.getInput(INOUT_INDEX::VALUE);
     573              : 
     574            0 :   Tensor empty_tensor("empty", value.getFormat(), value.getDataType());
     575              : 
     576              :   Tensor &mask =
     577            0 :     provide_attention_mask ? context.getInput(INOUT_INDEX::MASK) : empty_tensor;
     578              : 
     579            0 :   TensorDim query_dim = query.getDim();
     580            0 :   TensorDim key_dim = key.getDim();
     581            0 :   TensorDim value_dim = value.getDim();
     582              : 
     583            0 :   Tensor &output = context.getOutput(INOUT_INDEX::OUTPUT);
     584              : 
     585            0 :   TensorDim output_dim = output.getDim();
     586              :   Tensor &ret_attention_weight =
     587              :     return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
     588            0 :       ? context.getOutput(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
     589              :       : empty_tensor;
     590              : 
     591              :   /** get weights */
     592              :   Tensor &query_fc_weight =
     593            0 :     context.getWeight(weight_idx[AttentionParams::query_fc_weight]);
     594              :   Tensor &query_fc_bias =
     595              :     disable_bias
     596            0 :       ? empty_tensor
     597            0 :       : context.getWeight(weight_idx[AttentionParams::query_fc_bias]);
     598              :   Tensor &key_fc_weight =
     599            0 :     context.getWeight(weight_idx[AttentionParams::key_fc_weight]);
     600              :   Tensor &key_fc_bias =
     601            0 :     disable_bias ? empty_tensor
     602            0 :                  : context.getWeight(weight_idx[AttentionParams::key_fc_bias]);
     603              :   Tensor &value_fc_weight =
     604            0 :     context.getWeight(weight_idx[AttentionParams::value_fc_weight]);
     605              :   Tensor &value_fc_bias =
     606              :     disable_bias
     607            0 :       ? empty_tensor
     608            0 :       : context.getWeight(weight_idx[AttentionParams::value_fc_bias]);
     609            0 :   Tensor &fc_weight = context.getWeight(weight_idx[AttentionParams::fc_weight]);
     610              :   Tensor &fc_bias = disable_bias
     611            0 :                       ? empty_tensor
     612            0 :                       : context.getWeight(weight_idx[AttentionParams::fc_bias]);
     613              : 
     614              :   /** get tensors */
     615              :   Tensor &projected_query =
     616            0 :     context.getTensor(weight_idx[AttentionParams::projected_query]);
     617              :   Tensor &projected_key =
     618            0 :     context.getTensor(weight_idx[AttentionParams::projected_key]);
     619              :   Tensor &projected_value =
     620            0 :     context.getTensor(weight_idx[AttentionParams::projected_value]);
     621            0 :   Tensor &cache_key = context.getTensor(weight_idx[AttentionParams::cache_key]);
     622              :   Tensor &cache_value =
     623            0 :     context.getTensor(weight_idx[AttentionParams::cache_value]);
     624              : 
     625            0 :   TensorDim projected_query_dim = projected_query.getDim();
     626            0 :   TensorDim projected_key_dim = projected_key.getDim();
     627            0 :   TensorDim projected_value_dim = projected_value.getDim();
     628            0 :   TensorDim cache_key_dim = cache_key.getDim();
     629            0 :   TensorDim cache_value_dim = cache_value.getDim();
     630              : 
     631            0 :   TensorDim projected_query_step_dim = projected_query_dim;
     632              : 
     633            0 :   TensorDim projected_key_step_dim = projected_key_dim;
     634            0 :   TensorDim projected_value_step_dim = projected_value_dim;
     635            0 :   TensorDim cache_key_step_dim = cache_key_dim;
     636            0 :   TensorDim cache_value_step_dim = cache_value_dim;
     637            0 :   projected_query_step_dim.height(to - from);
     638              : 
     639            0 :   projected_key_step_dim.height(to);
     640            0 :   projected_value_step_dim.height(to);
     641            0 :   cache_key_step_dim.height(to - from);
     642            0 :   cache_value_step_dim.height(to - from);
     643              : 
     644              :   Tensor projected_query_step =
     645            0 :     projected_query.getSharedDataTensor(projected_query_step_dim, 0, true);
     646              :   Tensor projected_key_step =
     647            0 :     projected_key.getSharedDataTensor(projected_key_step_dim, 0, true);
     648              :   Tensor projected_value_step =
     649            0 :     projected_value.getSharedDataTensor(projected_value_step_dim, 0, true);
     650              : 
     651              :   Tensor cache_key_step = cache_key.getSharedDataTensor(
     652            0 :     cache_key_step_dim, from * cache_key_dim.width(), true);
     653              :   Tensor cache_value_step = cache_value.getSharedDataTensor(
     654            0 :     cache_value_step_dim, from * cache_value_dim.width(), true);
     655              : 
     656              :   TensorDim cached_key_dim = {cache_key_dim.batch(), cache_key_dim.channel(),
     657              :                               to, cache_key_dim.width(),
     658            0 :                               cache_key.getTensorType()};
     659              :   TensorDim cached_value_dim = {
     660              :     cache_value_dim.batch(), cache_value_dim.channel(), to,
     661            0 :     cache_value_dim.width(), cache_value.getTensorType()};
     662            0 :   Tensor cached_key = cache_key.getSharedDataTensor(cached_key_dim, 0, true);
     663              :   Tensor cached_value =
     664            0 :     cache_value.getSharedDataTensor(cached_value_dim, 0, true);
     665              : 
     666              :   Tensor &attention_weight =
     667            0 :     context.getTensor(weight_idx[AttentionParams::attention_weight]);
     668              :   Tensor &attention_output =
     669            0 :     context.getTensor(weight_idx[AttentionParams::attention_output]);
     670            0 :   TensorDim attention_weight_dim = attention_weight.getDim();
     671              : 
     672            0 :   TensorDim attention_weight_step_dim = attention_weight_dim;
     673            0 :   attention_weight_step_dim.height(to - from);
     674            0 :   attention_weight_step_dim.width(to);
     675              : 
     676              :   Tensor attention_weight_step =
     677            0 :     attention_weight.getSharedDataTensor(attention_weight_step_dim, 0, true);
     678              : 
     679            0 :   TensorDim attention_output_dim = attention_output.getDim();
     680            0 :   TensorDim attention_output_step_dim = attention_output_dim;
     681            0 :   attention_output_step_dim.height(to - from);
     682              : 
     683              :   Tensor attention_output_step =
     684            0 :     attention_output.getSharedDataTensor(attention_output_step_dim, 0, true);
     685              : 
     686            0 :   const unsigned int batch_size = query_dim.batch();
     687            0 :   const unsigned int query_height = query_dim.height();
     688            0 :   const unsigned int key_height = key_dim.height();
     689            0 :   const unsigned int value_height = value_dim.height();
     690              : 
     691            0 :   query.dot(query_fc_weight, projected_query_step);
     692            0 :   if (!disable_bias) {
     693            0 :     projected_query_step.add_i(query_fc_bias);
     694              :   }
     695            0 :   key.dot(key_fc_weight, cache_key_step);
     696            0 :   if (!disable_bias) {
     697            0 :     cache_key_step.add_i(key_fc_bias);
     698              :   }
     699            0 :   value.dot(value_fc_weight, cache_value_step);
     700            0 :   if (!disable_bias) {
     701            0 :     cache_value_step.add_i(value_fc_bias);
     702              :   }
     703              : 
     704            0 :   projected_query_step.reshape(
     705            0 :     TensorDim({batch_size, 1, num_heads, projected_query_dim_prop}));
     706            0 :   cached_key.reshape(
     707            0 :     TensorDim({batch_size, to, num_heads, projected_key_dim_prop}));
     708            0 :   cached_value.reshape(
     709            0 :     TensorDim({batch_size, to, num_heads, projected_value_dim_prop}));
     710              : 
     711            0 :   projected_query_step.transpose("1:0:2", projected_query_step);
     712            0 :   cached_key.transpose("1:0:2", projected_key_step);
     713            0 :   cached_value.transpose("1:0:2", projected_value_step);
     714              : 
     715            0 :   projected_query_step.reshape(
     716            0 :     TensorDim({batch_size * num_heads, 1, 1, projected_query_dim_prop}));
     717            0 :   projected_key_step.reshape(
     718            0 :     TensorDim({batch_size * num_heads, 1, to, projected_key_dim_prop}));
     719            0 :   projected_value_step.reshape(
     720            0 :     TensorDim({batch_size * num_heads, 1, to, projected_value_dim_prop}));
     721              : 
     722            0 :   attention_weight_step.reshape(TensorDim({batch_size * num_heads, 1, 1, to}));
     723            0 :   attention_output_step.reshape(
     724            0 :     TensorDim({batch_size * num_heads, 1, 1, projected_value_dim_prop}));
     725              : 
     726              :   /** scaled dot product attention */
     727            0 :   projected_query_step.dotBatched(projected_key_step, attention_weight_step,
     728              :                                   false, true);
     729            0 :   attention_weight_step.multiply_i(1 / sqrt((float)projected_query_dim_prop));
     730              : 
     731            0 :   if (!from) {
     732            0 :     unsigned int mask_size = attention_weight_step.getDim().width();
     733              :     unsigned int mask_dim_height = mask_size;
     734              :     unsigned int mask_dim_width = mask_size;
     735              : 
     736            0 :     Tensor causal_mask(TensorDim{1, 1, mask_size, mask_size,
     737            0 :                                  attention_weight_step.getTensorType()});
     738              : 
     739            0 :     causal_mask.setZero();
     740              : 
     741              : #ifdef ENABLE_FP16
     742              : #define _MASK_NUM -1e4
     743              : #else
     744              : #define _MASK_NUM -1e10
     745              : #endif
     746              : 
     747            0 :     for (unsigned int i = 0; i < mask_dim_height; ++i) {
     748            0 :       for (unsigned int j = i + 1; j < mask_dim_width; ++j) {
     749            0 :         causal_mask.setValue(0, 0, i, j, _MASK_NUM);
     750              :       }
     751              :     }
     752              : 
     753            0 :     attention_weight_step.add_i(causal_mask);
     754            0 :   }
     755              : 
     756              :   sm.run_fn(attention_weight_step, attention_weight_step);
     757              : 
     758            0 :   attention_weight_step.dotBatched(projected_value_step, attention_output_step);
     759              : 
     760            0 :   attention_output_step.reshape(
     761            0 :     TensorDim({batch_size, num_heads, to - from, projected_value_dim_prop}));
     762              : 
     763            0 :   attention_output_step = attention_output_step.transpose("1:0:2");
     764              : 
     765            0 :   attention_output_step.reshape(TensorDim(
     766            0 :     {batch_size * (to - from), 1, 1, num_heads * projected_value_dim_prop}));
     767              : 
     768            0 :   attention_output_step.dot(fc_weight, output);
     769            0 :   if (!disable_bias) {
     770            0 :     output.add_i(fc_bias);
     771              :   }
     772            0 : }
     773              : 
     774           85 : void MultiHeadAttentionLayer::calcCommonDerivative(RunLayerContext &context) {
     775              :   const unsigned int num_heads =
     776           85 :     std::get<props::NumHeads>(multi_head_attention_props).get();
     777              :   const unsigned int projected_key_dim_prop =
     778           85 :     std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
     779              :   const unsigned int projected_value_dim_prop =
     780           85 :     std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
     781              :   const float dropout_rate =
     782           85 :     std::get<props::DropOutRate>(multi_head_attention_props).get();
     783              :   const props::ReturnAttentionWeightInfo::Enum return_attention_weight =
     784           85 :     std::get<props::ReturnAttentionWeight>(multi_head_attention_props).get();
     785              :   const bool average_attention_weight =
     786           85 :     std::get<props::AverageAttentionWeight>(multi_head_attention_props).get();
     787              : 
     788           85 :   const bool provide_attention_mask = context.getNumInputs() == 4;
     789              :   const unsigned int projected_query_dim_prop = projected_key_dim_prop;
     790              : 
     791           85 :   Tensor empty_tensor;
     792              : 
     793           85 :   Tensor &query = context.getInput(INOUT_INDEX::QUERY);
     794           85 :   Tensor &key = context.getInput(INOUT_INDEX::KEY);
     795           85 :   Tensor &value = context.getInput(INOUT_INDEX::VALUE);
     796              :   const Tensor &incoming_derivative =
     797           85 :     context.getIncomingDerivative(INOUT_INDEX::OUTPUT);
     798              :   const Tensor &d_ret_attention_weight =
     799              :     return_attention_weight != props::ReturnAttentionWeightInfo::Enum::none
     800           85 :       ? context.getIncomingDerivative(INOUT_INDEX::RETURN_ATTENTION_WEIGHT)
     801           85 :       : empty_tensor;
     802              : 
     803           85 :   Tensor &fc_weight = context.getWeight(weight_idx[AttentionParams::fc_weight]);
     804              : 
     805              :   Tensor &projected_query =
     806           85 :     context.getTensor(weight_idx[AttentionParams::projected_query]);
     807              :   Tensor &d_projected_query =
     808           85 :     context.getTensorGrad(weight_idx[AttentionParams::projected_query]);
     809              :   Tensor &projected_key =
     810           85 :     context.getTensor(weight_idx[AttentionParams::projected_key]);
     811              :   Tensor &d_projected_key =
     812           85 :     context.getTensorGrad(weight_idx[AttentionParams::projected_key]);
     813              :   Tensor &projected_value =
     814           85 :     context.getTensor(weight_idx[AttentionParams::projected_value]);
     815              :   Tensor &d_projected_value =
     816           85 :     context.getTensorGrad(weight_idx[AttentionParams::projected_value]);
     817              : 
     818              :   Tensor &attention_weight =
     819           85 :     context.getTensor(weight_idx[AttentionParams::attention_weight]);
     820              :   Tensor &d_attention_weight =
     821           85 :     context.getTensorGrad(weight_idx[AttentionParams::attention_weight]);
     822              :   Tensor &d_attention_output =
     823           85 :     context.getTensorGrad(weight_idx[AttentionParams::attention_output]);
     824              : 
     825           85 :   const TensorDim query_dim = query.getDim();
     826           85 :   const unsigned int batch_size = query_dim.batch();
     827           85 :   const unsigned int query_height = query_dim.height();
     828           85 :   const TensorDim key_dim = key.getDim();
     829           85 :   const unsigned int key_height = key_dim.height();
     830           85 :   const TensorDim value_dim = value.getDim();
     831           85 :   const unsigned int value_height = value_dim.height();
     832              : 
     833           85 :   d_attention_output.dot_deriv_wrt_1(fc_weight, incoming_derivative);
     834              : 
     835           85 :   d_attention_output.reshape(
     836           85 :     TensorDim({batch_size, query_height, num_heads, projected_value_dim_prop}));
     837              : 
     838          170 :   d_attention_output = d_attention_output.transpose("1:0:2");
     839              : 
     840              :   /** set tensor name to restore origin name cause origin name was remove
     841              :    * during transpose */
     842          170 :   d_attention_output.setName("multi_head_attention:attention_output:grad");
     843              : 
     844           85 :   projected_query.reshape(TensorDim(
     845           85 :     {batch_size * num_heads, 1, query_height, projected_query_dim_prop}));
     846           85 :   d_projected_query.reshape(TensorDim(
     847              :     {batch_size * num_heads, 1, query_height, projected_query_dim_prop}));
     848           85 :   projected_key.reshape(
     849          170 :     TensorDim({batch_size * num_heads, 1, key_height, projected_key_dim_prop}));
     850           85 :   d_projected_key.reshape(
     851          170 :     TensorDim({batch_size * num_heads, 1, key_height, projected_key_dim_prop}));
     852           85 :   projected_value.reshape(TensorDim(
     853              :     {batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
     854           85 :   d_projected_value.reshape(TensorDim(
     855              :     {batch_size * num_heads, 1, value_height, projected_value_dim_prop}));
     856              : 
     857           85 :   attention_weight.reshape(
     858          170 :     TensorDim({batch_size * num_heads, 1, query_height, key_height}));
     859           85 :   d_attention_weight.reshape(
     860          170 :     TensorDim({batch_size * num_heads, 1, query_height, key_height}));
     861           85 :   d_attention_output.reshape(TensorDim(
     862              :     {batch_size * num_heads, 1, query_height, projected_value_dim_prop}));
     863              : 
     864           85 :   d_attention_weight.dot_batched_deriv_wrt_1(projected_value,
     865              :                                              d_attention_output);
     866           85 :   attention_weight.dot_batched_deriv_wrt_2(d_projected_value,
     867              :                                            d_attention_output);
     868              : 
     869           85 :   if (return_attention_weight ==
     870              :       props::ReturnAttentionWeightInfo::Enum::after) {
     871           12 :     const float scale = average_attention_weight ? 1 / (float)num_heads : 1;
     872           12 :     d_attention_weight.add_i(d_ret_attention_weight, scale);
     873              :   }
     874              : 
     875           85 :   if (dropout_rate > epsilon) {
     876              :     Tensor &dropout_mask =
     877            0 :       context.getTensor(weight_idx[AttentionParams::dropout_mask]);
     878            0 :     d_attention_weight.multiply_i(dropout_mask);
     879              :   }
     880              : 
     881           85 :   if (return_attention_weight ==
     882              :       props::ReturnAttentionWeightInfo::Enum::before) {
     883            1 :     d_attention_weight.add_i(d_ret_attention_weight);
     884              :   }
     885              : 
     886           85 :   sm.run_prime_fn(attention_weight, d_attention_weight, d_attention_weight);
     887           85 :   if (provide_attention_mask) {
     888           30 :     Tensor &d_mask = context.getOutgoingDerivative(INOUT_INDEX::MASK);
     889           30 :     d_mask.copyData(d_attention_weight);
     890              :   }
     891           85 :   d_attention_weight.multiply_i(
     892           85 :     1 / sqrt((float)projected_query_dim_prop)); /** scale */
     893              : 
     894           85 :   d_projected_query.dot_batched_deriv_wrt_1(projected_key, d_attention_weight,
     895              :                                             false, true);
     896           85 :   projected_query.dot_batched_deriv_wrt_2(d_projected_key, d_attention_weight,
     897              :                                           false, true);
     898              : 
     899           85 :   d_projected_query.reshape(
     900          170 :     TensorDim({batch_size, num_heads, query_height, projected_query_dim_prop}));
     901           85 :   d_projected_key.reshape(
     902          170 :     TensorDim({batch_size, num_heads, key_height, projected_key_dim_prop}));
     903           85 :   d_projected_value.reshape(
     904           85 :     TensorDim({batch_size, num_heads, value_height, projected_value_dim_prop}));
     905              : 
     906          170 :   d_projected_query = d_projected_query.transpose("1:0:2");
     907          170 :   d_projected_key = d_projected_key.transpose("1:0:2");
     908          170 :   d_projected_value = d_projected_value.transpose("1:0:2");
     909              : 
     910              :   /** set tensor name to restore origin name cause origin name was remove
     911              :    * during transpose */
     912           85 :   d_projected_query.setName("multi_head_attention:projected_query:grad");
     913           85 :   d_projected_key.setName("multi_head_attention:projected_key:grad");
     914          170 :   d_projected_value.setName("multi_head_attention:projected_value:grad");
     915              : 
     916              :   /** restore shape */
     917           85 :   projected_query.reshape(TensorDim(
     918           85 :     {batch_size, 1, query_height, num_heads * projected_query_dim_prop}));
     919           85 :   d_projected_query.reshape(TensorDim(
     920           85 :     {batch_size * query_height, 1, 1, num_heads * projected_query_dim_prop}));
     921           85 :   projected_key.reshape(
     922          170 :     TensorDim({batch_size, 1, key_height, num_heads * projected_key_dim_prop}));
     923           85 :   d_projected_key.reshape(TensorDim(
     924           85 :     {batch_size * key_height, 1, 1, num_heads * projected_key_dim_prop}));
     925           85 :   projected_value.reshape(TensorDim(
     926           85 :     {batch_size, 1, value_height, num_heads * projected_value_dim_prop}));
     927           85 :   d_projected_value.reshape(TensorDim(
     928           85 :     {batch_size * value_height, 1, 1, num_heads * projected_value_dim_prop}));
     929              : 
     930           85 :   attention_weight.reshape(
     931          170 :     TensorDim({batch_size, num_heads, query_height, key_height}));
     932           85 :   d_attention_weight.reshape(
     933          170 :     TensorDim({batch_size, num_heads, query_height, key_height}));
     934           85 :   d_attention_output.reshape(TensorDim(
     935              :     {batch_size, 1, query_height, num_heads * projected_value_dim_prop}));
     936           85 : }
     937              : 
     938           85 : void MultiHeadAttentionLayer::calcDerivative(RunLayerContext &context) {
     939           85 :   if (!context.getTrainable()) {
     940            0 :     calcCommonDerivative(context);
     941              :   }
     942              : 
     943           85 :   Tensor &query = context.getInput(INOUT_INDEX::QUERY);
     944           85 :   Tensor &d_query = context.getOutgoingDerivative(INOUT_INDEX::QUERY);
     945           85 :   Tensor &key = context.getInput(INOUT_INDEX::KEY);
     946           85 :   Tensor &d_key = context.getOutgoingDerivative(INOUT_INDEX::KEY);
     947           85 :   Tensor &value = context.getInput(INOUT_INDEX::VALUE);
     948           85 :   Tensor &d_value = context.getOutgoingDerivative(INOUT_INDEX::VALUE);
     949              :   /** d_mask will be calculated in calcCommonDerivative */
     950              : 
     951              :   Tensor &query_fc_weight =
     952           85 :     context.getWeight(weight_idx[AttentionParams::query_fc_weight]);
     953              :   Tensor &key_fc_weight =
     954           85 :     context.getWeight(weight_idx[AttentionParams::key_fc_weight]);
     955              :   Tensor &value_fc_weight =
     956           85 :     context.getWeight(weight_idx[AttentionParams::value_fc_weight]);
     957              : 
     958              :   Tensor &d_projected_query =
     959           85 :     context.getTensorGrad(weight_idx[AttentionParams::projected_query]);
     960              :   Tensor &d_projected_key =
     961           85 :     context.getTensorGrad(weight_idx[AttentionParams::projected_key]);
     962              :   Tensor &d_projected_value =
     963           85 :     context.getTensorGrad(weight_idx[AttentionParams::projected_value]);
     964              : 
     965           85 :   const TensorDim query_dim = query.getDim();
     966           85 :   const TensorDim key_dim = key.getDim();
     967           85 :   const TensorDim value_dim = value.getDim();
     968              : 
     969           85 :   d_query.dot_deriv_wrt_1(query_fc_weight, d_projected_query);
     970           85 :   d_key.dot_deriv_wrt_1(key_fc_weight, d_projected_key);
     971           85 :   d_value.dot_deriv_wrt_1(value_fc_weight, d_projected_value, false, false);
     972           85 : }
     973              : 
     974           85 : void MultiHeadAttentionLayer::calcGradient(RunLayerContext &context) {
     975           85 :   calcCommonDerivative(context);
     976              : 
     977              :   const bool disable_bias =
     978           85 :     std::get<props::DisableBias>(*layer_impl_props).get();
     979              : 
     980              :   const unsigned int num_heads =
     981           85 :     std::get<props::NumHeads>(multi_head_attention_props).get();
     982              :   const unsigned int projected_key_dim_prop =
     983           85 :     std::get<props::ProjectedKeyDim>(multi_head_attention_props).get();
     984              :   const unsigned int projected_value_dim_prop =
     985           85 :     std::get<props::ProjectedValueDim>(multi_head_attention_props).get();
     986              :   const unsigned int output_shape =
     987           85 :     std::get<props::OutputShape>(multi_head_attention_props).get();
     988              : 
     989              :   const unsigned int projected_query_dim_prop = projected_key_dim_prop;
     990              : 
     991           85 :   Tensor &query = context.getInput(INOUT_INDEX::QUERY);
     992           85 :   Tensor &key = context.getInput(INOUT_INDEX::KEY);
     993           85 :   Tensor &value = context.getInput(INOUT_INDEX::VALUE);
     994              :   const Tensor &incoming_derivative =
     995           85 :     context.getIncomingDerivative(INOUT_INDEX::OUTPUT);
     996              : 
     997              :   Tensor &d_query_fc_weight =
     998           85 :     context.getWeightGrad(weight_idx[AttentionParams::query_fc_weight]);
     999              :   Tensor &d_key_fc_weight =
    1000           85 :     context.getWeightGrad(weight_idx[AttentionParams::key_fc_weight]);
    1001              :   Tensor &d_value_fc_weight =
    1002           85 :     context.getWeightGrad(weight_idx[AttentionParams::value_fc_weight]);
    1003              :   Tensor &d_fc_weight =
    1004           85 :     context.getWeightGrad(weight_idx[AttentionParams::fc_weight]);
    1005              : 
    1006           85 :   Tensor empty_tensor;
    1007              :   Tensor &d_query_fc_bias =
    1008              :     disable_bias
    1009           85 :       ? empty_tensor
    1010           82 :       : context.getWeightGrad(weight_idx[AttentionParams::query_fc_bias]);
    1011              :   Tensor &d_key_fc_bias =
    1012              :     disable_bias
    1013              :       ? empty_tensor
    1014           82 :       : context.getWeightGrad(weight_idx[AttentionParams::key_fc_bias]);
    1015              :   Tensor &d_value_fc_bias =
    1016              :     disable_bias
    1017           85 :       ? empty_tensor
    1018           82 :       : context.getWeightGrad(weight_idx[AttentionParams::value_fc_bias]);
    1019              :   Tensor &d_fc_bias =
    1020              :     disable_bias ? empty_tensor
    1021           82 :                  : context.getWeightGrad(weight_idx[AttentionParams::fc_bias]);
    1022              : 
    1023              :   Tensor &d_projected_query =
    1024           85 :     context.getTensorGrad(weight_idx[AttentionParams::projected_query]);
    1025              :   Tensor &d_projected_key =
    1026           85 :     context.getTensorGrad(weight_idx[AttentionParams::projected_key]);
    1027              :   Tensor &d_projected_value =
    1028           85 :     context.getTensorGrad(weight_idx[AttentionParams::projected_value]);
    1029              : 
    1030              :   Tensor &attention_output =
    1031           85 :     context.getTensor(weight_idx[AttentionParams::attention_output]);
    1032              : 
    1033           85 :   const TensorDim query_dim = query.getDim();
    1034           85 :   const unsigned int batch_size = query_dim.batch();
    1035           85 :   const unsigned int query_height = query_dim.height();
    1036           85 :   const TensorDim key_dim = key.getDim();
    1037           85 :   const unsigned int key_height = key_dim.height();
    1038           85 :   const TensorDim value_dim = value.getDim();
    1039           85 :   const unsigned int value_height = value_dim.height();
    1040              : 
    1041           85 :   attention_output.dot_deriv_wrt_2(
    1042              :     d_fc_weight, incoming_derivative, false, false,
    1043           85 :     !context.isGradientFirstAccess(weight_idx[AttentionParams::fc_weight]));
    1044              : 
    1045           85 :   if (!disable_bias) {
    1046           82 :     Tensor incoming_derivative_ = incoming_derivative;
    1047           82 :     incoming_derivative_.reshape(
    1048           82 :       TensorDim({batch_size * query_height, 1, 1, output_shape}));
    1049           82 :     incoming_derivative_.sum(
    1050              :       0, d_fc_bias, 1,
    1051           82 :       !context.isGradientFirstAccess(weight_idx[AttentionParams::fc_bias]));
    1052           82 :   }
    1053              : 
    1054           85 :   query.dot_deriv_wrt_2(d_query_fc_weight, d_projected_query, false, false,
    1055           85 :                         !context.isGradientFirstAccess(
    1056              :                           weight_idx[AttentionParams::query_fc_weight]));
    1057           85 :   if (!disable_bias) {
    1058           82 :     d_projected_query.reshape(TensorDim(
    1059           82 :       {batch_size * query_height, 1, 1, num_heads * projected_query_dim_prop}));
    1060           82 :     d_projected_query.sum(0, d_query_fc_bias, 1,
    1061           82 :                           !context.isGradientFirstAccess(
    1062              :                             weight_idx[AttentionParams::query_fc_bias]));
    1063           82 :     d_projected_query.reshape(TensorDim(
    1064              :       {batch_size, 1, query_height, num_heads * projected_query_dim_prop}));
    1065              :   }
    1066              : 
    1067           85 :   key.dot_deriv_wrt_2(
    1068              :     d_key_fc_weight, d_projected_key, false, false,
    1069           85 :     !context.isGradientFirstAccess(weight_idx[AttentionParams::key_fc_weight]));
    1070           85 :   if (!disable_bias) {
    1071           82 :     d_projected_key.reshape(TensorDim(
    1072           82 :       {batch_size * key_height, 1, 1, num_heads * projected_key_dim_prop}));
    1073           82 :     d_projected_key.sum(
    1074              :       0, d_key_fc_bias, 1,
    1075           82 :       !context.isGradientFirstAccess(weight_idx[AttentionParams::key_fc_bias]));
    1076           82 :     d_projected_key.reshape(TensorDim(
    1077              :       {batch_size, 1, key_height, num_heads * projected_key_dim_prop}));
    1078              :   }
    1079              : 
    1080           85 :   value.dot_deriv_wrt_2(d_value_fc_weight, d_projected_value, false, false,
    1081           85 :                         !context.isGradientFirstAccess(
    1082              :                           weight_idx[AttentionParams::value_fc_weight]));
    1083           85 :   if (!disable_bias) {
    1084           82 :     d_projected_value.reshape(TensorDim(
    1085           82 :       {batch_size * value_height, 1, 1, num_heads * projected_value_dim_prop}));
    1086           82 :     d_projected_value.sum(0, d_value_fc_bias, 1,
    1087           82 :                           !context.isGradientFirstAccess(
    1088              :                             weight_idx[AttentionParams::value_fc_bias]));
    1089           82 :     d_projected_value.reshape(TensorDim(
    1090              :       {batch_size, 1, value_height, num_heads * projected_value_dim_prop}));
    1091              :   }
    1092           85 : }
    1093              : 
    1094          665 : void MultiHeadAttentionLayer::setProperty(
    1095              :   const std::vector<std::string> &values) {
    1096          665 :   auto remain_props = loadProperties(values, multi_head_attention_props);
    1097          663 :   LayerImpl::setProperty(remain_props);
    1098          663 : }
    1099              : 
    1100          108 : void MultiHeadAttentionLayer::setBatch(RunLayerContext &context,
    1101              :                                        unsigned int batch) {
    1102              :   const float dropout_rate =
    1103          108 :     std::get<props::DropOutRate>(multi_head_attention_props).get();
    1104              : 
    1105          108 :   context.updateTensor(weight_idx[AttentionParams::projected_query], batch);
    1106          108 :   context.updateTensor(weight_idx[AttentionParams::projected_key], batch);
    1107          108 :   context.updateTensor(weight_idx[AttentionParams::projected_value], batch);
    1108          108 :   context.updateTensor(weight_idx[AttentionParams::cache_key], batch);
    1109          108 :   context.updateTensor(weight_idx[AttentionParams::cache_value], batch);
    1110              :   // context.updateTensor(weight_idx[AttentionParams::cache_value], batch);
    1111          108 :   context.updateTensor(weight_idx[AttentionParams::attention_weight], batch);
    1112          108 :   if (dropout_rate > epsilon) {
    1113            0 :     context.updateTensor(weight_idx[AttentionParams::dropout_mask], batch);
    1114              :   }
    1115          108 :   context.updateTensor(weight_idx[AttentionParams::attention_output], batch);
    1116          108 : }
    1117              : 
    1118           54 : void MultiHeadAttentionLayer::exportTo(
    1119              :   Exporter &exporter, const ml::train::ExportMethods &method) const {
    1120           54 :   LayerImpl::exportTo(exporter, method);
    1121           54 :   exporter.saveResult(multi_head_attention_props, method, this);
    1122           54 : }
    1123              : 
    1124              : } /* namespace nntrainer */
        

Generated by: LCOV version 2.0-1