LCOV - code coverage report
Current view: top level - nntrainer/optimizers - lr_scheduler.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 66.7 % 6 4
Test Date: 2025-12-14 20:38:17 Functions: 50.0 % 2 1

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
       4              :  *
       5              :  * @file   lr_scheduler.h
       6              :  * @date   09 December 2021
       7              :  * @brief  This is Learning Rate Scheduler interface class
       8              :  * @see    https://github.com/nnstreamer/nntrainer
       9              :  * @author Parichay Kapoor <pk.kapoor@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  *
      12              :  */
      13              : 
      14              : #ifndef __LEARNING_RATE_SCHEDULER__
      15              : #define __LEARNING_RATE_SCHEDULER__
      16              : #ifdef __cplusplus
      17              : 
      18              : #include <string>
      19              : 
      20              : #include <optimizer.h>
      21              : 
      22              : namespace nntrainer {
      23              : 
      24              : class Exporter;
      25              : 
      26              : /**
      27              :  * @brief     Enumeration of optimizer type
      28              :  */
      29              : enum LearningRateSchedulerType {
      30              :   CONSTANT = 0, /**< constant */
      31              :   EXPONENTIAL,  /**< exponentially decay */
      32              :   STEP,         /**< step wise decay */
      33              :   COSINE,       /**< cosine annealing */
      34              :   LINEAR        /**< linear decay */
      35              : };
      36              : 
      37              : /**
      38              :  * @class   Learning Rate Schedulers Base class
      39              :  * @brief   Base class for all Learning Rate Schedulers
      40              :  */
      41              : class LearningRateScheduler : public ml::train::LearningRateScheduler {
      42              : 
      43              : public:
      44              :   /**
      45              :    * @brief     Destructor of learning rate scheduler Class
      46              :    */
      47              :   virtual ~LearningRateScheduler() = default;
      48              : 
      49              :   /**
      50              :    * @brief     Finalize creating the learning rate scheduler
      51              :    *
      52              :    * @details   Verify that all the needed properties have been and within the
      53              :    * valid range.
      54              :    * @note      After calling this it is not allowed to
      55              :    * change properties.
      56              :    */
      57              :   virtual void finalize() = 0;
      58              : 
      59              :   /**
      60              :    * @brief     get Learning Rate for the given iteration
      61              :    * @param[in] iteration Iteration for the learning rate
      62              :    * @retval    Learning rate in double
      63              :    * @detail    the return value of this function and getInitialLearningRate()
      64              :    * may not match for iteration == 0 (warmup can lead to different initial
      65              :    * learning rates).
      66              :    *
      67              :    * @note this is non-const function intentionally.
      68              :    */
      69              :   virtual double getLearningRate(size_t iteration) = 0;
      70              : 
      71              :   /**
      72              :    * @brief this function helps exporting the learning rate in a predefined
      73              :    * format, while workarounding issue caused by templated function type eraser
      74              :    *
      75              :    * @param     exporter exporter that contains exporting logic
      76              :    * @param     method enum value to identify how it should be exported to
      77              :    */
      78            0 :   virtual void exportTo(Exporter &exporter,
      79            0 :                         const ml::train::ExportMethods &method) const {}
      80              : 
      81              :   /**
      82              :    * @brief     Default allowed properties
      83              :    * Constant Learning rate scheduler
      84              :    * - learning_rate : float
      85              :    *
      86              :    * Exponential Learning rate scheduler
      87              :    * - learning_rate : float
      88              :    * - decay_rate : float,
      89              :    * - decay_steps : float,
      90              :    *
      91              :    * Cosine Annealing Learning rate scheduler
      92              :    * - max_learning_rate : float
      93              :    * - min_learning_rate : float
      94              :    * - decay_steps : float
      95              :    *
      96              :    * Linear Learning rate scheduler
      97              :    * - max_learning_rate : float
      98              :    * - min_learning_rate : float
      99              :    * - decay_steps : positive integer
     100              :    *
     101              :    * more to be added
     102              :    */
     103              : 
     104              :   /**
     105              :    * @brief     set learning rate scheduler properties
     106              :    * @param[in] values learning rate scheduler properties list
     107              :    * @details   This function accepts vector of properties in the format -
     108              :    *  { std::string property_name = std::string property_val, ...}
     109              :    */
     110              :   virtual void setProperty(const std::vector<std::string> &values) = 0;
     111              : 
     112              :   /**
     113              :    * @brief     get learning rate scheduler Type
     114              :    * @retval    learning rate scheduler type
     115              :    */
     116              :   virtual const std::string getType() const = 0;
     117              : };
     118              : 
     119              : /**
     120              :  * @brief General LR Scheduler Factory function to create LR Scheduler
     121              :  *
     122              :  * @param props property representation
     123              :  * @return std::unique_ptr<nntrainer::LearningRateScheduler> created object
     124              :  */
     125              : template <typename T,
     126              :           std::enable_if_t<std::is_base_of<LearningRateScheduler, T>::value, T>
     127              :             * = nullptr>
     128              : std::unique_ptr<LearningRateScheduler>
     129            6 : createLearningRateScheduler(const std::vector<std::string> &props = {}) {
     130            6 :   std::unique_ptr<LearningRateScheduler> ptr = std::make_unique<T>();
     131            6 :   ptr->setProperty(props);
     132            1 :   return ptr;
     133              : }
     134              : 
     135              : } /* namespace nntrainer */
     136              : 
     137              : #endif /* __cplusplus */
     138              : #endif /* __LEARNING_RATE_SCHEDULER__ */
        

Generated by: LCOV version 2.0-1