Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2025 Jeonghun Park <top231902@naver.com>
4 : *
5 : * @file lion.cpp
6 : * @date 1 December 2025
7 : * @see https://github.com/nntrainer/nntrainer
8 : * @author Jeonghun Park <top231902@naver.com>
9 : * @author Minseo Kim <ms05251@naver.com>
10 : * @bug No known bugs except for NYI items
11 : * @brief This is the Lion Optimizer.
12 : */
13 :
14 : #include <cmath>
15 : #include <fstream>
16 :
17 : #include <lion.h>
18 : #include <nntrainer_error.h>
19 : #include <nntrainer_log.h>
20 : #include <node_exporter.h>
21 : #include <util_func.h>
22 :
23 : namespace nntrainer {
24 :
25 1 : Lion::Lion() : lion_props(PropsB1(), PropsB2(), PropsWeightDecayLion()) {
26 : auto &[beta1, beta2, weight_decay] = lion_props;
27 1 : beta1.set(0.9f);
28 1 : beta2.set(0.99f);
29 1 : weight_decay.set(0.0f);
30 1 : }
31 :
32 2 : Lion::~Lion() {}
33 :
34 : enum LionParams { m };
35 :
36 0 : std::vector<TensorDim> Lion::getOptimizerVariableDim(const TensorDim &dim) {
37 0 : TensorDim m_dim(dim);
38 : m_dim.setDataType(ml::train::TensorDim::DataType::FP32);
39 0 : return {m_dim};
40 : }
41 :
42 0 : void Lion::exportTo(Exporter &exporter,
43 : const ml::train::ExportMethods &method) const {
44 0 : exporter.saveResult(lion_props, method, this);
45 : Optimizer::exportTo(exporter, method);
46 0 : }
47 :
48 2 : void Lion::setProperty(const std::vector<std::string> &values) {
49 2 : auto left = loadProperties(values, lion_props);
50 2 : Optimizer::setProperty(left);
51 2 : }
52 :
53 0 : void Lion::applyGradient(RunOptimizerContext &context) {
54 : // 1. Get Tensors and Properties
55 0 : Tensor empty_tensor;
56 : Tensor &x_grad =
57 0 : context.getGradient().getDataType() == ml::train::TensorDim::DataType::FP32
58 0 : ? context.getGradient()
59 : : empty_tensor;
60 :
61 0 : if (x_grad.empty()) {
62 0 : x_grad = context.getGradient().clone(ml::train::TensorDim::DataType::FP32);
63 : }
64 :
65 0 : context.applyLossScale(x_grad);
66 :
67 0 : Tensor &m = context.getOptimizerVariable(LionParams::m);
68 :
69 0 : auto &beta1 = std::get<PropsB1>(lion_props).get();
70 0 : auto &beta2 = std::get<PropsB2>(lion_props).get();
71 0 : auto &weight_decay = std::get<PropsWeightDecayLion>(lion_props).get();
72 0 : float lr = context.getLearningRate();
73 :
74 0 : Tensor original_x_grad = x_grad.clone();
75 :
76 : // 2. Calculate interpolated momentum: c_t = beta1 * m_t + (1 - beta1) * g_t
77 0 : x_grad.multiply_i(1.0 - beta1);
78 0 : x_grad.add_i(m, beta1);
79 :
80 : // 3. Update momentum for next iteration: m_{t+1} = beta2 * m_t + (1 - beta2)
81 : // * g_t
82 0 : m.multiply_i(beta2);
83 0 : m.add_i(original_x_grad, 1.0 - beta2);
84 :
85 : // 4. Take the sign of the interpolated momentum
86 : std::function<float(float)> sign_func = [](float val) {
87 0 : if (val > 0.0f)
88 : return 1.0f;
89 0 : if (val < 0.0f)
90 0 : return -1.0f;
91 : return 0.0f;
92 : };
93 0 : x_grad.apply_i<float>(sign_func);
94 :
95 : // 5. Add decoupled weight decay term. w = w - lr * wd * w
96 0 : if (weight_decay > 0.0) {
97 0 : Tensor &w = context.isMixedPrecision() ? context.getWeightFP32()
98 0 : : context.getWeight();
99 0 : w.multiply_i(1.0f - (context.getLearningRate() * weight_decay));
100 : }
101 :
102 : // 6. Apply the final gradient update
103 0 : context.applyGradient(lr, x_grad);
104 0 : }
105 :
106 : } // namespace nntrainer
|