LCOV - code coverage report
Current view: top level - nntrainer/tensor - cache_loader.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 0.0 % 129 0
Test Date: 2025-12-14 20:38:17 Functions: 0.0 % 22 0

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2022 Jiho Chu <jiho.chu@samsung.com>
       4              :  *
       5              :  * @file   cache_loader.cpp
       6              :  * @date   10 Nov 2022
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Jiho Chu <jiho.chu@samsung.com>
       9              :  * @bug    No known bugs except for NYI items
      10              :  * @brief  Cache loader class
      11              :  *
      12              :  */
      13              : 
      14              : #include "cache_loader.h"
      15              : #include "task.h"
      16              : #include "task_executor.h"
      17              : 
      18              : #include <cache_pool.h>
      19              : #include <climits>
      20              : #include <cstdint>
      21              : #include <exception>
      22              : #include <memory>
      23              : #include <nntrainer_error.h>
      24              : #include <nntrainer_log.h>
      25              : 
      26              : namespace nntrainer {
      27              : 
      28            0 : CacheLoader::CacheLoader(std::shared_ptr<CachePool> cache_pool) :
      29              :   pool(cache_pool),
      30            0 :   load_task_executor(nullptr),
      31            0 :   unload_task_executor(nullptr) {}
      32              : 
      33            0 : CacheLoader::~CacheLoader() {
      34            0 :   if (load_task_executor)
      35            0 :     delete load_task_executor;
      36            0 :   if (unload_task_executor)
      37            0 :     delete unload_task_executor;
      38            0 : }
      39              : 
      40            0 : void CacheLoader::init() {
      41            0 :   if (load_task_executor == nullptr)
      42            0 :     load_task_executor = new TaskExecutor("loadPool", 2);
      43            0 :   if (unload_task_executor == nullptr)
      44            0 :     unload_task_executor = new TaskExecutor("UnloadPool", 2);
      45            0 : }
      46              : 
      47            0 : void CacheLoader::finish() {
      48            0 :   delete load_task_executor;
      49            0 :   load_task_executor = nullptr;
      50            0 :   delete unload_task_executor;
      51            0 :   unload_task_executor = nullptr;
      52            0 : }
      53              : 
      54            0 : void CacheLoader::load(unsigned int order) { loadAllinOrder(order); }
      55              : 
      56            0 : bool CacheLoader::loadAllinOrder(unsigned int order) {
      57            0 :   if (!load_task_executor) {
      58            0 :     ml_loge("init is needed");
      59            0 :     return false;
      60              :   }
      61              : 
      62            0 :   std::set<unsigned int> exec_id = pool->getExecIDs(order);
      63              : 
      64            0 :   for (auto &id : exec_id) {
      65            0 :     loadTensor(id);
      66              :   }
      67              : 
      68              :   return true;
      69              : }
      70              : 
      71            0 : int CacheLoader::loadTensor(unsigned int id) {
      72            0 :   if (!load_task_executor) {
      73            0 :     ml_loge("init is needed");
      74            0 :     return ML_ERROR_INVALID_PARAMETER;
      75              :   }
      76            0 :   checkUnloadComplete(id);
      77              : 
      78            0 :   std::lock_guard<std::mutex> lock(state_mutex);
      79              : 
      80            0 :   if (states[id] == LoadState::Loading || states[id] == LoadState::Loaded)
      81            0 :     return -1;
      82              : 
      83            0 :   states[id] = LoadState::Loading;
      84              : 
      85            0 :   int load_task_id = load_task_executor->submit(
      86            0 :     [this, id](void *data) {
      87            0 :       pool->loadTensor(id);
      88            0 :       std::lock_guard<std::mutex> lock(this->state_mutex);
      89            0 :       this->states[id] = LoadState::Loaded;
      90            0 :     },
      91            0 :     (void *)(std::uintptr_t)id);
      92              : 
      93            0 :   pool->getCacheElem(id).setLoadTaskID(load_task_id);
      94              : 
      95            0 :   return load_task_id;
      96              : }
      97              : 
      98            0 : bool CacheLoader::unloadAllinOrder(unsigned int order) {
      99            0 :   if (!load_task_executor) {
     100            0 :     ml_loge("init is needed");
     101            0 :     return false;
     102              :   }
     103              : 
     104            0 :   std::set<unsigned int> exec_id = pool->getExecIDs(order);
     105              : 
     106            0 :   for (auto &id : exec_id) {
     107            0 :     unloadTensor(id);
     108              :   }
     109              : 
     110              :   return true;
     111              : }
     112              : 
     113            0 : int CacheLoader::unloadTensor(unsigned int id) {
     114            0 :   if (!load_task_executor) {
     115            0 :     ml_loge("init is needed");
     116            0 :     return ML_ERROR_INVALID_PARAMETER;
     117              :   }
     118              : 
     119            0 :   checkLoadComplete(id);
     120              : 
     121            0 :   std::lock_guard<std::mutex> lock(state_mutex);
     122              : 
     123            0 :   if (states[id] != LoadState::Loaded)
     124              :     return -1;
     125              : 
     126            0 :   states[id] = LoadState::Unloading;
     127              : 
     128            0 :   int unload_task_id = load_task_executor->submit(
     129            0 :     [this, id](void *data) {
     130            0 :       pool->unloadTensor(id);
     131            0 :       std::lock_guard<std::mutex> lock(this->state_mutex);
     132            0 :       this->states[id] = LoadState::Idle;
     133            0 :     },
     134            0 :     (void *)(std::uintptr_t)id);
     135              : 
     136            0 :   pool->getCacheElem(id).setUnloadTaskID(unload_task_id);
     137            0 :   return unload_task_id;
     138              : }
     139              : 
     140            0 : LoadState CacheLoader::getState(int id) const {
     141            0 :   std::lock_guard<std::mutex> lock(state_mutex);
     142              :   auto it = states.find(id);
     143            0 :   if (it == states.end())
     144              :     return LoadState::Idle;
     145            0 :   return it->second;
     146              : }
     147              : 
     148            0 : int CacheLoader::flushAsync(unsigned int order,
     149              :                             TaskExecutor::CompleteCallback complete) {
     150            0 :   return flushAsync(order, complete, LONG_MAX);
     151              : }
     152              : 
     153            0 : int CacheLoader::flushAsync(unsigned int order,
     154              :                             TaskExecutor::CompleteCallback complete,
     155              :                             long timeout_ms) {
     156            0 :   if (!unload_task_executor) {
     157            0 :     ml_loge("init is needed");
     158            0 :     return ML_ERROR_INVALID_PARAMETER;
     159              :   }
     160              : 
     161            0 :   std::set<unsigned int> exec_id = pool->getExecIDs(order);
     162              : 
     163            0 :   for (auto &id : exec_id) {
     164            0 :     unloadTensor(id);
     165              :   }
     166              : 
     167              :   return 0;
     168              : }
     169            0 : void CacheLoader::flush() {
     170              :   auto actives = pool->getActiveElems();
     171              : 
     172            0 :   for (auto &id : actives) {
     173            0 :     unloadTensor(id);
     174              :   }
     175              : 
     176            0 :   for (auto &id : actives) {
     177            0 :     checkUnloadComplete(id);
     178              :   }
     179              : 
     180            0 :   pool->flush();
     181            0 : }
     182              : 
     183            0 : int CacheLoader::cancelAsync(int id) {
     184              :   try {
     185            0 :     load_task_executor->cancel(id);
     186            0 :   } catch (const std::exception &e) {
     187            0 :     ml_loge("CacheLoader(%s): failed to cancel(%d): %s",
     188              :             pool->getName().c_str(), id, e.what());
     189              :     return ML_ERROR_UNKNOWN;
     190            0 :   }
     191              : 
     192              :   return ML_ERROR_NONE;
     193              : }
     194              : 
     195            0 : unsigned int CacheLoader::inActive(unsigned int order) {
     196            0 :   std::set<unsigned int> exec_id = pool->getExecIDs(order);
     197            0 :   for (auto &id : exec_id) {
     198            0 :     auto &elem = pool->getCacheElem(id);
     199              :     int load_task_id = elem.getLoadTaskID();
     200            0 :     if (load_task_id >= 0) {
     201            0 :       load_task_executor->releaseTask(load_task_id);
     202              :       elem.setLoadTaskID(-1);
     203            0 :       states[id] = LoadState::Unloading;
     204              :     }
     205            0 :     pool->inActive(id);
     206              :   }
     207            0 :   return 0;
     208              : }
     209              : 
     210            0 : bool CacheLoader::checkAllLoadComplete(unsigned int order) {
     211              : 
     212            0 :   std::set<unsigned int> exec_id = pool->getExecIDs(order);
     213              : 
     214            0 :   for (auto &id : exec_id) {
     215            0 :     checkLoadComplete(id);
     216              :   }
     217            0 :   return true;
     218              : }
     219              : 
     220            0 : bool CacheLoader::checkAllUnloadComplete(unsigned int order) {
     221              : 
     222            0 :   std::set<unsigned int> exec_id = pool->getExecIDs(order);
     223              : 
     224            0 :   for (auto &id : exec_id) {
     225            0 :     checkUnloadComplete(id);
     226              :   }
     227            0 :   return true;
     228              : }
     229              : 
     230            0 : bool CacheLoader::checkLoadComplete(unsigned int id) {
     231            0 :   auto &elem = pool->getCacheElem(id);
     232              :   int unload_task_id = elem.getUnloadTaskID();
     233              :   int load_task_id = elem.getLoadTaskID();
     234            0 :   if (unload_task_id >= 0) {
     235            0 :     load_task_executor->releaseTask(unload_task_id);
     236              :     elem.setUnloadTaskID(-1);
     237              :   }
     238              : 
     239            0 :   if (load_task_id >= 0) {
     240            0 :     load_task_executor->wait(load_task_id);
     241              :   }
     242              : 
     243            0 :   return true;
     244              : }
     245              : 
     246            0 : bool CacheLoader::checkUnloadComplete(unsigned int id) {
     247            0 :   auto &elem = pool->getCacheElem(id);
     248              :   int unload_task_id = elem.getUnloadTaskID();
     249              :   int load_task_id = elem.getLoadTaskID();
     250            0 :   if (load_task_id >= 0) {
     251            0 :     load_task_executor->releaseTask(load_task_id);
     252              :     elem.setLoadTaskID(-1);
     253              :   }
     254            0 :   if (unload_task_id >= 0) {
     255            0 :     load_task_executor->wait(unload_task_id);
     256              :   }
     257            0 :   return true;
     258              : }
     259              : 
     260              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1