Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
4 : *
5 : * @file weight.h
6 : * @date 22 September 2020
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Parichay Kapoor <pk.kapoor@samsung.com>
9 : * @bug No known bugs except for NYI items
10 : * @brief This is Weight Class for Neural Network
11 : *
12 : */
13 :
14 : #ifndef __WEIGHT_H__
15 : #define __WEIGHT_H__
16 :
17 : #include <tuple>
18 :
19 : #include <tensor.h>
20 : #include <tensor_wrap_specs.h>
21 : #include <var_grad.h>
22 :
23 : namespace nntrainer {
24 :
25 : /**
26 : * @class Weight
27 : * @brief Weight extends over Var_Grad with regularization & optimizer updates
28 : */
29 : class Weight : public Var_Grad {
30 : public:
31 : /**
32 : * @brief Specification of the Weight
33 : *
34 : * @details The tuple values are dimension, initializer, regularizer,
35 : * regularizer_constant, need_gradient property amd name of the Weight object.
36 : */
37 : typedef WeightSpec Spec;
38 :
39 : /**
40 : * @brief Weight default constructor
41 : */
42 0 : Weight() :
43 : Var_Grad(),
44 0 : regularizer(WeightRegularizer::UNKNOWN),
45 0 : regularizer_constant(1.0f),
46 0 : decay(0.0f),
47 0 : clip_by_global_norm(0.0f),
48 0 : output_axis(3),
49 0 : loss_scale(1.0),
50 0 : is_mixed(false) {}
51 :
52 : /**
53 : * @brief Construct a new Weight object
54 : *
55 : * @param dim Variable and gradient tensor dimension
56 : * @param init Initializer for the weight
57 : * @param reg Regularizer for the weight
58 : * @param reg_const Constant multiplier for regularizer
59 : * @param ng If the variable needs gradient
60 : * @param alloc_now The memory for the weight tensors be allocated upon init
61 : * @param name Name for this weight
62 : */
63 : explicit Weight(const TensorDim &dim,
64 : const Initializer init = Initializer::XAVIER_UNIFORM,
65 : const WeightRegularizer reg = WeightRegularizer::NONE,
66 : const float reg_const = 1.0f, const float decay = 0.0f,
67 : const float clip_by_global_norm = 0.0f, bool ng = true,
68 : bool alloc_now = false, std::string name = "",
69 : unsigned int axis = 3, float loss_scale_ = 1.0,
70 : bool is_mixed = false);
71 :
72 : /**
73 : * @brief Construct a new Weight object
74 : *
75 : * @param dim_v Variable and gradient tensor dimension
76 : * @param dim_g Gradient tensor dimension
77 : * @param init Initializer for the weight
78 : * @param reg Regularizer for the weight
79 : * @param reg_const Constant multiplier for regularizer
80 : * @param ng If the variable needs gradient
81 : * @param alloc_now The memory for the weight tensors be allocated upon init
82 : * @param name Name for this weight
83 : */
84 : explicit Weight(const TensorDim &dim_v, const TensorDim &dim_g,
85 : const Initializer init = Initializer::XAVIER_UNIFORM,
86 : const WeightRegularizer reg = WeightRegularizer::NONE,
87 : const float reg_const = 1.0f, const float decay = 0.0f,
88 : const float clip_by_global_norm = 0.0f, bool ng = true,
89 : bool alloc_now = false, std::string name = "",
90 : unsigned int axis = 3, float loss_scale_ = 1.0,
91 : bool is_mixed = false);
92 :
93 : /**
94 : * @brief Construct a new Weight object
95 : *
96 : * @param spec Weight specification
97 : */
98 258 : explicit Weight(const Spec &spec, bool alloc_now = false) :
99 : Weight(std::get<0>(spec), // TensorDim for Variable
100 : std::get<1>(spec), // TensorDim for Gradient
101 : std::get<2>(spec), // Initializer
102 : std::get<3>(spec), // WeightRegularizer
103 : std::get<4>(spec), // WeightRegularizerConstant
104 : std::get<5>(spec), // weight decay constant
105 : std::get<6>(spec), // MaxNorm for clipping
106 258 : std::get<7>(spec), // need_gradient
107 : alloc_now,
108 : std::get<8>(spec), // Name
109 : std::get<9>(spec), // out axis
110 : std::get<10>(spec), // loss scale
111 258 : std::get<11>(spec) // is Mixed precision training
112 516 : ) {}
113 :
114 : /**
115 : * @brief Construct a new Weight object
116 : *
117 : * @param v Already created variable object
118 : * @param g Already created gradient object
119 : * @param v32 Already created var32 object
120 : * @param n Name for this Weight
121 : *
122 : * @note This is primarily used to created wrapper of variable extracted from
123 : * context. If needed, add support for regularizer, and opt_vars.
124 : *
125 : * @note This API is not recommended for usage and must be used for internal
126 : * uses only, as Weight does not own the tensors v and g, and can go invalid
127 : * if the owner of these tensors free the tensors.
128 : */
129 : explicit Weight(const Tensor &v, const Tensor &g, const Tensor &v32,
130 : const std::string &n = "", bool is_dependent = false,
131 : unsigned int output_axis_ = 3);
132 :
133 : /**
134 : * @brief Construct a new Weight object
135 : *
136 : * @param v ptr to already created variable tensor
137 : * @param g ptr to already created gradient tensor
138 : * @param v32 ptr to already created variable32 tensor
139 : * @param reg Regularizer for the weight
140 : * @param reg_const Constant multiplier for regularizer
141 : */
142 : explicit Weight(Tensor *v, Tensor *g, Tensor *v32,
143 : const WeightRegularizer reg, const float reg_const,
144 : const float decay, bool is_dependent = false,
145 : const float max_norm = 0.0f, unsigned int output_axis_ = 3,
146 : float loss_scale_ = 1.0f, bool is_mixed = false);
147 :
148 : /**
149 : * @brief Swap for weight
150 : *
151 : * @param lhs Swap to
152 : * @param rhs Swap from
153 : * @note Only swap gradient if need gradient
154 : */
155 : friend void swap(Weight &lhs, Weight &rhs) noexcept {
156 : using std::swap;
157 : swap(static_cast<Var_Grad &>(lhs), static_cast<Var_Grad &>(rhs));
158 : swap(lhs.regularizer, rhs.regularizer);
159 : swap(lhs.regularizer_constant, rhs.regularizer_constant);
160 : swap(lhs.decay, rhs.decay);
161 : swap(lhs.clip_by_global_norm, rhs.clip_by_global_norm);
162 : swap(lhs.output_axis, rhs.output_axis);
163 : swap(lhs.opt_vars, rhs.opt_vars);
164 : swap(lhs.loss_scale, rhs.loss_scale);
165 : swap(lhs.var32, rhs.var32);
166 : swap(lhs.is_mixed, rhs.is_mixed);
167 : }
168 :
169 : /**
170 : * @brief Copy constructor for weight
171 : *
172 : * @param rhs weight to construct from
173 : */
174 3644 : Weight(const Weight &rhs) = default;
175 :
176 : /**
177 : * @brief Move constructor for weight
178 : *
179 : * @param rhs weight to construct from
180 : */
181 3234 : Weight(Weight &&rhs) = default;
182 :
183 : /**
184 : * @brief copy assigment
185 : *
186 : * @param rhs copy from
187 : * @return Weight& Updated weight
188 : */
189 : Weight &operator=(const Weight &rhs) = default;
190 :
191 : /**
192 : * @brief move assignment
193 : *
194 : * @param rhs move from
195 : * @return Weight& Updated weight
196 : */
197 0 : Weight &operator=(Weight &&rhs) = default;
198 :
199 : /**
200 : * @brief Clone the currnet object
201 : *
202 : * @return Cloned copy
203 : */
204 1822 : Weight clone() const {
205 1822 : Weight w(*this);
206 1822 : if (!this->var->empty())
207 3644 : w.var = std::make_shared<Tensor>(this->var->clone());
208 1822 : if (!this->grad->empty())
209 3380 : w.grad = std::make_shared<Tensor>(this->grad->clone());
210 1822 : if (!this->var32->empty())
211 0 : w.var32 = std::make_shared<Tensor>(this->var32->clone());
212 :
213 1822 : return w;
214 0 : }
215 :
216 : /**
217 : * @brief Clear optimizer variables
218 : */
219 : void clearOptimizerVariables() { opt_vars.clear(); }
220 :
221 : /**
222 : * @brief Add optimizer variables
223 : * @param dim Optimizer variable dimension
224 : */
225 : void setOptimizerVariables(std::vector<Tensor *> tensors) {
226 3722 : opt_vars = tensors;
227 3722 : }
228 :
229 : /**
230 : * @brief Get optimizer variable reference
231 : * @param idx Index of the optimizer variable to get
232 : * @retval Reference of the optimizer variable
233 : */
234 2320 : Tensor &getOptimizerVariableRef(unsigned int idx) { return *opt_vars[idx]; }
235 :
236 : /**
237 : * @brief Get number of optimizer variable
238 : * @retval number of optimizer variable
239 : */
240 : int getNumOptVariable() { return opt_vars.size(); }
241 :
242 : /**
243 : * @brief Get axis of Weight
244 : * @retval axis of Wegiht
245 : */
246 : unsigned int getOutputAxis() { return output_axis; }
247 :
248 : /**
249 : * @brief check if weight regularizer type is l2norm
250 : * @return bool is weight regrulatizer type is L2 Norm
251 : */
252 : bool isWeightRegularizerL2Norm() {
253 32786 : return regularizer == WeightRegularizer::L2NORM;
254 : }
255 :
256 : /**
257 : * @brief check if weight decay is enabled
258 : * @return true if weight decay is enabled else false
259 : */
260 15621 : bool isWeightDecay() { return decay > epsilon_decay; }
261 :
262 : /**
263 : * @brief Get loss from the regularization of the weight
264 : */
265 21545 : float getRegularizationLoss() {
266 21545 : if (hasGradient() && isWeightRegularizerL2Norm())
267 142 : return regularizer_constant * 0.5f * var->l2norm();
268 :
269 : return 0;
270 : }
271 :
272 : /**
273 : * @brief Calculate gradient from the regularization of the weight
274 : */
275 : void calcRegularizationGradient() {
276 15627 : if (isWeightRegularizerL2Norm())
277 71 : grad->add_i(*var.get(), regularizer_constant);
278 : }
279 :
280 : /**
281 : * @brief Calculate gradient from the decay of the weight
282 : */
283 : void calcWeightDecayGradient() {
284 15621 : if (isWeightDecay())
285 : applyWeightDecay();
286 : }
287 :
288 : /**
289 : * @brief Apply the gradient to the weight
290 : */
291 15627 : void applyGradient(double lr) { var->add_i(*grad.get(), -lr); }
292 :
293 : /**
294 : * @brief Apply the gradient to the weight with updated gradient
295 : * @param[in] updated_grad gradient tensor which is updated in optimizer
296 : * it might be different data type with gradient in weight. .eg : FP32
297 : */
298 : void applyGradient(double lr, Tensor &updated_grad);
299 :
300 : /**
301 : * @brief Check if the gradient is supposed to be clipped by global norm with
302 : * the given max_norm value
303 : *
304 : * @param max_norm
305 : * @return true if it is to be clipped
306 : * @return false otherwise
307 : */
308 : static bool isGradientClipByGlobalNorm(const float max_norm) {
309 : return max_norm > epsilon;
310 : }
311 :
312 : /**
313 : * @brief Check if the gradient is supposed to be clipped by global norm
314 : *
315 : * @return true if it is to be clipped
316 : * @return false otherwise
317 : */
318 : bool isGradientClipByGlobalNorm() const {
319 19979 : return clip_by_global_norm > epsilon;
320 : }
321 :
322 : /**
323 : * @brief Check if the variable type is not full precision
324 : *
325 : * @return true if it is not full precsion
326 : * @return false otherwise
327 : */
328 44310 : bool isMixedPrecision() const { return is_mixed; }
329 :
330 : /**
331 : * @brief clip the gradient value based on the given global norm
332 : *
333 : * @param global_norm the global norm for all the weights
334 : */
335 : void clipGradientByGlobalNorm(const float global_norm) {
336 44 : if ((global_norm + epsilon) > clip_by_global_norm)
337 0 : grad->multiply_i(clip_by_global_norm / (global_norm + epsilon));
338 : }
339 :
340 : /**
341 : * @brief Get the variable FP32 tensor (by reference)
342 : *
343 : * @return Tensor Variable FP32 tensor
344 : */
345 : Tensor &getVariableFP32Ref() { return *var32.get(); }
346 :
347 : /**
348 : * @brief Quantize var32 to var
349 : *
350 : */
351 : void quantizeWeight();
352 :
353 : /**
354 : * @brief set loss scale
355 : * param[in] scale
356 : *
357 : */
358 0 : void setLossScale(float scale) { loss_scale = scale; };
359 :
360 : /**
361 : * @brief get loss scale
362 : *
363 : */
364 0 : const float getLossScale() { return loss_scale; };
365 :
366 : private:
367 : static constexpr float epsilon = 1e-6f; /**< epsilon for zero comparison */
368 : static constexpr float epsilon_decay =
369 : 1e-8f; /**< epsilon for zero comparison */
370 :
371 : WeightRegularizer regularizer; /**< regularizer for this variable */
372 : float regularizer_constant; /**< constant factor for regularization */
373 : float decay; /**< constant factor for the weight decay */
374 : float clip_by_global_norm; /**< constant factor to clip gradient by L2 norm */
375 : unsigned int output_axis;
376 : float loss_scale;
377 : bool is_mixed;
378 : std::vector<Tensor *>
379 : opt_vars; /**< optimizer variables : We assume it is always full-precsion*/
380 : std::shared_ptr<Tensor> var32;
381 :
382 : /**
383 : * @brief Apply the weight decay to the weight
384 : */
385 0 : void applyWeightDecay() { grad->add_i(*var.get(), decay); }
386 : };
387 :
388 : } // namespace nntrainer
389 :
390 : #endif /** __WEIGHT_H__ */
|