LCOV - code coverage report
Current view: top level - nntrainer/tensor/cpu_backend/ggml_interface - ggml_interface_mixed.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 42.0 % 362 152
Test Date: 2025-12-14 20:38:17 Functions: 55.6 % 27 15

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2025 Sungsik Kong <ss.kong@samsung.com>
       4              :  *
       5              :  * @file   ggml_interface_mixed.cpp
       6              :  * @date   15 April 2025
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Michal Wlasiuk <testmailsmtp12345@gmail.com>
       9              :  * @author Sungsik Kong <ss.kong@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  * @brief  Function interface to use ggml lib from cpu_backend. This file is
      12              :  * knowned to be optimized for GB devices on Windows
      13              :  */
      14              : 
      15              : #include <algorithm>
      16              : #include <bs_thread_pool_manager.hpp>
      17              : #include <cmath>
      18              : #include <ggml_interface.h>
      19              : #include <nntr_ggml_impl.h>
      20              : #include <nntr_ggml_impl_utils.h>
      21              : #include <string>
      22              : #include <thread>
      23              : #include <vector>
      24              : 
      25              : namespace nntrainer {
      26              : 
      27            0 : static inline void __ggml_q4_0_4x8_q8_0_GEMM_GEMV(
      28              :   const unsigned int M, const unsigned int N, const unsigned int K,
      29              :   const float *A, const unsigned int lda, const void *B, const unsigned int ldb,
      30              :   float *C, const unsigned int ldc) {
      31              :   int NB_COLS = 4;
      32            0 :   int blocks_per_row = (K + QK8_0 - 1) / QK8_0;
      33            0 :   int qa_size = sizeof(block_q8_0) * blocks_per_row;
      34            0 :   std::vector<char> QA = std::vector<char>(qa_size);
      35              : 
      36              :   auto qa_data = QA.data();
      37              : 
      38            0 :   nntr_quantize_row_q8_0(A, qa_data, K);
      39            0 :   int B_step = sizeof(block_q4_0) * (K / QK4_0);
      40              : 
      41            0 :   auto &bs_thread_pool = ThreadPoolManager::Global().getThreadPool();
      42            0 :   int thread_num = bs_thread_pool.get_thread_count();
      43              :   BS::multi_future<void> loop_future =
      44            0 :     bs_thread_pool.submit_loop(0, thread_num, [=](int i) {
      45            0 :       unsigned int M_step_start = (i * N) / thread_num;
      46            0 :       unsigned int M_step_end = ((i + 1) * N) / thread_num;
      47              : 
      48            0 :       M_step_start = (M_step_start % NB_COLS)
      49            0 :                        ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
      50              :                        : M_step_start;
      51            0 :       M_step_end = (M_step_end % NB_COLS)
      52            0 :                      ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
      53              :                      : M_step_end;
      54              : 
      55            0 :       nntr_gemv_q4_0_4x8_q8_0(K, (float *)(C + M_step_start), N,
      56            0 :                               (void *)((char *)B + M_step_start * B_step),
      57            0 :                               QA.data(), M, M_step_end - M_step_start);
      58            0 :     });
      59              :   loop_future.wait();
      60            0 : }
      61              : 
      62            0 : static inline void __ggml_q4_0_4x8_q8_0_GEMM_GEMM(
      63              :   const unsigned int M, const unsigned int N, const unsigned int K,
      64              :   const float *A, const unsigned int lda, const void *B, const unsigned int ldb,
      65              :   float *C, const unsigned int ldc) {
      66              :   int NB_COLS = 4;
      67            0 :   auto &bs_thread_pool = ThreadPoolManager::Global().getThreadPool();
      68            0 :   unsigned int blocks_per_4_rows = (K + QK8_0 - 1) / QK8_0;
      69            0 :   unsigned int qa_4_rows_size = sizeof(block_q8_0x4) * blocks_per_4_rows;
      70            0 :   const size_t qa_row_size = (sizeof(block_q8_0) * K) / QK8_0;
      71            0 :   unsigned int M4 = ((M - M % 4) / 4);
      72            0 :   int B_step = sizeof(block_q4_0) * (K / QK4_0);
      73              : 
      74            0 :   unsigned int qa_size = qa_4_rows_size * (((M >> 2) << 2) / 4 + 1);
      75            0 :   std::vector<char> QA = std::vector<char>(qa_size);
      76              : 
      77              :   // Quantize 4-divisible-M row portion with matrix-wise function
      78            0 :   for (unsigned int i = 0; i < M4; i++) {
      79            0 :     nntr_quantize_mat_q8_0_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
      80              :                                K);
      81              :   }
      82              :   // Quantize leftover 1 ~ 3 rows with row-wise function
      83            0 :   for (unsigned int i = M4 * 4; i < M; i++) {
      84            0 :     nntr_quantize_row_q8_0(
      85            0 :       (float *)A + i * K,
      86            0 :       (QA.data() + (M4 * qa_4_rows_size) + (i - M4 * 4) * qa_row_size), K);
      87              :   }
      88              : 
      89              :   ///@todo Dynamic thread-number selection for GEMM problem size
      90            0 :   int thread_num = bs_thread_pool.get_thread_count();
      91              :   BS::multi_future<void> multi_future =
      92            0 :     bs_thread_pool.submit_loop(0, thread_num, [=](int i) {
      93            0 :       unsigned int M_step_start = (i * N) / thread_num;
      94            0 :       unsigned int M_step_end = ((i + 1) * N) / thread_num;
      95              : 
      96            0 :       M_step_start = (M_step_start % NB_COLS)
      97            0 :                        ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
      98              :                        : M_step_start;
      99            0 :       M_step_end = (M_step_end % NB_COLS)
     100            0 :                      ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
     101              :                      : M_step_end;
     102              : 
     103            0 :       nntr_gemm_q4_0_4x8_q8_0(K, (C + (M_step_start)), ldc,
     104            0 :                               ((char *)B + ((M_step_start)*B_step)), QA.data(),
     105            0 :                               M4 * 4, (M_step_end) - (M_step_start));
     106            0 :     });
     107              :   multi_future.wait();
     108              : 
     109            0 :   for (unsigned int pb = M4 * 4; pb < M; pb++) {
     110              :     BS::multi_future<void> loop_future =
     111            0 :       bs_thread_pool.submit_loop(0, thread_num, [=](int i) {
     112            0 :         unsigned int M_step_start = (i * N) / thread_num;
     113            0 :         unsigned int M_step_end = ((i + 1) * N) / thread_num;
     114              : 
     115            0 :         M_step_start = (M_step_start % NB_COLS)
     116            0 :                          ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
     117              :                          : M_step_start;
     118            0 :         M_step_end = (M_step_end % NB_COLS)
     119            0 :                        ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
     120              :                        : M_step_end;
     121              : 
     122            0 :         nntr_gemv_q4_0_4x8_q8_0(
     123            0 :           K, (float *)((C + ((pb - M4 * 4) * N) + (M4 * 4 * N)) + M_step_start),
     124            0 :           N, (void *)((char *)B + M_step_start * B_step),
     125            0 :           QA.data() + (M4 * qa_4_rows_size) + (pb - M4 * 4) * qa_row_size, 1,
     126            0 :           M_step_end - M_step_start);
     127            0 :       });
     128              :     loop_future.wait();
     129              :   }
     130            0 : }
     131              : 
     132              : template <>
     133            0 : void __ggml_q4_0_4x8_q8_0_GEMM(const unsigned int M, const unsigned int N,
     134              :                                const unsigned int K, const float *A,
     135              :                                const unsigned int lda, const void *B,
     136              :                                const unsigned int ldb, float *C,
     137              :                                const unsigned int ldc) {
     138            0 :   if (M == 1) { // GEMV
     139            0 :     __ggml_q4_0_4x8_q8_0_GEMM_GEMV(M, N, K, A, lda, B, ldb, C, ldc);
     140              :   } else { // GEMM
     141            0 :     __ggml_q4_0_4x8_q8_0_GEMM_GEMM(M, N, K, A, lda, B, ldb, C, ldc);
     142              :   }
     143            0 : }
     144              : 
     145              : template <>
     146            0 : void __ggml_q4_0_4x8_q8_0_GEMM(const unsigned int M,
     147              :                                std::vector<unsigned int> Ns,
     148              :                                const unsigned int K, const float *A,
     149              :                                const unsigned int lda, std::vector<void *> Bs,
     150              :                                std::vector<unsigned int> ldbs,
     151              :                                std::vector<float *> Cs,
     152              :                                std::vector<unsigned int> ldcs) {
     153            0 :   auto &bs_thread_pool = ThreadPoolManager::Global().getThreadPool();
     154            0 :   int thread_num = bs_thread_pool.get_thread_count();
     155              : 
     156              :   int NB_COLS = 4;
     157            0 :   int B_step = sizeof(block_q4_0) * (K / QK4_0);
     158            0 :   int blocks_per_4_rows = (K + QK8_0 - 1) / QK8_0;
     159              : 
     160            0 :   if (M == 1) {
     161            0 :     int qa_size = sizeof(block_q8_0) * blocks_per_4_rows;
     162            0 :     std::vector<char> QA = std::vector<char>(qa_size);
     163              :     auto qa_data = QA.data();
     164            0 :     nntr_quantize_row_q8_0(A, qa_data, K);
     165            0 :     if (std::all_of(Ns.begin(), Ns.end(),
     166              :                     [](unsigned int n) { return n <= 256; })) {
     167            0 :       for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     168            0 :         unsigned int N = Ns[num_w];
     169            0 :         float *C = Cs[num_w];
     170            0 :         void *B = Bs[num_w];
     171              : 
     172              :         unsigned int M_step_start = 0;
     173              :         unsigned int M_step_end = N;
     174              :         M_step_start = (M_step_start % NB_COLS)
     175              :                          ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
     176              :                          : M_step_start;
     177            0 :         M_step_end = (M_step_end % NB_COLS)
     178            0 :                        ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
     179              :                        : M_step_end;
     180              : 
     181            0 :         nntr_gemv_q4_0_4x8_q8_0(K, (float *)(C + M_step_start), N,
     182              :                                 (void *)((char *)B + M_step_start * B_step),
     183              :                                 QA.data(), M, M_step_end - M_step_start);
     184              :       }
     185              :     } else {
     186              :       BS::multi_future<void> loop_future =
     187            0 :         bs_thread_pool.submit_loop(0, thread_num, [=](int i) {
     188            0 :           for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     189            0 :             unsigned int N = Ns[num_w];
     190            0 :             float *C = Cs[num_w];
     191            0 :             void *B = Bs[num_w];
     192            0 :             unsigned int M_step_start = (i * N) / thread_num;
     193            0 :             unsigned int M_step_end = ((i + 1) * N) / thread_num;
     194              : 
     195            0 :             M_step_start = (M_step_start % NB_COLS)
     196            0 :                              ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
     197              :                              : M_step_start;
     198            0 :             M_step_end = (M_step_end % NB_COLS)
     199            0 :                            ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
     200              :                            : M_step_end;
     201              : 
     202            0 :             nntr_gemv_q4_0_4x8_q8_0(K, (float *)(C + M_step_start), N,
     203            0 :                                     (void *)((char *)B + M_step_start * B_step),
     204            0 :                                     QA.data(), M, M_step_end - M_step_start);
     205              :           }
     206            0 :         });
     207              :       loop_future.wait();
     208              :     }
     209            0 :   } else {
     210            0 :     int n_threads = std::thread::hardware_concurrency() / 2;
     211            0 :     unsigned int qa_4_rows_size = sizeof(block_q8_0x4) * blocks_per_4_rows;
     212            0 :     const size_t qa_row_size = (sizeof(block_q8_0) * K) / QK8_0;
     213              : 
     214            0 :     unsigned int M4 = ((M - M % 4) / 4);
     215            0 :     unsigned int qa_size = qa_4_rows_size * (((M >> 2) << 2) / 4 + 1);
     216              : 
     217            0 :     std::vector<char> QA = std::vector<char>(qa_size);
     218              : 
     219            0 :     for (unsigned int i = 0; i < M4; i++) {
     220            0 :       nntr_quantize_mat_q8_0_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
     221              :                                  K);
     222              :     }
     223              : 
     224            0 :     for (unsigned int i = M4 * 4; i < M; i++) {
     225            0 :       nntr_quantize_row_q8_0(
     226            0 :         (float *)A + i * K,
     227            0 :         (QA.data() + (M4 * qa_4_rows_size) + (i - M4 * 4) * qa_row_size), K);
     228              :     }
     229              : 
     230            0 : #pragma omp parallel for schedule(guided) num_threads(n_threads)
     231              :     for (int i = 0; i < n_threads; i++) {
     232              :       for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     233              :         unsigned int N = Ns[num_w];
     234              :         unsigned int ldc = ldcs[num_w];
     235              : 
     236              :         float *C = Cs[num_w];
     237              :         void *B = Bs[num_w];
     238              : 
     239              :         unsigned int src0_start = (i * N) / n_threads;
     240              :         unsigned int src0_end = ((i + 1) * N) / n_threads;
     241              : 
     242              :         src0_start = (src0_start % NB_COLS)
     243              :                        ? src0_start + NB_COLS - (src0_start % NB_COLS)
     244              :                        : src0_start;
     245              : 
     246              :         src0_end = (src0_end % NB_COLS)
     247              :                      ? src0_end + NB_COLS - (src0_end % NB_COLS)
     248              :                      : src0_end;
     249              : 
     250              :         nntr_gemm_q4_0_4x8_q8_0(K, (float *)(C + src0_start), ldc,
     251              :                                 (void *)((char *)B + src0_start * B_step),
     252              :                                 QA.data(), M4 * 4, src0_end - src0_start);
     253              :       }
     254              :     }
     255              : 
     256              :     n_threads = 4;
     257            0 : #pragma omp parallel for schedule(guided) num_threads(n_threads)
     258              :     for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
     259              :       for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     260              :         unsigned int N = Ns[num_w];
     261              :         unsigned int ldc = ldcs[num_w];
     262              :         float *C = Cs[num_w];
     263              :         void *B = Bs[num_w];
     264              : 
     265              :         for (int pb = M4 * 4; pb < static_cast<int>(M); pb++) {
     266              :           unsigned int M_step_start = (thread_idx * N) / n_threads;
     267              :           unsigned int M_step_end = ((thread_idx + 1) * N) / n_threads;
     268              :           M_step_start = (M_step_start % NB_COLS)
     269              :                            ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
     270              :                            : M_step_start;
     271              :           M_step_end = (M_step_end % NB_COLS)
     272              :                          ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
     273              :                          : M_step_end;
     274              : 
     275              :           nntr_gemv_q4_0_4x8_q8_0(
     276              :             K,
     277              :             (float *)((C + ((pb - M4 * 4) * N) + (M4 * 4 * N)) + M_step_start),
     278              :             N, (void *)((char *)B + M_step_start * B_step),
     279              :             QA.data() + (M4 * qa_4_rows_size) + (pb - M4 * 4) * qa_row_size, 1,
     280              :             M_step_end - M_step_start);
     281              :         }
     282              :       }
     283              :     }
     284            0 :   }
     285            0 : }
     286              : 
     287            3 : static inline void __ggml_q4_0_8x8_q8_0_GEMM_GEMV(
     288              :   const unsigned int M, const unsigned int N, const unsigned int K,
     289              :   const float *A, const unsigned int lda, const void *B, const unsigned int ldb,
     290              :   float *C, const unsigned int ldc) {
     291            3 :   int blocks_per_row = (K + QK8_0 - 1) / QK8_0;
     292            3 :   int qa_size = sizeof(block_q8_0) * blocks_per_row;
     293            3 :   std::vector<char> QA = std::vector<char>(qa_size);
     294              : 
     295              :   auto qa_data = QA.data();
     296              : 
     297            3 :   nntr_quantize_row_q8_0(A, qa_data, K);
     298            3 :   int B_step = sizeof(block_q4_0) * (K / QK4_0);
     299              : 
     300            3 :   auto &bs_thread_pool = ThreadPoolManager::Global().getThreadPool();
     301            3 :   int thread_num = bs_thread_pool.get_thread_count();
     302              :   BS::multi_future<void> loop_future =
     303            9 :     bs_thread_pool.submit_loop(0, thread_num, [=](int i) {
     304           12 :       unsigned int M_step_start = (i * N) / thread_num;
     305           12 :       unsigned int M_step_end = ((i + 1) * N) / thread_num;
     306              : 
     307           12 :       M_step_start = (M_step_start % 8) ? M_step_start + 8 - (M_step_start % 8)
     308              :                                         : M_step_start;
     309              :       M_step_end =
     310           12 :         (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     311              : 
     312           12 :       nntr_gemv_q4_0_8x8_q8_0(K, (float *)(C + M_step_start), N,
     313           12 :                               (void *)((char *)B + M_step_start * B_step),
     314           12 :                               QA.data(), M, M_step_end - M_step_start);
     315            3 :     });
     316              :   loop_future.wait();
     317            3 : }
     318              : 
     319           52 : static inline void __ggml_q4_0_8x8_q8_0_GEMM_GEMM(
     320              :   const unsigned int M, const unsigned int N, const unsigned int K,
     321              :   const float *A, const unsigned int lda, const void *B, const unsigned int ldb,
     322              :   float *C, const unsigned int ldc) {
     323           52 :   auto &bs_thread_pool = ThreadPoolManager::Global().getThreadPool();
     324           52 :   unsigned int blocks_per_4_rows = (K + QK8_0 - 1) / QK8_0;
     325           52 :   unsigned int qa_4_rows_size = sizeof(block_q8_0x4) * blocks_per_4_rows;
     326           52 :   const size_t qa_row_size = (sizeof(block_q8_0) * K) / QK8_0;
     327           52 :   unsigned int M4 = ((M - M % 4) / 4);
     328           52 :   int B_step = sizeof(block_q4_0) * (K / QK4_0);
     329              : 
     330           52 :   unsigned int qa_size = qa_4_rows_size * (((M >> 2) << 2) / 4 + 1);
     331           52 :   std::vector<char> QA = std::vector<char>(qa_size);
     332              : 
     333              :   // Quantize 4-divisible-M row portion with matrix-wise function
     334          570 :   for (unsigned int i = 0; i < M4; i++) {
     335          518 :     nntr_quantize_mat_q8_0_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
     336              :                                K);
     337              :   }
     338              :   // Quantize leftover 1 ~ 3 rows with row-wise function
     339           58 :   for (unsigned int i = M4 * 4; i < M; i++) {
     340            6 :     nntr_quantize_row_q8_0(
     341            6 :       (float *)A + i * K,
     342            6 :       (QA.data() + (M4 * qa_4_rows_size) + (i - M4 * 4) * qa_row_size), K);
     343              :   }
     344              : 
     345              :   ///@todo Dynamic thread-number selection for GEMM problem size
     346           52 :   int thread_num = bs_thread_pool.get_thread_count();
     347              :   BS::multi_future<void> multi_future =
     348          156 :     bs_thread_pool.submit_loop(0, thread_num, [=](int i) {
     349          208 :       unsigned int M_step_start = (i * N) / thread_num;
     350          208 :       unsigned int M_step_end = ((i + 1) * N) / thread_num;
     351              : 
     352          208 :       M_step_start = (M_step_start % 8) ? M_step_start + 8 - (M_step_start % 8)
     353              :                                         : M_step_start;
     354              :       M_step_end =
     355          208 :         (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     356              : 
     357          208 :       nntr_gemm_q4_0_8x8_q8_0(K, (C + (M_step_start)), ldc,
     358          208 :                               ((char *)B + ((M_step_start)*B_step)), QA.data(),
     359          208 :                               M4 * 4, (M_step_end) - (M_step_start));
     360           52 :     });
     361              :   multi_future.wait();
     362              : 
     363           58 :   for (unsigned int pb = M4 * 4; pb < M; pb++) {
     364              :     BS::multi_future<void> loop_future =
     365           18 :       bs_thread_pool.submit_loop(0, thread_num, [=](int i) {
     366           24 :         unsigned int M_step_start = (i * N) / thread_num;
     367           24 :         unsigned int M_step_end = ((i + 1) * N) / thread_num;
     368              : 
     369           24 :         M_step_start = (M_step_start % 8)
     370           24 :                          ? M_step_start + 8 - (M_step_start % 8)
     371              :                          : M_step_start;
     372              :         M_step_end =
     373           24 :           (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     374              : 
     375           24 :         nntr_gemv_q4_0_8x8_q8_0(
     376           24 :           K, (float *)((C + ((pb - M4 * 4) * N) + (M4 * 4 * N)) + M_step_start),
     377           24 :           N, (void *)((char *)B + M_step_start * B_step),
     378           24 :           QA.data() + (M4 * qa_4_rows_size) + (pb - M4 * 4) * qa_row_size, 1,
     379           24 :           M_step_end - M_step_start);
     380            6 :       });
     381              :     loop_future.wait();
     382              :   }
     383           52 : }
     384              : 
     385           55 : void __ggml_q4_0_8x8_q8_0_GEMM(const unsigned int M, const unsigned int N,
     386              :                                const unsigned int K, const float *A,
     387              :                                const unsigned int lda, const void *B,
     388              :                                const unsigned int ldb, float *C,
     389              :                                const unsigned int ldc) {
     390           55 :   if (M == 1) { // GEMV
     391            3 :     __ggml_q4_0_8x8_q8_0_GEMM_GEMV(M, N, K, A, lda, B, ldb, C, ldc);
     392              :   } else { // GEMM
     393           52 :     __ggml_q4_0_8x8_q8_0_GEMM_GEMM(M, N, K, A, lda, B, ldb, C, ldc);
     394              :   }
     395           55 : }
     396              : 
     397              : template <>
     398            0 : void __ggml_q4_0_8x8_q8_0_GEMM(const unsigned int M,
     399              :                                std::vector<unsigned int> Ns,
     400              :                                const unsigned int K, const float *A,
     401              :                                const unsigned int lda, std::vector<void *> Bs,
     402              :                                std::vector<unsigned int> ldbs,
     403              :                                std::vector<float *> Cs,
     404              :                                std::vector<unsigned int> ldcs) {
     405            0 :   auto &bs_thread_pool = ThreadPoolManager::Global().getThreadPool();
     406            0 :   int thread_num = bs_thread_pool.get_thread_count();
     407              : 
     408            0 :   int B_step = sizeof(block_q4_0) * (K / QK4_0);
     409            0 :   int blocks_per_4_rows = (K + QK8_0 - 1) / QK8_0;
     410              : 
     411            0 :   if (M == 1) {
     412            0 :     int qa_size = sizeof(block_q8_0) * blocks_per_4_rows;
     413            0 :     std::vector<char> QA = std::vector<char>(qa_size);
     414              :     auto qa_data = QA.data();
     415            0 :     nntr_quantize_row_q8_0(A, qa_data, K);
     416              : 
     417            0 :     for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     418            0 :       unsigned int N = Ns[num_w];
     419            0 :       float *C = Cs[num_w];
     420            0 :       void *B = Bs[num_w];
     421              : 
     422            0 :       if (N <= 256) {
     423              :         unsigned int M_step_start = 0;
     424              :         unsigned int M_step_end = N;
     425              :         M_step_start = (M_step_start % 8)
     426              :                          ? M_step_start + 8 - (M_step_start % 8)
     427              :                          : M_step_start;
     428              :         M_step_end =
     429            0 :           (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     430              : 
     431            0 :         nntr_gemv_q4_0_8x8_q8_0(K, (float *)(C + M_step_start), N,
     432              :                                 (void *)((char *)B + M_step_start * B_step),
     433              :                                 QA.data(), M, M_step_end - M_step_start);
     434              :       }
     435              :     }
     436              : 
     437              :     BS::multi_future<void> loop_future =
     438            0 :       bs_thread_pool.submit_loop(0, thread_num, [=](int i) {
     439            0 :         for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     440            0 :           unsigned int N = Ns[num_w];
     441            0 :           float *C = Cs[num_w];
     442            0 :           void *B = Bs[num_w];
     443            0 :           unsigned int M_step_start = (i * N) / thread_num;
     444            0 :           unsigned int M_step_end = ((i + 1) * N) / thread_num;
     445              : 
     446            0 :           M_step_start = (M_step_start % 8)
     447            0 :                            ? M_step_start + 8 - (M_step_start % 8)
     448              :                            : M_step_start;
     449              :           M_step_end =
     450            0 :             (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     451              : 
     452            0 :           nntr_gemv_q4_0_8x8_q8_0(K, (float *)(C + M_step_start), N,
     453            0 :                                   (void *)((char *)B + M_step_start * B_step),
     454            0 :                                   QA.data(), M, M_step_end - M_step_start);
     455              :         }
     456            0 :       });
     457              :     loop_future.wait();
     458            0 :   } else {
     459            0 :     int n_threads = std::thread::hardware_concurrency() / 2;
     460            0 :     unsigned int qa_4_rows_size = sizeof(block_q8_0x4) * blocks_per_4_rows;
     461            0 :     const size_t qa_row_size = (sizeof(block_q8_0) * K) / QK8_0;
     462              : 
     463            0 :     unsigned int M4 = ((M - M % 4) / 4);
     464            0 :     unsigned int qa_size = qa_4_rows_size * (((M >> 2) << 2) / 4 + 1);
     465              : 
     466            0 :     std::vector<char> QA = std::vector<char>(qa_size);
     467              : 
     468            0 :     for (unsigned int i = 0; i < M4; i++) {
     469            0 :       nntr_quantize_mat_q8_0_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
     470              :                                  K);
     471              :     }
     472              : 
     473            0 :     for (unsigned int i = M4 * 4; i < M; i++) {
     474            0 :       nntr_quantize_row_q8_0(
     475            0 :         (float *)A + i * K,
     476            0 :         (QA.data() + (M4 * qa_4_rows_size) + (i - M4 * 4) * qa_row_size), K);
     477              :     }
     478              : 
     479            0 : #pragma omp parallel for schedule(guided) num_threads(n_threads)
     480              :     for (int i = 0; i < n_threads; i++) {
     481              :       for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     482              :         unsigned int N = Ns[num_w];
     483              :         unsigned int ldc = ldcs[num_w];
     484              : 
     485              :         float *C = Cs[num_w];
     486              :         void *B = Bs[num_w];
     487              : 
     488              :         unsigned int src0_start = (i * N) / n_threads;
     489              :         unsigned int src0_end = ((i + 1) * N) / n_threads;
     490              : 
     491              :         src0_start =
     492              :           (src0_start % 8) ? src0_start + 8 - (src0_start % 8) : src0_start;
     493              : 
     494              :         src0_end = (src0_end % 8) ? src0_end + 8 - (src0_end % 8) : src0_end;
     495              : 
     496              :         nntr_gemm_q4_0_8x8_q8_0(K, (float *)(C + src0_start), ldc,
     497              :                                 (void *)((char *)B + src0_start * B_step),
     498              :                                 QA.data(), M4 * 4, src0_end - src0_start);
     499              :       }
     500              :     }
     501              : 
     502              :     n_threads = 4;
     503            0 : #pragma omp parallel for schedule(guided) num_threads(n_threads)
     504              :     for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
     505              :       for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     506              :         unsigned int N = Ns[num_w];
     507              :         unsigned int ldc = ldcs[num_w];
     508              :         float *C = Cs[num_w];
     509              :         void *B = Bs[num_w];
     510              : 
     511              :         for (int pb = M4 * 4; pb < static_cast<int>(M); pb++) {
     512              :           unsigned int M_step_start = (thread_idx * N) / n_threads;
     513              :           unsigned int M_step_end = ((thread_idx + 1) * N) / n_threads;
     514              :           M_step_start = (M_step_start % 8)
     515              :                            ? M_step_start + 8 - (M_step_start % 8)
     516              :                            : M_step_start;
     517              :           M_step_end =
     518              :             (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     519              : 
     520              :           nntr_gemv_q4_0_8x8_q8_0(
     521              :             K,
     522              :             (float *)((C + ((pb - M4 * 4) * N) + (M4 * 4 * N)) + M_step_start),
     523              :             N, (void *)((char *)B + M_step_start * B_step),
     524              :             QA.data() + (M4 * qa_4_rows_size) + (pb - M4 * 4) * qa_row_size, 1,
     525              :             M_step_end - M_step_start);
     526              :         }
     527              :       }
     528              :     }
     529            0 :   }
     530            0 : }
     531              : 
     532            5 : static inline void __ggml_q4_K_8x8_q8_K_GEMM_GEMV(
     533              :   const unsigned int M, const unsigned int N, const unsigned int K,
     534              :   const float *A, const unsigned int lda, const void *B, const unsigned int ldb,
     535              :   float *C, const unsigned int ldc) {
     536            5 :   int B_step = sizeof(block_q4_K) * (K / QK_K);
     537            5 :   int blocks_per_row = (K + QK_K - 1) / QK_K;
     538            5 :   int qa_size = sizeof(block_q8_K) * blocks_per_row;
     539            5 :   std::vector<char> QA = std::vector<char>(qa_size);
     540              :   auto qa_data = QA.data();
     541            5 :   nntr_quantize_row_q8_K(A, qa_data, K);
     542              : 
     543            5 :   auto &bs_thread_pool = ThreadPoolManager::Global().getThreadPool();
     544            5 :   int thread_num = bs_thread_pool.get_thread_count();
     545              :   BS::multi_future<void> loop_future =
     546           15 :     bs_thread_pool.submit_loop(0, thread_num, [=](int i) {
     547           20 :       unsigned int M_step_start = (i * N) / thread_num;
     548           20 :       unsigned int M_step_end = ((i + 1) * N) / thread_num;
     549              : 
     550           20 :       M_step_start = (M_step_start % 8) ? M_step_start + 8 - (M_step_start % 8)
     551              :                                         : M_step_start;
     552              :       M_step_end =
     553           20 :         (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     554              : 
     555           20 :       nntr_gemv_q4_K_8x8_q8_K(K, (float *)(C + M_step_start), N,
     556           20 :                               (void *)((char *)B + M_step_start * B_step),
     557           20 :                               QA.data(), M, M_step_end - M_step_start);
     558            5 :     });
     559              :   loop_future.wait();
     560            5 : }
     561              : 
     562            4 : static inline void __ggml_q4_K_8x8_q8_K_GEMM_GEMM(
     563              :   const unsigned int M, const unsigned int N, const unsigned int K,
     564              :   const float *A, const unsigned int lda, const void *B, const unsigned int ldb,
     565              :   float *C, const unsigned int ldc) {
     566            4 :   auto &bs_thread_pool = ThreadPoolManager::Global().getThreadPool();
     567            4 :   unsigned int blocks_per_4_rows = (K + QK_K - 1) / QK_K;
     568            4 :   unsigned int qa_4_rows_size = sizeof(block_q8_Kx4) * blocks_per_4_rows;
     569            4 :   const size_t qa_row_size = (sizeof(block_q8_K) * K) / QK_K;
     570            4 :   unsigned int M4 = ((M - M % 4) / 4);
     571            4 :   int B_step = sizeof(block_q4_K) * (K / QK_K);
     572              : 
     573            4 :   unsigned int qa_size = qa_4_rows_size * (((M >> 2) << 2) / 4 + 1);
     574            4 :   std::vector<char> QA = std::vector<char>(qa_size);
     575              : 
     576              :   // Quantize 4-divisible-M row portion with matrix-wise function
     577          474 :   for (unsigned int i = 0; i < M4; i++) {
     578          470 :     nntr_quantize_mat_q8_K_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
     579              :                                K);
     580              :   }
     581              :   // Quantize leftover 1 ~ 3 rows with row-wise function
     582           10 :   for (unsigned int i = M4 * 4; i < M; i++) {
     583            6 :     nntr_quantize_row_q8_K(
     584            6 :       (float *)A + i * K,
     585            6 :       (QA.data() + (M4 * qa_4_rows_size) + (i - M4 * 4) * qa_row_size), K);
     586              :   }
     587              : 
     588              :   ///@todo Dynamic thread-number selection for GEMM problem size
     589            4 :   int thread_num = bs_thread_pool.get_thread_count();
     590              :   BS::multi_future<void> multi_future =
     591           12 :     bs_thread_pool.submit_loop(0, thread_num, [=](int i) {
     592           16 :       unsigned int M_step_start = (i * N) / thread_num;
     593           16 :       unsigned int M_step_end = ((i + 1) * N) / thread_num;
     594              : 
     595           16 :       M_step_start = (M_step_start % 8) ? M_step_start + 8 - (M_step_start % 8)
     596              :                                         : M_step_start;
     597              :       M_step_end =
     598           16 :         (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     599              : 
     600           16 :       nntr_gemm_q4_K_8x8_q8_K(K, (C + (M_step_start)), ldc,
     601           16 :                               ((char *)B + ((M_step_start)*B_step)), QA.data(),
     602           16 :                               M4 * 4, (M_step_end) - (M_step_start));
     603            4 :     });
     604              :   multi_future.wait();
     605              : 
     606           10 :   for (unsigned int pb = M4 * 4; pb < M; pb++) {
     607              :     BS::multi_future<void> loop_future =
     608           18 :       bs_thread_pool.submit_loop(0, thread_num, [=](int i) {
     609           24 :         unsigned int M_step_start = (i * N) / thread_num;
     610           24 :         unsigned int M_step_end = ((i + 1) * N) / thread_num;
     611              : 
     612           24 :         M_step_start = (M_step_start % 8)
     613           24 :                          ? M_step_start + 8 - (M_step_start % 8)
     614              :                          : M_step_start;
     615              :         M_step_end =
     616           24 :           (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     617              : 
     618           24 :         nntr_gemv_q4_K_8x8_q8_K(
     619           24 :           K, (float *)((C + ((pb - M4 * 4) * N) + (M4 * 4 * N)) + M_step_start),
     620           24 :           N, (void *)((char *)B + M_step_start * B_step),
     621           24 :           QA.data() + (M4 * qa_4_rows_size) + (pb - M4 * 4) * qa_row_size, 1,
     622           24 :           M_step_end - M_step_start);
     623            6 :       });
     624              :     loop_future.wait();
     625              :   }
     626            4 : }
     627              : 
     628            9 : void __ggml_q4_K_8x8_q8_K_GEMM(const unsigned int M, const unsigned int N,
     629              :                                const unsigned int K, const float *A,
     630              :                                const unsigned int lda, const void *B,
     631              :                                const unsigned int ldb, float *C,
     632              :                                const unsigned int ldc) {
     633            9 :   if (M == 1) { // GEMV
     634            5 :     __ggml_q4_K_8x8_q8_K_GEMM_GEMV(M, N, K, A, lda, B, ldb, C, ldc);
     635              :   } else { // GEMM
     636            4 :     __ggml_q4_K_8x8_q8_K_GEMM_GEMM(M, N, K, A, lda, B, ldb, C, ldc);
     637              :   }
     638            9 : }
     639              : 
     640            0 : void __ggml_q4_K_8x8_q8_K_GEMM(const unsigned int M,
     641              :                                std::vector<unsigned int> Ns,
     642              :                                const unsigned int K, const float *A,
     643              :                                const unsigned int lda, std::vector<void *> Bs,
     644              :                                std::vector<unsigned int> ldbs,
     645              :                                std::vector<float *> Cs,
     646              :                                std::vector<unsigned int> ldcs) {
     647              : 
     648            0 :   auto &bs_thread_pool = ThreadPoolManager::Global().getThreadPool();
     649            0 :   int thread_num = bs_thread_pool.get_thread_count();
     650              : 
     651            0 :   int B_step = sizeof(block_q4_K) * (K / QK_K);
     652            0 :   int blocks_per_4_rows = (K + QK_K - 1) / QK_K;
     653              : 
     654            0 :   if (M == 1) {
     655            0 :     int qa_size = sizeof(block_q8_K) * blocks_per_4_rows;
     656            0 :     std::vector<char> QA = std::vector<char>(qa_size);
     657              :     auto qa_data = QA.data();
     658            0 :     nntr_quantize_row_q8_K(A, qa_data, K);
     659            0 :     if (std::all_of(Ns.begin(), Ns.end(),
     660              :                     [](unsigned int n) { return n <= 256; })) {
     661            0 :       for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     662            0 :         unsigned int N = Ns[num_w];
     663            0 :         float *C = Cs[num_w];
     664            0 :         void *B = Bs[num_w];
     665              : 
     666              :         unsigned int M_step_start = 0;
     667              :         unsigned int M_step_end = N;
     668              :         M_step_start = (M_step_start % 8)
     669              :                          ? M_step_start + 8 - (M_step_start % 8)
     670              :                          : M_step_start;
     671              :         M_step_end =
     672            0 :           (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     673              : 
     674            0 :         nntr_gemv_q4_K_8x8_q8_K(K, (float *)(C + M_step_start), N,
     675              :                                 (void *)((char *)B + M_step_start * B_step),
     676              :                                 QA.data(), M, M_step_end - M_step_start);
     677              :       }
     678              :     } else {
     679              :       BS::multi_future<void> loop_future =
     680            0 :         bs_thread_pool.submit_loop(0, thread_num, [=](int i) {
     681            0 :           for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     682            0 :             unsigned int N = Ns[num_w];
     683            0 :             float *C = Cs[num_w];
     684            0 :             void *B = Bs[num_w];
     685            0 :             unsigned int M_step_start = (i * N) / thread_num;
     686            0 :             unsigned int M_step_end = ((i + 1) * N) / thread_num;
     687              : 
     688            0 :             M_step_start = (M_step_start % 8)
     689            0 :                              ? M_step_start + 8 - (M_step_start % 8)
     690              :                              : M_step_start;
     691              :             M_step_end =
     692            0 :               (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     693              : 
     694            0 :             nntr_gemv_q4_K_8x8_q8_K(K, (float *)(C + M_step_start), N,
     695            0 :                                     (void *)((char *)B + M_step_start * B_step),
     696            0 :                                     QA.data(), M, M_step_end - M_step_start);
     697              :           }
     698            0 :         });
     699              :       loop_future.wait();
     700              :     }
     701            0 :   } else {
     702              : 
     703            0 :     int n_threads = std::thread::hardware_concurrency() / 2;
     704            0 :     unsigned int qa_4_rows_size = sizeof(block_q8_Kx4) * blocks_per_4_rows;
     705            0 :     const size_t qa_row_size = (sizeof(block_q8_K) * K) / QK_K;
     706              : 
     707            0 :     unsigned int M4 = ((M - M % 4) / 4);
     708            0 :     unsigned int qa_size = qa_4_rows_size * (((M >> 2) << 2) / 4 + 1);
     709              : 
     710            0 :     std::vector<char> QA = std::vector<char>(qa_size);
     711              : 
     712            0 :     for (unsigned int i = 0; i < M4; i++) {
     713            0 :       nntr_quantize_mat_q8_K_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
     714              :                                  K);
     715              :     }
     716              : 
     717            0 :     for (unsigned int i = M4 * 4; i < M; i++) {
     718            0 :       nntr_quantize_row_q8_K(
     719            0 :         (float *)A + i * K,
     720            0 :         (QA.data() + (M4 * qa_4_rows_size) + (i - M4 * 4) * qa_row_size), K);
     721              :     }
     722              : 
     723            0 : #pragma omp parallel for schedule(guided) num_threads(n_threads)
     724              :     for (int i = 0; i < n_threads; i++) {
     725              :       for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     726              :         unsigned int N = Ns[num_w];
     727              :         unsigned int ldc = ldcs[num_w];
     728              : 
     729              :         float *C = Cs[num_w];
     730              :         void *B = Bs[num_w];
     731              : 
     732              :         unsigned int src0_start = (i * N) / n_threads;
     733              :         unsigned int src0_end = ((i + 1) * N) / n_threads;
     734              : 
     735              :         src0_start =
     736              :           (src0_start % 8) ? src0_start + 8 - (src0_start % 8) : src0_start;
     737              : 
     738              :         src0_end = (src0_end % 8) ? src0_end + 8 - (src0_end % 8) : src0_end;
     739              : 
     740              :         nntr_gemm_q4_K_8x8_q8_K(K, (float *)(C + src0_start), ldc,
     741              :                                 (void *)((char *)B + src0_start * B_step),
     742              :                                 QA.data(), M4 * 4, src0_end - src0_start);
     743              :       }
     744              :     }
     745              : 
     746              :     n_threads = 4;
     747            0 : #pragma omp parallel for schedule(guided) num_threads(n_threads)
     748              :     for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
     749              :       for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     750              :         unsigned int N = Ns[num_w];
     751              :         unsigned int ldc = ldcs[num_w];
     752              :         float *C = Cs[num_w];
     753              :         void *B = Bs[num_w];
     754              : 
     755              :         for (int pb = M4 * 4; pb < static_cast<int>(M); pb++) {
     756              :           unsigned int M_step_start = (thread_idx * N) / n_threads;
     757              :           unsigned int M_step_end = ((thread_idx + 1) * N) / n_threads;
     758              :           M_step_start = (M_step_start % 8)
     759              :                            ? M_step_start + 8 - (M_step_start % 8)
     760              :                            : M_step_start;
     761              :           M_step_end =
     762              :             (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     763              : 
     764              :           nntr_gemv_q4_K_8x8_q8_K(
     765              :             K,
     766              :             (float *)((C + ((pb - M4 * 4) * N) + (M4 * 4 * N)) + M_step_start),
     767              :             N, (void *)((char *)B + M_step_start * B_step),
     768              :             QA.data() + (M4 * qa_4_rows_size) + (pb - M4 * 4) * qa_row_size, 1,
     769              :             M_step_end - M_step_start);
     770              :         }
     771              :       }
     772              :     }
     773            0 :   }
     774            0 : }
     775              : 
     776              : template <>
     777            7 : void __ggml_gemm_q6_K(const unsigned int M, const unsigned int N,
     778              :                       const unsigned int K, const float *A,
     779              :                       const unsigned int lda, const void *B,
     780              :                       const unsigned int ldb, float *C,
     781              :                       const unsigned int ldc) {
     782              :   static constexpr const int32_t bs = 1;
     783              :   static constexpr const int32_t bx = 1;
     784              :   static constexpr const int32_t by = 1;
     785              :   static constexpr const int32_t nrc = 1;
     786              : 
     787            7 :   const int32_t blocks_per_row = (K + QK_K - 1) / QK_K;
     788            7 :   const int32_t A_row_size = sizeof(block_q8_K) * blocks_per_row;
     789            7 :   const int32_t B_row_size = sizeof(block_q6_K) * blocks_per_row;
     790              : 
     791            7 :   auto &tp = ThreadPoolManager::Global().getThreadPool();
     792            7 :   if (M == 1) {
     793            3 :     std::vector<char> quantized_A(A_row_size);
     794            3 :     nntr_quantize_row_q8_K(A, quantized_A.data(), K);
     795            3 :     const void *quantized_A_data = quantized_A.data();
     796              : 
     797            3 :     auto fut = tp.submit_loop(0, static_cast<int>(N), [&](int i) {
     798         4608 :       const void *bptr = (const char *)B + i * B_row_size;
     799         4608 :       nntr_vec_dot_q6_K_q8_K(K, &C[i], bs, bptr, bx, quantized_A_data, by, nrc);
     800            3 :     });
     801              :     fut.wait();
     802            3 :   } else {
     803            4 :     const int32_t A_total_size = A_row_size * static_cast<int32_t>(M);
     804            4 :     std::vector<char> quantized_A(A_total_size);
     805              : 
     806         1890 :     for (int i = 0; i < static_cast<int>(M); ++i) {
     807         1886 :       void *row_ptr = quantized_A.data() + i * A_row_size;
     808         1886 :       nntr_quantize_row_q8_K(A + i * K, row_ptr, K);
     809              :     }
     810              : 
     811            4 :     auto fut = tp.submit_loop(0, static_cast<int>(M), [&](int i) {
     812         1886 :       const void *a_row = quantized_A.data() + i * A_row_size;
     813         1886 :       float *c_row = C + i * ldc;
     814       967518 :       for (unsigned int j = 0; j < N; ++j) {
     815       965632 :         const void *bptr = (const char *)B + j * B_row_size;
     816       965632 :         nntr_vec_dot_q6_K_q8_K(K, &c_row[j], bs, bptr, bx, a_row, by, nrc);
     817              :       }
     818            4 :     });
     819              :     fut.wait();
     820            4 :   }
     821            7 : }
     822              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1