Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file sgd.cpp
6 : * @date 6 October 2020
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Jijoong Moon <jijoong.moon@samsung.com>
9 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : * @brief This is the SGD optimizer.
12 : */
13 :
14 : #include <sgd.h>
15 :
16 : namespace nntrainer {
17 :
18 14557 : void SGD::applyGradient(RunOptimizerContext &context) {
19 : // @todo This could go inside the context.
20 14557 : Tensor empty_tensor;
21 :
22 : Tensor &x_grad =
23 14557 : context.getGradient().getDataType() == ml::train::TensorDim::DataType::FP32
24 14557 : ? context.getGradient()
25 : : empty_tensor;
26 :
27 14557 : if (x_grad.empty()) {
28 0 : x_grad = context.getGradient().clone(ml::train::TensorDim::DataType::FP32);
29 0 : context.applyLossScale(x_grad);
30 : }
31 :
32 14557 : context.applyGradient(context.getLearningRate(), x_grad);
33 14557 : }
34 :
35 : } // namespace nntrainer
|