Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file lr_scheduler_step.h
6 : * @date 13 December 2021
7 : * @brief This is Step Learning Rate Scheduler class
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #ifndef __LEARNING_RATE_SCHEDULER_STEP__
15 : #define __LEARNING_RATE_SCHEDULER_STEP__
16 : #ifdef __cplusplus
17 :
18 : #include <string>
19 : #include <vector>
20 :
21 : #include <lr_scheduler.h>
22 :
23 : namespace nntrainer {
24 :
25 : namespace props {
26 : class LearningRate;
27 : class Iteration;
28 : } // namespace props
29 :
30 : /**
31 : * @class Step Learning Rate Scheduler class
32 : * @brief class for Step Learning Rate Schedulers
33 : */
34 : class StepLearningRateScheduler final : public LearningRateScheduler {
35 :
36 : public:
37 : /**
38 : * @brief Construct a new step learning rate scheduler object
39 : *
40 : */
41 : StepLearningRateScheduler();
42 :
43 : /**
44 : * @copydoc LearningRateScheduler::getLearningRate(size_t iteration) const
45 : *
46 : */
47 : double getLearningRate(size_t iteration) override;
48 :
49 : /**
50 : * @copydoc LearningRateScheduler::finalize()
51 : *
52 : */
53 : void finalize() override;
54 :
55 : /**
56 : * @copydoc LearningRateScheduler::exportTo(Exporter &exporter, const
57 : * ml::train::ExportMethods& method)
58 : *
59 : */
60 : void exportTo(Exporter &exporter,
61 : const ml::train::ExportMethods &method) const override;
62 :
63 : /**
64 : * @copydoc LearningRateScheduler::setProperty(const std::vector<std::string>
65 : * &values)
66 : */
67 : void setProperty(const std::vector<std::string> &values) override;
68 :
69 : /**
70 : * @copydoc LearningRateScheduler::getType() const
71 : *
72 : */
73 0 : const std::string getType() const override {
74 0 : return StepLearningRateScheduler::type;
75 : }
76 :
77 : static constexpr const char *type = "step";
78 :
79 : private:
80 : std::tuple<std::vector<props::LearningRate>, std::vector<props::Iteration>>
81 : lr_props;
82 : };
83 :
84 : } /* namespace nntrainer */
85 :
86 : #endif /* __cplusplus */
87 : #endif /* __LEARNING_RATE_SCHEDULER_STEP__ */
|