Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2024 Hyunwoo LEE <dlgusdn0414@snu.ac.kr>
4 : *
5 : * @file lr_scheduler_linear.cpp
6 : * @date 11 November 2024
7 : * @brief This is Linear Learning Rate Scheduler class
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Hyunwoo LEE <dlgusdn0414@snu.ac.kr>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #include <cmath>
15 :
16 : #include <common_properties.h>
17 : #include <lr_scheduler_linear.h>
18 : #include <nntrainer_error.h>
19 : #include <nntrainer_log.h>
20 : #include <node_exporter.h>
21 :
22 : namespace nntrainer {
23 :
24 10 : LinearLearningRateScheduler::LinearLearningRateScheduler() :
25 : lr_props(props::MaxLearningRate(), props::MinLearningRate(),
26 10 : props::DecaySteps()) {}
27 :
28 4 : void LinearLearningRateScheduler::finalize() {
29 5 : NNTR_THROW_IF(std::get<props::MaxLearningRate>(lr_props).empty(),
30 : std::invalid_argument)
31 : << "[LinearLearningRateScheduler] Max Learning Rate is not set";
32 3 : NNTR_THROW_IF(std::get<props::MinLearningRate>(lr_props).empty(),
33 : std::invalid_argument)
34 : << "[LinearLearningRateScheduler] Min Learning Rate is not set";
35 3 : NNTR_THROW_IF(std::get<props::DecaySteps>(lr_props).empty(),
36 : std::invalid_argument)
37 : << "[LinearLearningRateScheduler] Decay Steps is not set";
38 3 : NNTR_THROW_IF(std::get<props::DecaySteps>(lr_props) <= 0,
39 : std::invalid_argument)
40 : << "[LinearLearningRateScheduler] Decay Steps must be a positive integer";
41 3 : }
42 :
43 26 : void LinearLearningRateScheduler::setProperty(
44 : const std::vector<std::string> &values) {
45 26 : auto left = loadProperties(values, lr_props);
46 26 : NNTR_THROW_IF(left.size(), std::invalid_argument)
47 : << "[LinearLearningRateScheduler] There are unparsed properties";
48 25 : }
49 :
50 0 : void LinearLearningRateScheduler::exportTo(
51 : Exporter &exporter, const ml::train::ExportMethods &method) const {
52 0 : exporter.saveResult(lr_props, method, this);
53 0 : }
54 :
55 5 : double LinearLearningRateScheduler::getLearningRate(size_t iteration) {
56 : auto const &max_lr = std::get<props::MaxLearningRate>(lr_props);
57 : auto const &min_lr = std::get<props::MinLearningRate>(lr_props);
58 : auto const &decay_steps = std::get<props::DecaySteps>(lr_props);
59 :
60 : // Linear formula
61 5 : double lr = max_lr - (max_lr - min_lr) * (iteration / (double)decay_steps);
62 :
63 5 : return std::max(lr, (double)min_lr);
64 : }
65 :
66 : } // namespace nntrainer
|