LCOV - code coverage report
Current view: top level - nntrainer/tensor - int4_tensor.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 46.4 % 306 142
Test Date: 2025-12-14 20:38:17 Functions: 44.2 % 43 19

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * @file        int4_tensor.cpp
       4              :  * @date        23 January 2025
       5              :  * @brief       This is Int4QTensor class for quantized 4-bit integer calculation
       6              :  * @see         https://github.com/nnstreamer/nntrainer
       7              :  * @author      Donghyeon Jeong <dhyeon.jeong@samsung.com>
       8              :  * @bug         No known bugs except for NYI items
       9              :  */
      10              : 
      11              : #include <iomanip>
      12              : #include <iostream>
      13              : 
      14              : #include <cpu_backend.h>
      15              : #include <int4_tensor.h>
      16              : #include <tensor.h>
      17              : 
      18              : namespace nntrainer {
      19              : 
      20              : size_t Int4QTensor::group_size = 32;
      21              : 
      22            0 : Int4QTensor::Int4QTensor(std::string name_, Tformat fm, QScheme qscheme_,
      23            0 :                          size_t g_size) :
      24            0 :   TensorBase(name_, fm, Tdatatype::QINT4), qscheme(qscheme_) {
      25            0 :   group_size = g_size;
      26            0 : }
      27              : 
      28            6 : Int4QTensor::Int4QTensor(const TensorDim &d, bool alloc_now, Initializer init,
      29            6 :                          std::string name, QScheme qscheme_, size_t g_size) :
      30            6 :   TensorBase(d, alloc_now, init, name), qscheme(qscheme_) {
      31            6 :   group_size = g_size;
      32            6 :   if (alloc_now)
      33            6 :     allocate();
      34            6 : }
      35              : 
      36            2 : Int4QTensor::Int4QTensor(const TensorDim &d, const void *buf, QScheme qscheme_,
      37            2 :                          size_t g_size) :
      38            2 :   Int4QTensor(d, true, Initializer::NONE, "", qscheme_, g_size) {
      39            2 :   if (d.getDataLen() != 0) {
      40            2 :     if (buf != nullptr)
      41            0 :       copy(buf);
      42              :   }
      43            2 : }
      44              : 
      45            5 : Int4QTensor::Int4QTensor(
      46              :   std::vector<std::vector<std::vector<std::vector<int8_t>>>> const &d,
      47              :   std::vector<float> const &scales, Tformat fm, QScheme qscheme_,
      48            5 :   size_t g_size) :
      49            5 :   qscheme(qscheme_) {
      50            5 :   group_size = g_size;
      51            5 :   if (d.empty() || d[0].empty() || d[0][0].empty() || d[0][0][0].empty()) {
      52              :     throw std::out_of_range(
      53            0 :       "[Tensor] trying to initialize Int4QTensor from empty vector");
      54              :   }
      55              : 
      56            6 :   NNTR_THROW_IF(scales.size() != scale_size(), std::invalid_argument)
      57              :     << "invalid scale factor size " << scales.size();
      58              : 
      59            4 :   dim.setTensorDim(0, d.size());
      60            4 :   if (fm == Tformat::NCHW) {
      61            4 :     dim.setTensorDim(1, d[0].size());
      62            4 :     dim.setTensorDim(2, d[0][0].size());
      63            4 :     dim.setTensorDim(3, d[0][0][0].size());
      64              :   } else {
      65            0 :     dim.setTensorDim(2, d[0].size());
      66            0 :     dim.setTensorDim(3, d[0][0].size());
      67            0 :     dim.setTensorDim(1, d[0][0][0].size());
      68              :   }
      69              : 
      70              :   dim.setTensorType({fm, Tdatatype::QINT4});
      71              : 
      72            4 :   strides = dim.computeStrides();
      73            4 :   contiguous = true;
      74            4 :   initializer = Initializer::NONE;
      75            4 :   qscheme = qscheme_;
      76              : 
      77              :   /// @note sizeof(float) * scale_size() assumes scale factors are in
      78              :   /// full-precision fp.
      79              :   MemoryData *mem_data =
      80            4 :     new MemoryData((void *)(new int8_t[(dim.getDataLen() + 1) / 2 +
      81          104 :                                        sizeof(float) * scale_size()]()));
      82            5 :   data = std::shared_ptr<MemoryData>(mem_data, [](MemoryData *mem_data) {
      83            4 :     delete[] mem_data->getAddr<int8_t>();
      84            4 :     delete mem_data;
      85              :   });
      86              : 
      87            4 :   offset = 0;
      88              : 
      89            4 :   if (fm == Tformat::NCHW) {
      90           10 :     for (unsigned int i = 0; i < batch(); ++i)
      91           17 :       for (unsigned int j = 0; j < channel(); ++j)
      92           48 :         for (unsigned int k = 0; k < height(); ++k)
      93          195 :           for (unsigned int l = 0; l < width(); ++l)
      94          158 :             this->setValue(i, j, k, l, d[i][j][k][l]);
      95              :   } else {
      96            0 :     for (unsigned int i = 0; i < batch(); ++i)
      97            0 :       for (unsigned int j = 0; j < height(); ++j)
      98            0 :         for (unsigned int k = 0; k < width(); ++k)
      99            0 :           for (unsigned int l = 0; l < channel(); ++l)
     100            0 :             this->setValue(i, l, j, k, d[i][j][k][l]);
     101              :   }
     102              : 
     103              :   // copy scale factors
     104            4 :   scopy(scale_size(), scales.data(), 1, (float *)getScale(), 1);
     105            5 : }
     106              : 
     107            2 : bool Int4QTensor::operator==(const Int4QTensor &rhs) const {
     108            2 :   if (qscheme != rhs.qscheme)
     109              :     return false;
     110              : 
     111              :   // compare quantized data
     112            2 :   const int8_t *_data = (int8_t *)getData();
     113            2 :   const int8_t *_rdata = (int8_t *)rhs.getData();
     114           14 :   for (size_t i = 0; i < (size() + 1) / 2; ++i) {
     115           13 :     if (_data[i] != _rdata[i])
     116              :       return false;
     117              :   }
     118              : 
     119              :   // compare scale factors
     120            1 :   const float *_scales = (float *)getScale();
     121            1 :   const float *_rscales = (float *)rhs.getScale();
     122            2 :   for (size_t i = 0; i < scale_size(); ++i) {
     123            1 :     if (std::fabs(_scales[i] - _rscales[i]) > 1e-5)
     124              :       return false;
     125              :   }
     126              : 
     127              :   return true;
     128              : }
     129              : 
     130            6 : void Int4QTensor::allocate() {
     131            6 :   if (empty() || data)
     132              :     return;
     133              : 
     134            6 :   if (src_tensor) {
     135              :     /// allocate data based on the source tensor
     136            0 :     allocateSrcTensor();
     137              :     /** as this memory is shared, do NOT initialize */
     138              :   } else {
     139              :     /// allocate new memory for the tensor data
     140              :     MemoryData *mem_data;
     141              : 
     142              :     /// quantized 4-bit is stored as a 8-bit signed integer (int4x2)
     143              :     mem_data =
     144            6 :       new MemoryData((void *)(new int8_t[(dim.getDataLen() + 1) / 2 +
     145         2748 :                                          sizeof(float) * scale_size()]{}));
     146            6 :     data = std::shared_ptr<MemoryData>(mem_data, [](auto *mem_data) {
     147            6 :       delete[] mem_data->template getAddr<int8_t>();
     148            6 :       delete mem_data;
     149              :     });
     150              : 
     151            6 :     offset = 0;
     152            6 :     initialize();
     153              :   }
     154              : }
     155              : 
     156            0 : void Int4QTensor::deallocate() {
     157              :   data = nullptr;
     158            0 :   offset = 0;
     159            0 : }
     160              : 
     161          496 : void *Int4QTensor::getData() const {
     162          496 :   if (!data)
     163              :     return nullptr;
     164              : 
     165              :   data->validate();
     166          496 :   return data->getAddr<int8_t>() + offset;
     167              : }
     168              : 
     169            0 : void *Int4QTensor::getData(size_t idx) const {
     170            0 :   if (!data)
     171              :     return nullptr;
     172              : 
     173              :   data->validate();
     174            0 :   return data->getAddr<int8_t>() + offset + (idx / 2);
     175              : }
     176              : 
     177            7 : void *Int4QTensor::getScale() const {
     178            7 :   if (!data)
     179              :     return nullptr;
     180              : 
     181              :   data->validate();
     182            7 :   return ((int8_t *)getData()) + (size() + 1) / 2;
     183              : }
     184              : 
     185            0 : void *Int4QTensor::getScale(size_t idx) const {
     186            0 :   NNTR_THROW_IF(idx > scale_size(), std::invalid_argument)
     187              :     << "Tensor::getScale() index is not valid";
     188              : 
     189            0 :   if (!data)
     190              :     return nullptr;
     191              : 
     192              :   data->validate();
     193            0 :   return ((float *)getScale()) + idx;
     194              : }
     195              : 
     196            0 : void *Int4QTensor::getAddress(unsigned int i) {
     197            0 :   size_t index = getIndex(batch(), channel(), height(), width());
     198            0 :   if (i > index) {
     199              :     return nullptr;
     200              :   }
     201            0 :   return &((int8_t *)getData())[i / 2];
     202              : }
     203              : 
     204            0 : const void *Int4QTensor::getAddress(unsigned int i) const {
     205            0 :   size_t index = getIndex(batch(), channel(), height(), width());
     206            0 :   if (i > index) {
     207              :     return nullptr;
     208              :   }
     209            0 :   return &((int8_t *)getData())[i / 2];
     210              : }
     211              : 
     212          125 : const int8_t Int4QTensor::getValue(unsigned int i) const {
     213          125 :   int8_t value = ((int8_t *)getData())[i / 2];
     214          125 :   return (i % 2 == 0) ? value >> 4 : ((int8_t)(value << 4) >> 4);
     215              : }
     216              : 
     217            2 : int8_t Int4QTensor::getValue(unsigned int i) {
     218            2 :   int8_t value = ((int8_t *)getData())[i / 2];
     219            2 :   return (i % 2 == 0) ? value >> 4 : ((int8_t)(value << 4) >> 4);
     220              : }
     221              : 
     222            0 : const int8_t Int4QTensor::getValue(unsigned int b, unsigned int c,
     223              :                                    unsigned int h, unsigned int w) const {
     224            0 :   return getValue(getIndex(b, c, h, w));
     225              : }
     226              : 
     227            0 : int8_t Int4QTensor::getValue(unsigned int b, unsigned int c, unsigned int h,
     228              :                              unsigned int w) {
     229            0 :   return getValue(getIndex(b, c, h, w));
     230              : }
     231              : 
     232              : /// @todo this func should be template function
     233            5 : void Int4QTensor::setValue(float value) {
     234            6 :   NNTR_THROW_IF(value < -8 || value > 7, std::out_of_range)
     235              :     << "Value must be in range [-8, 7]. Input value: " << value;
     236              : 
     237            4 :   int8_t val = static_cast<int8_t>(value);
     238            4 :   int8_t *data = (int8_t *)getData();
     239            4 :   std::fill(data, data + (size() + 1) / 2, (val << 4) | (val & 0x0f));
     240            4 : }
     241              : 
     242              : /// @todo this func should be template function
     243            2 : void Int4QTensor::addValue(unsigned int b, unsigned int c, unsigned int h,
     244              :                            unsigned int w, float value, float beta) {
     245            2 :   auto const &idx = getIndex(b, c, h, w);
     246            2 :   float output = getValue(idx);
     247            2 :   output *= beta;
     248            2 :   output += value;
     249              : 
     250              :   // if result value is out of range, clamp to max/min value
     251            2 :   int8_t val = static_cast<int8_t>(std::trunc(std::clamp((int)output, -8, 7)));
     252              : 
     253              :   // encode result value to int8 data
     254            2 :   ((int8_t *)getData())[idx / 2] =
     255            2 :     (idx % 2 == 0) ? (val << 4) | (((int8_t *)getData())[idx / 2] & 0x0f)
     256            0 :                    : (((int8_t *)getData())[idx / 2] & 0xf0) | (val & 0x0f);
     257            2 : }
     258              : 
     259              : /// @todo this func should be template function
     260          158 : void Int4QTensor::setValue(unsigned int b, unsigned int c, unsigned int h,
     261              :                            unsigned int w, float value) {
     262          158 :   NNTR_THROW_IF(value < -8 || value > 7, std::out_of_range)
     263              :     << "Value must be in range [-8, 7]. Input value: " << value;
     264              : 
     265          158 :   auto const &idx = getIndex(b, c, h, w);
     266          158 :   int8_t val = static_cast<int8_t>(value);
     267              : 
     268          158 :   ((int8_t *)getData())[idx / 2] =
     269          158 :     (idx % 2 == 0) ? (val << 4) | (((int8_t *)getData())[idx / 2] & 0x0f)
     270           78 :                    : (((int8_t *)getData())[idx / 2] & 0xf0) | (val & 0x0f);
     271          158 : }
     272              : 
     273            1 : void Int4QTensor::setZero() {
     274              :   /// @todo accelerate with SIMD
     275            1 :   setValue(0);
     276            1 : }
     277              : 
     278            9 : void Int4QTensor::initialize() {
     279            9 :   if (empty() || !isAllocated())
     280              :     return;
     281              : 
     282              :   /// @note Sampling from the normal/uniform distribution is invalid
     283            9 :   switch (initializer) {
     284            1 :   case Initializer::ZEROS:
     285            1 :     setZero();
     286            1 :     break;
     287            3 :   case Initializer::ONES:
     288            3 :     setValue(1.0f);
     289            3 :     break;
     290              :   case Initializer::NONE:
     291              :     break;
     292            1 :   default:
     293              :     throw std::invalid_argument(
     294            1 :       "Initializer other than zero and one is not valid for " +
     295            3 :       getStringDataType());
     296              :     break;
     297              :   }
     298              : 
     299            8 :   putData();
     300              : }
     301              : 
     302            3 : void Int4QTensor::initialize(Initializer init) {
     303            3 :   initializer = init;
     304            3 :   initialize();
     305            2 : }
     306              : 
     307            0 : void Int4QTensor::copy(const Tensor &from) {
     308            0 :   reshape(from.getDim());
     309            0 :   copy(from.getData());
     310            0 : }
     311              : 
     312            0 : void Int4QTensor::copyData(const Tensor &from) {
     313            0 :   NNTR_THROW_IF(!contiguous, std::invalid_argument)
     314              :     << getName() << " is not contiguous, cannot copy.";
     315              : 
     316            0 :   NNTR_THROW_IF(size() != from.size(), std::invalid_argument)
     317              :     << "Size of the tensor to copy must match.";
     318              : 
     319              :   /// @todo support copy from float32 & float16 to int8 data
     320            0 :   switch (from.getDataType()) {
     321              :   case ml::train::TensorDim::DataType::QINT4:
     322            0 :     copy(from.getData());
     323              :     break;
     324            0 :   default:
     325            0 :     throw std::invalid_argument("Error: Unsupported data type");
     326              :     break;
     327              :   }
     328            0 : }
     329              : 
     330            0 : void Int4QTensor::copy_with_stride(const Tensor &input, Tensor &output) {
     331            0 :   for (unsigned int b = 0; b < output.batch(); ++b) {
     332            0 :     for (unsigned int c = 0; c < output.channel(); ++c) {
     333            0 :       for (unsigned int h = 0; h < output.height(); ++h) {
     334            0 :         for (unsigned int w = 0; w < output.width(); ++w) {
     335            0 :           output.setValue(b, c, h, w, input.getValue<int8_t>(b, c, h, w));
     336              :         }
     337              :       }
     338              :     }
     339              :   }
     340            0 : }
     341              : 
     342            0 : void Int4QTensor::save(std::ostream &file) {
     343              :   /// @note Save quantization information
     344            0 :   save_quantization_info(file);
     345              : 
     346            0 :   std::streamsize sz = static_cast<std::streamsize>(getMemoryBytes());
     347              : 
     348            0 :   NNTR_THROW_IF(sz < 0, std::invalid_argument)
     349            0 :     << "save size: " << getMemoryBytes()
     350              :     << " is too big. It cannot be represented by std::streamsize";
     351              : 
     352            0 :   checkedWrite(file, (char *)getData(), sz,
     353              :                "[Int4QTensor::save] operation failed");
     354            0 :   putData();
     355            0 : }
     356              : 
     357            0 : void Int4QTensor::read(std::ifstream &file, size_t start_offset,
     358              :                        bool read_from_offset) {
     359            0 :   if (start_offset == std::numeric_limits<size_t>::max()) {
     360            0 :     start_offset = file_offset;
     361              :   }
     362            0 :   read_quantization_info(file, start_offset, read_from_offset);
     363              : 
     364            0 :   std::streamsize sz = static_cast<std::streamsize>(getMemoryBytes());
     365              : 
     366            0 :   NNTR_THROW_IF(sz < 0, std::invalid_argument)
     367            0 :     << "read size: " << getMemoryBytes()
     368              :     << " is too big. It cannot be represented by std::streamsize";
     369              : 
     370            0 :   if (read_from_offset) {
     371            0 :     start_offset += sizeof(uint16_t);
     372              :   }
     373              : 
     374            0 :   checkedRead(file, (char *)getData(), sz,
     375              :               "[Int4QTensor::read] operation failed", start_offset,
     376              :               read_from_offset);
     377            0 :   putData();
     378            0 : }
     379              : 
     380            0 : void Int4QTensor::read(ReadSource src, size_t start_offset,
     381              :                        bool read_from_offset) {
     382            0 :   if (start_offset == std::numeric_limits<size_t>::max()) {
     383            0 :     start_offset = file_offset;
     384              :   }
     385            0 :   read_quantization_info(src, start_offset, read_from_offset);
     386              : 
     387            0 :   std::streamsize sz = static_cast<std::streamsize>(getMemoryBytes());
     388              : 
     389            0 :   NNTR_THROW_IF(sz < 0, std::invalid_argument)
     390            0 :     << "read size: " << getMemoryBytes()
     391              :     << " is too big. It cannot be represented by std::streamsize";
     392              : 
     393            0 :   if (read_from_offset) {
     394            0 :     start_offset += sizeof(uint16_t);
     395              :   }
     396              : 
     397            0 :   checkedRead(src, (char *)getData(), sz,
     398              :               "[Int4QTensor::read] operation failed", start_offset,
     399              :               read_from_offset);
     400            0 :   putData();
     401            0 : }
     402              : 
     403            1 : std::vector<unsigned int> Int4QTensor::argmax() const {
     404              :   std::vector<unsigned int> result;
     405            1 :   const int8_t *data = (int8_t *)getData();
     406              :   size_t batch_size = batch();
     407            1 :   size_t feature_len = dim.getFeatureLen();
     408            1 :   result.resize(batch_size);
     409              : 
     410            4 :   for (unsigned int b = 0; b < batch_size; ++b) {
     411              :     int8_t curr_val, max_val = -8;
     412              :     unsigned int max_element_idx = 0;
     413           30 :     for (unsigned int idx = 0; idx < feature_len; ++idx) {
     414           27 :       curr_val = getValue(idx + b * feature_len);
     415              : 
     416           27 :       if (curr_val > max_val) {
     417              :         max_val = curr_val;
     418              :         max_element_idx = idx;
     419              :       }
     420              :     }
     421            3 :     result[b] = max_element_idx;
     422              :   }
     423            1 :   return result;
     424            0 : }
     425              : 
     426            0 : std::vector<unsigned int> Int4QTensor::argmin() const {
     427              :   std::vector<unsigned int> result;
     428            0 :   const int8_t *data = (int8_t *)getData();
     429              :   size_t batch_size = batch();
     430            0 :   size_t feature_len = dim.getFeatureLen();
     431            0 :   result.resize(batch_size);
     432              : 
     433            0 :   for (unsigned int b = 0; b < batch_size; ++b) {
     434              :     int8_t curr_val, min_val = 7;
     435              :     unsigned int min_element_idx = 0;
     436            0 :     for (unsigned int idx = 0; idx < feature_len; ++idx) {
     437            0 :       curr_val = getValue(idx + b * feature_len);
     438              : 
     439            0 :       if (curr_val < min_val) {
     440              :         min_val = curr_val;
     441              :         min_element_idx = idx;
     442              :       }
     443              :     }
     444            0 :     result[b] = min_element_idx;
     445              :   }
     446            0 :   return result;
     447            0 : }
     448              : 
     449            1 : float Int4QTensor::max_abs() const {
     450              :   int8_t abs_max_val = 0;
     451              :   int8_t curr_val;
     452           23 :   for (unsigned int idx = 0; idx < size(); ++idx) {
     453           23 :     curr_val = std::abs(getValue(idx));
     454           23 :     abs_max_val = (curr_val > abs_max_val) ? curr_val : abs_max_val;
     455              : 
     456              :     // Terminate search when abs_max_val is an Int4 absolute max value 8
     457           23 :     if (abs_max_val == 8)
     458              :       return abs_max_val;
     459              :   }
     460              : 
     461            0 :   return abs_max_val;
     462              : }
     463              : 
     464            0 : float Int4QTensor::maxValue() const {
     465              :   int8_t max_val = -8;
     466              :   int8_t curr_val;
     467            0 :   for (unsigned int idx = 0; idx < size(); ++idx) {
     468            0 :     curr_val = getValue(idx);
     469            0 :     max_val = (curr_val > max_val) ? curr_val : max_val;
     470              : 
     471              :     // Terminate search when max_val is an Int4 max value 7
     472            0 :     if (max_val == 7)
     473              :       return max_val;
     474              :   }
     475              : 
     476            0 :   return max_val;
     477              : }
     478              : 
     479            3 : float Int4QTensor::minValue() const {
     480              :   int8_t min_val = 7;
     481              :   int8_t curr_val;
     482           77 :   for (unsigned int idx = 0; idx < size(); ++idx) {
     483           75 :     curr_val = getValue(idx);
     484           75 :     min_val = (curr_val < min_val) ? curr_val : min_val;
     485              : 
     486              :     // Terminate search when min_val is an Int4 min value -8
     487           75 :     if (min_val == -8)
     488              :       return min_val;
     489              :   }
     490              : 
     491            2 :   return min_val;
     492              : }
     493              : 
     494            0 : void Int4QTensor::print(std::ostream &out) const {
     495            0 :   const int8_t *data = (int8_t *)getData();
     496            0 :   unsigned int len = size();
     497            0 :   out << "data addr: " << reinterpret_cast<const float *>(data) << '\n';
     498            0 :   out << dim;
     499              : 
     500            0 :   if (len > 100) {
     501            0 :     out << '[' << (int)getValue(0) << ' ' << (int)getValue(1) << ' '
     502            0 :         << (int)getValue(2) << " ... " << (int)getValue(len - 3) << ' '
     503            0 :         << (int)getValue(len - 2) << ' ' << (int)getValue(len - 1) << ']'
     504              :         << std::endl;
     505            0 :     return;
     506              :   }
     507              : 
     508            0 :   std::ios init(NULL);
     509            0 :   init.copyfmt(out);
     510            0 :   if (getFormat() == Tformat::NCHW) {
     511            0 :     for (unsigned int k = 0; k < batch(); k++) {
     512            0 :       for (unsigned int l = 0; l < channel(); l++) {
     513            0 :         for (unsigned int i = 0; i < height(); i++) {
     514            0 :           for (unsigned int j = 0; j < width(); j++) {
     515            0 :             out << std::setw(10) << (int)this->getValue(k, l, i, j) << " ";
     516              :           }
     517              :           out << std::endl;
     518              :         }
     519              :         out << std::endl;
     520              :       }
     521              :       out << "-------" << std::endl;
     522              :     }
     523              :   } else {
     524            0 :     for (unsigned int k = 0; k < batch(); k++) {
     525            0 :       for (unsigned int i = 0; i < height(); i++) {
     526            0 :         for (unsigned int j = 0; j < width(); j++) {
     527            0 :           for (unsigned int l = 0; l < channel(); l++) {
     528            0 :             out << std::setw(10) << (int)this->getValue(k, l, i, j) << " ";
     529              :           }
     530              :           out << std::endl;
     531              :         }
     532              :         out << std::endl;
     533              :       }
     534              :       out << "-------" << std::endl;
     535              :     }
     536            0 :     out.copyfmt(init);
     537              :   }
     538              : 
     539              :   /// print quantization information
     540            0 :   const float *q_scales = (float *)getScale();
     541              : 
     542            0 :   if (scale_size() > 50) {
     543            0 :     out << "Scale factors: [" << q_scales[0] << ' ' << q_scales[1] << ' '
     544            0 :         << q_scales[2] << " ... " << q_scales[len - 3] << ' '
     545            0 :         << q_scales[len - 2] << ' ' << q_scales[len - 1] << ']' << std::endl;
     546              :     return;
     547              :   }
     548              : 
     549            0 :   out << "Scale factors: ";
     550            0 :   for (unsigned i = 0; i < scale_size(); ++i) {
     551            0 :     out << q_scales[i] << " ";
     552              :   }
     553              :   out << std::endl;
     554              : }
     555              : 
     556            0 : size_t Int4QTensor::getMemoryBytes() const {
     557            0 :   return ((size() + 1) / 2) * dim.getDataTypeSize() +
     558            0 :          scale_size() * sizeof(uint16_t);
     559              : }
     560              : 
     561           21 : size_t Int4QTensor::scale_size() const {
     562           21 :   switch (qscheme) {
     563              :   case QScheme::PER_TENSOR_AFFINE:
     564              :     return 1;
     565              :     break;
     566            3 :   case QScheme::PER_CHANNEL_AFFINE:
     567            3 :     return height() * width() / group_size;
     568              :     break;
     569              :   default:
     570              :     break;
     571              :   }
     572            0 :   return 0;
     573              : }
     574              : 
     575            0 : QScheme Int4QTensor::q_scheme() const { return qscheme; }
     576              : 
     577            0 : void Int4QTensor::copy(const void *buf) {
     578            0 :   NNTR_THROW_IF(!contiguous, std::invalid_argument)
     579              :     << getName() << " is not contiguous, cannot copy.";
     580              : 
     581            0 :   if (buf == getData()) {
     582              :     return;
     583              :   }
     584              :   // copy tensor data
     585            0 :   scopy((size() + 1) / 2, (int8_t *)buf, 1, (int8_t *)getData(), 1);
     586              : 
     587              :   // copy scale factor data
     588            0 :   float *scales = (float *)(((int8_t *)buf) + (size() + 1) / 2);
     589            0 :   scopy(scale_size(), scales, 1, (float *)getScale(), 1);
     590              : }
     591              : 
     592            0 : void Int4QTensor::save_quantization_info(std::ostream &file) {
     593            0 :   checkedWrite(file, (char *)&qscheme, sizeof(uint16_t),
     594              :                "[Int4QTensor::save] failed to write quantization information");
     595            0 : }
     596              : 
     597            0 : void Int4QTensor::read_quantization_info(std::ifstream &file,
     598              :                                          size_t start_offset,
     599              :                                          bool read_from_offset) {
     600            0 :   checkedRead(file, (char *)&qscheme, sizeof(uint16_t),
     601              :               "[Int4QTensor::read] failed to read quantization information",
     602              :               start_offset, read_from_offset);
     603            0 :   group_size = 32; /// Remove me
     604            0 : }
     605              : 
     606            0 : void Int4QTensor::read_quantization_info(ReadSource src, size_t start_offset,
     607              :                                          bool read_from_offset) {
     608            0 :   checkedRead(src, (char *)&qscheme, sizeof(uint16_t),
     609              :               "[Int4QTensor::read] failed to read quantization information",
     610              :               start_offset, read_from_offset);
     611            0 :   group_size = 32; /// Remove me
     612            0 : }
     613              : 
     614            0 : size_t Int4QTensor::getGroupSize() { return group_size; }
     615              : 
     616              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1