LCOV - code coverage report
Current view: top level - nntrainer/tensor/cpu_backend/ggml_interface - ggml_interface_omp.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 49.0 % 147 72
Test Date: 2026-01-12 20:43:37 Functions: 42.9 % 7 3

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2025 Michal Wlasiuk <testmailsmtp12345@gmail.com>
       4              :  * Copyright (C) 2025 Sungsik Kong <ss.kong@samsung.com>
       5              :  *
       6              :  * @file   ggml_interface_omp.cpp
       7              :  * @date   15 April 2025
       8              :  * @see    https://github.com/nntrainer/nntrainer
       9              :  * @author Michal Wlasiuk <testmailsmtp12345@gmail.com>
      10              :  * @author Sungsik Kong <ss.kong@samsung.com>
      11              :  * @bug    No known bugs except for NYI items
      12              :  * @brief  Function interface to use ggml lib from cpu_backend - accelerated
      13              :  * only with openMP
      14              :  */
      15              : 
      16              : #include <ggml_interface.h>
      17              : #include <nntr_ggml_impl.h>
      18              : #include <nntr_ggml_impl_utils.h>
      19              : 
      20              : #include <algorithm>
      21              : #include <stdexcept>
      22              : #include <string>
      23              : #include <thread>
      24              : #include <vector>
      25              : 
      26              : namespace nntrainer {
      27              : 
      28              : template <>
      29            0 : void __ggml_q4_0_4x8_q8_0_GEMM(const unsigned int M, const unsigned int N,
      30              :                                const unsigned int K, const float *A,
      31              :                                const unsigned int lda, const void *B,
      32              :                                const unsigned int ldb, float *C,
      33              :                                const unsigned int ldc) {
      34              :   int NB_COLS = 4;
      35            0 :   if (M == 1) { // GEMV
      36              :     int n_threads = 4;
      37            0 :     unsigned int B_step = sizeof(block_q4_0) * (K / QK4_0);
      38            0 :     unsigned int blocks_per_row = (K + QK8_0 - 1) / QK8_0;
      39            0 :     unsigned int qa_size = sizeof(block_q8_0) * blocks_per_row;
      40            0 :     std::vector<char> QA = std::vector<char>(qa_size);
      41            0 :     nntr_quantize_row_q8_0(A, QA.data(), K);
      42              : 
      43            0 : #pragma omp parallel for num_threads(n_threads)
      44              :     for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
      45              :       unsigned int M_step_start = (thread_idx * N) / n_threads;     // = 0
      46              :       unsigned int M_step_end = ((thread_idx + 1) * N) / n_threads; // ne01 = N
      47              : 
      48              :       M_step_start = (M_step_start % NB_COLS)
      49              :                        ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
      50              :                        : M_step_start;
      51              :       M_step_end = (M_step_end % NB_COLS)
      52              :                      ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
      53              :                      : M_step_end;
      54              : 
      55              :       nntr_gemv_q4_0_4x8_q8_0(K, (float *)((C) + M_step_start), N,
      56              :                               (void *)((char *)B + M_step_start * B_step),
      57              :                               QA.data(), M, M_step_end - M_step_start);
      58              :     }
      59            0 :   } else if (M % 4 != 0) {
      60              :     int n_threads = 8;
      61            0 :     unsigned int blocks_per_4_rows = (K + QK8_0 - 1) / QK8_0;
      62            0 :     unsigned int qa_4_rows_size = sizeof(block_q8_0x4) * blocks_per_4_rows;
      63            0 :     const size_t qa_row_size = (sizeof(block_q8_0) * K) / QK8_0;
      64            0 :     unsigned int M4 = ((M - M % 4) / 4);
      65            0 :     int B_step = sizeof(block_q4_0) * (K / QK4_0);
      66              : 
      67            0 :     unsigned int qa_size = qa_4_rows_size * (((M >> 2) << 2) / 4 + 1);
      68            0 :     std::vector<char> QA = std::vector<char>(qa_size);
      69              : 
      70              :     // Quantize 4-divisible-M row portion with matrix-wise function
      71            0 :     for (unsigned int i = 0; i < M4; i++) {
      72            0 :       nntr_quantize_mat_q8_0_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
      73              :                                  K);
      74              :     }
      75              :     // Quantize leftover 1 ~ 3 rows with row-wise function
      76            0 :     for (unsigned int i = M4 * 4; i < M; i++) {
      77            0 :       nntr_quantize_row_q8_0(
      78            0 :         (float *)A + i * K,
      79            0 :         (QA.data() + (M4 * qa_4_rows_size) + (i - M4 * 4) * qa_row_size), K);
      80              :     }
      81              : 
      82              : // Compute 4-divisible-M row portion with multithreaded GEMM
      83            0 : #pragma omp parallel for num_threads(n_threads)
      84              :     for (int i = 0; i < n_threads; i++) {
      85              :       unsigned int src0_start = (i * N) / n_threads;
      86              :       unsigned int src0_end = ((i + 1) * N) / n_threads;
      87              : 
      88              :       src0_start = (src0_start % NB_COLS)
      89              :                      ? src0_start + NB_COLS - (src0_start % NB_COLS)
      90              :                      : src0_start;
      91              :       src0_end = (src0_end % NB_COLS)
      92              :                    ? src0_end + NB_COLS - (src0_end % NB_COLS)
      93              :                    : src0_end;
      94              : 
      95              :       nntr_gemm_q4_0_4x8_q8_0(K, (float *)(C + src0_start), ldc,
      96              :                               (void *)((char *)B + src0_start * B_step),
      97              :                               QA.data(), M4 * 4, src0_end - src0_start);
      98              :     }
      99              : 
     100              :     // Compute leftover 1 ~ 3 rows with multithreaded GEMV
     101              :     n_threads = 4;
     102            0 :     for (unsigned int pb = M4 * 4; pb < M; pb++) {
     103            0 : #pragma omp parallel for num_threads(n_threads)
     104              :       for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
     105              :         unsigned int M_step_start = (thread_idx * N) / n_threads; // = 0
     106              :         unsigned int M_step_end =
     107              :           ((thread_idx + 1) * N) / n_threads; // ne01 = N
     108              : 
     109              :         M_step_start = (M_step_start % NB_COLS)
     110              :                          ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
     111              :                          : M_step_start;
     112              :         M_step_end = (M_step_end % NB_COLS)
     113              :                        ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
     114              :                        : M_step_end;
     115              : 
     116              :         nntr_gemv_q4_0_4x8_q8_0(
     117              :           K, (float *)((C + ((pb - M4 * 4) * N) + (M4 * 4 * N)) + M_step_start),
     118              :           N, (void *)((char *)B + M_step_start * B_step),
     119              :           QA.data() + (M4 * qa_4_rows_size) + (pb - M4 * 4) * qa_row_size, 1,
     120              :           M_step_end - M_step_start);
     121              :       }
     122              :     }
     123            0 :   } else { // GEMM
     124            0 :     unsigned int blocks_per_4_rows = (K + QK8_0 - 1) / QK8_0;
     125            0 :     unsigned int qa_4_rows_size = sizeof(block_q8_0x4) * blocks_per_4_rows;
     126            0 :     unsigned int M4 = ((M + 3) / 4);
     127              : 
     128            0 :     unsigned int qa_size = qa_4_rows_size * M4;
     129            0 :     std::vector<char> QA = std::vector<char>(qa_size);
     130              : 
     131              :     // Quantization of activations
     132              :     /// @note Heuristic inspection conducted that applying multithreading on
     133              :     /// run-time quantization hurts model latency
     134              :     // #pragma omp parallel for num_threads(16)
     135            0 :     for (int i = 0; i < static_cast<int>(M4); i++) {
     136            0 :       nntr_quantize_mat_q8_0_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
     137              :                                  K);
     138              :     }
     139            0 :     int thread_num = std::thread::hardware_concurrency() / 2;
     140            0 :     unsigned int B_step = sizeof(block_q4_0) * (K / QK4_0);
     141              : 
     142            0 : #pragma omp parallel for num_threads(thread_num)
     143              :     for (int i = 0; i < thread_num; i++) {
     144              :       unsigned int src0_start = (i * N) / thread_num;
     145              :       unsigned int src0_end = ((i + 1) * N) / thread_num;
     146              : 
     147              :       src0_start = (src0_start % NB_COLS)
     148              :                      ? src0_start + NB_COLS - (src0_start % NB_COLS)
     149              :                      : src0_start;
     150              :       src0_end = (src0_end % NB_COLS)
     151              :                    ? src0_end + NB_COLS - (src0_end % NB_COLS)
     152              :                    : src0_end;
     153              : 
     154              :       nntr_gemm_q4_0_4x8_q8_0(K, (float *)(C + src0_start), ldc,
     155              :                               (void *)((char *)B + src0_start * B_step),
     156              :                               QA.data(), M, src0_end - src0_start);
     157              :     }
     158            0 :   }
     159            0 : }
     160              : 
     161              : template <>
     162            0 : void __ggml_q4_0_4x8_q8_0_GEMM(const unsigned int M,
     163              :                                std::vector<unsigned int> Ns,
     164              :                                const unsigned int K, const float *A,
     165              :                                const unsigned int lda, std::vector<void *> Bs,
     166              :                                std::vector<unsigned int> ldbs,
     167              :                                std::vector<float *> Cs,
     168              :                                std::vector<unsigned int> ldcs) {
     169              :   int NB_COLS = 4;
     170            0 :   int B_step = sizeof(block_q4_0) * (K / QK4_0);
     171            0 :   int blocks_per_4_rows = (K + QK8_0 - 1) / QK8_0;
     172              : 
     173            0 :   if (M == 1) {
     174            0 :     int qa_size = sizeof(block_q8_0) * blocks_per_4_rows;
     175            0 :     std::vector<char> QA = std::vector<char>(qa_size);
     176              :     auto qa_data = QA.data();
     177            0 :     nntr_quantize_row_q8_0(A, qa_data, K);
     178            0 :     if (std::all_of(Ns.begin(), Ns.end(),
     179              :                     [](unsigned int n) { return n <= 256; })) {
     180            0 :       for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     181            0 :         unsigned int N = Ns[num_w];
     182            0 :         float *C = Cs[num_w];
     183            0 :         void *B = Bs[num_w];
     184              : 
     185              :         unsigned int M_step_start = 0;
     186              :         unsigned int M_step_end = N;
     187              :         M_step_start = (M_step_start % NB_COLS)
     188              :                          ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
     189              :                          : M_step_start;
     190            0 :         M_step_end = (M_step_end % NB_COLS)
     191            0 :                        ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
     192              :                        : M_step_end;
     193              : 
     194            0 :         nntr_gemv_q4_0_4x8_q8_0(K, (float *)(C + M_step_start), N,
     195              :                                 (void *)((char *)B + M_step_start * B_step),
     196              :                                 QA.data(), M, M_step_end - M_step_start);
     197              :       }
     198              :     } else {
     199              :       int n_threads = 1;
     200              :       // std::cout << "Parrallel gemv Ns.size(): " << Ns.size() << std::endl;
     201            0 : #pragma omp parallel for num_threads(n_threads)
     202              :       for (int i = 0; i < n_threads; ++i) {
     203              :         for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     204              :           unsigned int N = Ns[num_w];
     205              :           float *C = Cs[num_w];
     206              :           void *B = Bs[num_w];
     207              :           unsigned int M_step_start = (i * N) / n_threads;
     208              :           unsigned int M_step_end = ((i + 1) * N) / n_threads;
     209              : 
     210              :           M_step_start = (M_step_start % NB_COLS)
     211              :                            ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
     212              :                            : M_step_start;
     213              :           M_step_end = (M_step_end % NB_COLS)
     214              :                          ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
     215              :                          : M_step_end;
     216              : 
     217              :           nntr_gemv_q4_0_4x8_q8_0(K, (float *)(C + M_step_start), N,
     218              :                                   (void *)((char *)B + M_step_start * B_step),
     219              :                                   QA.data(), M, M_step_end - M_step_start);
     220              :         }
     221              :       }
     222              :     }
     223            0 :   } else {
     224              :     int n_threads = 4;
     225            0 :     unsigned int qa_4_rows_size = sizeof(block_q8_0x4) * blocks_per_4_rows;
     226            0 :     const size_t qa_row_size = (sizeof(block_q8_0) * K) / QK8_0;
     227              : 
     228            0 :     unsigned int M4 = ((M - M % 4) / 4);
     229            0 :     unsigned int qa_size = qa_4_rows_size * (((M >> 2) << 2) / 4 + 1);
     230              : 
     231            0 :     std::vector<char> QA = std::vector<char>(qa_size);
     232              : 
     233            0 :     for (unsigned int i = 0; i < M4; i++) {
     234            0 :       nntr_quantize_mat_q8_0_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
     235              :                                  K);
     236              :     }
     237              : 
     238            0 :     for (unsigned int i = M4 * 4; i < M; i++) {
     239            0 :       nntr_quantize_row_q8_0(
     240            0 :         (float *)A + i * K,
     241            0 :         (QA.data() + (M4 * qa_4_rows_size) + (i - M4 * 4) * qa_row_size), K);
     242              :     }
     243              : 
     244            0 : #pragma omp parallel for schedule(guided) num_threads(n_threads)
     245              :     for (int i = 0; i < n_threads; i++) {
     246              :       for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     247              :         unsigned int N = Ns[num_w];
     248              :         unsigned int ldc = ldcs[num_w];
     249              : 
     250              :         float *C = Cs[num_w];
     251              :         void *B = Bs[num_w];
     252              : 
     253              :         unsigned int src0_start = (i * N) / n_threads;
     254              :         unsigned int src0_end = ((i + 1) * N) / n_threads;
     255              : 
     256              :         src0_start = (src0_start % NB_COLS)
     257              :                        ? src0_start + NB_COLS - (src0_start % NB_COLS)
     258              :                        : src0_start;
     259              : 
     260              :         src0_end = (src0_end % NB_COLS)
     261              :                      ? src0_end + NB_COLS - (src0_end % NB_COLS)
     262              :                      : src0_end;
     263              : 
     264              :         nntr_gemm_q4_0_4x8_q8_0(K, (float *)(C + src0_start), ldc,
     265              :                                 (void *)((char *)B + src0_start * B_step),
     266              :                                 QA.data(), M4 * 4, src0_end - src0_start);
     267              :       }
     268              :     }
     269            0 :     if (M4 * 4 != M) {
     270              :       n_threads = 4;
     271            0 : #pragma omp parallel for schedule(guided) num_threads(n_threads)
     272              :       for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
     273              :         for (unsigned int num_w = 0; num_w < Ns.size(); ++num_w) {
     274              :           unsigned int N = Ns[num_w];
     275              :           unsigned int ldc = ldcs[num_w];
     276              :           float *C = Cs[num_w];
     277              :           void *B = Bs[num_w];
     278              : 
     279              :           for (int pb = M4 * 4; pb < static_cast<int>(M); pb++) {
     280              :             unsigned int M_step_start = (thread_idx * N) / n_threads;
     281              :             unsigned int M_step_end = ((thread_idx + 1) * N) / n_threads;
     282              :             M_step_start = (M_step_start % NB_COLS)
     283              :                              ? M_step_start + NB_COLS - (M_step_start % NB_COLS)
     284              :                              : M_step_start;
     285              :             M_step_end = (M_step_end % NB_COLS)
     286              :                            ? M_step_end + NB_COLS - (M_step_end % NB_COLS)
     287              :                            : M_step_end;
     288              : 
     289              :             nntr_gemv_q4_0_4x8_q8_0(
     290              :               K,
     291              :               (float *)((C + ((pb - M4 * 4) * N) + (M4 * 4 * N)) +
     292              :                         M_step_start),
     293              :               N, (void *)((char *)B + M_step_start * B_step),
     294              :               QA.data() + (M4 * qa_4_rows_size) + (pb - M4 * 4) * qa_row_size,
     295              :               1, M_step_end - M_step_start);
     296              :           }
     297              :         }
     298              :       }
     299              :     }
     300            0 :   }
     301            0 : }
     302              : 
     303           55 : void __ggml_q4_0_8x8_q8_0_GEMM(const unsigned int M, const unsigned int N,
     304              :                                const unsigned int K, const float *A,
     305              :                                const unsigned int lda, const void *B,
     306              :                                const unsigned int ldb, float *C,
     307              :                                const unsigned int ldc) {
     308           55 :   if (M == 1) { // GEMV
     309              :     int n_threads = 4;
     310            3 :     unsigned int B_step = sizeof(block_q4_0) * (K / QK4_0);
     311            3 :     unsigned int blocks_per_row = (K + QK8_0 - 1) / QK8_0;
     312            3 :     unsigned int qa_size = sizeof(block_q8_0) * blocks_per_row;
     313            3 :     std::vector<char> QA = std::vector<char>(qa_size);
     314            3 :     nntr_quantize_row_q8_0(A, QA.data(), K);
     315              : 
     316            3 : #pragma omp parallel for num_threads(n_threads)
     317              :     for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
     318              :       unsigned int M_step_start = (thread_idx * N) / n_threads;     // = 0
     319              :       unsigned int M_step_end = ((thread_idx + 1) * N) / n_threads; // ne01 = N
     320              : 
     321              :       M_step_start = (M_step_start % 8) ? M_step_start + 8 - (M_step_start % 8)
     322              :                                         : M_step_start;
     323              :       M_step_end =
     324              :         (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     325              : 
     326              :       nntr_gemv_q4_0_8x8_q8_0(K, (float *)((C) + M_step_start), N,
     327              :                               (void *)((char *)B + M_step_start * B_step),
     328              :                               QA.data(), M, M_step_end - M_step_start);
     329              :     }
     330            3 :   } else { // GEMM
     331           52 :     int n_threads = std::thread::hardware_concurrency() / 2;
     332           52 :     unsigned int blocks_per_4_rows = (K + QK8_0 - 1) / QK8_0;
     333           52 :     unsigned int qa_4_rows_size = sizeof(block_q8_0x4) * blocks_per_4_rows;
     334           52 :     const size_t qa_row_size = (sizeof(block_q8_0) * K) / QK8_0;
     335           52 :     unsigned int M4 = ((M - M % 4) / 4);
     336           52 :     int B_step = sizeof(block_q4_0) * (K / QK4_0);
     337              : 
     338           52 :     unsigned int qa_size = qa_4_rows_size * (((M >> 2) << 2) / 4 + 1);
     339           52 :     std::vector<char> QA = std::vector<char>(qa_size);
     340              : 
     341              :     // Quantize 4-divisible-M row portion with matrix-wise function
     342          570 :     for (unsigned int i = 0; i < M4; i++) {
     343          518 :       nntr_quantize_mat_q8_0_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
     344              :                                  K);
     345              :     }
     346              :     // Quantize leftover 1 ~ 3 rows with row-wise function
     347           58 :     for (unsigned int i = M4 * 4; i < M; i++) {
     348            6 :       nntr_quantize_row_q8_0(
     349            6 :         (float *)A + i * K,
     350            6 :         (QA.data() + (M4 * qa_4_rows_size) + (i - M4 * 4) * qa_row_size), K);
     351              :     }
     352              : 
     353              : // Compute 4-divisible-M row portion with multithreaded GEMM
     354           52 : #pragma omp parallel for num_threads(n_threads)
     355              :     for (int i = 0; i < n_threads; i++) {
     356              :       unsigned int src0_start = (i * N) / n_threads;
     357              :       unsigned int src0_end = ((i + 1) * N) / n_threads;
     358              : 
     359              :       src0_start =
     360              :         (src0_start % 8) ? src0_start + 8 - (src0_start % 8) : src0_start;
     361              :       src0_end = (src0_end % 8) ? src0_end + 8 - (src0_end % 8) : src0_end;
     362              : 
     363              :       nntr_gemm_q4_0_8x8_q8_0(K, (float *)(C + src0_start), ldc,
     364              :                               (void *)((char *)B + src0_start * B_step),
     365              :                               QA.data(), M4 * 4, src0_end - src0_start);
     366              :     }
     367              : 
     368              :     // Compute leftover 1 ~ 3 rows with multithreaded GEMV
     369              :     n_threads = 4;
     370           58 :     for (unsigned int pb = M4 * 4; pb < M; pb++) {
     371            6 : #pragma omp parallel for num_threads(n_threads)
     372              :       for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
     373              :         unsigned int M_step_start = (thread_idx * N) / n_threads;
     374              :         unsigned int M_step_end = ((thread_idx + 1) * N) / n_threads;
     375              : 
     376              :         M_step_start = (M_step_start % 8)
     377              :                          ? M_step_start + 8 - (M_step_start % 8)
     378              :                          : M_step_start;
     379              :         M_step_end =
     380              :           (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     381              : 
     382              :         nntr_gemv_q4_0_8x8_q8_0(
     383              :           K, (float *)((C + ((pb - M4 * 4) * N) + (M4 * 4 * N)) + M_step_start),
     384              :           N, (void *)((char *)B + M_step_start * B_step),
     385              :           QA.data() + (M4 * qa_4_rows_size) + (pb - M4 * 4) * qa_row_size, 1,
     386              :           M_step_end - M_step_start);
     387              :       }
     388              :     }
     389           52 :   }
     390           55 : }
     391              : 
     392              : template <>
     393            0 : void __ggml_q4_0_8x8_q8_0_GEMM(const unsigned int M,
     394              :                                std::vector<unsigned int> Ns,
     395              :                                const unsigned int K, const float *A,
     396              :                                const unsigned int lda, std::vector<void *> Bs,
     397              :                                std::vector<unsigned int> ldbs,
     398              :                                std::vector<float *> C,
     399              :                                std::vector<unsigned int> ldcs) {
     400              :   throw std::runtime_error("nntrainer::__ggml_q4_0_8x8_q8_0_GEMM for "
     401            0 :                            "multi-weights is not implemented yet");
     402              : }
     403              : 
     404            9 : void __ggml_q4_K_8x8_q8_K_GEMM(const unsigned int M, const unsigned int N,
     405              :                                const unsigned int K, const float *A,
     406              :                                const unsigned int lda, const void *B,
     407              :                                const unsigned int ldb, float *C,
     408              :                                const unsigned int ldc) {
     409            9 :   if (M == 1) { // GEMV
     410              :     int n_threads = 4;
     411            5 :     unsigned int blocks_per_row = (K + QK_K - 1) / QK_K;
     412            5 :     unsigned int qa_size = sizeof(block_q8_K) * blocks_per_row;
     413            5 :     unsigned int B_step = sizeof(block_q4_K) * (K / QK_K);
     414              : 
     415            5 :     std::vector<char> QA = std::vector<char>(qa_size);
     416              : 
     417            5 :     nntr_quantize_row_q8_K(A, QA.data(), K);
     418              : 
     419            5 : #pragma omp parallel for num_threads(n_threads)
     420              :     for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
     421              :       unsigned int M_step_start = (thread_idx * N) / n_threads;     // = 0
     422              :       unsigned int M_step_end = ((thread_idx + 1) * N) / n_threads; // ne01 = N
     423              : 
     424              :       M_step_start = (M_step_start % 8) ? M_step_start + 8 - (M_step_start % 8)
     425              :                                         : M_step_start;
     426              :       M_step_end =
     427              :         (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     428              : 
     429              :       nntr_gemv_q4_K_8x8_q8_K(K, (float *)((C) + M_step_start), N,
     430              :                               (void *)((char *)B + M_step_start * B_step),
     431              :                               QA.data(), M, M_step_end - M_step_start);
     432              :     }
     433            5 :   } else {
     434            4 :     int n_threads = std::thread::hardware_concurrency() / 2;
     435            4 :     unsigned int blocks_per_4_rows = (K + QK_K - 1) / QK_K;
     436            4 :     unsigned int qa_4_rows_size = sizeof(block_q8_Kx4) * blocks_per_4_rows;
     437            4 :     const size_t qa_row_size = (sizeof(block_q8_K) * K) / QK_K;
     438            4 :     unsigned int M4 = ((M - M % 4) / 4);
     439            4 :     int B_step = sizeof(block_q4_K) * (K / QK_K);
     440              : 
     441            4 :     unsigned int qa_size = qa_4_rows_size * (((M >> 2) << 2) / 4 + 1);
     442            4 :     std::vector<char> QA = std::vector<char>(qa_size);
     443              : 
     444              :     // Quantize 4-divisible-M row portion with matrix-wise function
     445          474 :     for (unsigned int i = 0; i < M4; i++) {
     446          470 :       nntr_quantize_mat_q8_K_4x8(A + 4 * i * K, QA.data() + i * qa_4_rows_size,
     447              :                                  K);
     448              :     }
     449              :     // Quantize leftover 1 ~ 3 rows with row-wise function
     450           10 :     for (unsigned int i = M4 * 4; i < M; i++) {
     451            6 :       nntr_quantize_row_q8_K(
     452            6 :         (float *)A + i * K,
     453            6 :         (QA.data() + (M4 * qa_4_rows_size) + (i - M4 * 4) * qa_row_size), K);
     454              :     }
     455              : 
     456              : // Compute 4-divisible-M row portion with multithreaded GEMM
     457            4 : #pragma omp parallel for num_threads(n_threads)
     458              :     for (int i = 0; i < n_threads; i++) {
     459              :       unsigned int src0_start = (i * N) / n_threads;
     460              :       unsigned int src0_end = ((i + 1) * N) / n_threads;
     461              : 
     462              :       src0_start =
     463              :         (src0_start % 8) ? src0_start + 8 - (src0_start % 8) : src0_start;
     464              :       src0_end = (src0_end % 8) ? src0_end + 8 - (src0_end % 8) : src0_end;
     465              : 
     466              :       nntr_gemm_q4_K_8x8_q8_K(K, (float *)(C + src0_start), ldc,
     467              :                               (void *)((char *)B + src0_start * B_step),
     468              :                               QA.data(), M4 * 4, src0_end - src0_start);
     469              :     }
     470              : 
     471              :     // Compute leftover 1 ~ 3 rows with multithreaded GEMV
     472              :     n_threads = 4;
     473           10 :     for (unsigned int pb = M4 * 4; pb < M; pb++) {
     474            6 : #pragma omp parallel for num_threads(n_threads)
     475              :       for (int thread_idx = 0; thread_idx < n_threads; ++thread_idx) {
     476              :         unsigned int M_step_start = (thread_idx * N) / n_threads;
     477              :         unsigned int M_step_end = ((thread_idx + 1) * N) / n_threads;
     478              : 
     479              :         M_step_start = (M_step_start % 8)
     480              :                          ? M_step_start + 8 - (M_step_start % 8)
     481              :                          : M_step_start;
     482              :         M_step_end =
     483              :           (M_step_end % 8) ? M_step_end + 8 - (M_step_end % 8) : M_step_end;
     484              : 
     485              :         nntr_gemv_q4_K_8x8_q8_K(
     486              :           K, (float *)((C + ((pb - M4 * 4) * N) + (M4 * 4 * N)) + M_step_start),
     487              :           N, (void *)((char *)B + M_step_start * B_step),
     488              :           QA.data() + (M4 * qa_4_rows_size) + (pb - M4 * 4) * qa_row_size, 1,
     489              :           M_step_end - M_step_start);
     490              :       }
     491              :     }
     492            4 :   }
     493            9 : }
     494              : 
     495            0 : void __ggml_q4_K_8x8_q8_K_GEMM(const unsigned int M,
     496              :                                std::vector<unsigned int> Ns,
     497              :                                const unsigned int K, const float *A,
     498              :                                const unsigned int lda, std::vector<void *> Bs,
     499              :                                std::vector<unsigned int> ldbs,
     500              :                                std::vector<float *> C,
     501              :                                std::vector<unsigned int> ldcs) {
     502              :   throw std::runtime_error("nntrainer::__ggml_q4_K_8x8_q8_K_GEMM for "
     503            0 :                            "multi-weights is not implemented yet");
     504              : }
     505              : 
     506              : template <>
     507            7 : void __ggml_gemm_q6_K(const unsigned int M, const unsigned int N,
     508              :                       const unsigned int K, const float *A,
     509              :                       const unsigned int lda, const void *B,
     510              :                       const unsigned int ldb, float *C,
     511              :                       const unsigned int ldc) {
     512            7 :   int32_t thread_count = std::thread::hardware_concurrency() / 2;
     513              : 
     514              :   static constexpr const int32_t bs = 1;  // unused in ggml_vec_dot_q6_K_q8_K
     515              :   static constexpr const int32_t bx = 1;  // unused in ggml_vec_dot_q6_K_q8_K
     516              :   static constexpr const int32_t by = 1;  // unused in ggml_vec_dot_q6_K_q8_K
     517              :   static constexpr const int32_t nrc = 1; // unused in ggml_vec_dot_q6_K_q8_K
     518              : 
     519            7 :   const int32_t blocks_per_row = (K + QK_K - 1) / QK_K;
     520            7 :   const int32_t A_row_size = sizeof(block_q8_K) * blocks_per_row;
     521            7 :   const int32_t B_row_size = sizeof(block_q6_K) * blocks_per_row;
     522              : 
     523              :   // GEMV
     524            7 :   if (M == 1) {
     525              :     thread_count = 4;
     526            3 :     std::vector<char> quantized_A(A_row_size);
     527            3 :     nntr_quantize_row_q8_K(A, quantized_A.data(), K);
     528              : 
     529              :     const void *const quantized_A_data = quantized_A.data();
     530              : 
     531            3 : #pragma omp parallel for num_threads(thread_count)
     532              :     for (int32_t thread_job = 0; thread_job < static_cast<int>(N);
     533              :          thread_job++) {
     534              :       const int32_t B_row_data_offset = B_row_size * thread_job;
     535              : 
     536              :       const void *const B_data = (void *)((char *)B + B_row_data_offset);
     537              : 
     538              :       nntr_vec_dot_q6_K_q8_K(K, &C[thread_job], bs, B_data, bx,
     539              :                              quantized_A_data, by, nrc);
     540              :     }
     541            3 :   } else { // GEMM
     542            4 :     const int32_t A_total_size = A_row_size * M;
     543            4 :     std::vector<char> quantized_A(A_total_size);
     544              : 
     545            4 : #pragma omp parallel for num_threads(thread_count)
     546              :     for (int32_t thread_job = 0; thread_job < static_cast<int>(M);
     547              :          thread_job++) {
     548              :       const int32_t A_row_data_offset = A_row_size * thread_job;
     549              :       void *A_data = (void *)((char *)quantized_A.data() + A_row_data_offset);
     550              :       nntr_quantize_row_q8_K(A + thread_job * K, A_data, K);
     551              :     }
     552            4 : #pragma omp parallel for num_threads(thread_count)
     553              :     for (int32_t thread_job = 0; thread_job < static_cast<int>(M);
     554              :          thread_job++) {
     555              :       const int32_t A_row_data_offset = A_row_size * thread_job;
     556              :       void *A_data = (void *)((char *)quantized_A.data() + A_row_data_offset);
     557              : 
     558              :       for (uint32_t j = 0; j < N; j++) {
     559              :         const int32_t B_row_data_offset = B_row_size * j;
     560              :         const void *const B_data = (void *)((char *)B + B_row_data_offset);
     561              : 
     562              :         nntr_vec_dot_q6_K_q8_K(K, &C[thread_job * ldc + j], bs, B_data, bx,
     563              :                                A_data, by, nrc);
     564              :       }
     565              :     }
     566            4 :   }
     567            7 : }
     568              : 
     569              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1