LCOV - code coverage report
Current view: top level - nntrainer/tensor/cpu_backend/fallback - fallback_kleidiai.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 100.0 % 164 164
Test Date: 2025-12-14 20:38:17 Functions: 100.0 % 7 7

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2024 Arm Limited and/or its affiliates
       4              :  *
       5              :  * @file   fallback_kleidiai.cpp
       6              :  * @date   15 September 2025
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Sungsik Kong <ss.kong@samsung.com>
       9              :  * @brief  Modified computational backend components of kleidiai. Portions of
      10              :  * this file are derived from Arm Limited code licensed under the Apache
      11              :  * License, Version 2.0, with modifications
      12              :  * @bug    No known bugs except for NYI items
      13              :  * @note   Licensed under the Apache License, Version 2.0 (the "License");
      14              :  *         you may not use this file except in compliance with the License.
      15              :  *         You may obtain a copy of the License at
      16              :  *             http://www.apache.org/licenses/LICENSE-2.0
      17              :  *
      18              :  * @modifications
      19              :  *   - [2025-09-15] Integrated and adapted Arm-provided code into
      20              :  *     nntrainer CPU backend
      21              :  *
      22              :  */
      23              : 
      24              : #include <cassert>
      25              : #include <cfloat>
      26              : #include <cmath>
      27              : #include <cstring>
      28              : #include <iostream>
      29              : #include <limits>
      30              : #include <string>
      31              : 
      32              : #include <fallback_kleidiai.h>
      33              : 
      34              : #define INT4_MIN (-8)
      35              : #define INT4_MAX (7)
      36              : 
      37            8 : static size_t roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; }
      38              : 
      39            4 : void ref_quant_qa8dx_f32(size_t m, size_t k, const float *lhs_f32,
      40              :                          int8_t *lhs_qa8dx) {
      41            4 :   const size_t dst_stride =
      42              :     (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
      43              : 
      44              :   const size_t lhs_qa8dx_stride = k;
      45              : 
      46           12 :   for (size_t m_idx = 0; m_idx < m; ++m_idx) {
      47            8 :     const float *src_ptr = lhs_f32 + m_idx * lhs_qa8dx_stride;
      48              : 
      49            8 :     float max0 = -FLT_MAX;
      50            8 :     float min0 = FLT_MAX;
      51              : 
      52              :     // Find min/max for each channel
      53           72 :     for (size_t k_idx = 0; k_idx < k; ++k_idx) {
      54           64 :       const float src0_0 = src_ptr[k_idx];
      55              : 
      56           64 :       max0 = std::max(src0_0, max0);
      57           64 :       min0 = std::min(src0_0, min0);
      58              :     }
      59              : 
      60              :     // Maximum/minimum int8 values
      61            8 :     const float qmin = (float)INT8_MIN;
      62            8 :     const float qmax = (float)INT8_MAX;
      63              : 
      64            8 :     const float rmin0 = std::min(0.0f, min0);
      65            8 :     const float rmax0 = std::max(0.0f, max0);
      66              : 
      67            8 :     const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
      68              : 
      69              :     // Reciprocal to quantize
      70            8 :     const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
      71              : 
      72            8 :     const float descaled_min0 = rmin0 * scale0;
      73            8 :     const float descaled_max0 = rmax0 * scale0;
      74              : 
      75            8 :     const float zero_point_from_min_error0 = qmin + descaled_min0;
      76            8 :     const float zero_point_from_max_error0 = qmax + descaled_max0;
      77              : 
      78            8 :     float zero_point0 =
      79            8 :       zero_point_from_min_error0 + zero_point_from_max_error0 > 0
      80            8 :         ? qmin - descaled_min0
      81              :         : qmax - descaled_max0;
      82              : 
      83            8 :     zero_point0 = std::max(zero_point0, qmin);
      84            8 :     zero_point0 = std::min(zero_point0, qmax);
      85              : 
      86              :     // Round to nearest integer
      87            8 :     const int32_t nudged_zero_point0 = lrintf(zero_point0);
      88              : 
      89            8 :     int8_t *dst_ptr = (int8_t *)lhs_qa8dx + m_idx * dst_stride;
      90              : 
      91              :     // LHS offset at the beginning of the row
      92            8 :     *((float *)(dst_ptr)) = recip_scale0;
      93              :     dst_ptr += sizeof(float);
      94            8 :     *((int32_t *)(dst_ptr)) = -nudged_zero_point0;
      95            8 :     dst_ptr += sizeof(int32_t);
      96              : 
      97              :     // Quantize the channels
      98           72 :     for (size_t k_idx = 0; k_idx < k; ++k_idx) {
      99           64 :       const float src0_0 = src_ptr[k_idx];
     100              : 
     101              :       // Scale the values
     102           64 :       int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
     103              : 
     104           64 :       v0_s32 = v0_s32 + nudged_zero_point0;
     105           64 :       v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
     106           64 :       v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
     107           64 :       dst_ptr[0] = (int8_t)v0_s32;
     108           64 :       dst_ptr += sizeof(int8_t);
     109              :     }
     110              :   }
     111            4 : };
     112              : 
     113            3 : static void quant_nxk_qs4cx_f32(size_t n, size_t k, const float *rhs_f32,
     114              :                                 uint8_t *rhs_qs4cx, float *rhs_scales_f32) {
     115            3 :   const size_t rhs_qs4cx_stride = (roundup(k, 2) / 2);
     116              : 
     117              :   // Make sure the output is filled with zeros
     118            3 :   std::memset(rhs_qs4cx, 0, n * rhs_qs4cx_stride);
     119              : 
     120           11 :   for (size_t n_idx = 0; n_idx < n; ++n_idx) {
     121            8 :     const float *src_ptr = rhs_f32 + n_idx * k;
     122              : 
     123            8 :     float max0 = -FLT_MAX;
     124            8 :     float min0 = FLT_MAX;
     125              : 
     126              :     // Find min/max for each channel
     127           72 :     for (size_t k_idx = 0; k_idx < k; ++k_idx) {
     128           64 :       const float src0_0 = src_ptr[k_idx];
     129              : 
     130           64 :       max0 = std::max(src0_0, max0);
     131           64 :       min0 = std::min(src0_0, min0);
     132              :     }
     133              : 
     134              :     // Maximum/minimum int8 values
     135              :     const float qmin = (float)INT4_MIN;
     136              :     const float qmax = (float)INT4_MAX;
     137              : 
     138            8 :     const float rmin0 = std::min(0.0f, min0);
     139            8 :     const float rmax0 = std::max(0.0f, max0);
     140              : 
     141            8 :     const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
     142              : 
     143              :     // Reciprocal to quantize
     144            8 :     const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
     145              : 
     146              :     // Quantize the channels
     147           72 :     for (size_t k_idx = 0; k_idx < k; ++k_idx) {
     148           64 :       const float src0_0 = src_ptr[k_idx];
     149              : 
     150              :       // Scale the values
     151           64 :       int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
     152              : 
     153              :       // Maximum/minimum int4 values
     154           64 :       v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT4_MIN));
     155           64 :       v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT4_MAX));
     156              : 
     157           64 :       const uint8_t v0_u8 = (uint8_t)(v0_s32 + 8);
     158              : 
     159           64 :       const size_t dst_addr = (k_idx / 2) + n_idx * rhs_qs4cx_stride;
     160           64 :       uint8_t rhs_v0 = rhs_qs4cx[dst_addr];
     161              : 
     162           64 :       if ((k_idx % 2) == 0) {
     163           32 :         rhs_v0 |= v0_u8;
     164              :       } else {
     165           32 :         rhs_v0 |= (v0_u8 << 4);
     166              :       }
     167           64 :       rhs_qs4cx[dst_addr] = rhs_v0;
     168              :     }
     169              : 
     170            8 :     rhs_scales_f32[n_idx] = recip_scale0;
     171              :   }
     172            3 : };
     173              : 
     174            2 : static void quant_kxn_qs4cx_f32(size_t n, size_t k, const float *rhs_f32,
     175              :                                 uint8_t *rhs_qs4cx, float *rhs_scales_f32) {
     176            2 :   const size_t rhs_qs4cx_stride = (roundup(n, 2) / 2);
     177              : 
     178              :   // Make sure the output is filled with zeros
     179            2 :   std::memset(rhs_qs4cx, 0, k * rhs_qs4cx_stride);
     180              : 
     181            8 :   for (size_t n_idx = 0; n_idx < n; ++n_idx) {
     182            6 :     const float *src_ptr = rhs_f32 + n_idx * k;
     183              : 
     184            6 :     float max0 = -FLT_MAX;
     185            6 :     float min0 = FLT_MAX;
     186              : 
     187              :     // Find min/max for each channel
     188           54 :     for (size_t k_idx = 0; k_idx < k; ++k_idx) {
     189           48 :       const float src0_0 = src_ptr[k_idx];
     190              : 
     191           48 :       max0 = std::max(src0_0, max0);
     192           48 :       min0 = std::min(src0_0, min0);
     193              :     }
     194              : 
     195              :     // Maximum/minimum int8 values
     196              :     const float qmin = (float)INT4_MIN;
     197              :     const float qmax = (float)INT4_MAX;
     198              : 
     199            6 :     const float rmin0 = std::min(0.0f, min0);
     200            6 :     const float rmax0 = std::max(0.0f, max0);
     201              : 
     202            6 :     const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
     203              : 
     204              :     // Reciprocal to quantize
     205            6 :     const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
     206              : 
     207              :     // Quantize the channels
     208           54 :     for (size_t k_idx = 0; k_idx < k; ++k_idx) {
     209           48 :       const float src0_0 = src_ptr[k_idx];
     210              : 
     211              :       // Scale the values
     212           48 :       int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
     213              : 
     214              :       // Maximum/minimum int4 values
     215           48 :       v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT4_MIN));
     216           48 :       v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT4_MAX));
     217              : 
     218           48 :       const uint8_t v0_u8 = (uint8_t)(v0_s32 + 8);
     219              : 
     220           48 :       const size_t dst_addr = (n_idx / 2) + k_idx * rhs_qs4cx_stride;
     221           48 :       uint8_t rhs_v0 = rhs_qs4cx[dst_addr];
     222              : 
     223           48 :       if ((n_idx % 2) == 0) {
     224           24 :         rhs_v0 |= v0_u8;
     225              :       } else {
     226           24 :         rhs_v0 |= (v0_u8 << 4);
     227              :       }
     228           48 :       rhs_qs4cx[dst_addr] = rhs_v0;
     229              :     }
     230              : 
     231            6 :     rhs_scales_f32[n_idx] = recip_scale0;
     232              :   }
     233            2 : };
     234              : 
     235            5 : void quant_qs4cx_f32(size_t n, size_t k, rhs_format format,
     236              :                      const float *rhs_f32, uint8_t *rhs_qs4cx,
     237              :                      float *rhs_scales_f32) {
     238            5 :   if (rhs_format::nxk == format) {
     239            3 :     quant_nxk_qs4cx_f32(n, k, rhs_f32, rhs_qs4cx, rhs_scales_f32);
     240              :   } else {
     241            2 :     quant_kxn_qs4cx_f32(n, k, rhs_f32, rhs_qs4cx, rhs_scales_f32);
     242              :   }
     243            5 : };
     244              : 
     245            2 : static void ref_matmul_mxn_mxk_nxk_f32_qa8dx_qs4cx( // transB
     246              :   size_t m, size_t n, size_t k, const int8_t *lhs_qa8dx,
     247              :   const uint8_t *rhs_qs4cx, const float *rhs_scales_f32, float *dst_f32,
     248              :   float scalar_min, float scalar_max) {
     249            2 :   const size_t lhs_stride =
     250              :     k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
     251              : 
     252            2 :   const size_t rhs_qs4cx_stride = (roundup(k, 2) / 2);
     253              : 
     254            6 :   for (size_t m_idx = 0; m_idx < m; ++m_idx) {
     255            4 :     const int8_t *lhs_ptr_start = lhs_qa8dx + m_idx * lhs_stride;
     256              : 
     257           12 :     for (size_t n_idx = 0; n_idx < n; ++n_idx) {
     258              :       // Main f32 accumulator
     259              :       int32_t iacc = 0;
     260              : 
     261              :       const int8_t *lhs_ptr = lhs_ptr_start;
     262            8 :       const uint8_t *rhs_ptr = rhs_qs4cx + n_idx * rhs_qs4cx_stride;
     263              : 
     264              :       // Get the LHS quantization parameters stored at the
     265              :       // beginning of each row
     266            8 :       const float lhs_scale = *(const float *)lhs_ptr;
     267              :       lhs_ptr += sizeof(float);
     268              : 
     269            8 :       const int32_t lhs_offset = *(const int32_t *)lhs_ptr;
     270            8 :       lhs_ptr += sizeof(int32_t);
     271              : 
     272           72 :       for (size_t k_idx = 0; k_idx < k; ++k_idx) {
     273              :         // Get the LHS values
     274           64 :         const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
     275              : 
     276              :         // Get the RHS values
     277           64 :         const uint8_t rhs_byte = rhs_ptr[0];
     278              : 
     279              :         // Unpack the RHS values
     280              :         int32_t rhs_v0 = 0;
     281           64 :         if ((k_idx % 2) == 0) {
     282           32 :           rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8);
     283              :         } else {
     284           32 :           rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8);
     285              :         }
     286              : 
     287           64 :         iacc += lhs_v0 * rhs_v0;
     288           64 :         iacc += lhs_offset * rhs_v0;
     289              : 
     290           64 :         lhs_ptr += 1;
     291              : 
     292              :         // Increment only when k_idx is not a multiple of 2
     293           64 :         rhs_ptr += k_idx % 2;
     294              :       }
     295              : 
     296              :       // Get the RHS scale
     297            8 :       const float rhs_scale = rhs_scales_f32[n_idx];
     298              : 
     299            8 :       float main_acc = iacc * rhs_scale;
     300              : 
     301            8 :       main_acc = main_acc * lhs_scale;
     302              : 
     303              :       // Clamp (min-max) operation
     304            8 :       main_acc = std::max(main_acc, scalar_min);
     305            8 :       main_acc = std::min(main_acc, scalar_max);
     306              : 
     307            8 :       dst_f32[0] = main_acc;
     308            8 :       dst_f32 += 1;
     309              :     }
     310              :   }
     311            2 : };
     312              : 
     313            1 : static void ref_matmul_mxn_mxk_kxn_f32_qa8dx_qs4cx( // noTrans
     314              :   size_t m, size_t n, size_t k, const int8_t *lhs_qa8dx,
     315              :   const uint8_t *rhs_qs4cx, const float *rhs_scales_f32, float *dst_f32,
     316              :   float scalar_min, float scalar_max) {
     317            1 :   const size_t lhs_stride =
     318              :     k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
     319              : 
     320            1 :   const size_t rhs_qs4cx_stride = (roundup(n, 2) / 2);
     321              : 
     322            3 :   for (size_t m_idx = 0; m_idx < m; ++m_idx) {
     323            2 :     const int8_t *lhs_ptr_start = lhs_qa8dx + m_idx * lhs_stride;
     324              : 
     325            6 :     for (size_t n_idx = 0; n_idx < n; ++n_idx) {
     326              :       // Main f32 accumulator
     327              :       int32_t iacc = 0;
     328              : 
     329              :       const int8_t *lhs_ptr = lhs_ptr_start;
     330            4 :       const uint8_t *rhs_ptr = rhs_qs4cx + (n_idx / 2);
     331              : 
     332              :       // Get the LHS quantization parameters stored at the
     333              :       // beginning of each row
     334            4 :       const float lhs_scale = *(const float *)lhs_ptr;
     335              :       lhs_ptr += sizeof(float);
     336              : 
     337            4 :       const int32_t lhs_offset = *(const int32_t *)lhs_ptr;
     338            4 :       lhs_ptr += sizeof(int32_t);
     339              : 
     340           36 :       for (size_t k_idx = 0; k_idx < k; ++k_idx) {
     341              :         // Get the LHS values
     342           32 :         const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
     343              : 
     344              :         // Get the RHS values
     345           32 :         const uint8_t rhs_byte = rhs_ptr[0];
     346              : 
     347              :         // Unpack the RHS values
     348              :         int32_t rhs_v0 = 0;
     349           32 :         if ((n_idx % 2) == 0) {
     350           16 :           rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8);
     351              :         } else {
     352           16 :           rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8);
     353              :         }
     354              : 
     355           32 :         iacc += lhs_v0 * rhs_v0;
     356           32 :         iacc += lhs_offset * rhs_v0;
     357              : 
     358           32 :         lhs_ptr += 1;
     359              : 
     360              :         // Increment only when k_idx is not a multiple of 2
     361           32 :         rhs_ptr += rhs_qs4cx_stride;
     362              :       }
     363              : 
     364              :       // Get the RHS scale
     365            4 :       const float rhs_scale = rhs_scales_f32[n_idx];
     366              : 
     367            4 :       float main_acc = iacc * rhs_scale;
     368              : 
     369            4 :       main_acc = main_acc * lhs_scale;
     370              : 
     371              :       // Clamp (min-max) operation
     372            4 :       main_acc = std::max(main_acc, scalar_min);
     373            4 :       main_acc = std::min(main_acc, scalar_max);
     374              : 
     375            4 :       dst_f32[0] = main_acc;
     376            4 :       dst_f32 += 1;
     377              :     }
     378              :   }
     379            1 : };
     380              : 
     381            3 : void ref_matmul_f32_qa8dx_qs4cx(size_t m, size_t n, size_t k, rhs_format format,
     382              :                                 const int8_t *lhs_qa8dx,
     383              :                                 const uint8_t *rhs_qs4cx,
     384              :                                 const float *rhs_scales_f32, float *dst_f32,
     385              :                                 float scalar_min, float scalar_max) {
     386              :   const size_t lhs_stride =
     387              :     k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
     388              : 
     389            3 :   if (rhs_format::nxk == format) {
     390            2 :     ref_matmul_mxn_mxk_nxk_f32_qa8dx_qs4cx(m, n, k, lhs_qa8dx, rhs_qs4cx,
     391              :                                            rhs_scales_f32, dst_f32, scalar_min,
     392              :                                            scalar_max);
     393              :   } else {
     394            1 :     ref_matmul_mxn_mxk_kxn_f32_qa8dx_qs4cx(m, n, k, lhs_qa8dx, rhs_qs4cx,
     395              :                                            rhs_scales_f32, dst_f32, scalar_min,
     396              :                                            scalar_max);
     397              :   }
     398            3 : };
        

Generated by: LCOV version 2.0-1