Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file mol_attention_layer.cpp
6 : * @date 11 November 2021
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
9 : * @bug No known bugs except for NYI items
10 : * @brief This is MoL Attention Layer Class for Neural Network
11 : *
12 : */
13 :
14 : #include <math.h>
15 :
16 : #include <layer_context.h>
17 : #include <mol_attention_layer.h>
18 : #include <nntrainer_error.h>
19 : #include <nntrainer_log.h>
20 : #include <node_exporter.h>
21 :
22 : namespace nntrainer {
23 :
24 19 : MoLAttentionLayer::MoLAttentionLayer() :
25 19 : helper_exec(false),
26 19 : softmax(ActivationType::ACT_SOFTMAX, false),
27 19 : tanh(ActivationType::ACT_TANH, false),
28 38 : sigmoid(ActivationType::ACT_SIGMOID, false) {
29 : wt_idx.fill(std::numeric_limits<unsigned>::max());
30 19 : }
31 :
32 38 : MoLAttentionLayer::~MoLAttentionLayer() {}
33 :
34 : static constexpr size_t SINGLE_INOUT_IDX = 0;
35 :
36 : enum MoLAttentionParams {
37 : query = 0,
38 : value = 1,
39 : state = 2,
40 : mask_len = 3,
41 : fc_w,
42 : fc_bias,
43 : fc_proj_w,
44 : fc_out,
45 : fc_tanh,
46 : fc_proj_out,
47 : scores,
48 : prob,
49 : prob_left,
50 : prob_right,
51 : u_neg_div,
52 : u_pos_div,
53 : dstate,
54 : };
55 :
56 4 : void MoLAttentionLayer::finalize(InitLayerContext &context) {
57 4 : NNTR_THROW_IF(context.getNumInputs() < 3 || context.getNumInputs() > 4,
58 : std::invalid_argument)
59 : << "MoL Attention layer needs 3-4 inputs.";
60 :
61 : auto const &all_dims = context.getInputDimensions();
62 : auto const &query_dim = all_dims[MoLAttentionParams::query];
63 : auto const &value_dim = all_dims[MoLAttentionParams::value];
64 : auto const &state_dim = all_dims[MoLAttentionParams::state];
65 :
66 4 : wt_idx[MoLAttentionParams::query] = MoLAttentionParams::query;
67 4 : wt_idx[MoLAttentionParams::value] = MoLAttentionParams::value;
68 4 : wt_idx[MoLAttentionParams::state] = MoLAttentionParams::state;
69 4 : wt_idx[MoLAttentionParams::mask_len] = MoLAttentionParams::mask_len;
70 :
71 4 : NNTR_THROW_IF(query_dim.width() != value_dim.width(), std::invalid_argument)
72 : << "Query and Value dimension mismatch for layer " << context.getName();
73 :
74 4 : NNTR_THROW_IF(std::get<props::Unit>(mol_props).empty(), std::invalid_argument)
75 : << "Number of units not provided for layer " << context.getName();
76 4 : auto unit = std::get<props::Unit>(mol_props).get();
77 :
78 4 : NNTR_THROW_IF(std::get<props::MoL_K>(mol_props).empty(),
79 : std::invalid_argument)
80 : << "MoL_K property not provided for layer " << context.getName();
81 4 : auto mol_k = std::get<props::MoL_K>(mol_props).get();
82 :
83 4 : NNTR_THROW_IF(mol_k != state_dim.width(), std::invalid_argument)
84 : << "MoL_K property mismatches the provided state dimension for layer"
85 : << context.getName();
86 :
87 : auto &weight_regularizer =
88 : std::get<props::WeightRegularizer>(*layer_impl_props);
89 : auto &weight_regularizer_constant =
90 : std::get<props::WeightRegularizerConstant>(*layer_impl_props);
91 : auto &weight_initializer =
92 : std::get<props::WeightInitializer>(*layer_impl_props);
93 : auto &bias_initializer = std::get<props::BiasInitializer>(*layer_impl_props);
94 : auto &weight_decay = std::get<props::WeightDecay>(*layer_impl_props);
95 : auto &bias_decay = std::get<props::BiasDecay>(*layer_impl_props);
96 :
97 4 : TensorDim fc_w_dim = {query_dim.width(), unit};
98 4 : wt_idx[MoLAttentionParams::fc_w] = context.requestWeight(
99 : fc_w_dim, weight_initializer, weight_regularizer,
100 : weight_regularizer_constant, weight_decay, "fc_w", true);
101 4 : TensorDim fc_bias_dim = {unit};
102 4 : wt_idx[MoLAttentionParams::fc_bias] = context.requestWeight(
103 : fc_bias_dim, bias_initializer, weight_regularizer,
104 : weight_regularizer_constant, bias_decay, "fc_bias", true);
105 :
106 4 : TensorDim fc_proj_w_dim = {unit, 3 * mol_k};
107 8 : wt_idx[MoLAttentionParams::fc_proj_w] = context.requestWeight(
108 : fc_proj_w_dim, weight_initializer, weight_regularizer,
109 : weight_regularizer_constant, weight_decay, "fc_proj_w", true);
110 :
111 4 : TensorDim fc_out_dim = query_dim;
112 4 : fc_out_dim.width(fc_w_dim.width());
113 4 : wt_idx[MoLAttentionParams::fc_out] =
114 4 : context.requestTensor(fc_out_dim, "fc_out", Initializer::NONE, false,
115 : TensorLifespan::FORWARD_FUNC_LIFESPAN);
116 :
117 4 : wt_idx[MoLAttentionParams::fc_tanh] =
118 4 : context.requestTensor(fc_out_dim, "fc_tanh", Initializer::NONE, false,
119 : TensorLifespan::ITERATION_LIFESPAN);
120 :
121 4 : TensorDim fc_proj_out_dim = fc_out_dim;
122 4 : fc_proj_out_dim.width(fc_proj_w_dim.width());
123 4 : wt_idx[MoLAttentionParams::fc_proj_out] =
124 8 : context.requestTensor(fc_proj_out_dim, "fc_proj_out", Initializer::NONE,
125 : false, TensorLifespan::ITERATION_LIFESPAN);
126 :
127 : TensorDim scores_dim =
128 4 : TensorDim({value_dim.batch(), 1, 1, value_dim.height()});
129 4 : wt_idx[MoLAttentionParams::scores] =
130 4 : context.requestTensor(scores_dim, "scores", Initializer::NONE, false,
131 : TensorLifespan::ITERATION_LIFESPAN);
132 :
133 4 : TensorDim prob_dim = value_dim;
134 4 : prob_dim.width(mol_k);
135 4 : wt_idx[MoLAttentionParams::prob] =
136 4 : context.requestTensor(prob_dim, "prob", Initializer::NONE, false,
137 : TensorLifespan::ITERATION_LIFESPAN);
138 4 : wt_idx[MoLAttentionParams::prob_left] =
139 4 : context.requestTensor(prob_dim, "prob_left", Initializer::NONE, false,
140 : TensorLifespan::ITERATION_LIFESPAN);
141 4 : wt_idx[MoLAttentionParams::prob_right] =
142 4 : context.requestTensor(prob_dim, "prob_right", Initializer::NONE, false,
143 : TensorLifespan::ITERATION_LIFESPAN);
144 4 : wt_idx[MoLAttentionParams::u_neg_div] =
145 4 : context.requestTensor(prob_dim, "u_neg_div", Initializer::NONE, false,
146 : TensorLifespan::ITERATION_LIFESPAN);
147 4 : wt_idx[MoLAttentionParams::u_pos_div] =
148 4 : context.requestTensor(prob_dim, "u_pos_div", Initializer::NONE, false,
149 : TensorLifespan::ITERATION_LIFESPAN);
150 4 : wt_idx[MoLAttentionParams::dstate] =
151 4 : context.requestTensor(state_dim, "dstate", Initializer::NONE, false,
152 : TensorLifespan::BACKWARD_FUNC_LIFESPAN);
153 :
154 4 : if (context.getNumRequestedOutputs() == 2)
155 0 : context.setOutputDimensions({query_dim, state_dim});
156 : else
157 4 : context.setOutputDimensions({query_dim});
158 4 : }
159 :
160 0 : void MoLAttentionLayer::forwarding(RunLayerContext &context, bool training) {
161 0 : Tensor &query = context.getInput(wt_idx[MoLAttentionParams::query]);
162 0 : Tensor &value = context.getInput(wt_idx[MoLAttentionParams::value]);
163 0 : Tensor &state = context.getInput(wt_idx[MoLAttentionParams::state]);
164 :
165 0 : Tensor &output = context.getOutput(0);
166 0 : Tensor &fc_w = context.getWeight(wt_idx[MoLAttentionParams::fc_w]);
167 0 : Tensor &fc_bias = context.getWeight(wt_idx[MoLAttentionParams::fc_bias]);
168 0 : Tensor &fc_proj_w = context.getWeight(wt_idx[MoLAttentionParams::fc_proj_w]);
169 0 : Tensor &fc_out = context.getTensor(wt_idx[MoLAttentionParams::fc_out]);
170 0 : Tensor &fc_tanh = context.getTensor(wt_idx[MoLAttentionParams::fc_tanh]);
171 : Tensor &fc_proj_out =
172 0 : context.getTensor(wt_idx[MoLAttentionParams::fc_proj_out]);
173 0 : Tensor &scores = context.getTensor(wt_idx[MoLAttentionParams::scores]);
174 0 : Tensor &prob = context.getTensor(wt_idx[MoLAttentionParams::prob]);
175 0 : Tensor &prob_left = context.getTensor(wt_idx[MoLAttentionParams::prob_left]);
176 : Tensor &prob_right =
177 0 : context.getTensor(wt_idx[MoLAttentionParams::prob_right]);
178 0 : Tensor &u_neg_div = context.getTensor(wt_idx[MoLAttentionParams::u_neg_div]);
179 0 : Tensor &u_pos_div = context.getTensor(wt_idx[MoLAttentionParams::u_pos_div]);
180 :
181 0 : const TensorDim &input_dim = query.getDim();
182 0 : unsigned int batch = input_dim.batch();
183 0 : auto mol_k = std::get<props::MoL_K>(mol_props).get();
184 :
185 : /** reset helper state */
186 0 : helper_exec = false;
187 :
188 0 : query.dot(fc_w, fc_out);
189 0 : fc_out.add_i(fc_bias);
190 :
191 : tanh.run_fn(fc_out, fc_tanh);
192 :
193 0 : fc_tanh.dot(fc_proj_w, fc_proj_out);
194 :
195 0 : Tensor kappa_src, beta_src, alpha_src;
196 0 : kappa_src.copy_with_stride(
197 0 : fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, 0, false));
198 0 : beta_src.copy_with_stride(
199 0 : fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k, false));
200 0 : alpha_src.copy_with_stride(
201 0 : fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false));
202 :
203 0 : kappa_src.apply_i<float>(&expf);
204 0 : beta_src.apply_i<float>(&expf);
205 0 : Tensor kappa = kappa_src;
206 0 : Tensor beta = beta_src;
207 :
208 0 : Tensor alpha;
209 : softmax.run_fn(alpha_src, alpha);
210 :
211 0 : fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, 0, false)
212 0 : .copy_with_stride(kappa);
213 0 : fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k, false)
214 0 : .copy_with_stride(beta);
215 0 : fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false)
216 0 : .copy_with_stride(alpha);
217 :
218 : /** @todo cache u_base, u_pos, u_neg */
219 0 : Tensor u_base = Tensor(TensorDim({batch, 1, value.height(), mol_k}));
220 0 : for (unsigned int b = 0; b < batch; b++) {
221 0 : for (unsigned int h = 0; h < u_base.height(); h++) {
222 0 : float *u_data = u_base.getAddress<float>(b, 0, h, 0);
223 0 : std::fill(u_data, u_data + u_base.width(), static_cast<float>(h + 1));
224 : }
225 : }
226 :
227 0 : Tensor u_pos = u_base.add(0.5f);
228 0 : u_base.add_i(-0.5f);
229 0 : Tensor u_neg = u_base;
230 :
231 0 : Tensor beta_eps = beta.add(1e-8f);
232 :
233 0 : Tensor u_pos_m, u_neg_m;
234 0 : if (context.getNumOutputs() == 2) {
235 0 : Tensor &updated_state = context.getOutput(1);
236 0 : state.add(kappa, updated_state);
237 0 : u_pos_m = u_pos.subtract(updated_state);
238 0 : u_neg_m = u_neg.subtract(updated_state);
239 : } else {
240 0 : Tensor updated_state = state.add(kappa);
241 0 : u_pos_m = u_pos.subtract(updated_state);
242 0 : u_neg_m = u_neg.subtract(updated_state);
243 0 : }
244 :
245 0 : u_pos_m.divide(beta_eps, u_pos_div);
246 : sigmoid.run_fn(u_pos_div, prob_left);
247 :
248 0 : u_neg_m.divide(beta_eps, u_neg_div);
249 : sigmoid.run_fn(u_neg_div, prob_right);
250 :
251 0 : prob_left.subtract(prob_right, prob);
252 :
253 0 : Tensor prob_scaled = prob.multiply(alpha);
254 0 : prob_scaled.sum(3, scores);
255 :
256 0 : if (context.getNumInputs() == 4) {
257 0 : Tensor mask = Tensor(scores.getDim());
258 0 : mask.filter_mask(context.getInput(wt_idx[MoLAttentionParams::mask_len]),
259 : false);
260 0 : scores.multiply_i(mask);
261 0 : }
262 :
263 0 : scores.dotBatched(value, output);
264 0 : }
265 :
266 0 : void MoLAttentionLayer::calcDerivativeHelper(RunLayerContext &context,
267 : Tensor &dstate) {
268 : /** optimize temporary tensor usage here */
269 0 : Tensor &query = context.getInput(wt_idx[MoLAttentionParams::query]);
270 0 : Tensor &value = context.getInput(wt_idx[MoLAttentionParams::value]);
271 :
272 0 : const Tensor &derivative = context.getIncomingDerivative(0);
273 :
274 : Tensor &fc_proj_out =
275 0 : context.getTensor(wt_idx[MoLAttentionParams::fc_proj_out]);
276 : Tensor &dfc_proj_out =
277 0 : context.getTensor(wt_idx[MoLAttentionParams::fc_proj_out]);
278 0 : Tensor &scores = context.getTensor(wt_idx[MoLAttentionParams::scores]);
279 0 : Tensor &prob = context.getTensor(wt_idx[MoLAttentionParams::prob]);
280 0 : Tensor &prob_left = context.getTensor(wt_idx[MoLAttentionParams::prob_left]);
281 : Tensor &prob_right =
282 0 : context.getTensor(wt_idx[MoLAttentionParams::prob_right]);
283 0 : Tensor &u_neg_div = context.getTensor(wt_idx[MoLAttentionParams::u_neg_div]);
284 0 : Tensor &u_pos_div = context.getTensor(wt_idx[MoLAttentionParams::u_pos_div]);
285 :
286 0 : const TensorDim &input_dim = query.getDim();
287 0 : unsigned int batch = input_dim.batch();
288 0 : auto mol_k = std::get<props::MoL_K>(mol_props).get();
289 :
290 0 : Tensor kappa, beta, alpha;
291 0 : kappa.copy_with_stride(
292 0 : fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, 0, false));
293 0 : beta.copy_with_stride(
294 0 : fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k, false));
295 0 : alpha.copy_with_stride(
296 0 : fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false));
297 :
298 0 : Tensor dscores = Tensor(TensorDim({value.batch(), 1, 1, value.height()}));
299 0 : dscores.dot_batched_deriv_wrt_1(value, derivative);
300 0 : dscores.reshape(TensorDim({scores.batch(), 1, scores.width(), 1}));
301 0 : if (context.getNumInputs() == 4) {
302 0 : Tensor mask = Tensor(dscores.getDim());
303 0 : mask.filter_mask(context.getInput(wt_idx[MoLAttentionParams::mask_len]));
304 0 : dscores.multiply_i(mask);
305 0 : }
306 :
307 0 : Tensor dprob_scaled = Tensor(TensorDim({batch, 1, value.height(), mol_k}));
308 0 : dprob_scaled.setZero();
309 0 : dprob_scaled.add_i(dscores);
310 :
311 0 : Tensor dalpha = dprob_scaled.multiply(prob).sum(2);
312 0 : Tensor dprob = dprob_scaled.multiply(alpha);
313 :
314 0 : Tensor dprob_left = dprob;
315 0 : Tensor dprob_right = dprob.multiply(-1);
316 :
317 0 : Tensor beta_eps = beta.add(1e-8f);
318 0 : Tensor du_neg_div, du_pos_div;
319 0 : sigmoid.run_prime_fn(prob_right, du_neg_div, dprob_right);
320 0 : Tensor du_neg_m = du_neg_div.divide(beta_eps);
321 0 : Tensor dm_neg = du_neg_m.multiply(-1).sum(2);
322 0 : Tensor dbeta_eps_neg = du_neg_m.multiply(u_neg_div).multiply(-1).sum(2);
323 :
324 0 : sigmoid.run_prime_fn(prob_left, du_pos_div, dprob_left);
325 0 : Tensor du_pos_m = du_pos_div.divide(beta_eps);
326 0 : Tensor dm_pos = du_pos_m.multiply(-1).sum(2);
327 0 : Tensor dbeta_eps_pos = du_pos_m.multiply(u_pos_div).multiply(-1).sum(2);
328 :
329 0 : Tensor dbeta_eps = dbeta_eps_neg.add(dbeta_eps_pos);
330 0 : dm_neg.add(dm_pos, dstate);
331 0 : if (context.getNumOutputs() == 2) {
332 0 : const Tensor &derivative_state = context.getIncomingDerivative(1);
333 0 : dstate.add_i(derivative_state);
334 0 : }
335 0 : Tensor dkappa = dstate;
336 0 : Tensor dbeta = dbeta_eps;
337 :
338 0 : Tensor dalpha_src;
339 0 : softmax.run_prime_fn(alpha, dalpha_src, dalpha);
340 :
341 0 : Tensor dkappa_src = dkappa.multiply(kappa);
342 0 : Tensor dbeta_src = dbeta.multiply(beta);
343 :
344 : /** dfc_proj_out shares memory with fc_proj_out */
345 0 : dfc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, 0, false)
346 0 : .copy_with_stride(dkappa_src);
347 0 : dfc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k, false)
348 0 : .copy_with_stride(dbeta_src);
349 0 : dfc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false)
350 0 : .copy_with_stride(dalpha_src);
351 :
352 : /** update the helper state */
353 0 : helper_exec = true;
354 0 : }
355 :
356 0 : void MoLAttentionLayer::calcDerivative(RunLayerContext &context) {
357 : Tensor &dquery =
358 0 : context.getOutgoingDerivative(wt_idx[MoLAttentionParams::query]);
359 : Tensor &dvalue =
360 0 : context.getOutgoingDerivative(wt_idx[MoLAttentionParams::value]);
361 : Tensor &dstate =
362 0 : context.getOutgoingDerivative(wt_idx[MoLAttentionParams::state]);
363 0 : Tensor &dstate_local = context.getTensor(wt_idx[MoLAttentionParams::dstate]);
364 :
365 0 : const Tensor &derivative = context.getIncomingDerivative(SINGLE_INOUT_IDX);
366 :
367 0 : Tensor &fc_w = context.getWeight(wt_idx[MoLAttentionParams::fc_w]);
368 0 : Tensor &fc_proj_w = context.getWeight(wt_idx[MoLAttentionParams::fc_proj_w]);
369 0 : Tensor &fc_tanh = context.getTensor(wt_idx[MoLAttentionParams::fc_tanh]);
370 : Tensor &dfc_proj_out =
371 0 : context.getTensor(wt_idx[MoLAttentionParams::fc_proj_out]);
372 0 : Tensor &scores = context.getTensor(wt_idx[MoLAttentionParams::scores]);
373 :
374 0 : scores.dot_batched_deriv_wrt_2(dvalue, derivative);
375 :
376 0 : if (!helper_exec)
377 0 : calcDerivativeHelper(context, dstate);
378 : else
379 0 : dstate.copyData(dstate_local);
380 :
381 0 : Tensor dfc_tanh = Tensor(fc_tanh.getDim());
382 0 : dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out, false, false);
383 :
384 0 : Tensor dfc_out;
385 0 : tanh.run_prime_fn(fc_tanh, dfc_out, dfc_tanh);
386 0 : dquery.dot_deriv_wrt_1(fc_w, dfc_out, false, false);
387 0 : }
388 :
389 0 : void MoLAttentionLayer::calcGradient(RunLayerContext &context) {
390 0 : Tensor &query = context.getInput(wt_idx[MoLAttentionParams::query]);
391 0 : Tensor &dstate = context.getTensor(wt_idx[MoLAttentionParams::dstate]);
392 :
393 0 : Tensor &fc_proj_w = context.getWeight(wt_idx[MoLAttentionParams::fc_proj_w]);
394 0 : Tensor &dfc_w = context.getWeightGrad(wt_idx[MoLAttentionParams::fc_w]);
395 0 : Tensor &dfc_bias = context.getWeightGrad(wt_idx[MoLAttentionParams::fc_bias]);
396 : Tensor &dfc_proj_w =
397 0 : context.getWeightGrad(wt_idx[MoLAttentionParams::fc_proj_w]);
398 0 : Tensor &fc_tanh = context.getTensor(wt_idx[MoLAttentionParams::fc_tanh]);
399 : Tensor &dfc_proj_out =
400 0 : context.getTensor(wt_idx[MoLAttentionParams::fc_proj_out]);
401 :
402 0 : if (!helper_exec)
403 0 : calcDerivativeHelper(context, dstate);
404 :
405 0 : Tensor dfc_tanh = Tensor(fc_tanh.getDim());
406 0 : fc_tanh.dot_deriv_wrt_2(
407 : dfc_proj_w, dfc_proj_out, false, false,
408 0 : !context.isGradientFirstAccess(wt_idx[MoLAttentionParams::fc_proj_w]));
409 0 : dfc_tanh.dot_deriv_wrt_1(fc_proj_w, dfc_proj_out);
410 :
411 0 : Tensor dfc_out;
412 0 : tanh.run_prime_fn(fc_tanh, dfc_out, dfc_tanh);
413 0 : query.dot_deriv_wrt_2(
414 : dfc_w, dfc_out, false, false,
415 0 : !context.isGradientFirstAccess(wt_idx[MoLAttentionParams::fc_w]));
416 :
417 0 : if (context.isGradientFirstAccess(wt_idx[MoLAttentionParams::fc_bias])) {
418 0 : dfc_out.sum({0, 1, 2}, dfc_bias);
419 : } else {
420 : /// @todo optimize below by adding beta to Tensor::sum
421 0 : Tensor t = dfc_out.sum({0, 1, 2});
422 0 : dfc_bias.add_i(t);
423 0 : }
424 0 : }
425 :
426 35 : void MoLAttentionLayer::setProperty(const std::vector<std::string> &values) {
427 35 : auto remain_props = loadProperties(values, mol_props);
428 34 : LayerImpl::setProperty(remain_props);
429 34 : }
430 :
431 0 : void MoLAttentionLayer::setBatch(RunLayerContext &context, unsigned int batch) {
432 0 : context.updateTensor(wt_idx[MoLAttentionParams::fc_out], batch);
433 0 : context.updateTensor(wt_idx[MoLAttentionParams::fc_tanh], batch);
434 0 : context.updateTensor(wt_idx[MoLAttentionParams::fc_proj_out], batch);
435 0 : context.updateTensor(wt_idx[MoLAttentionParams::scores], batch);
436 0 : context.updateTensor(wt_idx[MoLAttentionParams::prob], batch);
437 0 : context.updateTensor(wt_idx[MoLAttentionParams::prob_left], batch);
438 0 : context.updateTensor(wt_idx[MoLAttentionParams::prob_right], batch);
439 0 : context.updateTensor(wt_idx[MoLAttentionParams::u_neg_div], batch);
440 0 : context.updateTensor(wt_idx[MoLAttentionParams::u_pos_div], batch);
441 0 : context.updateTensor(wt_idx[MoLAttentionParams::dstate], batch);
442 0 : }
443 :
444 0 : void MoLAttentionLayer::exportTo(Exporter &exporter,
445 : const ml::train::ExportMethods &method) const {
446 0 : LayerImpl::exportTo(exporter, method);
447 0 : exporter.saveResult(mol_props, method, this);
448 0 : }
449 :
450 : } /* namespace nntrainer */
|