Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2025 Jeonghun Park <top231902@naver.com>
4 : *
5 : * @file lion.h
6 : * @date 1 December 2025
7 : * @see https://github.com/nntrainer/nntrainer
8 : * @author Jeonghun Park <top231902@naver.com>
9 : * @author Minseo Kim <ms05251@naver.com>
10 : * @bug No known bugs except for NYI items
11 : * @brief This is the Lion Optimizer Header.
12 : */
13 :
14 : #ifndef __LION_H__
15 : #define __LION_H__
16 : #ifdef __cplusplus
17 :
18 : #include <adam.h>
19 : #include <base_properties.h>
20 : #include <optimizer_devel.h>
21 : #include <tuple>
22 :
23 : namespace nntrainer {
24 :
25 : /**
26 : * @brief weight decay property
27 : *
28 : */
29 1 : class PropsWeightDecayLion : public Property<double> {
30 : public:
31 : static constexpr const char *key =
32 : "weight_decay"; /**< unique key to access */
33 : using prop_tag = double_prop_tag; /**< property type */
34 : };
35 :
36 : /**
37 : * @class Lion Optimizer class
38 : * @brief Lion Optimizer (E. Chen et al., 2023)
39 : */
40 : class Lion : public Optimizer {
41 : public:
42 : /**
43 : * @brief Construct a new Lion object
44 : */
45 : Lion();
46 :
47 : /**
48 : * @brief Destroy the Lion object
49 : */
50 : ~Lion();
51 :
52 : /**
53 : * @copydoc Optimizer::getDefaultLearningRate()
54 : */
55 1 : double getDefaultLearningRate() const override { return 1e-4; }
56 :
57 : /**
58 : * @copydoc Optimizer::applyGradient(RunOptimizerContext &context)
59 : */
60 : void applyGradient(RunOptimizerContext &context) override;
61 :
62 : /**
63 : * @copydoc Optimizer::getType()
64 : */
65 0 : const std::string getType() const override { return Lion::type; }
66 :
67 : /**
68 : * @copydoc Optimizer::getOptimizerVariableDim(const TensorDim &dim)
69 : */
70 : std::vector<TensorDim> getOptimizerVariableDim(const TensorDim &dim) override;
71 :
72 : /**
73 : * @copydoc Optimizer::exportTo(Exporter &exporter,
74 : * const ml::train::ExportMethods &method)
75 : */
76 : void exportTo(Exporter &exporter,
77 : const ml::train::ExportMethods &method) const override;
78 :
79 : /**
80 : * @copydoc Optimizer::setProperty(const std::vector<std::string> &values)
81 : */
82 : void setProperty(const std::vector<std::string> &values) override;
83 :
84 : static constexpr const char *type = "lion";
85 :
86 : private:
87 : std::tuple<PropsB1, PropsB2, PropsWeightDecayLion> lion_props;
88 : };
89 : } /* namespace nntrainer */
90 :
91 : #endif /* __cplusplus */
92 : #endif /* __LION_H__ */
|