Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2025 Michal Wlasiuk <testmailsmtp12345@gmail.com>
4 : * Copyright (C) 2025 Sungsik Kong <ss.kong@samsung.com>
5 : *
6 : * @file ggml_interface_omp.cpp
7 : * @date 15 April 2025
8 : * @see https://github.com/nntrainer/nntrainer
9 : * @author Michal Wlasiuk <testmailsmtp12345@gmail.com>
10 : * @author Sungsik Kong <ss.kong@samsung.com>
11 : * @bug No known bugs except for NYI items
12 : * @brief Function interface to use ggml lib from cpu_backend - accelerated
13 : * only with openMP
14 : */
15 :
16 : #include <ggml_interface.h>
17 : #include <nntr_ggml_impl.h>
18 : #include <nntr_ggml_impl_utils.h>
19 :
20 : #include <algorithm>
21 : #include <stdexcept>
22 : #include <string>
23 : #include <thread>
24 : #include <vector>
25 :
26 : namespace nntrainer {
27 :
28 : template <>
29 0 : void __ggml_q4_0_4x8_q8_0_GEMM(const unsigned int M, const unsigned int N,
30 : const unsigned int K, const float *A,
31 : const unsigned int lda, const void *B,
32 : const unsigned int ldb, float *C,
33 : const unsigned int ldc) {
34 : int NB_COLS = 4;
35 0 : if (M == 1) { // GEMV
36 : int n_threads = 4;
37 0 : unsigned int B_step = sizeof(block_q4_0) * (K / QK4_0);
38 0 : unsigned int blocks_per_row = (K + QK8_0 - 1) / QK8_0;
39 0 : unsigned int qa_size = sizeof(block_q8_0) * blocks_per_row;
40 0 : std::vector<char> QA = std::vector<char>(qa_size);
41 0 : nntr_quantize_row_q8_0(A, QA.data(), K);
42 :
43 0 : #pragma omp parallel for num_threads(n_threads)
44 : for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
45 : unsigned int M_step_start = (thread_idx * N) / n_threads; // = 0
46 : unsigned int M_step_end = ((thread_idx + 1) * N) / n_threads; // ne01 = N
47 :
48 : M_step_start = (M_step_start % NB_COLS)
49 : ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
50 : : M_step_start;
51 : M_step_end = (M_step_end % NB_COLS)
52 : ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
53 : : M_step_end;
54 :
55 : nntr_gemv_q4_0_4x8_q8_0(K, (float *)((C) + M_step_start), N,
56 : (void *)((char *)B + M_step_start * B_step),
57 : QA.data(), M, M_step_end - M_step_start);
58 : }
59 0 : } else if (M % 4 != 0) {
60 : int n_threads = 8;
61 0 : unsigned int blocks_per_4_rows = (K + QK8_0 - 1) / QK8_0;
62 0 : unsigned int qa_4_rows_size = sizeof(block_q8_0x4) * blocks_per_4_rows;
63 0 : const size_t qa_row_size = (sizeof(block_q8_0) * K) / QK8_0;
64 0 : unsigned int M4 = ((M - M % 4) / 4);
65 0 : int B_step = sizeof(block_q4_0) * (K / QK4_0);
66 :
67 0 : unsigned int qa_size = qa_4_rows_size * (((M >> 2) << 2) / 4 + 1);
68 0 : std::vector<char> QA = std::vector<char>(qa_size);
69 :
70 : // Quantize 4-divisible-M row portion with matrix-wise function
71 0 : for (unsigned int i = 0; i < M4; i++) {
72 0 : nntr_quantize_mat_q8_0_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
73 : K);
74 : }
75 : // Quantize leftover 1 ~ 3 rows with row-wise function
76 0 : for (unsigned int i = M4 * 4; i < M; i++) {
77 0 : nntr_quantize_row_q8_0(
78 0 : (float *)A + i * K,
79 0 : (QA.data() + (M4 * qa_4_rows_size) + (i - M4 * 4) * qa_row_size), K);
80 : }
81 :
82 : // Compute 4-divisible-M row portion with multithreaded GEMM
83 0 : #pragma omp parallel for num_threads(n_threads)
84 : for (int i = 0; i < n_threads; i++) {
85 : unsigned int src0_start = (i * N) / n_threads;
86 : unsigned int src0_end = ((i + 1) * N) / n_threads;
87 :
88 : src0_start = (src0_start % NB_COLS)
89 : ? src0_start + NB_COLS - (src0_start % NB_COLS)
90 : : src0_start;
91 : src0_end = (src0_end % NB_COLS)
92 : ? src0_end + NB_COLS - (src0_end % NB_COLS)
93 : : src0_end;
94 :
95 : nntr_gemm_q4_0_4x8_q8_0(K, (float *)(C + src0_start), ldc,
96 : (void *)((char *)B + src0_start * B_step),
97 : QA.data(), M4 * 4, src0_end - src0_start);
98 : }
99 :
100 : // Compute leftover 1 ~ 3 rows with multithreaded GEMV
101 : n_threads = 4;
102 0 : for (unsigned int pb = M4 * 4; pb < M; pb++) {
103 0 : #pragma omp parallel for num_threads(n_threads)
104 : for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
105 : unsigned int M_step_start = (thread_idx * N) / n_threads; // = 0
106 : unsigned int M_step_end =
107 : ((thread_idx + 1) * N) / n_threads; // ne01 = N
108 :
109 : M_step_start = (M_step_start % NB_COLS)
110 : ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
111 : : M_step_start;
112 : M_step_end = (M_step_end % NB_COLS)
113 : ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
114 : : M_step_end;
115 :
116 : nntr_gemv_q4_0_4x8_q8_0(
117 : K, (float *)((C + ((pb - M4 * 4) * N) + (M4 * 4 * N)) + M_step_start),
118 : N, (void *)((char *)B + M_step_start * B_step),
119 : QA.data() + (M4 * qa_4_rows_size) + (pb - M4 * 4) * qa_row_size, 1,
120 : M_step_end - M_step_start);
121 : }
122 : }
123 0 : } else { // GEMM
124 0 : unsigned int blocks_per_4_rows = (K + QK8_0 - 1) / QK8_0;
125 0 : unsigned int qa_4_rows_size = sizeof(block_q8_0x4) * blocks_per_4_rows;
126 0 : unsigned int M4 = ((M + 3) / 4);
127 :
128 0 : unsigned int qa_size = qa_4_rows_size * M4;
129 0 : std::vector<char> QA = std::vector<char>(qa_size);
130 :
131 : // Quantization of activations
132 : /// @note Heuristic inspection conducted that applying multithreading on
133 : /// run-time quantization hurts model latency
134 : // #pragma omp parallel for num_threads(16)
135 0 : for (int i = 0; i < static_cast<int>(M4); i++) {
136 0 : nntr_quantize_mat_q8_0_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
137 : K);
138 : }
139 0 : int thread_num = std::thread::hardware_concurrency() / 2;
140 0 : unsigned int B_step = sizeof(block_q4_0) * (K / QK4_0);
141 :
142 0 : #pragma omp parallel for num_threads(thread_num)
143 : for (int i = 0; i < thread_num; i++) {
144 : unsigned int src0_start = (i * N) / thread_num;
145 : unsigned int src0_end = ((i + 1) * N) / thread_num;
146 :
147 : src0_start = (src0_start % NB_COLS)
148 : ? src0_start + NB_COLS - (src0_start % NB_COLS)
149 : : src0_start;
150 : src0_end = (src0_end % NB_COLS)
151 : ? src0_end + NB_COLS - (src0_end % NB_COLS)
152 : : src0_end;
153 :
154 : nntr_gemm_q4_0_4x8_q8_0(K, (float *)(C + src0_start), ldc,
155 : (void *)((char *)B + src0_start * B_step),
156 : QA.data(), M, src0_end - src0_start);
157 : }
158 0 : }
159 0 : }
160 :
161 : template <>
162 0 : void __ggml_q4_0_4x8_q8_0_GEMM(const unsigned int M,
163 : std::vector<unsigned int> Ns,
164 : const unsigned int K, const float *A,
165 : const unsigned int lda, std::vector<void *> Bs,
166 : std::vector<unsigned int> ldbs,
167 : std::vector<float *> Cs,
168 : std::vector<unsigned int> ldcs) {
169 : int NB_COLS = 4;
170 0 : int B_step = sizeof(block_q4_0) * (K / QK4_0);
171 0 : int blocks_per_4_rows = (K + QK8_0 - 1) / QK8_0;
172 :
173 0 : if (M == 1) {
174 0 : int qa_size = sizeof(block_q8_0) * blocks_per_4_rows;
175 0 : std::vector<char> QA = std::vector<char>(qa_size);
176 : auto qa_data = QA.data();
177 0 : nntr_quantize_row_q8_0(A, qa_data, K);
178 0 : if (std::all_of(Ns.begin(), Ns.end(),
179 : [](unsigned int n) { return n <= 256; })) {
180 0 : for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
181 0 : unsigned int N = Ns[num_w];
182 0 : float *C = Cs[num_w];
183 0 : void *B = Bs[num_w];
184 :
185 : unsigned int M_step_start = 0;
186 : unsigned int M_step_end = N;
187 : M_step_start = (M_step_start % NB_COLS)
188 : ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
189 : : M_step_start;
190 0 : M_step_end = (M_step_end % NB_COLS)
191 0 : ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
192 : : M_step_end;
193 :
194 0 : nntr_gemv_q4_0_4x8_q8_0(K, (float *)(C + M_step_start), N,
195 : (void *)((char *)B + M_step_start * B_step),
196 : QA.data(), M, M_step_end - M_step_start);
197 : }
198 : } else {
199 : int n_threads = 1;
200 : // std::cout << "Parrallel gemv Ns.size(): " << Ns.size() << std::endl;
201 0 : #pragma omp parallel for num_threads(n_threads)
202 : for (int i = 0; i < n_threads; ++i) {
203 : for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
204 : unsigned int N = Ns[num_w];
205 : float *C = Cs[num_w];
206 : void *B = Bs[num_w];
207 : unsigned int M_step_start = (i * N) / n_threads;
208 : unsigned int M_step_end = ((i + 1) * N) / n_threads;
209 :
210 : M_step_start = (M_step_start % NB_COLS)
211 : ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
212 : : M_step_start;
213 : M_step_end = (M_step_end % NB_COLS)
214 : ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
215 : : M_step_end;
216 :
217 : nntr_gemv_q4_0_4x8_q8_0(K, (float *)(C + M_step_start), N,
218 : (void *)((char *)B + M_step_start * B_step),
219 : QA.data(), M, M_step_end - M_step_start);
220 : }
221 : }
222 : }
223 0 : } else {
224 : int n_threads = 4;
225 0 : unsigned int qa_4_rows_size = sizeof(block_q8_0x4) * blocks_per_4_rows;
226 0 : const size_t qa_row_size = (sizeof(block_q8_0) * K) / QK8_0;
227 :
228 0 : unsigned int M4 = ((M - M % 4) / 4);
229 0 : unsigned int qa_size = qa_4_rows_size * (((M >> 2) << 2) / 4 + 1);
230 :
231 0 : std::vector<char> QA = std::vector<char>(qa_size);
232 :
233 0 : for (unsigned int i = 0; i < M4; i++) {
234 0 : nntr_quantize_mat_q8_0_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
235 : K);
236 : }
237 :
238 0 : for (unsigned int i = M4 * 4; i < M; i++) {
239 0 : nntr_quantize_row_q8_0(
240 0 : (float *)A + i * K,
241 0 : (QA.data() + (M4 * qa_4_rows_size) + (i - M4 * 4) * qa_row_size), K);
242 : }
243 :
244 0 : #pragma omp parallel for schedule(guided) num_threads(n_threads)
245 : for (int i = 0; i < n_threads; i++) {
246 : for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
247 : unsigned int N = Ns[num_w];
248 : unsigned int ldc = ldcs[num_w];
249 :
250 : float *C = Cs[num_w];
251 : void *B = Bs[num_w];
252 :
253 : unsigned int src0_start = (i * N) / n_threads;
254 : unsigned int src0_end = ((i + 1) * N) / n_threads;
255 :
256 : src0_start = (src0_start % NB_COLS)
257 : ? src0_start + NB_COLS - (src0_start % NB_COLS)
258 : : src0_start;
259 :
260 : src0_end = (src0_end % NB_COLS)
261 : ? src0_end + NB_COLS - (src0_end % NB_COLS)
262 : : src0_end;
263 :
264 : nntr_gemm_q4_0_4x8_q8_0(K, (float *)(C + src0_start), ldc,
265 : (void *)((char *)B + src0_start * B_step),
266 : QA.data(), M4 * 4, src0_end - src0_start);
267 : }
268 : }
269 0 : if (M4 * 4 != M) {
270 : n_threads = 4;
271 0 : #pragma omp parallel for schedule(guided) num_threads(n_threads)
272 : for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
273 : for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
274 : unsigned int N = Ns[num_w];
275 : unsigned int ldc = ldcs[num_w];
276 : float *C = Cs[num_w];
277 : void *B = Bs[num_w];
278 :
279 : for (int pb = M4 * 4; pb < static_cast<int>(M); pb++) {
280 : unsigned int M_step_start = (thread_idx * N) / n_threads;
281 : unsigned int M_step_end = ((thread_idx + 1) * N) / n_threads;
282 : M_step_start = (M_step_start % NB_COLS)
283 : ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
284 : : M_step_start;
285 : M_step_end = (M_step_end % NB_COLS)
286 : ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
287 : : M_step_end;
288 :
289 : nntr_gemv_q4_0_4x8_q8_0(
290 : K,
291 : (float *)((C + ((pb - M4 * 4) * N) + (M4 * 4 * N)) +
292 : M_step_start),
293 : N, (void *)((char *)B + M_step_start * B_step),
294 : QA.data() + (M4 * qa_4_rows_size) + (pb - M4 * 4) * qa_row_size,
295 : 1, M_step_end - M_step_start);
296 : }
297 : }
298 : }
299 : }
300 0 : }
301 0 : }
302 :
303 55 : void __ggml_q4_0_8x8_q8_0_GEMM(const unsigned int M, const unsigned int N,
304 : const unsigned int K, const float *A,
305 : const unsigned int lda, const void *B,
306 : const unsigned int ldb, float *C,
307 : const unsigned int ldc) {
308 55 : if (M == 1) { // GEMV
309 : int n_threads = 4;
310 3 : unsigned int B_step = sizeof(block_q4_0) * (K / QK4_0);
311 3 : unsigned int blocks_per_row = (K + QK8_0 - 1) / QK8_0;
312 3 : unsigned int qa_size = sizeof(block_q8_0) * blocks_per_row;
313 3 : std::vector<char> QA = std::vector<char>(qa_size);
314 3 : nntr_quantize_row_q8_0(A, QA.data(), K);
315 :
316 3 : #pragma omp parallel for num_threads(n_threads)
317 : for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
318 : unsigned int M_step_start = (thread_idx * N) / n_threads; // = 0
319 : unsigned int M_step_end = ((thread_idx + 1) * N) / n_threads; // ne01 = N
320 :
321 : M_step_start = (M_step_start % 8) ? M_step_start + 8 - (M_step_start % 8)
322 : : M_step_start;
323 : M_step_end =
324 : (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
325 :
326 : nntr_gemv_q4_0_8x8_q8_0(K, (float *)((C) + M_step_start), N,
327 : (void *)((char *)B + M_step_start * B_step),
328 : QA.data(), M, M_step_end - M_step_start);
329 : }
330 3 : } else { // GEMM
331 52 : int n_threads = std::thread::hardware_concurrency() / 2;
332 52 : unsigned int blocks_per_4_rows = (K + QK8_0 - 1) / QK8_0;
333 52 : unsigned int qa_4_rows_size = sizeof(block_q8_0x4) * blocks_per_4_rows;
334 52 : const size_t qa_row_size = (sizeof(block_q8_0) * K) / QK8_0;
335 52 : unsigned int M4 = ((M - M % 4) / 4);
336 52 : int B_step = sizeof(block_q4_0) * (K / QK4_0);
337 :
338 52 : unsigned int qa_size = qa_4_rows_size * (((M >> 2) << 2) / 4 + 1);
339 52 : std::vector<char> QA = std::vector<char>(qa_size);
340 :
341 : // Quantize 4-divisible-M row portion with matrix-wise function
342 570 : for (unsigned int i = 0; i < M4; i++) {
343 518 : nntr_quantize_mat_q8_0_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
344 : K);
345 : }
346 : // Quantize leftover 1 ~ 3 rows with row-wise function
347 58 : for (unsigned int i = M4 * 4; i < M; i++) {
348 6 : nntr_quantize_row_q8_0(
349 6 : (float *)A + i * K,
350 6 : (QA.data() + (M4 * qa_4_rows_size) + (i - M4 * 4) * qa_row_size), K);
351 : }
352 :
353 : // Compute 4-divisible-M row portion with multithreaded GEMM
354 52 : #pragma omp parallel for num_threads(n_threads)
355 : for (int i = 0; i < n_threads; i++) {
356 : unsigned int src0_start = (i * N) / n_threads;
357 : unsigned int src0_end = ((i + 1) * N) / n_threads;
358 :
359 : src0_start =
360 : (src0_start % 8) ? src0_start + 8 - (src0_start % 8) : src0_start;
361 : src0_end = (src0_end % 8) ? src0_end + 8 - (src0_end % 8) : src0_end;
362 :
363 : nntr_gemm_q4_0_8x8_q8_0(K, (float *)(C + src0_start), ldc,
364 : (void *)((char *)B + src0_start * B_step),
365 : QA.data(), M4 * 4, src0_end - src0_start);
366 : }
367 :
368 : // Compute leftover 1 ~ 3 rows with multithreaded GEMV
369 : n_threads = 4;
370 58 : for (unsigned int pb = M4 * 4; pb < M; pb++) {
371 6 : #pragma omp parallel for num_threads(n_threads)
372 : for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
373 : unsigned int M_step_start = (thread_idx * N) / n_threads;
374 : unsigned int M_step_end = ((thread_idx + 1) * N) / n_threads;
375 :
376 : M_step_start = (M_step_start % 8)
377 : ? M_step_start + 8 - (M_step_start % 8)
378 : : M_step_start;
379 : M_step_end =
380 : (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
381 :
382 : nntr_gemv_q4_0_8x8_q8_0(
383 : K, (float *)((C + ((pb - M4 * 4) * N) + (M4 * 4 * N)) + M_step_start),
384 : N, (void *)((char *)B + M_step_start * B_step),
385 : QA.data() + (M4 * qa_4_rows_size) + (pb - M4 * 4) * qa_row_size, 1,
386 : M_step_end - M_step_start);
387 : }
388 : }
389 52 : }
390 55 : }
391 :
392 : template <>
393 0 : void __ggml_q4_0_8x8_q8_0_GEMM(const unsigned int M,
394 : std::vector<unsigned int> Ns,
395 : const unsigned int K, const float *A,
396 : const unsigned int lda, std::vector<void *> Bs,
397 : std::vector<unsigned int> ldbs,
398 : std::vector<float *> C,
399 : std::vector<unsigned int> ldcs) {
400 : throw std::runtime_error("nntrainer::__ggml_q4_0_8x8_q8_0_GEMM for "
401 0 : "multi-weights is not implemented yet");
402 : }
403 :
404 9 : void __ggml_q4_K_8x8_q8_K_GEMM(const unsigned int M, const unsigned int N,
405 : const unsigned int K, const float *A,
406 : const unsigned int lda, const void *B,
407 : const unsigned int ldb, float *C,
408 : const unsigned int ldc) {
409 9 : if (M == 1) { // GEMV
410 : int n_threads = 4;
411 5 : unsigned int blocks_per_row = (K + QK_K - 1) / QK_K;
412 5 : unsigned int qa_size = sizeof(block_q8_K) * blocks_per_row;
413 5 : unsigned int B_step = sizeof(block_q4_K) * (K / QK_K);
414 :
415 5 : std::vector<char> QA = std::vector<char>(qa_size);
416 :
417 5 : nntr_quantize_row_q8_K(A, QA.data(), K);
418 :
419 5 : #pragma omp parallel for num_threads(n_threads)
420 : for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
421 : unsigned int M_step_start = (thread_idx * N) / n_threads; // = 0
422 : unsigned int M_step_end = ((thread_idx + 1) * N) / n_threads; // ne01 = N
423 :
424 : M_step_start = (M_step_start % 8) ? M_step_start + 8 - (M_step_start % 8)
425 : : M_step_start;
426 : M_step_end =
427 : (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
428 :
429 : nntr_gemv_q4_K_8x8_q8_K(K, (float *)((C) + M_step_start), N,
430 : (void *)((char *)B + M_step_start * B_step),
431 : QA.data(), M, M_step_end - M_step_start);
432 : }
433 5 : } else {
434 4 : int n_threads = std::thread::hardware_concurrency() / 2;
435 4 : unsigned int blocks_per_4_rows = (K + QK_K - 1) / QK_K;
436 4 : unsigned int qa_4_rows_size = sizeof(block_q8_Kx4) * blocks_per_4_rows;
437 4 : const size_t qa_row_size = (sizeof(block_q8_K) * K) / QK_K;
438 4 : unsigned int M4 = ((M - M % 4) / 4);
439 4 : int B_step = sizeof(block_q4_K) * (K / QK_K);
440 :
441 4 : unsigned int qa_size = qa_4_rows_size * (((M >> 2) << 2) / 4 + 1);
442 4 : std::vector<char> QA = std::vector<char>(qa_size);
443 :
444 : // Quantize 4-divisible-M row portion with matrix-wise function
445 474 : for (unsigned int i = 0; i < M4; i++) {
446 470 : nntr_quantize_mat_q8_K_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
447 : K);
448 : }
449 : // Quantize leftover 1 ~ 3 rows with row-wise function
450 10 : for (unsigned int i = M4 * 4; i < M; i++) {
451 6 : nntr_quantize_row_q8_K(
452 6 : (float *)A + i * K,
453 6 : (QA.data() + (M4 * qa_4_rows_size) + (i - M4 * 4) * qa_row_size), K);
454 : }
455 :
456 : // Compute 4-divisible-M row portion with multithreaded GEMM
457 4 : #pragma omp parallel for num_threads(n_threads)
458 : for (int i = 0; i < n_threads; i++) {
459 : unsigned int src0_start = (i * N) / n_threads;
460 : unsigned int src0_end = ((i + 1) * N) / n_threads;
461 :
462 : src0_start =
463 : (src0_start % 8) ? src0_start + 8 - (src0_start % 8) : src0_start;
464 : src0_end = (src0_end % 8) ? src0_end + 8 - (src0_end % 8) : src0_end;
465 :
466 : nntr_gemm_q4_K_8x8_q8_K(K, (float *)(C + src0_start), ldc,
467 : (void *)((char *)B + src0_start * B_step),
468 : QA.data(), M4 * 4, src0_end - src0_start);
469 : }
470 :
471 : // Compute leftover 1 ~ 3 rows with multithreaded GEMV
472 : n_threads = 4;
473 10 : for (unsigned int pb = M4 * 4; pb < M; pb++) {
474 6 : #pragma omp parallel for num_threads(n_threads)
475 : for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
476 : unsigned int M_step_start = (thread_idx * N) / n_threads;
477 : unsigned int M_step_end = ((thread_idx + 1) * N) / n_threads;
478 :
479 : M_step_start = (M_step_start % 8)
480 : ? M_step_start + 8 - (M_step_start % 8)
481 : : M_step_start;
482 : M_step_end =
483 : (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
484 :
485 : nntr_gemv_q4_K_8x8_q8_K(
486 : K, (float *)((C + ((pb - M4 * 4) * N) + (M4 * 4 * N)) + M_step_start),
487 : N, (void *)((char *)B + M_step_start * B_step),
488 : QA.data() + (M4 * qa_4_rows_size) + (pb - M4 * 4) * qa_row_size, 1,
489 : M_step_end - M_step_start);
490 : }
491 : }
492 4 : }
493 9 : }
494 :
495 0 : void __ggml_q4_K_8x8_q8_K_GEMM(const unsigned int M,
496 : std::vector<unsigned int> Ns,
497 : const unsigned int K, const float *A,
498 : const unsigned int lda, std::vector<void *> Bs,
499 : std::vector<unsigned int> ldbs,
500 : std::vector<float *> C,
501 : std::vector<unsigned int> ldcs) {
502 : throw std::runtime_error("nntrainer::__ggml_q4_K_8x8_q8_K_GEMM for "
503 0 : "multi-weights is not implemented yet");
504 : }
505 :
506 : template <>
507 7 : void __ggml_gemm_q6_K(const unsigned int M, const unsigned int N,
508 : const unsigned int K, const float *A,
509 : const unsigned int lda, const void *B,
510 : const unsigned int ldb, float *C,
511 : const unsigned int ldc) {
512 7 : int32_t thread_count = std::thread::hardware_concurrency() / 2;
513 :
514 : static constexpr const int32_t bs = 1; // unused in ggml_vec_dot_q6_K_q8_K
515 : static constexpr const int32_t bx = 1; // unused in ggml_vec_dot_q6_K_q8_K
516 : static constexpr const int32_t by = 1; // unused in ggml_vec_dot_q6_K_q8_K
517 : static constexpr const int32_t nrc = 1; // unused in ggml_vec_dot_q6_K_q8_K
518 :
519 7 : const int32_t blocks_per_row = (K + QK_K - 1) / QK_K;
520 7 : const int32_t A_row_size = sizeof(block_q8_K) * blocks_per_row;
521 7 : const int32_t B_row_size = sizeof(block_q6_K) * blocks_per_row;
522 :
523 : // GEMV
524 7 : if (M == 1) {
525 : thread_count = 4;
526 3 : std::vector<char> quantized_A(A_row_size);
527 3 : nntr_quantize_row_q8_K(A, quantized_A.data(), K);
528 :
529 : const void *const quantized_A_data = quantized_A.data();
530 :
531 3 : #pragma omp parallel for num_threads(thread_count)
532 : for (int32_t thread_job = 0; thread_job < static_cast<int>(N);
533 : thread_job++) {
534 : const int32_t B_row_data_offset = B_row_size * thread_job;
535 :
536 : const void *const B_data = (void *)((char *)B + B_row_data_offset);
537 :
538 : nntr_vec_dot_q6_K_q8_K(K, &C[thread_job], bs, B_data, bx,
539 : quantized_A_data, by, nrc);
540 : }
541 3 : } else { // GEMM
542 4 : const int32_t A_total_size = A_row_size * M;
543 4 : std::vector<char> quantized_A(A_total_size);
544 :
545 4 : #pragma omp parallel for num_threads(thread_count)
546 : for (int32_t thread_job = 0; thread_job < static_cast<int>(M);
547 : thread_job++) {
548 : const int32_t A_row_data_offset = A_row_size * thread_job;
549 : void *A_data = (void *)((char *)quantized_A.data() + A_row_data_offset);
550 : nntr_quantize_row_q8_K(A + thread_job * K, A_data, K);
551 : }
552 4 : #pragma omp parallel for num_threads(thread_count)
553 : for (int32_t thread_job = 0; thread_job < static_cast<int>(M);
554 : thread_job++) {
555 : const int32_t A_row_data_offset = A_row_size * thread_job;
556 : void *A_data = (void *)((char *)quantized_A.data() + A_row_data_offset);
557 :
558 : for (uint32_t j = 0; j < N; j++) {
559 : const int32_t B_row_data_offset = B_row_size * j;
560 : const void *const B_data = (void *)((char *)B + B_row_data_offset);
561 :
562 : nntr_vec_dot_q6_K_q8_K(K, &C[thread_job * ldc + j], bs, B_data, bx,
563 : A_data, by, nrc);
564 : }
565 : }
566 4 : }
567 7 : }
568 :
569 : } // namespace nntrainer
|