LCOV - code coverage report
Current view: top level - nntrainer/tensor - cache_pool.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 0.0 % 4 0
Test Date: 2025-12-14 20:38:17 Functions: 0.0 % 2 0

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2022 Jiho Chu <jiho.chu@samsung.com>
       4              :  *
       5              :  * @file   cache_pool.h
       6              :  * @date   01 July 2022
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Jiho Chu <jiho.chu@samsung.com>
       9              :  * @bug    No known bugs except for NYI items
      10              :  * @brief  Cache pool class inherited from memory pool
      11              :  *
      12              :  */
      13              : 
      14              : #ifndef __CACHE_POOL_H__
      15              : #define __CACHE_POOL_H__
      16              : 
      17              : #include <list>
      18              : #include <mutex>
      19              : #include <unordered_set>
      20              : #include <vector>
      21              : 
      22              : #include <cache_elem.h>
      23              : #include <common.h>
      24              : #include <memory_pool.h>
      25              : #include <swap_device.h>
      26              : 
      27              : namespace nntrainer {
      28              : 
      29              : /**
      30              :  * @class   CachePool
      31              :  * @brief   Cache memory with fixed size utilizing swap device
      32              :  */
      33              : class CachePool : public MemoryPool {
      34              : public:
      35              :   using CacheElems =
      36              :     std::unordered_map<unsigned int,
      37              :                        std::unique_ptr<CacheElem>>; /**< cache id, cache elem */
      38              :   using CacheElemsIter = CacheElems::iterator;
      39              :   using ExecIds = std::set<unsigned int>;
      40              :   using ExecIdsIter = ExecIds::iterator;
      41              : 
      42              :   /**
      43              :    * @brief CachePool default constructor
      44              :    *
      45              :    * @param name name of the cache pool
      46              :    */
      47              :   explicit CachePool(const std::string &name);
      48              : 
      49              :   /**
      50              :    * @brief CachePool constructor with cache path
      51              :    *
      52              :    */
      53              :   explicit CachePool(const std::string &path, const std::string &name);
      54              : 
      55              :   /**
      56              :    * @brief CachePool constructor with cache path & ExecutionMode
      57              :    *
      58              :    */
      59              :   explicit CachePool(
      60              :     const std::string &path, const std::string &name,
      61              :     ml::train::ExecutionMode exec_mode = ml::train::ExecutionMode::TRAIN);
      62              : 
      63              :   /**
      64              :    * @brief MemoryPool destructor
      65              :    *
      66              :    */
      67              :   virtual ~CachePool();
      68              : 
      69              :   /**
      70              :    * @brief inactive elements
      71              :    *
      72              :    * @param order order to inactive
      73              :    */
      74              :   void inActive(unsigned int order);
      75              : 
      76              :   /**
      77              :    * @brief Do the allocation of cache
      78              :    *
      79              :    */
      80              :   virtual void allocate() override;
      81              : 
      82              :   /**
      83              :    * @brief Free all the allocated cache
      84              :    *
      85              :    */
      86              :   virtual void deallocate() override;
      87              : 
      88              :   /**
      89              :    * @brief Request Memory from memory pool
      90              :    * @note start_time is inclusive, but end_time is exclusive
      91              :    */
      92              :   virtual unsigned int requestMemory(
      93              :     size_t bytes, unsigned int start_time, unsigned int end_time,
      94              :     std::vector<unsigned int> exec_order = std::vector<unsigned int>(),
      95              :     TensorLifespan lifespan = TensorLifespan::MAX_LIFESPAN,
      96              :     bool is_wgrad = false) override;
      97              :   /**
      98              :    * @brief Get the allocated cache
      99              :    *
     100              :    * @param id The token received from the requestMemory
     101              :    *
     102              :    * @return The pointer of the cache
     103              :    *
     104              :    * @details This function will throw if called before allocation.
     105              :    */
     106              :   virtual std::shared_ptr<MemoryData> getMemory(unsigned int id) override;
     107              : 
     108              :   /**
     109              :    * @brief Is the cache pool allocated
     110              :    *
     111              :    * @return true if the memory is allocated, else false
     112              :    */
     113              :   virtual bool isAllocated() const override;
     114              : 
     115              :   /**
     116              :    * @brief Flush cache data to device
     117              :    *
     118              :    * @note it must be called only when epoch ends.
     119              :    */
     120              :   virtual void flush();
     121              : 
     122              :   /**
     123              :    * @brief Flush cache data to device except given order
     124              :    *
     125              :    * @param order except execution order
     126              :    */
     127              :   virtual void flushExcept(unsigned int order);
     128              : 
     129              :   /**
     130              :    * @brief Flush cache data to device except given order
     131              :    *
     132              :    * @param order except execution order
     133              :    */
     134              :   virtual void flushExcept(std::vector<unsigned int> order);
     135              : 
     136              :   /**
     137              :    * @brief Clear the memory pool
     138              :    *
     139              :    */
     140              :   virtual void clear() override;
     141              : 
     142              :   /**
     143              :    * @brief Load cache data by execution order
     144              :    *
     145              :    * @param order execution order
     146              :    */
     147              :   virtual void loadExec(unsigned int order);
     148              : 
     149              :   /**
     150              :    * @brief Load Tensor
     151              :    *
     152              :    * @param order order of Tensor to load
     153              :    */
     154              :   virtual void loadTensor(unsigned int order);
     155              : 
     156              :   /**
     157              :    * @brief Load cache data by execution order
     158              :    *
     159              :    * @param order execution order
     160              :    */
     161              :   virtual bool loadExecOnce(unsigned int order, ExecIdsIter &iter);
     162              : 
     163              :   /**
     164              :    * @brief Unload cache data by execution order
     165              :    *
     166              :    * @param order execution order
     167              :    */
     168              :   virtual void unloadExec(unsigned int order);
     169              : 
     170              :   /**
     171              :    * @brief Unload Tensor
     172              :    *
     173              :    * @param order order of Tensor to unload
     174              :    */
     175              :   virtual void unloadTensor(unsigned int order);
     176              : 
     177              :   /**
     178              :    * @brief Load active cache data
     179              :    */
     180              :   virtual void loadActives();
     181              : 
     182              :   /**
     183              :    * @brief Unload active cache data
     184              :    */
     185              :   virtual void unloadActives();
     186              : 
     187              :   /**
     188              :    * @brief Get name
     189              :    *
     190              :    * @return cache pool name
     191              :    */
     192            0 :   virtual std::string getName() { return name; }
     193              : 
     194              :   /**
     195              :    * @brief Get ExecutionMode
     196              :    *
     197              :    * @return ml::train::ExecutionMode
     198              :    */
     199              :   ml::train::ExecutionMode getExecMode() const { return execution_mode_; }
     200              : 
     201              :   /**
     202              :    * @brief set FSU weight path
     203              :    *
     204              :    * @param path FSU weight file path
     205              :    */
     206              :   void setFsuWeightPath(std::string path) override;
     207              : 
     208              :   /**
     209              :    * @brief set weight file offset for FSU loading
     210              :    *
     211              :    * @param offsets weight file offset
     212              :    */
     213              :   void
     214            0 :   setWeightOffset(std::vector<std::pair<size_t, size_t>> offsets) override {
     215            0 :     swap_device->setWeightOffset(offsets);
     216            0 :   }
     217              : 
     218              :   /**
     219              :    * @brief get Tensor ID set in order
     220              :    *
     221              :    * @param order Execution order
     222              :    * @return Tensor id set
     223              :    */
     224              :   std::set<unsigned int> getExecIDs(unsigned int order) {
     225              :     return exec_ids[order];
     226              :   }
     227              : 
     228              :   /**
     229              :    * @brief get Active Cache Elem lists
     230              :    *
     231              :    * @return Active Cache Elem list
     232              :    */
     233              :   std::unordered_set<unsigned int> getActiveElems() { return actives; }
     234              : 
     235              :   /**
     236              :    * @brief get Cache Elem with id
     237              :    * @param id Tensor ID
     238              :    * @return Cache Elem
     239              :    */
     240              :   CacheElem &getCacheElem(unsigned int id) { return *elems[id]; }
     241              : 
     242              :   /**
     243              :    * @brief check Cache Elem with id is loaded (Active)
     244              :    * @param id Tensor ID
     245              :    * @return true if it is loaded
     246              :    */
     247              :   bool isLoaded(unsigned int id) { return elems[id]->isActive(); }
     248              : 
     249              : protected:
     250              :   /**
     251              :    * @brief validate cache element
     252              :    *
     253              :    * @param cache element id
     254              :    */
     255              :   virtual void validate(unsigned int id);
     256              : 
     257              :   /**
     258              :    * @brief invalidate cache element
     259              :    *
     260              :    * @param cache element id
     261              :    */
     262              :   virtual void invalidate(unsigned int id);
     263              : 
     264              :   /**
     265              :    * @brief Get cache policies
     266              :    *
     267              :    * @return Cache polices
     268              :    */
     269              :   std::vector<CachePolicy> &getCachePolicy() { return policies; }
     270              : 
     271              : private:
     272              :   void eraseActiveIf(const std::function<bool(unsigned int id)> &pred);
     273              : 
     274              :   std::string name;                         /**< pool name */
     275              :   ml::train::ExecutionMode execution_mode_; /**< execution mode */
     276              :   std::shared_ptr<SwapDevice> swap_device;  /**< swap device */
     277              :   CacheElems elems;                         /**< cache elements */
     278              :   std::unordered_set<unsigned int> actives;
     279              :   std::vector<CachePolicy> policies;
     280              :   std::unordered_map<unsigned int, ExecIds> exec_ids;
     281              : 
     282              :   mutable std::mutex mutex;
     283              : };
     284              : 
     285              : } // namespace nntrainer
     286              : 
     287              : #endif /** __CACHE_POOL_H__ */
        

Generated by: LCOV version 2.0-1