LCOV - code coverage report
Current view: top level - nntrainer/tensor/cpu_backend/x86 - avx2_impl.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 54.0 % 563 304
Test Date: 2025-12-14 20:38:17 Functions: 44.7 % 38 17

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2023 Donghyeon Jeong <dhyeon.jeong@samsung.com>
       4              :  *
       5              :  * @file   avx2_impl.cpp
       6              :  * @date   20 Feb 2024
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Donghyeon Jeong <dhyeon.jeong@samsung.com>
       9              :  * @author Sungsik Kong <ss.kong@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  * @brief  This is a source for AVX implementation
      12              :  *
      13              :  */
      14              : 
      15              : #include "avx2_impl.h"
      16              : #include <array>
      17              : #if __has_include(<bit>)
      18              : #include <bit>
      19              : #endif
      20              : #include <cassert>
      21              : #include <cmath>
      22              : #include <cstdint>
      23              : #include <cstring>
      24              : #include <fp16.h>
      25              : #include <immintrin.h>
      26              : #include <limits>
      27              : #if __has_include(<numbers>)
      28              : #include <numbers>
      29              : #endif
      30              : #include <type_traits>
      31              : #if __has_include(<version>)
      32              : #include <version>
      33              : #endif
      34              : #include <fallback_internal.h>
      35              : #include <util_func.h>
      36              : #include <vector>
      37              : 
      38              : #if !defined(__has_constexpr_builtin)
      39              : #define __has_constexpr_builtin(x) (0)
      40              : #endif
      41              : 
      42              : #if !defined(__has_cpp_attribute)
      43              : #define __has_cpp_attribute(x) (0)
      44              : #endif
      45              : 
      46              : // VECTORCALL calling-conv (default for x86_64-linux-gnu)
      47              : #if _MSC_VER >= 1700
      48              : #define _nnt_CC_VECTORCALL __vectorcall
      49              : #else
      50              : #define _nnt_CC_VECTORCALL
      51              : #endif
      52              : 
      53              : // Flatten attribute
      54              : #if _MSC_VER >= 1700 || __has_cpp_attribute(msvc::flatten)
      55              : #define _nnt_ATTR_FLATTEN [[msvc::flatten]]
      56              : #elif __has_cpp_attribute(gnu::flatten)
      57              : // clang, g++
      58              : #define _nnt_ATTR_FLATTEN [[gnu::flatten]]
      59              : #else
      60              : #define _nnt_ATTR_FLATTEN
      61              : #endif
      62              : 
      63              : #if _MSC_VER >= 1700 || __has_cpp_attribute(msvc::noinline)
      64              : #define _nnt_ATTR_NOINLINE [[msvc::noinline]]
      65              : #elif __has_cpp_attribute(gnu::flatten)
      66              : // clang, g++
      67              : #define _nnt_ATTR_NOINLINE [[gnu::noinline]]
      68              : #else
      69              : #define _nnt_ATTR_NOINLINE
      70              : #endif
      71              : 
      72              : #if _MSC_VER >= 1700 || __has_cpp_attribute(msvc::forceinline)
      73              : #define _nnt_ATTR_ALWAYS_INLINE [[msvc::forceinline]]
      74              : #elif __has_cpp_attribute(gnu::always_inline)
      75              : #define _nnt_ATTR_ALWAYS_INLINE [[gnu::always_inline]]
      76              : #endif
      77              : 
      78              : #if __has_cpp_attribute(unikely)
      79              : #define UNLIKELY [[unlikely]]
      80              : #else
      81              : #define UNLIKELY
      82              : #endif
      83              : 
      84              : #if !defined(_MSC_VER) && !defined(__clang__)
      85              : #pragma GCC diagnostic ignored "-Wattributes"
      86              : #endif
      87              : 
      88              : #if defined(__clang__) || defined(__GNUC__)
      89              : #define RESTRICT __restrict__
      90              : #else
      91              : #define RESTRICT
      92              : #endif
      93              : 
      94              : namespace {
      95              : 
      96              : template <typename To_, typename From_>
      97              : constexpr inline bool concept17_BinaryCastable =
      98              :   sizeof(To_) == sizeof(From_) &&
      99              :   std::is_trivially_copyable_v<From_> &&std::is_trivially_copyable_v<To_>;
     100              : 
     101              : template <class To_, class From_>
     102              : auto compat_bit_cast(const From_ &src) noexcept
     103              :   -> std::enable_if_t<concept17_BinaryCastable<To_, From_>, To_> {
     104              : #if __cpp_lib_bit_cast >= 201806L
     105              :   return std::bit_cast<To_>(src);
     106              : #else
     107              :   To_ dst;
     108              :   std::memcpy(&dst, &src, sizeof(To_));
     109              :   return dst;
     110              : #endif
     111              : }
     112              : 
     113              : [[nodiscard]] constexpr inline unsigned
     114              : constexpr_popcount(uint32_t v) noexcept {
     115              : #if __cpp_lib_bitops >= 201907L
     116              :   return std::popcount(v);
     117              : #else
     118              :   // Popcount via bit-hack
     119              :   v = v - ((v >> 1) & 0x55555555);                // reuse input as temporary
     120              :   v = (v & 0x33333333) + ((v >> 2) & 0x33333333); // temp
     121              :   auto c = (((v + (v >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24; // count
     122              :   return c;
     123              : #endif
     124              : }
     125              : 
     126              : template <unsigned I_>
     127              : constexpr inline bool concept17_PowerOfTwo = (constexpr_popcount(I_) == 1);
     128              : 
     129              : namespace numbers {
     130              : 
     131              : #if __has_include(<numbers>) && __cpp_lib_math_constants >= 201907L
     132              : using std::numbers::ln2_v;
     133              : using std::numbers::log2e_v;
     134              : #else
     135              : template <typename Float_> constexpr inline auto ln2_v = Float_{M_LN2};
     136              : 
     137              : template <typename Float_> constexpr inline auto log2e_v = Float_{M_LOG2E};
     138              : #endif
     139              : } // namespace numbers
     140              : 
     141              : constexpr inline float EXP_ARG_MIN = -87.0;
     142              : constexpr inline float EXP_ARG_MAX = +88.3762626647949f;
     143              : 
     144              : // @brief Precalculated lookup table for 2^x calculation
     145              : template <unsigned N_, typename Ty_ = uint32_t, typename Float_ = float,
     146              :           typename = std::enable_if_t<concept17_PowerOfTwo<N_>>>
     147              : struct Exp2Table {
     148              : 
     149              :   constexpr static inline auto MANTISSA_BITS =
     150              :     std::numeric_limits<Float_>::digits - 1;
     151              : 
     152              : #if __cpp_consteval >= 201811L && __has_constexpr_builtin(__builtin_exp2)
     153              :   [[nodiscard]] static consteval auto calculate() noexcept {
     154              :     std::array<Ty_, N_> t;
     155              :     for (unsigned i = 0; i < N_; ++i)
     156              :       t[i] = std::bit_cast<Ty_>(std::exp2(Float_{1.0} * i / N_)) -
     157              :              ((i << MANTISSA_BITS) / N_);
     158              :     return t;
     159              :   }
     160              : #endif
     161              : };
     162              : 
     163              : #if !__has_constexpr_builtin(__builtin_exp2) || !(__cpp_consteval >= 201811L)
     164              : 
     165              : // @brief Precalculated lookup table for 2^x calculation when we don't have
     166              : // constexpr math functions
     167              : template <> struct Exp2Table<8, uint32_t, float> {
     168              :   [[nodiscard]] static constexpr auto calculate() noexcept {
     169              :     std::array<uint32_t, 8> t = {0x3f800000U, 0x3f7b95c2U, 0x3f7837f0U,
     170              :                                  0x3f75fed7U, 0x3f7504f3U, 0x3f75672aU,
     171              :                                  0x3f7744fdU, 0x3f7ac0c7U};
     172              :     return t;
     173              :   }
     174              : };
     175              : 
     176              : #endif
     177              : 
     178              : template <unsigned N_> // requires PowerOfTwo<N_>
     179              : alignas(__m256) inline constexpr auto exp2_table_v = Exp2Table<N_>::calculate();
     180              : 
     181              : // Scalar version of expf() approximation with 3-th deg polynominal of
     182              : // fractional part
     183              : //
     184              : // The error with regards to std::expf less than 5e-6
     185              : // It is valid in range [-87, +88.37) - not handling +INF, NaN etc.
     186              : // The function domain is clamped to valid function range.
     187              : //
     188              : // The strategy picked is to approximate exp as 2^K*2^F
     189              : template <unsigned N_, typename = std::enable_if_t<concept17_PowerOfTwo<N_>>>
     190              : [[nodiscard]] _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN inline float
     191              : approx_exp_exp2_lookup(float x) noexcept {
     192              :   constexpr static unsigned N_MASK = uint32_t(N_ - 1U);
     193              :   constexpr static unsigned FLT_MANTISSA_BITS =
     194              :     std::numeric_limits<float>::digits - 1U;
     195              : 
     196              :   x = std::max(x, EXP_ARG_MIN);
     197              :   x *= float(0x1.0p1 / numbers::ln2_v<double> * N_);
     198              :   x = std::min(x, float(EXP_ARG_MAX / numbers::ln2_v<double> * N_));
     199              : 
     200              :   // Round nearest and convert integer part to an int (std::modf)
     201              :   // NB: This way doesn't handle ties even.
     202              :   auto x_int = x + 0x1.8p23f;
     203              :   auto x_uint = compat_bit_cast<uint32_t>(x_int);
     204              :   x_int -= 0x1.8p23f;
     205              :   auto x_frac = x - x_int;
     206              : 
     207              :   auto s_int = exp2_table_v<N_>[x_uint & N_MASK];
     208              :   auto x_uint_shifted = x_uint
     209              :                         << (FLT_MANTISSA_BITS - constexpr_popcount(N_MASK));
     210              :   auto s_int_2 = s_int + x_uint_shifted;
     211              :   auto s = compat_bit_cast<float>(s_int_2);
     212              : 
     213              :   // Polynominal of form C0*x^3 + C1*x^2 + C2*x^1 + 1.0
     214              :   static constexpr float poly_4[] = {0x1.c6af84b912394p-5f / N_ / N_ / N_,
     215              :                                      0x1.ebfce50fac4f3p-3f / N_ / N_,
     216              :                                      0x1.62e42ff0c52d6p-1f / N_};
     217              : 
     218              :   auto q0 = std::fma(x_frac, poly_4[0], poly_4[1]);
     219              :   auto x_frac_pow2 = x_frac * x_frac;
     220              :   auto q2 = x_frac * poly_4[2]; // not adding +1.0
     221              : 
     222              :   x = std::fma(q0, x_frac_pow2, q2);
     223              : 
     224              :   return std::fma(x, s, s); // NB: (handles (x+1) by addition of s)
     225              : }
     226              : 
     227              : // Vectorized version of above
     228              : template <unsigned N_, typename = std::enable_if_t<concept17_PowerOfTwo<N_>>>
     229              : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN inline auto _nnt_CC_VECTORCALL
     230              : avx2_approx_exp_e2lookup(__m256 xs) noexcept {
     231              : 
     232              :   constexpr static uint32_t N_MASK = uint32_t(N_ - 1U);
     233              :   alignas(64) constexpr static auto EXP2_TBL = exp2_table_v<N_>;
     234              :   constexpr static unsigned MANTISSA_BITS =
     235              :     std::numeric_limits<float>::digits - 1;
     236              : 
     237              :   // Ensure arg in range [exp_arg_min, exp_arg_max]
     238              :   xs = _mm256_max_ps(xs, _mm256_set1_ps(EXP_ARG_MIN));
     239              :   // Would clamp to EXP_ARG_MAX but we move it after multiplication for IPC:
     240              :   // xs = _mm256_min_ps(xs, _mm256_set1_ps(EXP_ARG_MAX));
     241              : 
     242              :   xs =
     243              :     _mm256_mul_ps(xs, _mm256_set1_ps(float(1.0 / numbers::ln2_v<double> * N_)));
     244              :   // Clamp EXP_ARG_MAX after multiply
     245              :   xs = _mm256_min_ps(
     246              :     xs, _mm256_set1_ps(float(EXP_ARG_MAX / numbers::ln2_v<double> * N_)));
     247              : 
     248              :   // Mostly equivalent to, doesn't round ties to even
     249              :   // auto xs_int = _mm256_round_ps(xs, _MM_FROUND_TO_NEAREST_INT |
     250              :   // _MM_FROUND_NO_EXC); auto xs_int_as_u32 = _mm256_cvtps_epi32(xs_int);
     251              :   auto xs_int = _mm256_add_ps(xs, _mm256_set1_ps(0x1.8p23f));
     252              :   auto xs_int_as_u32 = _mm256_castps_si256(xs_int);
     253              :   xs_int = _mm256_sub_ps(xs_int, _mm256_set1_ps(0x1.8p23f));
     254              : 
     255              :   // Calculate fractional part
     256              :   auto xs_frac = _mm256_sub_ps(xs, xs_int);
     257              :   // Indices for lookup (modulo N_)
     258              :   auto exp2_idxs = _mm256_and_si256(xs_int_as_u32, _mm256_set1_epi32(N_MASK));
     259              : 
     260              :   __m256i s_ints;
     261              : 
     262              :   // Lookup e^xs_int s factor
     263              :   if constexpr (N_ == 8) {
     264              :     // Lookup by vector permute
     265              :     auto tbl = _mm256_load_si256((__m256i *)EXP2_TBL.data());
     266              :     s_ints = _mm256_permutevar8x32_epi32(tbl, exp2_idxs);
     267              :   } else {
     268              :     // Falback for not fitting number of vector elements
     269              :     s_ints = _mm256_i32gather_epi32(EXP2_TBL.data(), exp2_idxs, 1);
     270              :   }
     271              : 
     272              :   auto xs_uint_shifted = _mm256_slli_epi32(
     273              :     xs_int_as_u32, MANTISSA_BITS - constexpr_popcount(N_MASK));
     274              :   auto s_ints_2 = _mm256_add_epi32(s_ints, xs_uint_shifted);
     275              :   auto s_floats = _mm256_castsi256_ps(s_ints_2);
     276              : 
     277              :   static constexpr float poly_d4[] = {0x1.c6af84b912394p-5f / N_ / N_ / N_,
     278              :                                       0x1.ebfce50fac4f3p-3f / N_ / N_,
     279              :                                       0x1.62e42ff0c52d6p-1f / N_};
     280              : 
     281              :   const auto C0 = _mm256_set1_ps(poly_d4[0]);
     282              :   const auto C1 = _mm256_set1_ps(poly_d4[1]);
     283              :   const auto C2 = _mm256_set1_ps(poly_d4[2]);
     284              : 
     285              :   auto qs0 = _mm256_fmadd_ps(xs_frac, C0, C1);
     286              :   auto xs_frac_pow2 = _mm256_mul_ps(xs_frac, xs_frac);
     287              :   auto qs2 = _mm256_mul_ps(xs_frac, C2);
     288              : 
     289              :   xs = _mm256_fmadd_ps(qs0, xs_frac_pow2, qs2);
     290              : 
     291              :   return _mm256_fmadd_ps(xs, s_floats, s_floats);
     292              : }
     293              : 
     294              : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN auto _nnt_CC_VECTORCALL
     295              : avx2_negate_ps(__m256 x) noexcept -> __m256 {
     296              :   constexpr auto SIGN_SHIFT = sizeof(float) * 8 - 1;
     297              :   const auto UNDEF = _mm256_undefined_si256();
     298              :   const auto sign_bit =
     299              :     _mm256_slli_epi32(_mm256_cmpeq_epi16(UNDEF, UNDEF), SIGN_SHIFT);
     300              :   auto flt_sign_bit = _mm256_castsi256_ps(sign_bit);
     301              :   auto neg_x = _mm256_xor_ps(x, flt_sign_bit);
     302              :   return neg_x;
     303              : }
     304              : 
     305              : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN auto _nnt_CC_VECTORCALL
     306              : avx2_approx_swiglu(__m256 x, __m256 s) noexcept -> __m256 {
     307              :   auto neg_x = avx2_negate_ps(x);
     308              :   auto inv_sigmoid =
     309              :     _mm256_add_ps(avx2_approx_exp_e2lookup<8>(neg_x), _mm256_set1_ps(1.0f));
     310              :   auto swiglu_nonscaled = _mm256_div_ps(x, inv_sigmoid);
     311              :   return _mm256_mul_ps(swiglu_nonscaled, s);
     312              : }
     313              : 
     314              : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN auto _nnt_CC_VECTORCALL
     315              : avx2_approx_swiglu_alpha(__m256 x, __m256 s, __m256 alpha) noexcept -> __m256 {
     316              :   auto alpha_x = _mm256_mul_ps(alpha, x);
     317              :   auto neg_alpha_x = avx2_negate_ps(alpha_x);
     318              :   auto inv_sigmoid = _mm256_add_ps(avx2_approx_exp_e2lookup<8>(neg_alpha_x),
     319              :                                    _mm256_set1_ps(1.0f));
     320              :   auto swiglu_nonscaled = _mm256_div_ps(x, inv_sigmoid);
     321              :   return _mm256_mul_ps(swiglu_nonscaled, s);
     322              : }
     323              : } // namespace
     324              : 
     325              : namespace nntrainer::avx2 {
     326              : 
     327              : /**
     328              :  * @brief struct of q4_0x8 block
     329              :  */
     330              : struct block_q4_0x8 {
     331              :   uint16_t d[8];   // 16B
     332              :   uint8_t qs[128]; // 16 x u64
     333              : };
     334              : 
     335              : #define USE_NONTEMPORAL_STORES 1
     336              : 
     337              : static inline void store256_u16(void *dst, __m256i v) {
     338              : #if defined(USE_NONTEMPORAL_STORES)
     339              :   // use NT only if 32B-aligned; otherwise fall back (correctness first)
     340              :   if (((uintptr_t)dst & 31u) == 0) {
     341              :     _mm256_stream_si256((__m256i *)dst, v);
     342              :     return;
     343              :   }
     344              : #endif
     345              :   _mm256_storeu_si256((__m256i *)dst, v);
     346              : }
     347              : 
     348            0 : void unpack_q4_0x8_transpose16(const void *src, unsigned short *__restrict dT,
     349              :                                unsigned short *__restrict qsT, int N, int K,
     350              :                                int CT) // column tile (in units of 32-cols)
     351              : {
     352            0 :   assert((K % 256) == 0);
     353            0 :   assert((N % 8) == 0);
     354              : 
     355              :   const auto *__restrict x = static_cast<const block_q4_0x8 *>(src);
     356              : 
     357            0 :   const int groups_N8 = N / 8;    // number of 8-row groups
     358            0 :   const int cols_scales = K / 32; // K subblocks
     359              : 
     360              :   // AVX2 constants
     361            0 :   const __m128i v88 = _mm_set1_epi8((char)0x88);
     362            0 :   const __m128i v0f = _mm_set1_epi8((char)0x0F);
     363            0 :   const __m128i vF0 = _mm_set1_epi8((char)0xF0);
     364              : 
     365              :   const __m128i idx_even =
     366              :     _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, (char)0xFF, (char)0xFF, (char)0xFF,
     367            0 :                   (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF);
     368              :   const __m128i idx_odd =
     369              :     _mm_setr_epi8(1, 3, 5, 7, 9, 11, 13, 15, (char)0xFF, (char)0xFF, (char)0xFF,
     370            0 :                   (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF);
     371              :   const __m128i idx_0246 =
     372              :     _mm_setr_epi8(0, 2, 4, 6, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
     373              :                   (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
     374            0 :                   (char)0xFF, (char)0xFF, (char)0xFF);
     375              :   const __m128i idx_1357 =
     376              :     _mm_setr_epi8(1, 3, 5, 7, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
     377              :                   (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
     378            0 :                   (char)0xFF, (char)0xFF, (char)0xFF);
     379              : 
     380            0 :   auto pack_row8 = [&](const unsigned char *qs0, const unsigned char *qs1,
     381              :                        int off) -> __m128i {
     382            0 :     __m128i lo8 = _mm_loadl_epi64((const __m128i *)(qs0 + 8 * off));
     383            0 :     __m128i hi8 = _mm_loadl_epi64((const __m128i *)(qs1 + 8 * off));
     384              :     __m128i v = _mm_unpacklo_epi64(lo8, hi8);
     385            0 :     v = _mm_xor_si128(v, v88);
     386            0 :     __m128i lo = _mm_and_si128(v, v0f);
     387              :     __m128i hi = _mm_and_si128(_mm_srli_epi16(v, 4), v0f);
     388            0 :     __m128i lo_e = _mm_shuffle_epi8(lo, idx_even);
     389            0 :     __m128i lo_o = _mm_shuffle_epi8(lo, idx_odd);
     390              :     __m128i hi_e = _mm_shuffle_epi8(hi, idx_even);
     391              :     __m128i hi_o = _mm_shuffle_epi8(hi, idx_odd);
     392              :     __m128i low_lane =
     393            0 :       _mm_or_si128(lo_e, _mm_and_si128(_mm_slli_epi16(lo_o, 4), vF0));
     394              :     __m128i high_lane =
     395              :       _mm_or_si128(hi_e, _mm_and_si128(_mm_slli_epi16(hi_o, 4), vF0));
     396            0 :     __m128i low_e2 = _mm_shuffle_epi8(low_lane, idx_0246);
     397            0 :     __m128i low_o2 = _mm_shuffle_epi8(low_lane, idx_1357);
     398              :     __m128i high_e2 = _mm_shuffle_epi8(high_lane, idx_0246);
     399              :     __m128i high_o2 = _mm_shuffle_epi8(high_lane, idx_1357);
     400              :     __m128i pack_lo = _mm_unpacklo_epi8(low_e2, low_o2);   // 4×u16 (w0..w3)
     401              :     __m128i pack_hi = _mm_unpacklo_epi8(high_e2, high_o2); // 4×u16 (w4..w7)
     402            0 :     return _mm_unpacklo_epi64(pack_lo, pack_hi);           // 8×u16 (w0..w7)
     403            0 :   };
     404              : 
     405              :   auto transpose8x8_epi16 =
     406              :     [](__m128i r0, __m128i r1, __m128i r2, __m128i r3, __m128i r4, __m128i r5,
     407              :        __m128i r6, __m128i r7, __m128i &c0, __m128i &c1, __m128i &c2,
     408              :        __m128i &c3, __m128i &c4, __m128i &c5, __m128i &c6, __m128i &c7) {
     409              :       __m128i t0 = _mm_unpacklo_epi16(r0, r1);
     410              :       __m128i t1 = _mm_unpackhi_epi16(r0, r1);
     411              :       __m128i t2 = _mm_unpacklo_epi16(r2, r3);
     412              :       __m128i t3 = _mm_unpackhi_epi16(r2, r3);
     413              :       __m128i t4 = _mm_unpacklo_epi16(r4, r5);
     414              :       __m128i t5 = _mm_unpackhi_epi16(r4, r5);
     415              :       __m128i t6 = _mm_unpacklo_epi16(r6, r7);
     416              :       __m128i t7 = _mm_unpackhi_epi16(r6, r7);
     417              : 
     418              :       __m128i u0 = _mm_unpacklo_epi32(t0, t2);
     419              :       __m128i u1 = _mm_unpackhi_epi32(t0, t2);
     420              :       __m128i u2 = _mm_unpacklo_epi32(t1, t3);
     421              :       __m128i u3 = _mm_unpackhi_epi32(t1, t3);
     422              :       __m128i u4 = _mm_unpacklo_epi32(t4, t6);
     423              :       __m128i u5 = _mm_unpackhi_epi32(t4, t6);
     424              :       __m128i u6 = _mm_unpacklo_epi32(t5, t7);
     425              :       __m128i u7 = _mm_unpackhi_epi32(t5, t7);
     426              : 
     427              :       c0 = _mm_unpacklo_epi64(u0, u4);
     428              :       c1 = _mm_unpackhi_epi64(u0, u4);
     429              :       c2 = _mm_unpacklo_epi64(u1, u5);
     430              :       c3 = _mm_unpackhi_epi64(u1, u5);
     431              :       c4 = _mm_unpacklo_epi64(u2, u6);
     432              :       c5 = _mm_unpackhi_epi64(u2, u6);
     433              :       c6 = _mm_unpacklo_epi64(u3, u7);
     434              :       c7 = _mm_unpackhi_epi64(u3, u7);
     435              :     };
     436              : 
     437              :   // -------- pair-processing path: handle two 8-row groups (16 rows) per pass
     438              :   // --------
     439            0 :   const int groups_pairs = groups_N8 / 2;
     440              : 
     441              : #ifdef _MSC_VER
     442              : #pragma warning(push)
     443              : #pragma warning(disable : 4849)
     444              : #endif
     445            0 : #pragma omp parallel for collapse(2) schedule(static)
     446              : #ifdef _MSC_VER
     447              : #pragma warning(pop)
     448              : #endif
     449              :   for (int c0 = 0; c0 < cols_scales; c0 += CT) {
     450              :     for (int bp = 0; bp < groups_pairs; ++bp) {
     451              :       const int b0 = 2 * bp;
     452              :       const int b1 = b0 + 1;
     453              :       const int r0 = b0 * 8; // 16 rows: r0..r0+15
     454              :       const int c1 = std::min(c0 + CT, cols_scales);
     455              : 
     456              :       for (int c = c0; c < c1; ++c) {
     457              :         const block_q4_0x8 &A = x[b0 * cols_scales + c];
     458              :         const block_q4_0x8 &B = x[b1 * cols_scales + c];
     459              : 
     460              :         unsigned short *__restrict dT_c = dT + c * N;
     461              :         unsigned short *__restrict qsT_c0 = qsT + (c * 8) * N;
     462              : 
     463              :         // scales: pack two 8×u16 vectors → one 256b store to dT[c, r0..r0+15]
     464              :         __m128i sd0 = _mm_loadu_si128((const __m128i *)A.d);
     465              :         __m128i sd1 = _mm_loadu_si128((const __m128i *)B.d);
     466              :         __m256i sdp = _mm256_set_m128i(sd1, sd0);
     467              :         store256_u16(dT_c + r0, sdp);
     468              : 
     469              :         // pre-split stripes
     470              :         const unsigned char *__restrict A0 = A.qs;      // + 8*off
     471              :         const unsigned char *__restrict A1 = A.qs + 64; // + 8*off
     472              :         const unsigned char *__restrict B0 = B.qs;
     473              :         const unsigned char *__restrict B1 = B.qs + 64;
     474              : 
     475              :         // build 8 rows for A and 8 rows for B
     476              :         __m128i Ra[8], Rb[8];
     477              :         for (int off = 0; off < 8; ++off) {
     478              :           Ra[off] = pack_row8(A0, A1, off);
     479              :           Rb[off] = pack_row8(B0, B1, off);
     480              :         }
     481              : 
     482              :         // 8×8 transpose → columns (each 8×u16) for A and B
     483              :         __m128i Ca0, Ca1, Ca2, Ca3, Ca4, Ca5, Ca6, Ca7;
     484              :         __m128i Cb0, Cb1, Cb2, Cb3, Cb4, Cb5, Cb6, Cb7;
     485              :         transpose8x8_epi16(Ra[0], Ra[1], Ra[2], Ra[3], Ra[4], Ra[5], Ra[6],
     486              :                            Ra[7], Ca0, Ca1, Ca2, Ca3, Ca4, Ca5, Ca6, Ca7);
     487              :         transpose8x8_epi16(Rb[0], Rb[1], Rb[2], Rb[3], Rb[4], Rb[5], Rb[6],
     488              :                            Rb[7], Cb0, Cb1, Cb2, Cb3, Cb4, Cb5, Cb6, Cb7);
     489              : 
     490              :         // pair and store 32B per column t: rows r0..r0+15 are contiguous
     491              :         unsigned short *__restrict base = qsT_c0 + r0;
     492              :         const int S = N;
     493              :         store256_u16(base + 0 * S, _mm256_set_m128i(Cb0, Ca0));
     494              :         store256_u16(base + 1 * S, _mm256_set_m128i(Cb1, Ca1));
     495              :         store256_u16(base + 2 * S, _mm256_set_m128i(Cb2, Ca2));
     496              :         store256_u16(base + 3 * S, _mm256_set_m128i(Cb3, Ca3));
     497              :         store256_u16(base + 4 * S, _mm256_set_m128i(Cb4, Ca4));
     498              :         store256_u16(base + 5 * S, _mm256_set_m128i(Cb5, Ca5));
     499              :         store256_u16(base + 6 * S, _mm256_set_m128i(Cb6, Ca6));
     500              :         store256_u16(base + 7 * S, _mm256_set_m128i(Cb7, Ca7));
     501              :       }
     502              :     }
     503              :   }
     504              : 
     505              :   // -------- tail: if odd number of 8-row groups, process the last one (8 rows)
     506              :   // --------
     507            0 :   if (groups_N8 & 1) {
     508            0 :     const int b = groups_N8 - 1;
     509            0 :     const int r0 = b * 8;
     510              : 
     511            0 : #pragma omp parallel for schedule(static)
     512              :     for (int c0 = 0; c0 < cols_scales; c0 += CT) {
     513              :       const int c1 = std::min(c0 + CT, cols_scales);
     514              :       for (int c = c0; c < c1; ++c) {
     515              :         const block_q4_0x8 &A = x[b * cols_scales + c];
     516              :         unsigned short *__restrict dT_c = dT + c * N;
     517              :         unsigned short *__restrict qsT_c0 = qsT + (c * 8) * N;
     518              : 
     519              :         // scales (8×u16)
     520              :         __m128i sd0 = _mm_loadu_si128((const __m128i *)A.d);
     521              :         _mm_storeu_si128((__m128i *)(dT_c + r0), sd0);
     522              : 
     523              :         const unsigned char *__restrict A0 = A.qs;
     524              :         const unsigned char *__restrict A1 = A.qs + 64;
     525              : 
     526              :         __m128i R[8];
     527              :         for (int off = 0; off < 8; ++off)
     528              :           R[off] = pack_row8(A0, A1, off);
     529              : 
     530              :         __m128i C0, C1, C2, C3, C4, C5, C6, C7;
     531              :         transpose8x8_epi16(R[0], R[1], R[2], R[3], R[4], R[5], R[6], R[7], C0,
     532              :                            C1, C2, C3, C4, C5, C6, C7);
     533              : 
     534              :         unsigned short *__restrict base = qsT_c0 + r0;
     535              :         const int S = N;
     536              :         _mm_storeu_si128((__m128i *)(base + 0 * S), C0);
     537              :         _mm_storeu_si128((__m128i *)(base + 1 * S), C1);
     538              :         _mm_storeu_si128((__m128i *)(base + 2 * S), C2);
     539              :         _mm_storeu_si128((__m128i *)(base + 3 * S), C3);
     540              :         _mm_storeu_si128((__m128i *)(base + 4 * S), C4);
     541              :         _mm_storeu_si128((__m128i *)(base + 5 * S), C5);
     542              :         _mm_storeu_si128((__m128i *)(base + 6 * S), C6);
     543              :         _mm_storeu_si128((__m128i *)(base + 7 * S), C7);
     544              :       }
     545              :     }
     546              :   }
     547              : 
     548              : #if defined(USE_NONTEMPORAL_STORES)
     549              :   _mm_sfence(); // ensure NT stores are globally visible before returning
     550              : #endif
     551            0 : }
     552              : 
     553              : static inline __m256i butterfly32(__m256i a) {
     554              :   const __m256i SHUF_EVEN = _mm256_setr_epi8(
     555              :     0, 2, 4, 6, 8, 10, 12, 14, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
     556              :     (char)0x80, (char)0x80, (char)0x80, (char)0x80, 0, 2, 4, 6, 8, 10, 12, 14,
     557              :     (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
     558              :     (char)0x80, (char)0x80);
     559              :   const __m256i SHUF_ODD = _mm256_setr_epi8(
     560              :     1, 3, 5, 7, 9, 11, 13, 15, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
     561              :     (char)0x80, (char)0x80, (char)0x80, (char)0x80, 1, 3, 5, 7, 9, 11, 13, 15,
     562              :     (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
     563              :     (char)0x80, (char)0x80);
     564              :   const __m256i even = _mm256_shuffle_epi8(a, SHUF_EVEN);
     565              :   const __m256i odd = _mm256_shuffle_epi8(a, SHUF_ODD);
     566              :   const __m256i LO = _mm256_set1_epi8(0x0F);
     567              :   const __m256i HI = _mm256_set1_epi8((char)0xF0);
     568              :   __m256i low =
     569              :     _mm256_or_si256(_mm256_and_si256(even, LO),
     570              :                     _mm256_slli_epi16(_mm256_and_si256(odd, LO), 4));
     571              :   __m256i high =
     572              :     _mm256_or_si256(_mm256_srli_epi16(_mm256_and_si256(even, HI), 4),
     573              :                     _mm256_and_si256(odd, HI));
     574              :   high = _mm256_slli_si256(high, 8);
     575              :   return _mm256_or_si256(low, high);
     576              : }
     577              : 
     578              : // Build 16B packet [d0|d1] from two 8B chunks using vector loads (no GPR
     579              : // moves).
     580              : static inline __m128i make_pkt128(const uint8_t *base_qs, int d0, int d1) {
     581            0 :   __m128i lo = _mm_loadl_epi64((const __m128i *)(base_qs + ((size_t)d0 << 3)));
     582            0 :   __m128i hi = _mm_loadl_epi64((const __m128i *)(base_qs + ((size_t)d1 << 3)));
     583              :   return _mm_unpacklo_epi64(lo, hi);
     584              : }
     585              : 
     586              : // ================== core template with QS unrolled by 8 blocks
     587              : // ==================
     588              : template <int UNIT, int GROUPS>
     589              : static inline void convert_q4_0x8_noshuffle(const void *src,
     590              :                                             uint16_t *RESTRICT d_out,
     591              :                                             uint8_t *RESTRICT qs_out) {
     592              :   static_assert(UNIT % 16 == 0, "UNIT must be multiple of 16");
     593              :   constexpr int BLOCKS_PER_GROUP = UNIT / 8;  // d entries per offset per group
     594              :   constexpr int PAIRS_PER_OFFSET = UNIT / 16; // 16B packets per half per offset
     595              :   static_assert((PAIRS_PER_OFFSET % 4) == 0,
     596              :                 "need multiple of 4 packets (8 blocks) per iter");
     597              : 
     598              :   constexpr size_t D_ELEMS_PER_GROUP = 8 * BLOCKS_PER_GROUP;
     599              :   constexpr size_t QS_BYTES_PER_GROUP = (size_t)16 * UNIT;
     600              :   constexpr size_t QS_BYTES_PER_OFFSET = (size_t)2 * UNIT;
     601              : 
     602            0 :   const block_q4_0x8 *x = (const block_q4_0x8 *)src;
     603            0 :   const __m256i bias256 = _mm256_set1_epi8((char)0x88);
     604              : 
     605              : #ifdef _MSC_VER
     606              : #pragma warning(push)
     607              : #pragma warning(disable : 4849)
     608              : #endif
     609            0 : #pragma omp parallel for collapse(2) schedule(static)
     610              : #ifdef _MSC_VER
     611              : #pragma warning(pop)
     612              : #endif
     613              :   for (int b = 0; b < GROUPS; ++b) {
     614              :     for (int offset = 0; offset < 8; ++offset) {
     615              : 
     616              :       // ---- D slice ----
     617              :       {
     618              :         uint16_t *d_ptr = d_out + (size_t)b * D_ELEMS_PER_GROUP +
     619              :                           (size_t)offset * BLOCKS_PER_GROUP;
     620              :         const block_q4_0x8 *xb = x + (size_t)b * BLOCKS_PER_GROUP;
     621              :         for (int i = 0; i < BLOCKS_PER_GROUP; ++i) {
     622              :           d_ptr[i] = xb[i].d[offset];
     623              :         }
     624              :       }
     625              : 
     626              :       // ---- QS slice (unroll 8 blocks / 128B per iter) ----
     627              :       {
     628              :         uint8_t *qs_ptr = qs_out + (size_t)b * QS_BYTES_PER_GROUP +
     629              :                           (size_t)offset * QS_BYTES_PER_OFFSET;
     630              :         const int base_q = (b * UNIT * 2) + offset;
     631              :         const int d0 = (base_q & 15), d1 = d0 ^ 8;
     632              : 
     633            0 :         auto do_half = [&](int blk_base) {
     634              :           // Each iter handles 8 consecutive blocks: j..j+7
     635            0 :           for (int j = 0; j < PAIRS_PER_OFFSET; j += 8) {
     636            0 :             const uint8_t *q0 = x[blk_base + j + 0].qs;
     637            0 :             const uint8_t *q1 = x[blk_base + j + 1].qs;
     638            0 :             const uint8_t *q2 = x[blk_base + j + 2].qs;
     639            0 :             const uint8_t *q3 = x[blk_base + j + 3].qs;
     640            0 :             const uint8_t *q4 = x[blk_base + j + 4].qs;
     641            0 :             const uint8_t *q5 = x[blk_base + j + 5].qs;
     642            0 :             const uint8_t *q6 = x[blk_base + j + 6].qs;
     643            0 :             const uint8_t *q7 = x[blk_base + j + 7].qs;
     644              : 
     645              : #if Q4X8_PREFETCH_DIST > 0
     646              :             _mm_prefetch(
     647              :               (const char *)(x[blk_base + j + Q4X8_PREFETCH_DIST].qs),
     648              :               _MM_HINT_NTA);
     649              : #endif
     650              :             // Build 8 packets in XMM regs
     651            0 :             __m128i pkt0 = make_pkt128(q0, d0, d1);
     652              :             __m128i pkt1 = make_pkt128(q1, d0, d1);
     653              :             __m128i pkt2 = make_pkt128(q2, d0, d1);
     654              :             __m128i pkt3 = make_pkt128(q3, d0, d1);
     655              :             __m128i pkt4 = make_pkt128(q4, d0, d1);
     656              :             __m128i pkt5 = make_pkt128(q5, d0, d1);
     657              :             __m128i pkt6 = make_pkt128(q6, d0, d1);
     658              :             __m128i pkt7 = make_pkt128(q7, d0, d1);
     659              : 
     660              :             // Four 32B batches: [0|1], [2|3], [4|5], [6|7]
     661              :             __m256i v01 = _mm256_set_m128i(pkt1, pkt0);
     662              :             __m256i v23 = _mm256_set_m128i(pkt3, pkt2);
     663              :             __m256i v45 = _mm256_set_m128i(pkt5, pkt4);
     664              :             __m256i v67 = _mm256_set_m128i(pkt7, pkt6);
     665              : 
     666            0 :             v01 = _mm256_xor_si256(v01, bias256);
     667              :             v23 = _mm256_xor_si256(v23, bias256);
     668              :             v45 = _mm256_xor_si256(v45, bias256);
     669              :             v67 = _mm256_xor_si256(v67, bias256);
     670              : 
     671              :             __m256i o01 = butterfly32(v01);
     672              :             __m256i o23 = butterfly32(v23);
     673              :             __m256i o45 = butterfly32(v45);
     674              :             __m256i o67 = butterfly32(v67);
     675              : 
     676              : #if Q4X8_USE_STREAMING_STORES
     677              :             _mm256_stream_si256((__m256i *)(qs_ptr + 0), o01);
     678              :             _mm256_stream_si256((__m256i *)(qs_ptr + 32), o23);
     679              :             _mm256_stream_si256((__m256i *)(qs_ptr + 64), o45);
     680              :             _mm256_stream_si256((__m256i *)(qs_ptr + 96), o67);
     681              : #else
     682            0 :             _mm256_storeu_si256((__m256i *)(qs_ptr + 0), o01);
     683            0 :             _mm256_storeu_si256((__m256i *)(qs_ptr + 32), o23);
     684            0 :             _mm256_storeu_si256((__m256i *)(qs_ptr + 64), o45);
     685            0 :             _mm256_storeu_si256((__m256i *)(qs_ptr + 96), o67);
     686              : #endif
     687            0 :             qs_ptr += 128;
     688              :           }
     689              :         };
     690              : 
     691              :         // first half
     692              :         do_half(base_q >> 4);
     693              :         // second half (same d0/d1 pattern)
     694              :         do_half((base_q + UNIT) >> 4);
     695              :       }
     696              :     }
     697              :   }
     698              : 
     699              : #if Q4X8_USE_STREAMING_STORES
     700              :   _mm_sfence();
     701              : #endif
     702              : }
     703              : 
     704              : // ================== wrappers for your K,N combinations ==================
     705              : // K = 3072 (UNIT = 768)
     706            0 : void convert_q4_0x8_shuffle_K3072_N98304(const void *src, uint16_t *d_out,
     707              :                                          uint8_t *qs_out) {
     708              :   // groups = (N*8)/UNIT = 1024
     709              :   convert_q4_0x8_noshuffle<768, 1024>(src, d_out, qs_out);
     710            0 : }
     711            0 : void convert_q4_0x8_shuffle_K3072_N36864(const void *src, uint16_t *d_out,
     712              :                                          uint8_t *qs_out) {
     713              :   // groups = 384
     714              :   convert_q4_0x8_noshuffle<768, 384>(src, d_out, qs_out);
     715            0 : }
     716            0 : void convert_q4_0x8_shuffle_K3072_N3072(const void *src, uint16_t *d_out,
     717              :                                         uint8_t *qs_out) {
     718              :   // groups = 32
     719              :   convert_q4_0x8_noshuffle<768, 32>(src, d_out, qs_out);
     720            0 : }
     721              : 
     722              : // K = 8192 (UNIT = 2048)
     723            0 : void convert_q4_0x8_shuffle_K8192_N98304(const void *src, uint16_t *d_out,
     724              :                                          uint8_t *qs_out) {
     725              :   // groups = 384
     726              :   convert_q4_0x8_noshuffle<2048, 384>(src, d_out, qs_out);
     727            0 : }
     728            0 : void convert_q4_0x8_shuffle_K8192_N36864(const void *src, uint16_t *d_out,
     729              :                                          uint8_t *qs_out) {
     730              :   // groups = 144
     731              :   convert_q4_0x8_noshuffle<2048, 144>(src, d_out, qs_out);
     732            0 : }
     733            0 : void convert_q4_0x8_shuffle_K8192_N3072(const void *src, uint16_t *d_out,
     734              :                                         uint8_t *qs_out) {
     735              :   // groups = 12
     736              :   convert_q4_0x8_noshuffle<2048, 12>(src, d_out, qs_out);
     737            0 : }
     738              : 
     739              : // Optional tiny dispatcher if you want one entry point:
     740            0 : void convert_q4_0x8_shuffle_dispatch_avx(const void *src, uint16_t *d_out,
     741              :                                          uint8_t *qs_out, int N, int K) {
     742            0 :   if (K == 3072) {
     743            0 :     if (N == 98304)
     744            0 :       return convert_q4_0x8_shuffle_K3072_N98304(src, d_out, qs_out);
     745            0 :     if (N == 36864)
     746            0 :       return convert_q4_0x8_shuffle_K3072_N36864(src, d_out, qs_out);
     747            0 :     if (N == 3072)
     748            0 :       return convert_q4_0x8_shuffle_K3072_N3072(src, d_out, qs_out);
     749              :   } else { // K == 8192
     750            0 :     if (N == 98304)
     751            0 :       return convert_q4_0x8_shuffle_K8192_N98304(src, d_out, qs_out);
     752            0 :     if (N == 36864)
     753            0 :       return convert_q4_0x8_shuffle_K8192_N36864(src, d_out, qs_out);
     754            0 :     if (N == 3072)
     755            0 :       return convert_q4_0x8_shuffle_K8192_N3072(src, d_out, qs_out);
     756              :   }
     757              :   // If a new combo appears, fall back to a generic version (not shown here).
     758            0 :   assert(!"Unsupported (K,N) combination");
     759              : }
     760              : 
     761           12 : bool is_valid(const unsigned int N, const float *input) {
     762           12 :   assert(N != 0);
     763           12 :   assert(input != NULL);
     764              : 
     765              :   int temp = 0;
     766              :   unsigned int idx = 0;
     767              : 
     768              :   const __m256 SIGN_MASK = _mm256_set1_ps(-0.0);
     769              :   const __m256 INF = _mm256_set1_ps(std::numeric_limits<float>::infinity());
     770              : 
     771              :   // 16 single-precision check : ( X != X )
     772           15 :   for (; N - idx >= 16; idx += 16) {
     773              :     __m256 vec0 = _mm256_loadu_ps(input);
     774              :     __m256 vec1 = _mm256_loadu_ps(input + 8);
     775            6 :     input += 16;
     776              :     __m256 res = _mm256_cmp_ps(vec0, vec0, _CMP_NEQ_UQ);
     777            6 :     temp = temp | _mm256_movemask_ps(res);
     778              : 
     779            6 :     if (temp)
     780              :       return false;
     781              : 
     782              :     // check infinity in vec0
     783              :     vec0 = _mm256_andnot_ps(SIGN_MASK, vec0);
     784              :     vec0 = _mm256_cmp_ps(vec0, INF, _CMP_EQ_OQ);
     785              : 
     786            5 :     temp = temp | _mm256_movemask_ps(vec0);
     787            5 :     if (temp)
     788              :       return false;
     789              : 
     790              :     __m256 res1 = _mm256_cmp_ps(vec1, vec1, _CMP_NEQ_UQ);
     791            3 :     temp = temp | _mm256_movemask_ps(res1);
     792              : 
     793            3 :     if (temp)
     794              :       return false;
     795              : 
     796              :     // check infinity in vec1
     797              :     vec1 = _mm256_andnot_ps(SIGN_MASK, vec1);
     798              :     vec1 = _mm256_cmp_ps(vec1, INF, _CMP_EQ_OQ);
     799              : 
     800            3 :     temp = temp | _mm256_movemask_ps(vec1);
     801              : 
     802            3 :     if (temp)
     803              :       return false;
     804              :   }
     805              : 
     806              :   // 8 single-precision check : ( X != X )
     807           11 :   for (; N - idx >= 8; idx += 8) {
     808              :     __m256 vec = _mm256_loadu_ps(input);
     809            5 :     input += 8;
     810              :     __m256 res = _mm256_cmp_ps(vec, vec, _CMP_NEQ_UQ);
     811            5 :     temp = temp | _mm256_movemask_ps(res);
     812              : 
     813            5 :     if (temp)
     814              :       return false;
     815              : 
     816              :     // check infinity in vec
     817              :     vec = _mm256_andnot_ps(SIGN_MASK, vec);
     818              :     vec = _mm256_cmp_ps(vec, INF, _CMP_EQ_OQ);
     819              : 
     820            4 :     temp = temp | _mm256_movemask_ps(vec);
     821              : 
     822            4 :     if (temp)
     823              :       return false;
     824              :   }
     825              : 
     826           11 :   while (idx < N) {
     827            8 :     if (!isFloatValid(*input)) {
     828              :       return false;
     829              :     }
     830            5 :     ++input;
     831            5 :     ++idx;
     832              :   }
     833              : 
     834              :   return true;
     835              : }
     836              : 
     837       173980 : void custom_scopy(const unsigned int N, const float *X, const int incX,
     838              :                   float *Y, const int incY) {
     839       173980 :   unsigned int N8 = (N >> 3) << 3;
     840       957465 :   for (unsigned int i = 0; i < N8; i += 8) {
     841              : #if defined(_WIN32)
     842              :     __m256 temp = _mm256_loadu_ps(&X[i]);
     843              :     _mm256_storeu_ps(&Y[i], temp);
     844              : #else
     845       783485 :     __asm__ __volatile__("vmovups (%1), %%ymm0\n\t"
     846              :                          "vmovups %%ymm0, (%0)\n\t"
     847              :                          :
     848       783485 :                          : "r"(&Y[i]), "r"(&X[i])
     849              :                          : "ymm0", "memory");
     850              : #endif
     851              :   }
     852       499415 :   for (unsigned int i = N8; i < N; ++i) {
     853       325435 :     Y[i] = X[i];
     854              :   }
     855       173980 : }
     856              : 
     857           25 : void transpose_matrix(const unsigned int M, const unsigned int N,
     858              :                       const float *src, unsigned int ld_src, float *dst,
     859              :                       unsigned int ld_dst) {
     860           25 :   unsigned int vindexm[8] = {0,          ld_src,     ld_src * 2, ld_src * 3,
     861           25 :                              ld_src * 4, ld_src * 5, ld_src * 6, ld_src * 7};
     862              :   __m256i vindex = _mm256_loadu_si256((__m256i *)&vindexm[0]);
     863              :   __m256 vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8;
     864              : 
     865           25 :   unsigned int M8 = (M & ~(7));
     866           25 :   unsigned int N8 = (N & ~(7));
     867         1378 :   for (unsigned int i = 0; i < M8; i += 8) {
     868       308626 :     for (unsigned int j = 0; j < N8; j += 8) {
     869              :       // loading from columns
     870       307273 :       vec1 = _mm256_i32gather_ps(&src[ld_src * i + j + 0], vindex, 4);
     871       307273 :       vec2 = _mm256_i32gather_ps(&src[ld_src * i + j + 1], vindex, 4);
     872       307273 :       vec3 = _mm256_i32gather_ps(&src[ld_src * i + j + 2], vindex, 4);
     873       307273 :       vec4 = _mm256_i32gather_ps(&src[ld_src * i + j + 3], vindex, 4);
     874       307273 :       vec5 = _mm256_i32gather_ps(&src[ld_src * i + j + 4], vindex, 4);
     875       307273 :       vec6 = _mm256_i32gather_ps(&src[ld_src * i + j + 5], vindex, 4);
     876       307273 :       vec7 = _mm256_i32gather_ps(&src[ld_src * i + j + 6], vindex, 4);
     877       307273 :       vec8 = _mm256_i32gather_ps(&src[ld_src * i + j + 7], vindex, 4);
     878              : 
     879              :       // storing to the rows
     880       307273 :       _mm256_storeu_ps(&dst[(j + 0) * ld_dst + i], vec1);
     881       307273 :       _mm256_storeu_ps(&dst[(j + 1) * ld_dst + i], vec2);
     882       307273 :       _mm256_storeu_ps(&dst[(j + 2) * ld_dst + i], vec3);
     883       307273 :       _mm256_storeu_ps(&dst[(j + 3) * ld_dst + i], vec4);
     884       307273 :       _mm256_storeu_ps(&dst[(j + 4) * ld_dst + i], vec5);
     885       307273 :       _mm256_storeu_ps(&dst[(j + 5) * ld_dst + i], vec6);
     886       307273 :       _mm256_storeu_ps(&dst[(j + 6) * ld_dst + i], vec7);
     887       307273 :       _mm256_storeu_ps(&dst[(j + 7) * ld_dst + i], vec8);
     888              :     }
     889              :   }
     890              : 
     891              :   // tailing right
     892        10910 :   for (unsigned int i = 0; i < M; i++) {
     893        12316 :     for (unsigned int j = N8; j < N; j++) {
     894         1431 :       dst[i + j * ld_dst] = src[i * ld_src + j];
     895              :     }
     896              :   }
     897              : 
     898              :   // tailing bottom
     899           86 :   for (unsigned int i = M8; i < M; i++) {
     900          356 :     for (unsigned int j = 0; j < N; j++) {
     901          295 :       dst[i + j * ld_dst] = src[i * ld_src + j];
     902              :     }
     903              :   }
     904           25 : }
     905              : 
     906            0 : void swiglu(const unsigned int N, float *X, const float *Y, const float *Z) {
     907              :   size_t i = 0;
     908              : 
     909              :   const auto oldcsr = _mm_getcsr();
     910            0 :   _mm_setcsr(oldcsr | 0x8040); // DAZ | FTZ
     911              : 
     912              :   // 16-wide blocks
     913            0 :   for (; i + 16 <= N; i += 16) {
     914            0 :     const __m256 y0 = _mm256_loadu_ps(Y + i);
     915            0 :     const __m256 y1 = _mm256_loadu_ps(Y + i + 8);
     916            0 :     const __m256 z0 = _mm256_loadu_ps(Z + i);
     917            0 :     const __m256 z1 = _mm256_loadu_ps(Z + i + 8);
     918              : 
     919            0 :     _mm256_storeu_ps(X + i, avx2_approx_swiglu(y0, z0));
     920            0 :     _mm256_storeu_ps(X + i + 8, avx2_approx_swiglu(y1, z1));
     921              :   }
     922              : 
     923              :   // One 8-wide block if available
     924            0 :   if (i + 8 <= N) {
     925            0 :     const __m256 y0 = _mm256_loadu_ps(Y + i);
     926            0 :     const __m256 z0 = _mm256_loadu_ps(Z + i);
     927            0 :     _mm256_storeu_ps(X + i, avx2_approx_swiglu(y0, z0));
     928              :     i += 8;
     929              :   }
     930              : 
     931              :   // Remaining 1..7 elements via maskload/maskstore
     932            0 :   if (i < N) {
     933            0 :     const int remain = static_cast<int>(N - i); // 1..7
     934              : 
     935            0 :     alignas(64) const int mtab[16] = {-1, -1, -1, -1, -1, -1, -1, -1,
     936              :                                       0,  0,  0,  0,  0,  0,  0,  0};
     937              :     // Start so that we take 'remain' ones then zeros.
     938            0 :     const int off = 8 - remain; // in [1..7], or 0 if remain==8
     939            0 :     const __m256i vmask = _mm256_loadu_si256((const __m256i *)(mtab + off));
     940              : 
     941            0 :     const __m256 y = _mm256_maskload_ps(Y + i, vmask);
     942            0 :     const __m256 z = _mm256_maskload_ps(Z + i, vmask);
     943              :     const __m256 r = avx2_approx_swiglu(y, z);
     944            0 :     _mm256_maskstore_ps(X + i, vmask, r);
     945              :   }
     946              : 
     947              :   _mm_setcsr(oldcsr);
     948            0 : }
     949              : 
     950            0 : void swiglu(const unsigned int N, float *X, const float *Y, const float *Z,
     951              :             float alpha) {
     952              :   size_t i = 0;
     953              : 
     954              :   const auto oldcsr = _mm_getcsr();
     955            0 :   _mm_setcsr(oldcsr | 0x8040); // DAZ | FTZ
     956              : 
     957              :   const __m256 alpha_vec = _mm256_set1_ps(alpha);
     958              : 
     959              :   // 16-wide blocks
     960            0 :   for (; i + 16 <= N; i += 16) {
     961            0 :     const __m256 y0 = _mm256_loadu_ps(Y + i);
     962            0 :     const __m256 y1 = _mm256_loadu_ps(Y + i + 8);
     963            0 :     const __m256 z0 = _mm256_loadu_ps(Z + i);
     964            0 :     const __m256 z1 = _mm256_loadu_ps(Z + i + 8);
     965              : 
     966            0 :     _mm256_storeu_ps(X + i, avx2_approx_swiglu_alpha(y0, z0, alpha_vec));
     967            0 :     _mm256_storeu_ps(X + i + 8, avx2_approx_swiglu_alpha(y1, z1, alpha_vec));
     968              :   }
     969              : 
     970              :   // One 8-wide block if present
     971            0 :   if (i + 8 <= N) {
     972            0 :     const __m256 y0 = _mm256_loadu_ps(Y + i);
     973            0 :     const __m256 z0 = _mm256_loadu_ps(Z + i);
     974            0 :     _mm256_storeu_ps(X + i, avx2_approx_swiglu_alpha(y0, z0, alpha_vec));
     975              :     i += 8;
     976              :   }
     977              : 
     978              :   // Remaining 1..7 elements via masked AVX (no stray stores)
     979            0 :   if (i < N) {
     980            0 :     const int remain = static_cast<int>(N - i); // 1..7
     981              : 
     982            0 :     alignas(64) const int mtab[16] = {
     983              :       -1, -1, -1, -1, -1, -1, -1, -1, // ones
     984              :       0,  0,  0,  0,  0,  0,  0,  0   // zeros
     985              :     };
     986            0 :     const int off = 8 - remain; // choose first `remain` lanes active
     987            0 :     const __m256i vmask = _mm256_loadu_si256((const __m256i *)(mtab + off));
     988              : 
     989            0 :     const __m256 y = _mm256_maskload_ps(Y + i, vmask);
     990            0 :     const __m256 z = _mm256_maskload_ps(Z + i, vmask);
     991              :     const __m256 r = avx2_approx_swiglu_alpha(y, z, alpha_vec);
     992            0 :     _mm256_maskstore_ps(X + i, vmask, r);
     993              :   }
     994              : 
     995              :   _mm_setcsr(oldcsr);
     996            0 : }
     997              : 
     998        29926 : void ele_mul(const unsigned int N, const float *X, const float *Y, float *Z,
     999              :              float alpha, float beta, unsigned int i_stride,
    1000              :              unsigned int o_stride) {
    1001        29926 :   if (alpha == 1.0f && beta == 0.0f && o_stride == 1) {
    1002        29581 :     unsigned int N8 = (N & ~(7));
    1003        29581 :     if (i_stride == 0) {
    1004         9591 :       float vy8[8] = {Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0]};
    1005              :       auto y = _mm256_loadu_ps(&vy8[0]);
    1006        18108 :       for (unsigned int i = 0; i < N8; i += 8) {
    1007              :         auto x = _mm256_loadu_ps(X);
    1008              :         auto z = _mm256_mul_ps(x, y);
    1009              :         _mm256_storeu_ps(Z, z);
    1010         8517 :         X += 8;
    1011         8517 :         Y += i_stride * 8;
    1012         8517 :         Z += 8;
    1013              :       }
    1014              :     } else {
    1015     13214695 :       for (unsigned int i = 0; i < N8; i += 8) {
    1016              :         auto x = _mm256_loadu_ps(X);
    1017              :         auto y = _mm256_loadu_ps(Y);
    1018              :         auto z = _mm256_mul_ps(x, y);
    1019              :         _mm256_storeu_ps(Z, z);
    1020     13194705 :         X += 8;
    1021     13194705 :         Y += i_stride * 8;
    1022     13194705 :         Z += 8;
    1023              :       }
    1024              :     }
    1025       138986 :     for (unsigned int i = N8; i < N; ++i) {
    1026       109405 :       *Z = *X * *Y;
    1027       109405 :       X++;
    1028       109405 :       Y += i_stride;
    1029       109405 :       Z++;
    1030              :     }
    1031              :   } else {
    1032              :     // TODO: AVX2 implementation if used
    1033        65505 :     for (unsigned int i = 0; i < N; ++i) {
    1034        65160 :       *Z = *X * alpha * *Y + ((0.0f == beta) ? 0.0f : beta * *Z);
    1035        65160 :       X += o_stride;
    1036        65160 :       Y += i_stride;
    1037        65160 :       Z += o_stride;
    1038              :     }
    1039              :   }
    1040        29926 : }
    1041              : 
    1042       113033 : void ele_add(const unsigned int N, const float *X, const float *Y, float *Z,
    1043              :              float alpha, float beta, unsigned int i_stride,
    1044              :              unsigned int o_stride) {
    1045       113033 :   if (alpha == 1.0f && beta == 0.0f && o_stride == 1) {
    1046        75820 :     unsigned int N8 = (N & ~(7));
    1047        75820 :     if (i_stride == 0) {
    1048        26375 :       float vy8[8] = {Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0]};
    1049              :       auto y = _mm256_loadu_ps(&vy8[0]);
    1050       580391 :       for (unsigned int i = 0; i < N8; i += 8) {
    1051              :         auto x = _mm256_loadu_ps(X);
    1052              :         auto z = _mm256_add_ps(x, y);
    1053              :         _mm256_storeu_ps(Z, z);
    1054       554016 :         X += 8;
    1055       554016 :         Y += i_stride * 8;
    1056       554016 :         Z += 8;
    1057              :       }
    1058              :     } else {
    1059       169478 :       for (unsigned int i = 0; i < N8; i += 8) {
    1060              :         auto x = _mm256_loadu_ps(X);
    1061              :         auto y = _mm256_loadu_ps(Y);
    1062              :         auto z = _mm256_add_ps(x, y);
    1063              :         _mm256_storeu_ps(Z, z);
    1064       120033 :         X += 8;
    1065       120033 :         Y += i_stride * 8;
    1066       120033 :         Z += 8;
    1067              :       }
    1068              :     }
    1069       234632 :     for (unsigned int i = N8; i < N; ++i) {
    1070       158812 :       *Z = *X + *Y;
    1071       158812 :       X++;
    1072       158812 :       Y += i_stride;
    1073       158812 :       Z++;
    1074              :     }
    1075              :   } else {
    1076              :     // TODO: AVX2 implementation if used
    1077    203432843 :     for (unsigned int i = 0; i < N; ++i) {
    1078    203395630 :       *Z = *X + alpha * *Y + ((0.0f == beta) ? 0.0f : beta * *Z);
    1079    203395630 :       X += o_stride;
    1080    203395630 :       Y += i_stride;
    1081    203395630 :       Z += o_stride;
    1082              :     }
    1083              :   }
    1084       113033 : }
    1085              : 
    1086            9 : static inline __m256 exp256_ps(__m256 x) {
    1087              :   /*  Low-Precision Version I*/
    1088              :   // const __m256 c1 = _mm256_set1_ps(12102203.0f);
    1089              :   // const __m256 c2 = _mm256_set1_ps(1065353216.0f);
    1090              :   // __m256 fx = _mm256_add_ps(_mm256_mul_ps(x, c1),c2);
    1091              :   // return _mm256_castsi256_ps(_mm256_cvtps_epi32(fx));
    1092              : 
    1093              :   /* Low-Precision Version II*/
    1094              :   /*    const __m256 ln2 = _mm256_set1_ps(0.69314718056f);
    1095              :     const __m256 inv_ln2 = _mm256_set1_ps(1.44269504089f); // 1 / ln(2)
    1096              : 
    1097              :     // Range reduction: x = n * ln2 + r,  where n is integer and |r| <= ln2/2
    1098              :     __m256 fx = _mm256_mul_ps(x, inv_ln2);
    1099              :     fx = _mm256_round_ps(fx, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
    1100              :     __m256i emm0 = _mm256_cvtps_epi32(fx);
    1101              : 
    1102              :     __m256 tmp = _mm256_mul_ps(fx, ln2);
    1103              :     __m256 r = _mm256_sub_ps(x, tmp);
    1104              : 
    1105              :     // Compute polynomial approximation of exp(r)
    1106              :     const __m256 c1 = _mm256_set1_ps(1.9875691500E-4f);
    1107              :     const __m256 c2 = _mm256_set1_ps(1.3981999507E-3f);
    1108              :     const __m256 c3 = _mm256_set1_ps(8.3334519073E-3f);
    1109              :     const __m256 c4 = _mm256_set1_ps(4.1665795894E-2f);
    1110              :     const __m256 c5 = _mm256_set1_ps(1.6666665459E-1f);
    1111              :     const __m256 c6 = _mm256_set1_ps(5.0000001201E-1f);
    1112              : 
    1113              :   //  __m256 r2 = _mm256_mul_ps(r, r);
    1114              :   //  __m256 r3 = _mm256_mul_ps(r2, r);
    1115              :   //  __m256 r4 = _mm256_mul_ps(r2, r2);
    1116              : 
    1117              :     __m256 y = _mm256_fmadd_ps(c1, r, c2);
    1118              :     y = _mm256_fmadd_ps(y, r, c3);
    1119              :     y = _mm256_fmadd_ps(y, r, c4);
    1120              :     y = _mm256_fmadd_ps(y, r, c5);
    1121              :     y = _mm256_fmadd_ps(y, r, c6);
    1122              :     y = _mm256_fmadd_ps(y, r, _mm256_set1_ps(1.0f));
    1123              : 
    1124              :     // Reconstruct exp(x) = 2^n * exp(r)
    1125              :     emm0 = _mm256_add_epi32(emm0, _mm256_set1_epi32(127));
    1126              :     emm0 = _mm256_slli_epi32(emm0, 23);
    1127              :     __m256 pow2n = _mm256_castsi256_ps(emm0);
    1128              : 
    1129              :     return _mm256_mul_ps(y, pow2n);
    1130              :   */
    1131              :   /* Low-Precision Versino III */
    1132              :   const __m256 LOG2EF = _mm256_set1_ps(1.44269504088896341f); // 1 / ln(2)
    1133              :   const __m256 LN2 = _mm256_set1_ps(0.6931471805599453f);     // ln(2)
    1134              : 
    1135              :   // Clamp input to range to prevent overflow/underflow
    1136              :   const __m256 max_x = _mm256_set1_ps(88.3762626647949f);  // log(FLT_MAX)
    1137              :   const __m256 min_x = _mm256_set1_ps(-88.3762626647949f); // log(FLT_MIN)
    1138              :   x = _mm256_max_ps(min_x, _mm256_min_ps(max_x, x));
    1139              : 
    1140              :   // Range reduction: x = n * ln2 + r
    1141              :   __m256 fx = _mm256_mul_ps(x, LOG2EF); // x * (1/ln(2))
    1142              :   fx = _mm256_floor_ps(_mm256_add_ps(fx, _mm256_set1_ps(0.5f)));
    1143              : 
    1144              :   __m256 tmp = _mm256_mul_ps(fx, LN2); // n * ln(2)
    1145              :   __m256 r = _mm256_sub_ps(x, tmp);    // r = x - n * ln2
    1146              : 
    1147              :   // Compute exp(r) using 10th-order polynomial (Horner's method)
    1148              :   const __m256 c0 = _mm256_set1_ps(1.0f);
    1149              :   const __m256 c1 = _mm256_set1_ps(1.0f);
    1150              :   const __m256 c2 = _mm256_set1_ps(0.5f);
    1151              :   const __m256 c3 = _mm256_set1_ps(1.0f / 6.0f);
    1152              :   const __m256 c4 = _mm256_set1_ps(1.0f / 24.0f);
    1153              :   const __m256 c5 = _mm256_set1_ps(1.0f / 120.0f);
    1154              :   const __m256 c6 = _mm256_set1_ps(1.0f / 720.0f);
    1155              :   const __m256 c7 = _mm256_set1_ps(1.0f / 5040.0f);
    1156              :   const __m256 c8 = _mm256_set1_ps(1.0f / 40320.0f);
    1157              :   const __m256 c9 = _mm256_set1_ps(1.0f / 362880.0f);
    1158              :   const __m256 c10 = _mm256_set1_ps(1.0f / 3628800.0f);
    1159              : 
    1160              :   __m256 y = c10;
    1161              :   y = _mm256_fmadd_ps(y, r, c9);
    1162              :   y = _mm256_fmadd_ps(y, r, c8);
    1163              :   y = _mm256_fmadd_ps(y, r, c7);
    1164              :   y = _mm256_fmadd_ps(y, r, c6);
    1165              :   y = _mm256_fmadd_ps(y, r, c5);
    1166              :   y = _mm256_fmadd_ps(y, r, c4);
    1167              :   y = _mm256_fmadd_ps(y, r, c3);
    1168              :   y = _mm256_fmadd_ps(y, r, c2);
    1169              :   y = _mm256_fmadd_ps(y, r, c1);
    1170              :   y = _mm256_fmadd_ps(y, r, c0); // final y = (((...r+...)*r+...)*r + 1)
    1171              : 
    1172              :   // Reconstruct exp(x) = 2^n * exp(r)
    1173              :   __m256i emm0 = _mm256_cvtps_epi32(fx);
    1174              :   emm0 = _mm256_add_epi32(emm0, _mm256_set1_epi32(127));
    1175              :   emm0 = _mm256_slli_epi32(emm0, 23);
    1176              :   __m256 pow2n = _mm256_castsi256_ps(emm0);
    1177              : 
    1178            9 :   return _mm256_mul_ps(y, pow2n);
    1179              : }
    1180              : 
    1181            1 : static void softmax_row_inplace(float *qk_out, size_t start_row, size_t end_row,
    1182              :                                 size_t num_heads) {
    1183            1 :   size_t row_range = end_row - start_row;
    1184            1 :   const size_t full_blocks = (num_heads / 8) * 8;
    1185              :   // const size_t remainder = num_heads % 8;
    1186              : 
    1187            1 :   float *max_vals = new float[num_heads];
    1188            1 :   float *sum_vals = new float[num_heads];
    1189              :   // 1. max
    1190           11 :   for (size_t c = 0; c < num_heads; ++c) {
    1191           10 :     float max_val = -INFINITY;
    1192           40 :     for (size_t r = start_row; r < end_row; ++r)
    1193           49 :       max_val = std::max(max_val, qk_out[r * num_heads + c]);
    1194           10 :     max_vals[c] = max_val;
    1195              :   }
    1196              : 
    1197              :   // 2. inplace exp + sum
    1198            2 :   for (size_t c = 0; c < full_blocks; c += 8) {
    1199            1 :     __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
    1200              :     __m256 sum = _mm256_setzero_ps();
    1201            4 :     for (size_t r = 0; r < row_range; ++r) {
    1202            3 :       float *ptr = &qk_out[(start_row + r) * num_heads + c];
    1203              :       __m256 val = _mm256_loadu_ps(ptr);
    1204            3 :       __m256 e = exp256_ps(_mm256_sub_ps(val, maxv));
    1205              :       _mm256_storeu_ps(ptr, e); // overwrite qk_out
    1206              :       sum = _mm256_add_ps(sum, e);
    1207              :     }
    1208            1 :     _mm256_storeu_ps(&sum_vals[c], sum);
    1209              :   }
    1210              : 
    1211            3 :   for (size_t c = full_blocks; c < num_heads; ++c) {
    1212              :     float sum = 0.0f;
    1213            2 :     float maxv = max_vals[c];
    1214            8 :     for (size_t r = 0; r < row_range; ++r) {
    1215            6 :       float &a = qk_out[(start_row + r) * num_heads + c];
    1216            6 :       a = std::exp(a - maxv); // overwrite qk_out
    1217            6 :       sum += a;
    1218              :     }
    1219            2 :     sum_vals[c] = sum;
    1220              :   }
    1221              :   // 3. softmax = exp / sum (inplace)
    1222            4 :   for (size_t r = 0; r < row_range; ++r) {
    1223            6 :     for (size_t c = 0; c < full_blocks; c += 8) {
    1224            3 :       float *ptr = &qk_out[(start_row + r) * num_heads + c];
    1225              :       __m256 val = _mm256_loadu_ps(ptr); // already exp(x - max)
    1226            3 :       __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
    1227              :       __m256 soft = _mm256_div_ps(val, sumv);
    1228              :       _mm256_storeu_ps(ptr, soft);
    1229              :     }
    1230            9 :     for (size_t c = full_blocks; c < num_heads; ++c) {
    1231            6 :       qk_out[(start_row + r) * num_heads + c] /= sum_vals[c];
    1232              :     }
    1233              :   }
    1234              : 
    1235            1 :   delete[] max_vals;
    1236            1 :   delete[] sum_vals;
    1237            1 : }
    1238              : 
    1239            0 : static void softmax_row_with_sink_inplace(float *qk_out, size_t start_row,
    1240              :                                           size_t end_row, size_t num_heads,
    1241              :                                           float *sink) {
    1242            0 :   size_t row_range = end_row - start_row;
    1243            0 :   const size_t full_blocks = (num_heads / 8) * 8;
    1244              : 
    1245            0 :   float *max_vals = new float[num_heads];
    1246            0 :   float *sum_vals = new float[num_heads];
    1247              :   // 1. max
    1248            0 :   for (size_t c = 0; c < num_heads; ++c) {
    1249            0 :     float max_val = -INFINITY;
    1250            0 :     for (size_t r = start_row; r < end_row; ++r)
    1251            0 :       max_val = std::max(max_val, qk_out[r * num_heads + c]);
    1252            0 :     max_vals[c] = std::max(sink[c], max_val);
    1253              :   }
    1254              : 
    1255              :   // 2. inplace exp + sum
    1256            0 :   for (size_t c = 0; c < full_blocks; c += 8) {
    1257            0 :     __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
    1258            0 :     __m256 sum = _mm256_loadu_ps(&sink[c]);
    1259            0 :     sum = exp256_ps(_mm256_sub_ps(sum, maxv));
    1260            0 :     for (size_t r = 0; r < row_range; ++r) {
    1261            0 :       float *ptr = &qk_out[(start_row + r) * num_heads + c];
    1262              :       __m256 val = _mm256_loadu_ps(ptr);
    1263            0 :       __m256 e = exp256_ps(_mm256_sub_ps(val, maxv));
    1264              :       _mm256_storeu_ps(ptr, e); // overwrite qk_out
    1265              :       sum = _mm256_add_ps(sum, e);
    1266              :     }
    1267            0 :     _mm256_storeu_ps(&sum_vals[c], sum);
    1268              :   }
    1269              : 
    1270            0 :   for (size_t c = full_blocks; c < num_heads; ++c) {
    1271            0 :     float maxv = max_vals[c];
    1272            0 :     float sum = std::exp(sink[c] - maxv);
    1273            0 :     for (size_t r = 0; r < row_range; ++r) {
    1274            0 :       float &a = qk_out[(start_row + r) * num_heads + c];
    1275            0 :       a = std::exp(a - maxv); // overwrite qk_out
    1276            0 :       sum += a;
    1277              :     }
    1278            0 :     sum_vals[c] = sum;
    1279              :   }
    1280              :   // 3. softmax = exp / sum (inplace)
    1281            0 :   for (size_t r = 0; r < row_range; ++r) {
    1282            0 :     for (size_t c = 0; c < full_blocks; c += 8) {
    1283            0 :       float *ptr = &qk_out[(start_row + r) * num_heads + c];
    1284              :       __m256 val = _mm256_loadu_ps(ptr); // already exp(x - max)
    1285            0 :       __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
    1286              :       __m256 soft = _mm256_div_ps(val, sumv);
    1287              :       _mm256_storeu_ps(ptr, soft);
    1288              :     }
    1289            0 :     for (size_t c = full_blocks; c < num_heads; ++c) {
    1290            0 :       qk_out[(start_row + r) * num_heads + c] /= sum_vals[c];
    1291              :     }
    1292              :   }
    1293              : 
    1294            0 :   delete[] max_vals;
    1295            0 :   delete[] sum_vals;
    1296            0 : }
    1297              : 
    1298              : template <>
    1299            1 : void softmax_row_inplace(float *qk_out, size_t start_row, size_t end_row,
    1300              :                          size_t num_heads, float *sink) {
    1301            1 :   if (sink == nullptr) {
    1302            1 :     return softmax_row_inplace(qk_out, start_row, end_row, num_heads);
    1303              :   } else {
    1304            0 :     return softmax_row_with_sink_inplace(qk_out, start_row, end_row, num_heads,
    1305            0 :                                          sink);
    1306              :   }
    1307              : }
    1308              : 
    1309            1 : static void softmax_row(float *qk_out, size_t start_row, size_t end_row,
    1310              :                         size_t num_heads) {
    1311            1 :   const size_t full_block = (num_heads / 8) * 8;
    1312              : 
    1313            1 :   float *max_vals = new float[num_heads];
    1314            1 :   float *sum_vals = new float[num_heads];
    1315              : 
    1316              :   // 1. Find Max along with col
    1317           11 :   for (size_t c = 0; c < num_heads; ++c) {
    1318           10 :     float max_val = -INFINITY;
    1319           40 :     for (size_t r = start_row; r < end_row; ++r) {
    1320           49 :       max_val = std::max(max_val, qk_out[r * num_heads + c]);
    1321              :     }
    1322           10 :     max_vals[c] = max_val;
    1323              :   }
    1324              : 
    1325              :   // 2. Compute sum along with col (exp vectorized)
    1326            2 :   for (size_t c = 0; c < full_block; c += 8) {
    1327              :     __m256 sum = _mm256_setzero_ps();
    1328            4 :     for (size_t r = start_row; r < end_row; ++r) {
    1329            3 :       __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
    1330            3 :       __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
    1331              :       __m256 sub = _mm256_sub_ps(val, maxv);
    1332            3 :       __m256 e = exp256_ps(sub);
    1333              :       sum = _mm256_add_ps(sum, e);
    1334              :     }
    1335            1 :     _mm256_storeu_ps(&sum_vals[c], sum);
    1336              :   }
    1337              : 
    1338            3 :   for (size_t c = full_block; c < num_heads; ++c) {
    1339              :     float sum = 0.0f;
    1340            8 :     for (size_t r = start_row; r < end_row; ++r) {
    1341            6 :       sum += std::exp(qk_out[r * num_heads + c] - max_vals[c]);
    1342              :     }
    1343            2 :     sum_vals[c] = sum;
    1344              :   }
    1345              : 
    1346              :   // 3. apply softmax
    1347            4 :   for (size_t r = start_row; r < end_row; ++r) {
    1348            6 :     for (size_t c = 0; c < full_block; c += 8) {
    1349            3 :       __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
    1350            3 :       __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
    1351              :       __m256 sub = _mm256_sub_ps(val, maxv);
    1352            3 :       __m256 e = exp256_ps(sub);
    1353            3 :       __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
    1354              :       __m256 softmax = _mm256_div_ps(e, sumv);
    1355              :       _mm256_storeu_ps(&qk_out[r * num_heads + c], softmax);
    1356              :     }
    1357            9 :     for (size_t c = full_block; c < num_heads; ++c) {
    1358            6 :       qk_out[r * num_heads + c] =
    1359            6 :         std::exp(qk_out[r * num_heads + c] - max_vals[c]) / sum_vals[c];
    1360              :     }
    1361              :   }
    1362              : 
    1363            1 :   delete[] max_vals;
    1364            1 :   delete[] sum_vals;
    1365            1 : }
    1366              : 
    1367            0 : static void softmax_row_with_sink(float *qk_out, size_t start_row,
    1368              :                                   size_t end_row, size_t num_heads,
    1369              :                                   float *sink) {
    1370            0 :   const size_t full_block = (num_heads / 8) * 8;
    1371              : 
    1372            0 :   float *max_vals = new float[num_heads];
    1373            0 :   float *sum_vals = new float[num_heads];
    1374              : 
    1375              :   // 1. Find Max along with col
    1376            0 :   for (size_t c = 0; c < num_heads; ++c) {
    1377            0 :     float max_val = -INFINITY;
    1378            0 :     for (size_t r = start_row; r < end_row; ++r) {
    1379            0 :       max_val = std::max(max_val, qk_out[r * num_heads + c]);
    1380              :     }
    1381            0 :     max_vals[c] = std::max(max_val, sink[c]);
    1382              :   }
    1383              : 
    1384              :   // 2. Compute sum along with col (exp vectorized)
    1385            0 :   for (size_t c = 0; c < full_block; c += 8) {
    1386            0 :     __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
    1387            0 :     __m256 sum = _mm256_loadu_ps(&sink[c]);
    1388              :     sum = _mm256_sub_ps(sum, maxv);
    1389            0 :     sum = exp256_ps(sum);
    1390            0 :     for (size_t r = start_row; r < end_row; ++r) {
    1391            0 :       __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
    1392              :       __m256 sub = _mm256_sub_ps(val, maxv);
    1393            0 :       __m256 e = exp256_ps(sub);
    1394              :       sum = _mm256_add_ps(sum, e);
    1395              :     }
    1396            0 :     _mm256_storeu_ps(&sum_vals[c], sum);
    1397              :   }
    1398              : 
    1399            0 :   for (size_t c = full_block; c < num_heads; ++c) {
    1400            0 :     float sum = std::exp(sink[c] - max_vals[c]);
    1401            0 :     for (size_t r = start_row; r < end_row; ++r) {
    1402            0 :       sum += std::exp(qk_out[r * num_heads + c] - max_vals[c]);
    1403              :     }
    1404            0 :     sum_vals[c] = sum;
    1405              :   }
    1406              : 
    1407              :   // 3. apply softmax
    1408            0 :   for (size_t r = start_row; r < end_row; ++r) {
    1409            0 :     for (size_t c = 0; c < full_block; c += 8) {
    1410            0 :       __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
    1411            0 :       __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
    1412              :       __m256 sub = _mm256_sub_ps(val, maxv);
    1413            0 :       __m256 e = exp256_ps(sub);
    1414            0 :       __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
    1415              :       __m256 softmax = _mm256_div_ps(e, sumv);
    1416              :       _mm256_storeu_ps(&qk_out[r * num_heads + c], softmax);
    1417              :     }
    1418            0 :     for (size_t c = full_block; c < num_heads; ++c) {
    1419            0 :       qk_out[r * num_heads + c] =
    1420            0 :         std::exp(qk_out[r * num_heads + c] - max_vals[c]) / sum_vals[c];
    1421              :     }
    1422              :   }
    1423              : 
    1424            0 :   delete[] max_vals;
    1425            0 :   delete[] sum_vals;
    1426            0 : }
    1427              : 
    1428              : template <>
    1429            1 : void softmax_row(float *qk_out, size_t start_row, size_t end_row,
    1430              :                  size_t num_heads, float *sink) {
    1431            1 :   if (sink == nullptr) {
    1432            1 :     return softmax_row(qk_out, start_row, end_row, num_heads);
    1433              :   } else {
    1434            0 :     return softmax_row_with_sink(qk_out, start_row, end_row, num_heads, sink);
    1435              :   }
    1436              : }
    1437              : #ifdef _WIN32
    1438              : #define COMPUTE_FP16_TO_FP32(x)                                                \
    1439              :   _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
    1440              : #define COMPUTE_FP32_TO_FP16(x)                                                \
    1441              :   _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
    1442              : #elif defined(__TIZEN__) && !defined(__F16C__)
    1443              : #define COMPUTE_FP16_TO_FP32(x) nntrainer::compute_fp16_to_fp32(x)
    1444              : #define COMPUTE_FP32_TO_FP16(x) nntrainer::compute_fp32_to_fp16(x)
    1445              : #else
    1446              : #define COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
    1447              : #define COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
    1448              : #endif
    1449              : 
    1450              : static inline __m256 convert_vector_f16_to_f32(__m128i x) {
    1451              : #if defined(__TIZEN__) && !defined(__F16C__)
    1452              :   alignas(32) uint16_t u16_array[8]; // 32-byte aligned storage
    1453              :   alignas(32) float f32_array[8];    // 32-byte aligned storage
    1454              : 
    1455              :   // Safely store __m128i to array (avoids aliasing)
    1456              :   _mm_storeu_si128(reinterpret_cast<__m128i *>(u16_array), x);
    1457              : 
    1458              :   // Convert each FP16 value to FP32
    1459              :   for (int i = 0; i < 8; i++) {
    1460              :     f32_array[i] = COMPUTE_FP16_TO_FP32(u16_array[i]);
    1461              :   }
    1462              : 
    1463              :   // Load aligned array into __m256
    1464              :   return _mm256_load_ps(f32_array);
    1465              : #else
    1466              :   return _mm256_cvtph_ps(x);
    1467              : #endif
    1468              : }
    1469              : 
    1470              : static inline __m128i convert_vector_f32_to_f16(__m256 x) {
    1471              : #if defined(__TIZEN__) && !defined(__F16C__)
    1472              :   __m128i vec_f16;
    1473              :   float *f32_ptr = reinterpret_cast<float *>(&x);
    1474              :   uint16_t *u16_ptr = reinterpret_cast<uint16_t *>(&vec_f16);
    1475              :   for (int i = 0; i < 8; i++) {
    1476              :     u16_ptr[i] = COMPUTE_FP32_TO_FP16(f32_ptr[i]);
    1477              :   }
    1478              :   return vec_f16;
    1479              : #else
    1480              :   return _mm256_cvtps_ph(x, 0);
    1481              : #endif
    1482              : }
    1483              : 
    1484              : static inline __m128i convert_vector_f32_to_f16(__m128 x) {
    1485              : #if defined(__TIZEN__) && !defined(__F16C__)
    1486              :   __m128i vec_f16;
    1487              :   float *f32_ptr = reinterpret_cast<float *>(&x);
    1488              :   uint16_t *u16_ptr = reinterpret_cast<uint16_t *>(&vec_f16);
    1489              : 
    1490              :   for (int i = 0; i < 4; i++) {
    1491              :     u16_ptr[i] = COMPUTE_FP32_TO_FP16(f32_ptr[i]);
    1492              :   }
    1493              :   return vec_f16;
    1494              : #else
    1495              :   return _mm_cvtps_ph(x, 0);
    1496              : #endif
    1497              : }
    1498              : 
    1499           12 : static inline void load_fp16_8_to_chunk(const uint16_t *src, float *dst,
    1500              :                                         int chunk_size) {
    1501              :   int i = 0;
    1502           24 :   for (; i + 8 <= chunk_size; i += 8) {
    1503           12 :     __m128i half = _mm_loadu_si128(reinterpret_cast<const __m128i *>(src + i));
    1504              :     __m256 f32 = convert_vector_f16_to_f32(half);
    1505           12 :     _mm256_storeu_ps(&dst[i], f32);
    1506              :   }
    1507           32 :   for (; i < chunk_size; ++i) {
    1508           20 :     dst[i] = nntrainer::compute_fp16_to_fp32(src[i]);
    1509              :   }
    1510           12 : }
    1511              : 
    1512            1 : void compute_fp16vcache_fp32_transposed(int row_num, const float *in,
    1513              :                                         const uint16_t *vcache, float *output,
    1514              :                                         int num_cache_head, int gqa_size,
    1515              :                                         int head_dim,
    1516              :                                         size_t local_window_size) {
    1517              :   // cpu_set_t cpu_set;
    1518              :   // CPU_ZERO(&cpu_set);
    1519              :   // std::vector<bool> affinity(8, false);
    1520              :   // affinity[i % affinity.size()] = true;
    1521              : 
    1522              :   // for (std::size_t j = 0;
    1523              :   //      j < std::min<std::size_t>(affinity.size(), CPU_SETSIZE); ++j) {
    1524              :   //   if (affinity[j])
    1525              :   //     CPU_SET(j, &cpu_set);
    1526              :   // }
    1527              :   // pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpu_set);
    1528              : 
    1529            1 :   std::vector<float> tmp_fp32(head_dim);
    1530            1 :   int num_blocks = head_dim / 8;
    1531            2 :   __m256 *sumVec = new __m256[std::max(1, num_blocks * gqa_size)];
    1532              : 
    1533            3 :   for (int n = 0; n < num_cache_head; ++n) {
    1534            2 :     int rem = head_dim % 8;
    1535              : 
    1536              :     /* Declaration: std::vector<__m256> sumVec(num_blocks * gqa_size,
    1537              :      * _mm256_setzero_ps()); caused warning: ignoring attributes on template
    1538              :      * argument ‘__m256’ [-Wignored-attributes].
    1539              :      * So it is implemented that way.
    1540              :      */
    1541            6 :     for (int i = 0; i < num_blocks * gqa_size; i++) {
    1542            4 :       sumVec[i] = _mm256_setzero_ps();
    1543              :     }
    1544            2 :     std::vector<float> sumRem((size_t)gqa_size * rem, 0.0f);
    1545              : 
    1546            6 :     for (int j = row_num < local_window_size ? 0
    1547            0 :                                              : row_num + 1 - local_window_size;
    1548            6 :          j <= row_num; ++j) {
    1549            4 :       const uint16_t *vptr = vcache + (j * num_cache_head + n) * head_dim;
    1550            4 :       load_fp16_8_to_chunk(vptr, tmp_fp32.data(), head_dim);
    1551              : 
    1552           12 :       for (int h = 0; h < gqa_size; ++h) {
    1553            8 :         float a_val =
    1554              :           in[(row_num < local_window_size
    1555            8 :                 ? j
    1556            0 :                 : (unsigned long)(j - (row_num + 1 - local_window_size))) *
    1557            8 :                (unsigned long)(gqa_size * num_cache_head) +
    1558            8 :              (unsigned long)(n * gqa_size) + h];
    1559              : 
    1560              :         __m256 inVec = _mm256_set1_ps(a_val);
    1561              : 
    1562           16 :         for (int b = 0; b < num_blocks; ++b) {
    1563            8 :           __m256 bVec = _mm256_loadu_ps(&tmp_fp32[b * 8]);
    1564            8 :           sumVec[h * num_blocks + b] =
    1565            8 :             _mm256_fmadd_ps(inVec, bVec, sumVec[h * num_blocks + b]);
    1566              :         }
    1567              : 
    1568            8 :         float *remPtr = &sumRem.data()[h * rem];
    1569            8 :         int base = num_blocks * 8;
    1570           16 :         for (int r = 0; r < rem; ++r) {
    1571            8 :           remPtr[r] += a_val * tmp_fp32[base + r];
    1572              :         }
    1573              :       }
    1574              :     }
    1575              : 
    1576            6 :     for (int h = 0; h < gqa_size; ++h) {
    1577            8 :       for (int b = 0; b < num_blocks; ++b) {
    1578            4 :         int out_base = (n * gqa_size + h) * head_dim + b * 8;
    1579            4 :         _mm256_storeu_ps(&output[out_base], sumVec[h * num_blocks + b]);
    1580              :       }
    1581              : 
    1582            4 :       float *remPtr = &sumRem.data()[h * rem];
    1583              :       // float *remPtr = &sumRem[h * rem];
    1584            4 :       int base = num_blocks * 8;
    1585            8 :       for (int r = 0; r < rem; ++r) {
    1586            4 :         int out_idx = (n * gqa_size + h) * head_dim + base + r;
    1587            4 :         output[out_idx] = remPtr[r];
    1588              :       }
    1589              :     }
    1590            2 :   }
    1591            1 :   delete[] sumVec;
    1592            1 : }
    1593              : 
    1594              : template <>
    1595            1 : void compute_kcaches(const float *in, const uint16_t *kcache, float *output,
    1596              :                      int num_rows, int num_cache_head, int head_dim,
    1597              :                      int gqa_size, int tile_size, size_t local_window_size) {
    1598            1 :   std::vector<float> tmp_fp32(head_dim);
    1599              : 
    1600            1 :   int start_row =
    1601            1 :     num_rows < local_window_size ? 0 : num_rows - local_window_size;
    1602            1 :   int row_cnt = num_rows < local_window_size ? num_rows : local_window_size;
    1603            1 :   const int tile_count = (row_cnt + tile_size - 1) / tile_size;
    1604              : 
    1605            3 :   for (int n = 0; n < num_cache_head; ++n) {
    1606            4 :     for (int t = 0; t < tile_count; ++t) {
    1607            2 :       int row_tile_start = t * tile_size;
    1608            2 :       int tile_rows = std::min(tile_size, row_cnt - row_tile_start);
    1609              : 
    1610           10 :       for (int g = 0; g < gqa_size; ++g) {
    1611            8 :         const float *in_ptr = in + n * gqa_size * head_dim + g * head_dim;
    1612           16 :         for (int t_row = 0; t_row < tile_rows; ++t_row) {
    1613            8 :           int row = start_row + row_tile_start + t_row;
    1614            8 :           if (row + 1 < num_rows) {
    1615            0 :             const uint16_t *next_kptr =
    1616            0 :               kcache + ((row + 1) * num_cache_head + n) * head_dim;
    1617              :             _mm_prefetch(reinterpret_cast<const char *>(next_kptr),
    1618              :                          _MM_HINT_T0);
    1619              :           }
    1620            8 :           const uint16_t *kptr = kcache + (row * num_cache_head + n) * head_dim;
    1621            8 :           load_fp16_8_to_chunk(kptr, tmp_fp32.data(), head_dim);
    1622              : 
    1623              :           const float *k_row = tmp_fp32.data();
    1624              : 
    1625              :           float sum = 0.0f;
    1626              :           int i = 0;
    1627              :           __m256 acc = _mm256_setzero_ps();
    1628           16 :           for (; i + 8 <= head_dim; i += 8) {
    1629            8 :             __m256 va = _mm256_loadu_ps(in_ptr + i);
    1630            8 :             __m256 vb = _mm256_loadu_ps(k_row + i);
    1631              :             acc = _mm256_fmadd_ps(va, vb, acc);
    1632              :           }
    1633              : 
    1634              :           __m128 low = _mm256_castps256_ps128(acc);
    1635              :           __m128 high = _mm256_extractf128_ps(acc, 1);
    1636              :           __m128 sum128 = _mm_add_ps(low, high);
    1637              :           sum128 = _mm_hadd_ps(sum128, sum128);
    1638              :           sum128 = _mm_hadd_ps(sum128, sum128);
    1639            8 :           sum += _mm_cvtss_f32(sum128);
    1640              : 
    1641           24 :           for (; i < head_dim; ++i)
    1642           16 :             sum += in_ptr[i] * k_row[i];
    1643              : 
    1644            8 :           output[(row - start_row) * num_cache_head * gqa_size + n * gqa_size +
    1645            8 :                  g] = sum / sqrt((float)head_dim);
    1646              :         }
    1647              :       }
    1648              :     }
    1649              :   }
    1650            1 : }
    1651              : 
    1652            2 : void compute_rotary_emb_value(unsigned int width, unsigned int dim,
    1653              :                               unsigned int half_, float *inout, void *output,
    1654              :                               const float *cos_, const float *sin_,
    1655              :                               bool only_convert_to_fp16) {
    1656              :   enum class OutputType { FP16, FP32 };
    1657              : 
    1658              :   OutputType out_type = OutputType::FP32;
    1659            2 :   if (output != nullptr)
    1660              :     out_type = OutputType::FP16;
    1661              : 
    1662            6 :   for (unsigned int w = 0; w < width; w += dim) {
    1663              :     unsigned int k = 0;
    1664            8 :     for (; k + 7 < half_; k += 8) {
    1665            4 :       unsigned int i0 = w + k;
    1666            4 :       unsigned int i1 = w + k + half_;
    1667              : 
    1668            4 :       __m256 a = _mm256_loadu_ps(&inout[i0]);
    1669            4 :       __m256 b = _mm256_loadu_ps(&inout[i1]);
    1670              : 
    1671            4 :       if (only_convert_to_fp16) {
    1672            0 :         if (out_type == OutputType::FP16) {
    1673              :           __m128i a_fp16 = convert_vector_f32_to_f16(a);
    1674              :           __m128i b_fp16 = convert_vector_f32_to_f16(b);
    1675              : 
    1676            0 :           _mm_storeu_si128(
    1677            0 :             reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i0),
    1678              :             a_fp16);
    1679            0 :           _mm_storeu_si128(
    1680            0 :             reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i1),
    1681              :             b_fp16);
    1682              :         }
    1683              : 
    1684              :       } else {
    1685            4 :         __m256 cos_v = _mm256_loadu_ps(&cos_[k]);
    1686            4 :         __m256 sin_v = _mm256_loadu_ps(&sin_[k]);
    1687              : 
    1688              :         __m256 out0 =
    1689              :           _mm256_sub_ps(_mm256_mul_ps(a, cos_v), _mm256_mul_ps(b, sin_v));
    1690              :         __m256 out1 =
    1691              :           _mm256_add_ps(_mm256_mul_ps(a, sin_v), _mm256_mul_ps(b, cos_v));
    1692              : 
    1693            4 :         if (out_type == OutputType::FP16) {
    1694              :           __m128i out0_fp16 = convert_vector_f32_to_f16(out0);
    1695              :           __m128i out1_fp16 = convert_vector_f32_to_f16(out1);
    1696              : 
    1697            2 :           _mm_storeu_si128(
    1698            2 :             reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i0),
    1699              :             out0_fp16);
    1700            2 :           _mm_storeu_si128(
    1701            2 :             reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i1),
    1702              :             out1_fp16);
    1703              : 
    1704              :         } else if (out_type == OutputType::FP32) {
    1705              :           _mm256_storeu_ps(&inout[i0], out0);
    1706              :           _mm256_storeu_ps(&inout[i1], out1);
    1707              :         }
    1708              :       }
    1709              :     }
    1710              : 
    1711           12 :     for (; k < half_; ++k) {
    1712            8 :       unsigned int i0 = w + k;
    1713            8 :       unsigned int i1 = w + k + half_;
    1714              :       // assert(i1 < width && "Scalar i1 overflow!");
    1715            8 :       float a = inout[i0];
    1716            8 :       float b = inout[i1];
    1717              : 
    1718            8 :       if (only_convert_to_fp16) {
    1719            0 :         static_cast<uint16_t *>(output)[i0] = COMPUTE_FP32_TO_FP16(a);
    1720            0 :         static_cast<uint16_t *>(output)[i1] = COMPUTE_FP32_TO_FP16(b);
    1721              :       } else {
    1722            8 :         float c = cos_[k];
    1723            8 :         float s = sin_[k];
    1724              : 
    1725            8 :         float out0 = a * c - b * s;
    1726            8 :         float out1 = a * s + b * c;
    1727              : 
    1728            8 :         if (out_type == OutputType::FP16) {
    1729            4 :           static_cast<uint16_t *>(output)[i0] = COMPUTE_FP32_TO_FP16(out0);
    1730            8 :           static_cast<uint16_t *>(output)[i1] = COMPUTE_FP32_TO_FP16(out1);
    1731              :         } else if (out_type == OutputType::FP32) {
    1732            4 :           inout[i0] = out0;
    1733            4 :           inout[i1] = out1;
    1734              :         }
    1735              :       }
    1736              :     }
    1737              :   }
    1738            2 : }
    1739              : 
    1740              : static float hsum_avx(__m256 v) {
    1741              :   __m128 vlow = _mm256_castps256_ps128(v);
    1742              :   __m128 vhigh = _mm256_extractf128_ps(v, 1);
    1743              :   vlow = _mm_add_ps(vlow, vhigh);
    1744              :   __m128 shuf = _mm_movehdup_ps(vlow);
    1745              :   __m128 sums = _mm_add_ps(vlow, shuf);
    1746              :   shuf = _mm_movehl_ps(shuf, sums);
    1747              :   sums = _mm_add_ss(sums, shuf);
    1748              :   return _mm_cvtss_f32(sums);
    1749              : }
    1750              : 
    1751            0 : void rms_norm_wrt_width_fp32_intrinsic(const float *__restrict X,
    1752              :                                        float *__restrict Y, size_t H, size_t W,
    1753              :                                        float epsilon) {
    1754            0 :   for (std::size_t h = 0; h < H; ++h) {
    1755            0 :     const float *rowX = X + h * W;
    1756            0 :     float *rowY = Y + h * W;
    1757              : 
    1758              :     std::size_t i = 0;
    1759              :     __m256 acc0 = _mm256_setzero_ps();
    1760              :     __m256 acc1 = _mm256_setzero_ps();
    1761              :     __m256 acc2 = _mm256_setzero_ps();
    1762              :     __m256 acc3 = _mm256_setzero_ps();
    1763              : 
    1764            0 :     for (; i + 32 <= W; i += 32) {
    1765            0 :       __m256 x0 = _mm256_loadu_ps(rowX + i);
    1766            0 :       __m256 x1 = _mm256_loadu_ps(rowX + i + 8);
    1767            0 :       __m256 x2 = _mm256_loadu_ps(rowX + i + 16);
    1768            0 :       __m256 x3 = _mm256_loadu_ps(rowX + i + 24);
    1769              :       acc0 = _mm256_fmadd_ps(x0, x0, acc0);
    1770              :       acc1 = _mm256_fmadd_ps(x1, x1, acc1);
    1771              :       acc2 = _mm256_fmadd_ps(x2, x2, acc2);
    1772              :       acc3 = _mm256_fmadd_ps(x3, x3, acc3);
    1773              :     }
    1774            0 :     for (; i + 8 <= W; i += 8) {
    1775            0 :       __m256 x = _mm256_loadu_ps(rowX + i);
    1776              :       acc0 = _mm256_fmadd_ps(x, x, acc0);
    1777              :     }
    1778              :     float sumsq =
    1779            0 :       hsum_avx(acc0) + hsum_avx(acc1) + hsum_avx(acc2) + hsum_avx(acc3);
    1780            0 :     for (; i < W; ++i) {
    1781            0 :       float v = rowX[i];
    1782            0 :       sumsq += v * v;
    1783              :     }
    1784              : 
    1785            0 :     float mean = sumsq / static_cast<float>(W);
    1786            0 :     float scale = 1.0f / std::sqrt(mean + epsilon);
    1787              :     __m256 vscale = _mm256_set1_ps(scale);
    1788              : 
    1789              :     i = 0;
    1790            0 :     for (; i + 32 <= W; i += 32) {
    1791            0 :       __m256 x0 = _mm256_loadu_ps(rowX + i);
    1792            0 :       __m256 x1 = _mm256_loadu_ps(rowX + i + 8);
    1793            0 :       __m256 x2 = _mm256_loadu_ps(rowX + i + 16);
    1794            0 :       __m256 x3 = _mm256_loadu_ps(rowX + i + 24);
    1795            0 :       _mm256_storeu_ps(rowY + i, _mm256_mul_ps(x0, vscale));
    1796            0 :       _mm256_storeu_ps(rowY + i + 8, _mm256_mul_ps(x1, vscale));
    1797            0 :       _mm256_storeu_ps(rowY + i + 16, _mm256_mul_ps(x2, vscale));
    1798            0 :       _mm256_storeu_ps(rowY + i + 24, _mm256_mul_ps(x3, vscale));
    1799              :     }
    1800            0 :     for (; i + 8 <= W; i += 8) {
    1801            0 :       __m256 x = _mm256_loadu_ps(rowX + i);
    1802            0 :       _mm256_storeu_ps(rowY + i, _mm256_mul_ps(x, vscale));
    1803              :     }
    1804            0 :     for (; i < W; ++i) {
    1805            0 :       rowY[i] = rowX[i] * scale;
    1806              :     }
    1807              :   }
    1808            0 : }
    1809              : 
    1810              : template <>
    1811           21 : void clamp(const float *input, float *output, size_t length, float lower_bound,
    1812              :            float upper_bound) {
    1813              :   const size_t step = 8;
    1814              :   const __m256 vLo = _mm256_set1_ps(lower_bound);
    1815              :   const __m256 vHi = _mm256_set1_ps(upper_bound);
    1816              : 
    1817              :   size_t i = 0;
    1818         8085 :   for (; i + step <= length; i += step) {
    1819         8064 :     __m256 v = _mm256_loadu_ps(input + i);
    1820              :     v = _mm256_max_ps(v, vLo);
    1821              :     v = _mm256_min_ps(v, vHi);
    1822         8064 :     _mm256_storeu_ps(output + i, v);
    1823              :   }
    1824           21 :   if (i < length) {
    1825            0 :     for (size_t k = i; k < length; ++k) {
    1826            0 :       float v = input[k];
    1827              :       // If v is NaN, the comparisons below will yield false; we keep NaN.
    1828              :       // This matches most framework "pass-through NaN" behavior.
    1829            0 :       output[k] =
    1830            0 :         (v < lower_bound) ? lower_bound : ((v > upper_bound) ? upper_bound : v);
    1831              :     }
    1832              :   }
    1833           21 : }
    1834              : 
    1835            1 : void copy_f16_f32(unsigned int N, const uint16_t *input, float *output) {
    1836              :   unsigned int idx = 0;
    1837              :   const uint16_t *data = (const uint16_t *)input;
    1838              : 
    1839              :   // 16 half-precision floating point values to single-precision values
    1840            6 :   for (; N - idx >= 16; idx += 16) {
    1841              :     const __m256 vec0 =
    1842              :       convert_vector_f16_to_f32(_mm_loadu_si128((const __m128i *)data));
    1843              :     const __m256 vec1 =
    1844              :       convert_vector_f16_to_f32(_mm_loadu_si128((const __m128i *)(data + 8)));
    1845            5 :     data += 16;
    1846              : 
    1847              :     _mm256_storeu_ps(output, vec0);
    1848              :     _mm256_storeu_ps(output + 8, vec1);
    1849            5 :     output += 16;
    1850              :   }
    1851              :   // 8 half-precision floating point values to single-precision values
    1852            2 :   for (; N - idx >= 8; idx += 8) {
    1853              :     const __m256 vec =
    1854              :       convert_vector_f16_to_f32(_mm_loadu_si128((const __m128i *)data));
    1855            1 :     data += 8;
    1856              : 
    1857              :     _mm256_storeu_ps(output, vec);
    1858            1 :     output += 8;
    1859              :   }
    1860              :   // remaining half-precision floating point values to single-precision values
    1861            3 :   while (idx < N) {
    1862            2 :     *output = compute_fp16_to_fp32(*data);
    1863            2 :     ++output;
    1864            2 :     ++data;
    1865            2 :     ++idx;
    1866              :   }
    1867            1 : }
    1868              : 
    1869            0 : void copy_f32_f16(unsigned int N, const float *input, uint16_t *output) {
    1870              :   unsigned int idx = 0;
    1871              :   uint16_t *out_data = (uint16_t *)output;
    1872              : 
    1873              :   // 16 single-precision floating point values to half-precision values
    1874            0 :   for (; N - idx >= 16; idx += 16) {
    1875              :     const __m256 vec0 = _mm256_loadu_ps(input);
    1876              :     const __m256 vec1 = _mm256_loadu_ps(input + 8);
    1877            0 :     input += 16;
    1878              : 
    1879              :     _mm_storeu_si128((__m128i *)out_data, convert_vector_f32_to_f16(vec0));
    1880              :     _mm_storeu_si128((__m128i *)(out_data + 8),
    1881              :                      convert_vector_f32_to_f16(vec1));
    1882            0 :     out_data += 16;
    1883              :   }
    1884              :   // 8 single-precision floating point values to half-precision values
    1885            0 :   for (; N - idx >= 8; idx += 8) {
    1886              :     const __m256 vec = _mm256_loadu_ps(input);
    1887            0 :     input += 8;
    1888              : 
    1889              :     _mm_storeu_si128((__m128i *)out_data, convert_vector_f32_to_f16(vec));
    1890            0 :     out_data += 8;
    1891              :   }
    1892              :   // 4 single-precision floating point values to half-precision values
    1893            0 :   for (; N - idx >= 4; idx += 4) {
    1894              :     const __m128 vec = _mm_loadu_ps(input);
    1895            0 :     input += 4;
    1896              : 
    1897              :     _mm_storeu_si64((__m128i *)out_data, convert_vector_f32_to_f16(vec));
    1898            0 :     out_data += 4;
    1899              :   }
    1900              :   // remaining single-precision floating point values to half-precision values
    1901            0 :   while (idx < N) {
    1902            0 :     *out_data = compute_fp32_to_fp16(*input);
    1903            0 :     ++out_data;
    1904            0 :     ++input;
    1905            0 :     ++idx;
    1906              :   }
    1907            0 : }
    1908              : 
    1909      1752656 : void create_q4_0_weights(const uint8_t *int4_weight, uint8_t *q4_0_weight) {
    1910              :   // Load 16 bytes of input data
    1911              :   __m128i input = _mm_loadu_si128((const __m128i *)int4_weight);
    1912              : 
    1913              :   // Create masks for extracting low and high nibbles
    1914              :   const __m128i low_nibble_mask = _mm_set1_epi8(0x0F);
    1915              :   const __m128i high_nibble_mask = _mm_set1_epi8(static_cast<char>(0xF0));
    1916              : 
    1917              :   // Extract low nibbles from first 8 bytes
    1918              :   __m128i A = _mm_and_si128(input, low_nibble_mask);
    1919              : 
    1920              :   // Extract high nibbles from first 8 bytes and shift right
    1921              :   __m128i B = _mm_and_si128(input, high_nibble_mask);
    1922              :   B = _mm_srli_epi16(B, 4);
    1923              : 
    1924              :   // Extract low nibbles from second 8 bytes
    1925              :   __m128i input_shifted = _mm_bsrli_si128(input, 8);
    1926              :   __m128i C = _mm_and_si128(input_shifted, low_nibble_mask);
    1927              : 
    1928              :   // Extract high nibbles from second 8 bytes and shift right
    1929              :   __m128i D = _mm_and_si128(input_shifted, high_nibble_mask);
    1930              :   D = _mm_srli_epi16(D, 4);
    1931              : 
    1932              :   // Interleave low nibbles: v0 from first8, v2 from second8
    1933              :   __m128i AC = _mm_or_si128(A, _mm_slli_epi16(C, 4));
    1934              : 
    1935              :   // Interleave high nibbles: v1 from first8, v3 from second8
    1936              :   __m128i BD = _mm_or_si128(B, _mm_slli_epi16(D, 4));
    1937              : 
    1938              :   // Pack the results: interleave low and high bytes
    1939              :   __m128i result = _mm_unpacklo_epi8(AC, BD);
    1940              : 
    1941              :   // Store the 16 bytes result
    1942              :   _mm_storeu_si128((__m128i *)q4_0_weight, result);
    1943      1752656 : }
    1944              : 
    1945              : } // namespace nntrainer::avx2
        

Generated by: LCOV version 2.0-1