Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file optimizer_context.h
6 : * @date 30 July 2021
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
9 : * @bug No known bugs except for NYI items
10 : * @brief This is the layer context for each layer
11 : */
12 :
13 : #include <optimizer_context.h>
14 : #include <weight.h>
15 :
16 : namespace nntrainer {
17 :
18 : /**
19 : * @brief Get the Weight tensor object
20 : */
21 0 : Tensor &RunOptimizerContext::getWeight() const {
22 0 : return weight->getVariableRef();
23 : }
24 :
25 : /**
26 : * @brief Get the Weight FP32 tensor object (master weight for mixed precision)
27 : */
28 0 : Tensor &RunOptimizerContext::getWeightFP32() const {
29 0 : return weight->getVariableFP32Ref();
30 : }
31 :
32 : /**
33 : * @brief Get the Weight Gradient tensor object
34 : */
35 31254 : Tensor &RunOptimizerContext::getGradient() const {
36 31254 : return weight->getGradientRef();
37 : }
38 :
39 : /**
40 : * @brief Get the optimizer variable associated to this weight
41 : */
42 2140 : Tensor &RunOptimizerContext::getOptimizerVariable(unsigned int idx) const {
43 2140 : return weight->getOptimizerVariableRef(idx);
44 : }
45 :
46 : /**
47 : * @brief Apply the gradient with the given learning rate
48 : */
49 0 : void RunOptimizerContext::applyGradient(double lr) const {
50 0 : weight->applyGradient(lr);
51 0 : }
52 :
53 : /**
54 : * @brief Apply the gradient with the given learning rate and gradient
55 : */
56 15627 : void RunOptimizerContext::applyGradient(double lr, Tensor &updated_grad) const {
57 15627 : weight->applyGradient(lr, updated_grad);
58 15627 : }
59 :
60 : /**
61 : * @brief Apply loss scale to gradient (full precision)
62 : */
63 1070 : void RunOptimizerContext::applyLossScale(Tensor &fp32_grad) {
64 1070 : if (!weight->isMixedPrecision())
65 1070 : return;
66 0 : if (fp32_grad.getDataType() != ml::train::TensorDim::DataType::FP32)
67 : throw std::invalid_argument(
68 0 : "gradient should be fullprecsion to maintain accuracy");
69 0 : float loss_scale = weight->getLossScale();
70 0 : fp32_grad.divide_i(loss_scale);
71 : }
72 :
73 : /**
74 : * @brief Return if the underlying weight is mixed precision
75 : */
76 0 : bool RunOptimizerContext::isMixedPrecision() const {
77 0 : return weight->isMixedPrecision();
78 : }
79 : } // namespace nntrainer
|