Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2024 Daniel Jang <minhyukjang@snu.ac.kr>
4 : *
5 : * @file adamw.h
6 : * @date 3 November 2024
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Jijoong Moon <jijoong.moon@samsung.com>
9 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
10 : * @author Daniel Jang <minhyukjang@snu.ac.kr>
11 : * @bug No known bugs except for NYI items
12 : * @brief This is the AdamW Optimizer.
13 : */
14 : #ifndef __ADAMW_H__
15 : #define __ADAMW_H__
16 : #ifdef __cplusplus
17 :
18 : #include <tuple>
19 :
20 : #include <adam.h>
21 :
22 : #include <base_properties.h>
23 : #include <optimizer_devel.h>
24 :
25 : namespace nntrainer {
26 :
27 : /**
28 : * @brief weight decay property for AdamW
29 : */
30 8 : class PropsWeightDecayW : public Property<double> {
31 : public:
32 : static constexpr const char *key = "weight_decay";
33 : using prop_tag = double_prop_tag;
34 : };
35 :
36 : /**
37 : * @class AdamW Optimizer class
38 : * @brief AdamW Optimizer
39 : */
40 : class AdamW : public Optimizer {
41 : public:
42 : /**
43 : * @brief Construct a new AdamW object
44 : *
45 : */
46 : AdamW();
47 :
48 : /**
49 : * @brief Destroy the AdamW object
50 : *
51 : */
52 : ~AdamW();
53 :
54 : /**
55 : * @copydoc Optimizer::getDefaultLearningRate()
56 : *
57 : */
58 5 : double getDefaultLearningRate() const override { return 0.001; }
59 :
60 : /**
61 : * @copydoc applyGradient(RunOptimizerContext &context)
62 : */
63 : void applyGradient(RunOptimizerContext &context) override;
64 :
65 : /**
66 : * @copydoc Optimizer::getType()
67 : */
68 8 : const std::string getType() const override { return AdamW::type; }
69 :
70 : /**
71 : * @copydoc Optimizer::getOptimizerVariableDim(const TensorDim &dim)
72 : */
73 : std::vector<TensorDim> getOptimizerVariableDim(const TensorDim &dim) override;
74 :
75 : /**
76 : * @copydoc Optimizer::exportTo(Exporter &exporter, const
77 : * ml::train::ExportMethods& method)
78 : */
79 : void exportTo(Exporter &exporter,
80 : const ml::train::ExportMethods &method) const override;
81 :
82 : static constexpr const char *type = "adamw";
83 :
84 : /**
85 : * @copydoc Optimizer::setProperty(const std::vector<std::string> &values)
86 : */
87 : void setProperty(const std::vector<std::string> &values) override;
88 :
89 : private:
90 : std::tuple<PropsB1, PropsB2, PropsEpsilon, TorchRef, PropsWeightDecayW>
91 : adam_props;
92 :
93 : /**
94 : * @brief Get updated learning rate
95 : *
96 : * @param lr learning rate
97 : *
98 : * @return updated learning rate
99 : */
100 : double getUpdatedLearningRate(unsigned int iteration, double lr) const;
101 : };
102 : } /* namespace nntrainer */
103 :
104 : #endif /* __cplusplus */
105 : #endif /* __ADAMW_H__ */
|