LCOV - code coverage report
Current view: top level - nntrainer/tensor - int4_tensor.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 40.0 % 5 2
Test Date: 2025-12-14 20:38:17 Functions: 50.0 % 4 2

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * @file        int4_tensor.h
       4              :  * @date        23 January 2025
       5              :  * @brief       This is Int4QTensor class for quantized 4-bit integer calculation
       6              :  * @see         https://github.com/nnstreamer/nntrainer
       7              :  * @author      Donghyeon Jeong <dhyeon.jeong@samsung.com>
       8              :  * @bug         No known bugs except for NYI items
       9              :  */
      10              : 
      11              : #ifndef __INT4_TENSOR_H__
      12              : #define __INT4_TENSOR_H__
      13              : #ifdef __cplusplus
      14              : 
      15              : #include <quantizer.h>
      16              : #include <tensor_base.h>
      17              : 
      18              : namespace nntrainer {
      19              : 
      20              : /**
      21              :  * @class Int4QTensor class
      22              :  * @brief Int4QTensor class for quantized 4-bit integer calculation
      23              :  *
      24              :  * @note Int4QTensor store int4 data within the int8 memory space.
      25              :  * Specifically, each int8 value contains two int4 values packed together.
      26              :  * The first four bits represent the first int4 value, while the last four bits
      27              :  * represent the second int4 value.
      28              :  * E.g., 01011001 (89) represents 0101 (+5) and 1001 (-1)
      29              :  *
      30              :  * @todo Remove variable `group_size` and add PER_GROUP_AFFINE_32,64,128
      31              :  */
      32              : class Int4QTensor : public TensorBase {
      33              : public:
      34              :   /**
      35              :    * @brief     Basic Constructor of Tensor
      36              :    */
      37              :   Int4QTensor(std::string name_ = "", Tformat fm = Tformat::NCHW,
      38              :               QScheme qscheme_ = QScheme::PER_CHANNEL_AFFINE,
      39              :               size_t g_size = 32);
      40              : 
      41              :   /**
      42              :    * @brief Construct a new Int4QTensor object
      43              :    *
      44              :    * @param d Tensor dim for this qint4 tensor
      45              :    * @param alloc_now Allocate memory to this tensor or not
      46              :    * @param init Initializer for the tensor
      47              :    * @param name Name of the tensor
      48              :    * @param qscheme_ Quantization scheme of the tensor
      49              :    */
      50              :   Int4QTensor(const TensorDim &d, bool alloc_now,
      51              :               Initializer init = Initializer::NONE, std::string name = "",
      52              :               QScheme qscheme_ = QScheme::PER_CHANNEL_AFFINE,
      53              :               size_t g_size = 32);
      54              : 
      55              :   /**
      56              :    * @brief Construct a new Int4QTensor object
      57              :    *
      58              :    * @param d Tensor dim for this tensor
      59              :    * @param buf buffer
      60              :    * @param qscheme_ quantization scheme of the tensor
      61              :    */
      62              :   Int4QTensor(const TensorDim &d, const void *buf = nullptr,
      63              :               QScheme qscheme_ = QScheme::PER_CHANNEL_AFFINE,
      64              :               size_t g_size = 32);
      65              : 
      66              :   /**
      67              :    * @brief Construct a new Int4QTensor object
      68              :    *
      69              :    * @param d data for the Tensor
      70              :    * @param scales scale factors for the Tensor
      71              :    * @param fm format for the Tensor
      72              :    * @param qscheme_ quantization scheme of the tensor
      73              :    */
      74              :   Int4QTensor(
      75              :     std::vector<std::vector<std::vector<std::vector<int8_t>>>> const &d,
      76              :     std::vector<float> const &scales, Tformat fm, QScheme qscheme_,
      77              :     size_t g_size = 32);
      78              : 
      79              :   /**
      80              :    * @brief Construct a new Int4QTensor object
      81              :    * @param rhs TensorBase object to copy
      82              :    */
      83            0 :   Int4QTensor(TensorBase &rhs) :
      84            0 :     TensorBase(rhs), qscheme(QScheme::PER_CHANNEL_AFFINE) {}
      85              : 
      86              :   /**
      87              :    * @brief Basic Destructor
      88              :    */
      89           10 :   ~Int4QTensor() {}
      90              : 
      91              :   /**
      92              :    * @brief     Comparison operator overload
      93              :    * @param[in] rhs Tensor to be compared with
      94              :    */
      95              :   bool operator==(const Int4QTensor &rhs) const;
      96              : 
      97              :   /**
      98              :    * @brief     Comparison operator overload
      99              :    * @param[in] rhs Tensor to be compared with
     100              :    */
     101              :   bool operator!=(const Int4QTensor &rhs) const { return !(*this == rhs); }
     102              : 
     103              :   /**
     104              :    * @copydoc Tensor::allocate()
     105              :    */
     106              :   void allocate() override;
     107              : 
     108              :   /**
     109              :    * @copydoc Tensor::deallocate()
     110              :    */
     111              :   void deallocate() override;
     112              : 
     113              :   /**
     114              :    * @copydoc Tensor::getData()
     115              :    */
     116              :   void *getData() const override;
     117              : 
     118              :   /**
     119              :    * @copydoc Tensor::getData(size_t idx)
     120              :    */
     121              :   void *getData(size_t idx) const override;
     122              : 
     123              :   /**
     124              :    * @copydoc Tensor::getScale()
     125              :    */
     126              :   void *getScale() const override;
     127              : 
     128              :   /**
     129              :    * @copydoc Tensor::getScale(size_t idx)
     130              :    */
     131              :   void *getScale(size_t idx) const override;
     132              : 
     133              :   /**
     134              :    * @brief     i data index
     135              :    * @retval    address of ith data
     136              :    */
     137              :   void *getAddress(unsigned int i) override;
     138              : 
     139              :   /**
     140              :    * @brief     i data index
     141              :    * @retval    address of ith data
     142              :    */
     143              :   const void *getAddress(unsigned int i) const override;
     144              : 
     145              :   /**
     146              :    * @brief     return value at specific location
     147              :    * @param[in] i index
     148              :    */
     149              :   const int8_t getValue(unsigned int i) const;
     150              : 
     151              :   /**
     152              :    * @brief     return value at specific location
     153              :    * @param[in] i index
     154              :    */
     155              :   int8_t getValue(unsigned int i);
     156              : 
     157              :   /**
     158              :    * @brief     return value at specific location
     159              :    * @param[in] b batch location
     160              :    * @param[in] c channel location
     161              :    * @param[in] h height location
     162              :    * @param[in] w width location
     163              :    */
     164              :   const int8_t getValue(unsigned int b, unsigned int c, unsigned int h,
     165              :                         unsigned int w) const;
     166              : 
     167              :   /**
     168              :    * @brief     return value at specific location
     169              :    * @param[in] b batch location
     170              :    * @param[in] c channel location
     171              :    * @param[in] h height location
     172              :    * @param[in] w width location
     173              :    */
     174              :   int8_t getValue(unsigned int b, unsigned int c, unsigned int h,
     175              :                   unsigned int w);
     176              : 
     177              :   /**
     178              :    * @copydoc Tensor::setValue(float value)
     179              :    */
     180              :   void setValue(float value) override;
     181              : 
     182              :   /**
     183              :    * @copydoc Tensor::setValue(b, c, h, w, value)
     184              :    */
     185              :   void setValue(unsigned int b, unsigned int c, unsigned int h, unsigned int w,
     186              :                 float value) override;
     187              : 
     188              :   /**
     189              :    * @copydoc Tensor::addValue(b, c, h, w, value, beta)
     190              :    */
     191              :   void addValue(unsigned int b, unsigned int c, unsigned int h, unsigned int w,
     192              :                 float value, float beta) override;
     193              : 
     194              :   /**
     195              :    * @copydoc Tensor::setZero()
     196              :    */
     197              :   void setZero() override;
     198              : 
     199              :   /**
     200              :    * @copydoc Tensor::initialize()
     201              :    */
     202              :   void initialize() override;
     203              : 
     204              :   /**
     205              :    * @copydoc Tensor::initialize(Initializer init)
     206              :    */
     207              :   void initialize(Initializer init) override;
     208              : 
     209              :   /**
     210              :    * @copydoc Tensor::copy(const Tensor &from)
     211              :    */
     212              :   void copy(const Tensor &from) override;
     213              : 
     214              :   /**
     215              :    * @copydoc Tensor::copyData(const Tensor &from)
     216              :    */
     217              :   void copyData(const Tensor &from) override;
     218              : 
     219              :   /**
     220              :    * @copydoc Tensor::copy_with_stride()
     221              :    */
     222              :   void copy_with_stride(const Tensor &input, Tensor &output) override;
     223              : 
     224              :   /**
     225              :    * @copydoc Tensor::save(std::ostream &file)
     226              :    */
     227              :   void save(std::ostream &file) override;
     228              : 
     229              :   /**
     230              :    * @copydoc Tensor::read(std::ifstream &file)
     231              :    */
     232              :   void read(std::ifstream &file, size_t start_offset,
     233              :             bool read_from_offset) override;
     234              : 
     235              :   /**
     236              :    * @brief     Read the Tensor from file
     237              :    * @param[in] src input file stream
     238              :    */
     239              :   void read(ReadSource src, size_t start_offset = 0,
     240              :             bool read_from_offset = false) override;
     241              : 
     242              :   /**
     243              :    * @copydoc Tensor::argmax()
     244              :    */
     245              :   std::vector<unsigned int> argmax() const override;
     246              : 
     247              :   /**
     248              :    * @copydoc Tensor::argmin()
     249              :    */
     250              :   std::vector<unsigned int> argmin() const override;
     251              : 
     252              :   /**
     253              :    * @copydoc Tensor::max_abs()
     254              :    */
     255              :   float max_abs() const override;
     256              : 
     257              :   /**
     258              :    * @copydoc Tensor::maxValue()
     259              :    */
     260              :   float maxValue() const override;
     261              : 
     262              :   /**
     263              :    * @copydoc Tensor::minValue()
     264              :    */
     265              :   float minValue() const override;
     266              : 
     267              :   /**
     268              :    * @copydoc Tensor::print(std::ostream &out)
     269              :    */
     270              :   void print(std::ostream &out) const override;
     271              : 
     272              :   /**
     273              :    * @copydoc TensorBase::save_quantization_info()
     274              :    */
     275              :   void save_quantization_info(std::ostream &file) override;
     276              : 
     277              :   /**
     278              :    * @copydoc TensorBase::read_quantization_info()
     279              :    */
     280              :   void read_quantization_info(std::ifstream &file, size_t start_offset,
     281              :                               bool read_from_offset) override;
     282              : 
     283              :   /**
     284              :    * @copydoc TensorBase::read_quantization_info()
     285              :    */
     286              :   void read_quantization_info(ReadSource src, size_t start_offset,
     287              :                               bool read_from_offset) override;
     288              :   /**
     289              :    * @copydoc Tensor::getMemoryBytes()
     290              :    */
     291              :   size_t getMemoryBytes() const override;
     292              : 
     293              :   /**
     294              :    * @copydoc Tensor::scale_size()
     295              :    */
     296              :   size_t scale_size() const override;
     297              : 
     298              :   /**
     299              :    * @copydoc Tensor::q_scheme()
     300              :    */
     301              :   QScheme q_scheme() const override;
     302              : 
     303              :   /**
     304              :    * @brief Returns quantization group size
     305              :    */
     306              :   static size_t getGroupSize();
     307              : 
     308              : private:
     309              :   /**
     310              :    * @brief quantization scheme
     311              :    */
     312              :   QScheme qscheme;
     313              : 
     314              :   /**
     315              :    * @brief Quantization group size
     316              :    *
     317              :    * @note need to properly define this
     318              :    */
     319              :   static size_t group_size;
     320              : 
     321              :   /**
     322              :    * @brief copy a buffer to @a this, the caller has to ensure that @a this is
     323              :    * initialized otherwise undefined behavior
     324              :    *
     325              :    * @param buf buffer to copy from
     326              :    */
     327              :   void copy(const void *buf);
     328              : 
     329              :   /**
     330              :    * @brief  Get the Data Type String object
     331              :    * @return std::string of tensor data type (QINT4)
     332              :    */
     333            1 :   std::string getStringDataType() const override { return "QINT4"; }
     334              : 
     335              :   /**
     336              :    * @copydoc Tensor::isValid()
     337              :    */
     338            0 :   bool isValid() const override { return true; };
     339              : };
     340              : 
     341              : } // namespace nntrainer
     342              : 
     343              : #endif /* __cplusplus */
     344              : #endif /* __INT4_TENSOR_H__ */
        

Generated by: LCOV version 2.0-1