Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file lr_scheduler.h
6 : * @date 09 December 2021
7 : * @brief This is Learning Rate Scheduler interface class
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #ifndef __LEARNING_RATE_SCHEDULER__
15 : #define __LEARNING_RATE_SCHEDULER__
16 : #ifdef __cplusplus
17 :
18 : #include <string>
19 :
20 : #include <optimizer.h>
21 :
22 : namespace nntrainer {
23 :
24 : class Exporter;
25 :
26 : /**
27 : * @brief Enumeration of optimizer type
28 : */
29 : enum LearningRateSchedulerType {
30 : CONSTANT = 0, /**< constant */
31 : EXPONENTIAL, /**< exponentially decay */
32 : STEP, /**< step wise decay */
33 : COSINE, /**< cosine annealing */
34 : LINEAR /**< linear decay */
35 : };
36 :
37 : /**
38 : * @class Learning Rate Schedulers Base class
39 : * @brief Base class for all Learning Rate Schedulers
40 : */
41 : class LearningRateScheduler : public ml::train::LearningRateScheduler {
42 :
43 : public:
44 : /**
45 : * @brief Destructor of learning rate scheduler Class
46 : */
47 : virtual ~LearningRateScheduler() = default;
48 :
49 : /**
50 : * @brief Finalize creating the learning rate scheduler
51 : *
52 : * @details Verify that all the needed properties have been and within the
53 : * valid range.
54 : * @note After calling this it is not allowed to
55 : * change properties.
56 : */
57 : virtual void finalize() = 0;
58 :
59 : /**
60 : * @brief get Learning Rate for the given iteration
61 : * @param[in] iteration Iteration for the learning rate
62 : * @retval Learning rate in double
63 : * @detail the return value of this function and getInitialLearningRate()
64 : * may not match for iteration == 0 (warmup can lead to different initial
65 : * learning rates).
66 : *
67 : * @note this is non-const function intentionally.
68 : */
69 : virtual double getLearningRate(size_t iteration) = 0;
70 :
71 : /**
72 : * @brief this function helps exporting the learning rate in a predefined
73 : * format, while workarounding issue caused by templated function type eraser
74 : *
75 : * @param exporter exporter that contains exporting logic
76 : * @param method enum value to identify how it should be exported to
77 : */
78 0 : virtual void exportTo(Exporter &exporter,
79 0 : const ml::train::ExportMethods &method) const {}
80 :
81 : /**
82 : * @brief Default allowed properties
83 : * Constant Learning rate scheduler
84 : * - learning_rate : float
85 : *
86 : * Exponential Learning rate scheduler
87 : * - learning_rate : float
88 : * - decay_rate : float,
89 : * - decay_steps : float,
90 : *
91 : * Cosine Annealing Learning rate scheduler
92 : * - max_learning_rate : float
93 : * - min_learning_rate : float
94 : * - decay_steps : float
95 : *
96 : * Linear Learning rate scheduler
97 : * - max_learning_rate : float
98 : * - min_learning_rate : float
99 : * - decay_steps : positive integer
100 : *
101 : * more to be added
102 : */
103 :
104 : /**
105 : * @brief set learning rate scheduler properties
106 : * @param[in] values learning rate scheduler properties list
107 : * @details This function accepts vector of properties in the format -
108 : * { std::string property_name = std::string property_val, ...}
109 : */
110 : virtual void setProperty(const std::vector<std::string> &values) = 0;
111 :
112 : /**
113 : * @brief get learning rate scheduler Type
114 : * @retval learning rate scheduler type
115 : */
116 : virtual const std::string getType() const = 0;
117 : };
118 :
119 : /**
120 : * @brief General LR Scheduler Factory function to create LR Scheduler
121 : *
122 : * @param props property representation
123 : * @return std::unique_ptr<nntrainer::LearningRateScheduler> created object
124 : */
125 : template <typename T,
126 : std::enable_if_t<std::is_base_of<LearningRateScheduler, T>::value, T>
127 : * = nullptr>
128 : std::unique_ptr<LearningRateScheduler>
129 6 : createLearningRateScheduler(const std::vector<std::string> &props = {}) {
130 6 : std::unique_ptr<LearningRateScheduler> ptr = std::make_unique<T>();
131 6 : ptr->setProperty(props);
132 1 : return ptr;
133 : }
134 :
135 : } /* namespace nntrainer */
136 :
137 : #endif /* __cplusplus */
138 : #endif /* __LEARNING_RATE_SCHEDULER__ */
|