Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * @file int4_utils.cpp
4 : * @date 15 October 2025
5 : * @brief This is Int4Utils class for utils for INT4 quantization format.
6 : * @see https://github.com/nnstreamer/nntrainer
7 : * @author Grzegorz Kisala <gkisala@gmail.com>
8 : * @bug No known bugs
9 : */
10 :
11 : #include "int4_utils.h"
12 :
13 : #include <cassert>
14 : #include <cmath>
15 :
16 : #include "cpu_backend.h"
17 : #include "fp16.h"
18 : #include "nntrainer_error.h"
19 : #include "util_func.h"
20 :
21 : namespace nntrainer {
22 :
23 1700176 : float Int4Utils::computeScaleForGroup(const float *group_weights,
24 : const size_t group_size) {
25 : auto max_absolute_weight = 0.0f;
26 :
27 57785168 : for (size_t i = 0; i < group_size; ++i) {
28 56084992 : auto weight = group_weights[i];
29 :
30 56084992 : NNTR_THROW_IF(!std::isfinite(weight), std::invalid_argument)
31 : << "Weight is not finite value";
32 :
33 : const auto absolute_weight = std::abs(weight);
34 :
35 56084992 : if (absolute_weight > max_absolute_weight) {
36 : max_absolute_weight = absolute_weight;
37 : }
38 : }
39 :
40 : auto group_scale =
41 1700176 : (max_absolute_weight == 0.0f) ? 1.0f : (max_absolute_weight / 7.0f);
42 :
43 1700176 : NNTR_THROW_IF(!std::isfinite(group_scale), std::invalid_argument)
44 : << "Scale is not finite value";
45 :
46 1700176 : return group_scale;
47 : }
48 :
49 24 : void Int4Utils::computeScales(const float *weights, const size_t rows_count,
50 : const size_t columns_count,
51 : const size_t group_size,
52 : std::vector<float> &scales) {
53 : // NNTR_THROW_IF(columns_count % group_size, std::invalid_argument)
54 : // << "Columns size not divisible by group size";
55 24 : NNTR_THROW_IF(columns_count % 4, std::invalid_argument)
56 : << "Columns size not divisible by 4";
57 :
58 24 : const auto full_groups_per_row = columns_count / group_size;
59 24 : const auto last_group_size = columns_count % group_size;
60 24 : const auto padded_groups_per_row = ceilDiv(columns_count, group_size);
61 24 : const auto rows_count_pad = align(rows_count, ROW_BLOCK_SIZE);
62 24 : scales.resize(rows_count_pad * padded_groups_per_row, 1.0f);
63 :
64 23352 : for (size_t row_id = 0; row_id < rows_count; ++row_id) {
65 23328 : const auto *weights_row = weights + (row_id * columns_count);
66 :
67 1723504 : for (size_t group_id = 0; group_id < full_groups_per_row; ++group_id) {
68 1700176 : const auto *weights_group = weights_row + (group_id * group_size);
69 1700176 : scales[(group_id * rows_count_pad) + row_id] =
70 1700176 : computeScaleForGroup(weights_group, group_size);
71 : }
72 :
73 : // Compute scale for the last padded group
74 23328 : if (last_group_size > 0) {
75 0 : const auto *weights_group =
76 0 : weights_row + (full_groups_per_row * group_size);
77 0 : scales[(full_groups_per_row * rows_count_pad) + row_id] =
78 0 : computeScaleForGroup(weights_group, last_group_size);
79 : }
80 : }
81 24 : }
82 :
83 56084992 : uint8_t Int4Utils::pack(const float *weights, const float *scales,
84 : const size_t row_id, const size_t column_id,
85 : const size_t groups_per_row, const size_t group_size,
86 : const size_t rows_count, const size_t columns_count) {
87 : {
88 56084992 : const auto rows_count_pad = align(rows_count, ROW_BLOCK_SIZE);
89 56084992 : const float scale =
90 56084992 : scales[row_id + ((column_id / group_size) * rows_count_pad)];
91 56084992 : const float weight = weights[(row_id * columns_count) + column_id];
92 56084992 : return quantizeToInt4(weight, scale);
93 : }
94 : }
95 :
96 24 : void Int4Utils::quantizeAndRepack(const float *weights, const size_t rows_count,
97 : const size_t columns_count,
98 : const size_t group_size,
99 : std::vector<uint8_t> &out_weights,
100 : std::vector<uint16_t> &out_scales) {
101 24 : NNTR_THROW_IF(!weights, std::invalid_argument) << "Weight cannot be null";
102 :
103 24 : NNTR_THROW_IF((rows_count <= 0), std::invalid_argument)
104 : << "Rows count needs to be greater than 0";
105 :
106 24 : NNTR_THROW_IF((columns_count <= 0), std::invalid_argument)
107 : << "Columns count needs to be greater than 0";
108 :
109 24 : NNTR_THROW_IF((!(group_size == 32 || group_size == 64 || group_size == 128)),
110 : std::invalid_argument)
111 : << "Group size must be 32/64/128";
112 :
113 : std::vector<float> scales_fp32;
114 24 : computeScales(weights, rows_count, columns_count, group_size, scales_fp32);
115 :
116 24 : out_scales.resize(scales_fp32.size());
117 1703448 : for (size_t scale_id = 0; scale_id < scales_fp32.size(); ++scale_id) {
118 1703424 : out_scales[scale_id] = compute_fp32_to_fp16(scales_fp32[scale_id]);
119 : }
120 :
121 24 : NNTR_THROW_IF(columns_count % COLUMN_BLOCK_SIZE, std::invalid_argument)
122 : << "Columns size not divisible by column block size";
123 :
124 : // Prepare output buffer in OS_IS_YX_OSV32_ISV2 layout
125 24 : const auto groups_per_row = ceilDiv(columns_count, group_size);
126 24 : const auto row_blocks_count = ceilDiv(rows_count, ROW_BLOCK_SIZE);
127 24 : const auto columns_count_pad = align(columns_count, group_size);
128 : const auto column_blocks_count =
129 24 : ceilDiv(columns_count_pad, COLUMN_BLOCK_SIZE);
130 24 : const auto rows_count_pad = row_blocks_count * ROW_BLOCK_SIZE;
131 :
132 24 : out_weights.resize((rows_count_pad * columns_count_pad) / 2, 0);
133 :
134 : size_t out_idx = 0;
135 :
136 766 : for (size_t row_block_id = 0; row_block_id < row_blocks_count;
137 : ++row_block_id) {
138 879846 : for (size_t column_block_id = 0; column_block_id < column_blocks_count;
139 : ++column_block_id) {
140 29010432 : for (size_t i = 0; i < ROW_BLOCK_SIZE; ++i) {
141 : uint8_t lo = 0, hi = 0;
142 28131328 : const auto row_id_absolute = (row_block_id * ROW_BLOCK_SIZE) + i;
143 28131328 : if (row_id_absolute < rows_count) {
144 28042496 : const auto column_id_absolute_lo =
145 : (column_block_id * COLUMN_BLOCK_SIZE);
146 28042496 : if (column_id_absolute_lo < columns_count) {
147 28042496 : lo = pack(weights, scales_fp32.data(), row_id_absolute,
148 : column_id_absolute_lo, groups_per_row, group_size,
149 : rows_count, columns_count);
150 :
151 28042496 : const auto column_id_absolute_hi = column_id_absolute_lo + 1;
152 28042496 : if (column_id_absolute_hi < columns_count) {
153 28042496 : hi = pack(weights, scales_fp32.data(), row_id_absolute,
154 : column_id_absolute_hi, groups_per_row, group_size,
155 : rows_count, columns_count);
156 : }
157 : }
158 : }
159 :
160 28131328 : out_weights[out_idx++] = uint8_t((hi << 4) | lo);
161 : }
162 : }
163 : }
164 24 : }
165 :
166 56084992 : uint8_t Int4Utils::quantizeToInt4(const float weight, const float scale) {
167 56084992 : auto div = std::nearbyintf(weight / scale);
168 :
169 56084992 : if (std::isnan(div)) {
170 0 : div = 0.0f;
171 : }
172 :
173 56084992 : div = std::clamp(div, -8.0f, 7.0f);
174 56084992 : int quantized = (int)div;
175 56084992 : return uint8_t(quantized & 0xF);
176 : }
177 :
178 56084992 : int Int4Utils::convertInt4ToInt(const uint8_t int4_value) {
179 : static int lookup[] = {0, 1, 2, 3, 4, 5, 6, 7,
180 : -8, -7, -6, -5, -4, -3, -2, -1};
181 :
182 56084992 : return lookup[int4_value];
183 : }
184 :
185 24 : void Int4Utils::dequantizePacked(const std::vector<uint8_t> &weights,
186 : const std::vector<uint16_t> &scales,
187 : const size_t rows_count,
188 : const size_t columns_count,
189 : const size_t group_size,
190 : std::vector<float> &dequantized_weights) {
191 24 : const auto groups_per_row = ceilDiv(columns_count, group_size);
192 24 : const auto rows_count_pad = align(rows_count, ROW_BLOCK_SIZE);
193 24 : const auto row_blocks_count = ceilDiv(rows_count, ROW_BLOCK_SIZE);
194 24 : const auto columns_count_pad = align(columns_count, group_size);
195 : const auto column_blocks_count =
196 24 : ceilDiv(columns_count_pad, COLUMN_BLOCK_SIZE);
197 :
198 24 : dequantized_weights.resize(rows_count * columns_count);
199 :
200 : size_t weights_idx = 0;
201 :
202 766 : for (size_t row_block_id = 0; row_block_id < row_blocks_count;
203 : ++row_block_id) {
204 879846 : for (size_t column_block_id = 0; column_block_id < column_blocks_count;
205 : ++column_block_id) {
206 29010432 : for (size_t i = 0; i < ROW_BLOCK_SIZE; ++i) {
207 : uint8_t lo = 0, hi = 0;
208 28131328 : const auto row_id_absolute = (row_block_id * ROW_BLOCK_SIZE) + i;
209 28131328 : if (row_id_absolute < rows_count) {
210 28042496 : const auto column_id_absolute_lo =
211 : (column_block_id * COLUMN_BLOCK_SIZE);
212 28042496 : if (column_id_absolute_lo < columns_count) {
213 28042496 : const auto column_id_absolute_hi = column_id_absolute_lo + 1;
214 :
215 : const auto scale_lo =
216 : scales[row_id_absolute +
217 28042496 : ((column_id_absolute_lo / group_size) * rows_count_pad)];
218 :
219 : const auto scale_hi =
220 : scales[row_id_absolute +
221 28042496 : ((column_id_absolute_hi / group_size) * rows_count_pad)];
222 :
223 28042496 : const auto weight = weights[weights_idx];
224 : const auto weight_lo = weight & 0xF;
225 28042496 : const auto weight_hi = (weight >> 4) & 0xF;
226 :
227 28042496 : dequantized_weights[(row_id_absolute * columns_count) +
228 28042496 : column_id_absolute_lo] =
229 56084992 : Int4Utils::convertInt4ToInt(weight_lo) *
230 28042496 : nntrainer::compute_fp16_to_fp32(scale_lo);
231 :
232 28042496 : if (column_id_absolute_hi < columns_count) {
233 : dequantized_weights[(row_id_absolute * columns_count) +
234 28042496 : column_id_absolute_hi] =
235 56084992 : Int4Utils::convertInt4ToInt(weight_hi) *
236 28042496 : nntrainer::compute_fp16_to_fp32(scale_hi);
237 : }
238 : }
239 : }
240 28131328 : weights_idx++;
241 : }
242 : }
243 : }
244 24 : }
245 :
246 0 : void Int4Utils::dequantizePackedRow(uint8_t *weights, uint16_t *scales,
247 : const size_t rows_count,
248 : const size_t columns_count,
249 : const size_t group_size,
250 : const size_t row_index,
251 : float *dequantized_row) {
252 : // --- Validate ---
253 0 : NNTR_THROW_IF(rows_count == 0 || columns_count == 0, std::invalid_argument)
254 : << "rows_count and columns_count must be > 0";
255 0 : NNTR_THROW_IF(row_index >= rows_count, std::out_of_range)
256 : << "row_index out of range";
257 0 : NNTR_THROW_IF(!(group_size == 32 || group_size == 64 || group_size == 128),
258 : std::invalid_argument)
259 : << "group_size must be 32/64/128";
260 :
261 : // --- Layout ---
262 0 : const size_t rows_count_pad = align(rows_count, ROW_BLOCK_SIZE);
263 0 : const size_t columns_count_pad = align(columns_count, group_size);
264 : const size_t column_blocks_count =
265 0 : ceilDiv(columns_count_pad, COLUMN_BLOCK_SIZE); // COLUMN_BLOCK_SIZE == 2
266 0 : const size_t padded_groups_per_row = ceilDiv(columns_count, group_size);
267 :
268 : // Address the bytes for this row
269 0 : const size_t row_block_id = row_index / ROW_BLOCK_SIZE;
270 0 : const size_t i_in_block = row_index % ROW_BLOCK_SIZE;
271 : const size_t bytes_per_row_block_span = column_blocks_count * ROW_BLOCK_SIZE;
272 0 : const size_t row_block_base =
273 0 : row_block_id * bytes_per_row_block_span + i_in_block;
274 :
275 0 : for (size_t column_block_id = 0; column_block_id < column_blocks_count;
276 : ++column_block_id) {
277 0 : const size_t weights_idx =
278 0 : row_block_base + column_block_id * ROW_BLOCK_SIZE;
279 0 : const uint8_t packed_byte = weights[weights_idx];
280 :
281 0 : const size_t col_lo = column_block_id * COLUMN_BLOCK_SIZE;
282 0 : const size_t col_hi = col_lo + 1;
283 :
284 0 : const int q_lo = Int4Utils::convertInt4ToInt(packed_byte & 0xF);
285 0 : const int q_hi = Int4Utils::convertInt4ToInt((packed_byte >> 4) & 0xF);
286 :
287 0 : if (col_lo < columns_count) {
288 0 : const size_t g_lo = col_lo / group_size;
289 0 : const float s_lo = nntrainer::compute_fp16_to_fp32(
290 0 : scales[row_index + g_lo * rows_count_pad]);
291 0 : dequantized_row[col_lo] = static_cast<float>(q_lo) * s_lo;
292 : }
293 0 : if (col_hi < columns_count) {
294 0 : const size_t g_hi = col_hi / group_size;
295 0 : const float s_hi = nntrainer::compute_fp16_to_fp32(
296 0 : scales[row_index + g_hi * rows_count_pad]);
297 0 : dequantized_row[col_hi] = static_cast<float>(q_hi) * s_hi;
298 : }
299 : }
300 0 : }
301 :
302 0 : void Int4Utils::dequantizePackedRow32ToInt4Scale(
303 : const uint8_t *weights, const uint16_t *scales, const size_t rows_count,
304 : const size_t columns_count, const size_t group_size, const size_t row_index,
305 : const size_t column_index, uint8_t *weight_int4_row32, uint16_t *scale) {
306 : // --- Validate ---
307 0 : NNTR_THROW_IF(rows_count == 0 || columns_count == 0, std::invalid_argument)
308 : << "rows_count and columns_count must be > 0";
309 0 : NNTR_THROW_IF(row_index >= rows_count, std::out_of_range)
310 : << "row_index out of range";
311 0 : NNTR_THROW_IF(!(group_size == 32 || group_size == 64 || group_size == 128),
312 : std::invalid_argument)
313 : << "group_size must be 32/64/128";
314 0 : NNTR_THROW_IF(columns_count % 32 != 0, std::invalid_argument)
315 : << "columns_count must be divisible by 32";
316 :
317 : // --- Layout ---
318 0 : const size_t rows_count_pad = align(rows_count, ROW_BLOCK_SIZE);
319 0 : const size_t columns_count_pad = align(columns_count, group_size);
320 : const size_t column_blocks_count =
321 0 : ceilDiv(columns_count_pad, COLUMN_BLOCK_SIZE); // COLUMN_BLOCK_SIZE == 2
322 0 : const size_t padded_groups_per_row = ceilDiv(columns_count, group_size);
323 :
324 : // Address the bytes for this row
325 0 : const size_t row_block_id = row_index / ROW_BLOCK_SIZE;
326 0 : const size_t i_in_block = row_index % ROW_BLOCK_SIZE;
327 : const size_t bytes_per_row_block_span = column_blocks_count * ROW_BLOCK_SIZE;
328 0 : const size_t row_block_base =
329 0 : row_block_id * bytes_per_row_block_span + i_in_block;
330 :
331 0 : for (size_t column_block_id = 0; column_block_id < 16; ++column_block_id) {
332 0 : const size_t weights_idx =
333 0 : row_block_base + (column_index / 2 + column_block_id) * ROW_BLOCK_SIZE;
334 0 : const uint8_t packed_byte = weights[weights_idx];
335 :
336 0 : weight_int4_row32[column_block_id] = packed_byte;
337 : }
338 :
339 0 : *scale = scales[row_index + (column_index / group_size) * rows_count_pad];
340 0 : }
341 : } // namespace nntrainer
|