LCOV - code coverage report
Current view: top level - nntrainer/tensor - memory_data.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 50.0 % 24 12
Test Date: 2025-12-14 20:38:17 Functions: 66.7 % 3 2

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2022 Jiho Chu <jiho.chu@samsung.com>
       4              :  *
       5              :  * @file   memory_data.h
       6              :  * @date   14 Oct 2022
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Jiho Chu <jiho.chu@samsung.com>
       9              :  * @bug    No known bugs except for NYI items
      10              :  * @brief  MemoryData class
      11              :  *
      12              :  */
      13              : 
      14              : #ifndef __MEMORY_DATA_H__
      15              : #define __MEMORY_DATA_H__
      16              : 
      17              : #include <functional>
      18              : 
      19              : namespace nntrainer {
      20              : 
      21              : using MemoryDataValidateCallback = std::function<void(unsigned int)>;
      22              : 
      23              : /**
      24              :  * @brief  MemoryData Class
      25              :  */
      26              : class MemoryData {
      27              :   /**
      28              :    * @brief MemoryPool is granted friend access to call setSVM()
      29              :    * @details This restricts the ability to modify the SVM allocation flag
      30              :    *          to only MemoryPool::getMemory(), preventing malicious or
      31              :    *          accidental modification from other parts of the codebase.
      32              :    */
      33              :   friend class MemoryPool;
      34              : 
      35              : public:
      36              :   /**
      37              :    * @brief  Constructor of Memory Data
      38              :    * @param[in] addr Memory data
      39              :    */
      40       406808 :   explicit MemoryData(void *addr) :
      41       406808 :     valid(true),
      42       406808 :     id(0),
      43       406808 :     address(addr),
      44              :     validate_cb([](unsigned int) {}),
      45              :     invalidate_cb([](unsigned int) {}),
      46       406808 :     svm_allocation(false) {}
      47              : 
      48              :   /**
      49              :    * @brief  Constructor of Memory Data
      50              :    * @param[in] mem_id validate callback.
      51              :    * @param[in] v_cb validate callback.
      52              :    * @param[in] i_cb invalidate callback.
      53              :    */
      54            0 :   explicit MemoryData(unsigned int mem_id, MemoryDataValidateCallback v_cb,
      55              :                       MemoryDataValidateCallback i_cb,
      56            0 :                       void *memory_ptr = nullptr) :
      57            0 :     valid(false),
      58            0 :     id(mem_id),
      59            0 :     address(memory_ptr),
      60            0 :     validate_cb(v_cb),
      61            0 :     invalidate_cb(i_cb),
      62            0 :     svm_allocation(false) {}
      63              : 
      64              :   /**
      65              :    * @brief  Deleted constructor of Memory Data
      66              :    */
      67              :   explicit MemoryData() = delete;
      68              : 
      69              :   /**
      70              :    * @brief  Constructor of MemoryData
      71              :    */
      72              :   explicit MemoryData(MemoryDataValidateCallback v_cb,
      73              :                       MemoryDataValidateCallback i_cb) = delete;
      74              :   /**
      75              :    * @brief  Constructor of MemoryData
      76              :    */
      77              :   explicit MemoryData(void *addr, MemoryDataValidateCallback v_cb,
      78              :                       MemoryDataValidateCallback i_cb) = delete;
      79              : 
      80              :   /**
      81              :    * @brief  Destructor of Memory Data
      82              :    */
      83       748388 :   virtual ~MemoryData() = default;
      84              : 
      85              :   /**
      86              :    * @brief  Set address
      87              :    */
      88            0 :   void setAddr(void *addr) { address = addr; }
      89              : 
      90              :   /**
      91              :    * @brief  Get address
      92              :    */
      93              :   template <typename T = float> T *getAddr() const {
      94    136521695 :     return static_cast<T *>(address);
      95              :   }
      96              : 
      97              :   /**
      98              :    * @brief  Validate memory data
      99              :    */
     100              :   void validate() {
     101    136180265 :     if (valid)
     102              :       return;
     103            0 :     if (validate_cb != nullptr)
     104            0 :       validate_cb(id);
     105              :   }
     106              : 
     107              :   /**
     108              :    * @brief  Invalidate memory data
     109              :    */
     110              :   void invalidate() {
     111       386074 :     if (!valid)
     112              :       return;
     113       386074 :     if (invalidate_cb != nullptr)
     114       386074 :       invalidate_cb(id);
     115              :   }
     116              : 
     117              :   /**
     118              :    * @brief  Set valid
     119              :    */
     120            0 :   void setValid(bool v) { valid = v; }
     121              : 
     122              :   /**
     123              :    * @brief   Check if data is a shared virtual memory
     124              :    */
     125              :   bool isSVM() const { return svm_allocation; }
     126              : 
     127              : private:
     128              :   /**
     129              :    * @brief  Set SVM allocation flag (private - only accessible by MemoryPool)
     130              :    * @param[in] is_svm True if this memory is a shared virtual memory
     131              :    * @note This method is intentionally private to prevent modification of the
     132              :    *       SVM flag after MemoryData creation. Only MemoryPool (friend class)
     133              :    *       can call this during memory allocation to ensure data integrity.
     134              :    */
     135        19092 :   void setSVM(bool is_svm) { svm_allocation = is_svm; }
     136              : 
     137              :   bool valid;
     138              :   unsigned int id;
     139              :   void *address;
     140              :   MemoryDataValidateCallback validate_cb;
     141              :   MemoryDataValidateCallback invalidate_cb;
     142              :   bool svm_allocation;
     143              : };
     144              : 
     145              : } // namespace nntrainer
     146              : 
     147              : #endif /* __MEMORY_DATA_H__ */
        

Generated by: LCOV version 2.0-1