Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file adam.cpp
6 : * @date 6 October 2020
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Jijoong Moon <jijoong.moon@samsung.com>
9 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : * @brief This is the Adam optimizer.
12 : */
13 :
14 : #include <cmath>
15 : #include <fstream>
16 :
17 : #include <adam.h>
18 : #include <nntrainer_error.h>
19 : #include <nntrainer_log.h>
20 : #include <node_exporter.h>
21 : #include <util_func.h>
22 :
23 : namespace nntrainer {
24 :
25 217 : Adam::Adam() : adam_props(PropsB1(), PropsB2(), PropsEpsilon(), TorchRef()) {
26 : /** default properties */
27 : auto &[b1, b2, eps, torch_ref] = adam_props;
28 217 : b1.set(0.9f);
29 217 : b2.set(0.999f);
30 217 : eps.set(1.0e-7f);
31 217 : torch_ref.set(false);
32 217 : }
33 :
34 434 : Adam::~Adam() {}
35 :
36 : enum AdamParams { wm, wv };
37 :
38 210 : std::vector<TensorDim> Adam::getOptimizerVariableDim(const TensorDim &dim) {
39 : /**
40 : * @note We assume the optimizer parameters should be full precsion to
41 : * maintain the accuracy even in mixed precision training.
42 : */
43 210 : TensorDim wm_dim(dim);
44 210 : TensorDim wv_dim(dim);
45 : wm_dim.setDataType(ml::train::TensorDim::DataType::FP32);
46 : wv_dim.setDataType(ml::train::TensorDim::DataType::FP32);
47 210 : return {wm_dim, wv_dim};
48 : }
49 :
50 3 : void Adam::exportTo(Exporter &exporter,
51 : const ml::train::ExportMethods &method) const {
52 3 : exporter.saveResult(adam_props, method, this);
53 : Optimizer::exportTo(exporter, method);
54 3 : }
55 :
56 443 : void Adam::setProperty(const std::vector<std::string> &values) {
57 443 : auto left = loadProperties(values, adam_props);
58 434 : Optimizer::setProperty(left);
59 434 : }
60 :
61 1064 : double Adam::getUpdatedLearningRate(unsigned int iteration, double lr) const {
62 1064 : auto &beta1 = std::get<PropsB1>(adam_props).get();
63 1064 : auto &beta2 = std::get<PropsB2>(adam_props).get();
64 :
65 : /// @note This change suppresses C4244 warnings due to precision loss.
66 : #ifdef _MSC_VER
67 : #pragma warning(push)
68 : #pragma warning(disable : 4244)
69 : #endif
70 : std::function<float(double)> biasCorrection = [&](float f) {
71 2128 : return 1.0f - pow(f, iteration + 1);
72 : };
73 :
74 2128 : lr *= sqrt(biasCorrection(beta2)) / biasCorrection(beta1);
75 : #ifdef _MSC_VER
76 : #pragma warning(pop)
77 : #endif
78 :
79 1064 : return lr;
80 : }
81 :
82 1064 : void Adam::applyGradient(RunOptimizerContext &context) {
83 1064 : Tensor empty_tensor;
84 :
85 : Tensor &x_grad =
86 1064 : context.getGradient().getDataType() == ml::train::TensorDim::DataType::FP32
87 1064 : ? context.getGradient()
88 : : empty_tensor;
89 :
90 1064 : if (x_grad.empty()) {
91 0 : x_grad = context.getGradient().clone(ml::train::TensorDim::DataType::FP32);
92 : }
93 :
94 1064 : context.applyLossScale(x_grad);
95 :
96 1064 : auto &beta1 = std::get<PropsB1>(adam_props).get();
97 1064 : auto &beta2 = std::get<PropsB2>(adam_props).get();
98 1064 : auto &epsilon = std::get<PropsEpsilon>(adam_props).get();
99 1064 : auto &torch_ref = std::get<TorchRef>(adam_props).get();
100 :
101 : // This is implementation of adam from original paper.
102 : // This is not deleted intentionally.
103 1064 : unsigned int iteration = context.getIteration();
104 1064 : float biasCorrection1 = 1.0 - pow((double)beta1, iteration + 1);
105 1064 : float biasCorrection2 = 1.0 - pow((double)beta2, iteration + 1);
106 1064 : Tensor &wm = context.getOptimizerVariable(AdamParams::wm);
107 1064 : Tensor &wv = context.getOptimizerVariable(AdamParams::wv);
108 :
109 1064 : wm.multiply_i(beta1);
110 1064 : wm.add_i(x_grad, 1.0f - beta1);
111 :
112 1064 : wv.multiply_i(beta2);
113 1064 : wv.add_i(x_grad.multiply(x_grad), 1.0f - beta2);
114 :
115 1064 : if (torch_ref) {
116 0 : Tensor denom = wv.apply<float>(sqrtFloat<float>);
117 0 : denom.divide_i(sqrtFloat(biasCorrection2));
118 0 : denom.add_i(epsilon);
119 0 : wm.divide(denom, x_grad);
120 :
121 0 : context.applyGradient(context.getLearningRate() / biasCorrection1, x_grad);
122 :
123 0 : } else {
124 : /// @note This change suppresses C4244 warnings due to precision loss.
125 : #ifdef _MSC_VER
126 : #pragma warning(push)
127 : #pragma warning(disable : 4244)
128 : #endif
129 1064 : std::function<double(double)> sqrtEps = [epsilon](double f) {
130 52680858 : return 1.0 / (sqrtDouble(f) + epsilon);
131 : };
132 :
133 1064 : x_grad = wv.apply<float>(sqrtEps, x_grad);
134 : #ifdef _MSC_VER
135 : #pragma warning(pop)
136 : #endif
137 1064 : x_grad.multiply_i(wm);
138 1064 : context.applyGradient(
139 : getUpdatedLearningRate(context.getIteration(), context.getLearningRate()),
140 : x_grad);
141 : }
142 1064 : }
143 :
144 : } // namespace nntrainer
|