LCOV - code coverage report
Current view: top level - nntrainer/tensor - int4_utils.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 62.9 % 151 95
Test Date: 2025-12-14 20:38:17 Functions: 77.8 % 9 7

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * @file        int4_utils.cpp
       4              :  * @date        15 October 2025
       5              :  * @brief       This is Int4Utils class for utils for INT4 quantization format.
       6              :  * @see         https://github.com/nnstreamer/nntrainer
       7              :  * @author      Grzegorz Kisala <gkisala@gmail.com>
       8              :  * @bug         No known bugs
       9              :  */
      10              : 
      11              : #include "int4_utils.h"
      12              : 
      13              : #include <cassert>
      14              : #include <cmath>
      15              : 
      16              : #include "cpu_backend.h"
      17              : #include "fp16.h"
      18              : #include "nntrainer_error.h"
      19              : #include "util_func.h"
      20              : 
      21              : namespace nntrainer {
      22              : 
      23      1700176 : float Int4Utils::computeScaleForGroup(const float *group_weights,
      24              :                                       const size_t group_size) {
      25              :   auto max_absolute_weight = 0.0f;
      26              : 
      27     57785168 :   for (size_t i = 0; i < group_size; ++i) {
      28     56084992 :     auto weight = group_weights[i];
      29              : 
      30     56084992 :     NNTR_THROW_IF(!std::isfinite(weight), std::invalid_argument)
      31              :       << "Weight is not finite value";
      32              : 
      33              :     const auto absolute_weight = std::abs(weight);
      34              : 
      35     56084992 :     if (absolute_weight > max_absolute_weight) {
      36              :       max_absolute_weight = absolute_weight;
      37              :     }
      38              :   }
      39              : 
      40              :   auto group_scale =
      41      1700176 :     (max_absolute_weight == 0.0f) ? 1.0f : (max_absolute_weight / 7.0f);
      42              : 
      43      1700176 :   NNTR_THROW_IF(!std::isfinite(group_scale), std::invalid_argument)
      44              :     << "Scale is not finite value";
      45              : 
      46      1700176 :   return group_scale;
      47              : }
      48              : 
      49           24 : void Int4Utils::computeScales(const float *weights, const size_t rows_count,
      50              :                               const size_t columns_count,
      51              :                               const size_t group_size,
      52              :                               std::vector<float> &scales) {
      53              :   // NNTR_THROW_IF(columns_count % group_size, std::invalid_argument)
      54              :   //   << "Columns size not divisible by group size";
      55           24 :   NNTR_THROW_IF(columns_count % 4, std::invalid_argument)
      56              :     << "Columns size not divisible by 4";
      57              : 
      58           24 :   const auto full_groups_per_row = columns_count / group_size;
      59           24 :   const auto last_group_size = columns_count % group_size;
      60           24 :   const auto padded_groups_per_row = ceilDiv(columns_count, group_size);
      61           24 :   const auto rows_count_pad = align(rows_count, ROW_BLOCK_SIZE);
      62           24 :   scales.resize(rows_count_pad * padded_groups_per_row, 1.0f);
      63              : 
      64        23352 :   for (size_t row_id = 0; row_id < rows_count; ++row_id) {
      65        23328 :     const auto *weights_row = weights + (row_id * columns_count);
      66              : 
      67      1723504 :     for (size_t group_id = 0; group_id < full_groups_per_row; ++group_id) {
      68      1700176 :       const auto *weights_group = weights_row + (group_id * group_size);
      69      1700176 :       scales[(group_id * rows_count_pad) + row_id] =
      70      1700176 :         computeScaleForGroup(weights_group, group_size);
      71              :     }
      72              : 
      73              :     // Compute scale for the last padded group
      74        23328 :     if (last_group_size > 0) {
      75            0 :       const auto *weights_group =
      76            0 :         weights_row + (full_groups_per_row * group_size);
      77            0 :       scales[(full_groups_per_row * rows_count_pad) + row_id] =
      78            0 :         computeScaleForGroup(weights_group, last_group_size);
      79              :     }
      80              :   }
      81           24 : }
      82              : 
      83     56084992 : uint8_t Int4Utils::pack(const float *weights, const float *scales,
      84              :                         const size_t row_id, const size_t column_id,
      85              :                         const size_t groups_per_row, const size_t group_size,
      86              :                         const size_t rows_count, const size_t columns_count) {
      87              :   {
      88     56084992 :     const auto rows_count_pad = align(rows_count, ROW_BLOCK_SIZE);
      89     56084992 :     const float scale =
      90     56084992 :       scales[row_id + ((column_id / group_size) * rows_count_pad)];
      91     56084992 :     const float weight = weights[(row_id * columns_count) + column_id];
      92     56084992 :     return quantizeToInt4(weight, scale);
      93              :   }
      94              : }
      95              : 
      96           24 : void Int4Utils::quantizeAndRepack(const float *weights, const size_t rows_count,
      97              :                                   const size_t columns_count,
      98              :                                   const size_t group_size,
      99              :                                   std::vector<uint8_t> &out_weights,
     100              :                                   std::vector<uint16_t> &out_scales) {
     101           24 :   NNTR_THROW_IF(!weights, std::invalid_argument) << "Weight cannot be null";
     102              : 
     103           24 :   NNTR_THROW_IF((rows_count <= 0), std::invalid_argument)
     104              :     << "Rows count needs to be greater than 0";
     105              : 
     106           24 :   NNTR_THROW_IF((columns_count <= 0), std::invalid_argument)
     107              :     << "Columns count needs to be greater than 0";
     108              : 
     109           24 :   NNTR_THROW_IF((!(group_size == 32 || group_size == 64 || group_size == 128)),
     110              :                 std::invalid_argument)
     111              :     << "Group size must be 32/64/128";
     112              : 
     113              :   std::vector<float> scales_fp32;
     114           24 :   computeScales(weights, rows_count, columns_count, group_size, scales_fp32);
     115              : 
     116           24 :   out_scales.resize(scales_fp32.size());
     117      1703448 :   for (size_t scale_id = 0; scale_id < scales_fp32.size(); ++scale_id) {
     118      1703424 :     out_scales[scale_id] = compute_fp32_to_fp16(scales_fp32[scale_id]);
     119              :   }
     120              : 
     121           24 :   NNTR_THROW_IF(columns_count % COLUMN_BLOCK_SIZE, std::invalid_argument)
     122              :     << "Columns size not divisible by column block size";
     123              : 
     124              :   // Prepare output buffer in OS_IS_YX_OSV32_ISV2 layout
     125           24 :   const auto groups_per_row = ceilDiv(columns_count, group_size);
     126           24 :   const auto row_blocks_count = ceilDiv(rows_count, ROW_BLOCK_SIZE);
     127           24 :   const auto columns_count_pad = align(columns_count, group_size);
     128              :   const auto column_blocks_count =
     129           24 :     ceilDiv(columns_count_pad, COLUMN_BLOCK_SIZE);
     130           24 :   const auto rows_count_pad = row_blocks_count * ROW_BLOCK_SIZE;
     131              : 
     132           24 :   out_weights.resize((rows_count_pad * columns_count_pad) / 2, 0);
     133              : 
     134              :   size_t out_idx = 0;
     135              : 
     136          766 :   for (size_t row_block_id = 0; row_block_id < row_blocks_count;
     137              :        ++row_block_id) {
     138       879846 :     for (size_t column_block_id = 0; column_block_id < column_blocks_count;
     139              :          ++column_block_id) {
     140     29010432 :       for (size_t i = 0; i < ROW_BLOCK_SIZE; ++i) {
     141              :         uint8_t lo = 0, hi = 0;
     142     28131328 :         const auto row_id_absolute = (row_block_id * ROW_BLOCK_SIZE) + i;
     143     28131328 :         if (row_id_absolute < rows_count) {
     144     28042496 :           const auto column_id_absolute_lo =
     145              :             (column_block_id * COLUMN_BLOCK_SIZE);
     146     28042496 :           if (column_id_absolute_lo < columns_count) {
     147     28042496 :             lo = pack(weights, scales_fp32.data(), row_id_absolute,
     148              :                       column_id_absolute_lo, groups_per_row, group_size,
     149              :                       rows_count, columns_count);
     150              : 
     151     28042496 :             const auto column_id_absolute_hi = column_id_absolute_lo + 1;
     152     28042496 :             if (column_id_absolute_hi < columns_count) {
     153     28042496 :               hi = pack(weights, scales_fp32.data(), row_id_absolute,
     154              :                         column_id_absolute_hi, groups_per_row, group_size,
     155              :                         rows_count, columns_count);
     156              :             }
     157              :           }
     158              :         }
     159              : 
     160     28131328 :         out_weights[out_idx++] = uint8_t((hi << 4) | lo);
     161              :       }
     162              :     }
     163              :   }
     164           24 : }
     165              : 
     166     56084992 : uint8_t Int4Utils::quantizeToInt4(const float weight, const float scale) {
     167     56084992 :   auto div = std::nearbyintf(weight / scale);
     168              : 
     169     56084992 :   if (std::isnan(div)) {
     170            0 :     div = 0.0f;
     171              :   }
     172              : 
     173     56084992 :   div = std::clamp(div, -8.0f, 7.0f);
     174     56084992 :   int quantized = (int)div;
     175     56084992 :   return uint8_t(quantized & 0xF);
     176              : }
     177              : 
     178     56084992 : int Int4Utils::convertInt4ToInt(const uint8_t int4_value) {
     179              :   static int lookup[] = {0,  1,  2,  3,  4,  5,  6,  7,
     180              :                          -8, -7, -6, -5, -4, -3, -2, -1};
     181              : 
     182     56084992 :   return lookup[int4_value];
     183              : }
     184              : 
     185           24 : void Int4Utils::dequantizePacked(const std::vector<uint8_t> &weights,
     186              :                                  const std::vector<uint16_t> &scales,
     187              :                                  const size_t rows_count,
     188              :                                  const size_t columns_count,
     189              :                                  const size_t group_size,
     190              :                                  std::vector<float> &dequantized_weights) {
     191           24 :   const auto groups_per_row = ceilDiv(columns_count, group_size);
     192           24 :   const auto rows_count_pad = align(rows_count, ROW_BLOCK_SIZE);
     193           24 :   const auto row_blocks_count = ceilDiv(rows_count, ROW_BLOCK_SIZE);
     194           24 :   const auto columns_count_pad = align(columns_count, group_size);
     195              :   const auto column_blocks_count =
     196           24 :     ceilDiv(columns_count_pad, COLUMN_BLOCK_SIZE);
     197              : 
     198           24 :   dequantized_weights.resize(rows_count * columns_count);
     199              : 
     200              :   size_t weights_idx = 0;
     201              : 
     202          766 :   for (size_t row_block_id = 0; row_block_id < row_blocks_count;
     203              :        ++row_block_id) {
     204       879846 :     for (size_t column_block_id = 0; column_block_id < column_blocks_count;
     205              :          ++column_block_id) {
     206     29010432 :       for (size_t i = 0; i < ROW_BLOCK_SIZE; ++i) {
     207              :         uint8_t lo = 0, hi = 0;
     208     28131328 :         const auto row_id_absolute = (row_block_id * ROW_BLOCK_SIZE) + i;
     209     28131328 :         if (row_id_absolute < rows_count) {
     210     28042496 :           const auto column_id_absolute_lo =
     211              :             (column_block_id * COLUMN_BLOCK_SIZE);
     212     28042496 :           if (column_id_absolute_lo < columns_count) {
     213     28042496 :             const auto column_id_absolute_hi = column_id_absolute_lo + 1;
     214              : 
     215              :             const auto scale_lo =
     216              :               scales[row_id_absolute +
     217     28042496 :                      ((column_id_absolute_lo / group_size) * rows_count_pad)];
     218              : 
     219              :             const auto scale_hi =
     220              :               scales[row_id_absolute +
     221     28042496 :                      ((column_id_absolute_hi / group_size) * rows_count_pad)];
     222              : 
     223     28042496 :             const auto weight = weights[weights_idx];
     224              :             const auto weight_lo = weight & 0xF;
     225     28042496 :             const auto weight_hi = (weight >> 4) & 0xF;
     226              : 
     227     28042496 :             dequantized_weights[(row_id_absolute * columns_count) +
     228     28042496 :                                 column_id_absolute_lo] =
     229     56084992 :               Int4Utils::convertInt4ToInt(weight_lo) *
     230     28042496 :               nntrainer::compute_fp16_to_fp32(scale_lo);
     231              : 
     232     28042496 :             if (column_id_absolute_hi < columns_count) {
     233              :               dequantized_weights[(row_id_absolute * columns_count) +
     234     28042496 :                                   column_id_absolute_hi] =
     235     56084992 :                 Int4Utils::convertInt4ToInt(weight_hi) *
     236     28042496 :                 nntrainer::compute_fp16_to_fp32(scale_hi);
     237              :             }
     238              :           }
     239              :         }
     240     28131328 :         weights_idx++;
     241              :       }
     242              :     }
     243              :   }
     244           24 : }
     245              : 
     246            0 : void Int4Utils::dequantizePackedRow(uint8_t *weights, uint16_t *scales,
     247              :                                     const size_t rows_count,
     248              :                                     const size_t columns_count,
     249              :                                     const size_t group_size,
     250              :                                     const size_t row_index,
     251              :                                     float *dequantized_row) {
     252              :   // --- Validate ---
     253            0 :   NNTR_THROW_IF(rows_count == 0 || columns_count == 0, std::invalid_argument)
     254              :     << "rows_count and columns_count must be > 0";
     255            0 :   NNTR_THROW_IF(row_index >= rows_count, std::out_of_range)
     256              :     << "row_index out of range";
     257            0 :   NNTR_THROW_IF(!(group_size == 32 || group_size == 64 || group_size == 128),
     258              :                 std::invalid_argument)
     259              :     << "group_size must be 32/64/128";
     260              : 
     261              :   // --- Layout ---
     262            0 :   const size_t rows_count_pad = align(rows_count, ROW_BLOCK_SIZE);
     263            0 :   const size_t columns_count_pad = align(columns_count, group_size);
     264              :   const size_t column_blocks_count =
     265            0 :     ceilDiv(columns_count_pad, COLUMN_BLOCK_SIZE); // COLUMN_BLOCK_SIZE == 2
     266            0 :   const size_t padded_groups_per_row = ceilDiv(columns_count, group_size);
     267              : 
     268              :   // Address the bytes for this row
     269            0 :   const size_t row_block_id = row_index / ROW_BLOCK_SIZE;
     270            0 :   const size_t i_in_block = row_index % ROW_BLOCK_SIZE;
     271              :   const size_t bytes_per_row_block_span = column_blocks_count * ROW_BLOCK_SIZE;
     272            0 :   const size_t row_block_base =
     273            0 :     row_block_id * bytes_per_row_block_span + i_in_block;
     274              : 
     275            0 :   for (size_t column_block_id = 0; column_block_id < column_blocks_count;
     276              :        ++column_block_id) {
     277            0 :     const size_t weights_idx =
     278            0 :       row_block_base + column_block_id * ROW_BLOCK_SIZE;
     279            0 :     const uint8_t packed_byte = weights[weights_idx];
     280              : 
     281            0 :     const size_t col_lo = column_block_id * COLUMN_BLOCK_SIZE;
     282            0 :     const size_t col_hi = col_lo + 1;
     283              : 
     284            0 :     const int q_lo = Int4Utils::convertInt4ToInt(packed_byte & 0xF);
     285            0 :     const int q_hi = Int4Utils::convertInt4ToInt((packed_byte >> 4) & 0xF);
     286              : 
     287            0 :     if (col_lo < columns_count) {
     288            0 :       const size_t g_lo = col_lo / group_size;
     289            0 :       const float s_lo = nntrainer::compute_fp16_to_fp32(
     290            0 :         scales[row_index + g_lo * rows_count_pad]);
     291            0 :       dequantized_row[col_lo] = static_cast<float>(q_lo) * s_lo;
     292              :     }
     293            0 :     if (col_hi < columns_count) {
     294            0 :       const size_t g_hi = col_hi / group_size;
     295            0 :       const float s_hi = nntrainer::compute_fp16_to_fp32(
     296            0 :         scales[row_index + g_hi * rows_count_pad]);
     297            0 :       dequantized_row[col_hi] = static_cast<float>(q_hi) * s_hi;
     298              :     }
     299              :   }
     300            0 : }
     301              : 
     302            0 : void Int4Utils::dequantizePackedRow32ToInt4Scale(
     303              :   const uint8_t *weights, const uint16_t *scales, const size_t rows_count,
     304              :   const size_t columns_count, const size_t group_size, const size_t row_index,
     305              :   const size_t column_index, uint8_t *weight_int4_row32, uint16_t *scale) {
     306              :   // --- Validate ---
     307            0 :   NNTR_THROW_IF(rows_count == 0 || columns_count == 0, std::invalid_argument)
     308              :     << "rows_count and columns_count must be > 0";
     309            0 :   NNTR_THROW_IF(row_index >= rows_count, std::out_of_range)
     310              :     << "row_index out of range";
     311            0 :   NNTR_THROW_IF(!(group_size == 32 || group_size == 64 || group_size == 128),
     312              :                 std::invalid_argument)
     313              :     << "group_size must be 32/64/128";
     314            0 :   NNTR_THROW_IF(columns_count % 32 != 0, std::invalid_argument)
     315              :     << "columns_count must be divisible by 32";
     316              : 
     317              :   // --- Layout ---
     318            0 :   const size_t rows_count_pad = align(rows_count, ROW_BLOCK_SIZE);
     319            0 :   const size_t columns_count_pad = align(columns_count, group_size);
     320              :   const size_t column_blocks_count =
     321            0 :     ceilDiv(columns_count_pad, COLUMN_BLOCK_SIZE); // COLUMN_BLOCK_SIZE == 2
     322            0 :   const size_t padded_groups_per_row = ceilDiv(columns_count, group_size);
     323              : 
     324              :   // Address the bytes for this row
     325            0 :   const size_t row_block_id = row_index / ROW_BLOCK_SIZE;
     326            0 :   const size_t i_in_block = row_index % ROW_BLOCK_SIZE;
     327              :   const size_t bytes_per_row_block_span = column_blocks_count * ROW_BLOCK_SIZE;
     328            0 :   const size_t row_block_base =
     329            0 :     row_block_id * bytes_per_row_block_span + i_in_block;
     330              : 
     331            0 :   for (size_t column_block_id = 0; column_block_id < 16; ++column_block_id) {
     332            0 :     const size_t weights_idx =
     333            0 :       row_block_base + (column_index / 2 + column_block_id) * ROW_BLOCK_SIZE;
     334            0 :     const uint8_t packed_byte = weights[weights_idx];
     335              : 
     336            0 :     weight_int4_row32[column_block_id] = packed_byte;
     337              :   }
     338              : 
     339            0 :   *scale = scales[row_index + (column_index / group_size) * rows_count_pad];
     340            0 : }
     341              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1