LCOV - code coverage report
Current view: top level - api/ccapi/include - optimizer.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 100.0 % 10 10
Test Date: 2026-01-12 20:43:37 Functions: 100.0 % 5 5

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
       4              :  *
       5              :  * @file   optimizer.h
       6              :  * @date   14 October 2020
       7              :  * @see    https://github.com/nntrainer/nntrainer
       8              :  * @author Jijoong Moon <jijoong.moon@samsung.com>
       9              :  * @author Parichay Kapoor <pk.kapoor@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  * @brief  This is optimizers interface for c++ API
      12              :  *
      13              :  * @note This is experimental API and not stable.
      14              :  */
      15              : 
      16              : #ifndef __ML_TRAIN_OPTIMIZER_H__
      17              : #define __ML_TRAIN_OPTIMIZER_H__
      18              : 
      19              : #if __cplusplus >= MIN_CPP_VERSION
      20              : 
      21              : #include <string>
      22              : #include <vector>
      23              : 
      24              : #include <common.h>
      25              : 
      26              : namespace ml {
      27              : namespace train {
      28              : 
      29              : /** forward declaration */
      30              : class LearningRateScheduler;
      31              : 
      32              : /**
      33              :  * @brief     Enumeration of optimizer type
      34              :  */
      35              : enum OptimizerType {
      36              :   ADAM = ML_TRAIN_OPTIMIZER_TYPE_ADAM,      /** adam */
      37              :   ADAMW = ML_TRAIN_OPTIMIZER_TYPE_ADAMW,    /** AdamW */
      38              :   LION = ML_TRAIN_OPTIMIZER_TYPE_LION,      /** Lion */
      39              :   SGD = ML_TRAIN_OPTIMIZER_TYPE_SGD,        /** sgd */
      40              :   UNKNOWN = ML_TRAIN_OPTIMIZER_TYPE_UNKNOWN /** unknown */
      41              : };
      42              : 
      43              : /**
      44              :  * @class   Optimizer Base class for optimizers
      45              :  * @brief   Base class for all optimizers
      46              :  */
      47              : class Optimizer {
      48              : public:
      49              :   /**
      50              :    * @brief     Destructor of Optimizer Class
      51              :    */
      52              :   virtual ~Optimizer() = default;
      53              : 
      54              :   /**
      55              :    * @brief     get Optimizer Type
      56              :    * @retval    Optimizer type
      57              :    */
      58              :   virtual const std::string getType() const = 0;
      59              : 
      60              :   /**
      61              :    * @brief     Default allowed properties
      62              :    * Available for all optimizers
      63              :    * - learning_rate : float
      64              :    *
      65              :    * Available for SGD and Adam optimizers
      66              :    * - decay_rate : float,
      67              :    * - decay_steps : float,
      68              :    *
      69              :    * Available for Adam optimizer
      70              :    * - beta1 : float,
      71              :    * - beta2 : float,
      72              :    * - epsilon : float,
      73              :    */
      74              : 
      75              :   /**
      76              :    * @brief     set Optimizer Parameters
      77              :    * @param[in] values Optimizer Parameter list
      78              :    * @details   This function accepts vector of properties in the format -
      79              :    *  { std::string property_name, void * property_val, ...}
      80              :    */
      81              :   virtual void setProperty(const std::vector<std::string> &values) = 0;
      82              : 
      83              :   /**
      84              :    * @brief Set the Learning Rate Scheduler object
      85              :    *
      86              :    * @param lrs the learning rate scheduler object
      87              :    */
      88              :   virtual int setLearningRateScheduler(
      89              :     std::shared_ptr<ml::train::LearningRateScheduler> lrs) = 0;
      90              : };
      91              : 
      92              : /**
      93              :  * @brief Factory creator with constructor for optimizer
      94              :  */
      95              : std::unique_ptr<Optimizer>
      96              : createOptimizer(const std::string &type,
      97              :                 const std::vector<std::string> &properties = {});
      98              : 
      99              : /**
     100              :  * @brief Factory creator with constructor for optimizer
     101              :  */
     102              : std::unique_ptr<Optimizer>
     103              : createOptimizer(const OptimizerType &type,
     104              :                 const std::vector<std::string> &properties = {});
     105              : 
     106              : /**
     107              :  * @brief General Optimizer Factory function to register optimizer
     108              :  *
     109              :  * @param props property representation
     110              :  * @return std::unique_ptr<ml::train::Optimizer> created object
     111              :  */
     112              : template <typename T,
     113              :           std::enable_if_t<std::is_base_of<Optimizer, T>::value, T> * = nullptr>
     114              : std::unique_ptr<Optimizer>
     115              : createOptimizer(const std::vector<std::string> &props = {}) {
     116              :   std::unique_ptr<Optimizer> ptr = std::make_unique<T>();
     117              : 
     118              :   ptr->setProperty(props);
     119              :   return ptr;
     120              : }
     121              : 
     122              : namespace optimizer {
     123              : 
     124              : /**
     125              :  * @brief Helper function to create adam optimizer
     126              :  */
     127              : inline std::unique_ptr<Optimizer>
     128              : Adam(const std::vector<std::string> &properties = {}) {
     129            5 :   return createOptimizer(OptimizerType::ADAM, properties);
     130              : }
     131              : 
     132              : /**
     133              :  * @brief Helper function to create sgd optimizer
     134              :  */
     135              : inline std::unique_ptr<Optimizer>
     136              : SGD(const std::vector<std::string> &properties = {}) {
     137            3 :   return createOptimizer(OptimizerType::SGD, properties);
     138              : }
     139              : 
     140              : /**
     141              :  * @brief Helper function to create AdamW Optimizer
     142              :  */
     143              : inline std::unique_ptr<Optimizer>
     144              : AdamW(const std::vector<std::string> &properties = {}) {
     145            1 :   return createOptimizer(OptimizerType::ADAMW, properties);
     146              : }
     147              : 
     148              : /**
     149              :  * @brief Helper function to create Lion Optimizer
     150              :  */
     151              : inline std::unique_ptr<Optimizer>
     152              : Lion(const std::vector<std::string> &properties = {}) {
     153            1 :   return createOptimizer(OptimizerType::LION, properties);
     154              : }
     155              : 
     156              : } // namespace optimizer
     157              : 
     158              : /**
     159              :  * @brief     Enumeration of learning rate scheduler type
     160              :  */
     161              : enum LearningRateSchedulerType {
     162              :   CONSTANT = ML_TRAIN_LR_SCHEDULER_TYPE_CONSTANT, /**< constant */
     163              :   EXPONENTIAL =
     164              :     ML_TRAIN_LR_SCHEDULER_TYPE_EXPONENTIAL,  /**< exponentially decay */
     165              :   STEP = ML_TRAIN_LR_SCHEDULER_TYPE_STEP,    /**< step wise decay */
     166              :   COSINE = ML_TRAIN_LR_SCHEDULER_TYPE_COSINE /**< cosine annealing */
     167              : };
     168              : 
     169              : /**
     170              :  * @class   Learning Rate Schedulers Base class
     171              :  * @brief   Base class for all Learning Rate Schedulers
     172              :  */
     173              : class LearningRateScheduler {
     174              : 
     175              : public:
     176              :   /**
     177              :    * @brief     Destructor of learning rate scheduler Class
     178              :    */
     179              :   virtual ~LearningRateScheduler() = default;
     180              : 
     181              :   /**
     182              :    * @brief     Default allowed properties
     183              :    * Constant Learning rate scheduler
     184              :    * - learning_rate : float
     185              :    *
     186              :    * Exponential Learning rate scheduler
     187              :    * - learning_rate : float
     188              :    * - decay_rate : float,
     189              :    * - decay_steps : float,
     190              :    *
     191              :    * Step Learning rate scheduler
     192              :    * - learing_rate : float, float, ...
     193              :    * - iteration : uint, uint, ...
     194              :    *
     195              :    * more to be added
     196              :    */
     197              : 
     198              :   /**
     199              :    * @brief     set learning rate scheduler properties
     200              :    * @param[in] values learning rate scheduler properties list
     201              :    * @details   This function accepts vector of properties in the format -
     202              :    *  { std::string property_name = std::string property_val, ...}
     203              :    */
     204              :   virtual void setProperty(const std::vector<std::string> &values) = 0;
     205              : 
     206              :   /**
     207              :    * @brief     get learning rate scheduler Type
     208              :    * @retval    learning rate scheduler type
     209              :    */
     210              :   virtual const std::string getType() const = 0;
     211              : };
     212              : 
     213              : /**
     214              :  * @brief Factory creator with constructor for learning rate scheduler type
     215              :  */
     216              : std::unique_ptr<ml::train::LearningRateScheduler>
     217              : createLearningRateScheduler(const LearningRateSchedulerType &type,
     218              :                             const std::vector<std::string> &properties = {});
     219              : 
     220              : /**
     221              :  * @brief Factory creator with constructor for learning rate scheduler
     222              :  */
     223              : std::unique_ptr<ml::train::LearningRateScheduler>
     224              : createLearningRateScheduler(const std::string &type,
     225              :                             const std::vector<std::string> &properties = {});
     226              : 
     227              : /**
     228              :  * @brief General LR Scheduler Factory function to create LR Scheduler
     229              :  *
     230              :  * @param props property representation
     231              :  * @return std::unique_ptr<nntrainer::LearningRateScheduler> created object
     232              :  */
     233              : template <typename T,
     234              :           std::enable_if_t<std::is_base_of<LearningRateScheduler, T>::value, T>
     235              :             * = nullptr>
     236              : std::unique_ptr<LearningRateScheduler>
     237           74 : createLearningRateScheduler(const std::vector<std::string> &props = {}) {
     238           74 :   std::unique_ptr<LearningRateScheduler> ptr = std::make_unique<T>();
     239           74 :   ptr->setProperty(props);
     240           74 :   return ptr;
     241              : }
     242              : 
     243              : namespace optimizer {
     244              : namespace learning_rate {
     245              : 
     246              : /**
     247              :  * @brief Helper function to create constant learning rate scheduler
     248              :  */
     249              : inline std::unique_ptr<LearningRateScheduler>
     250              : Constant(const std::vector<std::string> &properties = {}) {
     251              :   return createLearningRateScheduler(LearningRateSchedulerType::CONSTANT,
     252              :                                      properties);
     253              : }
     254              : 
     255              : /**
     256              :  * @brief Helper function to create exponential learning rate scheduler
     257              :  */
     258              : inline std::unique_ptr<LearningRateScheduler>
     259              : Exponential(const std::vector<std::string> &properties = {}) {
     260            2 :   return createLearningRateScheduler(LearningRateSchedulerType::EXPONENTIAL,
     261            2 :                                      properties);
     262              : }
     263              : 
     264              : /**
     265              :  * @brief Helper function to create step learning rate scheduler
     266              :  */
     267              : inline std::unique_ptr<LearningRateScheduler>
     268              : Step(const std::vector<std::string> &properties = {}) {
     269              :   return createLearningRateScheduler(LearningRateSchedulerType::STEP,
     270              :                                      properties);
     271              : }
     272              : 
     273              : /**
     274              :  * @brief Helper function to create cosine learning rate scheduler
     275              :  */
     276              : inline std::unique_ptr<LearningRateScheduler>
     277              : Cosine(const std::vector<std::string> &properties = {}) {
     278              :   return createLearningRateScheduler(LearningRateSchedulerType::COSINE,
     279              :                                      properties);
     280              : }
     281              : 
     282              : } // namespace learning_rate
     283              : } // namespace optimizer
     284              : 
     285              : } // namespace train
     286              : } // namespace ml
     287              : 
     288              : #else
     289              : #error "CPP versions c++17 or over are only supported"
     290              : #endif // __cpluscplus
     291              : #endif // __ML_TRAIN_OPTIMIZER_H__
        

Generated by: LCOV version 2.0-1