LCOV - code coverage report
Current view: top level - nntrainer/tensor/cpu_backend/x86 - avx2_impl.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 57.1 % 609 348
Test Date: 2026-01-12 20:43:37 Functions: 47.6 % 42 20

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2023 Donghyeon Jeong <dhyeon.jeong@samsung.com>
       4              :  *
       5              :  * @file   avx2_impl.cpp
       6              :  * @date   20 Feb 2024
       7              :  * @see    https://github.com/nntrainer/nntrainer
       8              :  * @author Donghyeon Jeong <dhyeon.jeong@samsung.com>
       9              :  * @author Sungsik Kong <ss.kong@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  * @brief  This is a source for AVX implementation
      12              :  *
      13              :  */
      14              : 
      15              : #include "avx2_impl.h"
      16              : #include <array>
      17              : #if __has_include(<bit>)
      18              : #include <bit>
      19              : #endif
      20              : #include <cassert>
      21              : #include <cmath>
      22              : #include <cstdint>
      23              : #include <cstring>
      24              : #include <fp16.h>
      25              : #include <immintrin.h>
      26              : #include <limits>
      27              : #if __has_include(<numbers>)
      28              : #include <numbers>
      29              : #endif
      30              : #include <type_traits>
      31              : #if __has_include(<version>)
      32              : #include <version>
      33              : #endif
      34              : #include <fallback_internal.h>
      35              : #include <util_func.h>
      36              : #include <vector>
      37              : 
      38              : #include "nntr_ggml_impl_common.h"
      39              : 
      40              : #if !defined(__has_constexpr_builtin)
      41              : #define __has_constexpr_builtin(x) (0)
      42              : #endif
      43              : 
      44              : #if !defined(__has_cpp_attribute)
      45              : #define __has_cpp_attribute(x) (0)
      46              : #endif
      47              : 
      48              : // VECTORCALL calling-conv (default for x86_64-linux-gnu)
      49              : #if _MSC_VER >= 1700
      50              : #define _nnt_CC_VECTORCALL __vectorcall
      51              : #else
      52              : #define _nnt_CC_VECTORCALL
      53              : #endif
      54              : 
      55              : // Flatten attribute
      56              : #if _MSC_VER >= 1700 || __has_cpp_attribute(msvc::flatten)
      57              : #define _nnt_ATTR_FLATTEN [[msvc::flatten]]
      58              : #elif __has_cpp_attribute(gnu::flatten)
      59              : // clang, g++
      60              : #define _nnt_ATTR_FLATTEN [[gnu::flatten]]
      61              : #else
      62              : #define _nnt_ATTR_FLATTEN
      63              : #endif
      64              : 
      65              : #if _MSC_VER >= 1700 || __has_cpp_attribute(msvc::noinline)
      66              : #define _nnt_ATTR_NOINLINE [[msvc::noinline]]
      67              : #elif __has_cpp_attribute(gnu::flatten)
      68              : // clang, g++
      69              : #define _nnt_ATTR_NOINLINE [[gnu::noinline]]
      70              : #else
      71              : #define _nnt_ATTR_NOINLINE
      72              : #endif
      73              : 
      74              : #if _MSC_VER >= 1700 || __has_cpp_attribute(msvc::forceinline)
      75              : #define _nnt_ATTR_ALWAYS_INLINE [[msvc::forceinline]]
      76              : #elif __has_cpp_attribute(gnu::always_inline)
      77              : #define _nnt_ATTR_ALWAYS_INLINE [[gnu::always_inline]]
      78              : #endif
      79              : 
      80              : #if __has_cpp_attribute(unikely)
      81              : #define UNLIKELY [[unlikely]]
      82              : #else
      83              : #define UNLIKELY
      84              : #endif
      85              : 
      86              : #if !defined(_MSC_VER) && !defined(__clang__)
      87              : #pragma GCC diagnostic ignored "-Wattributes"
      88              : #endif
      89              : 
      90              : #if defined(__clang__) || defined(__GNUC__)
      91              : #define RESTRICT __restrict__
      92              : #else
      93              : #define RESTRICT
      94              : #endif
      95              : 
      96              : namespace {
      97              : 
      98              : template <typename To_, typename From_>
      99              : constexpr inline bool concept17_BinaryCastable =
     100              :   sizeof(To_) == sizeof(From_) &&
     101              :   std::is_trivially_copyable_v<From_> &&std::is_trivially_copyable_v<To_>;
     102              : 
     103              : template <class To_, class From_>
     104              : auto compat_bit_cast(const From_ &src) noexcept
     105              :   -> std::enable_if_t<concept17_BinaryCastable<To_, From_>, To_> {
     106              : #if __cpp_lib_bit_cast >= 201806L
     107              :   return std::bit_cast<To_>(src);
     108              : #else
     109              :   To_ dst;
     110              :   std::memcpy(&dst, &src, sizeof(To_));
     111              :   return dst;
     112              : #endif
     113              : }
     114              : 
     115              : [[nodiscard]] constexpr inline unsigned
     116              : constexpr_popcount(uint32_t v) noexcept {
     117              : #if __cpp_lib_bitops >= 201907L
     118              :   return std::popcount(v);
     119              : #else
     120              :   // Popcount via bit-hack
     121              :   v = v - ((v >> 1) & 0x55555555);                // reuse input as temporary
     122              :   v = (v & 0x33333333) + ((v >> 2) & 0x33333333); // temp
     123              :   auto c = (((v + (v >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24; // count
     124              :   return c;
     125              : #endif
     126              : }
     127              : 
     128              : template <unsigned I_>
     129              : constexpr inline bool concept17_PowerOfTwo = (constexpr_popcount(I_) == 1);
     130              : 
     131              : namespace numbers {
     132              : 
     133              : #if __has_include(<numbers>) && __cpp_lib_math_constants >= 201907L
     134              : using std::numbers::ln2_v;
     135              : using std::numbers::log2e_v;
     136              : #else
     137              : template <typename Float_> constexpr inline auto ln2_v = Float_{M_LN2};
     138              : 
     139              : template <typename Float_> constexpr inline auto log2e_v = Float_{M_LOG2E};
     140              : #endif
     141              : } // namespace numbers
     142              : 
     143              : constexpr inline float EXP_ARG_MIN = -87.0;
     144              : constexpr inline float EXP_ARG_MAX = +88.3762626647949f;
     145              : 
     146              : // @brief Precalculated lookup table for 2^x calculation
     147              : template <unsigned N_, typename Ty_ = uint32_t, typename Float_ = float,
     148              :           typename = std::enable_if_t<concept17_PowerOfTwo<N_>>>
     149              : struct Exp2Table {
     150              : 
     151              :   constexpr static inline auto MANTISSA_BITS =
     152              :     std::numeric_limits<Float_>::digits - 1;
     153              : 
     154              : #if __cpp_consteval >= 201811L && __has_constexpr_builtin(__builtin_exp2)
     155              :   [[nodiscard]] static consteval auto calculate() noexcept {
     156              :     std::array<Ty_, N_> t;
     157              :     for (unsigned i = 0; i < N_; ++i)
     158              :       t[i] = std::bit_cast<Ty_>(std::exp2(Float_{1.0} * i / N_)) -
     159              :              ((i << MANTISSA_BITS) / N_);
     160              :     return t;
     161              :   }
     162              : #endif
     163              : };
     164              : 
     165              : #if !__has_constexpr_builtin(__builtin_exp2) || !(__cpp_consteval >= 201811L)
     166              : 
     167              : // @brief Precalculated lookup table for 2^x calculation when we don't have
     168              : // constexpr math functions
     169              : template <> struct Exp2Table<8, uint32_t, float> {
     170              :   [[nodiscard]] static constexpr auto calculate() noexcept {
     171              :     std::array<uint32_t, 8> t = {0x3f800000U, 0x3f7b95c2U, 0x3f7837f0U,
     172              :                                  0x3f75fed7U, 0x3f7504f3U, 0x3f75672aU,
     173              :                                  0x3f7744fdU, 0x3f7ac0c7U};
     174              :     return t;
     175              :   }
     176              : };
     177              : 
     178              : #endif
     179              : 
     180              : template <unsigned N_> // requires PowerOfTwo<N_>
     181              : alignas(__m256) inline constexpr auto exp2_table_v = Exp2Table<N_>::calculate();
     182              : 
     183              : // Scalar version of expf() approximation with 3-th deg polynominal of
     184              : // fractional part
     185              : //
     186              : // The error with regards to std::expf less than 5e-6
     187              : // It is valid in range [-87, +88.37) - not handling +INF, NaN etc.
     188              : // The function domain is clamped to valid function range.
     189              : //
     190              : // The strategy picked is to approximate exp as 2^K*2^F
     191              : template <unsigned N_, typename = std::enable_if_t<concept17_PowerOfTwo<N_>>>
     192              : [[nodiscard]] _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN inline float
     193              : approx_exp_exp2_lookup(float x) noexcept {
     194              :   constexpr static unsigned N_MASK = uint32_t(N_ - 1U);
     195              :   constexpr static unsigned FLT_MANTISSA_BITS =
     196              :     std::numeric_limits<float>::digits - 1U;
     197              : 
     198              :   x = std::max(x, EXP_ARG_MIN);
     199              :   x *= float(0x1.0p1 / numbers::ln2_v<double> * N_);
     200              :   x = std::min(x, float(EXP_ARG_MAX / numbers::ln2_v<double> * N_));
     201              : 
     202              :   // Round nearest and convert integer part to an int (std::modf)
     203              :   // NB: This way doesn't handle ties even.
     204              :   auto x_int = x + 0x1.8p23f;
     205              :   auto x_uint = compat_bit_cast<uint32_t>(x_int);
     206              :   x_int -= 0x1.8p23f;
     207              :   auto x_frac = x - x_int;
     208              : 
     209              :   auto s_int = exp2_table_v<N_>[x_uint & N_MASK];
     210              :   auto x_uint_shifted = x_uint
     211              :                         << (FLT_MANTISSA_BITS - constexpr_popcount(N_MASK));
     212              :   auto s_int_2 = s_int + x_uint_shifted;
     213              :   auto s = compat_bit_cast<float>(s_int_2);
     214              : 
     215              :   // Polynominal of form C0*x^3 + C1*x^2 + C2*x^1 + 1.0
     216              :   static constexpr float poly_4[] = {0x1.c6af84b912394p-5f / N_ / N_ / N_,
     217              :                                      0x1.ebfce50fac4f3p-3f / N_ / N_,
     218              :                                      0x1.62e42ff0c52d6p-1f / N_};
     219              : 
     220              :   auto q0 = std::fma(x_frac, poly_4[0], poly_4[1]);
     221              :   auto x_frac_pow2 = x_frac * x_frac;
     222              :   auto q2 = x_frac * poly_4[2]; // not adding +1.0
     223              : 
     224              :   x = std::fma(q0, x_frac_pow2, q2);
     225              : 
     226              :   return std::fma(x, s, s); // NB: (handles (x+1) by addition of s)
     227              : }
     228              : 
     229              : // Vectorized version of above
     230              : template <unsigned N_, typename = std::enable_if_t<concept17_PowerOfTwo<N_>>>
     231              : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN inline auto _nnt_CC_VECTORCALL
     232              : avx2_approx_exp_e2lookup(__m256 xs) noexcept {
     233              : 
     234              :   constexpr static uint32_t N_MASK = uint32_t(N_ - 1U);
     235              :   alignas(64) constexpr static auto EXP2_TBL = exp2_table_v<N_>;
     236              :   constexpr static unsigned MANTISSA_BITS =
     237              :     std::numeric_limits<float>::digits - 1;
     238              : 
     239              :   // Ensure arg in range [exp_arg_min, exp_arg_max]
     240              :   xs = _mm256_max_ps(xs, _mm256_set1_ps(EXP_ARG_MIN));
     241              :   // Would clamp to EXP_ARG_MAX but we move it after multiplication for IPC:
     242              :   // xs = _mm256_min_ps(xs, _mm256_set1_ps(EXP_ARG_MAX));
     243              : 
     244              :   xs =
     245              :     _mm256_mul_ps(xs, _mm256_set1_ps(float(1.0 / numbers::ln2_v<double> * N_)));
     246              :   // Clamp EXP_ARG_MAX after multiply
     247              :   xs = _mm256_min_ps(
     248              :     xs, _mm256_set1_ps(float(EXP_ARG_MAX / numbers::ln2_v<double> * N_)));
     249              : 
     250              :   // Mostly equivalent to, doesn't round ties to even
     251              :   // auto xs_int = _mm256_round_ps(xs, _MM_FROUND_TO_NEAREST_INT |
     252              :   // _MM_FROUND_NO_EXC); auto xs_int_as_u32 = _mm256_cvtps_epi32(xs_int);
     253              :   auto xs_int = _mm256_add_ps(xs, _mm256_set1_ps(0x1.8p23f));
     254              :   auto xs_int_as_u32 = _mm256_castps_si256(xs_int);
     255              :   xs_int = _mm256_sub_ps(xs_int, _mm256_set1_ps(0x1.8p23f));
     256              : 
     257              :   // Calculate fractional part
     258              :   auto xs_frac = _mm256_sub_ps(xs, xs_int);
     259              :   // Indices for lookup (modulo N_)
     260              :   auto exp2_idxs = _mm256_and_si256(xs_int_as_u32, _mm256_set1_epi32(N_MASK));
     261              : 
     262              :   __m256i s_ints;
     263              : 
     264              :   // Lookup e^xs_int s factor
     265              :   if constexpr (N_ == 8) {
     266              :     // Lookup by vector permute
     267              :     auto tbl = _mm256_load_si256((__m256i *)EXP2_TBL.data());
     268              :     s_ints = _mm256_permutevar8x32_epi32(tbl, exp2_idxs);
     269              :   } else {
     270              :     // Falback for not fitting number of vector elements
     271              :     s_ints = _mm256_i32gather_epi32(EXP2_TBL.data(), exp2_idxs, 1);
     272              :   }
     273              : 
     274              :   auto xs_uint_shifted = _mm256_slli_epi32(
     275              :     xs_int_as_u32, MANTISSA_BITS - constexpr_popcount(N_MASK));
     276              :   auto s_ints_2 = _mm256_add_epi32(s_ints, xs_uint_shifted);
     277              :   auto s_floats = _mm256_castsi256_ps(s_ints_2);
     278              : 
     279              :   static constexpr float poly_d4[] = {0x1.c6af84b912394p-5f / N_ / N_ / N_,
     280              :                                       0x1.ebfce50fac4f3p-3f / N_ / N_,
     281              :                                       0x1.62e42ff0c52d6p-1f / N_};
     282              : 
     283              :   const auto C0 = _mm256_set1_ps(poly_d4[0]);
     284              :   const auto C1 = _mm256_set1_ps(poly_d4[1]);
     285              :   const auto C2 = _mm256_set1_ps(poly_d4[2]);
     286              : 
     287              :   auto qs0 = _mm256_fmadd_ps(xs_frac, C0, C1);
     288              :   auto xs_frac_pow2 = _mm256_mul_ps(xs_frac, xs_frac);
     289              :   auto qs2 = _mm256_mul_ps(xs_frac, C2);
     290              : 
     291              :   xs = _mm256_fmadd_ps(qs0, xs_frac_pow2, qs2);
     292              : 
     293              :   return _mm256_fmadd_ps(xs, s_floats, s_floats);
     294              : }
     295              : 
     296              : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN auto _nnt_CC_VECTORCALL
     297              : avx2_negate_ps(__m256 x) noexcept -> __m256 {
     298              :   constexpr auto SIGN_SHIFT = sizeof(float) * 8 - 1;
     299              :   const auto UNDEF = _mm256_undefined_si256();
     300              :   const auto sign_bit =
     301              :     _mm256_slli_epi32(_mm256_cmpeq_epi16(UNDEF, UNDEF), SIGN_SHIFT);
     302              :   auto flt_sign_bit = _mm256_castsi256_ps(sign_bit);
     303              :   auto neg_x = _mm256_xor_ps(x, flt_sign_bit);
     304              :   return neg_x;
     305              : }
     306              : 
     307              : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN auto _nnt_CC_VECTORCALL
     308              : avx2_approx_swiglu(__m256 x, __m256 s) noexcept -> __m256 {
     309              :   auto neg_x = avx2_negate_ps(x);
     310              :   auto inv_sigmoid =
     311              :     _mm256_add_ps(avx2_approx_exp_e2lookup<8>(neg_x), _mm256_set1_ps(1.0f));
     312              :   auto swiglu_nonscaled = _mm256_div_ps(x, inv_sigmoid);
     313              :   return _mm256_mul_ps(swiglu_nonscaled, s);
     314              : }
     315              : 
     316              : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN auto _nnt_CC_VECTORCALL
     317              : avx2_approx_swiglu_alpha(__m256 x, __m256 s, __m256 alpha) noexcept -> __m256 {
     318              :   auto alpha_x = _mm256_mul_ps(alpha, x);
     319              :   auto neg_alpha_x = avx2_negate_ps(alpha_x);
     320              :   auto inv_sigmoid = _mm256_add_ps(avx2_approx_exp_e2lookup<8>(neg_alpha_x),
     321              :                                    _mm256_set1_ps(1.0f));
     322              :   auto swiglu_nonscaled = _mm256_div_ps(x, inv_sigmoid);
     323              :   return _mm256_mul_ps(swiglu_nonscaled, s);
     324              : }
     325              : } // namespace
     326              : 
     327              : namespace nntrainer::avx2 {
     328              : 
     329              : /**
     330              :  * @brief struct of q4_0x8 block
     331              :  */
     332              : struct block_q4_0x8 {
     333              :   uint16_t d[8];   // 16B
     334              :   uint8_t qs[128]; // 16 x u64
     335              : };
     336              : 
     337              : #define USE_NONTEMPORAL_STORES 1
     338              : 
     339              : static inline void store256_u16(void *dst, __m256i v) {
     340              : #if defined(USE_NONTEMPORAL_STORES)
     341              :   // use NT only if 32B-aligned; otherwise fall back (correctness first)
     342              :   if (((uintptr_t)dst & 31u) == 0) {
     343              :     _mm256_stream_si256((__m256i *)dst, v);
     344              :     return;
     345              :   }
     346              : #endif
     347              :   _mm256_storeu_si256((__m256i *)dst, v);
     348              : }
     349              : 
     350            0 : void unpack_q4_0x8_transpose16(const void *src, unsigned short *__restrict dT,
     351              :                                unsigned short *__restrict qsT, int N, int K,
     352              :                                int CT) // column tile (in units of 32-cols)
     353              : {
     354            0 :   assert((K % 256) == 0);
     355            0 :   assert((N % 8) == 0);
     356              : 
     357              :   const auto *__restrict x = static_cast<const block_q4_0x8 *>(src);
     358              : 
     359            0 :   const int groups_N8 = N / 8;    // number of 8-row groups
     360            0 :   const int cols_scales = K / 32; // K subblocks
     361              : 
     362              :   // AVX2 constants
     363            0 :   const __m128i v88 = _mm_set1_epi8((char)0x88);
     364            0 :   const __m128i v0f = _mm_set1_epi8((char)0x0F);
     365            0 :   const __m128i vF0 = _mm_set1_epi8((char)0xF0);
     366              : 
     367              :   const __m128i idx_even =
     368              :     _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, (char)0xFF, (char)0xFF, (char)0xFF,
     369            0 :                   (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF);
     370              :   const __m128i idx_odd =
     371              :     _mm_setr_epi8(1, 3, 5, 7, 9, 11, 13, 15, (char)0xFF, (char)0xFF, (char)0xFF,
     372            0 :                   (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF);
     373              :   const __m128i idx_0246 =
     374              :     _mm_setr_epi8(0, 2, 4, 6, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
     375              :                   (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
     376            0 :                   (char)0xFF, (char)0xFF, (char)0xFF);
     377              :   const __m128i idx_1357 =
     378              :     _mm_setr_epi8(1, 3, 5, 7, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
     379              :                   (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
     380            0 :                   (char)0xFF, (char)0xFF, (char)0xFF);
     381              : 
     382            0 :   auto pack_row8 = [&](const unsigned char *qs0, const unsigned char *qs1,
     383              :                        int off) -> __m128i {
     384            0 :     __m128i lo8 = _mm_loadl_epi64((const __m128i *)(qs0 + 8 * off));
     385            0 :     __m128i hi8 = _mm_loadl_epi64((const __m128i *)(qs1 + 8 * off));
     386              :     __m128i v = _mm_unpacklo_epi64(lo8, hi8);
     387            0 :     v = _mm_xor_si128(v, v88);
     388            0 :     __m128i lo = _mm_and_si128(v, v0f);
     389              :     __m128i hi = _mm_and_si128(_mm_srli_epi16(v, 4), v0f);
     390            0 :     __m128i lo_e = _mm_shuffle_epi8(lo, idx_even);
     391            0 :     __m128i lo_o = _mm_shuffle_epi8(lo, idx_odd);
     392              :     __m128i hi_e = _mm_shuffle_epi8(hi, idx_even);
     393              :     __m128i hi_o = _mm_shuffle_epi8(hi, idx_odd);
     394              :     __m128i low_lane =
     395            0 :       _mm_or_si128(lo_e, _mm_and_si128(_mm_slli_epi16(lo_o, 4), vF0));
     396              :     __m128i high_lane =
     397              :       _mm_or_si128(hi_e, _mm_and_si128(_mm_slli_epi16(hi_o, 4), vF0));
     398            0 :     __m128i low_e2 = _mm_shuffle_epi8(low_lane, idx_0246);
     399            0 :     __m128i low_o2 = _mm_shuffle_epi8(low_lane, idx_1357);
     400              :     __m128i high_e2 = _mm_shuffle_epi8(high_lane, idx_0246);
     401              :     __m128i high_o2 = _mm_shuffle_epi8(high_lane, idx_1357);
     402              :     __m128i pack_lo = _mm_unpacklo_epi8(low_e2, low_o2);   // 4×u16 (w0..w3)
     403              :     __m128i pack_hi = _mm_unpacklo_epi8(high_e2, high_o2); // 4×u16 (w4..w7)
     404            0 :     return _mm_unpacklo_epi64(pack_lo, pack_hi);           // 8×u16 (w0..w7)
     405            0 :   };
     406              : 
     407              :   auto transpose8x8_epi16 =
     408              :     [](__m128i r0, __m128i r1, __m128i r2, __m128i r3, __m128i r4, __m128i r5,
     409              :        __m128i r6, __m128i r7, __m128i &c0, __m128i &c1, __m128i &c2,
     410              :        __m128i &c3, __m128i &c4, __m128i &c5, __m128i &c6, __m128i &c7) {
     411              :       __m128i t0 = _mm_unpacklo_epi16(r0, r1);
     412              :       __m128i t1 = _mm_unpackhi_epi16(r0, r1);
     413              :       __m128i t2 = _mm_unpacklo_epi16(r2, r3);
     414              :       __m128i t3 = _mm_unpackhi_epi16(r2, r3);
     415              :       __m128i t4 = _mm_unpacklo_epi16(r4, r5);
     416              :       __m128i t5 = _mm_unpackhi_epi16(r4, r5);
     417              :       __m128i t6 = _mm_unpacklo_epi16(r6, r7);
     418              :       __m128i t7 = _mm_unpackhi_epi16(r6, r7);
     419              : 
     420              :       __m128i u0 = _mm_unpacklo_epi32(t0, t2);
     421              :       __m128i u1 = _mm_unpackhi_epi32(t0, t2);
     422              :       __m128i u2 = _mm_unpacklo_epi32(t1, t3);
     423              :       __m128i u3 = _mm_unpackhi_epi32(t1, t3);
     424              :       __m128i u4 = _mm_unpacklo_epi32(t4, t6);
     425              :       __m128i u5 = _mm_unpackhi_epi32(t4, t6);
     426              :       __m128i u6 = _mm_unpacklo_epi32(t5, t7);
     427              :       __m128i u7 = _mm_unpackhi_epi32(t5, t7);
     428              : 
     429              :       c0 = _mm_unpacklo_epi64(u0, u4);
     430              :       c1 = _mm_unpackhi_epi64(u0, u4);
     431              :       c2 = _mm_unpacklo_epi64(u1, u5);
     432              :       c3 = _mm_unpackhi_epi64(u1, u5);
     433              :       c4 = _mm_unpacklo_epi64(u2, u6);
     434              :       c5 = _mm_unpackhi_epi64(u2, u6);
     435              :       c6 = _mm_unpacklo_epi64(u3, u7);
     436              :       c7 = _mm_unpackhi_epi64(u3, u7);
     437              :     };
     438              : 
     439              :   // -------- pair-processing path: handle two 8-row groups (16 rows) per pass
     440              :   // --------
     441            0 :   const int groups_pairs = groups_N8 / 2;
     442              : 
     443              : #ifdef _MSC_VER
     444              : #pragma warning(push)
     445              : #pragma warning(disable : 4849)
     446              : #endif
     447            0 : #pragma omp parallel for collapse(2) schedule(static)
     448              : #ifdef _MSC_VER
     449              : #pragma warning(pop)
     450              : #endif
     451              :   for (int c0 = 0; c0 < cols_scales; c0 += CT) {
     452              :     for (int bp = 0; bp < groups_pairs; ++bp) {
     453              :       const int b0 = 2 * bp;
     454              :       const int b1 = b0 + 1;
     455              :       const int r0 = b0 * 8; // 16 rows: r0..r0+15
     456              :       const int c1 = std::min(c0 + CT, cols_scales);
     457              : 
     458              :       for (int c = c0; c < c1; ++c) {
     459              :         const block_q4_0x8 &A = x[b0 * cols_scales + c];
     460              :         const block_q4_0x8 &B = x[b1 * cols_scales + c];
     461              : 
     462              :         unsigned short *__restrict dT_c = dT + c * N;
     463              :         unsigned short *__restrict qsT_c0 = qsT + (c * 8) * N;
     464              : 
     465              :         // scales: pack two 8×u16 vectors → one 256b store to dT[c, r0..r0+15]
     466              :         __m128i sd0 = _mm_loadu_si128((const __m128i *)A.d);
     467              :         __m128i sd1 = _mm_loadu_si128((const __m128i *)B.d);
     468              :         __m256i sdp = _mm256_set_m128i(sd1, sd0);
     469              :         store256_u16(dT_c + r0, sdp);
     470              : 
     471              :         // pre-split stripes
     472              :         const unsigned char *__restrict A0 = A.qs;      // + 8*off
     473              :         const unsigned char *__restrict A1 = A.qs + 64; // + 8*off
     474              :         const unsigned char *__restrict B0 = B.qs;
     475              :         const unsigned char *__restrict B1 = B.qs + 64;
     476              : 
     477              :         // build 8 rows for A and 8 rows for B
     478              :         __m128i Ra[8], Rb[8];
     479              :         for (int off = 0; off < 8; ++off) {
     480              :           Ra[off] = pack_row8(A0, A1, off);
     481              :           Rb[off] = pack_row8(B0, B1, off);
     482              :         }
     483              : 
     484              :         // 8×8 transpose → columns (each 8×u16) for A and B
     485              :         __m128i Ca0, Ca1, Ca2, Ca3, Ca4, Ca5, Ca6, Ca7;
     486              :         __m128i Cb0, Cb1, Cb2, Cb3, Cb4, Cb5, Cb6, Cb7;
     487              :         transpose8x8_epi16(Ra[0], Ra[1], Ra[2], Ra[3], Ra[4], Ra[5], Ra[6],
     488              :                            Ra[7], Ca0, Ca1, Ca2, Ca3, Ca4, Ca5, Ca6, Ca7);
     489              :         transpose8x8_epi16(Rb[0], Rb[1], Rb[2], Rb[3], Rb[4], Rb[5], Rb[6],
     490              :                            Rb[7], Cb0, Cb1, Cb2, Cb3, Cb4, Cb5, Cb6, Cb7);
     491              : 
     492              :         // pair and store 32B per column t: rows r0..r0+15 are contiguous
     493              :         unsigned short *__restrict base = qsT_c0 + r0;
     494              :         const int S = N;
     495              :         store256_u16(base + 0 * S, _mm256_set_m128i(Cb0, Ca0));
     496              :         store256_u16(base + 1 * S, _mm256_set_m128i(Cb1, Ca1));
     497              :         store256_u16(base + 2 * S, _mm256_set_m128i(Cb2, Ca2));
     498              :         store256_u16(base + 3 * S, _mm256_set_m128i(Cb3, Ca3));
     499              :         store256_u16(base + 4 * S, _mm256_set_m128i(Cb4, Ca4));
     500              :         store256_u16(base + 5 * S, _mm256_set_m128i(Cb5, Ca5));
     501              :         store256_u16(base + 6 * S, _mm256_set_m128i(Cb6, Ca6));
     502              :         store256_u16(base + 7 * S, _mm256_set_m128i(Cb7, Ca7));
     503              :       }
     504              :     }
     505              :   }
     506              : 
     507              :   // -------- tail: if odd number of 8-row groups, process the last one (8 rows)
     508              :   // --------
     509            0 :   if (groups_N8 & 1) {
     510            0 :     const int b = groups_N8 - 1;
     511            0 :     const int r0 = b * 8;
     512              : 
     513            0 : #pragma omp parallel for schedule(static)
     514              :     for (int c0 = 0; c0 < cols_scales; c0 += CT) {
     515              :       const int c1 = std::min(c0 + CT, cols_scales);
     516              :       for (int c = c0; c < c1; ++c) {
     517              :         const block_q4_0x8 &A = x[b * cols_scales + c];
     518              :         unsigned short *__restrict dT_c = dT + c * N;
     519              :         unsigned short *__restrict qsT_c0 = qsT + (c * 8) * N;
     520              : 
     521              :         // scales (8×u16)
     522              :         __m128i sd0 = _mm_loadu_si128((const __m128i *)A.d);
     523              :         _mm_storeu_si128((__m128i *)(dT_c + r0), sd0);
     524              : 
     525              :         const unsigned char *__restrict A0 = A.qs;
     526              :         const unsigned char *__restrict A1 = A.qs + 64;
     527              : 
     528              :         __m128i R[8];
     529              :         for (int off = 0; off < 8; ++off)
     530              :           R[off] = pack_row8(A0, A1, off);
     531              : 
     532              :         __m128i C0, C1, C2, C3, C4, C5, C6, C7;
     533              :         transpose8x8_epi16(R[0], R[1], R[2], R[3], R[4], R[5], R[6], R[7], C0,
     534              :                            C1, C2, C3, C4, C5, C6, C7);
     535              : 
     536              :         unsigned short *__restrict base = qsT_c0 + r0;
     537              :         const int S = N;
     538              :         _mm_storeu_si128((__m128i *)(base + 0 * S), C0);
     539              :         _mm_storeu_si128((__m128i *)(base + 1 * S), C1);
     540              :         _mm_storeu_si128((__m128i *)(base + 2 * S), C2);
     541              :         _mm_storeu_si128((__m128i *)(base + 3 * S), C3);
     542              :         _mm_storeu_si128((__m128i *)(base + 4 * S), C4);
     543              :         _mm_storeu_si128((__m128i *)(base + 5 * S), C5);
     544              :         _mm_storeu_si128((__m128i *)(base + 6 * S), C6);
     545              :         _mm_storeu_si128((__m128i *)(base + 7 * S), C7);
     546              :       }
     547              :     }
     548              :   }
     549              : 
     550              : #if defined(USE_NONTEMPORAL_STORES)
     551              :   _mm_sfence(); // ensure NT stores are globally visible before returning
     552              : #endif
     553            0 : }
     554              : 
     555              : static inline __m256i butterfly32(__m256i a) {
     556              :   const __m256i SHUF_EVEN = _mm256_setr_epi8(
     557              :     0, 2, 4, 6, 8, 10, 12, 14, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
     558              :     (char)0x80, (char)0x80, (char)0x80, (char)0x80, 0, 2, 4, 6, 8, 10, 12, 14,
     559              :     (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
     560              :     (char)0x80, (char)0x80);
     561              :   const __m256i SHUF_ODD = _mm256_setr_epi8(
     562              :     1, 3, 5, 7, 9, 11, 13, 15, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
     563              :     (char)0x80, (char)0x80, (char)0x80, (char)0x80, 1, 3, 5, 7, 9, 11, 13, 15,
     564              :     (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
     565              :     (char)0x80, (char)0x80);
     566              :   const __m256i even = _mm256_shuffle_epi8(a, SHUF_EVEN);
     567              :   const __m256i odd = _mm256_shuffle_epi8(a, SHUF_ODD);
     568              :   const __m256i LO = _mm256_set1_epi8(0x0F);
     569              :   const __m256i HI = _mm256_set1_epi8((char)0xF0);
     570              :   __m256i low =
     571              :     _mm256_or_si256(_mm256_and_si256(even, LO),
     572              :                     _mm256_slli_epi16(_mm256_and_si256(odd, LO), 4));
     573              :   __m256i high =
     574              :     _mm256_or_si256(_mm256_srli_epi16(_mm256_and_si256(even, HI), 4),
     575              :                     _mm256_and_si256(odd, HI));
     576              :   high = _mm256_slli_si256(high, 8);
     577              :   return _mm256_or_si256(low, high);
     578              : }
     579              : 
     580              : // Build 16B packet [d0|d1] from two 8B chunks using vector loads (no GPR
     581              : // moves).
     582              : static inline __m128i make_pkt128(const uint8_t *base_qs, int d0, int d1) {
     583            0 :   __m128i lo = _mm_loadl_epi64((const __m128i *)(base_qs + ((size_t)d0 << 3)));
     584            0 :   __m128i hi = _mm_loadl_epi64((const __m128i *)(base_qs + ((size_t)d1 << 3)));
     585              :   return _mm_unpacklo_epi64(lo, hi);
     586              : }
     587              : 
     588              : // ================== core template with QS unrolled by 8 blocks
     589              : // ==================
     590              : template <int UNIT, int GROUPS>
     591              : static inline void convert_q4_0x8_noshuffle(const void *src,
     592              :                                             uint16_t *RESTRICT d_out,
     593              :                                             uint8_t *RESTRICT qs_out) {
     594              :   static_assert(UNIT % 16 == 0, "UNIT must be multiple of 16");
     595              :   constexpr int BLOCKS_PER_GROUP = UNIT / 8;  // d entries per offset per group
     596              :   constexpr int PAIRS_PER_OFFSET = UNIT / 16; // 16B packets per half per offset
     597              :   static_assert((PAIRS_PER_OFFSET % 4) == 0,
     598              :                 "need multiple of 4 packets (8 blocks) per iter");
     599              : 
     600              :   constexpr size_t D_ELEMS_PER_GROUP = 8 * BLOCKS_PER_GROUP;
     601              :   constexpr size_t QS_BYTES_PER_GROUP = (size_t)16 * UNIT;
     602              :   constexpr size_t QS_BYTES_PER_OFFSET = (size_t)2 * UNIT;
     603              : 
     604            0 :   const block_q4_0x8 *x = (const block_q4_0x8 *)src;
     605            0 :   const __m256i bias256 = _mm256_set1_epi8((char)0x88);
     606              : 
     607              : #ifdef _MSC_VER
     608              : #pragma warning(push)
     609              : #pragma warning(disable : 4849)
     610              : #endif
     611            0 : #pragma omp parallel for collapse(2) schedule(static)
     612              : #ifdef _MSC_VER
     613              : #pragma warning(pop)
     614              : #endif
     615              :   for (int b = 0; b < GROUPS; ++b) {
     616              :     for (int offset = 0; offset < 8; ++offset) {
     617              : 
     618              :       // ---- D slice ----
     619              :       {
     620              :         uint16_t *d_ptr = d_out + (size_t)b * D_ELEMS_PER_GROUP +
     621              :                           (size_t)offset * BLOCKS_PER_GROUP;
     622              :         const block_q4_0x8 *xb = x + (size_t)b * BLOCKS_PER_GROUP;
     623              :         for (int i = 0; i < BLOCKS_PER_GROUP; ++i) {
     624              :           d_ptr[i] = xb[i].d[offset];
     625              :         }
     626              :       }
     627              : 
     628              :       // ---- QS slice (unroll 8 blocks / 128B per iter) ----
     629              :       {
     630              :         uint8_t *qs_ptr = qs_out + (size_t)b * QS_BYTES_PER_GROUP +
     631              :                           (size_t)offset * QS_BYTES_PER_OFFSET;
     632              :         const int base_q = (b * UNIT * 2) + offset;
     633              :         const int d0 = (base_q & 15), d1 = d0 ^ 8;
     634              : 
     635            0 :         auto do_half = [&](int blk_base) {
     636              :           // Each iter handles 8 consecutive blocks: j..j+7
     637            0 :           for (int j = 0; j < PAIRS_PER_OFFSET; j += 8) {
     638            0 :             const uint8_t *q0 = x[blk_base + j + 0].qs;
     639            0 :             const uint8_t *q1 = x[blk_base + j + 1].qs;
     640            0 :             const uint8_t *q2 = x[blk_base + j + 2].qs;
     641            0 :             const uint8_t *q3 = x[blk_base + j + 3].qs;
     642            0 :             const uint8_t *q4 = x[blk_base + j + 4].qs;
     643            0 :             const uint8_t *q5 = x[blk_base + j + 5].qs;
     644            0 :             const uint8_t *q6 = x[blk_base + j + 6].qs;
     645            0 :             const uint8_t *q7 = x[blk_base + j + 7].qs;
     646              : 
     647              : #if Q4X8_PREFETCH_DIST > 0
     648              :             _mm_prefetch(
     649              :               (const char *)(x[blk_base + j + Q4X8_PREFETCH_DIST].qs),
     650              :               _MM_HINT_NTA);
     651              : #endif
     652              :             // Build 8 packets in XMM regs
     653            0 :             __m128i pkt0 = make_pkt128(q0, d0, d1);
     654              :             __m128i pkt1 = make_pkt128(q1, d0, d1);
     655              :             __m128i pkt2 = make_pkt128(q2, d0, d1);
     656              :             __m128i pkt3 = make_pkt128(q3, d0, d1);
     657              :             __m128i pkt4 = make_pkt128(q4, d0, d1);
     658              :             __m128i pkt5 = make_pkt128(q5, d0, d1);
     659              :             __m128i pkt6 = make_pkt128(q6, d0, d1);
     660              :             __m128i pkt7 = make_pkt128(q7, d0, d1);
     661              : 
     662              :             // Four 32B batches: [0|1], [2|3], [4|5], [6|7]
     663              :             __m256i v01 = _mm256_set_m128i(pkt1, pkt0);
     664              :             __m256i v23 = _mm256_set_m128i(pkt3, pkt2);
     665              :             __m256i v45 = _mm256_set_m128i(pkt5, pkt4);
     666              :             __m256i v67 = _mm256_set_m128i(pkt7, pkt6);
     667              : 
     668            0 :             v01 = _mm256_xor_si256(v01, bias256);
     669              :             v23 = _mm256_xor_si256(v23, bias256);
     670              :             v45 = _mm256_xor_si256(v45, bias256);
     671              :             v67 = _mm256_xor_si256(v67, bias256);
     672              : 
     673              :             __m256i o01 = butterfly32(v01);
     674              :             __m256i o23 = butterfly32(v23);
     675              :             __m256i o45 = butterfly32(v45);
     676              :             __m256i o67 = butterfly32(v67);
     677              : 
     678              : #if Q4X8_USE_STREAMING_STORES
     679              :             _mm256_stream_si256((__m256i *)(qs_ptr + 0), o01);
     680              :             _mm256_stream_si256((__m256i *)(qs_ptr + 32), o23);
     681              :             _mm256_stream_si256((__m256i *)(qs_ptr + 64), o45);
     682              :             _mm256_stream_si256((__m256i *)(qs_ptr + 96), o67);
     683              : #else
     684            0 :             _mm256_storeu_si256((__m256i *)(qs_ptr + 0), o01);
     685            0 :             _mm256_storeu_si256((__m256i *)(qs_ptr + 32), o23);
     686            0 :             _mm256_storeu_si256((__m256i *)(qs_ptr + 64), o45);
     687            0 :             _mm256_storeu_si256((__m256i *)(qs_ptr + 96), o67);
     688              : #endif
     689            0 :             qs_ptr += 128;
     690              :           }
     691              :         };
     692              : 
     693              :         // first half
     694              :         do_half(base_q >> 4);
     695              :         // second half (same d0/d1 pattern)
     696              :         do_half((base_q + UNIT) >> 4);
     697              :       }
     698              :     }
     699              :   }
     700              : 
     701              : #if Q4X8_USE_STREAMING_STORES
     702              :   _mm_sfence();
     703              : #endif
     704              : }
     705              : 
     706              : // ================== wrappers for your K,N combinations ==================
     707              : // K = 3072 (UNIT = 768)
     708            0 : void convert_q4_0x8_shuffle_K3072_N98304(const void *src, uint16_t *d_out,
     709              :                                          uint8_t *qs_out) {
     710              :   // groups = (N*8)/UNIT = 1024
     711              :   convert_q4_0x8_noshuffle<768, 1024>(src, d_out, qs_out);
     712            0 : }
     713            0 : void convert_q4_0x8_shuffle_K3072_N36864(const void *src, uint16_t *d_out,
     714              :                                          uint8_t *qs_out) {
     715              :   // groups = 384
     716              :   convert_q4_0x8_noshuffle<768, 384>(src, d_out, qs_out);
     717            0 : }
     718            0 : void convert_q4_0x8_shuffle_K3072_N3072(const void *src, uint16_t *d_out,
     719              :                                         uint8_t *qs_out) {
     720              :   // groups = 32
     721              :   convert_q4_0x8_noshuffle<768, 32>(src, d_out, qs_out);
     722            0 : }
     723              : 
     724              : // K = 8192 (UNIT = 2048)
     725            0 : void convert_q4_0x8_shuffle_K8192_N98304(const void *src, uint16_t *d_out,
     726              :                                          uint8_t *qs_out) {
     727              :   // groups = 384
     728              :   convert_q4_0x8_noshuffle<2048, 384>(src, d_out, qs_out);
     729            0 : }
     730            0 : void convert_q4_0x8_shuffle_K8192_N36864(const void *src, uint16_t *d_out,
     731              :                                          uint8_t *qs_out) {
     732              :   // groups = 144
     733              :   convert_q4_0x8_noshuffle<2048, 144>(src, d_out, qs_out);
     734            0 : }
     735            0 : void convert_q4_0x8_shuffle_K8192_N3072(const void *src, uint16_t *d_out,
     736              :                                         uint8_t *qs_out) {
     737              :   // groups = 12
     738              :   convert_q4_0x8_noshuffle<2048, 12>(src, d_out, qs_out);
     739            0 : }
     740              : 
     741              : // Optional tiny dispatcher if you want one entry point:
     742            0 : void convert_q4_0x8_shuffle_dispatch_avx(const void *src, uint16_t *d_out,
     743              :                                          uint8_t *qs_out, int N, int K) {
     744            0 :   if (K == 3072) {
     745            0 :     if (N == 98304)
     746            0 :       return convert_q4_0x8_shuffle_K3072_N98304(src, d_out, qs_out);
     747            0 :     if (N == 36864)
     748            0 :       return convert_q4_0x8_shuffle_K3072_N36864(src, d_out, qs_out);
     749            0 :     if (N == 3072)
     750            0 :       return convert_q4_0x8_shuffle_K3072_N3072(src, d_out, qs_out);
     751              :   } else { // K == 8192
     752            0 :     if (N == 98304)
     753            0 :       return convert_q4_0x8_shuffle_K8192_N98304(src, d_out, qs_out);
     754            0 :     if (N == 36864)
     755            0 :       return convert_q4_0x8_shuffle_K8192_N36864(src, d_out, qs_out);
     756            0 :     if (N == 3072)
     757            0 :       return convert_q4_0x8_shuffle_K8192_N3072(src, d_out, qs_out);
     758              :   }
     759              :   // If a new combo appears, fall back to a generic version (not shown here).
     760            0 :   assert(!"Unsupported (K,N) combination");
     761              : }
     762              : 
     763           12 : bool is_valid(const unsigned int N, const float *input) {
     764           12 :   assert(N != 0);
     765           12 :   assert(input != NULL);
     766              : 
     767              :   int temp = 0;
     768              :   unsigned int idx = 0;
     769              : 
     770              :   const __m256 SIGN_MASK = _mm256_set1_ps(-0.0);
     771              :   const __m256 INF = _mm256_set1_ps(std::numeric_limits<float>::infinity());
     772              : 
     773              :   // 16 single-precision check : ( X != X )
     774           15 :   for (; N - idx >= 16; idx += 16) {
     775              :     __m256 vec0 = _mm256_loadu_ps(input);
     776              :     __m256 vec1 = _mm256_loadu_ps(input + 8);
     777            6 :     input += 16;
     778              :     __m256 res = _mm256_cmp_ps(vec0, vec0, _CMP_NEQ_UQ);
     779            6 :     temp = temp | _mm256_movemask_ps(res);
     780              : 
     781            6 :     if (temp)
     782              :       return false;
     783              : 
     784              :     // check infinity in vec0
     785              :     vec0 = _mm256_andnot_ps(SIGN_MASK, vec0);
     786              :     vec0 = _mm256_cmp_ps(vec0, INF, _CMP_EQ_OQ);
     787              : 
     788            5 :     temp = temp | _mm256_movemask_ps(vec0);
     789            5 :     if (temp)
     790              :       return false;
     791              : 
     792              :     __m256 res1 = _mm256_cmp_ps(vec1, vec1, _CMP_NEQ_UQ);
     793            3 :     temp = temp | _mm256_movemask_ps(res1);
     794              : 
     795            3 :     if (temp)
     796              :       return false;
     797              : 
     798              :     // check infinity in vec1
     799              :     vec1 = _mm256_andnot_ps(SIGN_MASK, vec1);
     800              :     vec1 = _mm256_cmp_ps(vec1, INF, _CMP_EQ_OQ);
     801              : 
     802            3 :     temp = temp | _mm256_movemask_ps(vec1);
     803              : 
     804            3 :     if (temp)
     805              :       return false;
     806              :   }
     807              : 
     808              :   // 8 single-precision check : ( X != X )
     809           11 :   for (; N - idx >= 8; idx += 8) {
     810              :     __m256 vec = _mm256_loadu_ps(input);
     811            5 :     input += 8;
     812              :     __m256 res = _mm256_cmp_ps(vec, vec, _CMP_NEQ_UQ);
     813            5 :     temp = temp | _mm256_movemask_ps(res);
     814              : 
     815            5 :     if (temp)
     816              :       return false;
     817              : 
     818              :     // check infinity in vec
     819              :     vec = _mm256_andnot_ps(SIGN_MASK, vec);
     820              :     vec = _mm256_cmp_ps(vec, INF, _CMP_EQ_OQ);
     821              : 
     822            4 :     temp = temp | _mm256_movemask_ps(vec);
     823              : 
     824            4 :     if (temp)
     825              :       return false;
     826              :   }
     827              : 
     828           11 :   while (idx < N) {
     829            8 :     if (!isFloatValid(*input)) {
     830              :       return false;
     831              :     }
     832            5 :     ++input;
     833            5 :     ++idx;
     834              :   }
     835              : 
     836              :   return true;
     837              : }
     838              : 
     839       173980 : void custom_scopy(const unsigned int N, const float *X, const int incX,
     840              :                   float *Y, const int incY) {
     841       173980 :   unsigned int N8 = (N >> 3) << 3;
     842       957465 :   for (unsigned int i = 0; i < N8; i += 8) {
     843              : #if defined(_WIN32)
     844              :     __m256 temp = _mm256_loadu_ps(&X[i]);
     845              :     _mm256_storeu_ps(&Y[i], temp);
     846              : #else
     847       783485 :     __asm__ __volatile__("vmovups (%1), %%ymm0\n\t"
     848              :                          "vmovups %%ymm0, (%0)\n\t"
     849              :                          :
     850       783485 :                          : "r"(&Y[i]), "r"(&X[i])
     851              :                          : "ymm0", "memory");
     852              : #endif
     853              :   }
     854       499415 :   for (unsigned int i = N8; i < N; ++i) {
     855       325435 :     Y[i] = X[i];
     856              :   }
     857       173980 : }
     858              : 
     859           25 : void transpose_matrix(const unsigned int M, const unsigned int N,
     860              :                       const float *src, unsigned int ld_src, float *dst,
     861              :                       unsigned int ld_dst) {
     862           25 :   unsigned int vindexm[8] = {0,          ld_src,     ld_src * 2, ld_src * 3,
     863           25 :                              ld_src * 4, ld_src * 5, ld_src * 6, ld_src * 7};
     864              :   __m256i vindex = _mm256_loadu_si256((__m256i *)&vindexm[0]);
     865              :   __m256 vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8;
     866              : 
     867           25 :   unsigned int M8 = (M & ~(7));
     868           25 :   unsigned int N8 = (N & ~(7));
     869         1378 :   for (unsigned int i = 0; i < M8; i += 8) {
     870       308626 :     for (unsigned int j = 0; j < N8; j += 8) {
     871              :       // loading from columns
     872       307273 :       vec1 = _mm256_i32gather_ps(&src[ld_src * i + j + 0], vindex, 4);
     873       307273 :       vec2 = _mm256_i32gather_ps(&src[ld_src * i + j + 1], vindex, 4);
     874       307273 :       vec3 = _mm256_i32gather_ps(&src[ld_src * i + j + 2], vindex, 4);
     875       307273 :       vec4 = _mm256_i32gather_ps(&src[ld_src * i + j + 3], vindex, 4);
     876       307273 :       vec5 = _mm256_i32gather_ps(&src[ld_src * i + j + 4], vindex, 4);
     877       307273 :       vec6 = _mm256_i32gather_ps(&src[ld_src * i + j + 5], vindex, 4);
     878       307273 :       vec7 = _mm256_i32gather_ps(&src[ld_src * i + j + 6], vindex, 4);
     879       307273 :       vec8 = _mm256_i32gather_ps(&src[ld_src * i + j + 7], vindex, 4);
     880              : 
     881              :       // storing to the rows
     882       307273 :       _mm256_storeu_ps(&dst[(j + 0) * ld_dst + i], vec1);
     883       307273 :       _mm256_storeu_ps(&dst[(j + 1) * ld_dst + i], vec2);
     884       307273 :       _mm256_storeu_ps(&dst[(j + 2) * ld_dst + i], vec3);
     885       307273 :       _mm256_storeu_ps(&dst[(j + 3) * ld_dst + i], vec4);
     886       307273 :       _mm256_storeu_ps(&dst[(j + 4) * ld_dst + i], vec5);
     887       307273 :       _mm256_storeu_ps(&dst[(j + 5) * ld_dst + i], vec6);
     888       307273 :       _mm256_storeu_ps(&dst[(j + 6) * ld_dst + i], vec7);
     889       307273 :       _mm256_storeu_ps(&dst[(j + 7) * ld_dst + i], vec8);
     890              :     }
     891              :   }
     892              : 
     893              :   // tailing right
     894        10910 :   for (unsigned int i = 0; i < M; i++) {
     895        12316 :     for (unsigned int j = N8; j < N; j++) {
     896         1431 :       dst[i + j * ld_dst] = src[i * ld_src + j];
     897              :     }
     898              :   }
     899              : 
     900              :   // tailing bottom
     901           86 :   for (unsigned int i = M8; i < M; i++) {
     902          356 :     for (unsigned int j = 0; j < N; j++) {
     903          295 :       dst[i + j * ld_dst] = src[i * ld_src + j];
     904              :     }
     905              :   }
     906           25 : }
     907              : 
     908            0 : void swiglu(const unsigned int N, float *X, const float *Y, const float *Z) {
     909              :   size_t i = 0;
     910              : 
     911              :   const auto oldcsr = _mm_getcsr();
     912            0 :   _mm_setcsr(oldcsr | 0x8040); // DAZ | FTZ
     913              : 
     914              :   // 16-wide blocks
     915            0 :   for (; i + 16 <= N; i += 16) {
     916            0 :     const __m256 y0 = _mm256_loadu_ps(Y + i);
     917            0 :     const __m256 y1 = _mm256_loadu_ps(Y + i + 8);
     918            0 :     const __m256 z0 = _mm256_loadu_ps(Z + i);
     919            0 :     const __m256 z1 = _mm256_loadu_ps(Z + i + 8);
     920              : 
     921            0 :     _mm256_storeu_ps(X + i, avx2_approx_swiglu(y0, z0));
     922            0 :     _mm256_storeu_ps(X + i + 8, avx2_approx_swiglu(y1, z1));
     923              :   }
     924              : 
     925              :   // One 8-wide block if available
     926            0 :   if (i + 8 <= N) {
     927            0 :     const __m256 y0 = _mm256_loadu_ps(Y + i);
     928            0 :     const __m256 z0 = _mm256_loadu_ps(Z + i);
     929            0 :     _mm256_storeu_ps(X + i, avx2_approx_swiglu(y0, z0));
     930              :     i += 8;
     931              :   }
     932              : 
     933              :   // Remaining 1..7 elements via maskload/maskstore
     934            0 :   if (i < N) {
     935            0 :     const int remain = static_cast<int>(N - i); // 1..7
     936              : 
     937            0 :     alignas(64) const int mtab[16] = {-1, -1, -1, -1, -1, -1, -1, -1,
     938              :                                       0,  0,  0,  0,  0,  0,  0,  0};
     939              :     // Start so that we take 'remain' ones then zeros.
     940            0 :     const int off = 8 - remain; // in [1..7], or 0 if remain==8
     941            0 :     const __m256i vmask = _mm256_loadu_si256((const __m256i *)(mtab + off));
     942              : 
     943            0 :     const __m256 y = _mm256_maskload_ps(Y + i, vmask);
     944            0 :     const __m256 z = _mm256_maskload_ps(Z + i, vmask);
     945              :     const __m256 r = avx2_approx_swiglu(y, z);
     946            0 :     _mm256_maskstore_ps(X + i, vmask, r);
     947              :   }
     948              : 
     949              :   _mm_setcsr(oldcsr);
     950            0 : }
     951              : 
     952            0 : void swiglu(const unsigned int N, float *X, const float *Y, const float *Z,
     953              :             float alpha) {
     954              :   size_t i = 0;
     955              : 
     956              :   const auto oldcsr = _mm_getcsr();
     957            0 :   _mm_setcsr(oldcsr | 0x8040); // DAZ | FTZ
     958              : 
     959              :   const __m256 alpha_vec = _mm256_set1_ps(alpha);
     960              : 
     961              :   // 16-wide blocks
     962            0 :   for (; i + 16 <= N; i += 16) {
     963            0 :     const __m256 y0 = _mm256_loadu_ps(Y + i);
     964            0 :     const __m256 y1 = _mm256_loadu_ps(Y + i + 8);
     965            0 :     const __m256 z0 = _mm256_loadu_ps(Z + i);
     966            0 :     const __m256 z1 = _mm256_loadu_ps(Z + i + 8);
     967              : 
     968            0 :     _mm256_storeu_ps(X + i, avx2_approx_swiglu_alpha(y0, z0, alpha_vec));
     969            0 :     _mm256_storeu_ps(X + i + 8, avx2_approx_swiglu_alpha(y1, z1, alpha_vec));
     970              :   }
     971              : 
     972              :   // One 8-wide block if present
     973            0 :   if (i + 8 <= N) {
     974            0 :     const __m256 y0 = _mm256_loadu_ps(Y + i);
     975            0 :     const __m256 z0 = _mm256_loadu_ps(Z + i);
     976            0 :     _mm256_storeu_ps(X + i, avx2_approx_swiglu_alpha(y0, z0, alpha_vec));
     977              :     i += 8;
     978              :   }
     979              : 
     980              :   // Remaining 1..7 elements via masked AVX (no stray stores)
     981            0 :   if (i < N) {
     982            0 :     const int remain = static_cast<int>(N - i); // 1..7
     983              : 
     984            0 :     alignas(64) const int mtab[16] = {
     985              :       -1, -1, -1, -1, -1, -1, -1, -1, // ones
     986              :       0,  0,  0,  0,  0,  0,  0,  0   // zeros
     987              :     };
     988            0 :     const int off = 8 - remain; // choose first `remain` lanes active
     989            0 :     const __m256i vmask = _mm256_loadu_si256((const __m256i *)(mtab + off));
     990              : 
     991            0 :     const __m256 y = _mm256_maskload_ps(Y + i, vmask);
     992            0 :     const __m256 z = _mm256_maskload_ps(Z + i, vmask);
     993              :     const __m256 r = avx2_approx_swiglu_alpha(y, z, alpha_vec);
     994            0 :     _mm256_maskstore_ps(X + i, vmask, r);
     995              :   }
     996              : 
     997              :   _mm_setcsr(oldcsr);
     998            0 : }
     999              : 
    1000        29926 : void ele_mul(const unsigned int N, const float *X, const float *Y, float *Z,
    1001              :              float alpha, float beta, unsigned int i_stride,
    1002              :              unsigned int o_stride) {
    1003        29926 :   if (alpha == 1.0f && beta == 0.0f && o_stride == 1) {
    1004        29581 :     unsigned int N8 = (N & ~(7));
    1005        29581 :     if (i_stride == 0) {
    1006         9591 :       float vy8[8] = {Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0]};
    1007              :       auto y = _mm256_loadu_ps(&vy8[0]);
    1008        18108 :       for (unsigned int i = 0; i < N8; i += 8) {
    1009              :         auto x = _mm256_loadu_ps(X);
    1010              :         auto z = _mm256_mul_ps(x, y);
    1011              :         _mm256_storeu_ps(Z, z);
    1012         8517 :         X += 8;
    1013         8517 :         Y += i_stride * 8;
    1014         8517 :         Z += 8;
    1015              :       }
    1016              :     } else {
    1017     13214695 :       for (unsigned int i = 0; i < N8; i += 8) {
    1018              :         auto x = _mm256_loadu_ps(X);
    1019              :         auto y = _mm256_loadu_ps(Y);
    1020              :         auto z = _mm256_mul_ps(x, y);
    1021              :         _mm256_storeu_ps(Z, z);
    1022     13194705 :         X += 8;
    1023     13194705 :         Y += i_stride * 8;
    1024     13194705 :         Z += 8;
    1025              :       }
    1026              :     }
    1027       138986 :     for (unsigned int i = N8; i < N; ++i) {
    1028       109405 :       *Z = *X * *Y;
    1029       109405 :       X++;
    1030       109405 :       Y += i_stride;
    1031       109405 :       Z++;
    1032              :     }
    1033              :   } else {
    1034              :     // TODO: AVX2 implementation if used
    1035        65505 :     for (unsigned int i = 0; i < N; ++i) {
    1036        65160 :       *Z = *X * alpha * *Y + ((0.0f == beta) ? 0.0f : beta * *Z);
    1037        65160 :       X += o_stride;
    1038        65160 :       Y += i_stride;
    1039        65160 :       Z += o_stride;
    1040              :     }
    1041              :   }
    1042        29926 : }
    1043              : 
    1044       113033 : void ele_add(const unsigned int N, const float *X, const float *Y, float *Z,
    1045              :              float alpha, float beta, unsigned int i_stride,
    1046              :              unsigned int o_stride) {
    1047       113033 :   if (alpha == 1.0f && beta == 0.0f && o_stride == 1) {
    1048        75820 :     unsigned int N8 = (N & ~(7));
    1049        75820 :     if (i_stride == 0) {
    1050        26375 :       float vy8[8] = {Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0]};
    1051              :       auto y = _mm256_loadu_ps(&vy8[0]);
    1052       580391 :       for (unsigned int i = 0; i < N8; i += 8) {
    1053              :         auto x = _mm256_loadu_ps(X);
    1054              :         auto z = _mm256_add_ps(x, y);
    1055              :         _mm256_storeu_ps(Z, z);
    1056       554016 :         X += 8;
    1057       554016 :         Y += i_stride * 8;
    1058       554016 :         Z += 8;
    1059              :       }
    1060              :     } else {
    1061       169478 :       for (unsigned int i = 0; i < N8; i += 8) {
    1062              :         auto x = _mm256_loadu_ps(X);
    1063              :         auto y = _mm256_loadu_ps(Y);
    1064              :         auto z = _mm256_add_ps(x, y);
    1065              :         _mm256_storeu_ps(Z, z);
    1066       120033 :         X += 8;
    1067       120033 :         Y += i_stride * 8;
    1068       120033 :         Z += 8;
    1069              :       }
    1070              :     }
    1071       234632 :     for (unsigned int i = N8; i < N; ++i) {
    1072       158812 :       *Z = *X + *Y;
    1073       158812 :       X++;
    1074       158812 :       Y += i_stride;
    1075       158812 :       Z++;
    1076              :     }
    1077              :   } else {
    1078              :     // TODO: AVX2 implementation if used
    1079    203432843 :     for (unsigned int i = 0; i < N; ++i) {
    1080    203395630 :       *Z = *X + alpha * *Y + ((0.0f == beta) ? 0.0f : beta * *Z);
    1081    203395630 :       X += o_stride;
    1082    203395630 :       Y += i_stride;
    1083    203395630 :       Z += o_stride;
    1084              :     }
    1085              :   }
    1086       113033 : }
    1087              : 
    1088            9 : static inline __m256 exp256_ps(__m256 x) {
    1089              :   /*  Low-Precision Version I*/
    1090              :   // const __m256 c1 = _mm256_set1_ps(12102203.0f);
    1091              :   // const __m256 c2 = _mm256_set1_ps(1065353216.0f);
    1092              :   // __m256 fx = _mm256_add_ps(_mm256_mul_ps(x, c1),c2);
    1093              :   // return _mm256_castsi256_ps(_mm256_cvtps_epi32(fx));
    1094              : 
    1095              :   /* Low-Precision Version II*/
    1096              :   /*    const __m256 ln2 = _mm256_set1_ps(0.69314718056f);
    1097              :     const __m256 inv_ln2 = _mm256_set1_ps(1.44269504089f); // 1 / ln(2)
    1098              : 
    1099              :     // Range reduction: x = n * ln2 + r,  where n is integer and |r| <= ln2/2
    1100              :     __m256 fx = _mm256_mul_ps(x, inv_ln2);
    1101              :     fx = _mm256_round_ps(fx, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
    1102              :     __m256i emm0 = _mm256_cvtps_epi32(fx);
    1103              : 
    1104              :     __m256 tmp = _mm256_mul_ps(fx, ln2);
    1105              :     __m256 r = _mm256_sub_ps(x, tmp);
    1106              : 
    1107              :     // Compute polynomial approximation of exp(r)
    1108              :     const __m256 c1 = _mm256_set1_ps(1.9875691500E-4f);
    1109              :     const __m256 c2 = _mm256_set1_ps(1.3981999507E-3f);
    1110              :     const __m256 c3 = _mm256_set1_ps(8.3334519073E-3f);
    1111              :     const __m256 c4 = _mm256_set1_ps(4.1665795894E-2f);
    1112              :     const __m256 c5 = _mm256_set1_ps(1.6666665459E-1f);
    1113              :     const __m256 c6 = _mm256_set1_ps(5.0000001201E-1f);
    1114              : 
    1115              :   //  __m256 r2 = _mm256_mul_ps(r, r);
    1116              :   //  __m256 r3 = _mm256_mul_ps(r2, r);
    1117              :   //  __m256 r4 = _mm256_mul_ps(r2, r2);
    1118              : 
    1119              :     __m256 y = _mm256_fmadd_ps(c1, r, c2);
    1120              :     y = _mm256_fmadd_ps(y, r, c3);
    1121              :     y = _mm256_fmadd_ps(y, r, c4);
    1122              :     y = _mm256_fmadd_ps(y, r, c5);
    1123              :     y = _mm256_fmadd_ps(y, r, c6);
    1124              :     y = _mm256_fmadd_ps(y, r, _mm256_set1_ps(1.0f));
    1125              : 
    1126              :     // Reconstruct exp(x) = 2^n * exp(r)
    1127              :     emm0 = _mm256_add_epi32(emm0, _mm256_set1_epi32(127));
    1128              :     emm0 = _mm256_slli_epi32(emm0, 23);
    1129              :     __m256 pow2n = _mm256_castsi256_ps(emm0);
    1130              : 
    1131              :     return _mm256_mul_ps(y, pow2n);
    1132              :   */
    1133              :   /* Low-Precision Versino III */
    1134              :   const __m256 LOG2EF = _mm256_set1_ps(1.44269504088896341f); // 1 / ln(2)
    1135              :   const __m256 LN2 = _mm256_set1_ps(0.6931471805599453f);     // ln(2)
    1136              : 
    1137              :   // Clamp input to range to prevent overflow/underflow
    1138              :   const __m256 max_x = _mm256_set1_ps(88.3762626647949f);  // log(FLT_MAX)
    1139              :   const __m256 min_x = _mm256_set1_ps(-88.3762626647949f); // log(FLT_MIN)
    1140              :   x = _mm256_max_ps(min_x, _mm256_min_ps(max_x, x));
    1141              : 
    1142              :   // Range reduction: x = n * ln2 + r
    1143              :   __m256 fx = _mm256_mul_ps(x, LOG2EF); // x * (1/ln(2))
    1144              :   fx = _mm256_floor_ps(_mm256_add_ps(fx, _mm256_set1_ps(0.5f)));
    1145              : 
    1146              :   __m256 tmp = _mm256_mul_ps(fx, LN2); // n * ln(2)
    1147              :   __m256 r = _mm256_sub_ps(x, tmp);    // r = x - n * ln2
    1148              : 
    1149              :   // Compute exp(r) using 10th-order polynomial (Horner's method)
    1150              :   const __m256 c0 = _mm256_set1_ps(1.0f);
    1151              :   const __m256 c1 = _mm256_set1_ps(1.0f);
    1152              :   const __m256 c2 = _mm256_set1_ps(0.5f);
    1153              :   const __m256 c3 = _mm256_set1_ps(1.0f / 6.0f);
    1154              :   const __m256 c4 = _mm256_set1_ps(1.0f / 24.0f);
    1155              :   const __m256 c5 = _mm256_set1_ps(1.0f / 120.0f);
    1156              :   const __m256 c6 = _mm256_set1_ps(1.0f / 720.0f);
    1157              :   const __m256 c7 = _mm256_set1_ps(1.0f / 5040.0f);
    1158              :   const __m256 c8 = _mm256_set1_ps(1.0f / 40320.0f);
    1159              :   const __m256 c9 = _mm256_set1_ps(1.0f / 362880.0f);
    1160              :   const __m256 c10 = _mm256_set1_ps(1.0f / 3628800.0f);
    1161              : 
    1162              :   __m256 y = c10;
    1163              :   y = _mm256_fmadd_ps(y, r, c9);
    1164              :   y = _mm256_fmadd_ps(y, r, c8);
    1165              :   y = _mm256_fmadd_ps(y, r, c7);
    1166              :   y = _mm256_fmadd_ps(y, r, c6);
    1167              :   y = _mm256_fmadd_ps(y, r, c5);
    1168              :   y = _mm256_fmadd_ps(y, r, c4);
    1169              :   y = _mm256_fmadd_ps(y, r, c3);
    1170              :   y = _mm256_fmadd_ps(y, r, c2);
    1171              :   y = _mm256_fmadd_ps(y, r, c1);
    1172              :   y = _mm256_fmadd_ps(y, r, c0); // final y = (((...r+...)*r+...)*r + 1)
    1173              : 
    1174              :   // Reconstruct exp(x) = 2^n * exp(r)
    1175              :   __m256i emm0 = _mm256_cvtps_epi32(fx);
    1176              :   emm0 = _mm256_add_epi32(emm0, _mm256_set1_epi32(127));
    1177              :   emm0 = _mm256_slli_epi32(emm0, 23);
    1178              :   __m256 pow2n = _mm256_castsi256_ps(emm0);
    1179              : 
    1180            9 :   return _mm256_mul_ps(y, pow2n);
    1181              : }
    1182              : 
    1183            1 : static void softmax_row_inplace(float *qk_out, size_t start_row, size_t end_row,
    1184              :                                 size_t num_heads) {
    1185            1 :   size_t row_range = end_row - start_row;
    1186            1 :   const size_t full_blocks = (num_heads / 8) * 8;
    1187              :   // const size_t remainder = num_heads % 8;
    1188              : 
    1189            1 :   float *max_vals = new float[num_heads];
    1190            1 :   float *sum_vals = new float[num_heads];
    1191              :   // 1. max
    1192           11 :   for (size_t c = 0; c < num_heads; ++c) {
    1193           10 :     float max_val = -INFINITY;
    1194           40 :     for (size_t r = start_row; r < end_row; ++r)
    1195           49 :       max_val = std::max(max_val, qk_out[r * num_heads + c]);
    1196           10 :     max_vals[c] = max_val;
    1197              :   }
    1198              : 
    1199              :   // 2. inplace exp + sum
    1200            2 :   for (size_t c = 0; c < full_blocks; c += 8) {
    1201            1 :     __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
    1202              :     __m256 sum = _mm256_setzero_ps();
    1203            4 :     for (size_t r = 0; r < row_range; ++r) {
    1204            3 :       float *ptr = &qk_out[(start_row + r) * num_heads + c];
    1205              :       __m256 val = _mm256_loadu_ps(ptr);
    1206            3 :       __m256 e = exp256_ps(_mm256_sub_ps(val, maxv));
    1207              :       _mm256_storeu_ps(ptr, e); // overwrite qk_out
    1208              :       sum = _mm256_add_ps(sum, e);
    1209              :     }
    1210            1 :     _mm256_storeu_ps(&sum_vals[c], sum);
    1211              :   }
    1212              : 
    1213            3 :   for (size_t c = full_blocks; c < num_heads; ++c) {
    1214              :     float sum = 0.0f;
    1215            2 :     float maxv = max_vals[c];
    1216            8 :     for (size_t r = 0; r < row_range; ++r) {
    1217            6 :       float &a = qk_out[(start_row + r) * num_heads + c];
    1218            6 :       a = std::exp(a - maxv); // overwrite qk_out
    1219            6 :       sum += a;
    1220              :     }
    1221            2 :     sum_vals[c] = sum;
    1222              :   }
    1223              :   // 3. softmax = exp / sum (inplace)
    1224            4 :   for (size_t r = 0; r < row_range; ++r) {
    1225            6 :     for (size_t c = 0; c < full_blocks; c += 8) {
    1226            3 :       float *ptr = &qk_out[(start_row + r) * num_heads + c];
    1227              :       __m256 val = _mm256_loadu_ps(ptr); // already exp(x - max)
    1228            3 :       __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
    1229              :       __m256 soft = _mm256_div_ps(val, sumv);
    1230              :       _mm256_storeu_ps(ptr, soft);
    1231              :     }
    1232            9 :     for (size_t c = full_blocks; c < num_heads; ++c) {
    1233            6 :       qk_out[(start_row + r) * num_heads + c] /= sum_vals[c];
    1234              :     }
    1235              :   }
    1236              : 
    1237            1 :   delete[] max_vals;
    1238            1 :   delete[] sum_vals;
    1239            1 : }
    1240              : 
    1241            0 : static void softmax_row_with_sink_inplace(float *qk_out, size_t start_row,
    1242              :                                           size_t end_row, size_t num_heads,
    1243              :                                           float *sink) {
    1244            0 :   size_t row_range = end_row - start_row;
    1245            0 :   const size_t full_blocks = (num_heads / 8) * 8;
    1246              : 
    1247            0 :   float *max_vals = new float[num_heads];
    1248            0 :   float *sum_vals = new float[num_heads];
    1249              :   // 1. max
    1250            0 :   for (size_t c = 0; c < num_heads; ++c) {
    1251            0 :     float max_val = -INFINITY;
    1252            0 :     for (size_t r = start_row; r < end_row; ++r)
    1253            0 :       max_val = std::max(max_val, qk_out[r * num_heads + c]);
    1254            0 :     max_vals[c] = std::max(sink[c], max_val);
    1255              :   }
    1256              : 
    1257              :   // 2. inplace exp + sum
    1258            0 :   for (size_t c = 0; c < full_blocks; c += 8) {
    1259            0 :     __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
    1260            0 :     __m256 sum = _mm256_loadu_ps(&sink[c]);
    1261            0 :     sum = exp256_ps(_mm256_sub_ps(sum, maxv));
    1262            0 :     for (size_t r = 0; r < row_range; ++r) {
    1263            0 :       float *ptr = &qk_out[(start_row + r) * num_heads + c];
    1264              :       __m256 val = _mm256_loadu_ps(ptr);
    1265            0 :       __m256 e = exp256_ps(_mm256_sub_ps(val, maxv));
    1266              :       _mm256_storeu_ps(ptr, e); // overwrite qk_out
    1267              :       sum = _mm256_add_ps(sum, e);
    1268              :     }
    1269            0 :     _mm256_storeu_ps(&sum_vals[c], sum);
    1270              :   }
    1271              : 
    1272            0 :   for (size_t c = full_blocks; c < num_heads; ++c) {
    1273            0 :     float maxv = max_vals[c];
    1274            0 :     float sum = std::exp(sink[c] - maxv);
    1275            0 :     for (size_t r = 0; r < row_range; ++r) {
    1276            0 :       float &a = qk_out[(start_row + r) * num_heads + c];
    1277            0 :       a = std::exp(a - maxv); // overwrite qk_out
    1278            0 :       sum += a;
    1279              :     }
    1280            0 :     sum_vals[c] = sum;
    1281              :   }
    1282              :   // 3. softmax = exp / sum (inplace)
    1283            0 :   for (size_t r = 0; r < row_range; ++r) {
    1284            0 :     for (size_t c = 0; c < full_blocks; c += 8) {
    1285            0 :       float *ptr = &qk_out[(start_row + r) * num_heads + c];
    1286              :       __m256 val = _mm256_loadu_ps(ptr); // already exp(x - max)
    1287            0 :       __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
    1288              :       __m256 soft = _mm256_div_ps(val, sumv);
    1289              :       _mm256_storeu_ps(ptr, soft);
    1290              :     }
    1291            0 :     for (size_t c = full_blocks; c < num_heads; ++c) {
    1292            0 :       qk_out[(start_row + r) * num_heads + c] /= sum_vals[c];
    1293              :     }
    1294              :   }
    1295              : 
    1296            0 :   delete[] max_vals;
    1297            0 :   delete[] sum_vals;
    1298            0 : }
    1299              : 
    1300              : template <>
    1301            1 : void softmax_row_inplace(float *qk_out, size_t start_row, size_t end_row,
    1302              :                          size_t num_heads, float *sink) {
    1303            1 :   if (sink == nullptr) {
    1304            1 :     return softmax_row_inplace(qk_out, start_row, end_row, num_heads);
    1305              :   } else {
    1306            0 :     return softmax_row_with_sink_inplace(qk_out, start_row, end_row, num_heads,
    1307            0 :                                          sink);
    1308              :   }
    1309              : }
    1310              : 
    1311            1 : static void softmax_row(float *qk_out, size_t start_row, size_t end_row,
    1312              :                         size_t num_heads) {
    1313            1 :   const size_t full_block = (num_heads / 8) * 8;
    1314              : 
    1315            1 :   float *max_vals = new float[num_heads];
    1316            1 :   float *sum_vals = new float[num_heads];
    1317              : 
    1318              :   // 1. Find Max along with col
    1319           11 :   for (size_t c = 0; c < num_heads; ++c) {
    1320           10 :     float max_val = -INFINITY;
    1321           40 :     for (size_t r = start_row; r < end_row; ++r) {
    1322           49 :       max_val = std::max(max_val, qk_out[r * num_heads + c]);
    1323              :     }
    1324           10 :     max_vals[c] = max_val;
    1325              :   }
    1326              : 
    1327              :   // 2. Compute sum along with col (exp vectorized)
    1328            2 :   for (size_t c = 0; c < full_block; c += 8) {
    1329              :     __m256 sum = _mm256_setzero_ps();
    1330            4 :     for (size_t r = start_row; r < end_row; ++r) {
    1331            3 :       __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
    1332            3 :       __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
    1333              :       __m256 sub = _mm256_sub_ps(val, maxv);
    1334            3 :       __m256 e = exp256_ps(sub);
    1335              :       sum = _mm256_add_ps(sum, e);
    1336              :     }
    1337            1 :     _mm256_storeu_ps(&sum_vals[c], sum);
    1338              :   }
    1339              : 
    1340            3 :   for (size_t c = full_block; c < num_heads; ++c) {
    1341              :     float sum = 0.0f;
    1342            8 :     for (size_t r = start_row; r < end_row; ++r) {
    1343            6 :       sum += std::exp(qk_out[r * num_heads + c] - max_vals[c]);
    1344              :     }
    1345            2 :     sum_vals[c] = sum;
    1346              :   }
    1347              : 
    1348              :   // 3. apply softmax
    1349            4 :   for (size_t r = start_row; r < end_row; ++r) {
    1350            6 :     for (size_t c = 0; c < full_block; c += 8) {
    1351            3 :       __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
    1352            3 :       __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
    1353              :       __m256 sub = _mm256_sub_ps(val, maxv);
    1354            3 :       __m256 e = exp256_ps(sub);
    1355            3 :       __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
    1356              :       __m256 softmax = _mm256_div_ps(e, sumv);
    1357              :       _mm256_storeu_ps(&qk_out[r * num_heads + c], softmax);
    1358              :     }
    1359            9 :     for (size_t c = full_block; c < num_heads; ++c) {
    1360            6 :       qk_out[r * num_heads + c] =
    1361            6 :         std::exp(qk_out[r * num_heads + c] - max_vals[c]) / sum_vals[c];
    1362              :     }
    1363              :   }
    1364              : 
    1365            1 :   delete[] max_vals;
    1366            1 :   delete[] sum_vals;
    1367            1 : }
    1368              : 
    1369            0 : static void softmax_row_with_sink(float *qk_out, size_t start_row,
    1370              :                                   size_t end_row, size_t num_heads,
    1371              :                                   float *sink) {
    1372            0 :   const size_t full_block = (num_heads / 8) * 8;
    1373              : 
    1374            0 :   float *max_vals = new float[num_heads];
    1375            0 :   float *sum_vals = new float[num_heads];
    1376              : 
    1377              :   // 1. Find Max along with col
    1378            0 :   for (size_t c = 0; c < num_heads; ++c) {
    1379            0 :     float max_val = -INFINITY;
    1380            0 :     for (size_t r = start_row; r < end_row; ++r) {
    1381            0 :       max_val = std::max(max_val, qk_out[r * num_heads + c]);
    1382              :     }
    1383            0 :     max_vals[c] = std::max(max_val, sink[c]);
    1384              :   }
    1385              : 
    1386              :   // 2. Compute sum along with col (exp vectorized)
    1387            0 :   for (size_t c = 0; c < full_block; c += 8) {
    1388            0 :     __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
    1389            0 :     __m256 sum = _mm256_loadu_ps(&sink[c]);
    1390              :     sum = _mm256_sub_ps(sum, maxv);
    1391            0 :     sum = exp256_ps(sum);
    1392            0 :     for (size_t r = start_row; r < end_row; ++r) {
    1393            0 :       __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
    1394              :       __m256 sub = _mm256_sub_ps(val, maxv);
    1395            0 :       __m256 e = exp256_ps(sub);
    1396              :       sum = _mm256_add_ps(sum, e);
    1397              :     }
    1398            0 :     _mm256_storeu_ps(&sum_vals[c], sum);
    1399              :   }
    1400              : 
    1401            0 :   for (size_t c = full_block; c < num_heads; ++c) {
    1402            0 :     float sum = std::exp(sink[c] - max_vals[c]);
    1403            0 :     for (size_t r = start_row; r < end_row; ++r) {
    1404            0 :       sum += std::exp(qk_out[r * num_heads + c] - max_vals[c]);
    1405              :     }
    1406            0 :     sum_vals[c] = sum;
    1407              :   }
    1408              : 
    1409              :   // 3. apply softmax
    1410            0 :   for (size_t r = start_row; r < end_row; ++r) {
    1411            0 :     for (size_t c = 0; c < full_block; c += 8) {
    1412            0 :       __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
    1413            0 :       __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
    1414              :       __m256 sub = _mm256_sub_ps(val, maxv);
    1415            0 :       __m256 e = exp256_ps(sub);
    1416            0 :       __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
    1417              :       __m256 softmax = _mm256_div_ps(e, sumv);
    1418              :       _mm256_storeu_ps(&qk_out[r * num_heads + c], softmax);
    1419              :     }
    1420            0 :     for (size_t c = full_block; c < num_heads; ++c) {
    1421            0 :       qk_out[r * num_heads + c] =
    1422            0 :         std::exp(qk_out[r * num_heads + c] - max_vals[c]) / sum_vals[c];
    1423              :     }
    1424              :   }
    1425              : 
    1426            0 :   delete[] max_vals;
    1427            0 :   delete[] sum_vals;
    1428            0 : }
    1429              : 
    1430              : template <>
    1431            1 : void softmax_row(float *qk_out, size_t start_row, size_t end_row,
    1432              :                  size_t num_heads, float *sink) {
    1433            1 :   if (sink == nullptr) {
    1434            1 :     return softmax_row(qk_out, start_row, end_row, num_heads);
    1435              :   } else {
    1436            0 :     return softmax_row_with_sink(qk_out, start_row, end_row, num_heads, sink);
    1437              :   }
    1438              : }
    1439              : #ifdef _WIN32
    1440              : #define COMPUTE_FP16_TO_FP32(x)                                                \
    1441              :   _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
    1442              : #define COMPUTE_FP32_TO_FP16(x)                                                \
    1443              :   _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
    1444              : #elif defined(__TIZEN__) && !defined(__F16C__)
    1445              : #define COMPUTE_FP16_TO_FP32(x) nntrainer::compute_fp16_to_fp32(x)
    1446              : #define COMPUTE_FP32_TO_FP16(x) nntrainer::compute_fp32_to_fp16(x)
    1447              : #else
    1448              : #define COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
    1449              : #define COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
    1450              : #endif
    1451              : 
    1452              : static inline __m256 convert_vector_f16_to_f32(__m128i x) {
    1453              : #if defined(__TIZEN__) && !defined(__F16C__)
    1454              :   alignas(32) uint16_t u16_array[8]; // 32-byte aligned storage
    1455              :   alignas(32) float f32_array[8];    // 32-byte aligned storage
    1456              : 
    1457              :   // Safely store __m128i to array (avoids aliasing)
    1458              :   _mm_storeu_si128(reinterpret_cast<__m128i *>(u16_array), x);
    1459              : 
    1460              :   // Convert each FP16 value to FP32
    1461              :   for (int i = 0; i < 8; i++) {
    1462              :     f32_array[i] = COMPUTE_FP16_TO_FP32(u16_array[i]);
    1463              :   }
    1464              : 
    1465              :   // Load aligned array into __m256
    1466              :   return _mm256_load_ps(f32_array);
    1467              : #else
    1468              :   return _mm256_cvtph_ps(x);
    1469              : #endif
    1470              : }
    1471              : 
    1472              : static inline __m128i convert_vector_f32_to_f16(__m256 x) {
    1473              : #if defined(__TIZEN__) && !defined(__F16C__)
    1474              :   __m128i vec_f16;
    1475              :   float *f32_ptr = reinterpret_cast<float *>(&x);
    1476              :   uint16_t *u16_ptr = reinterpret_cast<uint16_t *>(&vec_f16);
    1477              :   for (int i = 0; i < 8; i++) {
    1478              :     u16_ptr[i] = COMPUTE_FP32_TO_FP16(f32_ptr[i]);
    1479              :   }
    1480              :   return vec_f16;
    1481              : #else
    1482              :   return _mm256_cvtps_ph(x, 0);
    1483              : #endif
    1484              : }
    1485              : 
    1486              : static inline __m128i convert_vector_f32_to_f16(__m128 x) {
    1487              : #if defined(__TIZEN__) && !defined(__F16C__)
    1488              :   __m128i vec_f16;
    1489              :   float *f32_ptr = reinterpret_cast<float *>(&x);
    1490              :   uint16_t *u16_ptr = reinterpret_cast<uint16_t *>(&vec_f16);
    1491              : 
    1492              :   for (int i = 0; i < 4; i++) {
    1493              :     u16_ptr[i] = COMPUTE_FP32_TO_FP16(f32_ptr[i]);
    1494              :   }
    1495              :   return vec_f16;
    1496              : #else
    1497              :   return _mm_cvtps_ph(x, 0);
    1498              : #endif
    1499              : }
    1500              : 
    1501           12 : static inline void load_fp16_8_to_chunk(const uint16_t *src, float *dst,
    1502              :                                         int chunk_size) {
    1503              :   int i = 0;
    1504           24 :   for (; i + 8 <= chunk_size; i += 8) {
    1505           12 :     __m128i half = _mm_loadu_si128(reinterpret_cast<const __m128i *>(src + i));
    1506              :     __m256 f32 = convert_vector_f16_to_f32(half);
    1507           12 :     _mm256_storeu_ps(&dst[i], f32);
    1508              :   }
    1509           32 :   for (; i < chunk_size; ++i) {
    1510           20 :     dst[i] = nntrainer::compute_fp16_to_fp32(src[i]);
    1511              :   }
    1512           12 : }
    1513              : 
    1514            1 : void compute_fp16vcache_fp32_transposed(int row_num, const float *in,
    1515              :                                         const uint16_t *vcache, float *output,
    1516              :                                         int num_cache_head, int gqa_size,
    1517              :                                         int head_dim,
    1518              :                                         size_t local_window_size) {
    1519              :   // cpu_set_t cpu_set;
    1520              :   // CPU_ZERO(&cpu_set);
    1521              :   // std::vector<bool> affinity(8, false);
    1522              :   // affinity[i % affinity.size()] = true;
    1523              : 
    1524              :   // for (std::size_t j = 0;
    1525              :   //      j < std::min<std::size_t>(affinity.size(), CPU_SETSIZE); ++j) {
    1526              :   //   if (affinity[j])
    1527              :   //     CPU_SET(j, &cpu_set);
    1528              :   // }
    1529              :   // pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpu_set);
    1530              : 
    1531            1 :   std::vector<float> tmp_fp32(head_dim);
    1532            1 :   int num_blocks = head_dim / 8;
    1533            2 :   __m256 *sumVec = new __m256[std::max(1, num_blocks * gqa_size)];
    1534              : 
    1535            3 :   for (int n = 0; n < num_cache_head; ++n) {
    1536            2 :     int rem = head_dim % 8;
    1537              : 
    1538              :     /* Declaration: std::vector<__m256> sumVec(num_blocks * gqa_size,
    1539              :      * _mm256_setzero_ps()); caused warning: ignoring attributes on template
    1540              :      * argument ‘__m256’ [-Wignored-attributes].
    1541              :      * So it is implemented that way.
    1542              :      */
    1543            6 :     for (int i = 0; i < num_blocks * gqa_size; i++) {
    1544            4 :       sumVec[i] = _mm256_setzero_ps();
    1545              :     }
    1546            2 :     std::vector<float> sumRem((size_t)gqa_size * rem, 0.0f);
    1547              : 
    1548            6 :     for (int j = row_num < local_window_size ? 0
    1549            0 :                                              : row_num + 1 - local_window_size;
    1550            6 :          j <= row_num; ++j) {
    1551            4 :       const uint16_t *vptr = vcache + (j * num_cache_head + n) * head_dim;
    1552            4 :       load_fp16_8_to_chunk(vptr, tmp_fp32.data(), head_dim);
    1553              : 
    1554           12 :       for (int h = 0; h < gqa_size; ++h) {
    1555            8 :         float a_val =
    1556              :           in[(row_num < local_window_size
    1557            8 :                 ? j
    1558            0 :                 : (unsigned long)(j - (row_num + 1 - local_window_size))) *
    1559            8 :                (unsigned long)(gqa_size * num_cache_head) +
    1560            8 :              (unsigned long)(n * gqa_size) + h];
    1561              : 
    1562              :         __m256 inVec = _mm256_set1_ps(a_val);
    1563              : 
    1564           16 :         for (int b = 0; b < num_blocks; ++b) {
    1565            8 :           __m256 bVec = _mm256_loadu_ps(&tmp_fp32[b * 8]);
    1566            8 :           sumVec[h * num_blocks + b] =
    1567            8 :             _mm256_fmadd_ps(inVec, bVec, sumVec[h * num_blocks + b]);
    1568              :         }
    1569              : 
    1570            8 :         float *remPtr = &sumRem.data()[h * rem];
    1571            8 :         int base = num_blocks * 8;
    1572           16 :         for (int r = 0; r < rem; ++r) {
    1573            8 :           remPtr[r] += a_val * tmp_fp32[base + r];
    1574              :         }
    1575              :       }
    1576              :     }
    1577              : 
    1578            6 :     for (int h = 0; h < gqa_size; ++h) {
    1579            8 :       for (int b = 0; b < num_blocks; ++b) {
    1580            4 :         int out_base = (n * gqa_size + h) * head_dim + b * 8;
    1581            4 :         _mm256_storeu_ps(&output[out_base], sumVec[h * num_blocks + b]);
    1582              :       }
    1583              : 
    1584            4 :       float *remPtr = &sumRem.data()[h * rem];
    1585              :       // float *remPtr = &sumRem[h * rem];
    1586            4 :       int base = num_blocks * 8;
    1587            8 :       for (int r = 0; r < rem; ++r) {
    1588            4 :         int out_idx = (n * gqa_size + h) * head_dim + base + r;
    1589            4 :         output[out_idx] = remPtr[r];
    1590              :       }
    1591              :     }
    1592            2 :   }
    1593            1 :   delete[] sumVec;
    1594            1 : }
    1595              : 
    1596              : template <>
    1597            1 : void compute_kcaches(const float *in, const uint16_t *kcache, float *output,
    1598              :                      int num_rows, int num_cache_head, int head_dim,
    1599              :                      int gqa_size, int tile_size, size_t local_window_size) {
    1600            1 :   std::vector<float> tmp_fp32(head_dim);
    1601              : 
    1602            1 :   int start_row =
    1603            1 :     num_rows < local_window_size ? 0 : num_rows - local_window_size;
    1604            1 :   int row_cnt = num_rows < local_window_size ? num_rows : local_window_size;
    1605            1 :   const int tile_count = (row_cnt + tile_size - 1) / tile_size;
    1606              : 
    1607            3 :   for (int n = 0; n < num_cache_head; ++n) {
    1608            4 :     for (int t = 0; t < tile_count; ++t) {
    1609            2 :       int row_tile_start = t * tile_size;
    1610            2 :       int tile_rows = std::min(tile_size, row_cnt - row_tile_start);
    1611              : 
    1612           10 :       for (int g = 0; g < gqa_size; ++g) {
    1613            8 :         const float *in_ptr = in + n * gqa_size * head_dim + g * head_dim;
    1614           16 :         for (int t_row = 0; t_row < tile_rows; ++t_row) {
    1615            8 :           int row = start_row + row_tile_start + t_row;
    1616            8 :           if (row + 1 < num_rows) {
    1617            0 :             const uint16_t *next_kptr =
    1618            0 :               kcache + ((row + 1) * num_cache_head + n) * head_dim;
    1619              :             _mm_prefetch(reinterpret_cast<const char *>(next_kptr),
    1620              :                          _MM_HINT_T0);
    1621              :           }
    1622            8 :           const uint16_t *kptr = kcache + (row * num_cache_head + n) * head_dim;
    1623            8 :           load_fp16_8_to_chunk(kptr, tmp_fp32.data(), head_dim);
    1624              : 
    1625              :           const float *k_row = tmp_fp32.data();
    1626              : 
    1627              :           float sum = 0.0f;
    1628              :           int i = 0;
    1629              :           __m256 acc = _mm256_setzero_ps();
    1630           16 :           for (; i + 8 <= head_dim; i += 8) {
    1631            8 :             __m256 va = _mm256_loadu_ps(in_ptr + i);
    1632            8 :             __m256 vb = _mm256_loadu_ps(k_row + i);
    1633              :             acc = _mm256_fmadd_ps(va, vb, acc);
    1634              :           }
    1635              : 
    1636              :           __m128 low = _mm256_castps256_ps128(acc);
    1637              :           __m128 high = _mm256_extractf128_ps(acc, 1);
    1638              :           __m128 sum128 = _mm_add_ps(low, high);
    1639              :           sum128 = _mm_hadd_ps(sum128, sum128);
    1640              :           sum128 = _mm_hadd_ps(sum128, sum128);
    1641            8 :           sum += _mm_cvtss_f32(sum128);
    1642              : 
    1643           24 :           for (; i < head_dim; ++i)
    1644           16 :             sum += in_ptr[i] * k_row[i];
    1645              : 
    1646            8 :           output[(row - start_row) * num_cache_head * gqa_size + n * gqa_size +
    1647            8 :                  g] = sum / sqrt((float)head_dim);
    1648              :         }
    1649              :       }
    1650              :     }
    1651              :   }
    1652            1 : }
    1653              : 
    1654            2 : void compute_rotary_emb_value(unsigned int width, unsigned int dim,
    1655              :                               unsigned int half_, float *inout, void *output,
    1656              :                               const float *cos_, const float *sin_,
    1657              :                               bool only_convert_to_fp16) {
    1658              :   enum class OutputType { FP16, FP32 };
    1659              : 
    1660              :   OutputType out_type = OutputType::FP32;
    1661            2 :   if (output != nullptr)
    1662              :     out_type = OutputType::FP16;
    1663              : 
    1664            6 :   for (unsigned int w = 0; w < width; w += dim) {
    1665              :     unsigned int k = 0;
    1666            8 :     for (; k + 7 < half_; k += 8) {
    1667            4 :       unsigned int i0 = w + k;
    1668            4 :       unsigned int i1 = w + k + half_;
    1669              : 
    1670            4 :       __m256 a = _mm256_loadu_ps(&inout[i0]);
    1671            4 :       __m256 b = _mm256_loadu_ps(&inout[i1]);
    1672              : 
    1673            4 :       if (only_convert_to_fp16) {
    1674            0 :         if (out_type == OutputType::FP16) {
    1675              :           __m128i a_fp16 = convert_vector_f32_to_f16(a);
    1676              :           __m128i b_fp16 = convert_vector_f32_to_f16(b);
    1677              : 
    1678            0 :           _mm_storeu_si128(
    1679            0 :             reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i0),
    1680              :             a_fp16);
    1681            0 :           _mm_storeu_si128(
    1682            0 :             reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i1),
    1683              :             b_fp16);
    1684              :         }
    1685              : 
    1686              :       } else {
    1687            4 :         __m256 cos_v = _mm256_loadu_ps(&cos_[k]);
    1688            4 :         __m256 sin_v = _mm256_loadu_ps(&sin_[k]);
    1689              : 
    1690              :         __m256 out0 =
    1691              :           _mm256_sub_ps(_mm256_mul_ps(a, cos_v), _mm256_mul_ps(b, sin_v));
    1692              :         __m256 out1 =
    1693              :           _mm256_add_ps(_mm256_mul_ps(a, sin_v), _mm256_mul_ps(b, cos_v));
    1694              : 
    1695            4 :         if (out_type == OutputType::FP16) {
    1696              :           __m128i out0_fp16 = convert_vector_f32_to_f16(out0);
    1697              :           __m128i out1_fp16 = convert_vector_f32_to_f16(out1);
    1698              : 
    1699            2 :           _mm_storeu_si128(
    1700            2 :             reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i0),
    1701              :             out0_fp16);
    1702            2 :           _mm_storeu_si128(
    1703            2 :             reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i1),
    1704              :             out1_fp16);
    1705              : 
    1706              :         } else if (out_type == OutputType::FP32) {
    1707              :           _mm256_storeu_ps(&inout[i0], out0);
    1708              :           _mm256_storeu_ps(&inout[i1], out1);
    1709              :         }
    1710              :       }
    1711              :     }
    1712              : 
    1713           12 :     for (; k < half_; ++k) {
    1714            8 :       unsigned int i0 = w + k;
    1715            8 :       unsigned int i1 = w + k + half_;
    1716              :       // assert(i1 < width && "Scalar i1 overflow!");
    1717            8 :       float a = inout[i0];
    1718            8 :       float b = inout[i1];
    1719              : 
    1720            8 :       if (only_convert_to_fp16) {
    1721            0 :         static_cast<uint16_t *>(output)[i0] = COMPUTE_FP32_TO_FP16(a);
    1722            0 :         static_cast<uint16_t *>(output)[i1] = COMPUTE_FP32_TO_FP16(b);
    1723              :       } else {
    1724            8 :         float c = cos_[k];
    1725            8 :         float s = sin_[k];
    1726              : 
    1727            8 :         float out0 = a * c - b * s;
    1728            8 :         float out1 = a * s + b * c;
    1729              : 
    1730            8 :         if (out_type == OutputType::FP16) {
    1731            4 :           static_cast<uint16_t *>(output)[i0] = COMPUTE_FP32_TO_FP16(out0);
    1732            8 :           static_cast<uint16_t *>(output)[i1] = COMPUTE_FP32_TO_FP16(out1);
    1733              :         } else if (out_type == OutputType::FP32) {
    1734            4 :           inout[i0] = out0;
    1735            4 :           inout[i1] = out1;
    1736              :         }
    1737              :       }
    1738              :     }
    1739              :   }
    1740            2 : }
    1741              : 
    1742              : static float hsum_avx(__m256 v) {
    1743              :   __m128 vlow = _mm256_castps256_ps128(v);
    1744              :   __m128 vhigh = _mm256_extractf128_ps(v, 1);
    1745              :   vlow = _mm_add_ps(vlow, vhigh);
    1746              :   __m128 shuf = _mm_movehdup_ps(vlow);
    1747              :   __m128 sums = _mm_add_ps(vlow, shuf);
    1748              :   shuf = _mm_movehl_ps(shuf, sums);
    1749              :   sums = _mm_add_ss(sums, shuf);
    1750              :   return _mm_cvtss_f32(sums);
    1751              : }
    1752              : 
    1753            0 : void rms_norm_wrt_width_fp32_intrinsic(const float *__restrict X,
    1754              :                                        float *__restrict Y, size_t H, size_t W,
    1755              :                                        float epsilon) {
    1756            0 :   for (std::size_t h = 0; h < H; ++h) {
    1757            0 :     const float *rowX = X + h * W;
    1758            0 :     float *rowY = Y + h * W;
    1759              : 
    1760              :     std::size_t i = 0;
    1761              :     __m256 acc0 = _mm256_setzero_ps();
    1762              :     __m256 acc1 = _mm256_setzero_ps();
    1763              :     __m256 acc2 = _mm256_setzero_ps();
    1764              :     __m256 acc3 = _mm256_setzero_ps();
    1765              : 
    1766            0 :     for (; i + 32 <= W; i += 32) {
    1767            0 :       __m256 x0 = _mm256_loadu_ps(rowX + i);
    1768            0 :       __m256 x1 = _mm256_loadu_ps(rowX + i + 8);
    1769            0 :       __m256 x2 = _mm256_loadu_ps(rowX + i + 16);
    1770            0 :       __m256 x3 = _mm256_loadu_ps(rowX + i + 24);
    1771              :       acc0 = _mm256_fmadd_ps(x0, x0, acc0);
    1772              :       acc1 = _mm256_fmadd_ps(x1, x1, acc1);
    1773              :       acc2 = _mm256_fmadd_ps(x2, x2, acc2);
    1774              :       acc3 = _mm256_fmadd_ps(x3, x3, acc3);
    1775              :     }
    1776            0 :     for (; i + 8 <= W; i += 8) {
    1777            0 :       __m256 x = _mm256_loadu_ps(rowX + i);
    1778              :       acc0 = _mm256_fmadd_ps(x, x, acc0);
    1779              :     }
    1780              :     float sumsq =
    1781            0 :       hsum_avx(acc0) + hsum_avx(acc1) + hsum_avx(acc2) + hsum_avx(acc3);
    1782            0 :     for (; i < W; ++i) {
    1783            0 :       float v = rowX[i];
    1784            0 :       sumsq += v * v;
    1785              :     }
    1786              : 
    1787            0 :     float mean = sumsq / static_cast<float>(W);
    1788            0 :     float scale = 1.0f / std::sqrt(mean + epsilon);
    1789              :     __m256 vscale = _mm256_set1_ps(scale);
    1790              : 
    1791              :     i = 0;
    1792            0 :     for (; i + 32 <= W; i += 32) {
    1793            0 :       __m256 x0 = _mm256_loadu_ps(rowX + i);
    1794            0 :       __m256 x1 = _mm256_loadu_ps(rowX + i + 8);
    1795            0 :       __m256 x2 = _mm256_loadu_ps(rowX + i + 16);
    1796            0 :       __m256 x3 = _mm256_loadu_ps(rowX + i + 24);
    1797            0 :       _mm256_storeu_ps(rowY + i, _mm256_mul_ps(x0, vscale));
    1798            0 :       _mm256_storeu_ps(rowY + i + 8, _mm256_mul_ps(x1, vscale));
    1799            0 :       _mm256_storeu_ps(rowY + i + 16, _mm256_mul_ps(x2, vscale));
    1800            0 :       _mm256_storeu_ps(rowY + i + 24, _mm256_mul_ps(x3, vscale));
    1801              :     }
    1802            0 :     for (; i + 8 <= W; i += 8) {
    1803            0 :       __m256 x = _mm256_loadu_ps(rowX + i);
    1804            0 :       _mm256_storeu_ps(rowY + i, _mm256_mul_ps(x, vscale));
    1805              :     }
    1806            0 :     for (; i < W; ++i) {
    1807            0 :       rowY[i] = rowX[i] * scale;
    1808              :     }
    1809              :   }
    1810            0 : }
    1811              : 
    1812              : template <>
    1813           21 : void clamp(const float *input, float *output, size_t length, float lower_bound,
    1814              :            float upper_bound) {
    1815              :   const size_t step = 8;
    1816              :   const __m256 vLo = _mm256_set1_ps(lower_bound);
    1817              :   const __m256 vHi = _mm256_set1_ps(upper_bound);
    1818              : 
    1819              :   size_t i = 0;
    1820         8085 :   for (; i + step <= length; i += step) {
    1821         8064 :     __m256 v = _mm256_loadu_ps(input + i);
    1822              :     v = _mm256_max_ps(v, vLo);
    1823              :     v = _mm256_min_ps(v, vHi);
    1824         8064 :     _mm256_storeu_ps(output + i, v);
    1825              :   }
    1826           21 :   if (i < length) {
    1827            0 :     for (size_t k = i; k < length; ++k) {
    1828            0 :       float v = input[k];
    1829              :       // If v is NaN, the comparisons below will yield false; we keep NaN.
    1830              :       // This matches most framework "pass-through NaN" behavior.
    1831            0 :       output[k] =
    1832            0 :         (v < lower_bound) ? lower_bound : ((v > upper_bound) ? upper_bound : v);
    1833              :     }
    1834              :   }
    1835           21 : }
    1836              : 
    1837            1 : void copy_f16_f32(unsigned int N, const uint16_t *input, float *output) {
    1838              :   unsigned int idx = 0;
    1839              :   const uint16_t *data = (const uint16_t *)input;
    1840              : 
    1841              :   // 16 half-precision floating point values to single-precision values
    1842            6 :   for (; N - idx >= 16; idx += 16) {
    1843              :     const __m256 vec0 =
    1844              :       convert_vector_f16_to_f32(_mm_loadu_si128((const __m128i *)data));
    1845              :     const __m256 vec1 =
    1846              :       convert_vector_f16_to_f32(_mm_loadu_si128((const __m128i *)(data + 8)));
    1847            5 :     data += 16;
    1848              : 
    1849              :     _mm256_storeu_ps(output, vec0);
    1850              :     _mm256_storeu_ps(output + 8, vec1);
    1851            5 :     output += 16;
    1852              :   }
    1853              :   // 8 half-precision floating point values to single-precision values
    1854            2 :   for (; N - idx >= 8; idx += 8) {
    1855              :     const __m256 vec =
    1856              :       convert_vector_f16_to_f32(_mm_loadu_si128((const __m128i *)data));
    1857            1 :     data += 8;
    1858              : 
    1859              :     _mm256_storeu_ps(output, vec);
    1860            1 :     output += 8;
    1861              :   }
    1862              :   // remaining half-precision floating point values to single-precision values
    1863            3 :   while (idx < N) {
    1864            2 :     *output = compute_fp16_to_fp32(*data);
    1865            2 :     ++output;
    1866            2 :     ++data;
    1867            2 :     ++idx;
    1868              :   }
    1869            1 : }
    1870              : 
    1871            0 : void copy_f32_f16(unsigned int N, const float *input, uint16_t *output) {
    1872              :   unsigned int idx = 0;
    1873              :   uint16_t *out_data = (uint16_t *)output;
    1874              : 
    1875              :   // 16 single-precision floating point values to half-precision values
    1876            0 :   for (; N - idx >= 16; idx += 16) {
    1877              :     const __m256 vec0 = _mm256_loadu_ps(input);
    1878              :     const __m256 vec1 = _mm256_loadu_ps(input + 8);
    1879            0 :     input += 16;
    1880              : 
    1881              :     _mm_storeu_si128((__m128i *)out_data, convert_vector_f32_to_f16(vec0));
    1882              :     _mm_storeu_si128((__m128i *)(out_data + 8),
    1883              :                      convert_vector_f32_to_f16(vec1));
    1884            0 :     out_data += 16;
    1885              :   }
    1886              :   // 8 single-precision floating point values to half-precision values
    1887            0 :   for (; N - idx >= 8; idx += 8) {
    1888              :     const __m256 vec = _mm256_loadu_ps(input);
    1889            0 :     input += 8;
    1890              : 
    1891              :     _mm_storeu_si128((__m128i *)out_data, convert_vector_f32_to_f16(vec));
    1892            0 :     out_data += 8;
    1893              :   }
    1894              :   // 4 single-precision floating point values to half-precision values
    1895            0 :   for (; N - idx >= 4; idx += 4) {
    1896              :     const __m128 vec = _mm_loadu_ps(input);
    1897            0 :     input += 4;
    1898              : 
    1899              :     _mm_storeu_si64((__m128i *)out_data, convert_vector_f32_to_f16(vec));
    1900            0 :     out_data += 4;
    1901              :   }
    1902              :   // remaining single-precision floating point values to half-precision values
    1903            0 :   while (idx < N) {
    1904            0 :     *out_data = compute_fp32_to_fp16(*input);
    1905            0 :     ++out_data;
    1906            0 :     ++input;
    1907            0 :     ++idx;
    1908              :   }
    1909            0 : }
    1910              : 
    1911            0 : void create_q4_0_weights(const uint8_t *int4_weight, uint8_t *q4_0_weight) {
    1912              :   // Load 16 bytes of input data
    1913              :   __m128i input = _mm_loadu_si128((const __m128i *)int4_weight);
    1914              : 
    1915              :   // Create masks for extracting low and high nibbles
    1916              :   const __m128i low_nibble_mask = _mm_set1_epi8(0x0F);
    1917              :   const __m128i high_nibble_mask = _mm_set1_epi8(static_cast<char>(0xF0));
    1918              : 
    1919              :   // Extract low nibbles from first 8 bytes
    1920              :   __m128i A = _mm_and_si128(input, low_nibble_mask);
    1921              : 
    1922              :   // Extract high nibbles from first 8 bytes and shift right
    1923              :   __m128i B = _mm_and_si128(input, high_nibble_mask);
    1924              :   B = _mm_srli_epi16(B, 4);
    1925              : 
    1926              :   // Extract low nibbles from second 8 bytes
    1927              :   __m128i input_shifted = _mm_bsrli_si128(input, 8);
    1928              :   __m128i C = _mm_and_si128(input_shifted, low_nibble_mask);
    1929              : 
    1930              :   // Extract high nibbles from second 8 bytes and shift right
    1931              :   __m128i D = _mm_and_si128(input_shifted, high_nibble_mask);
    1932              :   D = _mm_srli_epi16(D, 4);
    1933              : 
    1934              :   // Interleave low nibbles: v0 from first8, v2 from second8
    1935              :   __m128i AC = _mm_or_si128(A, _mm_slli_epi16(C, 4));
    1936              : 
    1937              :   // Interleave high nibbles: v1 from first8, v3 from second8
    1938              :   __m128i BD = _mm_or_si128(B, _mm_slli_epi16(D, 4));
    1939              : 
    1940              :   // Pack the results: interleave low and high bytes
    1941              :   __m128i result = _mm_unpacklo_epi8(AC, BD);
    1942              : 
    1943              :   // Store the 16 bytes result
    1944              :   _mm_storeu_si128((__m128i *)q4_0_weight, result);
    1945            0 : }
    1946              : 
    1947     10965600 : static inline void transpose_matrix_16x16(const uint8_t *input,
    1948              :                                           int input_stride, uint8_t *output,
    1949              :                                           int output_stride) {
    1950              :   const uint8_t *src = input;
    1951              :   uint8_t *dst = output;
    1952              : 
    1953              :   __m256i rows[8];
    1954     98690400 :   for (int i = 0; i < 8; ++i) {
    1955     87724800 :     rows[i] =
    1956     87724800 :       _mm256_loadu2_m128i((const __m128i *)(src + (8 + i) * input_stride),
    1957     87724800 :                           (const __m128i *)(src + i * input_stride));
    1958              :   }
    1959              : 
    1960              :   // Step 1: Transpose within 2x2 sub-blocks
    1961     10965600 :   __m256i temp0 = _mm256_unpacklo_epi8(rows[0], rows[1]);
    1962              :   __m256i temp1 = _mm256_unpackhi_epi8(rows[0], rows[1]);
    1963     10965600 :   __m256i temp2 = _mm256_unpacklo_epi8(rows[2], rows[3]);
    1964              :   __m256i temp3 = _mm256_unpackhi_epi8(rows[2], rows[3]);
    1965     10965600 :   __m256i temp4 = _mm256_unpacklo_epi8(rows[4], rows[5]);
    1966              :   __m256i temp5 = _mm256_unpackhi_epi8(rows[4], rows[5]);
    1967     10965600 :   __m256i temp6 = _mm256_unpacklo_epi8(rows[6], rows[7]);
    1968              :   __m256i temp7 = _mm256_unpackhi_epi8(rows[6], rows[7]);
    1969              : 
    1970              :   // Step 2: Transpose within 4x4 sub-blocks
    1971              :   __m256i interleave0 = _mm256_unpacklo_epi16(temp0, temp2);
    1972              :   __m256i interleave1 = _mm256_unpackhi_epi16(temp0, temp2);
    1973              :   __m256i interleave2 = _mm256_unpacklo_epi16(temp1, temp3);
    1974              :   __m256i interleave3 = _mm256_unpackhi_epi16(temp1, temp3);
    1975              :   __m256i interleave4 = _mm256_unpacklo_epi16(temp4, temp6);
    1976              :   __m256i interleave5 = _mm256_unpackhi_epi16(temp4, temp6);
    1977              :   __m256i interleave6 = _mm256_unpacklo_epi16(temp5, temp7);
    1978              :   __m256i interleave7 = _mm256_unpackhi_epi16(temp5, temp7);
    1979              : 
    1980              :   // Step 3: Transpose within 8x8 block
    1981              :   __m256i final0 = _mm256_unpacklo_epi32(interleave0, interleave4);
    1982              :   __m256i final1 = _mm256_unpackhi_epi32(interleave0, interleave4);
    1983              :   __m256i final2 = _mm256_unpacklo_epi32(interleave1, interleave5);
    1984              :   __m256i final3 = _mm256_unpackhi_epi32(interleave1, interleave5);
    1985              :   __m256i final4 = _mm256_unpacklo_epi32(interleave2, interleave6);
    1986              :   __m256i final5 = _mm256_unpackhi_epi32(interleave2, interleave6);
    1987              :   __m256i final6 = _mm256_unpacklo_epi32(interleave3, interleave7);
    1988              :   __m256i final7 = _mm256_unpackhi_epi32(interleave3, interleave7);
    1989              : 
    1990              :   // Step 4: Transpose within 16x16 block
    1991              :   __m256i res[8];
    1992     10965600 :   res[0] = _mm256_unpacklo_epi64(final0, final4);
    1993     10965600 :   res[1] = _mm256_unpackhi_epi64(final0, final4);
    1994     10965600 :   res[2] = _mm256_unpacklo_epi64(final1, final5);
    1995     10965600 :   res[3] = _mm256_unpackhi_epi64(final1, final5);
    1996     10965600 :   res[4] = _mm256_unpacklo_epi64(final2, final6);
    1997     10965600 :   res[5] = _mm256_unpackhi_epi64(final2, final6);
    1998     10965600 :   res[6] = _mm256_unpacklo_epi64(final3, final7);
    1999     10965600 :   res[7] = _mm256_unpackhi_epi64(final3, final7);
    2000              : 
    2001              :   const int perm_0213 = 0xd8; // 0, 2, 1, 3
    2002              :   const int perm_02 = 0x20;   // 0, 2
    2003              :   const int perm_13 = 0x31;   // 1, 3
    2004     54828000 :   for (int i = 0; i < 4; i++) {
    2005     43862400 :     __m256i a128x2 = _mm256_permute4x64_epi64(res[2 * i], perm_0213);
    2006     43862400 :     __m256i b128x2 = _mm256_permute4x64_epi64(res[2 * i + 1], perm_0213);
    2007     43862400 :     _mm256_storeu_si256((__m256i *)&dst[2 * i * output_stride],
    2008              :                         _mm256_permute2x128_si256(a128x2, b128x2, perm_02));
    2009     43862400 :     _mm256_storeu_si256((__m256i *)&dst[(8 + 2 * i) * output_stride],
    2010              :                         _mm256_permute2x128_si256(a128x2, b128x2, perm_13));
    2011              :   }
    2012     10965600 : }
    2013              : 
    2014     21908200 : static inline void create_q4_0_weights_x8(const uint8_t *int4_weight,
    2015              :                                           __m256i *q4_blocks) {
    2016              :   constexpr const size_t ROW_BLOCK_BYTE_SIZE = 16;
    2017              : 
    2018              :   // Create masks for extracting low and high nibbles
    2019              :   const __m256i low_nibble_mask = _mm256_set1_epi8(0x0F);
    2020              :   const __m256i high_nibble_mask = _mm256_set1_epi8(0xF0);
    2021              : 
    2022              :   // Create two blocks in one iteration
    2023    109541000 :   for (int i = 0; i < 4; ++i) {
    2024              :     // Load 16 bytes of input data
    2025              :     __m256i input = _mm256_loadu_si256(
    2026     87632800 :       (const __m256i *)(int4_weight + 2 * ROW_BLOCK_BYTE_SIZE * i));
    2027              : 
    2028              :     // A = input & low_nibble_mask
    2029              :     __m256i A = _mm256_and_si256(input, low_nibble_mask);
    2030              : 
    2031              :     // B = (input & high_nibble_mask) >> 4
    2032              :     __m256i B = _mm256_srli_epi16(_mm256_and_si256(input, high_nibble_mask), 4);
    2033              : 
    2034              :     // input_shifted = input >> 8 bytes
    2035              :     __m256i input_shifted = _mm256_bsrli_epi128(input, 8);
    2036              :     // C = input_shifted & low_nibble_mask
    2037              :     __m256i C = _mm256_and_si256(input_shifted, low_nibble_mask);
    2038              : 
    2039              :     // D = (input_shifted & high_nibble_mask) >> 4
    2040              :     __m256i D =
    2041              :       _mm256_srli_epi16(_mm256_and_si256(input_shifted, high_nibble_mask), 4);
    2042              : 
    2043              :     // AC = A | (C << 4)
    2044              :     __m256i AC = _mm256_or_si256(A, _mm256_slli_epi16(C, 4));
    2045              : 
    2046              :     // BD = B | (D << 4)
    2047              :     __m256i BD = _mm256_or_si256(B, _mm256_slli_epi16(D, 4));
    2048              : 
    2049              :     // Interleave AC and BD
    2050              :     __m256i result = _mm256_unpacklo_epi8(AC, BD);
    2051              : 
    2052     87632800 :     _mm256_store_si256(&q4_blocks[i], result);
    2053              :   }
    2054     21908200 : }
    2055              : 
    2056     21908200 : inline static void nntr_make_block_q4_0x8(const __m256i *in, block_q4_0x8 *out,
    2057              :                                           const uint16_t *scales) {
    2058              :   constexpr size_t IN_CNT = 8;
    2059     21908200 :   memcpy(out->d, scales, IN_CNT * sizeof(uint16_t));
    2060              : 
    2061              :   const int perm_0213 = 0xd8; // 0, 2, 1, 3
    2062              :   const int perm_02 = 0x20;   // 0, 2
    2063              :   const int perm_13 = 0x31;   // 1, 3
    2064     21908200 :   __m256i a128x2 = _mm256_permute4x64_epi64(*(__m256i *)&in[0], perm_0213);
    2065     21908200 :   __m256i b128x2 = _mm256_permute4x64_epi64(*(__m256i *)&in[1], perm_0213);
    2066     21908200 :   __m256i c128x2 = _mm256_permute4x64_epi64(*(__m256i *)&in[2], perm_0213);
    2067     21908200 :   __m256i d128x2 = _mm256_permute4x64_epi64(*(__m256i *)&in[3], perm_0213);
    2068              :   _mm256_storeu_si256((__m256i *)&out->qs[0],
    2069              :                       _mm256_permute2x128_si256(a128x2, b128x2, perm_02));
    2070              :   _mm256_storeu_si256((__m256i *)&out->qs[32],
    2071              :                       _mm256_permute2x128_si256(c128x2, d128x2, perm_02));
    2072              :   _mm256_storeu_si256((__m256i *)&out->qs[64],
    2073              :                       _mm256_permute2x128_si256(a128x2, b128x2, perm_13));
    2074              :   _mm256_storeu_si256((__m256i *)&out->qs[96],
    2075              :                       _mm256_permute2x128_si256(c128x2, d128x2, perm_13));
    2076     21908200 : }
    2077              : 
    2078         2400 : void transform_int4_osv32_isv2_to_q4_0x8(size_t N, size_t K,
    2079              :                                          const uint8_t *osv32_weights,
    2080              :                                          const uint16_t *osv32_scales,
    2081              :                                          size_t scale_group_size,
    2082              :                                          void *dst_q4_0x) {
    2083              : 
    2084         2400 :   NNTR_THROW_IF((!(scale_group_size == 32 || scale_group_size == 64 ||
    2085              :                    scale_group_size == 128)),
    2086              :                 std::invalid_argument)
    2087              :     << "Scale group size must be 32/64/128";
    2088         2400 :   NNTR_THROW_IF(K % QK4_0 != 0, std::invalid_argument)
    2089              :     << "K size must be divisable by QK4_0 (32)";
    2090         2400 :   NNTR_THROW_IF(N % 8 != 0, std::invalid_argument)
    2091              :     << "N size must be divisable by 8";
    2092              : 
    2093              :   static constexpr const size_t NUM_Q4_0_BLOCKS = 8;
    2094              :   static constexpr const size_t ROW_BLOCK_SIZE = 32;
    2095              :   static constexpr const size_t COLUMN_BLOCK_SIZE = 2;
    2096              :   static constexpr const size_t ROW_BLOCK_BYTE_SIZE = 16;
    2097              : 
    2098              :   static constexpr const size_t dst_tmp_size =
    2099              :     (8 * ROW_BLOCK_BYTE_SIZE) / sizeof(__m256i);
    2100              :   uint8_t *dst_ = reinterpret_cast<uint8_t *>(dst_q4_0x);
    2101              : 
    2102              :   // --- Layout ---
    2103         2400 :   const size_t rows_count_pad = align(N, ROW_BLOCK_SIZE);
    2104         2400 :   const size_t columns_count_pad = align(K, ROW_BLOCK_SIZE);
    2105         2400 :   const size_t column_blocks_count =
    2106              :     columns_count_pad / COLUMN_BLOCK_SIZE; // COLUMN_BLOCK_SIZE == 2
    2107         2400 :   const size_t bytes_per_row_block_span = column_blocks_count * ROW_BLOCK_SIZE;
    2108         2400 :   const int column_blocks_cnt = K / QK4_0;
    2109              : 
    2110              :   alignas(32) static thread_local __m256i dst_tmp[dst_tmp_size];
    2111              :   alignas(32) static thread_local uint8_t mx16x16[16 * 16];
    2112              : 
    2113         2400 : #pragma omp parallel for schedule(guided)
    2114              :   for (int row_id = 0; row_id < (int)N; row_id += 16) {
    2115              :     const size_t row_in_block_id = row_id / ROW_BLOCK_SIZE;
    2116              :     size_t i_in_block = row_id % ROW_BLOCK_SIZE;
    2117              :     for (int column_out_block_id = 0; column_out_block_id < column_blocks_cnt;
    2118              :          column_out_block_id++) {
    2119              :       int column_idx = column_out_block_id * QK4_0;
    2120              :       int scale_offset = (column_idx / scale_group_size) * rows_count_pad;
    2121              :       const size_t row_block_base =
    2122              :         row_in_block_id * bytes_per_row_block_span + i_in_block;
    2123              :       int src_offset =
    2124              :         row_block_base + column_out_block_id * 16 * ROW_BLOCK_SIZE;
    2125              :       transpose_matrix_16x16(&osv32_weights[src_offset], ROW_BLOCK_SIZE,
    2126              :                              mx16x16, 16);
    2127              :       int max_r = std::min((size_t)16, N - row_id);
    2128              :       size_t row_out_block_id = row_id / NUM_Q4_0_BLOCKS;
    2129              :       int dst_offset =
    2130              :         (NUM_Q4_0_BLOCKS * sizeof(block_q4_0)) *
    2131              :         (column_out_block_id + row_out_block_id * column_blocks_cnt);
    2132              :       for (int r = 0; r < max_r; r += NUM_Q4_0_BLOCKS) {
    2133              :         create_q4_0_weights_x8(&mx16x16[16 * r], dst_tmp);
    2134              : 
    2135              :         nntr_make_block_q4_0x8(dst_tmp, (block_q4_0x8 *)(dst_ + dst_offset),
    2136              :                                &osv32_scales[scale_offset + row_id + r]);
    2137              :         row_out_block_id++;
    2138              :         dst_offset +=
    2139              :           (NUM_Q4_0_BLOCKS * sizeof(block_q4_0)) * column_blocks_cnt;
    2140              :       }
    2141              :     }
    2142              :   }
    2143         2400 : }
    2144              : 
    2145              : } // namespace nntrainer::avx2
        

Generated by: LCOV version 2.0-1