LCOV - code coverage report
Current view: top level - nntrainer/tensor - weight.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 65.9 % 44 29
Test Date: 2025-12-14 20:38:17 Functions: 83.3 % 6 5

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2020 Parichay Kapoor <pk.kapoor@samsung.com>
       4              :  *
       5              :  * @file   weight.h
       6              :  * @date   22 September 2020
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Parichay Kapoor <pk.kapoor@samsung.com>
       9              :  * @bug    No known bugs except for NYI items
      10              :  * @brief  This is Weight Class for Neural Network
      11              :  *
      12              :  */
      13              : 
      14              : #ifndef __WEIGHT_H__
      15              : #define __WEIGHT_H__
      16              : 
      17              : #include <tuple>
      18              : 
      19              : #include <tensor.h>
      20              : #include <tensor_wrap_specs.h>
      21              : #include <var_grad.h>
      22              : 
      23              : namespace nntrainer {
      24              : 
      25              : /**
      26              :  * @class   Weight
      27              :  * @brief   Weight extends over Var_Grad with regularization & optimizer updates
      28              :  */
      29              : class Weight : public Var_Grad {
      30              : public:
      31              :   /**
      32              :    * @brief Specification of the Weight
      33              :    *
      34              :    * @details The tuple values are dimension, initializer, regularizer,
      35              :    * regularizer_constant, need_gradient property amd name of the Weight object.
      36              :    */
      37              :   typedef WeightSpec Spec;
      38              : 
      39              :   /**
      40              :    * @brief Weight default constructor
      41              :    */
      42            0 :   Weight() :
      43              :     Var_Grad(),
      44            0 :     regularizer(WeightRegularizer::UNKNOWN),
      45            0 :     regularizer_constant(1.0f),
      46            0 :     decay(0.0f),
      47            0 :     clip_by_global_norm(0.0f),
      48            0 :     output_axis(3),
      49            0 :     loss_scale(1.0),
      50            0 :     is_mixed(false) {}
      51              : 
      52              :   /**
      53              :    * @brief Construct a new Weight object
      54              :    *
      55              :    * @param dim Variable and gradient tensor dimension
      56              :    * @param init Initializer for the weight
      57              :    * @param reg Regularizer for the weight
      58              :    * @param reg_const Constant multiplier for regularizer
      59              :    * @param ng If the variable needs gradient
      60              :    * @param alloc_now The memory for the weight tensors be allocated upon init
      61              :    * @param name Name for this weight
      62              :    */
      63              :   explicit Weight(const TensorDim &dim,
      64              :                   const Initializer init = Initializer::XAVIER_UNIFORM,
      65              :                   const WeightRegularizer reg = WeightRegularizer::NONE,
      66              :                   const float reg_const = 1.0f, const float decay = 0.0f,
      67              :                   const float clip_by_global_norm = 0.0f, bool ng = true,
      68              :                   bool alloc_now = false, std::string name = "",
      69              :                   unsigned int axis = 3, float loss_scale_ = 1.0,
      70              :                   bool is_mixed = false);
      71              : 
      72              :   /**
      73              :    * @brief Construct a new Weight object
      74              :    *
      75              :    * @param dim_v Variable and gradient tensor dimension
      76              :    * @param dim_g Gradient tensor dimension
      77              :    * @param init Initializer for the weight
      78              :    * @param reg Regularizer for the weight
      79              :    * @param reg_const Constant multiplier for regularizer
      80              :    * @param ng If the variable needs gradient
      81              :    * @param alloc_now The memory for the weight tensors be allocated upon init
      82              :    * @param name Name for this weight
      83              :    */
      84              :   explicit Weight(const TensorDim &dim_v, const TensorDim &dim_g,
      85              :                   const Initializer init = Initializer::XAVIER_UNIFORM,
      86              :                   const WeightRegularizer reg = WeightRegularizer::NONE,
      87              :                   const float reg_const = 1.0f, const float decay = 0.0f,
      88              :                   const float clip_by_global_norm = 0.0f, bool ng = true,
      89              :                   bool alloc_now = false, std::string name = "",
      90              :                   unsigned int axis = 3, float loss_scale_ = 1.0,
      91              :                   bool is_mixed = false);
      92              : 
      93              :   /**
      94              :    * @brief Construct a new Weight object
      95              :    *
      96              :    * @param spec Weight specification
      97              :    */
      98          258 :   explicit Weight(const Spec &spec, bool alloc_now = false) :
      99              :     Weight(std::get<0>(spec), // TensorDim for Variable
     100              :            std::get<1>(spec), // TensorDim for Gradient
     101              :            std::get<2>(spec), // Initializer
     102              :            std::get<3>(spec), // WeightRegularizer
     103              :            std::get<4>(spec), // WeightRegularizerConstant
     104              :            std::get<5>(spec), // weight decay constant
     105              :            std::get<6>(spec), // MaxNorm for clipping
     106          258 :            std::get<7>(spec), // need_gradient
     107              :            alloc_now,
     108              :            std::get<8>(spec),  // Name
     109              :            std::get<9>(spec),  // out axis
     110              :            std::get<10>(spec), // loss scale
     111          258 :            std::get<11>(spec)  // is Mixed precision training
     112          516 :     ) {}
     113              : 
     114              :   /**
     115              :    * @brief Construct a new Weight object
     116              :    *
     117              :    * @param v Already created variable object
     118              :    * @param g Already created gradient object
     119              :    * @param v32 Already created var32 object
     120              :    * @param n Name for this Weight
     121              :    *
     122              :    * @note This is primarily used to created wrapper of variable extracted from
     123              :    * context. If needed, add support for regularizer, and opt_vars.
     124              :    *
     125              :    * @note This API is not recommended for usage and must be used for internal
     126              :    * uses only, as Weight does not own the tensors v and g, and can go invalid
     127              :    * if the owner of these tensors free the tensors.
     128              :    */
     129              :   explicit Weight(const Tensor &v, const Tensor &g, const Tensor &v32,
     130              :                   const std::string &n = "", bool is_dependent = false,
     131              :                   unsigned int output_axis_ = 3);
     132              : 
     133              :   /**
     134              :    * @brief Construct a new Weight object
     135              :    *
     136              :    * @param v ptr to already created variable tensor
     137              :    * @param g ptr to already created gradient tensor
     138              :    * @param v32 ptr to already created variable32 tensor
     139              :    * @param reg Regularizer for the weight
     140              :    * @param reg_const Constant multiplier for regularizer
     141              :    */
     142              :   explicit Weight(Tensor *v, Tensor *g, Tensor *v32,
     143              :                   const WeightRegularizer reg, const float reg_const,
     144              :                   const float decay, bool is_dependent = false,
     145              :                   const float max_norm = 0.0f, unsigned int output_axis_ = 3,
     146              :                   float loss_scale_ = 1.0f, bool is_mixed = false);
     147              : 
     148              :   /**
     149              :    * @brief Swap for weight
     150              :    *
     151              :    * @param lhs Swap to
     152              :    * @param rhs Swap from
     153              :    * @note Only swap gradient if need gradient
     154              :    */
     155              :   friend void swap(Weight &lhs, Weight &rhs) noexcept {
     156              :     using std::swap;
     157              :     swap(static_cast<Var_Grad &>(lhs), static_cast<Var_Grad &>(rhs));
     158              :     swap(lhs.regularizer, rhs.regularizer);
     159              :     swap(lhs.regularizer_constant, rhs.regularizer_constant);
     160              :     swap(lhs.decay, rhs.decay);
     161              :     swap(lhs.clip_by_global_norm, rhs.clip_by_global_norm);
     162              :     swap(lhs.output_axis, rhs.output_axis);
     163              :     swap(lhs.opt_vars, rhs.opt_vars);
     164              :     swap(lhs.loss_scale, rhs.loss_scale);
     165              :     swap(lhs.var32, rhs.var32);
     166              :     swap(lhs.is_mixed, rhs.is_mixed);
     167              :   }
     168              : 
     169              :   /**
     170              :    * @brief Copy constructor for weight
     171              :    *
     172              :    * @param rhs weight to construct from
     173              :    */
     174         3644 :   Weight(const Weight &rhs) = default;
     175              : 
     176              :   /**
     177              :    * @brief Move constructor for weight
     178              :    *
     179              :    * @param rhs weight to construct from
     180              :    */
     181         3234 :   Weight(Weight &&rhs) = default;
     182              : 
     183              :   /**
     184              :    * @brief copy assigment
     185              :    *
     186              :    * @param rhs copy from
     187              :    * @return Weight& Updated weight
     188              :    */
     189              :   Weight &operator=(const Weight &rhs) = default;
     190              : 
     191              :   /**
     192              :    * @brief move assignment
     193              :    *
     194              :    * @param rhs move from
     195              :    * @return Weight& Updated weight
     196              :    */
     197            0 :   Weight &operator=(Weight &&rhs) = default;
     198              : 
     199              :   /**
     200              :    * @brief Clone the currnet object
     201              :    *
     202              :    * @return Cloned copy
     203              :    */
     204         1822 :   Weight clone() const {
     205         1822 :     Weight w(*this);
     206         1822 :     if (!this->var->empty())
     207         3644 :       w.var = std::make_shared<Tensor>(this->var->clone());
     208         1822 :     if (!this->grad->empty())
     209         3380 :       w.grad = std::make_shared<Tensor>(this->grad->clone());
     210         1822 :     if (!this->var32->empty())
     211            0 :       w.var32 = std::make_shared<Tensor>(this->var32->clone());
     212              : 
     213         1822 :     return w;
     214            0 :   }
     215              : 
     216              :   /**
     217              :    * @brief Clear optimizer variables
     218              :    */
     219              :   void clearOptimizerVariables() { opt_vars.clear(); }
     220              : 
     221              :   /**
     222              :    * @brief Add optimizer variables
     223              :    * @param dim Optimizer variable dimension
     224              :    */
     225              :   void setOptimizerVariables(std::vector<Tensor *> tensors) {
     226         3722 :     opt_vars = tensors;
     227         3722 :   }
     228              : 
     229              :   /**
     230              :    * @brief Get optimizer variable reference
     231              :    * @param idx Index of the optimizer variable to get
     232              :    * @retval Reference of the optimizer variable
     233              :    */
     234         2320 :   Tensor &getOptimizerVariableRef(unsigned int idx) { return *opt_vars[idx]; }
     235              : 
     236              :   /**
     237              :    * @brief Get number of optimizer variable
     238              :    * @retval number of optimizer variable
     239              :    */
     240              :   int getNumOptVariable() { return opt_vars.size(); }
     241              : 
     242              :   /**
     243              :    * @brief Get axis of Weight
     244              :    * @retval axis of Wegiht
     245              :    */
     246              :   unsigned int getOutputAxis() { return output_axis; }
     247              : 
     248              :   /**
     249              :    * @brief     check if weight regularizer type is l2norm
     250              :    * @return    bool is weight regrulatizer type is L2 Norm
     251              :    */
     252              :   bool isWeightRegularizerL2Norm() {
     253        32786 :     return regularizer == WeightRegularizer::L2NORM;
     254              :   }
     255              : 
     256              :   /**
     257              :    * @brief     check if weight decay is enabled
     258              :    * @return    true if weight decay is enabled else false
     259              :    */
     260        15621 :   bool isWeightDecay() { return decay > epsilon_decay; }
     261              : 
     262              :   /**
     263              :    * @brief     Get loss from the regularization of the weight
     264              :    */
     265        21545 :   float getRegularizationLoss() {
     266        21545 :     if (hasGradient() && isWeightRegularizerL2Norm())
     267          142 :       return regularizer_constant * 0.5f * var->l2norm();
     268              : 
     269              :     return 0;
     270              :   }
     271              : 
     272              :   /**
     273              :    * @brief     Calculate gradient from the regularization of the weight
     274              :    */
     275              :   void calcRegularizationGradient() {
     276        15627 :     if (isWeightRegularizerL2Norm())
     277           71 :       grad->add_i(*var.get(), regularizer_constant);
     278              :   }
     279              : 
     280              :   /**
     281              :    * @brief     Calculate gradient from the decay of the weight
     282              :    */
     283              :   void calcWeightDecayGradient() {
     284        15621 :     if (isWeightDecay())
     285              :       applyWeightDecay();
     286              :   }
     287              : 
     288              :   /**
     289              :    * @brief     Apply the gradient to the weight
     290              :    */
     291        15627 :   void applyGradient(double lr) { var->add_i(*grad.get(), -lr); }
     292              : 
     293              :   /**
     294              :    * @brief     Apply the gradient to the weight with updated gradient
     295              :    * @param[in] updated_grad gradient tensor which is updated in optimizer
     296              :    * it might be different data type with gradient in weight. .eg : FP32
     297              :    */
     298              :   void applyGradient(double lr, Tensor &updated_grad);
     299              : 
     300              :   /**
     301              :    * @brief Check if the gradient is supposed to be clipped by global norm with
     302              :    * the given max_norm value
     303              :    *
     304              :    * @param max_norm
     305              :    * @return true if it is to be clipped
     306              :    * @return false otherwise
     307              :    */
     308              :   static bool isGradientClipByGlobalNorm(const float max_norm) {
     309              :     return max_norm > epsilon;
     310              :   }
     311              : 
     312              :   /**
     313              :    * @brief Check if the gradient is supposed to be clipped by global norm
     314              :    *
     315              :    * @return true if it is to be clipped
     316              :    * @return false otherwise
     317              :    */
     318              :   bool isGradientClipByGlobalNorm() const {
     319        19979 :     return clip_by_global_norm > epsilon;
     320              :   }
     321              : 
     322              :   /**
     323              :    * @brief Check if the variable type is not full precision
     324              :    *
     325              :    * @return true if it is not full precsion
     326              :    * @return false otherwise
     327              :    */
     328        44310 :   bool isMixedPrecision() const { return is_mixed; }
     329              : 
     330              :   /**
     331              :    * @brief clip the gradient value based on the given global norm
     332              :    *
     333              :    * @param global_norm the global norm for all the weights
     334              :    */
     335              :   void clipGradientByGlobalNorm(const float global_norm) {
     336           44 :     if ((global_norm + epsilon) > clip_by_global_norm)
     337            0 :       grad->multiply_i(clip_by_global_norm / (global_norm + epsilon));
     338              :   }
     339              : 
     340              :   /**
     341              :    * @brief Get the variable FP32 tensor (by reference)
     342              :    *
     343              :    * @return Tensor Variable FP32 tensor
     344              :    */
     345              :   Tensor &getVariableFP32Ref() { return *var32.get(); }
     346              : 
     347              :   /**
     348              :    * @brief Quantize var32 to var
     349              :    *
     350              :    */
     351              :   void quantizeWeight();
     352              : 
     353              :   /**
     354              :    * @brief set loss scale
     355              :    * param[in] scale
     356              :    *
     357              :    */
     358            0 :   void setLossScale(float scale) { loss_scale = scale; };
     359              : 
     360              :   /**
     361              :    * @brief get loss scale
     362              :    *
     363              :    */
     364            0 :   const float getLossScale() { return loss_scale; };
     365              : 
     366              : private:
     367              :   static constexpr float epsilon = 1e-6f; /**< epsilon for zero comparison */
     368              :   static constexpr float epsilon_decay =
     369              :     1e-8f; /**< epsilon for zero comparison */
     370              : 
     371              :   WeightRegularizer regularizer; /**< regularizer for this variable */
     372              :   float regularizer_constant;    /**< constant factor for regularization */
     373              :   float decay;                   /**< constant factor for the weight decay */
     374              :   float clip_by_global_norm; /**< constant factor to clip gradient by L2 norm */
     375              :   unsigned int output_axis;
     376              :   float loss_scale;
     377              :   bool is_mixed;
     378              :   std::vector<Tensor *>
     379              :     opt_vars; /**< optimizer variables : We assume it is always full-precsion*/
     380              :   std::shared_ptr<Tensor> var32;
     381              : 
     382              :   /**
     383              :    * @brief     Apply the weight decay to the weight
     384              :    */
     385            0 :   void applyWeightDecay() { grad->add_i(*var.get(), decay); }
     386              : };
     387              : 
     388              : } // namespace nntrainer
     389              : 
     390              : #endif /** __WEIGHT_H__ */
        

Generated by: LCOV version 2.0-1