Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2023 Donghyeon Jeong <dhyeon.jeong@samsung.com>
4 : *
5 : * @file avx2_impl.cpp
6 : * @date 20 Feb 2024
7 : * @see https://github.com/nntrainer/nntrainer
8 : * @author Donghyeon Jeong <dhyeon.jeong@samsung.com>
9 : * @author Sungsik Kong <ss.kong@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : * @brief This is a source for AVX implementation
12 : *
13 : */
14 :
15 : #include "avx2_impl.h"
16 : #include <array>
17 : #if __has_include(<bit>)
18 : #include <bit>
19 : #endif
20 : #include <cassert>
21 : #include <cmath>
22 : #include <cstdint>
23 : #include <cstring>
24 : #include <fp16.h>
25 : #include <immintrin.h>
26 : #include <limits>
27 : #if __has_include(<numbers>)
28 : #include <numbers>
29 : #endif
30 : #include <type_traits>
31 : #if __has_include(<version>)
32 : #include <version>
33 : #endif
34 : #include <fallback_internal.h>
35 : #include <util_func.h>
36 : #include <vector>
37 :
38 : #include "nntr_ggml_impl_common.h"
39 :
40 : #if !defined(__has_constexpr_builtin)
41 : #define __has_constexpr_builtin(x) (0)
42 : #endif
43 :
44 : #if !defined(__has_cpp_attribute)
45 : #define __has_cpp_attribute(x) (0)
46 : #endif
47 :
48 : // VECTORCALL calling-conv (default for x86_64-linux-gnu)
49 : #if _MSC_VER >= 1700
50 : #define _nnt_CC_VECTORCALL __vectorcall
51 : #else
52 : #define _nnt_CC_VECTORCALL
53 : #endif
54 :
55 : // Flatten attribute
56 : #if _MSC_VER >= 1700 || __has_cpp_attribute(msvc::flatten)
57 : #define _nnt_ATTR_FLATTEN [[msvc::flatten]]
58 : #elif __has_cpp_attribute(gnu::flatten)
59 : // clang, g++
60 : #define _nnt_ATTR_FLATTEN [[gnu::flatten]]
61 : #else
62 : #define _nnt_ATTR_FLATTEN
63 : #endif
64 :
65 : #if _MSC_VER >= 1700 || __has_cpp_attribute(msvc::noinline)
66 : #define _nnt_ATTR_NOINLINE [[msvc::noinline]]
67 : #elif __has_cpp_attribute(gnu::flatten)
68 : // clang, g++
69 : #define _nnt_ATTR_NOINLINE [[gnu::noinline]]
70 : #else
71 : #define _nnt_ATTR_NOINLINE
72 : #endif
73 :
74 : #if _MSC_VER >= 1700 || __has_cpp_attribute(msvc::forceinline)
75 : #define _nnt_ATTR_ALWAYS_INLINE [[msvc::forceinline]]
76 : #elif __has_cpp_attribute(gnu::always_inline)
77 : #define _nnt_ATTR_ALWAYS_INLINE [[gnu::always_inline]]
78 : #endif
79 :
80 : #if __has_cpp_attribute(unikely)
81 : #define UNLIKELY [[unlikely]]
82 : #else
83 : #define UNLIKELY
84 : #endif
85 :
86 : #if !defined(_MSC_VER) && !defined(__clang__)
87 : #pragma GCC diagnostic ignored "-Wattributes"
88 : #endif
89 :
90 : #if defined(__clang__) || defined(__GNUC__)
91 : #define RESTRICT __restrict__
92 : #else
93 : #define RESTRICT
94 : #endif
95 :
96 : namespace {
97 :
98 : template <typename To_, typename From_>
99 : constexpr inline bool concept17_BinaryCastable =
100 : sizeof(To_) == sizeof(From_) &&
101 : std::is_trivially_copyable_v<From_> &&std::is_trivially_copyable_v<To_>;
102 :
103 : template <class To_, class From_>
104 : auto compat_bit_cast(const From_ &src) noexcept
105 : -> std::enable_if_t<concept17_BinaryCastable<To_, From_>, To_> {
106 : #if __cpp_lib_bit_cast >= 201806L
107 : return std::bit_cast<To_>(src);
108 : #else
109 : To_ dst;
110 : std::memcpy(&dst, &src, sizeof(To_));
111 : return dst;
112 : #endif
113 : }
114 :
115 : [[nodiscard]] constexpr inline unsigned
116 : constexpr_popcount(uint32_t v) noexcept {
117 : #if __cpp_lib_bitops >= 201907L
118 : return std::popcount(v);
119 : #else
120 : // Popcount via bit-hack
121 : v = v - ((v >> 1) & 0x55555555); // reuse input as temporary
122 : v = (v & 0x33333333) + ((v >> 2) & 0x33333333); // temp
123 : auto c = (((v + (v >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24; // count
124 : return c;
125 : #endif
126 : }
127 :
128 : template <unsigned I_>
129 : constexpr inline bool concept17_PowerOfTwo = (constexpr_popcount(I_) == 1);
130 :
131 : namespace numbers {
132 :
133 : #if __has_include(<numbers>) && __cpp_lib_math_constants >= 201907L
134 : using std::numbers::ln2_v;
135 : using std::numbers::log2e_v;
136 : #else
137 : template <typename Float_> constexpr inline auto ln2_v = Float_{M_LN2};
138 :
139 : template <typename Float_> constexpr inline auto log2e_v = Float_{M_LOG2E};
140 : #endif
141 : } // namespace numbers
142 :
143 : constexpr inline float EXP_ARG_MIN = -87.0;
144 : constexpr inline float EXP_ARG_MAX = +88.3762626647949f;
145 :
146 : // @brief Precalculated lookup table for 2^x calculation
147 : template <unsigned N_, typename Ty_ = uint32_t, typename Float_ = float,
148 : typename = std::enable_if_t<concept17_PowerOfTwo<N_>>>
149 : struct Exp2Table {
150 :
151 : constexpr static inline auto MANTISSA_BITS =
152 : std::numeric_limits<Float_>::digits - 1;
153 :
154 : #if __cpp_consteval >= 201811L && __has_constexpr_builtin(__builtin_exp2)
155 : [[nodiscard]] static consteval auto calculate() noexcept {
156 : std::array<Ty_, N_> t;
157 : for (unsigned i = 0; i < N_; ++i)
158 : t[i] = std::bit_cast<Ty_>(std::exp2(Float_{1.0} * i / N_)) -
159 : ((i << MANTISSA_BITS) / N_);
160 : return t;
161 : }
162 : #endif
163 : };
164 :
165 : #if !__has_constexpr_builtin(__builtin_exp2) || !(__cpp_consteval >= 201811L)
166 :
167 : // @brief Precalculated lookup table for 2^x calculation when we don't have
168 : // constexpr math functions
169 : template <> struct Exp2Table<8, uint32_t, float> {
170 : [[nodiscard]] static constexpr auto calculate() noexcept {
171 : std::array<uint32_t, 8> t = {0x3f800000U, 0x3f7b95c2U, 0x3f7837f0U,
172 : 0x3f75fed7U, 0x3f7504f3U, 0x3f75672aU,
173 : 0x3f7744fdU, 0x3f7ac0c7U};
174 : return t;
175 : }
176 : };
177 :
178 : #endif
179 :
180 : template <unsigned N_> // requires PowerOfTwo<N_>
181 : alignas(__m256) inline constexpr auto exp2_table_v = Exp2Table<N_>::calculate();
182 :
183 : // Scalar version of expf() approximation with 3-th deg polynominal of
184 : // fractional part
185 : //
186 : // The error with regards to std::expf less than 5e-6
187 : // It is valid in range [-87, +88.37) - not handling +INF, NaN etc.
188 : // The function domain is clamped to valid function range.
189 : //
190 : // The strategy picked is to approximate exp as 2^K*2^F
191 : template <unsigned N_, typename = std::enable_if_t<concept17_PowerOfTwo<N_>>>
192 : [[nodiscard]] _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN inline float
193 : approx_exp_exp2_lookup(float x) noexcept {
194 : constexpr static unsigned N_MASK = uint32_t(N_ - 1U);
195 : constexpr static unsigned FLT_MANTISSA_BITS =
196 : std::numeric_limits<float>::digits - 1U;
197 :
198 : x = std::max(x, EXP_ARG_MIN);
199 : x *= float(0x1.0p1 / numbers::ln2_v<double> * N_);
200 : x = std::min(x, float(EXP_ARG_MAX / numbers::ln2_v<double> * N_));
201 :
202 : // Round nearest and convert integer part to an int (std::modf)
203 : // NB: This way doesn't handle ties even.
204 : auto x_int = x + 0x1.8p23f;
205 : auto x_uint = compat_bit_cast<uint32_t>(x_int);
206 : x_int -= 0x1.8p23f;
207 : auto x_frac = x - x_int;
208 :
209 : auto s_int = exp2_table_v<N_>[x_uint & N_MASK];
210 : auto x_uint_shifted = x_uint
211 : << (FLT_MANTISSA_BITS - constexpr_popcount(N_MASK));
212 : auto s_int_2 = s_int + x_uint_shifted;
213 : auto s = compat_bit_cast<float>(s_int_2);
214 :
215 : // Polynominal of form C0*x^3 + C1*x^2 + C2*x^1 + 1.0
216 : static constexpr float poly_4[] = {0x1.c6af84b912394p-5f / N_ / N_ / N_,
217 : 0x1.ebfce50fac4f3p-3f / N_ / N_,
218 : 0x1.62e42ff0c52d6p-1f / N_};
219 :
220 : auto q0 = std::fma(x_frac, poly_4[0], poly_4[1]);
221 : auto x_frac_pow2 = x_frac * x_frac;
222 : auto q2 = x_frac * poly_4[2]; // not adding +1.0
223 :
224 : x = std::fma(q0, x_frac_pow2, q2);
225 :
226 : return std::fma(x, s, s); // NB: (handles (x+1) by addition of s)
227 : }
228 :
229 : // Vectorized version of above
230 : template <unsigned N_, typename = std::enable_if_t<concept17_PowerOfTwo<N_>>>
231 : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN inline auto _nnt_CC_VECTORCALL
232 : avx2_approx_exp_e2lookup(__m256 xs) noexcept {
233 :
234 : constexpr static uint32_t N_MASK = uint32_t(N_ - 1U);
235 : alignas(64) constexpr static auto EXP2_TBL = exp2_table_v<N_>;
236 : constexpr static unsigned MANTISSA_BITS =
237 : std::numeric_limits<float>::digits - 1;
238 :
239 : // Ensure arg in range [exp_arg_min, exp_arg_max]
240 : xs = _mm256_max_ps(xs, _mm256_set1_ps(EXP_ARG_MIN));
241 : // Would clamp to EXP_ARG_MAX but we move it after multiplication for IPC:
242 : // xs = _mm256_min_ps(xs, _mm256_set1_ps(EXP_ARG_MAX));
243 :
244 : xs =
245 : _mm256_mul_ps(xs, _mm256_set1_ps(float(1.0 / numbers::ln2_v<double> * N_)));
246 : // Clamp EXP_ARG_MAX after multiply
247 : xs = _mm256_min_ps(
248 : xs, _mm256_set1_ps(float(EXP_ARG_MAX / numbers::ln2_v<double> * N_)));
249 :
250 : // Mostly equivalent to, doesn't round ties to even
251 : // auto xs_int = _mm256_round_ps(xs, _MM_FROUND_TO_NEAREST_INT |
252 : // _MM_FROUND_NO_EXC); auto xs_int_as_u32 = _mm256_cvtps_epi32(xs_int);
253 : auto xs_int = _mm256_add_ps(xs, _mm256_set1_ps(0x1.8p23f));
254 : auto xs_int_as_u32 = _mm256_castps_si256(xs_int);
255 : xs_int = _mm256_sub_ps(xs_int, _mm256_set1_ps(0x1.8p23f));
256 :
257 : // Calculate fractional part
258 : auto xs_frac = _mm256_sub_ps(xs, xs_int);
259 : // Indices for lookup (modulo N_)
260 : auto exp2_idxs = _mm256_and_si256(xs_int_as_u32, _mm256_set1_epi32(N_MASK));
261 :
262 : __m256i s_ints;
263 :
264 : // Lookup e^xs_int s factor
265 : if constexpr (N_ == 8) {
266 : // Lookup by vector permute
267 : auto tbl = _mm256_load_si256((__m256i *)EXP2_TBL.data());
268 : s_ints = _mm256_permutevar8x32_epi32(tbl, exp2_idxs);
269 : } else {
270 : // Falback for not fitting number of vector elements
271 : s_ints = _mm256_i32gather_epi32(EXP2_TBL.data(), exp2_idxs, 1);
272 : }
273 :
274 : auto xs_uint_shifted = _mm256_slli_epi32(
275 : xs_int_as_u32, MANTISSA_BITS - constexpr_popcount(N_MASK));
276 : auto s_ints_2 = _mm256_add_epi32(s_ints, xs_uint_shifted);
277 : auto s_floats = _mm256_castsi256_ps(s_ints_2);
278 :
279 : static constexpr float poly_d4[] = {0x1.c6af84b912394p-5f / N_ / N_ / N_,
280 : 0x1.ebfce50fac4f3p-3f / N_ / N_,
281 : 0x1.62e42ff0c52d6p-1f / N_};
282 :
283 : const auto C0 = _mm256_set1_ps(poly_d4[0]);
284 : const auto C1 = _mm256_set1_ps(poly_d4[1]);
285 : const auto C2 = _mm256_set1_ps(poly_d4[2]);
286 :
287 : auto qs0 = _mm256_fmadd_ps(xs_frac, C0, C1);
288 : auto xs_frac_pow2 = _mm256_mul_ps(xs_frac, xs_frac);
289 : auto qs2 = _mm256_mul_ps(xs_frac, C2);
290 :
291 : xs = _mm256_fmadd_ps(qs0, xs_frac_pow2, qs2);
292 :
293 : return _mm256_fmadd_ps(xs, s_floats, s_floats);
294 : }
295 :
296 : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN auto _nnt_CC_VECTORCALL
297 : avx2_negate_ps(__m256 x) noexcept -> __m256 {
298 : constexpr auto SIGN_SHIFT = sizeof(float) * 8 - 1;
299 : const auto UNDEF = _mm256_undefined_si256();
300 : const auto sign_bit =
301 : _mm256_slli_epi32(_mm256_cmpeq_epi16(UNDEF, UNDEF), SIGN_SHIFT);
302 : auto flt_sign_bit = _mm256_castsi256_ps(sign_bit);
303 : auto neg_x = _mm256_xor_ps(x, flt_sign_bit);
304 : return neg_x;
305 : }
306 :
307 : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN auto _nnt_CC_VECTORCALL
308 : avx2_approx_swiglu(__m256 x, __m256 s) noexcept -> __m256 {
309 : auto neg_x = avx2_negate_ps(x);
310 : auto inv_sigmoid =
311 : _mm256_add_ps(avx2_approx_exp_e2lookup<8>(neg_x), _mm256_set1_ps(1.0f));
312 : auto swiglu_nonscaled = _mm256_div_ps(x, inv_sigmoid);
313 : return _mm256_mul_ps(swiglu_nonscaled, s);
314 : }
315 :
316 : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN auto _nnt_CC_VECTORCALL
317 : avx2_approx_swiglu_alpha(__m256 x, __m256 s, __m256 alpha) noexcept -> __m256 {
318 : auto alpha_x = _mm256_mul_ps(alpha, x);
319 : auto neg_alpha_x = avx2_negate_ps(alpha_x);
320 : auto inv_sigmoid = _mm256_add_ps(avx2_approx_exp_e2lookup<8>(neg_alpha_x),
321 : _mm256_set1_ps(1.0f));
322 : auto swiglu_nonscaled = _mm256_div_ps(x, inv_sigmoid);
323 : return _mm256_mul_ps(swiglu_nonscaled, s);
324 : }
325 : } // namespace
326 :
327 : namespace nntrainer::avx2 {
328 :
329 : /**
330 : * @brief struct of q4_0x8 block
331 : */
332 : struct block_q4_0x8 {
333 : uint16_t d[8]; // 16B
334 : uint8_t qs[128]; // 16 x u64
335 : };
336 :
337 : #define USE_NONTEMPORAL_STORES 1
338 :
339 : static inline void store256_u16(void *dst, __m256i v) {
340 : #if defined(USE_NONTEMPORAL_STORES)
341 : // use NT only if 32B-aligned; otherwise fall back (correctness first)
342 : if (((uintptr_t)dst & 31u) == 0) {
343 : _mm256_stream_si256((__m256i *)dst, v);
344 : return;
345 : }
346 : #endif
347 : _mm256_storeu_si256((__m256i *)dst, v);
348 : }
349 :
350 0 : void unpack_q4_0x8_transpose16(const void *src, unsigned short *__restrict dT,
351 : unsigned short *__restrict qsT, int N, int K,
352 : int CT) // column tile (in units of 32-cols)
353 : {
354 0 : assert((K % 256) == 0);
355 0 : assert((N % 8) == 0);
356 :
357 : const auto *__restrict x = static_cast<const block_q4_0x8 *>(src);
358 :
359 0 : const int groups_N8 = N / 8; // number of 8-row groups
360 0 : const int cols_scales = K / 32; // K subblocks
361 :
362 : // AVX2 constants
363 0 : const __m128i v88 = _mm_set1_epi8((char)0x88);
364 0 : const __m128i v0f = _mm_set1_epi8((char)0x0F);
365 0 : const __m128i vF0 = _mm_set1_epi8((char)0xF0);
366 :
367 : const __m128i idx_even =
368 : _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, (char)0xFF, (char)0xFF, (char)0xFF,
369 0 : (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF);
370 : const __m128i idx_odd =
371 : _mm_setr_epi8(1, 3, 5, 7, 9, 11, 13, 15, (char)0xFF, (char)0xFF, (char)0xFF,
372 0 : (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF);
373 : const __m128i idx_0246 =
374 : _mm_setr_epi8(0, 2, 4, 6, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
375 : (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
376 0 : (char)0xFF, (char)0xFF, (char)0xFF);
377 : const __m128i idx_1357 =
378 : _mm_setr_epi8(1, 3, 5, 7, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
379 : (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
380 0 : (char)0xFF, (char)0xFF, (char)0xFF);
381 :
382 0 : auto pack_row8 = [&](const unsigned char *qs0, const unsigned char *qs1,
383 : int off) -> __m128i {
384 0 : __m128i lo8 = _mm_loadl_epi64((const __m128i *)(qs0 + 8 * off));
385 0 : __m128i hi8 = _mm_loadl_epi64((const __m128i *)(qs1 + 8 * off));
386 : __m128i v = _mm_unpacklo_epi64(lo8, hi8);
387 0 : v = _mm_xor_si128(v, v88);
388 0 : __m128i lo = _mm_and_si128(v, v0f);
389 : __m128i hi = _mm_and_si128(_mm_srli_epi16(v, 4), v0f);
390 0 : __m128i lo_e = _mm_shuffle_epi8(lo, idx_even);
391 0 : __m128i lo_o = _mm_shuffle_epi8(lo, idx_odd);
392 : __m128i hi_e = _mm_shuffle_epi8(hi, idx_even);
393 : __m128i hi_o = _mm_shuffle_epi8(hi, idx_odd);
394 : __m128i low_lane =
395 0 : _mm_or_si128(lo_e, _mm_and_si128(_mm_slli_epi16(lo_o, 4), vF0));
396 : __m128i high_lane =
397 : _mm_or_si128(hi_e, _mm_and_si128(_mm_slli_epi16(hi_o, 4), vF0));
398 0 : __m128i low_e2 = _mm_shuffle_epi8(low_lane, idx_0246);
399 0 : __m128i low_o2 = _mm_shuffle_epi8(low_lane, idx_1357);
400 : __m128i high_e2 = _mm_shuffle_epi8(high_lane, idx_0246);
401 : __m128i high_o2 = _mm_shuffle_epi8(high_lane, idx_1357);
402 : __m128i pack_lo = _mm_unpacklo_epi8(low_e2, low_o2); // 4×u16 (w0..w3)
403 : __m128i pack_hi = _mm_unpacklo_epi8(high_e2, high_o2); // 4×u16 (w4..w7)
404 0 : return _mm_unpacklo_epi64(pack_lo, pack_hi); // 8×u16 (w0..w7)
405 0 : };
406 :
407 : auto transpose8x8_epi16 =
408 : [](__m128i r0, __m128i r1, __m128i r2, __m128i r3, __m128i r4, __m128i r5,
409 : __m128i r6, __m128i r7, __m128i &c0, __m128i &c1, __m128i &c2,
410 : __m128i &c3, __m128i &c4, __m128i &c5, __m128i &c6, __m128i &c7) {
411 : __m128i t0 = _mm_unpacklo_epi16(r0, r1);
412 : __m128i t1 = _mm_unpackhi_epi16(r0, r1);
413 : __m128i t2 = _mm_unpacklo_epi16(r2, r3);
414 : __m128i t3 = _mm_unpackhi_epi16(r2, r3);
415 : __m128i t4 = _mm_unpacklo_epi16(r4, r5);
416 : __m128i t5 = _mm_unpackhi_epi16(r4, r5);
417 : __m128i t6 = _mm_unpacklo_epi16(r6, r7);
418 : __m128i t7 = _mm_unpackhi_epi16(r6, r7);
419 :
420 : __m128i u0 = _mm_unpacklo_epi32(t0, t2);
421 : __m128i u1 = _mm_unpackhi_epi32(t0, t2);
422 : __m128i u2 = _mm_unpacklo_epi32(t1, t3);
423 : __m128i u3 = _mm_unpackhi_epi32(t1, t3);
424 : __m128i u4 = _mm_unpacklo_epi32(t4, t6);
425 : __m128i u5 = _mm_unpackhi_epi32(t4, t6);
426 : __m128i u6 = _mm_unpacklo_epi32(t5, t7);
427 : __m128i u7 = _mm_unpackhi_epi32(t5, t7);
428 :
429 : c0 = _mm_unpacklo_epi64(u0, u4);
430 : c1 = _mm_unpackhi_epi64(u0, u4);
431 : c2 = _mm_unpacklo_epi64(u1, u5);
432 : c3 = _mm_unpackhi_epi64(u1, u5);
433 : c4 = _mm_unpacklo_epi64(u2, u6);
434 : c5 = _mm_unpackhi_epi64(u2, u6);
435 : c6 = _mm_unpacklo_epi64(u3, u7);
436 : c7 = _mm_unpackhi_epi64(u3, u7);
437 : };
438 :
439 : // -------- pair-processing path: handle two 8-row groups (16 rows) per pass
440 : // --------
441 0 : const int groups_pairs = groups_N8 / 2;
442 :
443 : #ifdef _MSC_VER
444 : #pragma warning(push)
445 : #pragma warning(disable : 4849)
446 : #endif
447 0 : #pragma omp parallel for collapse(2) schedule(static)
448 : #ifdef _MSC_VER
449 : #pragma warning(pop)
450 : #endif
451 : for (int c0 = 0; c0 < cols_scales; c0 += CT) {
452 : for (int bp = 0; bp < groups_pairs; ++bp) {
453 : const int b0 = 2 * bp;
454 : const int b1 = b0 + 1;
455 : const int r0 = b0 * 8; // 16 rows: r0..r0+15
456 : const int c1 = std::min(c0 + CT, cols_scales);
457 :
458 : for (int c = c0; c < c1; ++c) {
459 : const block_q4_0x8 &A = x[b0 * cols_scales + c];
460 : const block_q4_0x8 &B = x[b1 * cols_scales + c];
461 :
462 : unsigned short *__restrict dT_c = dT + c * N;
463 : unsigned short *__restrict qsT_c0 = qsT + (c * 8) * N;
464 :
465 : // scales: pack two 8×u16 vectors → one 256b store to dT[c, r0..r0+15]
466 : __m128i sd0 = _mm_loadu_si128((const __m128i *)A.d);
467 : __m128i sd1 = _mm_loadu_si128((const __m128i *)B.d);
468 : __m256i sdp = _mm256_set_m128i(sd1, sd0);
469 : store256_u16(dT_c + r0, sdp);
470 :
471 : // pre-split stripes
472 : const unsigned char *__restrict A0 = A.qs; // + 8*off
473 : const unsigned char *__restrict A1 = A.qs + 64; // + 8*off
474 : const unsigned char *__restrict B0 = B.qs;
475 : const unsigned char *__restrict B1 = B.qs + 64;
476 :
477 : // build 8 rows for A and 8 rows for B
478 : __m128i Ra[8], Rb[8];
479 : for (int off = 0; off < 8; ++off) {
480 : Ra[off] = pack_row8(A0, A1, off);
481 : Rb[off] = pack_row8(B0, B1, off);
482 : }
483 :
484 : // 8×8 transpose → columns (each 8×u16) for A and B
485 : __m128i Ca0, Ca1, Ca2, Ca3, Ca4, Ca5, Ca6, Ca7;
486 : __m128i Cb0, Cb1, Cb2, Cb3, Cb4, Cb5, Cb6, Cb7;
487 : transpose8x8_epi16(Ra[0], Ra[1], Ra[2], Ra[3], Ra[4], Ra[5], Ra[6],
488 : Ra[7], Ca0, Ca1, Ca2, Ca3, Ca4, Ca5, Ca6, Ca7);
489 : transpose8x8_epi16(Rb[0], Rb[1], Rb[2], Rb[3], Rb[4], Rb[5], Rb[6],
490 : Rb[7], Cb0, Cb1, Cb2, Cb3, Cb4, Cb5, Cb6, Cb7);
491 :
492 : // pair and store 32B per column t: rows r0..r0+15 are contiguous
493 : unsigned short *__restrict base = qsT_c0 + r0;
494 : const int S = N;
495 : store256_u16(base + 0 * S, _mm256_set_m128i(Cb0, Ca0));
496 : store256_u16(base + 1 * S, _mm256_set_m128i(Cb1, Ca1));
497 : store256_u16(base + 2 * S, _mm256_set_m128i(Cb2, Ca2));
498 : store256_u16(base + 3 * S, _mm256_set_m128i(Cb3, Ca3));
499 : store256_u16(base + 4 * S, _mm256_set_m128i(Cb4, Ca4));
500 : store256_u16(base + 5 * S, _mm256_set_m128i(Cb5, Ca5));
501 : store256_u16(base + 6 * S, _mm256_set_m128i(Cb6, Ca6));
502 : store256_u16(base + 7 * S, _mm256_set_m128i(Cb7, Ca7));
503 : }
504 : }
505 : }
506 :
507 : // -------- tail: if odd number of 8-row groups, process the last one (8 rows)
508 : // --------
509 0 : if (groups_N8 & 1) {
510 0 : const int b = groups_N8 - 1;
511 0 : const int r0 = b * 8;
512 :
513 0 : #pragma omp parallel for schedule(static)
514 : for (int c0 = 0; c0 < cols_scales; c0 += CT) {
515 : const int c1 = std::min(c0 + CT, cols_scales);
516 : for (int c = c0; c < c1; ++c) {
517 : const block_q4_0x8 &A = x[b * cols_scales + c];
518 : unsigned short *__restrict dT_c = dT + c * N;
519 : unsigned short *__restrict qsT_c0 = qsT + (c * 8) * N;
520 :
521 : // scales (8×u16)
522 : __m128i sd0 = _mm_loadu_si128((const __m128i *)A.d);
523 : _mm_storeu_si128((__m128i *)(dT_c + r0), sd0);
524 :
525 : const unsigned char *__restrict A0 = A.qs;
526 : const unsigned char *__restrict A1 = A.qs + 64;
527 :
528 : __m128i R[8];
529 : for (int off = 0; off < 8; ++off)
530 : R[off] = pack_row8(A0, A1, off);
531 :
532 : __m128i C0, C1, C2, C3, C4, C5, C6, C7;
533 : transpose8x8_epi16(R[0], R[1], R[2], R[3], R[4], R[5], R[6], R[7], C0,
534 : C1, C2, C3, C4, C5, C6, C7);
535 :
536 : unsigned short *__restrict base = qsT_c0 + r0;
537 : const int S = N;
538 : _mm_storeu_si128((__m128i *)(base + 0 * S), C0);
539 : _mm_storeu_si128((__m128i *)(base + 1 * S), C1);
540 : _mm_storeu_si128((__m128i *)(base + 2 * S), C2);
541 : _mm_storeu_si128((__m128i *)(base + 3 * S), C3);
542 : _mm_storeu_si128((__m128i *)(base + 4 * S), C4);
543 : _mm_storeu_si128((__m128i *)(base + 5 * S), C5);
544 : _mm_storeu_si128((__m128i *)(base + 6 * S), C6);
545 : _mm_storeu_si128((__m128i *)(base + 7 * S), C7);
546 : }
547 : }
548 : }
549 :
550 : #if defined(USE_NONTEMPORAL_STORES)
551 : _mm_sfence(); // ensure NT stores are globally visible before returning
552 : #endif
553 0 : }
554 :
555 : static inline __m256i butterfly32(__m256i a) {
556 : const __m256i SHUF_EVEN = _mm256_setr_epi8(
557 : 0, 2, 4, 6, 8, 10, 12, 14, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
558 : (char)0x80, (char)0x80, (char)0x80, (char)0x80, 0, 2, 4, 6, 8, 10, 12, 14,
559 : (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
560 : (char)0x80, (char)0x80);
561 : const __m256i SHUF_ODD = _mm256_setr_epi8(
562 : 1, 3, 5, 7, 9, 11, 13, 15, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
563 : (char)0x80, (char)0x80, (char)0x80, (char)0x80, 1, 3, 5, 7, 9, 11, 13, 15,
564 : (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
565 : (char)0x80, (char)0x80);
566 : const __m256i even = _mm256_shuffle_epi8(a, SHUF_EVEN);
567 : const __m256i odd = _mm256_shuffle_epi8(a, SHUF_ODD);
568 : const __m256i LO = _mm256_set1_epi8(0x0F);
569 : const __m256i HI = _mm256_set1_epi8((char)0xF0);
570 : __m256i low =
571 : _mm256_or_si256(_mm256_and_si256(even, LO),
572 : _mm256_slli_epi16(_mm256_and_si256(odd, LO), 4));
573 : __m256i high =
574 : _mm256_or_si256(_mm256_srli_epi16(_mm256_and_si256(even, HI), 4),
575 : _mm256_and_si256(odd, HI));
576 : high = _mm256_slli_si256(high, 8);
577 : return _mm256_or_si256(low, high);
578 : }
579 :
580 : // Build 16B packet [d0|d1] from two 8B chunks using vector loads (no GPR
581 : // moves).
582 : static inline __m128i make_pkt128(const uint8_t *base_qs, int d0, int d1) {
583 0 : __m128i lo = _mm_loadl_epi64((const __m128i *)(base_qs + ((size_t)d0 << 3)));
584 0 : __m128i hi = _mm_loadl_epi64((const __m128i *)(base_qs + ((size_t)d1 << 3)));
585 : return _mm_unpacklo_epi64(lo, hi);
586 : }
587 :
588 : // ================== core template with QS unrolled by 8 blocks
589 : // ==================
590 : template <int UNIT, int GROUPS>
591 : static inline void convert_q4_0x8_noshuffle(const void *src,
592 : uint16_t *RESTRICT d_out,
593 : uint8_t *RESTRICT qs_out) {
594 : static_assert(UNIT % 16 == 0, "UNIT must be multiple of 16");
595 : constexpr int BLOCKS_PER_GROUP = UNIT / 8; // d entries per offset per group
596 : constexpr int PAIRS_PER_OFFSET = UNIT / 16; // 16B packets per half per offset
597 : static_assert((PAIRS_PER_OFFSET % 4) == 0,
598 : "need multiple of 4 packets (8 blocks) per iter");
599 :
600 : constexpr size_t D_ELEMS_PER_GROUP = 8 * BLOCKS_PER_GROUP;
601 : constexpr size_t QS_BYTES_PER_GROUP = (size_t)16 * UNIT;
602 : constexpr size_t QS_BYTES_PER_OFFSET = (size_t)2 * UNIT;
603 :
604 0 : const block_q4_0x8 *x = (const block_q4_0x8 *)src;
605 0 : const __m256i bias256 = _mm256_set1_epi8((char)0x88);
606 :
607 : #ifdef _MSC_VER
608 : #pragma warning(push)
609 : #pragma warning(disable : 4849)
610 : #endif
611 0 : #pragma omp parallel for collapse(2) schedule(static)
612 : #ifdef _MSC_VER
613 : #pragma warning(pop)
614 : #endif
615 : for (int b = 0; b < GROUPS; ++b) {
616 : for (int offset = 0; offset < 8; ++offset) {
617 :
618 : // ---- D slice ----
619 : {
620 : uint16_t *d_ptr = d_out + (size_t)b * D_ELEMS_PER_GROUP +
621 : (size_t)offset * BLOCKS_PER_GROUP;
622 : const block_q4_0x8 *xb = x + (size_t)b * BLOCKS_PER_GROUP;
623 : for (int i = 0; i < BLOCKS_PER_GROUP; ++i) {
624 : d_ptr[i] = xb[i].d[offset];
625 : }
626 : }
627 :
628 : // ---- QS slice (unroll 8 blocks / 128B per iter) ----
629 : {
630 : uint8_t *qs_ptr = qs_out + (size_t)b * QS_BYTES_PER_GROUP +
631 : (size_t)offset * QS_BYTES_PER_OFFSET;
632 : const int base_q = (b * UNIT * 2) + offset;
633 : const int d0 = (base_q & 15), d1 = d0 ^ 8;
634 :
635 0 : auto do_half = [&](int blk_base) {
636 : // Each iter handles 8 consecutive blocks: j..j+7
637 0 : for (int j = 0; j < PAIRS_PER_OFFSET; j += 8) {
638 0 : const uint8_t *q0 = x[blk_base + j + 0].qs;
639 0 : const uint8_t *q1 = x[blk_base + j + 1].qs;
640 0 : const uint8_t *q2 = x[blk_base + j + 2].qs;
641 0 : const uint8_t *q3 = x[blk_base + j + 3].qs;
642 0 : const uint8_t *q4 = x[blk_base + j + 4].qs;
643 0 : const uint8_t *q5 = x[blk_base + j + 5].qs;
644 0 : const uint8_t *q6 = x[blk_base + j + 6].qs;
645 0 : const uint8_t *q7 = x[blk_base + j + 7].qs;
646 :
647 : #if Q4X8_PREFETCH_DIST > 0
648 : _mm_prefetch(
649 : (const char *)(x[blk_base + j + Q4X8_PREFETCH_DIST].qs),
650 : _MM_HINT_NTA);
651 : #endif
652 : // Build 8 packets in XMM regs
653 0 : __m128i pkt0 = make_pkt128(q0, d0, d1);
654 : __m128i pkt1 = make_pkt128(q1, d0, d1);
655 : __m128i pkt2 = make_pkt128(q2, d0, d1);
656 : __m128i pkt3 = make_pkt128(q3, d0, d1);
657 : __m128i pkt4 = make_pkt128(q4, d0, d1);
658 : __m128i pkt5 = make_pkt128(q5, d0, d1);
659 : __m128i pkt6 = make_pkt128(q6, d0, d1);
660 : __m128i pkt7 = make_pkt128(q7, d0, d1);
661 :
662 : // Four 32B batches: [0|1], [2|3], [4|5], [6|7]
663 : __m256i v01 = _mm256_set_m128i(pkt1, pkt0);
664 : __m256i v23 = _mm256_set_m128i(pkt3, pkt2);
665 : __m256i v45 = _mm256_set_m128i(pkt5, pkt4);
666 : __m256i v67 = _mm256_set_m128i(pkt7, pkt6);
667 :
668 0 : v01 = _mm256_xor_si256(v01, bias256);
669 : v23 = _mm256_xor_si256(v23, bias256);
670 : v45 = _mm256_xor_si256(v45, bias256);
671 : v67 = _mm256_xor_si256(v67, bias256);
672 :
673 : __m256i o01 = butterfly32(v01);
674 : __m256i o23 = butterfly32(v23);
675 : __m256i o45 = butterfly32(v45);
676 : __m256i o67 = butterfly32(v67);
677 :
678 : #if Q4X8_USE_STREAMING_STORES
679 : _mm256_stream_si256((__m256i *)(qs_ptr + 0), o01);
680 : _mm256_stream_si256((__m256i *)(qs_ptr + 32), o23);
681 : _mm256_stream_si256((__m256i *)(qs_ptr + 64), o45);
682 : _mm256_stream_si256((__m256i *)(qs_ptr + 96), o67);
683 : #else
684 0 : _mm256_storeu_si256((__m256i *)(qs_ptr + 0), o01);
685 0 : _mm256_storeu_si256((__m256i *)(qs_ptr + 32), o23);
686 0 : _mm256_storeu_si256((__m256i *)(qs_ptr + 64), o45);
687 0 : _mm256_storeu_si256((__m256i *)(qs_ptr + 96), o67);
688 : #endif
689 0 : qs_ptr += 128;
690 : }
691 : };
692 :
693 : // first half
694 : do_half(base_q >> 4);
695 : // second half (same d0/d1 pattern)
696 : do_half((base_q + UNIT) >> 4);
697 : }
698 : }
699 : }
700 :
701 : #if Q4X8_USE_STREAMING_STORES
702 : _mm_sfence();
703 : #endif
704 : }
705 :
706 : // ================== wrappers for your K,N combinations ==================
707 : // K = 3072 (UNIT = 768)
708 0 : void convert_q4_0x8_shuffle_K3072_N98304(const void *src, uint16_t *d_out,
709 : uint8_t *qs_out) {
710 : // groups = (N*8)/UNIT = 1024
711 : convert_q4_0x8_noshuffle<768, 1024>(src, d_out, qs_out);
712 0 : }
713 0 : void convert_q4_0x8_shuffle_K3072_N36864(const void *src, uint16_t *d_out,
714 : uint8_t *qs_out) {
715 : // groups = 384
716 : convert_q4_0x8_noshuffle<768, 384>(src, d_out, qs_out);
717 0 : }
718 0 : void convert_q4_0x8_shuffle_K3072_N3072(const void *src, uint16_t *d_out,
719 : uint8_t *qs_out) {
720 : // groups = 32
721 : convert_q4_0x8_noshuffle<768, 32>(src, d_out, qs_out);
722 0 : }
723 :
724 : // K = 8192 (UNIT = 2048)
725 0 : void convert_q4_0x8_shuffle_K8192_N98304(const void *src, uint16_t *d_out,
726 : uint8_t *qs_out) {
727 : // groups = 384
728 : convert_q4_0x8_noshuffle<2048, 384>(src, d_out, qs_out);
729 0 : }
730 0 : void convert_q4_0x8_shuffle_K8192_N36864(const void *src, uint16_t *d_out,
731 : uint8_t *qs_out) {
732 : // groups = 144
733 : convert_q4_0x8_noshuffle<2048, 144>(src, d_out, qs_out);
734 0 : }
735 0 : void convert_q4_0x8_shuffle_K8192_N3072(const void *src, uint16_t *d_out,
736 : uint8_t *qs_out) {
737 : // groups = 12
738 : convert_q4_0x8_noshuffle<2048, 12>(src, d_out, qs_out);
739 0 : }
740 :
741 : // Optional tiny dispatcher if you want one entry point:
742 0 : void convert_q4_0x8_shuffle_dispatch_avx(const void *src, uint16_t *d_out,
743 : uint8_t *qs_out, int N, int K) {
744 0 : if (K == 3072) {
745 0 : if (N == 98304)
746 0 : return convert_q4_0x8_shuffle_K3072_N98304(src, d_out, qs_out);
747 0 : if (N == 36864)
748 0 : return convert_q4_0x8_shuffle_K3072_N36864(src, d_out, qs_out);
749 0 : if (N == 3072)
750 0 : return convert_q4_0x8_shuffle_K3072_N3072(src, d_out, qs_out);
751 : } else { // K == 8192
752 0 : if (N == 98304)
753 0 : return convert_q4_0x8_shuffle_K8192_N98304(src, d_out, qs_out);
754 0 : if (N == 36864)
755 0 : return convert_q4_0x8_shuffle_K8192_N36864(src, d_out, qs_out);
756 0 : if (N == 3072)
757 0 : return convert_q4_0x8_shuffle_K8192_N3072(src, d_out, qs_out);
758 : }
759 : // If a new combo appears, fall back to a generic version (not shown here).
760 0 : assert(!"Unsupported (K,N) combination");
761 : }
762 :
763 12 : bool is_valid(const unsigned int N, const float *input) {
764 12 : assert(N != 0);
765 12 : assert(input != NULL);
766 :
767 : int temp = 0;
768 : unsigned int idx = 0;
769 :
770 : const __m256 SIGN_MASK = _mm256_set1_ps(-0.0);
771 : const __m256 INF = _mm256_set1_ps(std::numeric_limits<float>::infinity());
772 :
773 : // 16 single-precision check : ( X != X )
774 15 : for (; N - idx >= 16; idx += 16) {
775 : __m256 vec0 = _mm256_loadu_ps(input);
776 : __m256 vec1 = _mm256_loadu_ps(input + 8);
777 6 : input += 16;
778 : __m256 res = _mm256_cmp_ps(vec0, vec0, _CMP_NEQ_UQ);
779 6 : temp = temp | _mm256_movemask_ps(res);
780 :
781 6 : if (temp)
782 : return false;
783 :
784 : // check infinity in vec0
785 : vec0 = _mm256_andnot_ps(SIGN_MASK, vec0);
786 : vec0 = _mm256_cmp_ps(vec0, INF, _CMP_EQ_OQ);
787 :
788 5 : temp = temp | _mm256_movemask_ps(vec0);
789 5 : if (temp)
790 : return false;
791 :
792 : __m256 res1 = _mm256_cmp_ps(vec1, vec1, _CMP_NEQ_UQ);
793 3 : temp = temp | _mm256_movemask_ps(res1);
794 :
795 3 : if (temp)
796 : return false;
797 :
798 : // check infinity in vec1
799 : vec1 = _mm256_andnot_ps(SIGN_MASK, vec1);
800 : vec1 = _mm256_cmp_ps(vec1, INF, _CMP_EQ_OQ);
801 :
802 3 : temp = temp | _mm256_movemask_ps(vec1);
803 :
804 3 : if (temp)
805 : return false;
806 : }
807 :
808 : // 8 single-precision check : ( X != X )
809 11 : for (; N - idx >= 8; idx += 8) {
810 : __m256 vec = _mm256_loadu_ps(input);
811 5 : input += 8;
812 : __m256 res = _mm256_cmp_ps(vec, vec, _CMP_NEQ_UQ);
813 5 : temp = temp | _mm256_movemask_ps(res);
814 :
815 5 : if (temp)
816 : return false;
817 :
818 : // check infinity in vec
819 : vec = _mm256_andnot_ps(SIGN_MASK, vec);
820 : vec = _mm256_cmp_ps(vec, INF, _CMP_EQ_OQ);
821 :
822 4 : temp = temp | _mm256_movemask_ps(vec);
823 :
824 4 : if (temp)
825 : return false;
826 : }
827 :
828 11 : while (idx < N) {
829 8 : if (!isFloatValid(*input)) {
830 : return false;
831 : }
832 5 : ++input;
833 5 : ++idx;
834 : }
835 :
836 : return true;
837 : }
838 :
839 173980 : void custom_scopy(const unsigned int N, const float *X, const int incX,
840 : float *Y, const int incY) {
841 173980 : unsigned int N8 = (N >> 3) << 3;
842 957465 : for (unsigned int i = 0; i < N8; i += 8) {
843 : #if defined(_WIN32)
844 : __m256 temp = _mm256_loadu_ps(&X[i]);
845 : _mm256_storeu_ps(&Y[i], temp);
846 : #else
847 783485 : __asm__ __volatile__("vmovups (%1), %%ymm0\n\t"
848 : "vmovups %%ymm0, (%0)\n\t"
849 : :
850 783485 : : "r"(&Y[i]), "r"(&X[i])
851 : : "ymm0", "memory");
852 : #endif
853 : }
854 499415 : for (unsigned int i = N8; i < N; ++i) {
855 325435 : Y[i] = X[i];
856 : }
857 173980 : }
858 :
859 25 : void transpose_matrix(const unsigned int M, const unsigned int N,
860 : const float *src, unsigned int ld_src, float *dst,
861 : unsigned int ld_dst) {
862 25 : unsigned int vindexm[8] = {0, ld_src, ld_src * 2, ld_src * 3,
863 25 : ld_src * 4, ld_src * 5, ld_src * 6, ld_src * 7};
864 : __m256i vindex = _mm256_loadu_si256((__m256i *)&vindexm[0]);
865 : __m256 vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8;
866 :
867 25 : unsigned int M8 = (M & ~(7));
868 25 : unsigned int N8 = (N & ~(7));
869 1378 : for (unsigned int i = 0; i < M8; i += 8) {
870 308626 : for (unsigned int j = 0; j < N8; j += 8) {
871 : // loading from columns
872 307273 : vec1 = _mm256_i32gather_ps(&src[ld_src * i + j + 0], vindex, 4);
873 307273 : vec2 = _mm256_i32gather_ps(&src[ld_src * i + j + 1], vindex, 4);
874 307273 : vec3 = _mm256_i32gather_ps(&src[ld_src * i + j + 2], vindex, 4);
875 307273 : vec4 = _mm256_i32gather_ps(&src[ld_src * i + j + 3], vindex, 4);
876 307273 : vec5 = _mm256_i32gather_ps(&src[ld_src * i + j + 4], vindex, 4);
877 307273 : vec6 = _mm256_i32gather_ps(&src[ld_src * i + j + 5], vindex, 4);
878 307273 : vec7 = _mm256_i32gather_ps(&src[ld_src * i + j + 6], vindex, 4);
879 307273 : vec8 = _mm256_i32gather_ps(&src[ld_src * i + j + 7], vindex, 4);
880 :
881 : // storing to the rows
882 307273 : _mm256_storeu_ps(&dst[(j + 0) * ld_dst + i], vec1);
883 307273 : _mm256_storeu_ps(&dst[(j + 1) * ld_dst + i], vec2);
884 307273 : _mm256_storeu_ps(&dst[(j + 2) * ld_dst + i], vec3);
885 307273 : _mm256_storeu_ps(&dst[(j + 3) * ld_dst + i], vec4);
886 307273 : _mm256_storeu_ps(&dst[(j + 4) * ld_dst + i], vec5);
887 307273 : _mm256_storeu_ps(&dst[(j + 5) * ld_dst + i], vec6);
888 307273 : _mm256_storeu_ps(&dst[(j + 6) * ld_dst + i], vec7);
889 307273 : _mm256_storeu_ps(&dst[(j + 7) * ld_dst + i], vec8);
890 : }
891 : }
892 :
893 : // tailing right
894 10910 : for (unsigned int i = 0; i < M; i++) {
895 12316 : for (unsigned int j = N8; j < N; j++) {
896 1431 : dst[i + j * ld_dst] = src[i * ld_src + j];
897 : }
898 : }
899 :
900 : // tailing bottom
901 86 : for (unsigned int i = M8; i < M; i++) {
902 356 : for (unsigned int j = 0; j < N; j++) {
903 295 : dst[i + j * ld_dst] = src[i * ld_src + j];
904 : }
905 : }
906 25 : }
907 :
908 0 : void swiglu(const unsigned int N, float *X, const float *Y, const float *Z) {
909 : size_t i = 0;
910 :
911 : const auto oldcsr = _mm_getcsr();
912 0 : _mm_setcsr(oldcsr | 0x8040); // DAZ | FTZ
913 :
914 : // 16-wide blocks
915 0 : for (; i + 16 <= N; i += 16) {
916 0 : const __m256 y0 = _mm256_loadu_ps(Y + i);
917 0 : const __m256 y1 = _mm256_loadu_ps(Y + i + 8);
918 0 : const __m256 z0 = _mm256_loadu_ps(Z + i);
919 0 : const __m256 z1 = _mm256_loadu_ps(Z + i + 8);
920 :
921 0 : _mm256_storeu_ps(X + i, avx2_approx_swiglu(y0, z0));
922 0 : _mm256_storeu_ps(X + i + 8, avx2_approx_swiglu(y1, z1));
923 : }
924 :
925 : // One 8-wide block if available
926 0 : if (i + 8 <= N) {
927 0 : const __m256 y0 = _mm256_loadu_ps(Y + i);
928 0 : const __m256 z0 = _mm256_loadu_ps(Z + i);
929 0 : _mm256_storeu_ps(X + i, avx2_approx_swiglu(y0, z0));
930 : i += 8;
931 : }
932 :
933 : // Remaining 1..7 elements via maskload/maskstore
934 0 : if (i < N) {
935 0 : const int remain = static_cast<int>(N - i); // 1..7
936 :
937 0 : alignas(64) const int mtab[16] = {-1, -1, -1, -1, -1, -1, -1, -1,
938 : 0, 0, 0, 0, 0, 0, 0, 0};
939 : // Start so that we take 'remain' ones then zeros.
940 0 : const int off = 8 - remain; // in [1..7], or 0 if remain==8
941 0 : const __m256i vmask = _mm256_loadu_si256((const __m256i *)(mtab + off));
942 :
943 0 : const __m256 y = _mm256_maskload_ps(Y + i, vmask);
944 0 : const __m256 z = _mm256_maskload_ps(Z + i, vmask);
945 : const __m256 r = avx2_approx_swiglu(y, z);
946 0 : _mm256_maskstore_ps(X + i, vmask, r);
947 : }
948 :
949 : _mm_setcsr(oldcsr);
950 0 : }
951 :
952 0 : void swiglu(const unsigned int N, float *X, const float *Y, const float *Z,
953 : float alpha) {
954 : size_t i = 0;
955 :
956 : const auto oldcsr = _mm_getcsr();
957 0 : _mm_setcsr(oldcsr | 0x8040); // DAZ | FTZ
958 :
959 : const __m256 alpha_vec = _mm256_set1_ps(alpha);
960 :
961 : // 16-wide blocks
962 0 : for (; i + 16 <= N; i += 16) {
963 0 : const __m256 y0 = _mm256_loadu_ps(Y + i);
964 0 : const __m256 y1 = _mm256_loadu_ps(Y + i + 8);
965 0 : const __m256 z0 = _mm256_loadu_ps(Z + i);
966 0 : const __m256 z1 = _mm256_loadu_ps(Z + i + 8);
967 :
968 0 : _mm256_storeu_ps(X + i, avx2_approx_swiglu_alpha(y0, z0, alpha_vec));
969 0 : _mm256_storeu_ps(X + i + 8, avx2_approx_swiglu_alpha(y1, z1, alpha_vec));
970 : }
971 :
972 : // One 8-wide block if present
973 0 : if (i + 8 <= N) {
974 0 : const __m256 y0 = _mm256_loadu_ps(Y + i);
975 0 : const __m256 z0 = _mm256_loadu_ps(Z + i);
976 0 : _mm256_storeu_ps(X + i, avx2_approx_swiglu_alpha(y0, z0, alpha_vec));
977 : i += 8;
978 : }
979 :
980 : // Remaining 1..7 elements via masked AVX (no stray stores)
981 0 : if (i < N) {
982 0 : const int remain = static_cast<int>(N - i); // 1..7
983 :
984 0 : alignas(64) const int mtab[16] = {
985 : -1, -1, -1, -1, -1, -1, -1, -1, // ones
986 : 0, 0, 0, 0, 0, 0, 0, 0 // zeros
987 : };
988 0 : const int off = 8 - remain; // choose first `remain` lanes active
989 0 : const __m256i vmask = _mm256_loadu_si256((const __m256i *)(mtab + off));
990 :
991 0 : const __m256 y = _mm256_maskload_ps(Y + i, vmask);
992 0 : const __m256 z = _mm256_maskload_ps(Z + i, vmask);
993 : const __m256 r = avx2_approx_swiglu_alpha(y, z, alpha_vec);
994 0 : _mm256_maskstore_ps(X + i, vmask, r);
995 : }
996 :
997 : _mm_setcsr(oldcsr);
998 0 : }
999 :
1000 29926 : void ele_mul(const unsigned int N, const float *X, const float *Y, float *Z,
1001 : float alpha, float beta, unsigned int i_stride,
1002 : unsigned int o_stride) {
1003 29926 : if (alpha == 1.0f && beta == 0.0f && o_stride == 1) {
1004 29581 : unsigned int N8 = (N & ~(7));
1005 29581 : if (i_stride == 0) {
1006 9591 : float vy8[8] = {Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0]};
1007 : auto y = _mm256_loadu_ps(&vy8[0]);
1008 18108 : for (unsigned int i = 0; i < N8; i += 8) {
1009 : auto x = _mm256_loadu_ps(X);
1010 : auto z = _mm256_mul_ps(x, y);
1011 : _mm256_storeu_ps(Z, z);
1012 8517 : X += 8;
1013 8517 : Y += i_stride * 8;
1014 8517 : Z += 8;
1015 : }
1016 : } else {
1017 13214695 : for (unsigned int i = 0; i < N8; i += 8) {
1018 : auto x = _mm256_loadu_ps(X);
1019 : auto y = _mm256_loadu_ps(Y);
1020 : auto z = _mm256_mul_ps(x, y);
1021 : _mm256_storeu_ps(Z, z);
1022 13194705 : X += 8;
1023 13194705 : Y += i_stride * 8;
1024 13194705 : Z += 8;
1025 : }
1026 : }
1027 138986 : for (unsigned int i = N8; i < N; ++i) {
1028 109405 : *Z = *X * *Y;
1029 109405 : X++;
1030 109405 : Y += i_stride;
1031 109405 : Z++;
1032 : }
1033 : } else {
1034 : // TODO: AVX2 implementation if used
1035 65505 : for (unsigned int i = 0; i < N; ++i) {
1036 65160 : *Z = *X * alpha * *Y + ((0.0f == beta) ? 0.0f : beta * *Z);
1037 65160 : X += o_stride;
1038 65160 : Y += i_stride;
1039 65160 : Z += o_stride;
1040 : }
1041 : }
1042 29926 : }
1043 :
1044 113033 : void ele_add(const unsigned int N, const float *X, const float *Y, float *Z,
1045 : float alpha, float beta, unsigned int i_stride,
1046 : unsigned int o_stride) {
1047 113033 : if (alpha == 1.0f && beta == 0.0f && o_stride == 1) {
1048 75820 : unsigned int N8 = (N & ~(7));
1049 75820 : if (i_stride == 0) {
1050 26375 : float vy8[8] = {Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0]};
1051 : auto y = _mm256_loadu_ps(&vy8[0]);
1052 580391 : for (unsigned int i = 0; i < N8; i += 8) {
1053 : auto x = _mm256_loadu_ps(X);
1054 : auto z = _mm256_add_ps(x, y);
1055 : _mm256_storeu_ps(Z, z);
1056 554016 : X += 8;
1057 554016 : Y += i_stride * 8;
1058 554016 : Z += 8;
1059 : }
1060 : } else {
1061 169478 : for (unsigned int i = 0; i < N8; i += 8) {
1062 : auto x = _mm256_loadu_ps(X);
1063 : auto y = _mm256_loadu_ps(Y);
1064 : auto z = _mm256_add_ps(x, y);
1065 : _mm256_storeu_ps(Z, z);
1066 120033 : X += 8;
1067 120033 : Y += i_stride * 8;
1068 120033 : Z += 8;
1069 : }
1070 : }
1071 234632 : for (unsigned int i = N8; i < N; ++i) {
1072 158812 : *Z = *X + *Y;
1073 158812 : X++;
1074 158812 : Y += i_stride;
1075 158812 : Z++;
1076 : }
1077 : } else {
1078 : // TODO: AVX2 implementation if used
1079 203432843 : for (unsigned int i = 0; i < N; ++i) {
1080 203395630 : *Z = *X + alpha * *Y + ((0.0f == beta) ? 0.0f : beta * *Z);
1081 203395630 : X += o_stride;
1082 203395630 : Y += i_stride;
1083 203395630 : Z += o_stride;
1084 : }
1085 : }
1086 113033 : }
1087 :
1088 9 : static inline __m256 exp256_ps(__m256 x) {
1089 : /* Low-Precision Version I*/
1090 : // const __m256 c1 = _mm256_set1_ps(12102203.0f);
1091 : // const __m256 c2 = _mm256_set1_ps(1065353216.0f);
1092 : // __m256 fx = _mm256_add_ps(_mm256_mul_ps(x, c1),c2);
1093 : // return _mm256_castsi256_ps(_mm256_cvtps_epi32(fx));
1094 :
1095 : /* Low-Precision Version II*/
1096 : /* const __m256 ln2 = _mm256_set1_ps(0.69314718056f);
1097 : const __m256 inv_ln2 = _mm256_set1_ps(1.44269504089f); // 1 / ln(2)
1098 :
1099 : // Range reduction: x = n * ln2 + r, where n is integer and |r| <= ln2/2
1100 : __m256 fx = _mm256_mul_ps(x, inv_ln2);
1101 : fx = _mm256_round_ps(fx, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
1102 : __m256i emm0 = _mm256_cvtps_epi32(fx);
1103 :
1104 : __m256 tmp = _mm256_mul_ps(fx, ln2);
1105 : __m256 r = _mm256_sub_ps(x, tmp);
1106 :
1107 : // Compute polynomial approximation of exp(r)
1108 : const __m256 c1 = _mm256_set1_ps(1.9875691500E-4f);
1109 : const __m256 c2 = _mm256_set1_ps(1.3981999507E-3f);
1110 : const __m256 c3 = _mm256_set1_ps(8.3334519073E-3f);
1111 : const __m256 c4 = _mm256_set1_ps(4.1665795894E-2f);
1112 : const __m256 c5 = _mm256_set1_ps(1.6666665459E-1f);
1113 : const __m256 c6 = _mm256_set1_ps(5.0000001201E-1f);
1114 :
1115 : // __m256 r2 = _mm256_mul_ps(r, r);
1116 : // __m256 r3 = _mm256_mul_ps(r2, r);
1117 : // __m256 r4 = _mm256_mul_ps(r2, r2);
1118 :
1119 : __m256 y = _mm256_fmadd_ps(c1, r, c2);
1120 : y = _mm256_fmadd_ps(y, r, c3);
1121 : y = _mm256_fmadd_ps(y, r, c4);
1122 : y = _mm256_fmadd_ps(y, r, c5);
1123 : y = _mm256_fmadd_ps(y, r, c6);
1124 : y = _mm256_fmadd_ps(y, r, _mm256_set1_ps(1.0f));
1125 :
1126 : // Reconstruct exp(x) = 2^n * exp(r)
1127 : emm0 = _mm256_add_epi32(emm0, _mm256_set1_epi32(127));
1128 : emm0 = _mm256_slli_epi32(emm0, 23);
1129 : __m256 pow2n = _mm256_castsi256_ps(emm0);
1130 :
1131 : return _mm256_mul_ps(y, pow2n);
1132 : */
1133 : /* Low-Precision Versino III */
1134 : const __m256 LOG2EF = _mm256_set1_ps(1.44269504088896341f); // 1 / ln(2)
1135 : const __m256 LN2 = _mm256_set1_ps(0.6931471805599453f); // ln(2)
1136 :
1137 : // Clamp input to range to prevent overflow/underflow
1138 : const __m256 max_x = _mm256_set1_ps(88.3762626647949f); // log(FLT_MAX)
1139 : const __m256 min_x = _mm256_set1_ps(-88.3762626647949f); // log(FLT_MIN)
1140 : x = _mm256_max_ps(min_x, _mm256_min_ps(max_x, x));
1141 :
1142 : // Range reduction: x = n * ln2 + r
1143 : __m256 fx = _mm256_mul_ps(x, LOG2EF); // x * (1/ln(2))
1144 : fx = _mm256_floor_ps(_mm256_add_ps(fx, _mm256_set1_ps(0.5f)));
1145 :
1146 : __m256 tmp = _mm256_mul_ps(fx, LN2); // n * ln(2)
1147 : __m256 r = _mm256_sub_ps(x, tmp); // r = x - n * ln2
1148 :
1149 : // Compute exp(r) using 10th-order polynomial (Horner's method)
1150 : const __m256 c0 = _mm256_set1_ps(1.0f);
1151 : const __m256 c1 = _mm256_set1_ps(1.0f);
1152 : const __m256 c2 = _mm256_set1_ps(0.5f);
1153 : const __m256 c3 = _mm256_set1_ps(1.0f / 6.0f);
1154 : const __m256 c4 = _mm256_set1_ps(1.0f / 24.0f);
1155 : const __m256 c5 = _mm256_set1_ps(1.0f / 120.0f);
1156 : const __m256 c6 = _mm256_set1_ps(1.0f / 720.0f);
1157 : const __m256 c7 = _mm256_set1_ps(1.0f / 5040.0f);
1158 : const __m256 c8 = _mm256_set1_ps(1.0f / 40320.0f);
1159 : const __m256 c9 = _mm256_set1_ps(1.0f / 362880.0f);
1160 : const __m256 c10 = _mm256_set1_ps(1.0f / 3628800.0f);
1161 :
1162 : __m256 y = c10;
1163 : y = _mm256_fmadd_ps(y, r, c9);
1164 : y = _mm256_fmadd_ps(y, r, c8);
1165 : y = _mm256_fmadd_ps(y, r, c7);
1166 : y = _mm256_fmadd_ps(y, r, c6);
1167 : y = _mm256_fmadd_ps(y, r, c5);
1168 : y = _mm256_fmadd_ps(y, r, c4);
1169 : y = _mm256_fmadd_ps(y, r, c3);
1170 : y = _mm256_fmadd_ps(y, r, c2);
1171 : y = _mm256_fmadd_ps(y, r, c1);
1172 : y = _mm256_fmadd_ps(y, r, c0); // final y = (((...r+...)*r+...)*r + 1)
1173 :
1174 : // Reconstruct exp(x) = 2^n * exp(r)
1175 : __m256i emm0 = _mm256_cvtps_epi32(fx);
1176 : emm0 = _mm256_add_epi32(emm0, _mm256_set1_epi32(127));
1177 : emm0 = _mm256_slli_epi32(emm0, 23);
1178 : __m256 pow2n = _mm256_castsi256_ps(emm0);
1179 :
1180 9 : return _mm256_mul_ps(y, pow2n);
1181 : }
1182 :
1183 1 : static void softmax_row_inplace(float *qk_out, size_t start_row, size_t end_row,
1184 : size_t num_heads) {
1185 1 : size_t row_range = end_row - start_row;
1186 1 : const size_t full_blocks = (num_heads / 8) * 8;
1187 : // const size_t remainder = num_heads % 8;
1188 :
1189 1 : float *max_vals = new float[num_heads];
1190 1 : float *sum_vals = new float[num_heads];
1191 : // 1. max
1192 11 : for (size_t c = 0; c < num_heads; ++c) {
1193 10 : float max_val = -INFINITY;
1194 40 : for (size_t r = start_row; r < end_row; ++r)
1195 49 : max_val = std::max(max_val, qk_out[r * num_heads + c]);
1196 10 : max_vals[c] = max_val;
1197 : }
1198 :
1199 : // 2. inplace exp + sum
1200 2 : for (size_t c = 0; c < full_blocks; c += 8) {
1201 1 : __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
1202 : __m256 sum = _mm256_setzero_ps();
1203 4 : for (size_t r = 0; r < row_range; ++r) {
1204 3 : float *ptr = &qk_out[(start_row + r) * num_heads + c];
1205 : __m256 val = _mm256_loadu_ps(ptr);
1206 3 : __m256 e = exp256_ps(_mm256_sub_ps(val, maxv));
1207 : _mm256_storeu_ps(ptr, e); // overwrite qk_out
1208 : sum = _mm256_add_ps(sum, e);
1209 : }
1210 1 : _mm256_storeu_ps(&sum_vals[c], sum);
1211 : }
1212 :
1213 3 : for (size_t c = full_blocks; c < num_heads; ++c) {
1214 : float sum = 0.0f;
1215 2 : float maxv = max_vals[c];
1216 8 : for (size_t r = 0; r < row_range; ++r) {
1217 6 : float &a = qk_out[(start_row + r) * num_heads + c];
1218 6 : a = std::exp(a - maxv); // overwrite qk_out
1219 6 : sum += a;
1220 : }
1221 2 : sum_vals[c] = sum;
1222 : }
1223 : // 3. softmax = exp / sum (inplace)
1224 4 : for (size_t r = 0; r < row_range; ++r) {
1225 6 : for (size_t c = 0; c < full_blocks; c += 8) {
1226 3 : float *ptr = &qk_out[(start_row + r) * num_heads + c];
1227 : __m256 val = _mm256_loadu_ps(ptr); // already exp(x - max)
1228 3 : __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
1229 : __m256 soft = _mm256_div_ps(val, sumv);
1230 : _mm256_storeu_ps(ptr, soft);
1231 : }
1232 9 : for (size_t c = full_blocks; c < num_heads; ++c) {
1233 6 : qk_out[(start_row + r) * num_heads + c] /= sum_vals[c];
1234 : }
1235 : }
1236 :
1237 1 : delete[] max_vals;
1238 1 : delete[] sum_vals;
1239 1 : }
1240 :
1241 0 : static void softmax_row_with_sink_inplace(float *qk_out, size_t start_row,
1242 : size_t end_row, size_t num_heads,
1243 : float *sink) {
1244 0 : size_t row_range = end_row - start_row;
1245 0 : const size_t full_blocks = (num_heads / 8) * 8;
1246 :
1247 0 : float *max_vals = new float[num_heads];
1248 0 : float *sum_vals = new float[num_heads];
1249 : // 1. max
1250 0 : for (size_t c = 0; c < num_heads; ++c) {
1251 0 : float max_val = -INFINITY;
1252 0 : for (size_t r = start_row; r < end_row; ++r)
1253 0 : max_val = std::max(max_val, qk_out[r * num_heads + c]);
1254 0 : max_vals[c] = std::max(sink[c], max_val);
1255 : }
1256 :
1257 : // 2. inplace exp + sum
1258 0 : for (size_t c = 0; c < full_blocks; c += 8) {
1259 0 : __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
1260 0 : __m256 sum = _mm256_loadu_ps(&sink[c]);
1261 0 : sum = exp256_ps(_mm256_sub_ps(sum, maxv));
1262 0 : for (size_t r = 0; r < row_range; ++r) {
1263 0 : float *ptr = &qk_out[(start_row + r) * num_heads + c];
1264 : __m256 val = _mm256_loadu_ps(ptr);
1265 0 : __m256 e = exp256_ps(_mm256_sub_ps(val, maxv));
1266 : _mm256_storeu_ps(ptr, e); // overwrite qk_out
1267 : sum = _mm256_add_ps(sum, e);
1268 : }
1269 0 : _mm256_storeu_ps(&sum_vals[c], sum);
1270 : }
1271 :
1272 0 : for (size_t c = full_blocks; c < num_heads; ++c) {
1273 0 : float maxv = max_vals[c];
1274 0 : float sum = std::exp(sink[c] - maxv);
1275 0 : for (size_t r = 0; r < row_range; ++r) {
1276 0 : float &a = qk_out[(start_row + r) * num_heads + c];
1277 0 : a = std::exp(a - maxv); // overwrite qk_out
1278 0 : sum += a;
1279 : }
1280 0 : sum_vals[c] = sum;
1281 : }
1282 : // 3. softmax = exp / sum (inplace)
1283 0 : for (size_t r = 0; r < row_range; ++r) {
1284 0 : for (size_t c = 0; c < full_blocks; c += 8) {
1285 0 : float *ptr = &qk_out[(start_row + r) * num_heads + c];
1286 : __m256 val = _mm256_loadu_ps(ptr); // already exp(x - max)
1287 0 : __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
1288 : __m256 soft = _mm256_div_ps(val, sumv);
1289 : _mm256_storeu_ps(ptr, soft);
1290 : }
1291 0 : for (size_t c = full_blocks; c < num_heads; ++c) {
1292 0 : qk_out[(start_row + r) * num_heads + c] /= sum_vals[c];
1293 : }
1294 : }
1295 :
1296 0 : delete[] max_vals;
1297 0 : delete[] sum_vals;
1298 0 : }
1299 :
1300 : template <>
1301 1 : void softmax_row_inplace(float *qk_out, size_t start_row, size_t end_row,
1302 : size_t num_heads, float *sink) {
1303 1 : if (sink == nullptr) {
1304 1 : return softmax_row_inplace(qk_out, start_row, end_row, num_heads);
1305 : } else {
1306 0 : return softmax_row_with_sink_inplace(qk_out, start_row, end_row, num_heads,
1307 0 : sink);
1308 : }
1309 : }
1310 :
1311 1 : static void softmax_row(float *qk_out, size_t start_row, size_t end_row,
1312 : size_t num_heads) {
1313 1 : const size_t full_block = (num_heads / 8) * 8;
1314 :
1315 1 : float *max_vals = new float[num_heads];
1316 1 : float *sum_vals = new float[num_heads];
1317 :
1318 : // 1. Find Max along with col
1319 11 : for (size_t c = 0; c < num_heads; ++c) {
1320 10 : float max_val = -INFINITY;
1321 40 : for (size_t r = start_row; r < end_row; ++r) {
1322 49 : max_val = std::max(max_val, qk_out[r * num_heads + c]);
1323 : }
1324 10 : max_vals[c] = max_val;
1325 : }
1326 :
1327 : // 2. Compute sum along with col (exp vectorized)
1328 2 : for (size_t c = 0; c < full_block; c += 8) {
1329 : __m256 sum = _mm256_setzero_ps();
1330 4 : for (size_t r = start_row; r < end_row; ++r) {
1331 3 : __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
1332 3 : __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
1333 : __m256 sub = _mm256_sub_ps(val, maxv);
1334 3 : __m256 e = exp256_ps(sub);
1335 : sum = _mm256_add_ps(sum, e);
1336 : }
1337 1 : _mm256_storeu_ps(&sum_vals[c], sum);
1338 : }
1339 :
1340 3 : for (size_t c = full_block; c < num_heads; ++c) {
1341 : float sum = 0.0f;
1342 8 : for (size_t r = start_row; r < end_row; ++r) {
1343 6 : sum += std::exp(qk_out[r * num_heads + c] - max_vals[c]);
1344 : }
1345 2 : sum_vals[c] = sum;
1346 : }
1347 :
1348 : // 3. apply softmax
1349 4 : for (size_t r = start_row; r < end_row; ++r) {
1350 6 : for (size_t c = 0; c < full_block; c += 8) {
1351 3 : __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
1352 3 : __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
1353 : __m256 sub = _mm256_sub_ps(val, maxv);
1354 3 : __m256 e = exp256_ps(sub);
1355 3 : __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
1356 : __m256 softmax = _mm256_div_ps(e, sumv);
1357 : _mm256_storeu_ps(&qk_out[r * num_heads + c], softmax);
1358 : }
1359 9 : for (size_t c = full_block; c < num_heads; ++c) {
1360 6 : qk_out[r * num_heads + c] =
1361 6 : std::exp(qk_out[r * num_heads + c] - max_vals[c]) / sum_vals[c];
1362 : }
1363 : }
1364 :
1365 1 : delete[] max_vals;
1366 1 : delete[] sum_vals;
1367 1 : }
1368 :
1369 0 : static void softmax_row_with_sink(float *qk_out, size_t start_row,
1370 : size_t end_row, size_t num_heads,
1371 : float *sink) {
1372 0 : const size_t full_block = (num_heads / 8) * 8;
1373 :
1374 0 : float *max_vals = new float[num_heads];
1375 0 : float *sum_vals = new float[num_heads];
1376 :
1377 : // 1. Find Max along with col
1378 0 : for (size_t c = 0; c < num_heads; ++c) {
1379 0 : float max_val = -INFINITY;
1380 0 : for (size_t r = start_row; r < end_row; ++r) {
1381 0 : max_val = std::max(max_val, qk_out[r * num_heads + c]);
1382 : }
1383 0 : max_vals[c] = std::max(max_val, sink[c]);
1384 : }
1385 :
1386 : // 2. Compute sum along with col (exp vectorized)
1387 0 : for (size_t c = 0; c < full_block; c += 8) {
1388 0 : __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
1389 0 : __m256 sum = _mm256_loadu_ps(&sink[c]);
1390 : sum = _mm256_sub_ps(sum, maxv);
1391 0 : sum = exp256_ps(sum);
1392 0 : for (size_t r = start_row; r < end_row; ++r) {
1393 0 : __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
1394 : __m256 sub = _mm256_sub_ps(val, maxv);
1395 0 : __m256 e = exp256_ps(sub);
1396 : sum = _mm256_add_ps(sum, e);
1397 : }
1398 0 : _mm256_storeu_ps(&sum_vals[c], sum);
1399 : }
1400 :
1401 0 : for (size_t c = full_block; c < num_heads; ++c) {
1402 0 : float sum = std::exp(sink[c] - max_vals[c]);
1403 0 : for (size_t r = start_row; r < end_row; ++r) {
1404 0 : sum += std::exp(qk_out[r * num_heads + c] - max_vals[c]);
1405 : }
1406 0 : sum_vals[c] = sum;
1407 : }
1408 :
1409 : // 3. apply softmax
1410 0 : for (size_t r = start_row; r < end_row; ++r) {
1411 0 : for (size_t c = 0; c < full_block; c += 8) {
1412 0 : __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
1413 0 : __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
1414 : __m256 sub = _mm256_sub_ps(val, maxv);
1415 0 : __m256 e = exp256_ps(sub);
1416 0 : __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
1417 : __m256 softmax = _mm256_div_ps(e, sumv);
1418 : _mm256_storeu_ps(&qk_out[r * num_heads + c], softmax);
1419 : }
1420 0 : for (size_t c = full_block; c < num_heads; ++c) {
1421 0 : qk_out[r * num_heads + c] =
1422 0 : std::exp(qk_out[r * num_heads + c] - max_vals[c]) / sum_vals[c];
1423 : }
1424 : }
1425 :
1426 0 : delete[] max_vals;
1427 0 : delete[] sum_vals;
1428 0 : }
1429 :
1430 : template <>
1431 1 : void softmax_row(float *qk_out, size_t start_row, size_t end_row,
1432 : size_t num_heads, float *sink) {
1433 1 : if (sink == nullptr) {
1434 1 : return softmax_row(qk_out, start_row, end_row, num_heads);
1435 : } else {
1436 0 : return softmax_row_with_sink(qk_out, start_row, end_row, num_heads, sink);
1437 : }
1438 : }
1439 : #ifdef _WIN32
1440 : #define COMPUTE_FP16_TO_FP32(x) \
1441 : _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
1442 : #define COMPUTE_FP32_TO_FP16(x) \
1443 : _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
1444 : #elif defined(__TIZEN__) && !defined(__F16C__)
1445 : #define COMPUTE_FP16_TO_FP32(x) nntrainer::compute_fp16_to_fp32(x)
1446 : #define COMPUTE_FP32_TO_FP16(x) nntrainer::compute_fp32_to_fp16(x)
1447 : #else
1448 : #define COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
1449 : #define COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
1450 : #endif
1451 :
1452 : static inline __m256 convert_vector_f16_to_f32(__m128i x) {
1453 : #if defined(__TIZEN__) && !defined(__F16C__)
1454 : alignas(32) uint16_t u16_array[8]; // 32-byte aligned storage
1455 : alignas(32) float f32_array[8]; // 32-byte aligned storage
1456 :
1457 : // Safely store __m128i to array (avoids aliasing)
1458 : _mm_storeu_si128(reinterpret_cast<__m128i *>(u16_array), x);
1459 :
1460 : // Convert each FP16 value to FP32
1461 : for (int i = 0; i < 8; i++) {
1462 : f32_array[i] = COMPUTE_FP16_TO_FP32(u16_array[i]);
1463 : }
1464 :
1465 : // Load aligned array into __m256
1466 : return _mm256_load_ps(f32_array);
1467 : #else
1468 : return _mm256_cvtph_ps(x);
1469 : #endif
1470 : }
1471 :
1472 : static inline __m128i convert_vector_f32_to_f16(__m256 x) {
1473 : #if defined(__TIZEN__) && !defined(__F16C__)
1474 : __m128i vec_f16;
1475 : float *f32_ptr = reinterpret_cast<float *>(&x);
1476 : uint16_t *u16_ptr = reinterpret_cast<uint16_t *>(&vec_f16);
1477 : for (int i = 0; i < 8; i++) {
1478 : u16_ptr[i] = COMPUTE_FP32_TO_FP16(f32_ptr[i]);
1479 : }
1480 : return vec_f16;
1481 : #else
1482 : return _mm256_cvtps_ph(x, 0);
1483 : #endif
1484 : }
1485 :
1486 : static inline __m128i convert_vector_f32_to_f16(__m128 x) {
1487 : #if defined(__TIZEN__) && !defined(__F16C__)
1488 : __m128i vec_f16;
1489 : float *f32_ptr = reinterpret_cast<float *>(&x);
1490 : uint16_t *u16_ptr = reinterpret_cast<uint16_t *>(&vec_f16);
1491 :
1492 : for (int i = 0; i < 4; i++) {
1493 : u16_ptr[i] = COMPUTE_FP32_TO_FP16(f32_ptr[i]);
1494 : }
1495 : return vec_f16;
1496 : #else
1497 : return _mm_cvtps_ph(x, 0);
1498 : #endif
1499 : }
1500 :
1501 12 : static inline void load_fp16_8_to_chunk(const uint16_t *src, float *dst,
1502 : int chunk_size) {
1503 : int i = 0;
1504 24 : for (; i + 8 <= chunk_size; i += 8) {
1505 12 : __m128i half = _mm_loadu_si128(reinterpret_cast<const __m128i *>(src + i));
1506 : __m256 f32 = convert_vector_f16_to_f32(half);
1507 12 : _mm256_storeu_ps(&dst[i], f32);
1508 : }
1509 32 : for (; i < chunk_size; ++i) {
1510 20 : dst[i] = nntrainer::compute_fp16_to_fp32(src[i]);
1511 : }
1512 12 : }
1513 :
1514 1 : void compute_fp16vcache_fp32_transposed(int row_num, const float *in,
1515 : const uint16_t *vcache, float *output,
1516 : int num_cache_head, int gqa_size,
1517 : int head_dim,
1518 : size_t local_window_size) {
1519 : // cpu_set_t cpu_set;
1520 : // CPU_ZERO(&cpu_set);
1521 : // std::vector<bool> affinity(8, false);
1522 : // affinity[i % affinity.size()] = true;
1523 :
1524 : // for (std::size_t j = 0;
1525 : // j < std::min<std::size_t>(affinity.size(), CPU_SETSIZE); ++j) {
1526 : // if (affinity[j])
1527 : // CPU_SET(j, &cpu_set);
1528 : // }
1529 : // pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpu_set);
1530 :
1531 1 : std::vector<float> tmp_fp32(head_dim);
1532 1 : int num_blocks = head_dim / 8;
1533 2 : __m256 *sumVec = new __m256[std::max(1, num_blocks * gqa_size)];
1534 :
1535 3 : for (int n = 0; n < num_cache_head; ++n) {
1536 2 : int rem = head_dim % 8;
1537 :
1538 : /* Declaration: std::vector<__m256> sumVec(num_blocks * gqa_size,
1539 : * _mm256_setzero_ps()); caused warning: ignoring attributes on template
1540 : * argument ‘__m256’ [-Wignored-attributes].
1541 : * So it is implemented that way.
1542 : */
1543 6 : for (int i = 0; i < num_blocks * gqa_size; i++) {
1544 4 : sumVec[i] = _mm256_setzero_ps();
1545 : }
1546 2 : std::vector<float> sumRem((size_t)gqa_size * rem, 0.0f);
1547 :
1548 6 : for (int j = row_num < local_window_size ? 0
1549 0 : : row_num + 1 - local_window_size;
1550 6 : j <= row_num; ++j) {
1551 4 : const uint16_t *vptr = vcache + (j * num_cache_head + n) * head_dim;
1552 4 : load_fp16_8_to_chunk(vptr, tmp_fp32.data(), head_dim);
1553 :
1554 12 : for (int h = 0; h < gqa_size; ++h) {
1555 8 : float a_val =
1556 : in[(row_num < local_window_size
1557 8 : ? j
1558 0 : : (unsigned long)(j - (row_num + 1 - local_window_size))) *
1559 8 : (unsigned long)(gqa_size * num_cache_head) +
1560 8 : (unsigned long)(n * gqa_size) + h];
1561 :
1562 : __m256 inVec = _mm256_set1_ps(a_val);
1563 :
1564 16 : for (int b = 0; b < num_blocks; ++b) {
1565 8 : __m256 bVec = _mm256_loadu_ps(&tmp_fp32[b * 8]);
1566 8 : sumVec[h * num_blocks + b] =
1567 8 : _mm256_fmadd_ps(inVec, bVec, sumVec[h * num_blocks + b]);
1568 : }
1569 :
1570 8 : float *remPtr = &sumRem.data()[h * rem];
1571 8 : int base = num_blocks * 8;
1572 16 : for (int r = 0; r < rem; ++r) {
1573 8 : remPtr[r] += a_val * tmp_fp32[base + r];
1574 : }
1575 : }
1576 : }
1577 :
1578 6 : for (int h = 0; h < gqa_size; ++h) {
1579 8 : for (int b = 0; b < num_blocks; ++b) {
1580 4 : int out_base = (n * gqa_size + h) * head_dim + b * 8;
1581 4 : _mm256_storeu_ps(&output[out_base], sumVec[h * num_blocks + b]);
1582 : }
1583 :
1584 4 : float *remPtr = &sumRem.data()[h * rem];
1585 : // float *remPtr = &sumRem[h * rem];
1586 4 : int base = num_blocks * 8;
1587 8 : for (int r = 0; r < rem; ++r) {
1588 4 : int out_idx = (n * gqa_size + h) * head_dim + base + r;
1589 4 : output[out_idx] = remPtr[r];
1590 : }
1591 : }
1592 2 : }
1593 1 : delete[] sumVec;
1594 1 : }
1595 :
1596 : template <>
1597 1 : void compute_kcaches(const float *in, const uint16_t *kcache, float *output,
1598 : int num_rows, int num_cache_head, int head_dim,
1599 : int gqa_size, int tile_size, size_t local_window_size) {
1600 1 : std::vector<float> tmp_fp32(head_dim);
1601 :
1602 1 : int start_row =
1603 1 : num_rows < local_window_size ? 0 : num_rows - local_window_size;
1604 1 : int row_cnt = num_rows < local_window_size ? num_rows : local_window_size;
1605 1 : const int tile_count = (row_cnt + tile_size - 1) / tile_size;
1606 :
1607 3 : for (int n = 0; n < num_cache_head; ++n) {
1608 4 : for (int t = 0; t < tile_count; ++t) {
1609 2 : int row_tile_start = t * tile_size;
1610 2 : int tile_rows = std::min(tile_size, row_cnt - row_tile_start);
1611 :
1612 10 : for (int g = 0; g < gqa_size; ++g) {
1613 8 : const float *in_ptr = in + n * gqa_size * head_dim + g * head_dim;
1614 16 : for (int t_row = 0; t_row < tile_rows; ++t_row) {
1615 8 : int row = start_row + row_tile_start + t_row;
1616 8 : if (row + 1 < num_rows) {
1617 0 : const uint16_t *next_kptr =
1618 0 : kcache + ((row + 1) * num_cache_head + n) * head_dim;
1619 : _mm_prefetch(reinterpret_cast<const char *>(next_kptr),
1620 : _MM_HINT_T0);
1621 : }
1622 8 : const uint16_t *kptr = kcache + (row * num_cache_head + n) * head_dim;
1623 8 : load_fp16_8_to_chunk(kptr, tmp_fp32.data(), head_dim);
1624 :
1625 : const float *k_row = tmp_fp32.data();
1626 :
1627 : float sum = 0.0f;
1628 : int i = 0;
1629 : __m256 acc = _mm256_setzero_ps();
1630 16 : for (; i + 8 <= head_dim; i += 8) {
1631 8 : __m256 va = _mm256_loadu_ps(in_ptr + i);
1632 8 : __m256 vb = _mm256_loadu_ps(k_row + i);
1633 : acc = _mm256_fmadd_ps(va, vb, acc);
1634 : }
1635 :
1636 : __m128 low = _mm256_castps256_ps128(acc);
1637 : __m128 high = _mm256_extractf128_ps(acc, 1);
1638 : __m128 sum128 = _mm_add_ps(low, high);
1639 : sum128 = _mm_hadd_ps(sum128, sum128);
1640 : sum128 = _mm_hadd_ps(sum128, sum128);
1641 8 : sum += _mm_cvtss_f32(sum128);
1642 :
1643 24 : for (; i < head_dim; ++i)
1644 16 : sum += in_ptr[i] * k_row[i];
1645 :
1646 8 : output[(row - start_row) * num_cache_head * gqa_size + n * gqa_size +
1647 8 : g] = sum / sqrt((float)head_dim);
1648 : }
1649 : }
1650 : }
1651 : }
1652 1 : }
1653 :
1654 2 : void compute_rotary_emb_value(unsigned int width, unsigned int dim,
1655 : unsigned int half_, float *inout, void *output,
1656 : const float *cos_, const float *sin_,
1657 : bool only_convert_to_fp16) {
1658 : enum class OutputType { FP16, FP32 };
1659 :
1660 : OutputType out_type = OutputType::FP32;
1661 2 : if (output != nullptr)
1662 : out_type = OutputType::FP16;
1663 :
1664 6 : for (unsigned int w = 0; w < width; w += dim) {
1665 : unsigned int k = 0;
1666 8 : for (; k + 7 < half_; k += 8) {
1667 4 : unsigned int i0 = w + k;
1668 4 : unsigned int i1 = w + k + half_;
1669 :
1670 4 : __m256 a = _mm256_loadu_ps(&inout[i0]);
1671 4 : __m256 b = _mm256_loadu_ps(&inout[i1]);
1672 :
1673 4 : if (only_convert_to_fp16) {
1674 0 : if (out_type == OutputType::FP16) {
1675 : __m128i a_fp16 = convert_vector_f32_to_f16(a);
1676 : __m128i b_fp16 = convert_vector_f32_to_f16(b);
1677 :
1678 0 : _mm_storeu_si128(
1679 0 : reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i0),
1680 : a_fp16);
1681 0 : _mm_storeu_si128(
1682 0 : reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i1),
1683 : b_fp16);
1684 : }
1685 :
1686 : } else {
1687 4 : __m256 cos_v = _mm256_loadu_ps(&cos_[k]);
1688 4 : __m256 sin_v = _mm256_loadu_ps(&sin_[k]);
1689 :
1690 : __m256 out0 =
1691 : _mm256_sub_ps(_mm256_mul_ps(a, cos_v), _mm256_mul_ps(b, sin_v));
1692 : __m256 out1 =
1693 : _mm256_add_ps(_mm256_mul_ps(a, sin_v), _mm256_mul_ps(b, cos_v));
1694 :
1695 4 : if (out_type == OutputType::FP16) {
1696 : __m128i out0_fp16 = convert_vector_f32_to_f16(out0);
1697 : __m128i out1_fp16 = convert_vector_f32_to_f16(out1);
1698 :
1699 2 : _mm_storeu_si128(
1700 2 : reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i0),
1701 : out0_fp16);
1702 2 : _mm_storeu_si128(
1703 2 : reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i1),
1704 : out1_fp16);
1705 :
1706 : } else if (out_type == OutputType::FP32) {
1707 : _mm256_storeu_ps(&inout[i0], out0);
1708 : _mm256_storeu_ps(&inout[i1], out1);
1709 : }
1710 : }
1711 : }
1712 :
1713 12 : for (; k < half_; ++k) {
1714 8 : unsigned int i0 = w + k;
1715 8 : unsigned int i1 = w + k + half_;
1716 : // assert(i1 < width && "Scalar i1 overflow!");
1717 8 : float a = inout[i0];
1718 8 : float b = inout[i1];
1719 :
1720 8 : if (only_convert_to_fp16) {
1721 0 : static_cast<uint16_t *>(output)[i0] = COMPUTE_FP32_TO_FP16(a);
1722 0 : static_cast<uint16_t *>(output)[i1] = COMPUTE_FP32_TO_FP16(b);
1723 : } else {
1724 8 : float c = cos_[k];
1725 8 : float s = sin_[k];
1726 :
1727 8 : float out0 = a * c - b * s;
1728 8 : float out1 = a * s + b * c;
1729 :
1730 8 : if (out_type == OutputType::FP16) {
1731 4 : static_cast<uint16_t *>(output)[i0] = COMPUTE_FP32_TO_FP16(out0);
1732 8 : static_cast<uint16_t *>(output)[i1] = COMPUTE_FP32_TO_FP16(out1);
1733 : } else if (out_type == OutputType::FP32) {
1734 4 : inout[i0] = out0;
1735 4 : inout[i1] = out1;
1736 : }
1737 : }
1738 : }
1739 : }
1740 2 : }
1741 :
1742 : static float hsum_avx(__m256 v) {
1743 : __m128 vlow = _mm256_castps256_ps128(v);
1744 : __m128 vhigh = _mm256_extractf128_ps(v, 1);
1745 : vlow = _mm_add_ps(vlow, vhigh);
1746 : __m128 shuf = _mm_movehdup_ps(vlow);
1747 : __m128 sums = _mm_add_ps(vlow, shuf);
1748 : shuf = _mm_movehl_ps(shuf, sums);
1749 : sums = _mm_add_ss(sums, shuf);
1750 : return _mm_cvtss_f32(sums);
1751 : }
1752 :
1753 0 : void rms_norm_wrt_width_fp32_intrinsic(const float *__restrict X,
1754 : float *__restrict Y, size_t H, size_t W,
1755 : float epsilon) {
1756 0 : for (std::size_t h = 0; h < H; ++h) {
1757 0 : const float *rowX = X + h * W;
1758 0 : float *rowY = Y + h * W;
1759 :
1760 : std::size_t i = 0;
1761 : __m256 acc0 = _mm256_setzero_ps();
1762 : __m256 acc1 = _mm256_setzero_ps();
1763 : __m256 acc2 = _mm256_setzero_ps();
1764 : __m256 acc3 = _mm256_setzero_ps();
1765 :
1766 0 : for (; i + 32 <= W; i += 32) {
1767 0 : __m256 x0 = _mm256_loadu_ps(rowX + i);
1768 0 : __m256 x1 = _mm256_loadu_ps(rowX + i + 8);
1769 0 : __m256 x2 = _mm256_loadu_ps(rowX + i + 16);
1770 0 : __m256 x3 = _mm256_loadu_ps(rowX + i + 24);
1771 : acc0 = _mm256_fmadd_ps(x0, x0, acc0);
1772 : acc1 = _mm256_fmadd_ps(x1, x1, acc1);
1773 : acc2 = _mm256_fmadd_ps(x2, x2, acc2);
1774 : acc3 = _mm256_fmadd_ps(x3, x3, acc3);
1775 : }
1776 0 : for (; i + 8 <= W; i += 8) {
1777 0 : __m256 x = _mm256_loadu_ps(rowX + i);
1778 : acc0 = _mm256_fmadd_ps(x, x, acc0);
1779 : }
1780 : float sumsq =
1781 0 : hsum_avx(acc0) + hsum_avx(acc1) + hsum_avx(acc2) + hsum_avx(acc3);
1782 0 : for (; i < W; ++i) {
1783 0 : float v = rowX[i];
1784 0 : sumsq += v * v;
1785 : }
1786 :
1787 0 : float mean = sumsq / static_cast<float>(W);
1788 0 : float scale = 1.0f / std::sqrt(mean + epsilon);
1789 : __m256 vscale = _mm256_set1_ps(scale);
1790 :
1791 : i = 0;
1792 0 : for (; i + 32 <= W; i += 32) {
1793 0 : __m256 x0 = _mm256_loadu_ps(rowX + i);
1794 0 : __m256 x1 = _mm256_loadu_ps(rowX + i + 8);
1795 0 : __m256 x2 = _mm256_loadu_ps(rowX + i + 16);
1796 0 : __m256 x3 = _mm256_loadu_ps(rowX + i + 24);
1797 0 : _mm256_storeu_ps(rowY + i, _mm256_mul_ps(x0, vscale));
1798 0 : _mm256_storeu_ps(rowY + i + 8, _mm256_mul_ps(x1, vscale));
1799 0 : _mm256_storeu_ps(rowY + i + 16, _mm256_mul_ps(x2, vscale));
1800 0 : _mm256_storeu_ps(rowY + i + 24, _mm256_mul_ps(x3, vscale));
1801 : }
1802 0 : for (; i + 8 <= W; i += 8) {
1803 0 : __m256 x = _mm256_loadu_ps(rowX + i);
1804 0 : _mm256_storeu_ps(rowY + i, _mm256_mul_ps(x, vscale));
1805 : }
1806 0 : for (; i < W; ++i) {
1807 0 : rowY[i] = rowX[i] * scale;
1808 : }
1809 : }
1810 0 : }
1811 :
1812 : template <>
1813 21 : void clamp(const float *input, float *output, size_t length, float lower_bound,
1814 : float upper_bound) {
1815 : const size_t step = 8;
1816 : const __m256 vLo = _mm256_set1_ps(lower_bound);
1817 : const __m256 vHi = _mm256_set1_ps(upper_bound);
1818 :
1819 : size_t i = 0;
1820 8085 : for (; i + step <= length; i += step) {
1821 8064 : __m256 v = _mm256_loadu_ps(input + i);
1822 : v = _mm256_max_ps(v, vLo);
1823 : v = _mm256_min_ps(v, vHi);
1824 8064 : _mm256_storeu_ps(output + i, v);
1825 : }
1826 21 : if (i < length) {
1827 0 : for (size_t k = i; k < length; ++k) {
1828 0 : float v = input[k];
1829 : // If v is NaN, the comparisons below will yield false; we keep NaN.
1830 : // This matches most framework "pass-through NaN" behavior.
1831 0 : output[k] =
1832 0 : (v < lower_bound) ? lower_bound : ((v > upper_bound) ? upper_bound : v);
1833 : }
1834 : }
1835 21 : }
1836 :
1837 1 : void copy_f16_f32(unsigned int N, const uint16_t *input, float *output) {
1838 : unsigned int idx = 0;
1839 : const uint16_t *data = (const uint16_t *)input;
1840 :
1841 : // 16 half-precision floating point values to single-precision values
1842 6 : for (; N - idx >= 16; idx += 16) {
1843 : const __m256 vec0 =
1844 : convert_vector_f16_to_f32(_mm_loadu_si128((const __m128i *)data));
1845 : const __m256 vec1 =
1846 : convert_vector_f16_to_f32(_mm_loadu_si128((const __m128i *)(data + 8)));
1847 5 : data += 16;
1848 :
1849 : _mm256_storeu_ps(output, vec0);
1850 : _mm256_storeu_ps(output + 8, vec1);
1851 5 : output += 16;
1852 : }
1853 : // 8 half-precision floating point values to single-precision values
1854 2 : for (; N - idx >= 8; idx += 8) {
1855 : const __m256 vec =
1856 : convert_vector_f16_to_f32(_mm_loadu_si128((const __m128i *)data));
1857 1 : data += 8;
1858 :
1859 : _mm256_storeu_ps(output, vec);
1860 1 : output += 8;
1861 : }
1862 : // remaining half-precision floating point values to single-precision values
1863 3 : while (idx < N) {
1864 2 : *output = compute_fp16_to_fp32(*data);
1865 2 : ++output;
1866 2 : ++data;
1867 2 : ++idx;
1868 : }
1869 1 : }
1870 :
1871 0 : void copy_f32_f16(unsigned int N, const float *input, uint16_t *output) {
1872 : unsigned int idx = 0;
1873 : uint16_t *out_data = (uint16_t *)output;
1874 :
1875 : // 16 single-precision floating point values to half-precision values
1876 0 : for (; N - idx >= 16; idx += 16) {
1877 : const __m256 vec0 = _mm256_loadu_ps(input);
1878 : const __m256 vec1 = _mm256_loadu_ps(input + 8);
1879 0 : input += 16;
1880 :
1881 : _mm_storeu_si128((__m128i *)out_data, convert_vector_f32_to_f16(vec0));
1882 : _mm_storeu_si128((__m128i *)(out_data + 8),
1883 : convert_vector_f32_to_f16(vec1));
1884 0 : out_data += 16;
1885 : }
1886 : // 8 single-precision floating point values to half-precision values
1887 0 : for (; N - idx >= 8; idx += 8) {
1888 : const __m256 vec = _mm256_loadu_ps(input);
1889 0 : input += 8;
1890 :
1891 : _mm_storeu_si128((__m128i *)out_data, convert_vector_f32_to_f16(vec));
1892 0 : out_data += 8;
1893 : }
1894 : // 4 single-precision floating point values to half-precision values
1895 0 : for (; N - idx >= 4; idx += 4) {
1896 : const __m128 vec = _mm_loadu_ps(input);
1897 0 : input += 4;
1898 :
1899 : _mm_storeu_si64((__m128i *)out_data, convert_vector_f32_to_f16(vec));
1900 0 : out_data += 4;
1901 : }
1902 : // remaining single-precision floating point values to half-precision values
1903 0 : while (idx < N) {
1904 0 : *out_data = compute_fp32_to_fp16(*input);
1905 0 : ++out_data;
1906 0 : ++input;
1907 0 : ++idx;
1908 : }
1909 0 : }
1910 :
1911 0 : void create_q4_0_weights(const uint8_t *int4_weight, uint8_t *q4_0_weight) {
1912 : // Load 16 bytes of input data
1913 : __m128i input = _mm_loadu_si128((const __m128i *)int4_weight);
1914 :
1915 : // Create masks for extracting low and high nibbles
1916 : const __m128i low_nibble_mask = _mm_set1_epi8(0x0F);
1917 : const __m128i high_nibble_mask = _mm_set1_epi8(static_cast<char>(0xF0));
1918 :
1919 : // Extract low nibbles from first 8 bytes
1920 : __m128i A = _mm_and_si128(input, low_nibble_mask);
1921 :
1922 : // Extract high nibbles from first 8 bytes and shift right
1923 : __m128i B = _mm_and_si128(input, high_nibble_mask);
1924 : B = _mm_srli_epi16(B, 4);
1925 :
1926 : // Extract low nibbles from second 8 bytes
1927 : __m128i input_shifted = _mm_bsrli_si128(input, 8);
1928 : __m128i C = _mm_and_si128(input_shifted, low_nibble_mask);
1929 :
1930 : // Extract high nibbles from second 8 bytes and shift right
1931 : __m128i D = _mm_and_si128(input_shifted, high_nibble_mask);
1932 : D = _mm_srli_epi16(D, 4);
1933 :
1934 : // Interleave low nibbles: v0 from first8, v2 from second8
1935 : __m128i AC = _mm_or_si128(A, _mm_slli_epi16(C, 4));
1936 :
1937 : // Interleave high nibbles: v1 from first8, v3 from second8
1938 : __m128i BD = _mm_or_si128(B, _mm_slli_epi16(D, 4));
1939 :
1940 : // Pack the results: interleave low and high bytes
1941 : __m128i result = _mm_unpacklo_epi8(AC, BD);
1942 :
1943 : // Store the 16 bytes result
1944 : _mm_storeu_si128((__m128i *)q4_0_weight, result);
1945 0 : }
1946 :
1947 10965600 : static inline void transpose_matrix_16x16(const uint8_t *input,
1948 : int input_stride, uint8_t *output,
1949 : int output_stride) {
1950 : const uint8_t *src = input;
1951 : uint8_t *dst = output;
1952 :
1953 : __m256i rows[8];
1954 98690400 : for (int i = 0; i < 8; ++i) {
1955 87724800 : rows[i] =
1956 87724800 : _mm256_loadu2_m128i((const __m128i *)(src + (8 + i) * input_stride),
1957 87724800 : (const __m128i *)(src + i * input_stride));
1958 : }
1959 :
1960 : // Step 1: Transpose within 2x2 sub-blocks
1961 10965600 : __m256i temp0 = _mm256_unpacklo_epi8(rows[0], rows[1]);
1962 : __m256i temp1 = _mm256_unpackhi_epi8(rows[0], rows[1]);
1963 10965600 : __m256i temp2 = _mm256_unpacklo_epi8(rows[2], rows[3]);
1964 : __m256i temp3 = _mm256_unpackhi_epi8(rows[2], rows[3]);
1965 10965600 : __m256i temp4 = _mm256_unpacklo_epi8(rows[4], rows[5]);
1966 : __m256i temp5 = _mm256_unpackhi_epi8(rows[4], rows[5]);
1967 10965600 : __m256i temp6 = _mm256_unpacklo_epi8(rows[6], rows[7]);
1968 : __m256i temp7 = _mm256_unpackhi_epi8(rows[6], rows[7]);
1969 :
1970 : // Step 2: Transpose within 4x4 sub-blocks
1971 : __m256i interleave0 = _mm256_unpacklo_epi16(temp0, temp2);
1972 : __m256i interleave1 = _mm256_unpackhi_epi16(temp0, temp2);
1973 : __m256i interleave2 = _mm256_unpacklo_epi16(temp1, temp3);
1974 : __m256i interleave3 = _mm256_unpackhi_epi16(temp1, temp3);
1975 : __m256i interleave4 = _mm256_unpacklo_epi16(temp4, temp6);
1976 : __m256i interleave5 = _mm256_unpackhi_epi16(temp4, temp6);
1977 : __m256i interleave6 = _mm256_unpacklo_epi16(temp5, temp7);
1978 : __m256i interleave7 = _mm256_unpackhi_epi16(temp5, temp7);
1979 :
1980 : // Step 3: Transpose within 8x8 block
1981 : __m256i final0 = _mm256_unpacklo_epi32(interleave0, interleave4);
1982 : __m256i final1 = _mm256_unpackhi_epi32(interleave0, interleave4);
1983 : __m256i final2 = _mm256_unpacklo_epi32(interleave1, interleave5);
1984 : __m256i final3 = _mm256_unpackhi_epi32(interleave1, interleave5);
1985 : __m256i final4 = _mm256_unpacklo_epi32(interleave2, interleave6);
1986 : __m256i final5 = _mm256_unpackhi_epi32(interleave2, interleave6);
1987 : __m256i final6 = _mm256_unpacklo_epi32(interleave3, interleave7);
1988 : __m256i final7 = _mm256_unpackhi_epi32(interleave3, interleave7);
1989 :
1990 : // Step 4: Transpose within 16x16 block
1991 : __m256i res[8];
1992 10965600 : res[0] = _mm256_unpacklo_epi64(final0, final4);
1993 10965600 : res[1] = _mm256_unpackhi_epi64(final0, final4);
1994 10965600 : res[2] = _mm256_unpacklo_epi64(final1, final5);
1995 10965600 : res[3] = _mm256_unpackhi_epi64(final1, final5);
1996 10965600 : res[4] = _mm256_unpacklo_epi64(final2, final6);
1997 10965600 : res[5] = _mm256_unpackhi_epi64(final2, final6);
1998 10965600 : res[6] = _mm256_unpacklo_epi64(final3, final7);
1999 10965600 : res[7] = _mm256_unpackhi_epi64(final3, final7);
2000 :
2001 : const int perm_0213 = 0xd8; // 0, 2, 1, 3
2002 : const int perm_02 = 0x20; // 0, 2
2003 : const int perm_13 = 0x31; // 1, 3
2004 54828000 : for (int i = 0; i < 4; i++) {
2005 43862400 : __m256i a128x2 = _mm256_permute4x64_epi64(res[2 * i], perm_0213);
2006 43862400 : __m256i b128x2 = _mm256_permute4x64_epi64(res[2 * i + 1], perm_0213);
2007 43862400 : _mm256_storeu_si256((__m256i *)&dst[2 * i * output_stride],
2008 : _mm256_permute2x128_si256(a128x2, b128x2, perm_02));
2009 43862400 : _mm256_storeu_si256((__m256i *)&dst[(8 + 2 * i) * output_stride],
2010 : _mm256_permute2x128_si256(a128x2, b128x2, perm_13));
2011 : }
2012 10965600 : }
2013 :
2014 21908200 : static inline void create_q4_0_weights_x8(const uint8_t *int4_weight,
2015 : __m256i *q4_blocks) {
2016 : constexpr const size_t ROW_BLOCK_BYTE_SIZE = 16;
2017 :
2018 : // Create masks for extracting low and high nibbles
2019 : const __m256i low_nibble_mask = _mm256_set1_epi8(0x0F);
2020 : const __m256i high_nibble_mask = _mm256_set1_epi8(0xF0);
2021 :
2022 : // Create two blocks in one iteration
2023 109541000 : for (int i = 0; i < 4; ++i) {
2024 : // Load 16 bytes of input data
2025 : __m256i input = _mm256_loadu_si256(
2026 87632800 : (const __m256i *)(int4_weight + 2 * ROW_BLOCK_BYTE_SIZE * i));
2027 :
2028 : // A = input & low_nibble_mask
2029 : __m256i A = _mm256_and_si256(input, low_nibble_mask);
2030 :
2031 : // B = (input & high_nibble_mask) >> 4
2032 : __m256i B = _mm256_srli_epi16(_mm256_and_si256(input, high_nibble_mask), 4);
2033 :
2034 : // input_shifted = input >> 8 bytes
2035 : __m256i input_shifted = _mm256_bsrli_epi128(input, 8);
2036 : // C = input_shifted & low_nibble_mask
2037 : __m256i C = _mm256_and_si256(input_shifted, low_nibble_mask);
2038 :
2039 : // D = (input_shifted & high_nibble_mask) >> 4
2040 : __m256i D =
2041 : _mm256_srli_epi16(_mm256_and_si256(input_shifted, high_nibble_mask), 4);
2042 :
2043 : // AC = A | (C << 4)
2044 : __m256i AC = _mm256_or_si256(A, _mm256_slli_epi16(C, 4));
2045 :
2046 : // BD = B | (D << 4)
2047 : __m256i BD = _mm256_or_si256(B, _mm256_slli_epi16(D, 4));
2048 :
2049 : // Interleave AC and BD
2050 : __m256i result = _mm256_unpacklo_epi8(AC, BD);
2051 :
2052 87632800 : _mm256_store_si256(&q4_blocks[i], result);
2053 : }
2054 21908200 : }
2055 :
2056 21908200 : inline static void nntr_make_block_q4_0x8(const __m256i *in, block_q4_0x8 *out,
2057 : const uint16_t *scales) {
2058 : constexpr size_t IN_CNT = 8;
2059 21908200 : memcpy(out->d, scales, IN_CNT * sizeof(uint16_t));
2060 :
2061 : const int perm_0213 = 0xd8; // 0, 2, 1, 3
2062 : const int perm_02 = 0x20; // 0, 2
2063 : const int perm_13 = 0x31; // 1, 3
2064 21908200 : __m256i a128x2 = _mm256_permute4x64_epi64(*(__m256i *)&in[0], perm_0213);
2065 21908200 : __m256i b128x2 = _mm256_permute4x64_epi64(*(__m256i *)&in[1], perm_0213);
2066 21908200 : __m256i c128x2 = _mm256_permute4x64_epi64(*(__m256i *)&in[2], perm_0213);
2067 21908200 : __m256i d128x2 = _mm256_permute4x64_epi64(*(__m256i *)&in[3], perm_0213);
2068 : _mm256_storeu_si256((__m256i *)&out->qs[0],
2069 : _mm256_permute2x128_si256(a128x2, b128x2, perm_02));
2070 : _mm256_storeu_si256((__m256i *)&out->qs[32],
2071 : _mm256_permute2x128_si256(c128x2, d128x2, perm_02));
2072 : _mm256_storeu_si256((__m256i *)&out->qs[64],
2073 : _mm256_permute2x128_si256(a128x2, b128x2, perm_13));
2074 : _mm256_storeu_si256((__m256i *)&out->qs[96],
2075 : _mm256_permute2x128_si256(c128x2, d128x2, perm_13));
2076 21908200 : }
2077 :
2078 2400 : void transform_int4_osv32_isv2_to_q4_0x8(size_t N, size_t K,
2079 : const uint8_t *osv32_weights,
2080 : const uint16_t *osv32_scales,
2081 : size_t scale_group_size,
2082 : void *dst_q4_0x) {
2083 :
2084 2400 : NNTR_THROW_IF((!(scale_group_size == 32 || scale_group_size == 64 ||
2085 : scale_group_size == 128)),
2086 : std::invalid_argument)
2087 : << "Scale group size must be 32/64/128";
2088 2400 : NNTR_THROW_IF(K % QK4_0 != 0, std::invalid_argument)
2089 : << "K size must be divisable by QK4_0 (32)";
2090 2400 : NNTR_THROW_IF(N % 8 != 0, std::invalid_argument)
2091 : << "N size must be divisable by 8";
2092 :
2093 : static constexpr const size_t NUM_Q4_0_BLOCKS = 8;
2094 : static constexpr const size_t ROW_BLOCK_SIZE = 32;
2095 : static constexpr const size_t COLUMN_BLOCK_SIZE = 2;
2096 : static constexpr const size_t ROW_BLOCK_BYTE_SIZE = 16;
2097 :
2098 : static constexpr const size_t dst_tmp_size =
2099 : (8 * ROW_BLOCK_BYTE_SIZE) / sizeof(__m256i);
2100 : uint8_t *dst_ = reinterpret_cast<uint8_t *>(dst_q4_0x);
2101 :
2102 : // --- Layout ---
2103 2400 : const size_t rows_count_pad = align(N, ROW_BLOCK_SIZE);
2104 2400 : const size_t columns_count_pad = align(K, ROW_BLOCK_SIZE);
2105 2400 : const size_t column_blocks_count =
2106 : columns_count_pad / COLUMN_BLOCK_SIZE; // COLUMN_BLOCK_SIZE == 2
2107 2400 : const size_t bytes_per_row_block_span = column_blocks_count * ROW_BLOCK_SIZE;
2108 2400 : const int column_blocks_cnt = K / QK4_0;
2109 :
2110 : alignas(32) static thread_local __m256i dst_tmp[dst_tmp_size];
2111 : alignas(32) static thread_local uint8_t mx16x16[16 * 16];
2112 :
2113 2400 : #pragma omp parallel for schedule(guided)
2114 : for (int row_id = 0; row_id < (int)N; row_id += 16) {
2115 : const size_t row_in_block_id = row_id / ROW_BLOCK_SIZE;
2116 : size_t i_in_block = row_id % ROW_BLOCK_SIZE;
2117 : for (int column_out_block_id = 0; column_out_block_id < column_blocks_cnt;
2118 : column_out_block_id++) {
2119 : int column_idx = column_out_block_id * QK4_0;
2120 : int scale_offset = (column_idx / scale_group_size) * rows_count_pad;
2121 : const size_t row_block_base =
2122 : row_in_block_id * bytes_per_row_block_span + i_in_block;
2123 : int src_offset =
2124 : row_block_base + column_out_block_id * 16 * ROW_BLOCK_SIZE;
2125 : transpose_matrix_16x16(&osv32_weights[src_offset], ROW_BLOCK_SIZE,
2126 : mx16x16, 16);
2127 : int max_r = std::min((size_t)16, N - row_id);
2128 : size_t row_out_block_id = row_id / NUM_Q4_0_BLOCKS;
2129 : int dst_offset =
2130 : (NUM_Q4_0_BLOCKS * sizeof(block_q4_0)) *
2131 : (column_out_block_id + row_out_block_id * column_blocks_cnt);
2132 : for (int r = 0; r < max_r; r += NUM_Q4_0_BLOCKS) {
2133 : create_q4_0_weights_x8(&mx16x16[16 * r], dst_tmp);
2134 :
2135 : nntr_make_block_q4_0x8(dst_tmp, (block_q4_0x8 *)(dst_ + dst_offset),
2136 : &osv32_scales[scale_offset + row_id + r]);
2137 : row_out_block_id++;
2138 : dst_offset +=
2139 : (NUM_Q4_0_BLOCKS * sizeof(block_q4_0)) * column_blocks_cnt;
2140 : }
2141 : }
2142 : }
2143 2400 : }
2144 :
2145 : } // namespace nntrainer::avx2
|