LCOV - code coverage report
Current view: top level - nntrainer/optimizers - sgd.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 77.8 % 9 7
Test Date: 2025-12-14 20:38:17 Functions: 100.0 % 1 1

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
       4              :  *
       5              :  * @file   sgd.cpp
       6              :  * @date   6 October 2020
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Jijoong Moon <jijoong.moon@samsung.com>
       9              :  * @author Parichay Kapoor <pk.kapoor@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  * @brief  This is the SGD optimizer.
      12              :  */
      13              : 
      14              : #include <sgd.h>
      15              : 
      16              : namespace nntrainer {
      17              : 
      18        14557 : void SGD::applyGradient(RunOptimizerContext &context) {
      19              :   // @todo This could go inside the context.
      20        14557 :   Tensor empty_tensor;
      21              : 
      22              :   Tensor &x_grad =
      23        14557 :     context.getGradient().getDataType() == ml::train::TensorDim::DataType::FP32
      24        14557 :       ? context.getGradient()
      25              :       : empty_tensor;
      26              : 
      27        14557 :   if (x_grad.empty()) {
      28            0 :     x_grad = context.getGradient().clone(ml::train::TensorDim::DataType::FP32);
      29            0 :     context.applyLossScale(x_grad);
      30              :   }
      31              : 
      32        14557 :   context.applyGradient(context.getLearningRate(), x_grad);
      33        14557 : }
      34              : 
      35              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1