LCOV - code coverage report
Current view: top level - nntrainer/layers - mol_attention_layer.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 22.5 % 262 59
Test Date: 2025-12-14 20:38:17 Functions: 38.5 % 13 5

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
       4              :  *
       5              :  * @file   mol_attention_layer.cpp
       6              :  * @date   11 November 2021
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Parichay Kapoor <pk.kapoor@samsung.com>
       9              :  * @bug    No known bugs except for NYI items
      10              :  * @brief  This is MoL Attention Layer Class for Neural Network
      11              :  *
      12              :  */
      13              : 
      14              : #include <math.h>
      15              : 
      16              : #include <layer_context.h>
      17              : #include <mol_attention_layer.h>
      18              : #include <nntrainer_error.h>
      19              : #include <nntrainer_log.h>
      20              : #include <node_exporter.h>
      21              : 
      22              : namespace nntrainer {
      23              : 
      24           19 : MoLAttentionLayer::MoLAttentionLayer() :
      25           19 :   helper_exec(false),
      26           19 :   softmax(ActivationType::ACT_SOFTMAX, false),
      27           19 :   tanh(ActivationType::ACT_TANH, false),
      28           38 :   sigmoid(ActivationType::ACT_SIGMOID, false) {
      29              :   wt_idx.fill(std::numeric_limits<unsigned>::max());
      30           19 : }
      31              : 
      32           38 : MoLAttentionLayer::~MoLAttentionLayer() {}
      33              : 
      34              : static constexpr size_t SINGLE_INOUT_IDX = 0;
      35              : 
      36              : enum MoLAttentionParams {
      37              :   query = 0,
      38              :   value = 1,
      39              :   state = 2,
      40              :   mask_len = 3,
      41              :   fc_w,
      42              :   fc_bias,
      43              :   fc_proj_w,
      44              :   fc_out,
      45              :   fc_tanh,
      46              :   fc_proj_out,
      47              :   scores,
      48              :   prob,
      49              :   prob_left,
      50              :   prob_right,
      51              :   u_neg_div,
      52              :   u_pos_div,
      53              :   dstate,
      54              : };
      55              : 
      56            4 : void MoLAttentionLayer::finalize(InitLayerContext &context) {
      57            4 :   NNTR_THROW_IF(context.getNumInputs() < 3 || context.getNumInputs() > 4,
      58              :                 std::invalid_argument)
      59              :     << "MoL Attention layer needs 3-4 inputs.";
      60              : 
      61              :   auto const &all_dims = context.getInputDimensions();
      62              :   auto const &query_dim = all_dims[MoLAttentionParams::query];
      63              :   auto const &value_dim = all_dims[MoLAttentionParams::value];
      64              :   auto const &state_dim = all_dims[MoLAttentionParams::state];
      65              : 
      66            4 :   wt_idx[MoLAttentionParams::query] = MoLAttentionParams::query;
      67            4 :   wt_idx[MoLAttentionParams::value] = MoLAttentionParams::value;
      68            4 :   wt_idx[MoLAttentionParams::state] = MoLAttentionParams::state;
      69            4 :   wt_idx[MoLAttentionParams::mask_len] = MoLAttentionParams::mask_len;
      70              : 
      71            4 :   NNTR_THROW_IF(query_dim.width() != value_dim.width(), std::invalid_argument)
      72              :     << "Query and Value dimension mismatch for layer " << context.getName();
      73              : 
      74            4 :   NNTR_THROW_IF(std::get<props::Unit>(mol_props).empty(), std::invalid_argument)
      75              :     << "Number of units not provided for layer " << context.getName();
      76            4 :   auto unit = std::get<props::Unit>(mol_props).get();
      77              : 
      78            4 :   NNTR_THROW_IF(std::get<props::MoL_K>(mol_props).empty(),
      79              :                 std::invalid_argument)
      80              :     << "MoL_K property not provided for layer " << context.getName();
      81            4 :   auto mol_k = std::get<props::MoL_K>(mol_props).get();
      82              : 
      83            4 :   NNTR_THROW_IF(mol_k != state_dim.width(), std::invalid_argument)
      84              :     << "MoL_K property mismatches the provided state dimension for layer"
      85              :     << context.getName();
      86              : 
      87              :   auto &weight_regularizer =
      88              :     std::get<props::WeightRegularizer>(*layer_impl_props);
      89              :   auto &weight_regularizer_constant =
      90              :     std::get<props::WeightRegularizerConstant>(*layer_impl_props);
      91              :   auto &weight_initializer =
      92              :     std::get<props::WeightInitializer>(*layer_impl_props);
      93              :   auto &bias_initializer = std::get<props::BiasInitializer>(*layer_impl_props);
      94              :   auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
      95              :   auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
      96              : 
      97            4 :   TensorDim fc_w_dim = {query_dim.width(), unit};
      98            4 :   wt_idx[MoLAttentionParams::fc_w] = context.requestWeight(
      99              :     fc_w_dim, weight_initializer, weight_regularizer,
     100              :     weight_regularizer_constant, weight_decay, "fc_w", true);
     101            4 :   TensorDim fc_bias_dim = {unit};
     102            4 :   wt_idx[MoLAttentionParams::fc_bias] = context.requestWeight(
     103              :     fc_bias_dim, bias_initializer, weight_regularizer,
     104              :     weight_regularizer_constant, bias_decay, "fc_bias", true);
     105              : 
     106            4 :   TensorDim fc_proj_w_dim = {unit, 3 * mol_k};
     107            8 :   wt_idx[MoLAttentionParams::fc_proj_w] = context.requestWeight(
     108              :     fc_proj_w_dim, weight_initializer, weight_regularizer,
     109              :     weight_regularizer_constant, weight_decay, "fc_proj_w", true);
     110              : 
     111            4 :   TensorDim fc_out_dim = query_dim;
     112            4 :   fc_out_dim.width(fc_w_dim.width());
     113            4 :   wt_idx[MoLAttentionParams::fc_out] =
     114            4 :     context.requestTensor(fc_out_dim, "fc_out", Initializer::NONE, false,
     115              :                           TensorLifespan::FORWARD_FUNC_LIFESPAN);
     116              : 
     117            4 :   wt_idx[MoLAttentionParams::fc_tanh] =
     118            4 :     context.requestTensor(fc_out_dim, "fc_tanh", Initializer::NONE, false,
     119              :                           TensorLifespan::ITERATION_LIFESPAN);
     120              : 
     121            4 :   TensorDim fc_proj_out_dim = fc_out_dim;
     122            4 :   fc_proj_out_dim.width(fc_proj_w_dim.width());
     123            4 :   wt_idx[MoLAttentionParams::fc_proj_out] =
     124            8 :     context.requestTensor(fc_proj_out_dim, "fc_proj_out", Initializer::NONE,
     125              :                           false, TensorLifespan::ITERATION_LIFESPAN);
     126              : 
     127              :   TensorDim scores_dim =
     128            4 :     TensorDim({value_dim.batch(), 1, 1, value_dim.height()});
     129            4 :   wt_idx[MoLAttentionParams::scores] =
     130            4 :     context.requestTensor(scores_dim, "scores", Initializer::NONE, false,
     131              :                           TensorLifespan::ITERATION_LIFESPAN);
     132              : 
     133            4 :   TensorDim prob_dim = value_dim;
     134            4 :   prob_dim.width(mol_k);
     135            4 :   wt_idx[MoLAttentionParams::prob] =
     136            4 :     context.requestTensor(prob_dim, "prob", Initializer::NONE, false,
     137              :                           TensorLifespan::ITERATION_LIFESPAN);
     138            4 :   wt_idx[MoLAttentionParams::prob_left] =
     139            4 :     context.requestTensor(prob_dim, "prob_left", Initializer::NONE, false,
     140              :                           TensorLifespan::ITERATION_LIFESPAN);
     141            4 :   wt_idx[MoLAttentionParams::prob_right] =
     142            4 :     context.requestTensor(prob_dim, "prob_right", Initializer::NONE, false,
     143              :                           TensorLifespan::ITERATION_LIFESPAN);
     144            4 :   wt_idx[MoLAttentionParams::u_neg_div] =
     145            4 :     context.requestTensor(prob_dim, "u_neg_div", Initializer::NONE, false,
     146              :                           TensorLifespan::ITERATION_LIFESPAN);
     147            4 :   wt_idx[MoLAttentionParams::u_pos_div] =
     148            4 :     context.requestTensor(prob_dim, "u_pos_div", Initializer::NONE, false,
     149              :                           TensorLifespan::ITERATION_LIFESPAN);
     150            4 :   wt_idx[MoLAttentionParams::dstate] =
     151            4 :     context.requestTensor(state_dim, "dstate", Initializer::NONE, false,
     152              :                           TensorLifespan::BACKWARD_FUNC_LIFESPAN);
     153              : 
     154            4 :   if (context.getNumRequestedOutputs() == 2)
     155            0 :     context.setOutputDimensions({query_dim, state_dim});
     156              :   else
     157            4 :     context.setOutputDimensions({query_dim});
     158            4 : }
     159              : 
     160            0 : void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) {
     161            0 :   Tensor &query = context.getInput(wt_idx[MoLAttentionParams::query]);
     162            0 :   Tensor &value = context.getInput(wt_idx[MoLAttentionParams::value]);
     163            0 :   Tensor &state = context.getInput(wt_idx[MoLAttentionParams::state]);
     164              : 
     165            0 :   Tensor &output = context.getOutput(0);
     166            0 :   Tensor &fc_w = context.getWeight(wt_idx[MoLAttentionParams::fc_w]);
     167            0 :   Tensor &fc_bias = context.getWeight(wt_idx[MoLAttentionParams::fc_bias]);
     168            0 :   Tensor &fc_proj_w = context.getWeight(wt_idx[MoLAttentionParams::fc_proj_w]);
     169            0 :   Tensor &fc_out = context.getTensor(wt_idx[MoLAttentionParams::fc_out]);
     170            0 :   Tensor &fc_tanh = context.getTensor(wt_idx[MoLAttentionParams::fc_tanh]);
     171              :   Tensor &fc_proj_out =
     172            0 :     context.getTensor(wt_idx[MoLAttentionParams::fc_proj_out]);
     173            0 :   Tensor &scores = context.getTensor(wt_idx[MoLAttentionParams::scores]);
     174            0 :   Tensor &prob = context.getTensor(wt_idx[MoLAttentionParams::prob]);
     175            0 :   Tensor &prob_left = context.getTensor(wt_idx[MoLAttentionParams::prob_left]);
     176              :   Tensor &prob_right =
     177            0 :     context.getTensor(wt_idx[MoLAttentionParams::prob_right]);
     178            0 :   Tensor &u_neg_div = context.getTensor(wt_idx[MoLAttentionParams::u_neg_div]);
     179            0 :   Tensor &u_pos_div = context.getTensor(wt_idx[MoLAttentionParams::u_pos_div]);
     180              : 
     181            0 :   const TensorDim &input_dim = query.getDim();
     182            0 :   unsigned int batch = input_dim.batch();
     183            0 :   auto mol_k = std::get<props::MoL_K>(mol_props).get();
     184              : 
     185              :   /** reset helper state */
     186            0 :   helper_exec = false;
     187              : 
     188            0 :   query.dot(fc_w, fc_out);
     189            0 :   fc_out.add_i(fc_bias);
     190              : 
     191              :   tanh.run_fn(fc_out, fc_tanh);
     192              : 
     193            0 :   fc_tanh.dot(fc_proj_w, fc_proj_out);
     194              : 
     195            0 :   Tensor kappa_src, beta_src, alpha_src;
     196            0 :   kappa_src.copy_with_stride(
     197            0 :     fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, 0, false));
     198            0 :   beta_src.copy_with_stride(
     199            0 :     fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k, false));
     200            0 :   alpha_src.copy_with_stride(
     201            0 :     fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false));
     202              : 
     203            0 :   kappa_src.apply_i<float>(&expf);
     204            0 :   beta_src.apply_i<float>(&expf);
     205            0 :   Tensor kappa = kappa_src;
     206            0 :   Tensor beta = beta_src;
     207              : 
     208            0 :   Tensor alpha;
     209              :   softmax.run_fn(alpha_src, alpha);
     210              : 
     211            0 :   fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, 0, false)
     212            0 :     .copy_with_stride(kappa);
     213            0 :   fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k, false)
     214            0 :     .copy_with_stride(beta);
     215            0 :   fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false)
     216            0 :     .copy_with_stride(alpha);
     217              : 
     218              :   /** @todo cache u_base, u_pos, u_neg */
     219            0 :   Tensor u_base = Tensor(TensorDim({batch, 1, value.height(), mol_k}));
     220            0 :   for (unsigned int b = 0; b < batch; b++) {
     221            0 :     for (unsigned int h = 0; h < u_base.height(); h++) {
     222            0 :       float *u_data = u_base.getAddress<float>(b, 0, h, 0);
     223            0 :       std::fill(u_data, u_data + u_base.width(), static_cast<float>(h + 1));
     224              :     }
     225              :   }
     226              : 
     227            0 :   Tensor u_pos = u_base.add(0.5f);
     228            0 :   u_base.add_i(-0.5f);
     229            0 :   Tensor u_neg = u_base;
     230              : 
     231            0 :   Tensor beta_eps = beta.add(1e-8f);
     232              : 
     233            0 :   Tensor u_pos_m, u_neg_m;
     234            0 :   if (context.getNumOutputs() == 2) {
     235            0 :     Tensor &updated_state = context.getOutput(1);
     236            0 :     state.add(kappa, updated_state);
     237            0 :     u_pos_m = u_pos.subtract(updated_state);
     238            0 :     u_neg_m = u_neg.subtract(updated_state);
     239              :   } else {
     240            0 :     Tensor updated_state = state.add(kappa);
     241            0 :     u_pos_m = u_pos.subtract(updated_state);
     242            0 :     u_neg_m = u_neg.subtract(updated_state);
     243            0 :   }
     244              : 
     245            0 :   u_pos_m.divide(beta_eps, u_pos_div);
     246              :   sigmoid.run_fn(u_pos_div, prob_left);
     247              : 
     248            0 :   u_neg_m.divide(beta_eps, u_neg_div);
     249              :   sigmoid.run_fn(u_neg_div, prob_right);
     250              : 
     251            0 :   prob_left.subtract(prob_right, prob);
     252              : 
     253            0 :   Tensor prob_scaled = prob.multiply(alpha);
     254            0 :   prob_scaled.sum(3, scores);
     255              : 
     256            0 :   if (context.getNumInputs() == 4) {
     257            0 :     Tensor mask = Tensor(scores.getDim());
     258            0 :     mask.filter_mask(context.getInput(wt_idx[MoLAttentionParams::mask_len]),
     259              :                      false);
     260            0 :     scores.multiply_i(mask);
     261            0 :   }
     262              : 
     263            0 :   scores.dotBatched(value, output);
     264            0 : }
     265              : 
     266            0 : void MoLAttentionLayer::calcDerivativeHelper(RunLayerContext &context,
     267              :                                              Tensor &dstate) {
     268              :   /** optimize temporary tensor usage here */
     269            0 :   Tensor &query = context.getInput(wt_idx[MoLAttentionParams::query]);
     270            0 :   Tensor &value = context.getInput(wt_idx[MoLAttentionParams::value]);
     271              : 
     272            0 :   const Tensor &derivative = context.getIncomingDerivative(0);
     273              : 
     274              :   Tensor &fc_proj_out =
     275            0 :     context.getTensor(wt_idx[MoLAttentionParams::fc_proj_out]);
     276              :   Tensor &dfc_proj_out =
     277            0 :     context.getTensor(wt_idx[MoLAttentionParams::fc_proj_out]);
     278            0 :   Tensor &scores = context.getTensor(wt_idx[MoLAttentionParams::scores]);
     279            0 :   Tensor &prob = context.getTensor(wt_idx[MoLAttentionParams::prob]);
     280            0 :   Tensor &prob_left = context.getTensor(wt_idx[MoLAttentionParams::prob_left]);
     281              :   Tensor &prob_right =
     282            0 :     context.getTensor(wt_idx[MoLAttentionParams::prob_right]);
     283            0 :   Tensor &u_neg_div = context.getTensor(wt_idx[MoLAttentionParams::u_neg_div]);
     284            0 :   Tensor &u_pos_div = context.getTensor(wt_idx[MoLAttentionParams::u_pos_div]);
     285              : 
     286            0 :   const TensorDim &input_dim = query.getDim();
     287            0 :   unsigned int batch = input_dim.batch();
     288            0 :   auto mol_k = std::get<props::MoL_K>(mol_props).get();
     289              : 
     290            0 :   Tensor kappa, beta, alpha;
     291            0 :   kappa.copy_with_stride(
     292            0 :     fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, 0, false));
     293            0 :   beta.copy_with_stride(
     294            0 :     fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k, false));
     295            0 :   alpha.copy_with_stride(
     296            0 :     fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false));
     297              : 
     298            0 :   Tensor dscores = Tensor(TensorDim({value.batch(), 1, 1, value.height()}));
     299            0 :   dscores.dot_batched_deriv_wrt_1(value, derivative);
     300            0 :   dscores.reshape(TensorDim({scores.batch(), 1, scores.width(), 1}));
     301            0 :   if (context.getNumInputs() == 4) {
     302            0 :     Tensor mask = Tensor(dscores.getDim());
     303            0 :     mask.filter_mask(context.getInput(wt_idx[MoLAttentionParams::mask_len]));
     304            0 :     dscores.multiply_i(mask);
     305            0 :   }
     306              : 
     307            0 :   Tensor dprob_scaled = Tensor(TensorDim({batch, 1, value.height(), mol_k}));
     308            0 :   dprob_scaled.setZero();
     309            0 :   dprob_scaled.add_i(dscores);
     310              : 
     311            0 :   Tensor dalpha = dprob_scaled.multiply(prob).sum(2);
     312            0 :   Tensor dprob = dprob_scaled.multiply(alpha);
     313              : 
     314            0 :   Tensor dprob_left = dprob;
     315            0 :   Tensor dprob_right = dprob.multiply(-1);
     316              : 
     317            0 :   Tensor beta_eps = beta.add(1e-8f);
     318            0 :   Tensor du_neg_div, du_pos_div;
     319            0 :   sigmoid.run_prime_fn(prob_right, du_neg_div, dprob_right);
     320            0 :   Tensor du_neg_m = du_neg_div.divide(beta_eps);
     321            0 :   Tensor dm_neg = du_neg_m.multiply(-1).sum(2);
     322            0 :   Tensor dbeta_eps_neg = du_neg_m.multiply(u_neg_div).multiply(-1).sum(2);
     323              : 
     324            0 :   sigmoid.run_prime_fn(prob_left, du_pos_div, dprob_left);
     325            0 :   Tensor du_pos_m = du_pos_div.divide(beta_eps);
     326            0 :   Tensor dm_pos = du_pos_m.multiply(-1).sum(2);
     327            0 :   Tensor dbeta_eps_pos = du_pos_m.multiply(u_pos_div).multiply(-1).sum(2);
     328              : 
     329            0 :   Tensor dbeta_eps = dbeta_eps_neg.add(dbeta_eps_pos);
     330            0 :   dm_neg.add(dm_pos, dstate);
     331            0 :   if (context.getNumOutputs() == 2) {
     332            0 :     const Tensor &derivative_state = context.getIncomingDerivative(1);
     333            0 :     dstate.add_i(derivative_state);
     334            0 :   }
     335            0 :   Tensor dkappa = dstate;
     336            0 :   Tensor dbeta = dbeta_eps;
     337              : 
     338            0 :   Tensor dalpha_src;
     339            0 :   softmax.run_prime_fn(alpha, dalpha_src, dalpha);
     340              : 
     341            0 :   Tensor dkappa_src = dkappa.multiply(kappa);
     342            0 :   Tensor dbeta_src = dbeta.multiply(beta);
     343              : 
     344              :   /** dfc_proj_out shares memory with fc_proj_out */
     345            0 :   dfc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, 0, false)
     346            0 :     .copy_with_stride(dkappa_src);
     347            0 :   dfc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k, false)
     348            0 :     .copy_with_stride(dbeta_src);
     349            0 :   dfc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false)
     350            0 :     .copy_with_stride(dalpha_src);
     351              : 
     352              :   /** update the helper state */
     353            0 :   helper_exec = true;
     354            0 : }
     355              : 
     356            0 : void MoLAttentionLayer::calcDerivative(RunLayerContext &context) {
     357              :   Tensor &dquery =
     358            0 :     context.getOutgoingDerivative(wt_idx[MoLAttentionParams::query]);
     359              :   Tensor &dvalue =
     360            0 :     context.getOutgoingDerivative(wt_idx[MoLAttentionParams::value]);
     361              :   Tensor &dstate =
     362            0 :     context.getOutgoingDerivative(wt_idx[MoLAttentionParams::state]);
     363            0 :   Tensor &dstate_local = context.getTensor(wt_idx[MoLAttentionParams::dstate]);
     364              : 
     365            0 :   const Tensor &derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX);
     366              : 
     367            0 :   Tensor &fc_w = context.getWeight(wt_idx[MoLAttentionParams::fc_w]);
     368            0 :   Tensor &fc_proj_w = context.getWeight(wt_idx[MoLAttentionParams::fc_proj_w]);
     369            0 :   Tensor &fc_tanh = context.getTensor(wt_idx[MoLAttentionParams::fc_tanh]);
     370              :   Tensor &dfc_proj_out =
     371            0 :     context.getTensor(wt_idx[MoLAttentionParams::fc_proj_out]);
     372            0 :   Tensor &scores = context.getTensor(wt_idx[MoLAttentionParams::scores]);
     373              : 
     374            0 :   scores.dot_batched_deriv_wrt_2(dvalue, derivative);
     375              : 
     376            0 :   if (!helper_exec)
     377            0 :     calcDerivativeHelper(context, dstate);
     378              :   else
     379            0 :     dstate.copyData(dstate_local);
     380              : 
     381            0 :   Tensor dfc_tanh = Tensor(fc_tanh.getDim());
     382            0 :   dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out, false, false);
     383              : 
     384            0 :   Tensor dfc_out;
     385            0 :   tanh.run_prime_fn(fc_tanh, dfc_out, dfc_tanh);
     386            0 :   dquery.dot_deriv_wrt_1(fc_w, dfc_out, false, false);
     387            0 : }
     388              : 
     389            0 : void MoLAttentionLayer::calcGradient(RunLayerContext &context) {
     390            0 :   Tensor &query = context.getInput(wt_idx[MoLAttentionParams::query]);
     391            0 :   Tensor &dstate = context.getTensor(wt_idx[MoLAttentionParams::dstate]);
     392              : 
     393            0 :   Tensor &fc_proj_w = context.getWeight(wt_idx[MoLAttentionParams::fc_proj_w]);
     394            0 :   Tensor &dfc_w = context.getWeightGrad(wt_idx[MoLAttentionParams::fc_w]);
     395            0 :   Tensor &dfc_bias = context.getWeightGrad(wt_idx[MoLAttentionParams::fc_bias]);
     396              :   Tensor &dfc_proj_w =
     397            0 :     context.getWeightGrad(wt_idx[MoLAttentionParams::fc_proj_w]);
     398            0 :   Tensor &fc_tanh = context.getTensor(wt_idx[MoLAttentionParams::fc_tanh]);
     399              :   Tensor &dfc_proj_out =
     400            0 :     context.getTensor(wt_idx[MoLAttentionParams::fc_proj_out]);
     401              : 
     402            0 :   if (!helper_exec)
     403            0 :     calcDerivativeHelper(context, dstate);
     404              : 
     405            0 :   Tensor dfc_tanh = Tensor(fc_tanh.getDim());
     406            0 :   fc_tanh.dot_deriv_wrt_2(
     407              :     dfc_proj_w, dfc_proj_out, false, false,
     408            0 :     !context.isGradientFirstAccess(wt_idx[MoLAttentionParams::fc_proj_w]));
     409            0 :   dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out);
     410              : 
     411            0 :   Tensor dfc_out;
     412            0 :   tanh.run_prime_fn(fc_tanh, dfc_out, dfc_tanh);
     413            0 :   query.dot_deriv_wrt_2(
     414              :     dfc_w, dfc_out, false, false,
     415            0 :     !context.isGradientFirstAccess(wt_idx[MoLAttentionParams::fc_w]));
     416              : 
     417            0 :   if (context.isGradientFirstAccess(wt_idx[MoLAttentionParams::fc_bias])) {
     418            0 :     dfc_out.sum({0, 1, 2}, dfc_bias);
     419              :   } else {
     420              :     /// @todo optimize below by adding beta to Tensor::sum
     421            0 :     Tensor t = dfc_out.sum({0, 1, 2});
     422            0 :     dfc_bias.add_i(t);
     423            0 :   }
     424            0 : }
     425              : 
     426           35 : void MoLAttentionLayer::setProperty(const std::vector<std::string> &values) {
     427           35 :   auto remain_props = loadProperties(values, mol_props);
     428           34 :   LayerImpl::setProperty(remain_props);
     429           34 : }
     430              : 
     431            0 : void MoLAttentionLayer::setBatch(RunLayerContext &context, unsigned int batch) {
     432            0 :   context.updateTensor(wt_idx[MoLAttentionParams::fc_out], batch);
     433            0 :   context.updateTensor(wt_idx[MoLAttentionParams::fc_tanh], batch);
     434            0 :   context.updateTensor(wt_idx[MoLAttentionParams::fc_proj_out], batch);
     435            0 :   context.updateTensor(wt_idx[MoLAttentionParams::scores], batch);
     436            0 :   context.updateTensor(wt_idx[MoLAttentionParams::prob], batch);
     437            0 :   context.updateTensor(wt_idx[MoLAttentionParams::prob_left], batch);
     438            0 :   context.updateTensor(wt_idx[MoLAttentionParams::prob_right], batch);
     439            0 :   context.updateTensor(wt_idx[MoLAttentionParams::u_neg_div], batch);
     440            0 :   context.updateTensor(wt_idx[MoLAttentionParams::u_pos_div], batch);
     441            0 :   context.updateTensor(wt_idx[MoLAttentionParams::dstate], batch);
     442            0 : }
     443              : 
     444            0 : void MoLAttentionLayer::exportTo(Exporter &exporter,
     445              :                                  const ml::train::ExportMethods &method) const {
     446            0 :   LayerImpl::exportTo(exporter, method);
     447            0 :   exporter.saveResult(mol_props, method, this);
     448            0 : }
     449              : 
     450              : } /* namespace nntrainer */
        

Generated by: LCOV version 2.0-1