Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2024 Arm Limited and/or its affiliates
4 : *
5 : * @file fallback_kleidiai.cpp
6 : * @date 15 September 2025
7 : * @see https://github.com/nntrainer/nntrainer
8 : * @author Sungsik Kong <ss.kong@samsung.com>
9 : * @brief Modified computational backend components of kleidiai. Portions of
10 : * this file are derived from Arm Limited code licensed under the Apache
11 : * License, Version 2.0, with modifications
12 : * @bug No known bugs except for NYI items
13 : * @note Licensed under the Apache License, Version 2.0 (the "License");
14 : * you may not use this file except in compliance with the License.
15 : * You may obtain a copy of the License at
16 : * http://www.apache.org/licenses/LICENSE-2.0
17 : *
18 : * @modifications
19 : * - [2025-09-15] Integrated and adapted Arm-provided code into
20 : * nntrainer CPU backend
21 : *
22 : */
23 :
24 : #include <cassert>
25 : #include <cfloat>
26 : #include <cmath>
27 : #include <cstring>
28 : #include <iostream>
29 : #include <limits>
30 : #include <stdexcept>
31 : #include <string>
32 :
33 : #include <fallback_kleidiai.h>
34 :
35 : #define INT4_MIN (-8)
36 : #define INT4_MAX (7)
37 :
38 8 : static size_t roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; }
39 :
40 4 : void ref_quant_qa8dx_f32(size_t m, size_t k, const float *lhs_f32,
41 : int8_t *lhs_qa8dx) {
42 4 : const size_t dst_stride =
43 : (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
44 :
45 : const size_t lhs_qa8dx_stride = k;
46 :
47 12 : for (size_t m_idx = 0; m_idx < m; ++m_idx) {
48 8 : const float *src_ptr = lhs_f32 + m_idx * lhs_qa8dx_stride;
49 :
50 8 : float max0 = -FLT_MAX;
51 8 : float min0 = FLT_MAX;
52 :
53 : // Find min/max for each channel
54 72 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
55 64 : const float src0_0 = src_ptr[k_idx];
56 :
57 64 : max0 = std::max(src0_0, max0);
58 64 : min0 = std::min(src0_0, min0);
59 : }
60 :
61 : // Maximum/minimum int8 values
62 8 : const float qmin = (float)INT8_MIN;
63 8 : const float qmax = (float)INT8_MAX;
64 :
65 8 : const float rmin0 = std::min(0.0f, min0);
66 8 : const float rmax0 = std::max(0.0f, max0);
67 :
68 8 : const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
69 :
70 : // Reciprocal to quantize
71 8 : const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
72 :
73 8 : const float descaled_min0 = rmin0 * scale0;
74 8 : const float descaled_max0 = rmax0 * scale0;
75 :
76 8 : const float zero_point_from_min_error0 = qmin + descaled_min0;
77 8 : const float zero_point_from_max_error0 = qmax + descaled_max0;
78 :
79 8 : float zero_point0 =
80 8 : zero_point_from_min_error0 + zero_point_from_max_error0 > 0
81 8 : ? qmin - descaled_min0
82 : : qmax - descaled_max0;
83 :
84 8 : zero_point0 = std::max(zero_point0, qmin);
85 8 : zero_point0 = std::min(zero_point0, qmax);
86 :
87 : // Round to nearest integer
88 8 : const int32_t nudged_zero_point0 = lrintf(zero_point0);
89 :
90 8 : int8_t *dst_ptr = (int8_t *)lhs_qa8dx + m_idx * dst_stride;
91 :
92 : // LHS offset at the beginning of the row
93 8 : *((float *)(dst_ptr)) = recip_scale0;
94 : dst_ptr += sizeof(float);
95 8 : *((int32_t *)(dst_ptr)) = -nudged_zero_point0;
96 8 : dst_ptr += sizeof(int32_t);
97 :
98 : // Quantize the channels
99 72 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
100 64 : const float src0_0 = src_ptr[k_idx];
101 :
102 : // Scale the values
103 64 : int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
104 :
105 64 : v0_s32 = v0_s32 + nudged_zero_point0;
106 64 : v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
107 64 : v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
108 64 : dst_ptr[0] = (int8_t)v0_s32;
109 64 : dst_ptr += sizeof(int8_t);
110 : }
111 : }
112 4 : };
113 :
114 3 : static void quant_nxk_qs4cx_f32(size_t n, size_t k, const float *rhs_f32,
115 : uint8_t *rhs_qs4cx, float *rhs_scales_f32) {
116 3 : const size_t rhs_qs4cx_stride = (roundup(k, 2) / 2);
117 :
118 : // Make sure the output is filled with zeros
119 3 : std::memset(rhs_qs4cx, 0, n * rhs_qs4cx_stride);
120 :
121 11 : for (size_t n_idx = 0; n_idx < n; ++n_idx) {
122 8 : const float *src_ptr = rhs_f32 + n_idx * k;
123 :
124 8 : float max0 = -FLT_MAX;
125 8 : float min0 = FLT_MAX;
126 :
127 : // Find min/max for each channel
128 72 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
129 64 : const float src0_0 = src_ptr[k_idx];
130 :
131 64 : max0 = std::max(src0_0, max0);
132 64 : min0 = std::min(src0_0, min0);
133 : }
134 :
135 : // Maximum/minimum int8 values
136 : const float qmin = (float)INT4_MIN;
137 : const float qmax = (float)INT4_MAX;
138 :
139 8 : const float rmin0 = std::min(0.0f, min0);
140 8 : const float rmax0 = std::max(0.0f, max0);
141 :
142 8 : const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
143 :
144 : // Reciprocal to quantize
145 8 : const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
146 :
147 : // Quantize the channels
148 72 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
149 64 : const float src0_0 = src_ptr[k_idx];
150 :
151 : // Scale the values
152 64 : int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
153 :
154 : // Maximum/minimum int4 values
155 64 : v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT4_MIN));
156 64 : v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT4_MAX));
157 :
158 64 : const uint8_t v0_u8 = (uint8_t)(v0_s32 + 8);
159 :
160 64 : const size_t dst_addr = (k_idx / 2) + n_idx * rhs_qs4cx_stride;
161 64 : uint8_t rhs_v0 = rhs_qs4cx[dst_addr];
162 :
163 64 : if ((k_idx % 2) == 0) {
164 32 : rhs_v0 |= v0_u8;
165 : } else {
166 32 : rhs_v0 |= (v0_u8 << 4);
167 : }
168 64 : rhs_qs4cx[dst_addr] = rhs_v0;
169 : }
170 :
171 8 : rhs_scales_f32[n_idx] = recip_scale0;
172 : }
173 3 : };
174 :
175 2 : static void quant_kxn_qs4cx_f32(size_t n, size_t k, const float *rhs_f32,
176 : uint8_t *rhs_qs4cx, float *rhs_scales_f32) {
177 2 : const size_t rhs_qs4cx_stride = (roundup(n, 2) / 2);
178 :
179 : // Make sure the output is filled with zeros
180 2 : std::memset(rhs_qs4cx, 0, k * rhs_qs4cx_stride);
181 :
182 8 : for (size_t n_idx = 0; n_idx < n; ++n_idx) {
183 6 : const float *src_ptr = rhs_f32 + n_idx * k;
184 :
185 6 : float max0 = -FLT_MAX;
186 6 : float min0 = FLT_MAX;
187 :
188 : // Find min/max for each channel
189 54 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
190 48 : const float src0_0 = src_ptr[k_idx];
191 :
192 48 : max0 = std::max(src0_0, max0);
193 48 : min0 = std::min(src0_0, min0);
194 : }
195 :
196 : // Maximum/minimum int8 values
197 : const float qmin = (float)INT4_MIN;
198 : const float qmax = (float)INT4_MAX;
199 :
200 6 : const float rmin0 = std::min(0.0f, min0);
201 6 : const float rmax0 = std::max(0.0f, max0);
202 :
203 6 : const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
204 :
205 : // Reciprocal to quantize
206 6 : const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
207 :
208 : // Quantize the channels
209 54 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
210 48 : const float src0_0 = src_ptr[k_idx];
211 :
212 : // Scale the values
213 48 : int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
214 :
215 : // Maximum/minimum int4 values
216 48 : v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT4_MIN));
217 48 : v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT4_MAX));
218 :
219 48 : const uint8_t v0_u8 = (uint8_t)(v0_s32 + 8);
220 :
221 48 : const size_t dst_addr = (n_idx / 2) + k_idx * rhs_qs4cx_stride;
222 48 : uint8_t rhs_v0 = rhs_qs4cx[dst_addr];
223 :
224 48 : if ((n_idx % 2) == 0) {
225 24 : rhs_v0 |= v0_u8;
226 : } else {
227 24 : rhs_v0 |= (v0_u8 << 4);
228 : }
229 48 : rhs_qs4cx[dst_addr] = rhs_v0;
230 : }
231 :
232 6 : rhs_scales_f32[n_idx] = recip_scale0;
233 : }
234 2 : };
235 :
236 5 : void quant_qs4cx_f32(size_t n, size_t k, rhs_format format,
237 : const float *rhs_f32, uint8_t *rhs_qs4cx,
238 : float *rhs_scales_f32) {
239 5 : if (rhs_format::nxk == format) {
240 3 : quant_nxk_qs4cx_f32(n, k, rhs_f32, rhs_qs4cx, rhs_scales_f32);
241 : } else {
242 2 : quant_kxn_qs4cx_f32(n, k, rhs_f32, rhs_qs4cx, rhs_scales_f32);
243 : }
244 5 : };
245 :
246 2 : static void ref_matmul_mxn_mxk_nxk_f32_qa8dx_qs4cx( // transB
247 : size_t m, size_t n, size_t k, const int8_t *lhs_qa8dx,
248 : const uint8_t *rhs_qs4cx, const float *rhs_scales_f32, float *dst_f32,
249 : float scalar_min, float scalar_max) {
250 2 : const size_t lhs_stride =
251 : k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
252 :
253 2 : const size_t rhs_qs4cx_stride = (roundup(k, 2) / 2);
254 :
255 6 : for (size_t m_idx = 0; m_idx < m; ++m_idx) {
256 4 : const int8_t *lhs_ptr_start = lhs_qa8dx + m_idx * lhs_stride;
257 :
258 12 : for (size_t n_idx = 0; n_idx < n; ++n_idx) {
259 : // Main f32 accumulator
260 : int32_t iacc = 0;
261 :
262 : const int8_t *lhs_ptr = lhs_ptr_start;
263 8 : const uint8_t *rhs_ptr = rhs_qs4cx + n_idx * rhs_qs4cx_stride;
264 :
265 : // Get the LHS quantization parameters stored at the
266 : // beginning of each row
267 8 : const float lhs_scale = *(const float *)lhs_ptr;
268 : lhs_ptr += sizeof(float);
269 :
270 8 : const int32_t lhs_offset = *(const int32_t *)lhs_ptr;
271 8 : lhs_ptr += sizeof(int32_t);
272 :
273 72 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
274 : // Get the LHS values
275 64 : const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
276 :
277 : // Get the RHS values
278 64 : const uint8_t rhs_byte = rhs_ptr[0];
279 :
280 : // Unpack the RHS values
281 : int32_t rhs_v0 = 0;
282 64 : if ((k_idx % 2) == 0) {
283 32 : rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8);
284 : } else {
285 32 : rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8);
286 : }
287 :
288 64 : iacc += lhs_v0 * rhs_v0;
289 64 : iacc += lhs_offset * rhs_v0;
290 :
291 64 : lhs_ptr += 1;
292 :
293 : // Increment only when k_idx is not a multiple of 2
294 64 : rhs_ptr += k_idx % 2;
295 : }
296 :
297 : // Get the RHS scale
298 8 : const float rhs_scale = rhs_scales_f32[n_idx];
299 :
300 8 : float main_acc = iacc * rhs_scale;
301 :
302 8 : main_acc = main_acc * lhs_scale;
303 :
304 : // Clamp (min-max) operation
305 8 : main_acc = std::max(main_acc, scalar_min);
306 8 : main_acc = std::min(main_acc, scalar_max);
307 :
308 8 : dst_f32[0] = main_acc;
309 8 : dst_f32 += 1;
310 : }
311 : }
312 2 : };
313 :
314 1 : static void ref_matmul_mxn_mxk_kxn_f32_qa8dx_qs4cx( // noTrans
315 : size_t m, size_t n, size_t k, const int8_t *lhs_qa8dx,
316 : const uint8_t *rhs_qs4cx, const float *rhs_scales_f32, float *dst_f32,
317 : float scalar_min, float scalar_max) {
318 1 : const size_t lhs_stride =
319 : k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
320 :
321 1 : const size_t rhs_qs4cx_stride = (roundup(n, 2) / 2);
322 :
323 3 : for (size_t m_idx = 0; m_idx < m; ++m_idx) {
324 2 : const int8_t *lhs_ptr_start = lhs_qa8dx + m_idx * lhs_stride;
325 :
326 6 : for (size_t n_idx = 0; n_idx < n; ++n_idx) {
327 : // Main f32 accumulator
328 : int32_t iacc = 0;
329 :
330 : const int8_t *lhs_ptr = lhs_ptr_start;
331 4 : const uint8_t *rhs_ptr = rhs_qs4cx + (n_idx / 2);
332 :
333 : // Get the LHS quantization parameters stored at the
334 : // beginning of each row
335 4 : const float lhs_scale = *(const float *)lhs_ptr;
336 : lhs_ptr += sizeof(float);
337 :
338 4 : const int32_t lhs_offset = *(const int32_t *)lhs_ptr;
339 4 : lhs_ptr += sizeof(int32_t);
340 :
341 36 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
342 : // Get the LHS values
343 32 : const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
344 :
345 : // Get the RHS values
346 32 : const uint8_t rhs_byte = rhs_ptr[0];
347 :
348 : // Unpack the RHS values
349 : int32_t rhs_v0 = 0;
350 32 : if ((n_idx % 2) == 0) {
351 16 : rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8);
352 : } else {
353 16 : rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8);
354 : }
355 :
356 32 : iacc += lhs_v0 * rhs_v0;
357 32 : iacc += lhs_offset * rhs_v0;
358 :
359 32 : lhs_ptr += 1;
360 :
361 : // Increment only when k_idx is not a multiple of 2
362 32 : rhs_ptr += rhs_qs4cx_stride;
363 : }
364 :
365 : // Get the RHS scale
366 4 : const float rhs_scale = rhs_scales_f32[n_idx];
367 :
368 4 : float main_acc = iacc * rhs_scale;
369 :
370 4 : main_acc = main_acc * lhs_scale;
371 :
372 : // Clamp (min-max) operation
373 4 : main_acc = std::max(main_acc, scalar_min);
374 4 : main_acc = std::min(main_acc, scalar_max);
375 :
376 4 : dst_f32[0] = main_acc;
377 4 : dst_f32 += 1;
378 : }
379 : }
380 1 : };
381 :
382 3 : void ref_matmul_f32_qa8dx_qs4cx(size_t m, size_t n, size_t k, rhs_format format,
383 : const int8_t *lhs_qa8dx,
384 : const uint8_t *rhs_qs4cx,
385 : const float *rhs_scales_f32, float *dst_f32,
386 : float scalar_min, float scalar_max) {
387 : const size_t lhs_stride =
388 : k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
389 :
390 3 : if (rhs_format::nxk == format) {
391 2 : ref_matmul_mxn_mxk_nxk_f32_qa8dx_qs4cx(m, n, k, lhs_qa8dx, rhs_qs4cx,
392 : rhs_scales_f32, dst_f32, scalar_min,
393 : scalar_max);
394 : } else {
395 1 : ref_matmul_mxn_mxk_kxn_f32_qa8dx_qs4cx(m, n, k, lhs_qa8dx, rhs_qs4cx,
396 : rhs_scales_f32, dst_f32, scalar_min,
397 : scalar_max);
398 : }
399 3 : };
400 :
401 0 : void quant_qs4c32_f32(size_t n, size_t k, size_t bl, const float *rhs_f32,
402 : uint8_t *rhs_qs4c32) {
403 0 : throw std::runtime_error("NYI : quant_qs4c32_f32 (fallback)");
404 : }
405 :
406 0 : void ref_quant_qs8d32_f32(size_t n, size_t k, size_t bl, const float *rhs_f32,
407 : uint8_t *rhs_qs8c32) {
408 0 : throw std::runtime_error("NYI : ref_quant_qs8d32_f32 (fallback)");
409 : }
410 :
411 0 : void ref_matmul_f32_qs8d32_qs4c32(size_t m, size_t n, size_t k, size_t bl,
412 : const int8_t *lhs_qa8d32,
413 : const uint8_t *rhs_qs4c32, float *dst_f32,
414 : float scalar_min, float scalar_max) {
415 0 : throw std::runtime_error("NYI : ref_matmul_f32_qs8d32_qs4c32 (fallback)");
416 : }
|