Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2024 Daniel Jang <minhyukjang@snu.ac.kr>
4 : *
5 : * @file adamw.cpp
6 : * @date 3 November 2024
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Jijoong Moon <jijoong.moon@samsung.com>
9 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
10 : * @author Daniel Jang <minhyukjang@snu.ac.kr>
11 : * @bug No known bugs except for NYI items
12 : * @brief This is the AdamW Optimizer.
13 : */
14 :
15 : #include <cmath>
16 : #include <fstream>
17 :
18 : #include <adamw.h>
19 : #include <nntrainer_error.h>
20 : #include <nntrainer_log.h>
21 : #include <node_exporter.h>
22 : #include <util_func.h>
23 :
24 : namespace nntrainer {
25 :
26 8 : AdamW::AdamW() :
27 : adam_props(PropsB1(), PropsB2(), PropsEpsilon(), TorchRef(),
28 8 : PropsWeightDecayW()) {
29 : /** default properties */
30 : auto &[b1, b2, eps, torch_ref, weight_decay] = adam_props;
31 8 : b1.set(0.9f);
32 8 : b2.set(0.999f);
33 8 : eps.set(1.0e-8f);
34 8 : torch_ref.set(false);
35 8 : weight_decay.set(0.0f);
36 8 : }
37 :
38 16 : AdamW::~AdamW() {}
39 :
40 : enum AdamParams { wm, wv };
41 :
42 8 : std::vector<TensorDim> AdamW::getOptimizerVariableDim(const TensorDim &dim) {
43 : /**
44 : * @note We assume the optimizer parameters should be full precision to
45 : * maintain the accuracy even in mixed precision training.
46 : */
47 8 : TensorDim wm_dim(dim);
48 8 : TensorDim wv_dim(dim);
49 : wm_dim.setDataType(ml::train::TensorDim::DataType::FP32);
50 : wv_dim.setDataType(ml::train::TensorDim::DataType::FP32);
51 8 : return {wm_dim, wv_dim};
52 : }
53 :
54 2 : void AdamW::exportTo(Exporter &exporter,
55 : const ml::train::ExportMethods &method) const {
56 2 : exporter.saveResult(adam_props, method, this);
57 : Optimizer::exportTo(exporter, method);
58 2 : }
59 :
60 13 : void AdamW::setProperty(const std::vector<std::string> &values) {
61 13 : auto left = loadProperties(values, adam_props);
62 11 : Optimizer::setProperty(left);
63 11 : }
64 :
65 6 : double AdamW::getUpdatedLearningRate(unsigned int iteration, double lr) const {
66 6 : auto &beta1 = std::get<PropsB1>(adam_props).get();
67 6 : auto &beta2 = std::get<PropsB2>(adam_props).get();
68 : auto biasCorrection = [&](double f) {
69 6 : return 1.0 - (double)pow(f, iteration + 1);
70 : };
71 6 : lr *= sqrt(biasCorrection(beta2)) / biasCorrection(beta1);
72 6 : return lr;
73 : }
74 :
75 6 : void AdamW::applyGradient(RunOptimizerContext &context) {
76 6 : Tensor empty_tensor;
77 :
78 : Tensor &x_grad =
79 6 : context.getGradient().getDataType() == ml::train::TensorDim::DataType::FP32
80 6 : ? context.getGradient()
81 : : empty_tensor;
82 :
83 6 : if (x_grad.empty()) {
84 0 : x_grad = context.getGradient().clone(ml::train::TensorDim::DataType::FP32);
85 : }
86 :
87 6 : context.applyLossScale(x_grad);
88 :
89 6 : auto &beta1 = std::get<PropsB1>(adam_props).get();
90 6 : auto &beta2 = std::get<PropsB2>(adam_props).get();
91 6 : auto &epsilon = std::get<PropsEpsilon>(adam_props).get();
92 6 : auto &weight_decay = std::get<PropsWeightDecayW>(adam_props).get();
93 :
94 6 : Tensor &wm = context.getOptimizerVariable(AdamParams::wm);
95 6 : Tensor &wv = context.getOptimizerVariable(AdamParams::wv);
96 :
97 6 : wm.multiply_i(beta1);
98 6 : wm.add_i(x_grad, 1.0f - beta1);
99 :
100 6 : wv.multiply_i(beta2);
101 6 : wv.add_i(x_grad.multiply(x_grad), 1.0f - beta2);
102 :
103 : // Decoupled weight decay: w = w - lr * wd * w
104 6 : if (weight_decay > 0.0) {
105 0 : Tensor &w = context.isMixedPrecision() ? context.getWeightFP32()
106 0 : : context.getWeight();
107 0 : w.multiply_i(1.0f - (context.getLearningRate() * weight_decay));
108 : }
109 :
110 : // Adam update with bias-corrected lr
111 : double lr_t =
112 6 : getUpdatedLearningRate(context.getIteration(), context.getLearningRate());
113 :
114 6 : std::function<double(double)> sqrtEps = [epsilon](double f) {
115 48 : return 1.0 / (sqrtDouble(f) + epsilon);
116 : };
117 6 : x_grad = wv.apply<float>(sqrtEps, x_grad);
118 6 : x_grad.multiply_i(wm);
119 6 : context.applyGradient(lr_t, x_grad);
120 6 : }
121 :
122 : } // namespace nntrainer
|