LCOV - code coverage report
Current view: top level - nntrainer/optimizers - adamw.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 92.6 % 54 50
Test Date: 2025-12-14 20:38:17 Functions: 100.0 % 8 8

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2024 Daniel Jang <minhyukjang@snu.ac.kr>
       4              :  *
       5              :  * @file   adamw.cpp
       6              :  * @date   3 November 2024
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Jijoong Moon <jijoong.moon@samsung.com>
       9              :  * @author Parichay Kapoor <pk.kapoor@samsung.com>
      10              :  * @author Daniel Jang <minhyukjang@snu.ac.kr>
      11              :  * @bug    No known bugs except for NYI items
      12              :  * @brief  This is the AdamW Optimizer.
      13              :  */
      14              : 
      15              : #include <cmath>
      16              : #include <fstream>
      17              : 
      18              : #include <adamw.h>
      19              : #include <nntrainer_error.h>
      20              : #include <nntrainer_log.h>
      21              : #include <node_exporter.h>
      22              : #include <util_func.h>
      23              : 
      24              : namespace nntrainer {
      25              : 
      26            8 : AdamW::AdamW() :
      27              :   adam_props(PropsB1(), PropsB2(), PropsEpsilon(), TorchRef(),
      28            8 :              PropsWeightDecayW()) {
      29              :   /** default properties */
      30              :   auto &[b1, b2, eps, torch_ref, weight_decay] = adam_props;
      31            8 :   b1.set(0.9f);
      32            8 :   b2.set(0.999f);
      33            8 :   eps.set(1.0e-8f);
      34            8 :   torch_ref.set(false);
      35            8 :   weight_decay.set(0.0f);
      36            8 : }
      37              : 
      38           16 : AdamW::~AdamW() {}
      39              : 
      40              : enum AdamParams { wm, wv };
      41              : 
      42            8 : std::vector<TensorDim> AdamW::getOptimizerVariableDim(const TensorDim &dim) {
      43              :   /**
      44              :    * @note We assume the optimizer parameters should be full precision to
      45              :    * maintain the accuracy even in mixed precision training.
      46              :    */
      47            8 :   TensorDim wm_dim(dim);
      48            8 :   TensorDim wv_dim(dim);
      49              :   wm_dim.setDataType(ml::train::TensorDim::DataType::FP32);
      50              :   wv_dim.setDataType(ml::train::TensorDim::DataType::FP32);
      51            8 :   return {wm_dim, wv_dim};
      52              : }
      53              : 
      54            2 : void AdamW::exportTo(Exporter &exporter,
      55              :                      const ml::train::ExportMethods &method) const {
      56            2 :   exporter.saveResult(adam_props, method, this);
      57              :   Optimizer::exportTo(exporter, method);
      58            2 : }
      59              : 
      60           13 : void AdamW::setProperty(const std::vector<std::string> &values) {
      61           13 :   auto left = loadProperties(values, adam_props);
      62           11 :   Optimizer::setProperty(left);
      63           11 : }
      64              : 
      65            6 : double AdamW::getUpdatedLearningRate(unsigned int iteration, double lr) const {
      66            6 :   auto &beta1 = std::get<PropsB1>(adam_props).get();
      67            6 :   auto &beta2 = std::get<PropsB2>(adam_props).get();
      68              :   auto biasCorrection = [&](double f) {
      69            6 :     return 1.0 - (double)pow(f, iteration + 1);
      70              :   };
      71            6 :   lr *= sqrt(biasCorrection(beta2)) / biasCorrection(beta1);
      72            6 :   return lr;
      73              : }
      74              : 
      75            6 : void AdamW::applyGradient(RunOptimizerContext &context) {
      76            6 :   Tensor empty_tensor;
      77              : 
      78              :   Tensor &x_grad =
      79            6 :     context.getGradient().getDataType() == ml::train::TensorDim::DataType::FP32
      80            6 :       ? context.getGradient()
      81              :       : empty_tensor;
      82              : 
      83            6 :   if (x_grad.empty()) {
      84            0 :     x_grad = context.getGradient().clone(ml::train::TensorDim::DataType::FP32);
      85              :   }
      86              : 
      87            6 :   context.applyLossScale(x_grad);
      88              : 
      89            6 :   auto &beta1 = std::get<PropsB1>(adam_props).get();
      90            6 :   auto &beta2 = std::get<PropsB2>(adam_props).get();
      91            6 :   auto &epsilon = std::get<PropsEpsilon>(adam_props).get();
      92            6 :   auto &weight_decay = std::get<PropsWeightDecayW>(adam_props).get();
      93              : 
      94            6 :   Tensor &wm = context.getOptimizerVariable(AdamParams::wm);
      95            6 :   Tensor &wv = context.getOptimizerVariable(AdamParams::wv);
      96              : 
      97            6 :   wm.multiply_i(beta1);
      98            6 :   wm.add_i(x_grad, 1.0f - beta1);
      99              : 
     100            6 :   wv.multiply_i(beta2);
     101            6 :   wv.add_i(x_grad.multiply(x_grad), 1.0f - beta2);
     102              : 
     103              :   // Decoupled weight decay: w = w - lr * wd * w
     104            6 :   if (weight_decay > 0.0) {
     105            0 :     Tensor &w = context.isMixedPrecision() ? context.getWeightFP32()
     106            0 :                                            : context.getWeight();
     107            0 :     w.multiply_i(1.0f - (context.getLearningRate() * weight_decay));
     108              :   }
     109              : 
     110              :   // Adam update with bias-corrected lr
     111              :   double lr_t =
     112            6 :     getUpdatedLearningRate(context.getIteration(), context.getLearningRate());
     113              : 
     114            6 :   std::function<double(double)> sqrtEps = [epsilon](double f) {
     115           48 :     return 1.0 / (sqrtDouble(f) + epsilon);
     116              :   };
     117            6 :   x_grad = wv.apply<float>(sqrtEps, x_grad);
     118            6 :   x_grad.multiply_i(wm);
     119            6 :   context.applyGradient(lr_t, x_grad);
     120            6 : }
     121              : 
     122              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1