LCOV - code coverage report
Current view: top level - nntrainer/tensor - cache_pool.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 0.0 % 161 0
Test Date: 2025-12-14 20:38:17 Functions: 0.0 % 29 0

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2022 Jiho Chu <jiho.chu@samsung.com>
       4              :  *
       5              :  * @file   cache_pool.cpp
       6              :  * @date   01 July 2022
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Jiho Chu <jiho.chu@samsung.com>
       9              :  * @bug    No known bugs except for NYI items
      10              :  * @brief  Cache pool class inherited from memory pool
      11              :  *
      12              :  */
      13              : 
      14              : #include "cache_pool.h"
      15              : 
      16              : #include <limits>
      17              : #include <numeric>
      18              : #include <stdexcept>
      19              : #include <vector>
      20              : 
      21              : #include <nntrainer_error.h>
      22              : #include <nntrainer_log.h>
      23              : #include <profiler.h>
      24              : 
      25              : namespace nntrainer {
      26              : 
      27              : namespace {
      28              : 
      29              : /**
      30              :  * @brief convert tensor lifespan to cache policy
      31              :  *
      32              :  * @param lifespand tensor lifespan
      33              :  * @return cache policy
      34              :  * @note cache policy is defined by tensor's lifetime. If it needs to be
      35              :  * maintained its value, ALWAYS_SYNCED or ITERATION_CONSIST is proper. If not,
      36              :  * TEMPORAL doesnot keep its value.
      37              :  */
      38              : inline const CachePolicy
      39            0 : convertTensorLifespanToCachePolicy(const TensorLifespan lifespan) {
      40              :   CachePolicy policy;
      41              : 
      42            0 :   switch (lifespan) {
      43              :   case TensorLifespan::UNMANAGED:
      44              :     policy = CachePolicy::ALWAYS_SYNCED;
      45              :     break;
      46              :   case TensorLifespan::FORWARD_FUNC_LIFESPAN:
      47              :     policy = CachePolicy::TEMPORAL;
      48              :     break;
      49            0 :   case TensorLifespan::FORWARD_INFER_LIFESPAN:
      50              :     policy = CachePolicy::SYNC_ONCE;
      51            0 :     break;
      52              :   case TensorLifespan::CALC_DERIV_LIFESPAN:
      53              :     policy = CachePolicy::TEMPORAL;
      54              :     break;
      55              :   case TensorLifespan::CALC_GRAD_LIFESPAN:
      56              :     policy = CachePolicy::TEMPORAL;
      57              :     break;
      58              :   case TensorLifespan::CALC_AGRAD_LIFESPAN:
      59              :     policy = CachePolicy::TEMPORAL;
      60              :     break;
      61              :   case TensorLifespan::CALC_GRAD_DERIV_LIFESPAN:
      62              :     policy = CachePolicy::TEMPORAL;
      63              :     break;
      64              :   case TensorLifespan::CALC_GRAD_DERIV_AGRAD_LIFESPAN:
      65              :     policy = CachePolicy::ITERATION_CONSIST;
      66              :     break;
      67              :   case TensorLifespan::FORWARD_GRAD_LIFESPAN:
      68              :     policy = CachePolicy::ITERATION_CONSIST;
      69              :     break;
      70              :   case TensorLifespan::FORWARD_GRAD_AGRAD_LIFESPAN:
      71              :     policy = CachePolicy::ITERATION_CONSIST;
      72              :     break;
      73              :   case TensorLifespan::FORWARD_DERIV_LIFESPAN:
      74              :     policy = CachePolicy::ALWAYS_SYNCED;
      75              :     break;
      76              :   case TensorLifespan::ITERATION_LIFESPAN:
      77              :     policy = CachePolicy::ITERATION_CONSIST;
      78              :     break;
      79              :   case TensorLifespan::EPOCH_LIFESPAN:
      80              :     policy = CachePolicy::ITERATION_CONSIST;
      81              :     break;
      82              :   case TensorLifespan::MAX_LIFESPAN:
      83              :     policy = CachePolicy::ALWAYS_SYNCED;
      84              :     break;
      85              :   default:
      86              :     policy = CachePolicy::ALWAYS_SYNCED;
      87              :     break;
      88              :   }
      89              : 
      90            0 :   return policy;
      91              : }
      92              : 
      93              : std::atomic_int pool_id = 0;
      94              : 
      95              : } // namespace
      96              : 
      97            0 : CachePool::CachePool(const std::string &n) :
      98            0 :   name(n),
      99            0 :   execution_mode_(ml::train::ExecutionMode::TRAIN),
     100            0 :   swap_device(std::make_shared<SwapDevice>(n + "_" + std::to_string(getpid()) +
     101            0 :                                            "_" + std::to_string(pool_id++))) {}
     102              : 
     103            0 : CachePool::CachePool(const std::string &path, const std::string &n) :
     104            0 :   name(n), execution_mode_(ml::train::ExecutionMode::TRAIN) {
     105            0 :   if (path.empty())
     106            0 :     swap_device = std::make_shared<SwapDevice>(
     107            0 :       n + "_" + std::to_string(getpid()) + "_" + std::to_string(pool_id++));
     108              :   else
     109              :     swap_device =
     110            0 :       std::make_shared<SwapDevice>(path, n + "_" + std::to_string(getpid()) +
     111            0 :                                            "_" + std::to_string(pool_id++));
     112            0 : }
     113              : 
     114            0 : CachePool::CachePool(const std::string &path, const std::string &name_,
     115            0 :                      ml::train::ExecutionMode exec_mode_) :
     116            0 :   name(name_), execution_mode_(exec_mode_) {
     117            0 :   if (path.empty())
     118            0 :     swap_device = std::make_shared<SwapDevice>(
     119            0 :       name_ + "_" + std::to_string(getpid()) + "_" + std::to_string(pool_id++));
     120              :   else
     121            0 :     swap_device = std::make_shared<SwapDevice>(
     122              :       path,
     123            0 :       name_ + "_" + std::to_string(getpid()) + "_" + std::to_string(pool_id++));
     124            0 : }
     125              : 
     126            0 : CachePool::~CachePool() {
     127              :   try {
     128            0 :     deallocate();
     129            0 :   } catch (...) {
     130            0 :     ml_loge("Failed deallocate");
     131            0 :   }
     132            0 : }
     133              : 
     134            0 : void CachePool::inActive(unsigned int order) {
     135              : 
     136              :   auto exec_id = exec_ids[order];
     137            0 :   std::lock_guard<std::mutex> lock(mutex);
     138            0 :   for (auto &id : exec_id) {
     139              :     actives.erase(id);
     140              :     elems[id]->inActive();
     141              :   }
     142            0 : }
     143              : 
     144            0 : void CachePool::allocate() {
     145            0 :   NNTR_THROW_IF(swap_device->isOperating(), std::runtime_error)
     146              :     << "Cache pool is already allocated";
     147              : 
     148            0 :   size_t pool_size = size();
     149              : 
     150            0 :   NNTR_THROW_IF(pool_size == 0, std::runtime_error)
     151              :     << "Allocating memory pool with size 0";
     152            0 :   if (execution_mode_ == ml::train::ExecutionMode::INFERENCE)
     153            0 :     MemoryPool::allocateFSU();
     154            0 :   swap_device->start(size(), execution_mode_);
     155            0 : }
     156              : 
     157            0 : void CachePool::deallocate() {
     158            0 :   MemoryPool::deallocate();
     159            0 :   if (!swap_device->isOperating())
     160              :     return;
     161              : 
     162            0 :   if (execution_mode_ == ml::train::ExecutionMode::INFERENCE)
     163            0 :     MemoryPool::deallocate();
     164              : 
     165            0 :   for (auto &[id, elem] : elems)
     166            0 :     invalidate(id);
     167              : 
     168              :   actives.clear();
     169            0 :   swap_device->finish();
     170              : }
     171              : 
     172            0 : void CachePool::validate(unsigned int id) {
     173            0 :   if (!elems[id]->isActive()) {
     174            0 :     elems[id]->swapIn();
     175            0 :     std::lock_guard<std::mutex> lock(mutex);
     176              :     actives.insert(id);
     177              :   }
     178            0 : }
     179              : 
     180            0 : void CachePool::invalidate(unsigned int id) {
     181            0 :   if (elems[id]->isActive()) {
     182            0 :     elems[id]->swapOut();
     183            0 :     std::lock_guard<std::mutex> lock(mutex);
     184              :     actives.erase(id);
     185              :   }
     186            0 : }
     187              : 
     188            0 : unsigned int CachePool::requestMemory(size_t bytes, unsigned int start_time,
     189              :                                       unsigned int end_time,
     190              :                                       std::vector<unsigned int> exec_order,
     191              :                                       TensorLifespan lifespan, bool is_wgrad) {
     192            0 :   auto id = MemoryPool::requestMemory(bytes, start_time, end_time, exec_order,
     193              :                                       lifespan, is_wgrad);
     194              : 
     195            0 :   const CachePolicy policy = convertTensorLifespanToCachePolicy(lifespan);
     196              : 
     197            0 :   policies.push_back(policy);
     198              : 
     199            0 :   NNTR_THROW_IF(id != policies.size(), std::runtime_error)
     200              :     << "Invalid requestMemory call exist";
     201              : 
     202            0 :   return id;
     203              : }
     204              : 
     205            0 : std::shared_ptr<MemoryData> CachePool::getMemory(unsigned int id) {
     206            0 :   NNTR_THROW_IF(!swap_device->isOperating(), std::invalid_argument)
     207              :     << "Allocate memory before allocation";
     208              : 
     209            0 :   off_t offset = getMemoryOffset().at(id - 1);
     210            0 :   size_t len = getMemorySize().at(id - 1);
     211            0 :   auto exe_order = getMemoryExecOrder().at(id - 1);
     212            0 :   auto policy = getCachePolicy().at(id - 1);
     213              : 
     214            0 :   void *memory_ptr = nullptr;
     215            0 :   if (execution_mode_ == ml::train::ExecutionMode::INFERENCE) {
     216            0 :     memory_ptr = getMemoryPtrs().at(id - 1);
     217              :   }
     218              : 
     219              :   auto mem_data = std::make_shared<MemoryData>(
     220            0 :     id, std::bind(&CachePool::validate, this, std::placeholders::_1),
     221            0 :     std::bind(&CachePool::invalidate, this, std::placeholders::_1), memory_ptr);
     222              : 
     223            0 :   elems.emplace(id, std::make_unique<CacheElem>(swap_device, id, offset, len,
     224              :                                                 mem_data, policy, memory_ptr));
     225              : 
     226              :   std::string ords;
     227              : 
     228            0 :   if (execution_mode_ == ml::train::ExecutionMode::INFERENCE) {
     229              :     auto &o = exe_order[0];
     230              :     exec_ids[o].insert(id);
     231            0 :     ords.append(std::to_string(o));
     232              :   } else {
     233            0 :     for (auto &o : exe_order) {
     234              :       exec_ids[o].insert(id);
     235            0 :       ords.append(std::to_string(o));
     236              :     }
     237              :   }
     238            0 :   ml_logd("[%d] exe_order(%s), offset: %llu, len: %zu", id, ords.c_str(),
     239              :           (long long unsigned int)offset, len);
     240              : 
     241            0 :   return mem_data;
     242            0 : }
     243              : 
     244            0 : void CachePool::flush() {
     245            0 :   for (auto &id : actives) {
     246            0 :     elems[id]->swapOut(CacheElem::LAST_ACCESS);
     247              :   }
     248              : 
     249            0 :   for (auto &[id, elem] : elems)
     250              :     elem->reset();
     251              : 
     252              :   actives.clear();
     253            0 : }
     254              : 
     255            0 : void CachePool::flushExcept(unsigned int order) {
     256            0 :   auto exe_orders = getMemoryExecOrder();
     257              : 
     258            0 :   eraseActiveIf([&, order](const unsigned int id) -> bool {
     259            0 :     auto exe_order = exe_orders.at(id - 1);
     260            0 :     auto found = std::find(exe_order.begin(), exe_order.end(), order);
     261            0 :     if (found != exe_order.end()) {
     262              :       /**
     263              :        * We assumes that flushExcept will be called in front of each execution
     264              :        * order, and the order is incremental. So, we can conclude that, if the
     265              :        * order passes by the max order of the cache element, it was LAST
     266              :        * access of the element.
     267              :        */
     268              :       CacheElem::Options opt = CacheElem::NONE;
     269            0 :       if (*std::max_element(exe_order.begin(), exe_order.end()) < order)
     270              :         opt = CacheElem::LAST_ACCESS;
     271            0 :       elems[id]->swapOut(opt);
     272              :       return true;
     273              :     }
     274              :     return false;
     275            0 :   });
     276            0 : }
     277              : 
     278            0 : void CachePool::flushExcept(std::vector<unsigned int> order) {
     279            0 :   auto exe_orders = getMemoryExecOrder();
     280              : 
     281            0 :   eraseActiveIf([&, order](const unsigned int id) -> bool {
     282            0 :     auto exe_order = exe_orders.at(id - 1);
     283            0 :     for (auto &o : order) {
     284            0 :       auto found = std::find(exe_order.begin(), exe_order.end(), o);
     285            0 :       if (found != exe_order.end())
     286            0 :         return false;
     287              :     }
     288              :     /**
     289              :      * We assumes that flushExcept will be called in front of each execution
     290              :      * order, and the order is incremental. So, we can conclude that, if the
     291              :      * order passes by the max order of the cache element, it was LAST access of
     292              :      * the element.
     293              :      */
     294              :     CacheElem::Options opt = CacheElem::NONE;
     295            0 :     if (*std::max_element(exe_order.begin(), exe_order.end()) < order[0])
     296              :       opt = CacheElem::LAST_ACCESS;
     297            0 :     elems[id]->swapOut(opt);
     298              :     return true;
     299            0 :   });
     300            0 : }
     301              : 
     302            0 : void CachePool::clear() {
     303            0 :   flush();
     304            0 :   deallocate();
     305              :   policies.clear();
     306            0 :   MemoryPool::clear();
     307            0 : }
     308              : 
     309            0 : bool CachePool::isAllocated() const { return swap_device->isOperating(); }
     310              : 
     311            0 : void CachePool::loadExec(unsigned int order) {
     312            0 :   for (auto &id : exec_ids[order]) {
     313            0 :     validate(id);
     314              :   }
     315            0 : }
     316              : 
     317            0 : void CachePool::loadTensor(unsigned int id) { validate(id); }
     318              : 
     319            0 : bool CachePool::loadExecOnce(unsigned int order, ExecIdsIter &iter) {
     320            0 :   if (iter == exec_ids[order].end())
     321              :     return true;
     322              : 
     323            0 :   validate(*iter);
     324              : 
     325              :   iter++;
     326            0 :   return false;
     327              : }
     328              : 
     329            0 : void CachePool::unloadExec(unsigned int order) {
     330            0 :   for (auto &id : exec_ids[order]) {
     331            0 :     invalidate(id);
     332              :   }
     333              :   actives.clear();
     334            0 : }
     335              : 
     336            0 : void CachePool::unloadTensor(unsigned int order) {
     337            0 :   invalidate(order);
     338            0 :   std::lock_guard<std::mutex> lock(mutex);
     339              :   actives.erase(order);
     340            0 : }
     341              : 
     342            0 : void CachePool::loadActives() {
     343            0 :   ml_logd("load active caches");
     344              : 
     345            0 :   for (auto &id : actives) {
     346            0 :     elems[id]->swapIn();
     347              :   }
     348            0 : }
     349              : 
     350            0 : void CachePool::unloadActives() {
     351            0 :   ml_logd("unload active caches");
     352            0 :   for (auto &id : actives) {
     353            0 :     elems[id]->swapOut();
     354              :   }
     355            0 : }
     356              : 
     357            0 : void CachePool::setFsuWeightPath(std::string path) {
     358              :   auto start_with = [](const std::string &str, const std::string &prefix) {
     359            0 :     return str.size() >= prefix.size() &&
     360            0 :            str.compare(0, prefix.size(), prefix) == 0;
     361              :   };
     362              : 
     363            0 :   if (!start_with(swap_device->getDevicePath(), "weight_pool")) {
     364            0 :     remove(swap_device->getDevicePath().c_str());
     365              :   }
     366              : 
     367            0 :   swap_device->setFsuWeightPath(path);
     368            0 :   swap_device->finish();
     369            0 :   swap_device->start(size(), execution_mode_);
     370            0 : }
     371              : 
     372            0 : void CachePool::eraseActiveIf(
     373              :   const std::function<bool(unsigned int id)> &pred) {
     374            0 :   for (auto it = actives.begin(); it != actives.end();
     375            0 :        pred(*it) ? it = actives.erase(it) : ++it) {
     376              :   }
     377            0 : }
     378              : 
     379              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1