Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file optimizer_context.h
6 : * @date 30 July 2021
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
9 : * @bug No known bugs except for NYI items
10 : * @brief This is the optimizer context for each optimizer
11 : */
12 :
13 : #ifndef __OPTIMIZER_CONTEXT_H__
14 : #define __OPTIMIZER_CONTEXT_H__
15 :
16 : #include <memory>
17 : #include <vector>
18 :
19 : #include <tensor.h>
20 :
21 : namespace nntrainer {
22 :
23 : class Weight;
24 :
25 : /**
26 : * @class Op Context class for all optimizers
27 : * @brief Class for Optimizer context
28 : *
29 : * @details This provides for the optimizer execution.
30 : */
31 : class RunOptimizerContext {
32 : public:
33 : /**
34 : * @brief Construct a new Run Optimizer Context object
35 : *
36 : */
37 15627 : RunOptimizerContext(Weight *w = nullptr, size_t iter = 0, double lr = 0.0) :
38 15627 : weight(w), iteration(iter), learning_rate(lr) {}
39 :
40 : /**
41 : * @brief Get the Weight tensor object
42 : *
43 : * @return Tensor& Reference to the weight tensor
44 : */
45 : Tensor &getWeight() const;
46 :
47 : /**
48 : * @brief Get the Weight FP32 tensor object (master weight for mixed
49 : * precision)
50 : *
51 : * @return Tensor& Reference to the FP32 master weight tensor
52 : */
53 : Tensor &getWeightFP32() const;
54 :
55 : /**
56 : * @brief Get the Weight Gradient tensor object
57 : *
58 : * @return Tensor& Reference to the weight grad tensor
59 : */
60 : Tensor &getGradient() const;
61 :
62 : /**
63 : * @brief Return if the underlying weight is mixed precision
64 : */
65 : bool isMixedPrecision() const;
66 :
67 : /**
68 : * @brief Get the optimizer variable associated to this weight
69 : *
70 : * @param idx Identifier of the associated weight
71 : * @return Tensor& Reference to the optimizer variable
72 : */
73 : Tensor &getOptimizerVariable(unsigned int idx) const;
74 :
75 : /**
76 : * @brief Check if run context is set and is ready to use
77 : *
78 : * @return true if ready, else false
79 : */
80 : bool readyToUse() const { return weight != nullptr; }
81 :
82 : /**
83 : * @brief Apply the gradient with the given learning rate
84 : *
85 : * @param lr learning rate
86 : */
87 : void applyGradient(double lr) const;
88 :
89 : /**
90 : * @brief Apply the gradient with the given learning rate and updated
91 : * gradient
92 : *
93 : * @param lr learning rate
94 : * @param updated_grad gradient tensor which is updated. (usually it could be
95 : * fp32)
96 : */
97 : void applyGradient(double lr, Tensor &updated_grad) const;
98 :
99 : /**
100 : * @brief Get the current iteration value
101 : *
102 : * @return iteration value
103 : */
104 1070 : size_t getIteration() const { return iteration; }
105 :
106 : /**
107 : * @brief Get the current iteration value
108 : *
109 : * @return iteration value
110 : */
111 15627 : double getLearningRate() const { return learning_rate; }
112 :
113 : /**
114 : * @brief Apply loss scale to gradient (full precision)
115 : */
116 : void applyLossScale(Tensor &fp32_grad);
117 :
118 : private:
119 : Weight *weight; /**< weights for the optimizer */
120 : size_t iteration; /**< iteration number */
121 : double learning_rate; /**< learning rate */
122 : };
123 :
124 : } // namespace nntrainer
125 : #endif // __OPTIMIZER_CONTEXT_H__
|