Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2023 Donghyeon Jeong <dhyeon.jeong@samsung.com>
4 : *
5 : * @file avx2_impl.cpp
6 : * @date 20 Feb 2024
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Donghyeon Jeong <dhyeon.jeong@samsung.com>
9 : * @author Sungsik Kong <ss.kong@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : * @brief This is a source for AVX implementation
12 : *
13 : */
14 :
15 : #include "avx2_impl.h"
16 : #include <array>
17 : #if __has_include(<bit>)
18 : #include <bit>
19 : #endif
20 : #include <cassert>
21 : #include <cmath>
22 : #include <cstdint>
23 : #include <cstring>
24 : #include <fp16.h>
25 : #include <immintrin.h>
26 : #include <limits>
27 : #if __has_include(<numbers>)
28 : #include <numbers>
29 : #endif
30 : #include <type_traits>
31 : #if __has_include(<version>)
32 : #include <version>
33 : #endif
34 : #include <fallback_internal.h>
35 : #include <util_func.h>
36 : #include <vector>
37 :
38 : #if !defined(__has_constexpr_builtin)
39 : #define __has_constexpr_builtin(x) (0)
40 : #endif
41 :
42 : #if !defined(__has_cpp_attribute)
43 : #define __has_cpp_attribute(x) (0)
44 : #endif
45 :
46 : // VECTORCALL calling-conv (default for x86_64-linux-gnu)
47 : #if _MSC_VER >= 1700
48 : #define _nnt_CC_VECTORCALL __vectorcall
49 : #else
50 : #define _nnt_CC_VECTORCALL
51 : #endif
52 :
53 : // Flatten attribute
54 : #if _MSC_VER >= 1700 || __has_cpp_attribute(msvc::flatten)
55 : #define _nnt_ATTR_FLATTEN [[msvc::flatten]]
56 : #elif __has_cpp_attribute(gnu::flatten)
57 : // clang, g++
58 : #define _nnt_ATTR_FLATTEN [[gnu::flatten]]
59 : #else
60 : #define _nnt_ATTR_FLATTEN
61 : #endif
62 :
63 : #if _MSC_VER >= 1700 || __has_cpp_attribute(msvc::noinline)
64 : #define _nnt_ATTR_NOINLINE [[msvc::noinline]]
65 : #elif __has_cpp_attribute(gnu::flatten)
66 : // clang, g++
67 : #define _nnt_ATTR_NOINLINE [[gnu::noinline]]
68 : #else
69 : #define _nnt_ATTR_NOINLINE
70 : #endif
71 :
72 : #if _MSC_VER >= 1700 || __has_cpp_attribute(msvc::forceinline)
73 : #define _nnt_ATTR_ALWAYS_INLINE [[msvc::forceinline]]
74 : #elif __has_cpp_attribute(gnu::always_inline)
75 : #define _nnt_ATTR_ALWAYS_INLINE [[gnu::always_inline]]
76 : #endif
77 :
78 : #if __has_cpp_attribute(unikely)
79 : #define UNLIKELY [[unlikely]]
80 : #else
81 : #define UNLIKELY
82 : #endif
83 :
84 : #if !defined(_MSC_VER) && !defined(__clang__)
85 : #pragma GCC diagnostic ignored "-Wattributes"
86 : #endif
87 :
88 : #if defined(__clang__) || defined(__GNUC__)
89 : #define RESTRICT __restrict__
90 : #else
91 : #define RESTRICT
92 : #endif
93 :
94 : namespace {
95 :
96 : template <typename To_, typename From_>
97 : constexpr inline bool concept17_BinaryCastable =
98 : sizeof(To_) == sizeof(From_) &&
99 : std::is_trivially_copyable_v<From_> &&std::is_trivially_copyable_v<To_>;
100 :
101 : template <class To_, class From_>
102 : auto compat_bit_cast(const From_ &src) noexcept
103 : -> std::enable_if_t<concept17_BinaryCastable<To_, From_>, To_> {
104 : #if __cpp_lib_bit_cast >= 201806L
105 : return std::bit_cast<To_>(src);
106 : #else
107 : To_ dst;
108 : std::memcpy(&dst, &src, sizeof(To_));
109 : return dst;
110 : #endif
111 : }
112 :
113 : [[nodiscard]] constexpr inline unsigned
114 : constexpr_popcount(uint32_t v) noexcept {
115 : #if __cpp_lib_bitops >= 201907L
116 : return std::popcount(v);
117 : #else
118 : // Popcount via bit-hack
119 : v = v - ((v >> 1) & 0x55555555); // reuse input as temporary
120 : v = (v & 0x33333333) + ((v >> 2) & 0x33333333); // temp
121 : auto c = (((v + (v >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24; // count
122 : return c;
123 : #endif
124 : }
125 :
126 : template <unsigned I_>
127 : constexpr inline bool concept17_PowerOfTwo = (constexpr_popcount(I_) == 1);
128 :
129 : namespace numbers {
130 :
131 : #if __has_include(<numbers>) && __cpp_lib_math_constants >= 201907L
132 : using std::numbers::ln2_v;
133 : using std::numbers::log2e_v;
134 : #else
135 : template <typename Float_> constexpr inline auto ln2_v = Float_{M_LN2};
136 :
137 : template <typename Float_> constexpr inline auto log2e_v = Float_{M_LOG2E};
138 : #endif
139 : } // namespace numbers
140 :
141 : constexpr inline float EXP_ARG_MIN = -87.0;
142 : constexpr inline float EXP_ARG_MAX = +88.3762626647949f;
143 :
144 : // @brief Precalculated lookup table for 2^x calculation
145 : template <unsigned N_, typename Ty_ = uint32_t, typename Float_ = float,
146 : typename = std::enable_if_t<concept17_PowerOfTwo<N_>>>
147 : struct Exp2Table {
148 :
149 : constexpr static inline auto MANTISSA_BITS =
150 : std::numeric_limits<Float_>::digits - 1;
151 :
152 : #if __cpp_consteval >= 201811L && __has_constexpr_builtin(__builtin_exp2)
153 : [[nodiscard]] static consteval auto calculate() noexcept {
154 : std::array<Ty_, N_> t;
155 : for (unsigned i = 0; i < N_; ++i)
156 : t[i] = std::bit_cast<Ty_>(std::exp2(Float_{1.0} * i / N_)) -
157 : ((i << MANTISSA_BITS) / N_);
158 : return t;
159 : }
160 : #endif
161 : };
162 :
163 : #if !__has_constexpr_builtin(__builtin_exp2) || !(__cpp_consteval >= 201811L)
164 :
165 : // @brief Precalculated lookup table for 2^x calculation when we don't have
166 : // constexpr math functions
167 : template <> struct Exp2Table<8, uint32_t, float> {
168 : [[nodiscard]] static constexpr auto calculate() noexcept {
169 : std::array<uint32_t, 8> t = {0x3f800000U, 0x3f7b95c2U, 0x3f7837f0U,
170 : 0x3f75fed7U, 0x3f7504f3U, 0x3f75672aU,
171 : 0x3f7744fdU, 0x3f7ac0c7U};
172 : return t;
173 : }
174 : };
175 :
176 : #endif
177 :
178 : template <unsigned N_> // requires PowerOfTwo<N_>
179 : alignas(__m256) inline constexpr auto exp2_table_v = Exp2Table<N_>::calculate();
180 :
181 : // Scalar version of expf() approximation with 3-th deg polynominal of
182 : // fractional part
183 : //
184 : // The error with regards to std::expf less than 5e-6
185 : // It is valid in range [-87, +88.37) - not handling +INF, NaN etc.
186 : // The function domain is clamped to valid function range.
187 : //
188 : // The strategy picked is to approximate exp as 2^K*2^F
189 : template <unsigned N_, typename = std::enable_if_t<concept17_PowerOfTwo<N_>>>
190 : [[nodiscard]] _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN inline float
191 : approx_exp_exp2_lookup(float x) noexcept {
192 : constexpr static unsigned N_MASK = uint32_t(N_ - 1U);
193 : constexpr static unsigned FLT_MANTISSA_BITS =
194 : std::numeric_limits<float>::digits - 1U;
195 :
196 : x = std::max(x, EXP_ARG_MIN);
197 : x *= float(0x1.0p1 / numbers::ln2_v<double> * N_);
198 : x = std::min(x, float(EXP_ARG_MAX / numbers::ln2_v<double> * N_));
199 :
200 : // Round nearest and convert integer part to an int (std::modf)
201 : // NB: This way doesn't handle ties even.
202 : auto x_int = x + 0x1.8p23f;
203 : auto x_uint = compat_bit_cast<uint32_t>(x_int);
204 : x_int -= 0x1.8p23f;
205 : auto x_frac = x - x_int;
206 :
207 : auto s_int = exp2_table_v<N_>[x_uint & N_MASK];
208 : auto x_uint_shifted = x_uint
209 : << (FLT_MANTISSA_BITS - constexpr_popcount(N_MASK));
210 : auto s_int_2 = s_int + x_uint_shifted;
211 : auto s = compat_bit_cast<float>(s_int_2);
212 :
213 : // Polynominal of form C0*x^3 + C1*x^2 + C2*x^1 + 1.0
214 : static constexpr float poly_4[] = {0x1.c6af84b912394p-5f / N_ / N_ / N_,
215 : 0x1.ebfce50fac4f3p-3f / N_ / N_,
216 : 0x1.62e42ff0c52d6p-1f / N_};
217 :
218 : auto q0 = std::fma(x_frac, poly_4[0], poly_4[1]);
219 : auto x_frac_pow2 = x_frac * x_frac;
220 : auto q2 = x_frac * poly_4[2]; // not adding +1.0
221 :
222 : x = std::fma(q0, x_frac_pow2, q2);
223 :
224 : return std::fma(x, s, s); // NB: (handles (x+1) by addition of s)
225 : }
226 :
227 : // Vectorized version of above
228 : template <unsigned N_, typename = std::enable_if_t<concept17_PowerOfTwo<N_>>>
229 : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN inline auto _nnt_CC_VECTORCALL
230 : avx2_approx_exp_e2lookup(__m256 xs) noexcept {
231 :
232 : constexpr static uint32_t N_MASK = uint32_t(N_ - 1U);
233 : alignas(64) constexpr static auto EXP2_TBL = exp2_table_v<N_>;
234 : constexpr static unsigned MANTISSA_BITS =
235 : std::numeric_limits<float>::digits - 1;
236 :
237 : // Ensure arg in range [exp_arg_min, exp_arg_max]
238 : xs = _mm256_max_ps(xs, _mm256_set1_ps(EXP_ARG_MIN));
239 : // Would clamp to EXP_ARG_MAX but we move it after multiplication for IPC:
240 : // xs = _mm256_min_ps(xs, _mm256_set1_ps(EXP_ARG_MAX));
241 :
242 : xs =
243 : _mm256_mul_ps(xs, _mm256_set1_ps(float(1.0 / numbers::ln2_v<double> * N_)));
244 : // Clamp EXP_ARG_MAX after multiply
245 : xs = _mm256_min_ps(
246 : xs, _mm256_set1_ps(float(EXP_ARG_MAX / numbers::ln2_v<double> * N_)));
247 :
248 : // Mostly equivalent to, doesn't round ties to even
249 : // auto xs_int = _mm256_round_ps(xs, _MM_FROUND_TO_NEAREST_INT |
250 : // _MM_FROUND_NO_EXC); auto xs_int_as_u32 = _mm256_cvtps_epi32(xs_int);
251 : auto xs_int = _mm256_add_ps(xs, _mm256_set1_ps(0x1.8p23f));
252 : auto xs_int_as_u32 = _mm256_castps_si256(xs_int);
253 : xs_int = _mm256_sub_ps(xs_int, _mm256_set1_ps(0x1.8p23f));
254 :
255 : // Calculate fractional part
256 : auto xs_frac = _mm256_sub_ps(xs, xs_int);
257 : // Indices for lookup (modulo N_)
258 : auto exp2_idxs = _mm256_and_si256(xs_int_as_u32, _mm256_set1_epi32(N_MASK));
259 :
260 : __m256i s_ints;
261 :
262 : // Lookup e^xs_int s factor
263 : if constexpr (N_ == 8) {
264 : // Lookup by vector permute
265 : auto tbl = _mm256_load_si256((__m256i *)EXP2_TBL.data());
266 : s_ints = _mm256_permutevar8x32_epi32(tbl, exp2_idxs);
267 : } else {
268 : // Falback for not fitting number of vector elements
269 : s_ints = _mm256_i32gather_epi32(EXP2_TBL.data(), exp2_idxs, 1);
270 : }
271 :
272 : auto xs_uint_shifted = _mm256_slli_epi32(
273 : xs_int_as_u32, MANTISSA_BITS - constexpr_popcount(N_MASK));
274 : auto s_ints_2 = _mm256_add_epi32(s_ints, xs_uint_shifted);
275 : auto s_floats = _mm256_castsi256_ps(s_ints_2);
276 :
277 : static constexpr float poly_d4[] = {0x1.c6af84b912394p-5f / N_ / N_ / N_,
278 : 0x1.ebfce50fac4f3p-3f / N_ / N_,
279 : 0x1.62e42ff0c52d6p-1f / N_};
280 :
281 : const auto C0 = _mm256_set1_ps(poly_d4[0]);
282 : const auto C1 = _mm256_set1_ps(poly_d4[1]);
283 : const auto C2 = _mm256_set1_ps(poly_d4[2]);
284 :
285 : auto qs0 = _mm256_fmadd_ps(xs_frac, C0, C1);
286 : auto xs_frac_pow2 = _mm256_mul_ps(xs_frac, xs_frac);
287 : auto qs2 = _mm256_mul_ps(xs_frac, C2);
288 :
289 : xs = _mm256_fmadd_ps(qs0, xs_frac_pow2, qs2);
290 :
291 : return _mm256_fmadd_ps(xs, s_floats, s_floats);
292 : }
293 :
294 : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN auto _nnt_CC_VECTORCALL
295 : avx2_negate_ps(__m256 x) noexcept -> __m256 {
296 : constexpr auto SIGN_SHIFT = sizeof(float) * 8 - 1;
297 : const auto UNDEF = _mm256_undefined_si256();
298 : const auto sign_bit =
299 : _mm256_slli_epi32(_mm256_cmpeq_epi16(UNDEF, UNDEF), SIGN_SHIFT);
300 : auto flt_sign_bit = _mm256_castsi256_ps(sign_bit);
301 : auto neg_x = _mm256_xor_ps(x, flt_sign_bit);
302 : return neg_x;
303 : }
304 :
305 : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN auto _nnt_CC_VECTORCALL
306 : avx2_approx_swiglu(__m256 x, __m256 s) noexcept -> __m256 {
307 : auto neg_x = avx2_negate_ps(x);
308 : auto inv_sigmoid =
309 : _mm256_add_ps(avx2_approx_exp_e2lookup<8>(neg_x), _mm256_set1_ps(1.0f));
310 : auto swiglu_nonscaled = _mm256_div_ps(x, inv_sigmoid);
311 : return _mm256_mul_ps(swiglu_nonscaled, s);
312 : }
313 :
314 : _nnt_ATTR_ALWAYS_INLINE _nnt_ATTR_FLATTEN auto _nnt_CC_VECTORCALL
315 : avx2_approx_swiglu_alpha(__m256 x, __m256 s, __m256 alpha) noexcept -> __m256 {
316 : auto alpha_x = _mm256_mul_ps(alpha, x);
317 : auto neg_alpha_x = avx2_negate_ps(alpha_x);
318 : auto inv_sigmoid = _mm256_add_ps(avx2_approx_exp_e2lookup<8>(neg_alpha_x),
319 : _mm256_set1_ps(1.0f));
320 : auto swiglu_nonscaled = _mm256_div_ps(x, inv_sigmoid);
321 : return _mm256_mul_ps(swiglu_nonscaled, s);
322 : }
323 : } // namespace
324 :
325 : namespace nntrainer::avx2 {
326 :
327 : /**
328 : * @brief struct of q4_0x8 block
329 : */
330 : struct block_q4_0x8 {
331 : uint16_t d[8]; // 16B
332 : uint8_t qs[128]; // 16 x u64
333 : };
334 :
335 : #define USE_NONTEMPORAL_STORES 1
336 :
337 : static inline void store256_u16(void *dst, __m256i v) {
338 : #if defined(USE_NONTEMPORAL_STORES)
339 : // use NT only if 32B-aligned; otherwise fall back (correctness first)
340 : if (((uintptr_t)dst & 31u) == 0) {
341 : _mm256_stream_si256((__m256i *)dst, v);
342 : return;
343 : }
344 : #endif
345 : _mm256_storeu_si256((__m256i *)dst, v);
346 : }
347 :
348 0 : void unpack_q4_0x8_transpose16(const void *src, unsigned short *__restrict dT,
349 : unsigned short *__restrict qsT, int N, int K,
350 : int CT) // column tile (in units of 32-cols)
351 : {
352 0 : assert((K % 256) == 0);
353 0 : assert((N % 8) == 0);
354 :
355 : const auto *__restrict x = static_cast<const block_q4_0x8 *>(src);
356 :
357 0 : const int groups_N8 = N / 8; // number of 8-row groups
358 0 : const int cols_scales = K / 32; // K subblocks
359 :
360 : // AVX2 constants
361 0 : const __m128i v88 = _mm_set1_epi8((char)0x88);
362 0 : const __m128i v0f = _mm_set1_epi8((char)0x0F);
363 0 : const __m128i vF0 = _mm_set1_epi8((char)0xF0);
364 :
365 : const __m128i idx_even =
366 : _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, (char)0xFF, (char)0xFF, (char)0xFF,
367 0 : (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF);
368 : const __m128i idx_odd =
369 : _mm_setr_epi8(1, 3, 5, 7, 9, 11, 13, 15, (char)0xFF, (char)0xFF, (char)0xFF,
370 0 : (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF);
371 : const __m128i idx_0246 =
372 : _mm_setr_epi8(0, 2, 4, 6, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
373 : (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
374 0 : (char)0xFF, (char)0xFF, (char)0xFF);
375 : const __m128i idx_1357 =
376 : _mm_setr_epi8(1, 3, 5, 7, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
377 : (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF, (char)0xFF,
378 0 : (char)0xFF, (char)0xFF, (char)0xFF);
379 :
380 0 : auto pack_row8 = [&](const unsigned char *qs0, const unsigned char *qs1,
381 : int off) -> __m128i {
382 0 : __m128i lo8 = _mm_loadl_epi64((const __m128i *)(qs0 + 8 * off));
383 0 : __m128i hi8 = _mm_loadl_epi64((const __m128i *)(qs1 + 8 * off));
384 : __m128i v = _mm_unpacklo_epi64(lo8, hi8);
385 0 : v = _mm_xor_si128(v, v88);
386 0 : __m128i lo = _mm_and_si128(v, v0f);
387 : __m128i hi = _mm_and_si128(_mm_srli_epi16(v, 4), v0f);
388 0 : __m128i lo_e = _mm_shuffle_epi8(lo, idx_even);
389 0 : __m128i lo_o = _mm_shuffle_epi8(lo, idx_odd);
390 : __m128i hi_e = _mm_shuffle_epi8(hi, idx_even);
391 : __m128i hi_o = _mm_shuffle_epi8(hi, idx_odd);
392 : __m128i low_lane =
393 0 : _mm_or_si128(lo_e, _mm_and_si128(_mm_slli_epi16(lo_o, 4), vF0));
394 : __m128i high_lane =
395 : _mm_or_si128(hi_e, _mm_and_si128(_mm_slli_epi16(hi_o, 4), vF0));
396 0 : __m128i low_e2 = _mm_shuffle_epi8(low_lane, idx_0246);
397 0 : __m128i low_o2 = _mm_shuffle_epi8(low_lane, idx_1357);
398 : __m128i high_e2 = _mm_shuffle_epi8(high_lane, idx_0246);
399 : __m128i high_o2 = _mm_shuffle_epi8(high_lane, idx_1357);
400 : __m128i pack_lo = _mm_unpacklo_epi8(low_e2, low_o2); // 4×u16 (w0..w3)
401 : __m128i pack_hi = _mm_unpacklo_epi8(high_e2, high_o2); // 4×u16 (w4..w7)
402 0 : return _mm_unpacklo_epi64(pack_lo, pack_hi); // 8×u16 (w0..w7)
403 0 : };
404 :
405 : auto transpose8x8_epi16 =
406 : [](__m128i r0, __m128i r1, __m128i r2, __m128i r3, __m128i r4, __m128i r5,
407 : __m128i r6, __m128i r7, __m128i &c0, __m128i &c1, __m128i &c2,
408 : __m128i &c3, __m128i &c4, __m128i &c5, __m128i &c6, __m128i &c7) {
409 : __m128i t0 = _mm_unpacklo_epi16(r0, r1);
410 : __m128i t1 = _mm_unpackhi_epi16(r0, r1);
411 : __m128i t2 = _mm_unpacklo_epi16(r2, r3);
412 : __m128i t3 = _mm_unpackhi_epi16(r2, r3);
413 : __m128i t4 = _mm_unpacklo_epi16(r4, r5);
414 : __m128i t5 = _mm_unpackhi_epi16(r4, r5);
415 : __m128i t6 = _mm_unpacklo_epi16(r6, r7);
416 : __m128i t7 = _mm_unpackhi_epi16(r6, r7);
417 :
418 : __m128i u0 = _mm_unpacklo_epi32(t0, t2);
419 : __m128i u1 = _mm_unpackhi_epi32(t0, t2);
420 : __m128i u2 = _mm_unpacklo_epi32(t1, t3);
421 : __m128i u3 = _mm_unpackhi_epi32(t1, t3);
422 : __m128i u4 = _mm_unpacklo_epi32(t4, t6);
423 : __m128i u5 = _mm_unpackhi_epi32(t4, t6);
424 : __m128i u6 = _mm_unpacklo_epi32(t5, t7);
425 : __m128i u7 = _mm_unpackhi_epi32(t5, t7);
426 :
427 : c0 = _mm_unpacklo_epi64(u0, u4);
428 : c1 = _mm_unpackhi_epi64(u0, u4);
429 : c2 = _mm_unpacklo_epi64(u1, u5);
430 : c3 = _mm_unpackhi_epi64(u1, u5);
431 : c4 = _mm_unpacklo_epi64(u2, u6);
432 : c5 = _mm_unpackhi_epi64(u2, u6);
433 : c6 = _mm_unpacklo_epi64(u3, u7);
434 : c7 = _mm_unpackhi_epi64(u3, u7);
435 : };
436 :
437 : // -------- pair-processing path: handle two 8-row groups (16 rows) per pass
438 : // --------
439 0 : const int groups_pairs = groups_N8 / 2;
440 :
441 : #ifdef _MSC_VER
442 : #pragma warning(push)
443 : #pragma warning(disable : 4849)
444 : #endif
445 0 : #pragma omp parallel for collapse(2) schedule(static)
446 : #ifdef _MSC_VER
447 : #pragma warning(pop)
448 : #endif
449 : for (int c0 = 0; c0 < cols_scales; c0 += CT) {
450 : for (int bp = 0; bp < groups_pairs; ++bp) {
451 : const int b0 = 2 * bp;
452 : const int b1 = b0 + 1;
453 : const int r0 = b0 * 8; // 16 rows: r0..r0+15
454 : const int c1 = std::min(c0 + CT, cols_scales);
455 :
456 : for (int c = c0; c < c1; ++c) {
457 : const block_q4_0x8 &A = x[b0 * cols_scales + c];
458 : const block_q4_0x8 &B = x[b1 * cols_scales + c];
459 :
460 : unsigned short *__restrict dT_c = dT + c * N;
461 : unsigned short *__restrict qsT_c0 = qsT + (c * 8) * N;
462 :
463 : // scales: pack two 8×u16 vectors → one 256b store to dT[c, r0..r0+15]
464 : __m128i sd0 = _mm_loadu_si128((const __m128i *)A.d);
465 : __m128i sd1 = _mm_loadu_si128((const __m128i *)B.d);
466 : __m256i sdp = _mm256_set_m128i(sd1, sd0);
467 : store256_u16(dT_c + r0, sdp);
468 :
469 : // pre-split stripes
470 : const unsigned char *__restrict A0 = A.qs; // + 8*off
471 : const unsigned char *__restrict A1 = A.qs + 64; // + 8*off
472 : const unsigned char *__restrict B0 = B.qs;
473 : const unsigned char *__restrict B1 = B.qs + 64;
474 :
475 : // build 8 rows for A and 8 rows for B
476 : __m128i Ra[8], Rb[8];
477 : for (int off = 0; off < 8; ++off) {
478 : Ra[off] = pack_row8(A0, A1, off);
479 : Rb[off] = pack_row8(B0, B1, off);
480 : }
481 :
482 : // 8×8 transpose → columns (each 8×u16) for A and B
483 : __m128i Ca0, Ca1, Ca2, Ca3, Ca4, Ca5, Ca6, Ca7;
484 : __m128i Cb0, Cb1, Cb2, Cb3, Cb4, Cb5, Cb6, Cb7;
485 : transpose8x8_epi16(Ra[0], Ra[1], Ra[2], Ra[3], Ra[4], Ra[5], Ra[6],
486 : Ra[7], Ca0, Ca1, Ca2, Ca3, Ca4, Ca5, Ca6, Ca7);
487 : transpose8x8_epi16(Rb[0], Rb[1], Rb[2], Rb[3], Rb[4], Rb[5], Rb[6],
488 : Rb[7], Cb0, Cb1, Cb2, Cb3, Cb4, Cb5, Cb6, Cb7);
489 :
490 : // pair and store 32B per column t: rows r0..r0+15 are contiguous
491 : unsigned short *__restrict base = qsT_c0 + r0;
492 : const int S = N;
493 : store256_u16(base + 0 * S, _mm256_set_m128i(Cb0, Ca0));
494 : store256_u16(base + 1 * S, _mm256_set_m128i(Cb1, Ca1));
495 : store256_u16(base + 2 * S, _mm256_set_m128i(Cb2, Ca2));
496 : store256_u16(base + 3 * S, _mm256_set_m128i(Cb3, Ca3));
497 : store256_u16(base + 4 * S, _mm256_set_m128i(Cb4, Ca4));
498 : store256_u16(base + 5 * S, _mm256_set_m128i(Cb5, Ca5));
499 : store256_u16(base + 6 * S, _mm256_set_m128i(Cb6, Ca6));
500 : store256_u16(base + 7 * S, _mm256_set_m128i(Cb7, Ca7));
501 : }
502 : }
503 : }
504 :
505 : // -------- tail: if odd number of 8-row groups, process the last one (8 rows)
506 : // --------
507 0 : if (groups_N8 & 1) {
508 0 : const int b = groups_N8 - 1;
509 0 : const int r0 = b * 8;
510 :
511 0 : #pragma omp parallel for schedule(static)
512 : for (int c0 = 0; c0 < cols_scales; c0 += CT) {
513 : const int c1 = std::min(c0 + CT, cols_scales);
514 : for (int c = c0; c < c1; ++c) {
515 : const block_q4_0x8 &A = x[b * cols_scales + c];
516 : unsigned short *__restrict dT_c = dT + c * N;
517 : unsigned short *__restrict qsT_c0 = qsT + (c * 8) * N;
518 :
519 : // scales (8×u16)
520 : __m128i sd0 = _mm_loadu_si128((const __m128i *)A.d);
521 : _mm_storeu_si128((__m128i *)(dT_c + r0), sd0);
522 :
523 : const unsigned char *__restrict A0 = A.qs;
524 : const unsigned char *__restrict A1 = A.qs + 64;
525 :
526 : __m128i R[8];
527 : for (int off = 0; off < 8; ++off)
528 : R[off] = pack_row8(A0, A1, off);
529 :
530 : __m128i C0, C1, C2, C3, C4, C5, C6, C7;
531 : transpose8x8_epi16(R[0], R[1], R[2], R[3], R[4], R[5], R[6], R[7], C0,
532 : C1, C2, C3, C4, C5, C6, C7);
533 :
534 : unsigned short *__restrict base = qsT_c0 + r0;
535 : const int S = N;
536 : _mm_storeu_si128((__m128i *)(base + 0 * S), C0);
537 : _mm_storeu_si128((__m128i *)(base + 1 * S), C1);
538 : _mm_storeu_si128((__m128i *)(base + 2 * S), C2);
539 : _mm_storeu_si128((__m128i *)(base + 3 * S), C3);
540 : _mm_storeu_si128((__m128i *)(base + 4 * S), C4);
541 : _mm_storeu_si128((__m128i *)(base + 5 * S), C5);
542 : _mm_storeu_si128((__m128i *)(base + 6 * S), C6);
543 : _mm_storeu_si128((__m128i *)(base + 7 * S), C7);
544 : }
545 : }
546 : }
547 :
548 : #if defined(USE_NONTEMPORAL_STORES)
549 : _mm_sfence(); // ensure NT stores are globally visible before returning
550 : #endif
551 0 : }
552 :
553 : static inline __m256i butterfly32(__m256i a) {
554 : const __m256i SHUF_EVEN = _mm256_setr_epi8(
555 : 0, 2, 4, 6, 8, 10, 12, 14, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
556 : (char)0x80, (char)0x80, (char)0x80, (char)0x80, 0, 2, 4, 6, 8, 10, 12, 14,
557 : (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
558 : (char)0x80, (char)0x80);
559 : const __m256i SHUF_ODD = _mm256_setr_epi8(
560 : 1, 3, 5, 7, 9, 11, 13, 15, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
561 : (char)0x80, (char)0x80, (char)0x80, (char)0x80, 1, 3, 5, 7, 9, 11, 13, 15,
562 : (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80, (char)0x80,
563 : (char)0x80, (char)0x80);
564 : const __m256i even = _mm256_shuffle_epi8(a, SHUF_EVEN);
565 : const __m256i odd = _mm256_shuffle_epi8(a, SHUF_ODD);
566 : const __m256i LO = _mm256_set1_epi8(0x0F);
567 : const __m256i HI = _mm256_set1_epi8((char)0xF0);
568 : __m256i low =
569 : _mm256_or_si256(_mm256_and_si256(even, LO),
570 : _mm256_slli_epi16(_mm256_and_si256(odd, LO), 4));
571 : __m256i high =
572 : _mm256_or_si256(_mm256_srli_epi16(_mm256_and_si256(even, HI), 4),
573 : _mm256_and_si256(odd, HI));
574 : high = _mm256_slli_si256(high, 8);
575 : return _mm256_or_si256(low, high);
576 : }
577 :
578 : // Build 16B packet [d0|d1] from two 8B chunks using vector loads (no GPR
579 : // moves).
580 : static inline __m128i make_pkt128(const uint8_t *base_qs, int d0, int d1) {
581 0 : __m128i lo = _mm_loadl_epi64((const __m128i *)(base_qs + ((size_t)d0 << 3)));
582 0 : __m128i hi = _mm_loadl_epi64((const __m128i *)(base_qs + ((size_t)d1 << 3)));
583 : return _mm_unpacklo_epi64(lo, hi);
584 : }
585 :
586 : // ================== core template with QS unrolled by 8 blocks
587 : // ==================
588 : template <int UNIT, int GROUPS>
589 : static inline void convert_q4_0x8_noshuffle(const void *src,
590 : uint16_t *RESTRICT d_out,
591 : uint8_t *RESTRICT qs_out) {
592 : static_assert(UNIT % 16 == 0, "UNIT must be multiple of 16");
593 : constexpr int BLOCKS_PER_GROUP = UNIT / 8; // d entries per offset per group
594 : constexpr int PAIRS_PER_OFFSET = UNIT / 16; // 16B packets per half per offset
595 : static_assert((PAIRS_PER_OFFSET % 4) == 0,
596 : "need multiple of 4 packets (8 blocks) per iter");
597 :
598 : constexpr size_t D_ELEMS_PER_GROUP = 8 * BLOCKS_PER_GROUP;
599 : constexpr size_t QS_BYTES_PER_GROUP = (size_t)16 * UNIT;
600 : constexpr size_t QS_BYTES_PER_OFFSET = (size_t)2 * UNIT;
601 :
602 0 : const block_q4_0x8 *x = (const block_q4_0x8 *)src;
603 0 : const __m256i bias256 = _mm256_set1_epi8((char)0x88);
604 :
605 : #ifdef _MSC_VER
606 : #pragma warning(push)
607 : #pragma warning(disable : 4849)
608 : #endif
609 0 : #pragma omp parallel for collapse(2) schedule(static)
610 : #ifdef _MSC_VER
611 : #pragma warning(pop)
612 : #endif
613 : for (int b = 0; b < GROUPS; ++b) {
614 : for (int offset = 0; offset < 8; ++offset) {
615 :
616 : // ---- D slice ----
617 : {
618 : uint16_t *d_ptr = d_out + (size_t)b * D_ELEMS_PER_GROUP +
619 : (size_t)offset * BLOCKS_PER_GROUP;
620 : const block_q4_0x8 *xb = x + (size_t)b * BLOCKS_PER_GROUP;
621 : for (int i = 0; i < BLOCKS_PER_GROUP; ++i) {
622 : d_ptr[i] = xb[i].d[offset];
623 : }
624 : }
625 :
626 : // ---- QS slice (unroll 8 blocks / 128B per iter) ----
627 : {
628 : uint8_t *qs_ptr = qs_out + (size_t)b * QS_BYTES_PER_GROUP +
629 : (size_t)offset * QS_BYTES_PER_OFFSET;
630 : const int base_q = (b * UNIT * 2) + offset;
631 : const int d0 = (base_q & 15), d1 = d0 ^ 8;
632 :
633 0 : auto do_half = [&](int blk_base) {
634 : // Each iter handles 8 consecutive blocks: j..j+7
635 0 : for (int j = 0; j < PAIRS_PER_OFFSET; j += 8) {
636 0 : const uint8_t *q0 = x[blk_base + j + 0].qs;
637 0 : const uint8_t *q1 = x[blk_base + j + 1].qs;
638 0 : const uint8_t *q2 = x[blk_base + j + 2].qs;
639 0 : const uint8_t *q3 = x[blk_base + j + 3].qs;
640 0 : const uint8_t *q4 = x[blk_base + j + 4].qs;
641 0 : const uint8_t *q5 = x[blk_base + j + 5].qs;
642 0 : const uint8_t *q6 = x[blk_base + j + 6].qs;
643 0 : const uint8_t *q7 = x[blk_base + j + 7].qs;
644 :
645 : #if Q4X8_PREFETCH_DIST > 0
646 : _mm_prefetch(
647 : (const char *)(x[blk_base + j + Q4X8_PREFETCH_DIST].qs),
648 : _MM_HINT_NTA);
649 : #endif
650 : // Build 8 packets in XMM regs
651 0 : __m128i pkt0 = make_pkt128(q0, d0, d1);
652 : __m128i pkt1 = make_pkt128(q1, d0, d1);
653 : __m128i pkt2 = make_pkt128(q2, d0, d1);
654 : __m128i pkt3 = make_pkt128(q3, d0, d1);
655 : __m128i pkt4 = make_pkt128(q4, d0, d1);
656 : __m128i pkt5 = make_pkt128(q5, d0, d1);
657 : __m128i pkt6 = make_pkt128(q6, d0, d1);
658 : __m128i pkt7 = make_pkt128(q7, d0, d1);
659 :
660 : // Four 32B batches: [0|1], [2|3], [4|5], [6|7]
661 : __m256i v01 = _mm256_set_m128i(pkt1, pkt0);
662 : __m256i v23 = _mm256_set_m128i(pkt3, pkt2);
663 : __m256i v45 = _mm256_set_m128i(pkt5, pkt4);
664 : __m256i v67 = _mm256_set_m128i(pkt7, pkt6);
665 :
666 0 : v01 = _mm256_xor_si256(v01, bias256);
667 : v23 = _mm256_xor_si256(v23, bias256);
668 : v45 = _mm256_xor_si256(v45, bias256);
669 : v67 = _mm256_xor_si256(v67, bias256);
670 :
671 : __m256i o01 = butterfly32(v01);
672 : __m256i o23 = butterfly32(v23);
673 : __m256i o45 = butterfly32(v45);
674 : __m256i o67 = butterfly32(v67);
675 :
676 : #if Q4X8_USE_STREAMING_STORES
677 : _mm256_stream_si256((__m256i *)(qs_ptr + 0), o01);
678 : _mm256_stream_si256((__m256i *)(qs_ptr + 32), o23);
679 : _mm256_stream_si256((__m256i *)(qs_ptr + 64), o45);
680 : _mm256_stream_si256((__m256i *)(qs_ptr + 96), o67);
681 : #else
682 0 : _mm256_storeu_si256((__m256i *)(qs_ptr + 0), o01);
683 0 : _mm256_storeu_si256((__m256i *)(qs_ptr + 32), o23);
684 0 : _mm256_storeu_si256((__m256i *)(qs_ptr + 64), o45);
685 0 : _mm256_storeu_si256((__m256i *)(qs_ptr + 96), o67);
686 : #endif
687 0 : qs_ptr += 128;
688 : }
689 : };
690 :
691 : // first half
692 : do_half(base_q >> 4);
693 : // second half (same d0/d1 pattern)
694 : do_half((base_q + UNIT) >> 4);
695 : }
696 : }
697 : }
698 :
699 : #if Q4X8_USE_STREAMING_STORES
700 : _mm_sfence();
701 : #endif
702 : }
703 :
704 : // ================== wrappers for your K,N combinations ==================
705 : // K = 3072 (UNIT = 768)
706 0 : void convert_q4_0x8_shuffle_K3072_N98304(const void *src, uint16_t *d_out,
707 : uint8_t *qs_out) {
708 : // groups = (N*8)/UNIT = 1024
709 : convert_q4_0x8_noshuffle<768, 1024>(src, d_out, qs_out);
710 0 : }
711 0 : void convert_q4_0x8_shuffle_K3072_N36864(const void *src, uint16_t *d_out,
712 : uint8_t *qs_out) {
713 : // groups = 384
714 : convert_q4_0x8_noshuffle<768, 384>(src, d_out, qs_out);
715 0 : }
716 0 : void convert_q4_0x8_shuffle_K3072_N3072(const void *src, uint16_t *d_out,
717 : uint8_t *qs_out) {
718 : // groups = 32
719 : convert_q4_0x8_noshuffle<768, 32>(src, d_out, qs_out);
720 0 : }
721 :
722 : // K = 8192 (UNIT = 2048)
723 0 : void convert_q4_0x8_shuffle_K8192_N98304(const void *src, uint16_t *d_out,
724 : uint8_t *qs_out) {
725 : // groups = 384
726 : convert_q4_0x8_noshuffle<2048, 384>(src, d_out, qs_out);
727 0 : }
728 0 : void convert_q4_0x8_shuffle_K8192_N36864(const void *src, uint16_t *d_out,
729 : uint8_t *qs_out) {
730 : // groups = 144
731 : convert_q4_0x8_noshuffle<2048, 144>(src, d_out, qs_out);
732 0 : }
733 0 : void convert_q4_0x8_shuffle_K8192_N3072(const void *src, uint16_t *d_out,
734 : uint8_t *qs_out) {
735 : // groups = 12
736 : convert_q4_0x8_noshuffle<2048, 12>(src, d_out, qs_out);
737 0 : }
738 :
739 : // Optional tiny dispatcher if you want one entry point:
740 0 : void convert_q4_0x8_shuffle_dispatch_avx(const void *src, uint16_t *d_out,
741 : uint8_t *qs_out, int N, int K) {
742 0 : if (K == 3072) {
743 0 : if (N == 98304)
744 0 : return convert_q4_0x8_shuffle_K3072_N98304(src, d_out, qs_out);
745 0 : if (N == 36864)
746 0 : return convert_q4_0x8_shuffle_K3072_N36864(src, d_out, qs_out);
747 0 : if (N == 3072)
748 0 : return convert_q4_0x8_shuffle_K3072_N3072(src, d_out, qs_out);
749 : } else { // K == 8192
750 0 : if (N == 98304)
751 0 : return convert_q4_0x8_shuffle_K8192_N98304(src, d_out, qs_out);
752 0 : if (N == 36864)
753 0 : return convert_q4_0x8_shuffle_K8192_N36864(src, d_out, qs_out);
754 0 : if (N == 3072)
755 0 : return convert_q4_0x8_shuffle_K8192_N3072(src, d_out, qs_out);
756 : }
757 : // If a new combo appears, fall back to a generic version (not shown here).
758 0 : assert(!"Unsupported (K,N) combination");
759 : }
760 :
761 12 : bool is_valid(const unsigned int N, const float *input) {
762 12 : assert(N != 0);
763 12 : assert(input != NULL);
764 :
765 : int temp = 0;
766 : unsigned int idx = 0;
767 :
768 : const __m256 SIGN_MASK = _mm256_set1_ps(-0.0);
769 : const __m256 INF = _mm256_set1_ps(std::numeric_limits<float>::infinity());
770 :
771 : // 16 single-precision check : ( X != X )
772 15 : for (; N - idx >= 16; idx += 16) {
773 : __m256 vec0 = _mm256_loadu_ps(input);
774 : __m256 vec1 = _mm256_loadu_ps(input + 8);
775 6 : input += 16;
776 : __m256 res = _mm256_cmp_ps(vec0, vec0, _CMP_NEQ_UQ);
777 6 : temp = temp | _mm256_movemask_ps(res);
778 :
779 6 : if (temp)
780 : return false;
781 :
782 : // check infinity in vec0
783 : vec0 = _mm256_andnot_ps(SIGN_MASK, vec0);
784 : vec0 = _mm256_cmp_ps(vec0, INF, _CMP_EQ_OQ);
785 :
786 5 : temp = temp | _mm256_movemask_ps(vec0);
787 5 : if (temp)
788 : return false;
789 :
790 : __m256 res1 = _mm256_cmp_ps(vec1, vec1, _CMP_NEQ_UQ);
791 3 : temp = temp | _mm256_movemask_ps(res1);
792 :
793 3 : if (temp)
794 : return false;
795 :
796 : // check infinity in vec1
797 : vec1 = _mm256_andnot_ps(SIGN_MASK, vec1);
798 : vec1 = _mm256_cmp_ps(vec1, INF, _CMP_EQ_OQ);
799 :
800 3 : temp = temp | _mm256_movemask_ps(vec1);
801 :
802 3 : if (temp)
803 : return false;
804 : }
805 :
806 : // 8 single-precision check : ( X != X )
807 11 : for (; N - idx >= 8; idx += 8) {
808 : __m256 vec = _mm256_loadu_ps(input);
809 5 : input += 8;
810 : __m256 res = _mm256_cmp_ps(vec, vec, _CMP_NEQ_UQ);
811 5 : temp = temp | _mm256_movemask_ps(res);
812 :
813 5 : if (temp)
814 : return false;
815 :
816 : // check infinity in vec
817 : vec = _mm256_andnot_ps(SIGN_MASK, vec);
818 : vec = _mm256_cmp_ps(vec, INF, _CMP_EQ_OQ);
819 :
820 4 : temp = temp | _mm256_movemask_ps(vec);
821 :
822 4 : if (temp)
823 : return false;
824 : }
825 :
826 11 : while (idx < N) {
827 8 : if (!isFloatValid(*input)) {
828 : return false;
829 : }
830 5 : ++input;
831 5 : ++idx;
832 : }
833 :
834 : return true;
835 : }
836 :
837 173980 : void custom_scopy(const unsigned int N, const float *X, const int incX,
838 : float *Y, const int incY) {
839 173980 : unsigned int N8 = (N >> 3) << 3;
840 957465 : for (unsigned int i = 0; i < N8; i += 8) {
841 : #if defined(_WIN32)
842 : __m256 temp = _mm256_loadu_ps(&X[i]);
843 : _mm256_storeu_ps(&Y[i], temp);
844 : #else
845 783485 : __asm__ __volatile__("vmovups (%1), %%ymm0\n\t"
846 : "vmovups %%ymm0, (%0)\n\t"
847 : :
848 783485 : : "r"(&Y[i]), "r"(&X[i])
849 : : "ymm0", "memory");
850 : #endif
851 : }
852 499415 : for (unsigned int i = N8; i < N; ++i) {
853 325435 : Y[i] = X[i];
854 : }
855 173980 : }
856 :
857 25 : void transpose_matrix(const unsigned int M, const unsigned int N,
858 : const float *src, unsigned int ld_src, float *dst,
859 : unsigned int ld_dst) {
860 25 : unsigned int vindexm[8] = {0, ld_src, ld_src * 2, ld_src * 3,
861 25 : ld_src * 4, ld_src * 5, ld_src * 6, ld_src * 7};
862 : __m256i vindex = _mm256_loadu_si256((__m256i *)&vindexm[0]);
863 : __m256 vec1, vec2, vec3, vec4, vec5, vec6, vec7, vec8;
864 :
865 25 : unsigned int M8 = (M & ~(7));
866 25 : unsigned int N8 = (N & ~(7));
867 1378 : for (unsigned int i = 0; i < M8; i += 8) {
868 308626 : for (unsigned int j = 0; j < N8; j += 8) {
869 : // loading from columns
870 307273 : vec1 = _mm256_i32gather_ps(&src[ld_src * i + j + 0], vindex, 4);
871 307273 : vec2 = _mm256_i32gather_ps(&src[ld_src * i + j + 1], vindex, 4);
872 307273 : vec3 = _mm256_i32gather_ps(&src[ld_src * i + j + 2], vindex, 4);
873 307273 : vec4 = _mm256_i32gather_ps(&src[ld_src * i + j + 3], vindex, 4);
874 307273 : vec5 = _mm256_i32gather_ps(&src[ld_src * i + j + 4], vindex, 4);
875 307273 : vec6 = _mm256_i32gather_ps(&src[ld_src * i + j + 5], vindex, 4);
876 307273 : vec7 = _mm256_i32gather_ps(&src[ld_src * i + j + 6], vindex, 4);
877 307273 : vec8 = _mm256_i32gather_ps(&src[ld_src * i + j + 7], vindex, 4);
878 :
879 : // storing to the rows
880 307273 : _mm256_storeu_ps(&dst[(j + 0) * ld_dst + i], vec1);
881 307273 : _mm256_storeu_ps(&dst[(j + 1) * ld_dst + i], vec2);
882 307273 : _mm256_storeu_ps(&dst[(j + 2) * ld_dst + i], vec3);
883 307273 : _mm256_storeu_ps(&dst[(j + 3) * ld_dst + i], vec4);
884 307273 : _mm256_storeu_ps(&dst[(j + 4) * ld_dst + i], vec5);
885 307273 : _mm256_storeu_ps(&dst[(j + 5) * ld_dst + i], vec6);
886 307273 : _mm256_storeu_ps(&dst[(j + 6) * ld_dst + i], vec7);
887 307273 : _mm256_storeu_ps(&dst[(j + 7) * ld_dst + i], vec8);
888 : }
889 : }
890 :
891 : // tailing right
892 10910 : for (unsigned int i = 0; i < M; i++) {
893 12316 : for (unsigned int j = N8; j < N; j++) {
894 1431 : dst[i + j * ld_dst] = src[i * ld_src + j];
895 : }
896 : }
897 :
898 : // tailing bottom
899 86 : for (unsigned int i = M8; i < M; i++) {
900 356 : for (unsigned int j = 0; j < N; j++) {
901 295 : dst[i + j * ld_dst] = src[i * ld_src + j];
902 : }
903 : }
904 25 : }
905 :
906 0 : void swiglu(const unsigned int N, float *X, const float *Y, const float *Z) {
907 : size_t i = 0;
908 :
909 : const auto oldcsr = _mm_getcsr();
910 0 : _mm_setcsr(oldcsr | 0x8040); // DAZ | FTZ
911 :
912 : // 16-wide blocks
913 0 : for (; i + 16 <= N; i += 16) {
914 0 : const __m256 y0 = _mm256_loadu_ps(Y + i);
915 0 : const __m256 y1 = _mm256_loadu_ps(Y + i + 8);
916 0 : const __m256 z0 = _mm256_loadu_ps(Z + i);
917 0 : const __m256 z1 = _mm256_loadu_ps(Z + i + 8);
918 :
919 0 : _mm256_storeu_ps(X + i, avx2_approx_swiglu(y0, z0));
920 0 : _mm256_storeu_ps(X + i + 8, avx2_approx_swiglu(y1, z1));
921 : }
922 :
923 : // One 8-wide block if available
924 0 : if (i + 8 <= N) {
925 0 : const __m256 y0 = _mm256_loadu_ps(Y + i);
926 0 : const __m256 z0 = _mm256_loadu_ps(Z + i);
927 0 : _mm256_storeu_ps(X + i, avx2_approx_swiglu(y0, z0));
928 : i += 8;
929 : }
930 :
931 : // Remaining 1..7 elements via maskload/maskstore
932 0 : if (i < N) {
933 0 : const int remain = static_cast<int>(N - i); // 1..7
934 :
935 0 : alignas(64) const int mtab[16] = {-1, -1, -1, -1, -1, -1, -1, -1,
936 : 0, 0, 0, 0, 0, 0, 0, 0};
937 : // Start so that we take 'remain' ones then zeros.
938 0 : const int off = 8 - remain; // in [1..7], or 0 if remain==8
939 0 : const __m256i vmask = _mm256_loadu_si256((const __m256i *)(mtab + off));
940 :
941 0 : const __m256 y = _mm256_maskload_ps(Y + i, vmask);
942 0 : const __m256 z = _mm256_maskload_ps(Z + i, vmask);
943 : const __m256 r = avx2_approx_swiglu(y, z);
944 0 : _mm256_maskstore_ps(X + i, vmask, r);
945 : }
946 :
947 : _mm_setcsr(oldcsr);
948 0 : }
949 :
950 0 : void swiglu(const unsigned int N, float *X, const float *Y, const float *Z,
951 : float alpha) {
952 : size_t i = 0;
953 :
954 : const auto oldcsr = _mm_getcsr();
955 0 : _mm_setcsr(oldcsr | 0x8040); // DAZ | FTZ
956 :
957 : const __m256 alpha_vec = _mm256_set1_ps(alpha);
958 :
959 : // 16-wide blocks
960 0 : for (; i + 16 <= N; i += 16) {
961 0 : const __m256 y0 = _mm256_loadu_ps(Y + i);
962 0 : const __m256 y1 = _mm256_loadu_ps(Y + i + 8);
963 0 : const __m256 z0 = _mm256_loadu_ps(Z + i);
964 0 : const __m256 z1 = _mm256_loadu_ps(Z + i + 8);
965 :
966 0 : _mm256_storeu_ps(X + i, avx2_approx_swiglu_alpha(y0, z0, alpha_vec));
967 0 : _mm256_storeu_ps(X + i + 8, avx2_approx_swiglu_alpha(y1, z1, alpha_vec));
968 : }
969 :
970 : // One 8-wide block if present
971 0 : if (i + 8 <= N) {
972 0 : const __m256 y0 = _mm256_loadu_ps(Y + i);
973 0 : const __m256 z0 = _mm256_loadu_ps(Z + i);
974 0 : _mm256_storeu_ps(X + i, avx2_approx_swiglu_alpha(y0, z0, alpha_vec));
975 : i += 8;
976 : }
977 :
978 : // Remaining 1..7 elements via masked AVX (no stray stores)
979 0 : if (i < N) {
980 0 : const int remain = static_cast<int>(N - i); // 1..7
981 :
982 0 : alignas(64) const int mtab[16] = {
983 : -1, -1, -1, -1, -1, -1, -1, -1, // ones
984 : 0, 0, 0, 0, 0, 0, 0, 0 // zeros
985 : };
986 0 : const int off = 8 - remain; // choose first `remain` lanes active
987 0 : const __m256i vmask = _mm256_loadu_si256((const __m256i *)(mtab + off));
988 :
989 0 : const __m256 y = _mm256_maskload_ps(Y + i, vmask);
990 0 : const __m256 z = _mm256_maskload_ps(Z + i, vmask);
991 : const __m256 r = avx2_approx_swiglu_alpha(y, z, alpha_vec);
992 0 : _mm256_maskstore_ps(X + i, vmask, r);
993 : }
994 :
995 : _mm_setcsr(oldcsr);
996 0 : }
997 :
998 29926 : void ele_mul(const unsigned int N, const float *X, const float *Y, float *Z,
999 : float alpha, float beta, unsigned int i_stride,
1000 : unsigned int o_stride) {
1001 29926 : if (alpha == 1.0f && beta == 0.0f && o_stride == 1) {
1002 29581 : unsigned int N8 = (N & ~(7));
1003 29581 : if (i_stride == 0) {
1004 9591 : float vy8[8] = {Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0]};
1005 : auto y = _mm256_loadu_ps(&vy8[0]);
1006 18108 : for (unsigned int i = 0; i < N8; i += 8) {
1007 : auto x = _mm256_loadu_ps(X);
1008 : auto z = _mm256_mul_ps(x, y);
1009 : _mm256_storeu_ps(Z, z);
1010 8517 : X += 8;
1011 8517 : Y += i_stride * 8;
1012 8517 : Z += 8;
1013 : }
1014 : } else {
1015 13214695 : for (unsigned int i = 0; i < N8; i += 8) {
1016 : auto x = _mm256_loadu_ps(X);
1017 : auto y = _mm256_loadu_ps(Y);
1018 : auto z = _mm256_mul_ps(x, y);
1019 : _mm256_storeu_ps(Z, z);
1020 13194705 : X += 8;
1021 13194705 : Y += i_stride * 8;
1022 13194705 : Z += 8;
1023 : }
1024 : }
1025 138986 : for (unsigned int i = N8; i < N; ++i) {
1026 109405 : *Z = *X * *Y;
1027 109405 : X++;
1028 109405 : Y += i_stride;
1029 109405 : Z++;
1030 : }
1031 : } else {
1032 : // TODO: AVX2 implementation if used
1033 65505 : for (unsigned int i = 0; i < N; ++i) {
1034 65160 : *Z = *X * alpha * *Y + ((0.0f == beta) ? 0.0f : beta * *Z);
1035 65160 : X += o_stride;
1036 65160 : Y += i_stride;
1037 65160 : Z += o_stride;
1038 : }
1039 : }
1040 29926 : }
1041 :
1042 113033 : void ele_add(const unsigned int N, const float *X, const float *Y, float *Z,
1043 : float alpha, float beta, unsigned int i_stride,
1044 : unsigned int o_stride) {
1045 113033 : if (alpha == 1.0f && beta == 0.0f && o_stride == 1) {
1046 75820 : unsigned int N8 = (N & ~(7));
1047 75820 : if (i_stride == 0) {
1048 26375 : float vy8[8] = {Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0], Y[0]};
1049 : auto y = _mm256_loadu_ps(&vy8[0]);
1050 580391 : for (unsigned int i = 0; i < N8; i += 8) {
1051 : auto x = _mm256_loadu_ps(X);
1052 : auto z = _mm256_add_ps(x, y);
1053 : _mm256_storeu_ps(Z, z);
1054 554016 : X += 8;
1055 554016 : Y += i_stride * 8;
1056 554016 : Z += 8;
1057 : }
1058 : } else {
1059 169478 : for (unsigned int i = 0; i < N8; i += 8) {
1060 : auto x = _mm256_loadu_ps(X);
1061 : auto y = _mm256_loadu_ps(Y);
1062 : auto z = _mm256_add_ps(x, y);
1063 : _mm256_storeu_ps(Z, z);
1064 120033 : X += 8;
1065 120033 : Y += i_stride * 8;
1066 120033 : Z += 8;
1067 : }
1068 : }
1069 234632 : for (unsigned int i = N8; i < N; ++i) {
1070 158812 : *Z = *X + *Y;
1071 158812 : X++;
1072 158812 : Y += i_stride;
1073 158812 : Z++;
1074 : }
1075 : } else {
1076 : // TODO: AVX2 implementation if used
1077 203432843 : for (unsigned int i = 0; i < N; ++i) {
1078 203395630 : *Z = *X + alpha * *Y + ((0.0f == beta) ? 0.0f : beta * *Z);
1079 203395630 : X += o_stride;
1080 203395630 : Y += i_stride;
1081 203395630 : Z += o_stride;
1082 : }
1083 : }
1084 113033 : }
1085 :
1086 9 : static inline __m256 exp256_ps(__m256 x) {
1087 : /* Low-Precision Version I*/
1088 : // const __m256 c1 = _mm256_set1_ps(12102203.0f);
1089 : // const __m256 c2 = _mm256_set1_ps(1065353216.0f);
1090 : // __m256 fx = _mm256_add_ps(_mm256_mul_ps(x, c1),c2);
1091 : // return _mm256_castsi256_ps(_mm256_cvtps_epi32(fx));
1092 :
1093 : /* Low-Precision Version II*/
1094 : /* const __m256 ln2 = _mm256_set1_ps(0.69314718056f);
1095 : const __m256 inv_ln2 = _mm256_set1_ps(1.44269504089f); // 1 / ln(2)
1096 :
1097 : // Range reduction: x = n * ln2 + r, where n is integer and |r| <= ln2/2
1098 : __m256 fx = _mm256_mul_ps(x, inv_ln2);
1099 : fx = _mm256_round_ps(fx, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
1100 : __m256i emm0 = _mm256_cvtps_epi32(fx);
1101 :
1102 : __m256 tmp = _mm256_mul_ps(fx, ln2);
1103 : __m256 r = _mm256_sub_ps(x, tmp);
1104 :
1105 : // Compute polynomial approximation of exp(r)
1106 : const __m256 c1 = _mm256_set1_ps(1.9875691500E-4f);
1107 : const __m256 c2 = _mm256_set1_ps(1.3981999507E-3f);
1108 : const __m256 c3 = _mm256_set1_ps(8.3334519073E-3f);
1109 : const __m256 c4 = _mm256_set1_ps(4.1665795894E-2f);
1110 : const __m256 c5 = _mm256_set1_ps(1.6666665459E-1f);
1111 : const __m256 c6 = _mm256_set1_ps(5.0000001201E-1f);
1112 :
1113 : // __m256 r2 = _mm256_mul_ps(r, r);
1114 : // __m256 r3 = _mm256_mul_ps(r2, r);
1115 : // __m256 r4 = _mm256_mul_ps(r2, r2);
1116 :
1117 : __m256 y = _mm256_fmadd_ps(c1, r, c2);
1118 : y = _mm256_fmadd_ps(y, r, c3);
1119 : y = _mm256_fmadd_ps(y, r, c4);
1120 : y = _mm256_fmadd_ps(y, r, c5);
1121 : y = _mm256_fmadd_ps(y, r, c6);
1122 : y = _mm256_fmadd_ps(y, r, _mm256_set1_ps(1.0f));
1123 :
1124 : // Reconstruct exp(x) = 2^n * exp(r)
1125 : emm0 = _mm256_add_epi32(emm0, _mm256_set1_epi32(127));
1126 : emm0 = _mm256_slli_epi32(emm0, 23);
1127 : __m256 pow2n = _mm256_castsi256_ps(emm0);
1128 :
1129 : return _mm256_mul_ps(y, pow2n);
1130 : */
1131 : /* Low-Precision Versino III */
1132 : const __m256 LOG2EF = _mm256_set1_ps(1.44269504088896341f); // 1 / ln(2)
1133 : const __m256 LN2 = _mm256_set1_ps(0.6931471805599453f); // ln(2)
1134 :
1135 : // Clamp input to range to prevent overflow/underflow
1136 : const __m256 max_x = _mm256_set1_ps(88.3762626647949f); // log(FLT_MAX)
1137 : const __m256 min_x = _mm256_set1_ps(-88.3762626647949f); // log(FLT_MIN)
1138 : x = _mm256_max_ps(min_x, _mm256_min_ps(max_x, x));
1139 :
1140 : // Range reduction: x = n * ln2 + r
1141 : __m256 fx = _mm256_mul_ps(x, LOG2EF); // x * (1/ln(2))
1142 : fx = _mm256_floor_ps(_mm256_add_ps(fx, _mm256_set1_ps(0.5f)));
1143 :
1144 : __m256 tmp = _mm256_mul_ps(fx, LN2); // n * ln(2)
1145 : __m256 r = _mm256_sub_ps(x, tmp); // r = x - n * ln2
1146 :
1147 : // Compute exp(r) using 10th-order polynomial (Horner's method)
1148 : const __m256 c0 = _mm256_set1_ps(1.0f);
1149 : const __m256 c1 = _mm256_set1_ps(1.0f);
1150 : const __m256 c2 = _mm256_set1_ps(0.5f);
1151 : const __m256 c3 = _mm256_set1_ps(1.0f / 6.0f);
1152 : const __m256 c4 = _mm256_set1_ps(1.0f / 24.0f);
1153 : const __m256 c5 = _mm256_set1_ps(1.0f / 120.0f);
1154 : const __m256 c6 = _mm256_set1_ps(1.0f / 720.0f);
1155 : const __m256 c7 = _mm256_set1_ps(1.0f / 5040.0f);
1156 : const __m256 c8 = _mm256_set1_ps(1.0f / 40320.0f);
1157 : const __m256 c9 = _mm256_set1_ps(1.0f / 362880.0f);
1158 : const __m256 c10 = _mm256_set1_ps(1.0f / 3628800.0f);
1159 :
1160 : __m256 y = c10;
1161 : y = _mm256_fmadd_ps(y, r, c9);
1162 : y = _mm256_fmadd_ps(y, r, c8);
1163 : y = _mm256_fmadd_ps(y, r, c7);
1164 : y = _mm256_fmadd_ps(y, r, c6);
1165 : y = _mm256_fmadd_ps(y, r, c5);
1166 : y = _mm256_fmadd_ps(y, r, c4);
1167 : y = _mm256_fmadd_ps(y, r, c3);
1168 : y = _mm256_fmadd_ps(y, r, c2);
1169 : y = _mm256_fmadd_ps(y, r, c1);
1170 : y = _mm256_fmadd_ps(y, r, c0); // final y = (((...r+...)*r+...)*r + 1)
1171 :
1172 : // Reconstruct exp(x) = 2^n * exp(r)
1173 : __m256i emm0 = _mm256_cvtps_epi32(fx);
1174 : emm0 = _mm256_add_epi32(emm0, _mm256_set1_epi32(127));
1175 : emm0 = _mm256_slli_epi32(emm0, 23);
1176 : __m256 pow2n = _mm256_castsi256_ps(emm0);
1177 :
1178 9 : return _mm256_mul_ps(y, pow2n);
1179 : }
1180 :
1181 1 : static void softmax_row_inplace(float *qk_out, size_t start_row, size_t end_row,
1182 : size_t num_heads) {
1183 1 : size_t row_range = end_row - start_row;
1184 1 : const size_t full_blocks = (num_heads / 8) * 8;
1185 : // const size_t remainder = num_heads % 8;
1186 :
1187 1 : float *max_vals = new float[num_heads];
1188 1 : float *sum_vals = new float[num_heads];
1189 : // 1. max
1190 11 : for (size_t c = 0; c < num_heads; ++c) {
1191 10 : float max_val = -INFINITY;
1192 40 : for (size_t r = start_row; r < end_row; ++r)
1193 49 : max_val = std::max(max_val, qk_out[r * num_heads + c]);
1194 10 : max_vals[c] = max_val;
1195 : }
1196 :
1197 : // 2. inplace exp + sum
1198 2 : for (size_t c = 0; c < full_blocks; c += 8) {
1199 1 : __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
1200 : __m256 sum = _mm256_setzero_ps();
1201 4 : for (size_t r = 0; r < row_range; ++r) {
1202 3 : float *ptr = &qk_out[(start_row + r) * num_heads + c];
1203 : __m256 val = _mm256_loadu_ps(ptr);
1204 3 : __m256 e = exp256_ps(_mm256_sub_ps(val, maxv));
1205 : _mm256_storeu_ps(ptr, e); // overwrite qk_out
1206 : sum = _mm256_add_ps(sum, e);
1207 : }
1208 1 : _mm256_storeu_ps(&sum_vals[c], sum);
1209 : }
1210 :
1211 3 : for (size_t c = full_blocks; c < num_heads; ++c) {
1212 : float sum = 0.0f;
1213 2 : float maxv = max_vals[c];
1214 8 : for (size_t r = 0; r < row_range; ++r) {
1215 6 : float &a = qk_out[(start_row + r) * num_heads + c];
1216 6 : a = std::exp(a - maxv); // overwrite qk_out
1217 6 : sum += a;
1218 : }
1219 2 : sum_vals[c] = sum;
1220 : }
1221 : // 3. softmax = exp / sum (inplace)
1222 4 : for (size_t r = 0; r < row_range; ++r) {
1223 6 : for (size_t c = 0; c < full_blocks; c += 8) {
1224 3 : float *ptr = &qk_out[(start_row + r) * num_heads + c];
1225 : __m256 val = _mm256_loadu_ps(ptr); // already exp(x - max)
1226 3 : __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
1227 : __m256 soft = _mm256_div_ps(val, sumv);
1228 : _mm256_storeu_ps(ptr, soft);
1229 : }
1230 9 : for (size_t c = full_blocks; c < num_heads; ++c) {
1231 6 : qk_out[(start_row + r) * num_heads + c] /= sum_vals[c];
1232 : }
1233 : }
1234 :
1235 1 : delete[] max_vals;
1236 1 : delete[] sum_vals;
1237 1 : }
1238 :
1239 0 : static void softmax_row_with_sink_inplace(float *qk_out, size_t start_row,
1240 : size_t end_row, size_t num_heads,
1241 : float *sink) {
1242 0 : size_t row_range = end_row - start_row;
1243 0 : const size_t full_blocks = (num_heads / 8) * 8;
1244 :
1245 0 : float *max_vals = new float[num_heads];
1246 0 : float *sum_vals = new float[num_heads];
1247 : // 1. max
1248 0 : for (size_t c = 0; c < num_heads; ++c) {
1249 0 : float max_val = -INFINITY;
1250 0 : for (size_t r = start_row; r < end_row; ++r)
1251 0 : max_val = std::max(max_val, qk_out[r * num_heads + c]);
1252 0 : max_vals[c] = std::max(sink[c], max_val);
1253 : }
1254 :
1255 : // 2. inplace exp + sum
1256 0 : for (size_t c = 0; c < full_blocks; c += 8) {
1257 0 : __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
1258 0 : __m256 sum = _mm256_loadu_ps(&sink[c]);
1259 0 : sum = exp256_ps(_mm256_sub_ps(sum, maxv));
1260 0 : for (size_t r = 0; r < row_range; ++r) {
1261 0 : float *ptr = &qk_out[(start_row + r) * num_heads + c];
1262 : __m256 val = _mm256_loadu_ps(ptr);
1263 0 : __m256 e = exp256_ps(_mm256_sub_ps(val, maxv));
1264 : _mm256_storeu_ps(ptr, e); // overwrite qk_out
1265 : sum = _mm256_add_ps(sum, e);
1266 : }
1267 0 : _mm256_storeu_ps(&sum_vals[c], sum);
1268 : }
1269 :
1270 0 : for (size_t c = full_blocks; c < num_heads; ++c) {
1271 0 : float maxv = max_vals[c];
1272 0 : float sum = std::exp(sink[c] - maxv);
1273 0 : for (size_t r = 0; r < row_range; ++r) {
1274 0 : float &a = qk_out[(start_row + r) * num_heads + c];
1275 0 : a = std::exp(a - maxv); // overwrite qk_out
1276 0 : sum += a;
1277 : }
1278 0 : sum_vals[c] = sum;
1279 : }
1280 : // 3. softmax = exp / sum (inplace)
1281 0 : for (size_t r = 0; r < row_range; ++r) {
1282 0 : for (size_t c = 0; c < full_blocks; c += 8) {
1283 0 : float *ptr = &qk_out[(start_row + r) * num_heads + c];
1284 : __m256 val = _mm256_loadu_ps(ptr); // already exp(x - max)
1285 0 : __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
1286 : __m256 soft = _mm256_div_ps(val, sumv);
1287 : _mm256_storeu_ps(ptr, soft);
1288 : }
1289 0 : for (size_t c = full_blocks; c < num_heads; ++c) {
1290 0 : qk_out[(start_row + r) * num_heads + c] /= sum_vals[c];
1291 : }
1292 : }
1293 :
1294 0 : delete[] max_vals;
1295 0 : delete[] sum_vals;
1296 0 : }
1297 :
1298 : template <>
1299 1 : void softmax_row_inplace(float *qk_out, size_t start_row, size_t end_row,
1300 : size_t num_heads, float *sink) {
1301 1 : if (sink == nullptr) {
1302 1 : return softmax_row_inplace(qk_out, start_row, end_row, num_heads);
1303 : } else {
1304 0 : return softmax_row_with_sink_inplace(qk_out, start_row, end_row, num_heads,
1305 0 : sink);
1306 : }
1307 : }
1308 :
1309 1 : static void softmax_row(float *qk_out, size_t start_row, size_t end_row,
1310 : size_t num_heads) {
1311 1 : const size_t full_block = (num_heads / 8) * 8;
1312 :
1313 1 : float *max_vals = new float[num_heads];
1314 1 : float *sum_vals = new float[num_heads];
1315 :
1316 : // 1. Find Max along with col
1317 11 : for (size_t c = 0; c < num_heads; ++c) {
1318 10 : float max_val = -INFINITY;
1319 40 : for (size_t r = start_row; r < end_row; ++r) {
1320 49 : max_val = std::max(max_val, qk_out[r * num_heads + c]);
1321 : }
1322 10 : max_vals[c] = max_val;
1323 : }
1324 :
1325 : // 2. Compute sum along with col (exp vectorized)
1326 2 : for (size_t c = 0; c < full_block; c += 8) {
1327 : __m256 sum = _mm256_setzero_ps();
1328 4 : for (size_t r = start_row; r < end_row; ++r) {
1329 3 : __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
1330 3 : __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
1331 : __m256 sub = _mm256_sub_ps(val, maxv);
1332 3 : __m256 e = exp256_ps(sub);
1333 : sum = _mm256_add_ps(sum, e);
1334 : }
1335 1 : _mm256_storeu_ps(&sum_vals[c], sum);
1336 : }
1337 :
1338 3 : for (size_t c = full_block; c < num_heads; ++c) {
1339 : float sum = 0.0f;
1340 8 : for (size_t r = start_row; r < end_row; ++r) {
1341 6 : sum += std::exp(qk_out[r * num_heads + c] - max_vals[c]);
1342 : }
1343 2 : sum_vals[c] = sum;
1344 : }
1345 :
1346 : // 3. apply softmax
1347 4 : for (size_t r = start_row; r < end_row; ++r) {
1348 6 : for (size_t c = 0; c < full_block; c += 8) {
1349 3 : __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
1350 3 : __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
1351 : __m256 sub = _mm256_sub_ps(val, maxv);
1352 3 : __m256 e = exp256_ps(sub);
1353 3 : __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
1354 : __m256 softmax = _mm256_div_ps(e, sumv);
1355 : _mm256_storeu_ps(&qk_out[r * num_heads + c], softmax);
1356 : }
1357 9 : for (size_t c = full_block; c < num_heads; ++c) {
1358 6 : qk_out[r * num_heads + c] =
1359 6 : std::exp(qk_out[r * num_heads + c] - max_vals[c]) / sum_vals[c];
1360 : }
1361 : }
1362 :
1363 1 : delete[] max_vals;
1364 1 : delete[] sum_vals;
1365 1 : }
1366 :
1367 0 : static void softmax_row_with_sink(float *qk_out, size_t start_row,
1368 : size_t end_row, size_t num_heads,
1369 : float *sink) {
1370 0 : const size_t full_block = (num_heads / 8) * 8;
1371 :
1372 0 : float *max_vals = new float[num_heads];
1373 0 : float *sum_vals = new float[num_heads];
1374 :
1375 : // 1. Find Max along with col
1376 0 : for (size_t c = 0; c < num_heads; ++c) {
1377 0 : float max_val = -INFINITY;
1378 0 : for (size_t r = start_row; r < end_row; ++r) {
1379 0 : max_val = std::max(max_val, qk_out[r * num_heads + c]);
1380 : }
1381 0 : max_vals[c] = std::max(max_val, sink[c]);
1382 : }
1383 :
1384 : // 2. Compute sum along with col (exp vectorized)
1385 0 : for (size_t c = 0; c < full_block; c += 8) {
1386 0 : __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
1387 0 : __m256 sum = _mm256_loadu_ps(&sink[c]);
1388 : sum = _mm256_sub_ps(sum, maxv);
1389 0 : sum = exp256_ps(sum);
1390 0 : for (size_t r = start_row; r < end_row; ++r) {
1391 0 : __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
1392 : __m256 sub = _mm256_sub_ps(val, maxv);
1393 0 : __m256 e = exp256_ps(sub);
1394 : sum = _mm256_add_ps(sum, e);
1395 : }
1396 0 : _mm256_storeu_ps(&sum_vals[c], sum);
1397 : }
1398 :
1399 0 : for (size_t c = full_block; c < num_heads; ++c) {
1400 0 : float sum = std::exp(sink[c] - max_vals[c]);
1401 0 : for (size_t r = start_row; r < end_row; ++r) {
1402 0 : sum += std::exp(qk_out[r * num_heads + c] - max_vals[c]);
1403 : }
1404 0 : sum_vals[c] = sum;
1405 : }
1406 :
1407 : // 3. apply softmax
1408 0 : for (size_t r = start_row; r < end_row; ++r) {
1409 0 : for (size_t c = 0; c < full_block; c += 8) {
1410 0 : __m256 val = _mm256_loadu_ps(&qk_out[r * num_heads + c]);
1411 0 : __m256 maxv = _mm256_loadu_ps(&max_vals[c]);
1412 : __m256 sub = _mm256_sub_ps(val, maxv);
1413 0 : __m256 e = exp256_ps(sub);
1414 0 : __m256 sumv = _mm256_loadu_ps(&sum_vals[c]);
1415 : __m256 softmax = _mm256_div_ps(e, sumv);
1416 : _mm256_storeu_ps(&qk_out[r * num_heads + c], softmax);
1417 : }
1418 0 : for (size_t c = full_block; c < num_heads; ++c) {
1419 0 : qk_out[r * num_heads + c] =
1420 0 : std::exp(qk_out[r * num_heads + c] - max_vals[c]) / sum_vals[c];
1421 : }
1422 : }
1423 :
1424 0 : delete[] max_vals;
1425 0 : delete[] sum_vals;
1426 0 : }
1427 :
1428 : template <>
1429 1 : void softmax_row(float *qk_out, size_t start_row, size_t end_row,
1430 : size_t num_heads, float *sink) {
1431 1 : if (sink == nullptr) {
1432 1 : return softmax_row(qk_out, start_row, end_row, num_heads);
1433 : } else {
1434 0 : return softmax_row_with_sink(qk_out, start_row, end_row, num_heads, sink);
1435 : }
1436 : }
1437 : #ifdef _WIN32
1438 : #define COMPUTE_FP16_TO_FP32(x) \
1439 : _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
1440 : #define COMPUTE_FP32_TO_FP16(x) \
1441 : _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
1442 : #elif defined(__TIZEN__) && !defined(__F16C__)
1443 : #define COMPUTE_FP16_TO_FP32(x) nntrainer::compute_fp16_to_fp32(x)
1444 : #define COMPUTE_FP32_TO_FP16(x) nntrainer::compute_fp32_to_fp16(x)
1445 : #else
1446 : #define COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
1447 : #define COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
1448 : #endif
1449 :
1450 : static inline __m256 convert_vector_f16_to_f32(__m128i x) {
1451 : #if defined(__TIZEN__) && !defined(__F16C__)
1452 : alignas(32) uint16_t u16_array[8]; // 32-byte aligned storage
1453 : alignas(32) float f32_array[8]; // 32-byte aligned storage
1454 :
1455 : // Safely store __m128i to array (avoids aliasing)
1456 : _mm_storeu_si128(reinterpret_cast<__m128i *>(u16_array), x);
1457 :
1458 : // Convert each FP16 value to FP32
1459 : for (int i = 0; i < 8; i++) {
1460 : f32_array[i] = COMPUTE_FP16_TO_FP32(u16_array[i]);
1461 : }
1462 :
1463 : // Load aligned array into __m256
1464 : return _mm256_load_ps(f32_array);
1465 : #else
1466 : return _mm256_cvtph_ps(x);
1467 : #endif
1468 : }
1469 :
1470 : static inline __m128i convert_vector_f32_to_f16(__m256 x) {
1471 : #if defined(__TIZEN__) && !defined(__F16C__)
1472 : __m128i vec_f16;
1473 : float *f32_ptr = reinterpret_cast<float *>(&x);
1474 : uint16_t *u16_ptr = reinterpret_cast<uint16_t *>(&vec_f16);
1475 : for (int i = 0; i < 8; i++) {
1476 : u16_ptr[i] = COMPUTE_FP32_TO_FP16(f32_ptr[i]);
1477 : }
1478 : return vec_f16;
1479 : #else
1480 : return _mm256_cvtps_ph(x, 0);
1481 : #endif
1482 : }
1483 :
1484 : static inline __m128i convert_vector_f32_to_f16(__m128 x) {
1485 : #if defined(__TIZEN__) && !defined(__F16C__)
1486 : __m128i vec_f16;
1487 : float *f32_ptr = reinterpret_cast<float *>(&x);
1488 : uint16_t *u16_ptr = reinterpret_cast<uint16_t *>(&vec_f16);
1489 :
1490 : for (int i = 0; i < 4; i++) {
1491 : u16_ptr[i] = COMPUTE_FP32_TO_FP16(f32_ptr[i]);
1492 : }
1493 : return vec_f16;
1494 : #else
1495 : return _mm_cvtps_ph(x, 0);
1496 : #endif
1497 : }
1498 :
1499 12 : static inline void load_fp16_8_to_chunk(const uint16_t *src, float *dst,
1500 : int chunk_size) {
1501 : int i = 0;
1502 24 : for (; i + 8 <= chunk_size; i += 8) {
1503 12 : __m128i half = _mm_loadu_si128(reinterpret_cast<const __m128i *>(src + i));
1504 : __m256 f32 = convert_vector_f16_to_f32(half);
1505 12 : _mm256_storeu_ps(&dst[i], f32);
1506 : }
1507 32 : for (; i < chunk_size; ++i) {
1508 20 : dst[i] = nntrainer::compute_fp16_to_fp32(src[i]);
1509 : }
1510 12 : }
1511 :
1512 1 : void compute_fp16vcache_fp32_transposed(int row_num, const float *in,
1513 : const uint16_t *vcache, float *output,
1514 : int num_cache_head, int gqa_size,
1515 : int head_dim,
1516 : size_t local_window_size) {
1517 : // cpu_set_t cpu_set;
1518 : // CPU_ZERO(&cpu_set);
1519 : // std::vector<bool> affinity(8, false);
1520 : // affinity[i % affinity.size()] = true;
1521 :
1522 : // for (std::size_t j = 0;
1523 : // j < std::min<std::size_t>(affinity.size(), CPU_SETSIZE); ++j) {
1524 : // if (affinity[j])
1525 : // CPU_SET(j, &cpu_set);
1526 : // }
1527 : // pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpu_set);
1528 :
1529 1 : std::vector<float> tmp_fp32(head_dim);
1530 1 : int num_blocks = head_dim / 8;
1531 2 : __m256 *sumVec = new __m256[std::max(1, num_blocks * gqa_size)];
1532 :
1533 3 : for (int n = 0; n < num_cache_head; ++n) {
1534 2 : int rem = head_dim % 8;
1535 :
1536 : /* Declaration: std::vector<__m256> sumVec(num_blocks * gqa_size,
1537 : * _mm256_setzero_ps()); caused warning: ignoring attributes on template
1538 : * argument ‘__m256’ [-Wignored-attributes].
1539 : * So it is implemented that way.
1540 : */
1541 6 : for (int i = 0; i < num_blocks * gqa_size; i++) {
1542 4 : sumVec[i] = _mm256_setzero_ps();
1543 : }
1544 2 : std::vector<float> sumRem((size_t)gqa_size * rem, 0.0f);
1545 :
1546 6 : for (int j = row_num < local_window_size ? 0
1547 0 : : row_num + 1 - local_window_size;
1548 6 : j <= row_num; ++j) {
1549 4 : const uint16_t *vptr = vcache + (j * num_cache_head + n) * head_dim;
1550 4 : load_fp16_8_to_chunk(vptr, tmp_fp32.data(), head_dim);
1551 :
1552 12 : for (int h = 0; h < gqa_size; ++h) {
1553 8 : float a_val =
1554 : in[(row_num < local_window_size
1555 8 : ? j
1556 0 : : (unsigned long)(j - (row_num + 1 - local_window_size))) *
1557 8 : (unsigned long)(gqa_size * num_cache_head) +
1558 8 : (unsigned long)(n * gqa_size) + h];
1559 :
1560 : __m256 inVec = _mm256_set1_ps(a_val);
1561 :
1562 16 : for (int b = 0; b < num_blocks; ++b) {
1563 8 : __m256 bVec = _mm256_loadu_ps(&tmp_fp32[b * 8]);
1564 8 : sumVec[h * num_blocks + b] =
1565 8 : _mm256_fmadd_ps(inVec, bVec, sumVec[h * num_blocks + b]);
1566 : }
1567 :
1568 8 : float *remPtr = &sumRem.data()[h * rem];
1569 8 : int base = num_blocks * 8;
1570 16 : for (int r = 0; r < rem; ++r) {
1571 8 : remPtr[r] += a_val * tmp_fp32[base + r];
1572 : }
1573 : }
1574 : }
1575 :
1576 6 : for (int h = 0; h < gqa_size; ++h) {
1577 8 : for (int b = 0; b < num_blocks; ++b) {
1578 4 : int out_base = (n * gqa_size + h) * head_dim + b * 8;
1579 4 : _mm256_storeu_ps(&output[out_base], sumVec[h * num_blocks + b]);
1580 : }
1581 :
1582 4 : float *remPtr = &sumRem.data()[h * rem];
1583 : // float *remPtr = &sumRem[h * rem];
1584 4 : int base = num_blocks * 8;
1585 8 : for (int r = 0; r < rem; ++r) {
1586 4 : int out_idx = (n * gqa_size + h) * head_dim + base + r;
1587 4 : output[out_idx] = remPtr[r];
1588 : }
1589 : }
1590 2 : }
1591 1 : delete[] sumVec;
1592 1 : }
1593 :
1594 : template <>
1595 1 : void compute_kcaches(const float *in, const uint16_t *kcache, float *output,
1596 : int num_rows, int num_cache_head, int head_dim,
1597 : int gqa_size, int tile_size, size_t local_window_size) {
1598 1 : std::vector<float> tmp_fp32(head_dim);
1599 :
1600 1 : int start_row =
1601 1 : num_rows < local_window_size ? 0 : num_rows - local_window_size;
1602 1 : int row_cnt = num_rows < local_window_size ? num_rows : local_window_size;
1603 1 : const int tile_count = (row_cnt + tile_size - 1) / tile_size;
1604 :
1605 3 : for (int n = 0; n < num_cache_head; ++n) {
1606 4 : for (int t = 0; t < tile_count; ++t) {
1607 2 : int row_tile_start = t * tile_size;
1608 2 : int tile_rows = std::min(tile_size, row_cnt - row_tile_start);
1609 :
1610 10 : for (int g = 0; g < gqa_size; ++g) {
1611 8 : const float *in_ptr = in + n * gqa_size * head_dim + g * head_dim;
1612 16 : for (int t_row = 0; t_row < tile_rows; ++t_row) {
1613 8 : int row = start_row + row_tile_start + t_row;
1614 8 : if (row + 1 < num_rows) {
1615 0 : const uint16_t *next_kptr =
1616 0 : kcache + ((row + 1) * num_cache_head + n) * head_dim;
1617 : _mm_prefetch(reinterpret_cast<const char *>(next_kptr),
1618 : _MM_HINT_T0);
1619 : }
1620 8 : const uint16_t *kptr = kcache + (row * num_cache_head + n) * head_dim;
1621 8 : load_fp16_8_to_chunk(kptr, tmp_fp32.data(), head_dim);
1622 :
1623 : const float *k_row = tmp_fp32.data();
1624 :
1625 : float sum = 0.0f;
1626 : int i = 0;
1627 : __m256 acc = _mm256_setzero_ps();
1628 16 : for (; i + 8 <= head_dim; i += 8) {
1629 8 : __m256 va = _mm256_loadu_ps(in_ptr + i);
1630 8 : __m256 vb = _mm256_loadu_ps(k_row + i);
1631 : acc = _mm256_fmadd_ps(va, vb, acc);
1632 : }
1633 :
1634 : __m128 low = _mm256_castps256_ps128(acc);
1635 : __m128 high = _mm256_extractf128_ps(acc, 1);
1636 : __m128 sum128 = _mm_add_ps(low, high);
1637 : sum128 = _mm_hadd_ps(sum128, sum128);
1638 : sum128 = _mm_hadd_ps(sum128, sum128);
1639 8 : sum += _mm_cvtss_f32(sum128);
1640 :
1641 24 : for (; i < head_dim; ++i)
1642 16 : sum += in_ptr[i] * k_row[i];
1643 :
1644 8 : output[(row - start_row) * num_cache_head * gqa_size + n * gqa_size +
1645 8 : g] = sum / sqrt((float)head_dim);
1646 : }
1647 : }
1648 : }
1649 : }
1650 1 : }
1651 :
1652 2 : void compute_rotary_emb_value(unsigned int width, unsigned int dim,
1653 : unsigned int half_, float *inout, void *output,
1654 : const float *cos_, const float *sin_,
1655 : bool only_convert_to_fp16) {
1656 : enum class OutputType { FP16, FP32 };
1657 :
1658 : OutputType out_type = OutputType::FP32;
1659 2 : if (output != nullptr)
1660 : out_type = OutputType::FP16;
1661 :
1662 6 : for (unsigned int w = 0; w < width; w += dim) {
1663 : unsigned int k = 0;
1664 8 : for (; k + 7 < half_; k += 8) {
1665 4 : unsigned int i0 = w + k;
1666 4 : unsigned int i1 = w + k + half_;
1667 :
1668 4 : __m256 a = _mm256_loadu_ps(&inout[i0]);
1669 4 : __m256 b = _mm256_loadu_ps(&inout[i1]);
1670 :
1671 4 : if (only_convert_to_fp16) {
1672 0 : if (out_type == OutputType::FP16) {
1673 : __m128i a_fp16 = convert_vector_f32_to_f16(a);
1674 : __m128i b_fp16 = convert_vector_f32_to_f16(b);
1675 :
1676 0 : _mm_storeu_si128(
1677 0 : reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i0),
1678 : a_fp16);
1679 0 : _mm_storeu_si128(
1680 0 : reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i1),
1681 : b_fp16);
1682 : }
1683 :
1684 : } else {
1685 4 : __m256 cos_v = _mm256_loadu_ps(&cos_[k]);
1686 4 : __m256 sin_v = _mm256_loadu_ps(&sin_[k]);
1687 :
1688 : __m256 out0 =
1689 : _mm256_sub_ps(_mm256_mul_ps(a, cos_v), _mm256_mul_ps(b, sin_v));
1690 : __m256 out1 =
1691 : _mm256_add_ps(_mm256_mul_ps(a, sin_v), _mm256_mul_ps(b, cos_v));
1692 :
1693 4 : if (out_type == OutputType::FP16) {
1694 : __m128i out0_fp16 = convert_vector_f32_to_f16(out0);
1695 : __m128i out1_fp16 = convert_vector_f32_to_f16(out1);
1696 :
1697 2 : _mm_storeu_si128(
1698 2 : reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i0),
1699 : out0_fp16);
1700 2 : _mm_storeu_si128(
1701 2 : reinterpret_cast<__m128i *>(static_cast<uint16_t *>(output) + i1),
1702 : out1_fp16);
1703 :
1704 : } else if (out_type == OutputType::FP32) {
1705 : _mm256_storeu_ps(&inout[i0], out0);
1706 : _mm256_storeu_ps(&inout[i1], out1);
1707 : }
1708 : }
1709 : }
1710 :
1711 12 : for (; k < half_; ++k) {
1712 8 : unsigned int i0 = w + k;
1713 8 : unsigned int i1 = w + k + half_;
1714 : // assert(i1 < width && "Scalar i1 overflow!");
1715 8 : float a = inout[i0];
1716 8 : float b = inout[i1];
1717 :
1718 8 : if (only_convert_to_fp16) {
1719 0 : static_cast<uint16_t *>(output)[i0] = COMPUTE_FP32_TO_FP16(a);
1720 0 : static_cast<uint16_t *>(output)[i1] = COMPUTE_FP32_TO_FP16(b);
1721 : } else {
1722 8 : float c = cos_[k];
1723 8 : float s = sin_[k];
1724 :
1725 8 : float out0 = a * c - b * s;
1726 8 : float out1 = a * s + b * c;
1727 :
1728 8 : if (out_type == OutputType::FP16) {
1729 4 : static_cast<uint16_t *>(output)[i0] = COMPUTE_FP32_TO_FP16(out0);
1730 8 : static_cast<uint16_t *>(output)[i1] = COMPUTE_FP32_TO_FP16(out1);
1731 : } else if (out_type == OutputType::FP32) {
1732 4 : inout[i0] = out0;
1733 4 : inout[i1] = out1;
1734 : }
1735 : }
1736 : }
1737 : }
1738 2 : }
1739 :
1740 : static float hsum_avx(__m256 v) {
1741 : __m128 vlow = _mm256_castps256_ps128(v);
1742 : __m128 vhigh = _mm256_extractf128_ps(v, 1);
1743 : vlow = _mm_add_ps(vlow, vhigh);
1744 : __m128 shuf = _mm_movehdup_ps(vlow);
1745 : __m128 sums = _mm_add_ps(vlow, shuf);
1746 : shuf = _mm_movehl_ps(shuf, sums);
1747 : sums = _mm_add_ss(sums, shuf);
1748 : return _mm_cvtss_f32(sums);
1749 : }
1750 :
1751 0 : void rms_norm_wrt_width_fp32_intrinsic(const float *__restrict X,
1752 : float *__restrict Y, size_t H, size_t W,
1753 : float epsilon) {
1754 0 : for (std::size_t h = 0; h < H; ++h) {
1755 0 : const float *rowX = X + h * W;
1756 0 : float *rowY = Y + h * W;
1757 :
1758 : std::size_t i = 0;
1759 : __m256 acc0 = _mm256_setzero_ps();
1760 : __m256 acc1 = _mm256_setzero_ps();
1761 : __m256 acc2 = _mm256_setzero_ps();
1762 : __m256 acc3 = _mm256_setzero_ps();
1763 :
1764 0 : for (; i + 32 <= W; i += 32) {
1765 0 : __m256 x0 = _mm256_loadu_ps(rowX + i);
1766 0 : __m256 x1 = _mm256_loadu_ps(rowX + i + 8);
1767 0 : __m256 x2 = _mm256_loadu_ps(rowX + i + 16);
1768 0 : __m256 x3 = _mm256_loadu_ps(rowX + i + 24);
1769 : acc0 = _mm256_fmadd_ps(x0, x0, acc0);
1770 : acc1 = _mm256_fmadd_ps(x1, x1, acc1);
1771 : acc2 = _mm256_fmadd_ps(x2, x2, acc2);
1772 : acc3 = _mm256_fmadd_ps(x3, x3, acc3);
1773 : }
1774 0 : for (; i + 8 <= W; i += 8) {
1775 0 : __m256 x = _mm256_loadu_ps(rowX + i);
1776 : acc0 = _mm256_fmadd_ps(x, x, acc0);
1777 : }
1778 : float sumsq =
1779 0 : hsum_avx(acc0) + hsum_avx(acc1) + hsum_avx(acc2) + hsum_avx(acc3);
1780 0 : for (; i < W; ++i) {
1781 0 : float v = rowX[i];
1782 0 : sumsq += v * v;
1783 : }
1784 :
1785 0 : float mean = sumsq / static_cast<float>(W);
1786 0 : float scale = 1.0f / std::sqrt(mean + epsilon);
1787 : __m256 vscale = _mm256_set1_ps(scale);
1788 :
1789 : i = 0;
1790 0 : for (; i + 32 <= W; i += 32) {
1791 0 : __m256 x0 = _mm256_loadu_ps(rowX + i);
1792 0 : __m256 x1 = _mm256_loadu_ps(rowX + i + 8);
1793 0 : __m256 x2 = _mm256_loadu_ps(rowX + i + 16);
1794 0 : __m256 x3 = _mm256_loadu_ps(rowX + i + 24);
1795 0 : _mm256_storeu_ps(rowY + i, _mm256_mul_ps(x0, vscale));
1796 0 : _mm256_storeu_ps(rowY + i + 8, _mm256_mul_ps(x1, vscale));
1797 0 : _mm256_storeu_ps(rowY + i + 16, _mm256_mul_ps(x2, vscale));
1798 0 : _mm256_storeu_ps(rowY + i + 24, _mm256_mul_ps(x3, vscale));
1799 : }
1800 0 : for (; i + 8 <= W; i += 8) {
1801 0 : __m256 x = _mm256_loadu_ps(rowX + i);
1802 0 : _mm256_storeu_ps(rowY + i, _mm256_mul_ps(x, vscale));
1803 : }
1804 0 : for (; i < W; ++i) {
1805 0 : rowY[i] = rowX[i] * scale;
1806 : }
1807 : }
1808 0 : }
1809 :
1810 : template <>
1811 21 : void clamp(const float *input, float *output, size_t length, float lower_bound,
1812 : float upper_bound) {
1813 : const size_t step = 8;
1814 : const __m256 vLo = _mm256_set1_ps(lower_bound);
1815 : const __m256 vHi = _mm256_set1_ps(upper_bound);
1816 :
1817 : size_t i = 0;
1818 8085 : for (; i + step <= length; i += step) {
1819 8064 : __m256 v = _mm256_loadu_ps(input + i);
1820 : v = _mm256_max_ps(v, vLo);
1821 : v = _mm256_min_ps(v, vHi);
1822 8064 : _mm256_storeu_ps(output + i, v);
1823 : }
1824 21 : if (i < length) {
1825 0 : for (size_t k = i; k < length; ++k) {
1826 0 : float v = input[k];
1827 : // If v is NaN, the comparisons below will yield false; we keep NaN.
1828 : // This matches most framework "pass-through NaN" behavior.
1829 0 : output[k] =
1830 0 : (v < lower_bound) ? lower_bound : ((v > upper_bound) ? upper_bound : v);
1831 : }
1832 : }
1833 21 : }
1834 :
1835 1 : void copy_f16_f32(unsigned int N, const uint16_t *input, float *output) {
1836 : unsigned int idx = 0;
1837 : const uint16_t *data = (const uint16_t *)input;
1838 :
1839 : // 16 half-precision floating point values to single-precision values
1840 6 : for (; N - idx >= 16; idx += 16) {
1841 : const __m256 vec0 =
1842 : convert_vector_f16_to_f32(_mm_loadu_si128((const __m128i *)data));
1843 : const __m256 vec1 =
1844 : convert_vector_f16_to_f32(_mm_loadu_si128((const __m128i *)(data + 8)));
1845 5 : data += 16;
1846 :
1847 : _mm256_storeu_ps(output, vec0);
1848 : _mm256_storeu_ps(output + 8, vec1);
1849 5 : output += 16;
1850 : }
1851 : // 8 half-precision floating point values to single-precision values
1852 2 : for (; N - idx >= 8; idx += 8) {
1853 : const __m256 vec =
1854 : convert_vector_f16_to_f32(_mm_loadu_si128((const __m128i *)data));
1855 1 : data += 8;
1856 :
1857 : _mm256_storeu_ps(output, vec);
1858 1 : output += 8;
1859 : }
1860 : // remaining half-precision floating point values to single-precision values
1861 3 : while (idx < N) {
1862 2 : *output = compute_fp16_to_fp32(*data);
1863 2 : ++output;
1864 2 : ++data;
1865 2 : ++idx;
1866 : }
1867 1 : }
1868 :
1869 0 : void copy_f32_f16(unsigned int N, const float *input, uint16_t *output) {
1870 : unsigned int idx = 0;
1871 : uint16_t *out_data = (uint16_t *)output;
1872 :
1873 : // 16 single-precision floating point values to half-precision values
1874 0 : for (; N - idx >= 16; idx += 16) {
1875 : const __m256 vec0 = _mm256_loadu_ps(input);
1876 : const __m256 vec1 = _mm256_loadu_ps(input + 8);
1877 0 : input += 16;
1878 :
1879 : _mm_storeu_si128((__m128i *)out_data, convert_vector_f32_to_f16(vec0));
1880 : _mm_storeu_si128((__m128i *)(out_data + 8),
1881 : convert_vector_f32_to_f16(vec1));
1882 0 : out_data += 16;
1883 : }
1884 : // 8 single-precision floating point values to half-precision values
1885 0 : for (; N - idx >= 8; idx += 8) {
1886 : const __m256 vec = _mm256_loadu_ps(input);
1887 0 : input += 8;
1888 :
1889 : _mm_storeu_si128((__m128i *)out_data, convert_vector_f32_to_f16(vec));
1890 0 : out_data += 8;
1891 : }
1892 : // 4 single-precision floating point values to half-precision values
1893 0 : for (; N - idx >= 4; idx += 4) {
1894 : const __m128 vec = _mm_loadu_ps(input);
1895 0 : input += 4;
1896 :
1897 : _mm_storeu_si64((__m128i *)out_data, convert_vector_f32_to_f16(vec));
1898 0 : out_data += 4;
1899 : }
1900 : // remaining single-precision floating point values to half-precision values
1901 0 : while (idx < N) {
1902 0 : *out_data = compute_fp32_to_fp16(*input);
1903 0 : ++out_data;
1904 0 : ++input;
1905 0 : ++idx;
1906 : }
1907 0 : }
1908 :
1909 1752656 : void create_q4_0_weights(const uint8_t *int4_weight, uint8_t *q4_0_weight) {
1910 : // Load 16 bytes of input data
1911 : __m128i input = _mm_loadu_si128((const __m128i *)int4_weight);
1912 :
1913 : // Create masks for extracting low and high nibbles
1914 : const __m128i low_nibble_mask = _mm_set1_epi8(0x0F);
1915 : const __m128i high_nibble_mask = _mm_set1_epi8(static_cast<char>(0xF0));
1916 :
1917 : // Extract low nibbles from first 8 bytes
1918 : __m128i A = _mm_and_si128(input, low_nibble_mask);
1919 :
1920 : // Extract high nibbles from first 8 bytes and shift right
1921 : __m128i B = _mm_and_si128(input, high_nibble_mask);
1922 : B = _mm_srli_epi16(B, 4);
1923 :
1924 : // Extract low nibbles from second 8 bytes
1925 : __m128i input_shifted = _mm_bsrli_si128(input, 8);
1926 : __m128i C = _mm_and_si128(input_shifted, low_nibble_mask);
1927 :
1928 : // Extract high nibbles from second 8 bytes and shift right
1929 : __m128i D = _mm_and_si128(input_shifted, high_nibble_mask);
1930 : D = _mm_srli_epi16(D, 4);
1931 :
1932 : // Interleave low nibbles: v0 from first8, v2 from second8
1933 : __m128i AC = _mm_or_si128(A, _mm_slli_epi16(C, 4));
1934 :
1935 : // Interleave high nibbles: v1 from first8, v3 from second8
1936 : __m128i BD = _mm_or_si128(B, _mm_slli_epi16(D, 4));
1937 :
1938 : // Pack the results: interleave low and high bytes
1939 : __m128i result = _mm_unpacklo_epi8(AC, BD);
1940 :
1941 : // Store the 16 bytes result
1942 : _mm_storeu_si128((__m128i *)q4_0_weight, result);
1943 1752656 : }
1944 :
1945 : } // namespace nntrainer::avx2
|