Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file lr_scheduler_exponential.cpp
6 : * @date 09 December 2021
7 : * @brief This is Exponential Learning Rate Scheduler class
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #include <cmath>
15 :
16 : #include <common_properties.h>
17 : #include <lr_scheduler_exponential.h>
18 : #include <nntrainer_error.h>
19 : #include <nntrainer_log.h>
20 : #include <node_exporter.h>
21 :
22 : namespace nntrainer {
23 :
24 105 : ExponentialLearningRateScheduler::ExponentialLearningRateScheduler() :
25 105 : lr_props(props::DecayRate(), props::DecaySteps()) {}
26 :
27 98 : void ExponentialLearningRateScheduler::finalize() {
28 101 : NNTR_THROW_IF(std::get<props::DecayRate>(lr_props).empty(),
29 : std::invalid_argument)
30 : << "[ExponentialLearningRateScheduler] Decay Rate is not set";
31 96 : NNTR_THROW_IF(std::get<props::DecaySteps>(lr_props).empty(),
32 : std::invalid_argument)
33 : << "[ExponentialLearningRateScheduler] Decay Steps is not set";
34 94 : ConstantLearningRateScheduler::finalize();
35 93 : }
36 :
37 290 : void ExponentialLearningRateScheduler::setProperty(
38 : const std::vector<std::string> &values) {
39 290 : auto left = loadProperties(values, lr_props);
40 286 : ConstantLearningRateScheduler::setProperty(left);
41 286 : }
42 :
43 1 : void ExponentialLearningRateScheduler::exportTo(
44 : Exporter &exporter, const ml::train::ExportMethods &method) const {
45 1 : ConstantLearningRateScheduler::exportTo(exporter, method);
46 1 : exporter.saveResult(lr_props, method, this);
47 1 : }
48 :
49 175 : double ExponentialLearningRateScheduler::getLearningRate(size_t iteration) {
50 175 : auto const &lr = ConstantLearningRateScheduler::getLearningRate(iteration);
51 : auto const &[decay_rate, decay_steps] = lr_props;
52 :
53 174 : return lr * pow(decay_rate, (iteration / (float)decay_steps));
54 : }
55 :
56 : } // namespace nntrainer
|