Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2024 Arm Limited and/or its affiliates
4 : *
5 : * @file fallback_kleidiai.cpp
6 : * @date 15 September 2025
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Sungsik Kong <ss.kong@samsung.com>
9 : * @brief Modified computational backend components of kleidiai. Portions of
10 : * this file are derived from Arm Limited code licensed under the Apache
11 : * License, Version 2.0, with modifications
12 : * @bug No known bugs except for NYI items
13 : * @note Licensed under the Apache License, Version 2.0 (the "License");
14 : * you may not use this file except in compliance with the License.
15 : * You may obtain a copy of the License at
16 : * http://www.apache.org/licenses/LICENSE-2.0
17 : *
18 : * @modifications
19 : * - [2025-09-15] Integrated and adapted Arm-provided code into
20 : * nntrainer CPU backend
21 : *
22 : */
23 :
24 : #include <cassert>
25 : #include <cfloat>
26 : #include <cmath>
27 : #include <cstring>
28 : #include <iostream>
29 : #include <limits>
30 : #include <string>
31 :
32 : #include <fallback_kleidiai.h>
33 :
34 : #define INT4_MIN (-8)
35 : #define INT4_MAX (7)
36 :
37 8 : static size_t roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; }
38 :
39 4 : void ref_quant_qa8dx_f32(size_t m, size_t k, const float *lhs_f32,
40 : int8_t *lhs_qa8dx) {
41 4 : const size_t dst_stride =
42 : (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
43 :
44 : const size_t lhs_qa8dx_stride = k;
45 :
46 12 : for (size_t m_idx = 0; m_idx < m; ++m_idx) {
47 8 : const float *src_ptr = lhs_f32 + m_idx * lhs_qa8dx_stride;
48 :
49 8 : float max0 = -FLT_MAX;
50 8 : float min0 = FLT_MAX;
51 :
52 : // Find min/max for each channel
53 72 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
54 64 : const float src0_0 = src_ptr[k_idx];
55 :
56 64 : max0 = std::max(src0_0, max0);
57 64 : min0 = std::min(src0_0, min0);
58 : }
59 :
60 : // Maximum/minimum int8 values
61 8 : const float qmin = (float)INT8_MIN;
62 8 : const float qmax = (float)INT8_MAX;
63 :
64 8 : const float rmin0 = std::min(0.0f, min0);
65 8 : const float rmax0 = std::max(0.0f, max0);
66 :
67 8 : const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
68 :
69 : // Reciprocal to quantize
70 8 : const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
71 :
72 8 : const float descaled_min0 = rmin0 * scale0;
73 8 : const float descaled_max0 = rmax0 * scale0;
74 :
75 8 : const float zero_point_from_min_error0 = qmin + descaled_min0;
76 8 : const float zero_point_from_max_error0 = qmax + descaled_max0;
77 :
78 8 : float zero_point0 =
79 8 : zero_point_from_min_error0 + zero_point_from_max_error0 > 0
80 8 : ? qmin - descaled_min0
81 : : qmax - descaled_max0;
82 :
83 8 : zero_point0 = std::max(zero_point0, qmin);
84 8 : zero_point0 = std::min(zero_point0, qmax);
85 :
86 : // Round to nearest integer
87 8 : const int32_t nudged_zero_point0 = lrintf(zero_point0);
88 :
89 8 : int8_t *dst_ptr = (int8_t *)lhs_qa8dx + m_idx * dst_stride;
90 :
91 : // LHS offset at the beginning of the row
92 8 : *((float *)(dst_ptr)) = recip_scale0;
93 : dst_ptr += sizeof(float);
94 8 : *((int32_t *)(dst_ptr)) = -nudged_zero_point0;
95 8 : dst_ptr += sizeof(int32_t);
96 :
97 : // Quantize the channels
98 72 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
99 64 : const float src0_0 = src_ptr[k_idx];
100 :
101 : // Scale the values
102 64 : int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
103 :
104 64 : v0_s32 = v0_s32 + nudged_zero_point0;
105 64 : v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
106 64 : v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
107 64 : dst_ptr[0] = (int8_t)v0_s32;
108 64 : dst_ptr += sizeof(int8_t);
109 : }
110 : }
111 4 : };
112 :
113 3 : static void quant_nxk_qs4cx_f32(size_t n, size_t k, const float *rhs_f32,
114 : uint8_t *rhs_qs4cx, float *rhs_scales_f32) {
115 3 : const size_t rhs_qs4cx_stride = (roundup(k, 2) / 2);
116 :
117 : // Make sure the output is filled with zeros
118 3 : std::memset(rhs_qs4cx, 0, n * rhs_qs4cx_stride);
119 :
120 11 : for (size_t n_idx = 0; n_idx < n; ++n_idx) {
121 8 : const float *src_ptr = rhs_f32 + n_idx * k;
122 :
123 8 : float max0 = -FLT_MAX;
124 8 : float min0 = FLT_MAX;
125 :
126 : // Find min/max for each channel
127 72 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
128 64 : const float src0_0 = src_ptr[k_idx];
129 :
130 64 : max0 = std::max(src0_0, max0);
131 64 : min0 = std::min(src0_0, min0);
132 : }
133 :
134 : // Maximum/minimum int8 values
135 : const float qmin = (float)INT4_MIN;
136 : const float qmax = (float)INT4_MAX;
137 :
138 8 : const float rmin0 = std::min(0.0f, min0);
139 8 : const float rmax0 = std::max(0.0f, max0);
140 :
141 8 : const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
142 :
143 : // Reciprocal to quantize
144 8 : const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
145 :
146 : // Quantize the channels
147 72 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
148 64 : const float src0_0 = src_ptr[k_idx];
149 :
150 : // Scale the values
151 64 : int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
152 :
153 : // Maximum/minimum int4 values
154 64 : v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT4_MIN));
155 64 : v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT4_MAX));
156 :
157 64 : const uint8_t v0_u8 = (uint8_t)(v0_s32 + 8);
158 :
159 64 : const size_t dst_addr = (k_idx / 2) + n_idx * rhs_qs4cx_stride;
160 64 : uint8_t rhs_v0 = rhs_qs4cx[dst_addr];
161 :
162 64 : if ((k_idx % 2) == 0) {
163 32 : rhs_v0 |= v0_u8;
164 : } else {
165 32 : rhs_v0 |= (v0_u8 << 4);
166 : }
167 64 : rhs_qs4cx[dst_addr] = rhs_v0;
168 : }
169 :
170 8 : rhs_scales_f32[n_idx] = recip_scale0;
171 : }
172 3 : };
173 :
174 2 : static void quant_kxn_qs4cx_f32(size_t n, size_t k, const float *rhs_f32,
175 : uint8_t *rhs_qs4cx, float *rhs_scales_f32) {
176 2 : const size_t rhs_qs4cx_stride = (roundup(n, 2) / 2);
177 :
178 : // Make sure the output is filled with zeros
179 2 : std::memset(rhs_qs4cx, 0, k * rhs_qs4cx_stride);
180 :
181 8 : for (size_t n_idx = 0; n_idx < n; ++n_idx) {
182 6 : const float *src_ptr = rhs_f32 + n_idx * k;
183 :
184 6 : float max0 = -FLT_MAX;
185 6 : float min0 = FLT_MAX;
186 :
187 : // Find min/max for each channel
188 54 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
189 48 : const float src0_0 = src_ptr[k_idx];
190 :
191 48 : max0 = std::max(src0_0, max0);
192 48 : min0 = std::min(src0_0, min0);
193 : }
194 :
195 : // Maximum/minimum int8 values
196 : const float qmin = (float)INT4_MIN;
197 : const float qmax = (float)INT4_MAX;
198 :
199 6 : const float rmin0 = std::min(0.0f, min0);
200 6 : const float rmax0 = std::max(0.0f, max0);
201 :
202 6 : const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
203 :
204 : // Reciprocal to quantize
205 6 : const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
206 :
207 : // Quantize the channels
208 54 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
209 48 : const float src0_0 = src_ptr[k_idx];
210 :
211 : // Scale the values
212 48 : int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
213 :
214 : // Maximum/minimum int4 values
215 48 : v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT4_MIN));
216 48 : v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT4_MAX));
217 :
218 48 : const uint8_t v0_u8 = (uint8_t)(v0_s32 + 8);
219 :
220 48 : const size_t dst_addr = (n_idx / 2) + k_idx * rhs_qs4cx_stride;
221 48 : uint8_t rhs_v0 = rhs_qs4cx[dst_addr];
222 :
223 48 : if ((n_idx % 2) == 0) {
224 24 : rhs_v0 |= v0_u8;
225 : } else {
226 24 : rhs_v0 |= (v0_u8 << 4);
227 : }
228 48 : rhs_qs4cx[dst_addr] = rhs_v0;
229 : }
230 :
231 6 : rhs_scales_f32[n_idx] = recip_scale0;
232 : }
233 2 : };
234 :
235 5 : void quant_qs4cx_f32(size_t n, size_t k, rhs_format format,
236 : const float *rhs_f32, uint8_t *rhs_qs4cx,
237 : float *rhs_scales_f32) {
238 5 : if (rhs_format::nxk == format) {
239 3 : quant_nxk_qs4cx_f32(n, k, rhs_f32, rhs_qs4cx, rhs_scales_f32);
240 : } else {
241 2 : quant_kxn_qs4cx_f32(n, k, rhs_f32, rhs_qs4cx, rhs_scales_f32);
242 : }
243 5 : };
244 :
245 2 : static void ref_matmul_mxn_mxk_nxk_f32_qa8dx_qs4cx( // transB
246 : size_t m, size_t n, size_t k, const int8_t *lhs_qa8dx,
247 : const uint8_t *rhs_qs4cx, const float *rhs_scales_f32, float *dst_f32,
248 : float scalar_min, float scalar_max) {
249 2 : const size_t lhs_stride =
250 : k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
251 :
252 2 : const size_t rhs_qs4cx_stride = (roundup(k, 2) / 2);
253 :
254 6 : for (size_t m_idx = 0; m_idx < m; ++m_idx) {
255 4 : const int8_t *lhs_ptr_start = lhs_qa8dx + m_idx * lhs_stride;
256 :
257 12 : for (size_t n_idx = 0; n_idx < n; ++n_idx) {
258 : // Main f32 accumulator
259 : int32_t iacc = 0;
260 :
261 : const int8_t *lhs_ptr = lhs_ptr_start;
262 8 : const uint8_t *rhs_ptr = rhs_qs4cx + n_idx * rhs_qs4cx_stride;
263 :
264 : // Get the LHS quantization parameters stored at the
265 : // beginning of each row
266 8 : const float lhs_scale = *(const float *)lhs_ptr;
267 : lhs_ptr += sizeof(float);
268 :
269 8 : const int32_t lhs_offset = *(const int32_t *)lhs_ptr;
270 8 : lhs_ptr += sizeof(int32_t);
271 :
272 72 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
273 : // Get the LHS values
274 64 : const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
275 :
276 : // Get the RHS values
277 64 : const uint8_t rhs_byte = rhs_ptr[0];
278 :
279 : // Unpack the RHS values
280 : int32_t rhs_v0 = 0;
281 64 : if ((k_idx % 2) == 0) {
282 32 : rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8);
283 : } else {
284 32 : rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8);
285 : }
286 :
287 64 : iacc += lhs_v0 * rhs_v0;
288 64 : iacc += lhs_offset * rhs_v0;
289 :
290 64 : lhs_ptr += 1;
291 :
292 : // Increment only when k_idx is not a multiple of 2
293 64 : rhs_ptr += k_idx % 2;
294 : }
295 :
296 : // Get the RHS scale
297 8 : const float rhs_scale = rhs_scales_f32[n_idx];
298 :
299 8 : float main_acc = iacc * rhs_scale;
300 :
301 8 : main_acc = main_acc * lhs_scale;
302 :
303 : // Clamp (min-max) operation
304 8 : main_acc = std::max(main_acc, scalar_min);
305 8 : main_acc = std::min(main_acc, scalar_max);
306 :
307 8 : dst_f32[0] = main_acc;
308 8 : dst_f32 += 1;
309 : }
310 : }
311 2 : };
312 :
313 1 : static void ref_matmul_mxn_mxk_kxn_f32_qa8dx_qs4cx( // noTrans
314 : size_t m, size_t n, size_t k, const int8_t *lhs_qa8dx,
315 : const uint8_t *rhs_qs4cx, const float *rhs_scales_f32, float *dst_f32,
316 : float scalar_min, float scalar_max) {
317 1 : const size_t lhs_stride =
318 : k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
319 :
320 1 : const size_t rhs_qs4cx_stride = (roundup(n, 2) / 2);
321 :
322 3 : for (size_t m_idx = 0; m_idx < m; ++m_idx) {
323 2 : const int8_t *lhs_ptr_start = lhs_qa8dx + m_idx * lhs_stride;
324 :
325 6 : for (size_t n_idx = 0; n_idx < n; ++n_idx) {
326 : // Main f32 accumulator
327 : int32_t iacc = 0;
328 :
329 : const int8_t *lhs_ptr = lhs_ptr_start;
330 4 : const uint8_t *rhs_ptr = rhs_qs4cx + (n_idx / 2);
331 :
332 : // Get the LHS quantization parameters stored at the
333 : // beginning of each row
334 4 : const float lhs_scale = *(const float *)lhs_ptr;
335 : lhs_ptr += sizeof(float);
336 :
337 4 : const int32_t lhs_offset = *(const int32_t *)lhs_ptr;
338 4 : lhs_ptr += sizeof(int32_t);
339 :
340 36 : for (size_t k_idx = 0; k_idx < k; ++k_idx) {
341 : // Get the LHS values
342 32 : const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
343 :
344 : // Get the RHS values
345 32 : const uint8_t rhs_byte = rhs_ptr[0];
346 :
347 : // Unpack the RHS values
348 : int32_t rhs_v0 = 0;
349 32 : if ((n_idx % 2) == 0) {
350 16 : rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8);
351 : } else {
352 16 : rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8);
353 : }
354 :
355 32 : iacc += lhs_v0 * rhs_v0;
356 32 : iacc += lhs_offset * rhs_v0;
357 :
358 32 : lhs_ptr += 1;
359 :
360 : // Increment only when k_idx is not a multiple of 2
361 32 : rhs_ptr += rhs_qs4cx_stride;
362 : }
363 :
364 : // Get the RHS scale
365 4 : const float rhs_scale = rhs_scales_f32[n_idx];
366 :
367 4 : float main_acc = iacc * rhs_scale;
368 :
369 4 : main_acc = main_acc * lhs_scale;
370 :
371 : // Clamp (min-max) operation
372 4 : main_acc = std::max(main_acc, scalar_min);
373 4 : main_acc = std::min(main_acc, scalar_max);
374 :
375 4 : dst_f32[0] = main_acc;
376 4 : dst_f32 += 1;
377 : }
378 : }
379 1 : };
380 :
381 3 : void ref_matmul_f32_qa8dx_qs4cx(size_t m, size_t n, size_t k, rhs_format format,
382 : const int8_t *lhs_qa8dx,
383 : const uint8_t *rhs_qs4cx,
384 : const float *rhs_scales_f32, float *dst_f32,
385 : float scalar_min, float scalar_max) {
386 : const size_t lhs_stride =
387 : k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
388 :
389 3 : if (rhs_format::nxk == format) {
390 2 : ref_matmul_mxn_mxk_nxk_f32_qa8dx_qs4cx(m, n, k, lhs_qa8dx, rhs_qs4cx,
391 : rhs_scales_f32, dst_f32, scalar_min,
392 : scalar_max);
393 : } else {
394 1 : ref_matmul_mxn_mxk_kxn_f32_qa8dx_qs4cx(m, n, k, lhs_qa8dx, rhs_qs4cx,
395 : rhs_scales_f32, dst_f32, scalar_min,
396 : scalar_max);
397 : }
398 3 : };
|