Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file adam.h
6 : * @date 6 October 2020
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Jijoong Moon <jijoong.moon@samsung.com>
9 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : * @brief This is the Adam optimizer.
12 : */
13 : #ifndef __ADAM_H__
14 : #define __ADAM_H__
15 : #ifdef __cplusplus
16 :
17 : #include <tuple>
18 :
19 : #include <base_properties.h>
20 : #include <optimizer_devel.h>
21 :
22 : namespace nntrainer {
23 :
24 : /**
25 : * @brief Beta 1 props
26 : *
27 : */
28 225 : class PropsB1 : public Property<double> {
29 : public:
30 : static constexpr const char *key = "beta1"; /**< unique key to access */
31 : using prop_tag = double_prop_tag; /**< property type */
32 : };
33 :
34 : /**
35 : * @brief Beta 2 props
36 : *
37 : */
38 225 : class PropsB2 : public Property<double> {
39 : public:
40 : static constexpr const char *key = "beta2"; /**< unique key to access */
41 : using prop_tag = double_prop_tag; /**< property type */
42 : };
43 :
44 : /**
45 : * @brief epsilon props
46 : * @todo move this to common props
47 : *
48 : */
49 225 : class PropsEpsilon : public Property<double> {
50 : public:
51 : static constexpr const char *key = "epsilon"; /**< unique key to access */
52 : using prop_tag = double_prop_tag; /**< property type */
53 : };
54 :
55 : /**
56 : * @brief pytorch reference implementation
57 : *
58 : */
59 442 : class TorchRef : public Property<bool> {
60 : public:
61 : static constexpr const char *key = "torch_ref"; /**< unique key to access */
62 : using prop_tag = bool_prop_tag; /**< property type */
63 : };
64 :
65 : /**
66 : * @class Adam optimizer class
67 : * @brief Adam optimizer
68 : */
69 : class Adam : public Optimizer {
70 : public:
71 : /**
72 : * @brief Construct a new Adam object
73 : *
74 : */
75 : Adam();
76 :
77 : /**
78 : * @brief Destroy the Adam object
79 : *
80 : */
81 : ~Adam();
82 :
83 : /**
84 : * @copydoc Optimizer::getDefaultLearningRate()
85 : *
86 : */
87 213 : double getDefaultLearningRate() const override { return 0.001; }
88 :
89 : /**
90 : * @copydoc applyGradient(RunOptimizerContext &context)
91 : */
92 : void applyGradient(RunOptimizerContext &context) override;
93 :
94 : /**
95 : * @copydoc Optimizer::getType()
96 : */
97 1088 : const std::string getType() const override { return Adam::type; }
98 :
99 : /**
100 : * @copydoc Optimizer::getOptimizerVariableDim(const TensorDim &dim)
101 : */
102 : std::vector<TensorDim> getOptimizerVariableDim(const TensorDim &dim) override;
103 :
104 : /**
105 : * @copydoc Optimizer::exportTo(Exporter &exporter, const
106 : * ml::train::ExportMethods& method)
107 : */
108 : void exportTo(Exporter &exporter,
109 : const ml::train::ExportMethods &method) const override;
110 :
111 : static constexpr const char *type = "adam";
112 :
113 : /**
114 : * @copydoc Optimizer::setProperty(const std::vector<std::string> &values)
115 : */
116 : void setProperty(const std::vector<std::string> &values) override;
117 :
118 : private:
119 : std::tuple<PropsB1, PropsB2, PropsEpsilon, TorchRef> adam_props;
120 :
121 : /**
122 : * @brief Get updated learning rate
123 : *
124 : * @param lr learning rate
125 : *
126 : * @return updated learning rate
127 : */
128 : double getUpdatedLearningRate(unsigned int iteration, double lr) const;
129 : };
130 : } /* namespace nntrainer */
131 :
132 : #endif /* __cplusplus */
133 : #endif /* __ADAM_H__ */
|