LCOV - code coverage report
Current view: top level - nntrainer/optimizers - optimizer_context.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 100.0 % 4 4
Test Date: 2025-12-14 20:38:17 Functions: - 0 0

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
       4              :  *
       5              :  * @file   optimizer_context.h
       6              :  * @date   30 July 2021
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Parichay Kapoor <pk.kapoor@samsung.com>
       9              :  * @bug    No known bugs except for NYI items
      10              :  * @brief  This is the optimizer context for each optimizer
      11              :  */
      12              : 
      13              : #ifndef __OPTIMIZER_CONTEXT_H__
      14              : #define __OPTIMIZER_CONTEXT_H__
      15              : 
      16              : #include <memory>
      17              : #include <vector>
      18              : 
      19              : #include <tensor.h>
      20              : 
      21              : namespace nntrainer {
      22              : 
      23              : class Weight;
      24              : 
      25              : /**
      26              :  * @class   Op Context class for all optimizers
      27              :  * @brief   Class for Optimizer context
      28              :  *
      29              :  * @details This provides for the optimizer execution.
      30              :  */
      31              : class RunOptimizerContext {
      32              : public:
      33              :   /**
      34              :    * @brief Construct a new Run Optimizer Context object
      35              :    *
      36              :    */
      37        15627 :   RunOptimizerContext(Weight *w = nullptr, size_t iter = 0, double lr = 0.0) :
      38        15627 :     weight(w), iteration(iter), learning_rate(lr) {}
      39              : 
      40              :   /**
      41              :    * @brief Get the Weight tensor object
      42              :    *
      43              :    * @return Tensor& Reference to the weight tensor
      44              :    */
      45              :   Tensor &getWeight() const;
      46              : 
      47              :   /**
      48              :    * @brief Get the Weight FP32 tensor object (master weight for mixed
      49              :    * precision)
      50              :    *
      51              :    * @return Tensor& Reference to the FP32 master weight tensor
      52              :    */
      53              :   Tensor &getWeightFP32() const;
      54              : 
      55              :   /**
      56              :    * @brief Get the Weight Gradient tensor object
      57              :    *
      58              :    * @return Tensor& Reference to the weight grad tensor
      59              :    */
      60              :   Tensor &getGradient() const;
      61              : 
      62              :   /**
      63              :    * @brief Return if the underlying weight is mixed precision
      64              :    */
      65              :   bool isMixedPrecision() const;
      66              : 
      67              :   /**
      68              :    * @brief Get the optimizer variable associated to this weight
      69              :    *
      70              :    * @param idx Identifier of the associated weight
      71              :    * @return Tensor& Reference to the optimizer variable
      72              :    */
      73              :   Tensor &getOptimizerVariable(unsigned int idx) const;
      74              : 
      75              :   /**
      76              :    * @brief   Check if run context is set and is ready to use
      77              :    *
      78              :    * @return true if ready, else false
      79              :    */
      80              :   bool readyToUse() const { return weight != nullptr; }
      81              : 
      82              :   /**
      83              :    * @brief   Apply the gradient with the given learning rate
      84              :    *
      85              :    * @param lr learning rate
      86              :    */
      87              :   void applyGradient(double lr) const;
      88              : 
      89              :   /**
      90              :    * @brief   Apply the gradient with the given learning rate and updated
      91              :    * gradient
      92              :    *
      93              :    * @param lr learning rate
      94              :    * @param updated_grad gradient tensor which is updated. (usually it could be
      95              :    * fp32)
      96              :    */
      97              :   void applyGradient(double lr, Tensor &updated_grad) const;
      98              : 
      99              :   /**
     100              :    * @brief   Get the current iteration value
     101              :    *
     102              :    * @return iteration value
     103              :    */
     104         1070 :   size_t getIteration() const { return iteration; }
     105              : 
     106              :   /**
     107              :    * @brief   Get the current iteration value
     108              :    *
     109              :    * @return iteration value
     110              :    */
     111        15627 :   double getLearningRate() const { return learning_rate; }
     112              : 
     113              :   /**
     114              :    * @brief   Apply loss scale to gradient (full precision)
     115              :    */
     116              :   void applyLossScale(Tensor &fp32_grad);
     117              : 
     118              : private:
     119              :   Weight *weight;       /**< weights for the optimizer */
     120              :   size_t iteration;     /**< iteration number */
     121              :   double learning_rate; /**< learning rate */
     122              : };
     123              : 
     124              : } // namespace nntrainer
     125              : #endif // __OPTIMIZER_CONTEXT_H__
        

Generated by: LCOV version 2.0-1