LCOV - code coverage report
Current view: top level - nntrainer/optimizers - adam.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 100.0 % 6 6
Test Date: 2025-12-14 20:38:17 Functions: 100.0 % 2 2

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
       4              :  *
       5              :  * @file   adam.h
       6              :  * @date   6 October 2020
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Jijoong Moon <jijoong.moon@samsung.com>
       9              :  * @author Parichay Kapoor <pk.kapoor@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  * @brief  This is the Adam optimizer.
      12              :  */
      13              : #ifndef __ADAM_H__
      14              : #define __ADAM_H__
      15              : #ifdef __cplusplus
      16              : 
      17              : #include <tuple>
      18              : 
      19              : #include <base_properties.h>
      20              : #include <optimizer_devel.h>
      21              : 
      22              : namespace nntrainer {
      23              : 
      24              : /**
      25              :  * @brief Beta 1 props
      26              :  *
      27              :  */
      28          225 : class PropsB1 : public Property<double> {
      29              : public:
      30              :   static constexpr const char *key = "beta1"; /**< unique key to access */
      31              :   using prop_tag = double_prop_tag;           /**< property type */
      32              : };
      33              : 
      34              : /**
      35              :  * @brief Beta 2 props
      36              :  *
      37              :  */
      38          225 : class PropsB2 : public Property<double> {
      39              : public:
      40              :   static constexpr const char *key = "beta2"; /**< unique key to access */
      41              :   using prop_tag = double_prop_tag;           /**< property type */
      42              : };
      43              : 
      44              : /**
      45              :  * @brief epsilon props
      46              :  * @todo move this to common props
      47              :  *
      48              :  */
      49          225 : class PropsEpsilon : public Property<double> {
      50              : public:
      51              :   static constexpr const char *key = "epsilon"; /**< unique key to access */
      52              :   using prop_tag = double_prop_tag;             /**< property type */
      53              : };
      54              : 
      55              : /**
      56              :  * @brief pytorch reference implementation
      57              :  *
      58              :  */
      59          442 : class TorchRef : public Property<bool> {
      60              : public:
      61              :   static constexpr const char *key = "torch_ref"; /**< unique key to access */
      62              :   using prop_tag = bool_prop_tag;                 /**< property type */
      63              : };
      64              : 
      65              : /**
      66              :  * @class   Adam optimizer class
      67              :  * @brief   Adam optimizer
      68              :  */
      69              : class Adam : public Optimizer {
      70              : public:
      71              :   /**
      72              :    * @brief Construct a new Adam object
      73              :    *
      74              :    */
      75              :   Adam();
      76              : 
      77              :   /**
      78              :    * @brief Destroy the Adam object
      79              :    *
      80              :    */
      81              :   ~Adam();
      82              : 
      83              :   /**
      84              :    * @copydoc Optimizer::getDefaultLearningRate()
      85              :    *
      86              :    */
      87          213 :   double getDefaultLearningRate() const override { return 0.001; }
      88              : 
      89              :   /**
      90              :    * @copydoc applyGradient(RunOptimizerContext &context)
      91              :    */
      92              :   void applyGradient(RunOptimizerContext &context) override;
      93              : 
      94              :   /**
      95              :    * @copydoc Optimizer::getType()
      96              :    */
      97         1088 :   const std::string getType() const override { return Adam::type; }
      98              : 
      99              :   /**
     100              :    * @copydoc Optimizer::getOptimizerVariableDim(const TensorDim &dim)
     101              :    */
     102              :   std::vector<TensorDim> getOptimizerVariableDim(const TensorDim &dim) override;
     103              : 
     104              :   /**
     105              :    * @copydoc Optimizer::exportTo(Exporter &exporter, const
     106              :    * ml::train::ExportMethods& method)
     107              :    */
     108              :   void exportTo(Exporter &exporter,
     109              :                 const ml::train::ExportMethods &method) const override;
     110              : 
     111              :   static constexpr const char *type = "adam";
     112              : 
     113              :   /**
     114              :    * @copydoc Optimizer::setProperty(const std::vector<std::string> &values)
     115              :    */
     116              :   void setProperty(const std::vector<std::string> &values) override;
     117              : 
     118              : private:
     119              :   std::tuple<PropsB1, PropsB2, PropsEpsilon, TorchRef> adam_props;
     120              : 
     121              :   /**
     122              :    * @brief Get updated learning rate
     123              :    *
     124              :    * @param lr learning rate
     125              :    *
     126              :    * @return updated learning rate
     127              :    */
     128              :   double getUpdatedLearningRate(unsigned int iteration, double lr) const;
     129              : };
     130              : } /* namespace nntrainer */
     131              : 
     132              : #endif /* __cplusplus */
     133              : #endif /* __ADAM_H__ */
        

Generated by: LCOV version 2.0-1