LCOV - code coverage report
Current view: top level - nntrainer/layers - attention_layer.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 50.9 % 116 59
Test Date: 2025-12-14 20:38:17 Functions: 66.7 % 12 8

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
       4              :  *
       5              :  * @file   attention_layer.cpp
       6              :  * @date   1 October 2021
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Parichay Kapoor <pk.kapoor@samsung.com>
       9              :  * @bug    No known bugs except for NYI items
      10              :  * @brief  This is Attention Layer Class for Neural Network
      11              :  *
      12              :  */
      13              : 
      14              : #include <cmath>
      15              : 
      16              : #include <attention_layer.h>
      17              : #include <layer_context.h>
      18              : #include <nntrainer_error.h>
      19              : #include <nntrainer_log.h>
      20              : #include <node_exporter.h>
      21              : 
      22              : namespace nntrainer {
      23              : 
      24           21 : AttentionLayer::AttentionLayer() {
      25              :   wt_idx.fill(std::numeric_limits<unsigned>::max());
      26           21 : }
      27              : 
      28           42 : AttentionLayer::~AttentionLayer() {}
      29              : 
      30              : static constexpr size_t SINGLE_INOUT_IDX = 0;
      31              : 
      32              : enum AttentionParams { query = 0, value = 1, key = 2, weights };
      33              : 
      34            7 : void AttentionLayer::finalizeCommon(InitLayerContext &context) {
      35            7 :   if (context.getNumInputs() < 2 || context.getNumInputs() > 3)
      36            0 :     throw std::runtime_error("Attention layer needs 2-3 inputs.");
      37              : 
      38              :   auto const &all_dims = context.getInputDimensions();
      39              :   auto const &query_dim = all_dims[AttentionParams::query];
      40              :   auto const &value_dim = all_dims[AttentionParams::value];
      41              : 
      42            7 :   NNTR_THROW_IF(query_dim.width() != value_dim.width(), std::invalid_argument)
      43              :     << "Query and Value dimension mismatch for layer " << context.getName();
      44              : 
      45            7 :   wt_idx[AttentionParams::query] = AttentionParams::query;
      46            7 :   wt_idx[AttentionParams::value] = AttentionParams::value;
      47            7 :   wt_idx[AttentionParams::key] = AttentionParams::value;
      48              : 
      49            7 :   if (context.getNumInputs() == 3) {
      50              :     auto const &key_dim = all_dims[AttentionParams::key];
      51            1 :     if (key_dim != value_dim)
      52            0 :       throw std::invalid_argument("Key and value must have same shape");
      53              : 
      54            1 :     wt_idx[AttentionParams::key] = AttentionParams::key;
      55              :   }
      56            7 : }
      57              : 
      58            7 : void AttentionLayer::finalize(InitLayerContext &context) {
      59            7 :   finalizeCommon(context);
      60              : 
      61              :   auto const &all_dims = context.getInputDimensions();
      62              :   auto const &query_dim = all_dims[AttentionParams::query];
      63              :   auto const &value_dim = all_dims[AttentionParams::value];
      64              : 
      65            7 :   auto weights_dim = query_dim;
      66            7 :   weights_dim.width(value_dim.height());
      67            7 :   wt_idx[AttentionParams::weights] =
      68            7 :     context.requestTensor(weights_dim, "weights", Initializer::NONE, false,
      69              :                           TensorLifespan::ITERATION_LIFESPAN);
      70              : 
      71            7 :   context.setOutputDimensions({query_dim});
      72              : 
      73              :   auto data_type = context.getActivationDataType();
      74            7 :   if (data_type == ml::train::TensorDim::DataType::FP32) {
      75            7 :     sm.setActiFunc<float>(ActivationType::ACT_SOFTMAX);
      76            0 :   } else if (data_type == ml::train::TensorDim::DataType::FP16) {
      77              : #ifdef ENABLE_FP16
      78              :     sm.setActiFunc<_FP16>(ActivationType::ACT_SOFTMAX);
      79              : #else
      80            0 :     throw std::runtime_error("enable-fp16 is not enabled");
      81              : #endif
      82              :   }
      83            7 : }
      84              : 
      85           15 : void AttentionLayer::forwarding(RunLayerContext &context, bool training) {
      86           15 :   Tensor &query = context.getInput(wt_idx[AttentionParams::query]);
      87           15 :   Tensor &value = context.getInput(wt_idx[AttentionParams::value]);
      88           15 :   Tensor &key = context.getInput(wt_idx[AttentionParams::key]);
      89              : 
      90           15 :   Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
      91           15 :   Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]);
      92              : 
      93           15 :   query.dotBatched(key, weights, false, true); /** dot 1 */
      94           15 :   if (std::get<props::ScaledDotProduct>(attention_props).get()) {
      95            0 :     weights.multiply_i(1 / sqrt((float)key.getDim().width()));
      96              :   }
      97           15 :   if (std::get<props::CausalMask>(attention_props).get()) {
      98            0 :     unsigned int mask_size = weights.getDim().width();
      99              :     unsigned int mask_dim_height = mask_size;
     100              :     unsigned int mask_dim_width = mask_size;
     101              : 
     102            0 :     Tensor causal_mask(TensorDim{mask_size, mask_size});
     103              : 
     104            0 :     causal_mask.setZero();
     105            0 :     for (unsigned int i = 0; i < mask_dim_height; ++i) {
     106            0 :       for (unsigned int j = i + 1; j < mask_dim_width; ++j) {
     107            0 :         causal_mask.setValue(0, 0, i, j, -1e10);
     108              :       }
     109              :     }
     110              : 
     111            0 :     weights.add_i(causal_mask);
     112            0 :   }
     113              : 
     114              :   sm.run_fn(weights, weights);       /** softmax */
     115           15 :   weights.dotBatched(value, output); /** dot 2 */
     116           15 : }
     117              : 
     118            0 : void AttentionLayer::incremental_forwarding(RunLayerContext &context,
     119              :                                             unsigned int from, unsigned int to,
     120              :                                             bool training) {
     121            0 :   Tensor &query = context.getInput(wt_idx[AttentionParams::query]);
     122            0 :   Tensor &value = context.getInput(wt_idx[AttentionParams::value]);
     123            0 :   Tensor &key = context.getInput(wt_idx[AttentionParams::key]);
     124              : 
     125            0 :   TensorDim query_dim = query.getDim();
     126            0 :   TensorDim value_dim = value.getDim();
     127            0 :   TensorDim key_dim = key.getDim();
     128            0 :   TensorDim query_step_dim = {query_dim.batch(), query_dim.channel(), to - from,
     129            0 :                               query_dim.width()};
     130            0 :   TensorDim value_step_dim = {value_dim.batch(), value_dim.channel(), to,
     131            0 :                               value_dim.width()};
     132            0 :   TensorDim key_step_dim = {key_dim.batch(), key_dim.channel(), to,
     133            0 :                             key_dim.width()};
     134              :   Tensor query_step =
     135            0 :     query.getSharedDataTensor(query_step_dim, from * query_dim.width(), true);
     136            0 :   Tensor value_step = value.getSharedDataTensor(value_step_dim, 0, true);
     137            0 :   Tensor key_step = key.getSharedDataTensor(key_step_dim, 0, true);
     138              : 
     139            0 :   Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
     140            0 :   TensorDim output_dim = output.getDim();
     141            0 :   TensorDim output_step_dim = {output_dim.batch(), output_dim.channel(),
     142            0 :                                to - from, output_dim.width()};
     143              :   Tensor output_step = output.getSharedDataTensor(
     144            0 :     output_step_dim, from * output_dim.width(), true);
     145              : 
     146            0 :   Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]);
     147            0 :   TensorDim weights_dim = weights.getDim();
     148              :   TensorDim weights_step_dim = {
     149            0 :     query_step_dim.batch(), query_step_dim.channel(), query_step_dim.height(),
     150            0 :     value_step_dim.height()};
     151              :   Tensor weights_step = weights.getSharedDataTensor(
     152            0 :     weights_step_dim, from * weights_dim.width(), true);
     153              : 
     154            0 :   query_step.dotBatched(key_step, weights_step, false, true); /** dot 1 */
     155            0 :   if (std::get<props::ScaledDotProduct>(attention_props).get()) {
     156            0 :     weights_step.multiply_i(1 / sqrt((float)key.getDim().width()));
     157              :   }
     158              : 
     159            0 :   if (std::get<props::CausalMask>(attention_props).get() && !from) {
     160            0 :     unsigned int mask_size = weights_step.getDim().width();
     161              :     unsigned int mask_dim_height = mask_size;
     162              :     unsigned int mask_dim_width = mask_size;
     163              : 
     164            0 :     Tensor causal_mask(TensorDim{mask_size, mask_size});
     165              : 
     166            0 :     causal_mask.setZero();
     167            0 :     for (unsigned int i = 0; i < mask_dim_height; ++i) {
     168            0 :       for (unsigned int j = i + 1; j < mask_dim_width; ++j) {
     169            0 :         causal_mask.setValue(0, 0, i, j, -1e10);
     170              :       }
     171              :     }
     172              : 
     173            0 :     weights_step.add_i(causal_mask);
     174            0 :   }
     175              : 
     176              :   sm.run_fn(weights_step, weights_step);            /** softmax */
     177            0 :   weights_step.dotBatched(value_step, output_step); /** dot 2 */
     178            0 : }
     179              : 
     180            3 : void AttentionLayer::calcDerivative(RunLayerContext &context) {
     181            3 :   const Tensor &derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX);
     182              : 
     183            3 :   Tensor &query = context.getInput(wt_idx[AttentionParams::query]);
     184            3 :   Tensor &value = context.getInput(wt_idx[AttentionParams::value]);
     185            3 :   Tensor &key = context.getInput(wt_idx[AttentionParams::key]);
     186              : 
     187              :   Tensor &dquery =
     188            3 :     context.getOutgoingDerivative(wt_idx[AttentionParams::query]);
     189              :   Tensor &dvalue =
     190            3 :     context.getOutgoingDerivative(wt_idx[AttentionParams::value]);
     191            3 :   Tensor &dkey = context.getOutgoingDerivative(wt_idx[AttentionParams::key]);
     192              : 
     193            3 :   Tensor &weights = context.getTensor(wt_idx[AttentionParams::weights]);
     194              : 
     195              :   Tensor dweight = Tensor(
     196            3 :     TensorDim({derivative.batch(), 1, derivative.height(), value.height()},
     197            3 :               weights.getTensorType()));
     198              : 
     199              :   /** derivative for dot 2 */
     200            3 :   dweight.dot_batched_deriv_wrt_1(value, derivative);
     201            3 :   weights.dot_batched_deriv_wrt_2(dvalue, derivative);
     202              : 
     203              :   /** derivative for softmax */
     204            3 :   sm.run_prime_fn(weights, dweight, dweight);
     205              : 
     206            3 :   if (std::get<props::ScaledDotProduct>(attention_props).get()) {
     207            0 :     dweight.multiply_i(1 / sqrt((float)key.getDim().width()));
     208              :   }
     209              : 
     210              :   /** derivative for dot 1 */
     211            3 :   dquery.dot_batched_deriv_wrt_1(key, dweight, false, true);
     212            4 :   query.dot_batched_deriv_wrt_2(dkey, dweight, false, true,
     213            3 :                                 context.getNumInputs() == 2);
     214            3 : }
     215              : 
     216           36 : void AttentionLayer::setProperty(const std::vector<std::string> &values) {
     217           36 :   auto remain_props = loadProperties(values, attention_props);
     218           35 :   if (!remain_props.empty()) {
     219              :     std::string msg = "[AttentionLayer] Unknown Layer Properties count " +
     220            2 :                       std::to_string(values.size());
     221            4 :     throw exception::not_supported(msg);
     222              :   }
     223           35 : }
     224              : 
     225            0 : void AttentionLayer::setBatch(RunLayerContext &context, unsigned int batch) {
     226            0 :   context.updateTensor(wt_idx[AttentionParams::weights], batch);
     227            0 : }
     228              : 
     229              : } /* namespace nntrainer */
        

Generated by: LCOV version 2.0-1