LCOV - code coverage report
Current view: top level - nntrainer/tensor/cpu_backend/fallback - fallback_kleidiai.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 96.5 % 170 164
Test Date: 2026-01-12 20:43:37 Functions: 70.0 % 10 7

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2024 Arm Limited and/or its affiliates
       4              :  *
       5              :  * @file   fallback_kleidiai.cpp
       6              :  * @date   15 September 2025
       7              :  * @see    https://github.com/nntrainer/nntrainer
       8              :  * @author Sungsik Kong <ss.kong@samsung.com>
       9              :  * @brief  Modified computational backend components of kleidiai. Portions of
      10              :  * this file are derived from Arm Limited code licensed under the Apache
      11              :  * License, Version 2.0, with modifications
      12              :  * @bug    No known bugs except for NYI items
      13              :  * @note   Licensed under the Apache License, Version 2.0 (the "License");
      14              :  *         you may not use this file except in compliance with the License.
      15              :  *         You may obtain a copy of the License at
      16              :  *             http://www.apache.org/licenses/LICENSE-2.0
      17              :  *
      18              :  * @modifications
      19              :  *   - [2025-09-15] Integrated and adapted Arm-provided code into
      20              :  *     nntrainer CPU backend
      21              :  *
      22              :  */
      23              : 
      24              : #include <cassert>
      25              : #include <cfloat>
      26              : #include <cmath>
      27              : #include <cstring>
      28              : #include <iostream>
      29              : #include <limits>
      30              : #include <stdexcept>
      31              : #include <string>
      32              : 
      33              : #include <fallback_kleidiai.h>
      34              : 
      35              : #define INT4_MIN (-8)
      36              : #define INT4_MAX (7)
      37              : 
      38            8 : static size_t roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; }
      39              : 
      40            4 : void ref_quant_qa8dx_f32(size_t m, size_t k, const float *lhs_f32,
      41              :                          int8_t *lhs_qa8dx) {
      42            4 :   const size_t dst_stride =
      43              :     (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
      44              : 
      45              :   const size_t lhs_qa8dx_stride = k;
      46              : 
      47           12 :   for (size_t m_idx = 0; m_idx < m; ++m_idx) {
      48            8 :     const float *src_ptr = lhs_f32 + m_idx * lhs_qa8dx_stride;
      49              : 
      50            8 :     float max0 = -FLT_MAX;
      51            8 :     float min0 = FLT_MAX;
      52              : 
      53              :     // Find min/max for each channel
      54           72 :     for (size_t k_idx = 0; k_idx < k; ++k_idx) {
      55           64 :       const float src0_0 = src_ptr[k_idx];
      56              : 
      57           64 :       max0 = std::max(src0_0, max0);
      58           64 :       min0 = std::min(src0_0, min0);
      59              :     }
      60              : 
      61              :     // Maximum/minimum int8 values
      62            8 :     const float qmin = (float)INT8_MIN;
      63            8 :     const float qmax = (float)INT8_MAX;
      64              : 
      65            8 :     const float rmin0 = std::min(0.0f, min0);
      66            8 :     const float rmax0 = std::max(0.0f, max0);
      67              : 
      68            8 :     const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
      69              : 
      70              :     // Reciprocal to quantize
      71            8 :     const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
      72              : 
      73            8 :     const float descaled_min0 = rmin0 * scale0;
      74            8 :     const float descaled_max0 = rmax0 * scale0;
      75              : 
      76            8 :     const float zero_point_from_min_error0 = qmin + descaled_min0;
      77            8 :     const float zero_point_from_max_error0 = qmax + descaled_max0;
      78              : 
      79            8 :     float zero_point0 =
      80            8 :       zero_point_from_min_error0 + zero_point_from_max_error0 > 0
      81            8 :         ? qmin - descaled_min0
      82              :         : qmax - descaled_max0;
      83              : 
      84            8 :     zero_point0 = std::max(zero_point0, qmin);
      85            8 :     zero_point0 = std::min(zero_point0, qmax);
      86              : 
      87              :     // Round to nearest integer
      88            8 :     const int32_t nudged_zero_point0 = lrintf(zero_point0);
      89              : 
      90            8 :     int8_t *dst_ptr = (int8_t *)lhs_qa8dx + m_idx * dst_stride;
      91              : 
      92              :     // LHS offset at the beginning of the row
      93            8 :     *((float *)(dst_ptr)) = recip_scale0;
      94              :     dst_ptr += sizeof(float);
      95            8 :     *((int32_t *)(dst_ptr)) = -nudged_zero_point0;
      96            8 :     dst_ptr += sizeof(int32_t);
      97              : 
      98              :     // Quantize the channels
      99           72 :     for (size_t k_idx = 0; k_idx < k; ++k_idx) {
     100           64 :       const float src0_0 = src_ptr[k_idx];
     101              : 
     102              :       // Scale the values
     103           64 :       int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
     104              : 
     105           64 :       v0_s32 = v0_s32 + nudged_zero_point0;
     106           64 :       v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
     107           64 :       v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
     108           64 :       dst_ptr[0] = (int8_t)v0_s32;
     109           64 :       dst_ptr += sizeof(int8_t);
     110              :     }
     111              :   }
     112            4 : };
     113              : 
     114            3 : static void quant_nxk_qs4cx_f32(size_t n, size_t k, const float *rhs_f32,
     115              :                                 uint8_t *rhs_qs4cx, float *rhs_scales_f32) {
     116            3 :   const size_t rhs_qs4cx_stride = (roundup(k, 2) / 2);
     117              : 
     118              :   // Make sure the output is filled with zeros
     119            3 :   std::memset(rhs_qs4cx, 0, n * rhs_qs4cx_stride);
     120              : 
     121           11 :   for (size_t n_idx = 0; n_idx < n; ++n_idx) {
     122            8 :     const float *src_ptr = rhs_f32 + n_idx * k;
     123              : 
     124            8 :     float max0 = -FLT_MAX;
     125            8 :     float min0 = FLT_MAX;
     126              : 
     127              :     // Find min/max for each channel
     128           72 :     for (size_t k_idx = 0; k_idx < k; ++k_idx) {
     129           64 :       const float src0_0 = src_ptr[k_idx];
     130              : 
     131           64 :       max0 = std::max(src0_0, max0);
     132           64 :       min0 = std::min(src0_0, min0);
     133              :     }
     134              : 
     135              :     // Maximum/minimum int8 values
     136              :     const float qmin = (float)INT4_MIN;
     137              :     const float qmax = (float)INT4_MAX;
     138              : 
     139            8 :     const float rmin0 = std::min(0.0f, min0);
     140            8 :     const float rmax0 = std::max(0.0f, max0);
     141              : 
     142            8 :     const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
     143              : 
     144              :     // Reciprocal to quantize
     145            8 :     const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
     146              : 
     147              :     // Quantize the channels
     148           72 :     for (size_t k_idx = 0; k_idx < k; ++k_idx) {
     149           64 :       const float src0_0 = src_ptr[k_idx];
     150              : 
     151              :       // Scale the values
     152           64 :       int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
     153              : 
     154              :       // Maximum/minimum int4 values
     155           64 :       v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT4_MIN));
     156           64 :       v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT4_MAX));
     157              : 
     158           64 :       const uint8_t v0_u8 = (uint8_t)(v0_s32 + 8);
     159              : 
     160           64 :       const size_t dst_addr = (k_idx / 2) + n_idx * rhs_qs4cx_stride;
     161           64 :       uint8_t rhs_v0 = rhs_qs4cx[dst_addr];
     162              : 
     163           64 :       if ((k_idx % 2) == 0) {
     164           32 :         rhs_v0 |= v0_u8;
     165              :       } else {
     166           32 :         rhs_v0 |= (v0_u8 << 4);
     167              :       }
     168           64 :       rhs_qs4cx[dst_addr] = rhs_v0;
     169              :     }
     170              : 
     171            8 :     rhs_scales_f32[n_idx] = recip_scale0;
     172              :   }
     173            3 : };
     174              : 
     175            2 : static void quant_kxn_qs4cx_f32(size_t n, size_t k, const float *rhs_f32,
     176              :                                 uint8_t *rhs_qs4cx, float *rhs_scales_f32) {
     177            2 :   const size_t rhs_qs4cx_stride = (roundup(n, 2) / 2);
     178              : 
     179              :   // Make sure the output is filled with zeros
     180            2 :   std::memset(rhs_qs4cx, 0, k * rhs_qs4cx_stride);
     181              : 
     182            8 :   for (size_t n_idx = 0; n_idx < n; ++n_idx) {
     183            6 :     const float *src_ptr = rhs_f32 + n_idx * k;
     184              : 
     185            6 :     float max0 = -FLT_MAX;
     186            6 :     float min0 = FLT_MAX;
     187              : 
     188              :     // Find min/max for each channel
     189           54 :     for (size_t k_idx = 0; k_idx < k; ++k_idx) {
     190           48 :       const float src0_0 = src_ptr[k_idx];
     191              : 
     192           48 :       max0 = std::max(src0_0, max0);
     193           48 :       min0 = std::min(src0_0, min0);
     194              :     }
     195              : 
     196              :     // Maximum/minimum int8 values
     197              :     const float qmin = (float)INT4_MIN;
     198              :     const float qmax = (float)INT4_MAX;
     199              : 
     200            6 :     const float rmin0 = std::min(0.0f, min0);
     201            6 :     const float rmax0 = std::max(0.0f, max0);
     202              : 
     203            6 :     const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
     204              : 
     205              :     // Reciprocal to quantize
     206            6 :     const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
     207              : 
     208              :     // Quantize the channels
     209           54 :     for (size_t k_idx = 0; k_idx < k; ++k_idx) {
     210           48 :       const float src0_0 = src_ptr[k_idx];
     211              : 
     212              :       // Scale the values
     213           48 :       int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
     214              : 
     215              :       // Maximum/minimum int4 values
     216           48 :       v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT4_MIN));
     217           48 :       v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT4_MAX));
     218              : 
     219           48 :       const uint8_t v0_u8 = (uint8_t)(v0_s32 + 8);
     220              : 
     221           48 :       const size_t dst_addr = (n_idx / 2) + k_idx * rhs_qs4cx_stride;
     222           48 :       uint8_t rhs_v0 = rhs_qs4cx[dst_addr];
     223              : 
     224           48 :       if ((n_idx % 2) == 0) {
     225           24 :         rhs_v0 |= v0_u8;
     226              :       } else {
     227           24 :         rhs_v0 |= (v0_u8 << 4);
     228              :       }
     229           48 :       rhs_qs4cx[dst_addr] = rhs_v0;
     230              :     }
     231              : 
     232            6 :     rhs_scales_f32[n_idx] = recip_scale0;
     233              :   }
     234            2 : };
     235              : 
     236            5 : void quant_qs4cx_f32(size_t n, size_t k, rhs_format format,
     237              :                      const float *rhs_f32, uint8_t *rhs_qs4cx,
     238              :                      float *rhs_scales_f32) {
     239            5 :   if (rhs_format::nxk == format) {
     240            3 :     quant_nxk_qs4cx_f32(n, k, rhs_f32, rhs_qs4cx, rhs_scales_f32);
     241              :   } else {
     242            2 :     quant_kxn_qs4cx_f32(n, k, rhs_f32, rhs_qs4cx, rhs_scales_f32);
     243              :   }
     244            5 : };
     245              : 
     246            2 : static void ref_matmul_mxn_mxk_nxk_f32_qa8dx_qs4cx( // transB
     247              :   size_t m, size_t n, size_t k, const int8_t *lhs_qa8dx,
     248              :   const uint8_t *rhs_qs4cx, const float *rhs_scales_f32, float *dst_f32,
     249              :   float scalar_min, float scalar_max) {
     250            2 :   const size_t lhs_stride =
     251              :     k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
     252              : 
     253            2 :   const size_t rhs_qs4cx_stride = (roundup(k, 2) / 2);
     254              : 
     255            6 :   for (size_t m_idx = 0; m_idx < m; ++m_idx) {
     256            4 :     const int8_t *lhs_ptr_start = lhs_qa8dx + m_idx * lhs_stride;
     257              : 
     258           12 :     for (size_t n_idx = 0; n_idx < n; ++n_idx) {
     259              :       // Main f32 accumulator
     260              :       int32_t iacc = 0;
     261              : 
     262              :       const int8_t *lhs_ptr = lhs_ptr_start;
     263            8 :       const uint8_t *rhs_ptr = rhs_qs4cx + n_idx * rhs_qs4cx_stride;
     264              : 
     265              :       // Get the LHS quantization parameters stored at the
     266              :       // beginning of each row
     267            8 :       const float lhs_scale = *(const float *)lhs_ptr;
     268              :       lhs_ptr += sizeof(float);
     269              : 
     270            8 :       const int32_t lhs_offset = *(const int32_t *)lhs_ptr;
     271            8 :       lhs_ptr += sizeof(int32_t);
     272              : 
     273           72 :       for (size_t k_idx = 0; k_idx < k; ++k_idx) {
     274              :         // Get the LHS values
     275           64 :         const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
     276              : 
     277              :         // Get the RHS values
     278           64 :         const uint8_t rhs_byte = rhs_ptr[0];
     279              : 
     280              :         // Unpack the RHS values
     281              :         int32_t rhs_v0 = 0;
     282           64 :         if ((k_idx % 2) == 0) {
     283           32 :           rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8);
     284              :         } else {
     285           32 :           rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8);
     286              :         }
     287              : 
     288           64 :         iacc += lhs_v0 * rhs_v0;
     289           64 :         iacc += lhs_offset * rhs_v0;
     290              : 
     291           64 :         lhs_ptr += 1;
     292              : 
     293              :         // Increment only when k_idx is not a multiple of 2
     294           64 :         rhs_ptr += k_idx % 2;
     295              :       }
     296              : 
     297              :       // Get the RHS scale
     298            8 :       const float rhs_scale = rhs_scales_f32[n_idx];
     299              : 
     300            8 :       float main_acc = iacc * rhs_scale;
     301              : 
     302            8 :       main_acc = main_acc * lhs_scale;
     303              : 
     304              :       // Clamp (min-max) operation
     305            8 :       main_acc = std::max(main_acc, scalar_min);
     306            8 :       main_acc = std::min(main_acc, scalar_max);
     307              : 
     308            8 :       dst_f32[0] = main_acc;
     309            8 :       dst_f32 += 1;
     310              :     }
     311              :   }
     312            2 : };
     313              : 
     314            1 : static void ref_matmul_mxn_mxk_kxn_f32_qa8dx_qs4cx( // noTrans
     315              :   size_t m, size_t n, size_t k, const int8_t *lhs_qa8dx,
     316              :   const uint8_t *rhs_qs4cx, const float *rhs_scales_f32, float *dst_f32,
     317              :   float scalar_min, float scalar_max) {
     318            1 :   const size_t lhs_stride =
     319              :     k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
     320              : 
     321            1 :   const size_t rhs_qs4cx_stride = (roundup(n, 2) / 2);
     322              : 
     323            3 :   for (size_t m_idx = 0; m_idx < m; ++m_idx) {
     324            2 :     const int8_t *lhs_ptr_start = lhs_qa8dx + m_idx * lhs_stride;
     325              : 
     326            6 :     for (size_t n_idx = 0; n_idx < n; ++n_idx) {
     327              :       // Main f32 accumulator
     328              :       int32_t iacc = 0;
     329              : 
     330              :       const int8_t *lhs_ptr = lhs_ptr_start;
     331            4 :       const uint8_t *rhs_ptr = rhs_qs4cx + (n_idx / 2);
     332              : 
     333              :       // Get the LHS quantization parameters stored at the
     334              :       // beginning of each row
     335            4 :       const float lhs_scale = *(const float *)lhs_ptr;
     336              :       lhs_ptr += sizeof(float);
     337              : 
     338            4 :       const int32_t lhs_offset = *(const int32_t *)lhs_ptr;
     339            4 :       lhs_ptr += sizeof(int32_t);
     340              : 
     341           36 :       for (size_t k_idx = 0; k_idx < k; ++k_idx) {
     342              :         // Get the LHS values
     343           32 :         const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
     344              : 
     345              :         // Get the RHS values
     346           32 :         const uint8_t rhs_byte = rhs_ptr[0];
     347              : 
     348              :         // Unpack the RHS values
     349              :         int32_t rhs_v0 = 0;
     350           32 :         if ((n_idx % 2) == 0) {
     351           16 :           rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8);
     352              :         } else {
     353           16 :           rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8);
     354              :         }
     355              : 
     356           32 :         iacc += lhs_v0 * rhs_v0;
     357           32 :         iacc += lhs_offset * rhs_v0;
     358              : 
     359           32 :         lhs_ptr += 1;
     360              : 
     361              :         // Increment only when k_idx is not a multiple of 2
     362           32 :         rhs_ptr += rhs_qs4cx_stride;
     363              :       }
     364              : 
     365              :       // Get the RHS scale
     366            4 :       const float rhs_scale = rhs_scales_f32[n_idx];
     367              : 
     368            4 :       float main_acc = iacc * rhs_scale;
     369              : 
     370            4 :       main_acc = main_acc * lhs_scale;
     371              : 
     372              :       // Clamp (min-max) operation
     373            4 :       main_acc = std::max(main_acc, scalar_min);
     374            4 :       main_acc = std::min(main_acc, scalar_max);
     375              : 
     376            4 :       dst_f32[0] = main_acc;
     377            4 :       dst_f32 += 1;
     378              :     }
     379              :   }
     380            1 : };
     381              : 
     382            3 : void ref_matmul_f32_qa8dx_qs4cx(size_t m, size_t n, size_t k, rhs_format format,
     383              :                                 const int8_t *lhs_qa8dx,
     384              :                                 const uint8_t *rhs_qs4cx,
     385              :                                 const float *rhs_scales_f32, float *dst_f32,
     386              :                                 float scalar_min, float scalar_max) {
     387              :   const size_t lhs_stride =
     388              :     k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
     389              : 
     390            3 :   if (rhs_format::nxk == format) {
     391            2 :     ref_matmul_mxn_mxk_nxk_f32_qa8dx_qs4cx(m, n, k, lhs_qa8dx, rhs_qs4cx,
     392              :                                            rhs_scales_f32, dst_f32, scalar_min,
     393              :                                            scalar_max);
     394              :   } else {
     395            1 :     ref_matmul_mxn_mxk_kxn_f32_qa8dx_qs4cx(m, n, k, lhs_qa8dx, rhs_qs4cx,
     396              :                                            rhs_scales_f32, dst_f32, scalar_min,
     397              :                                            scalar_max);
     398              :   }
     399            3 : };
     400              : 
     401            0 : void quant_qs4c32_f32(size_t n, size_t k, size_t bl, const float *rhs_f32,
     402              :                       uint8_t *rhs_qs4c32) {
     403            0 :   throw std::runtime_error("NYI : quant_qs4c32_f32 (fallback)");
     404              : }
     405              : 
     406            0 : void ref_quant_qs8d32_f32(size_t n, size_t k, size_t bl, const float *rhs_f32,
     407              :                           uint8_t *rhs_qs8c32) {
     408            0 :   throw std::runtime_error("NYI : ref_quant_qs8d32_f32 (fallback)");
     409              : }
     410              : 
     411            0 : void ref_matmul_f32_qs8d32_qs4c32(size_t m, size_t n, size_t k, size_t bl,
     412              :                                   const int8_t *lhs_qa8d32,
     413              :                                   const uint8_t *rhs_qs4c32, float *dst_f32,
     414              :                                   float scalar_min, float scalar_max) {
     415            0 :   throw std::runtime_error("NYI : ref_matmul_f32_qs8d32_qs4c32 (fallback)");
     416              : }
        

Generated by: LCOV version 2.0-1