LCOV - code coverage report
Current view: top level - nntrainer/optimizers - lr_scheduler_linear.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 83.3 % 18 15
Test Date: 2025-12-14 20:38:17 Functions: 80.0 % 5 4

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2024 Hyunwoo LEE <dlgusdn0414@snu.ac.kr>
       4              :  *
       5              :  * @file   lr_scheduler_linear.cpp
       6              :  * @date   11 November 2024
       7              :  * @brief  This is Linear Learning Rate Scheduler class
       8              :  * @see    https://github.com/nnstreamer/nntrainer
       9              :  * @author Hyunwoo LEE <dlgusdn0414@snu.ac.kr>
      10              :  * @bug    No known bugs except for NYI items
      11              :  *
      12              :  */
      13              : 
      14              : #include <cmath>
      15              : 
      16              : #include <common_properties.h>
      17              : #include <lr_scheduler_linear.h>
      18              : #include <nntrainer_error.h>
      19              : #include <nntrainer_log.h>
      20              : #include <node_exporter.h>
      21              : 
      22              : namespace nntrainer {
      23              : 
      24           10 : LinearLearningRateScheduler::LinearLearningRateScheduler() :
      25              :   lr_props(props::MaxLearningRate(), props::MinLearningRate(),
      26           10 :            props::DecaySteps()) {}
      27              : 
      28            4 : void LinearLearningRateScheduler::finalize() {
      29            5 :   NNTR_THROW_IF(std::get<props::MaxLearningRate>(lr_props).empty(),
      30              :                 std::invalid_argument)
      31              :     << "[LinearLearningRateScheduler] Max Learning Rate is not set";
      32            3 :   NNTR_THROW_IF(std::get<props::MinLearningRate>(lr_props).empty(),
      33              :                 std::invalid_argument)
      34              :     << "[LinearLearningRateScheduler] Min Learning Rate is not set";
      35            3 :   NNTR_THROW_IF(std::get<props::DecaySteps>(lr_props).empty(),
      36              :                 std::invalid_argument)
      37              :     << "[LinearLearningRateScheduler] Decay Steps is not set";
      38            3 :   NNTR_THROW_IF(std::get<props::DecaySteps>(lr_props) <= 0,
      39              :                 std::invalid_argument)
      40              :     << "[LinearLearningRateScheduler] Decay Steps must be a positive integer";
      41            3 : }
      42              : 
      43           26 : void LinearLearningRateScheduler::setProperty(
      44              :   const std::vector<std::string> &values) {
      45           26 :   auto left = loadProperties(values, lr_props);
      46           26 :   NNTR_THROW_IF(left.size(), std::invalid_argument)
      47              :     << "[LinearLearningRateScheduler] There are unparsed properties";
      48           25 : }
      49              : 
      50            0 : void LinearLearningRateScheduler::exportTo(
      51              :   Exporter &exporter, const ml::train::ExportMethods &method) const {
      52            0 :   exporter.saveResult(lr_props, method, this);
      53            0 : }
      54              : 
      55            5 : double LinearLearningRateScheduler::getLearningRate(size_t iteration) {
      56              :   auto const &max_lr = std::get<props::MaxLearningRate>(lr_props);
      57              :   auto const &min_lr = std::get<props::MinLearningRate>(lr_props);
      58              :   auto const &decay_steps = std::get<props::DecaySteps>(lr_props);
      59              : 
      60              :   // Linear formula
      61            5 :   double lr = max_lr - (max_lr - min_lr) * (iteration / (double)decay_steps);
      62              : 
      63            5 :   return std::max(lr, (double)min_lr);
      64              : }
      65              : 
      66              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1