Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2022 Jiho Chu <jiho.chu@samsung.com>
4 : *
5 : * @file memory_data.h
6 : * @date 14 Oct 2022
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Jiho Chu <jiho.chu@samsung.com>
9 : * @bug No known bugs except for NYI items
10 : * @brief MemoryData class
11 : *
12 : */
13 :
14 : #ifndef __MEMORY_DATA_H__
15 : #define __MEMORY_DATA_H__
16 :
17 : #include <functional>
18 :
19 : namespace nntrainer {
20 :
21 : using MemoryDataValidateCallback = std::function<void(unsigned int)>;
22 :
23 : /**
24 : * @brief MemoryData Class
25 : */
26 : class MemoryData {
27 : /**
28 : * @brief MemoryPool is granted friend access to call setSVM()
29 : * @details This restricts the ability to modify the SVM allocation flag
30 : * to only MemoryPool::getMemory(), preventing malicious or
31 : * accidental modification from other parts of the codebase.
32 : */
33 : friend class MemoryPool;
34 :
35 : public:
36 : /**
37 : * @brief Constructor of Memory Data
38 : * @param[in] addr Memory data
39 : */
40 406808 : explicit MemoryData(void *addr) :
41 406808 : valid(true),
42 406808 : id(0),
43 406808 : address(addr),
44 : validate_cb([](unsigned int) {}),
45 : invalidate_cb([](unsigned int) {}),
46 406808 : svm_allocation(false) {}
47 :
48 : /**
49 : * @brief Constructor of Memory Data
50 : * @param[in] mem_id validate callback.
51 : * @param[in] v_cb validate callback.
52 : * @param[in] i_cb invalidate callback.
53 : */
54 0 : explicit MemoryData(unsigned int mem_id, MemoryDataValidateCallback v_cb,
55 : MemoryDataValidateCallback i_cb,
56 0 : void *memory_ptr = nullptr) :
57 0 : valid(false),
58 0 : id(mem_id),
59 0 : address(memory_ptr),
60 0 : validate_cb(v_cb),
61 0 : invalidate_cb(i_cb),
62 0 : svm_allocation(false) {}
63 :
64 : /**
65 : * @brief Deleted constructor of Memory Data
66 : */
67 : explicit MemoryData() = delete;
68 :
69 : /**
70 : * @brief Constructor of MemoryData
71 : */
72 : explicit MemoryData(MemoryDataValidateCallback v_cb,
73 : MemoryDataValidateCallback i_cb) = delete;
74 : /**
75 : * @brief Constructor of MemoryData
76 : */
77 : explicit MemoryData(void *addr, MemoryDataValidateCallback v_cb,
78 : MemoryDataValidateCallback i_cb) = delete;
79 :
80 : /**
81 : * @brief Destructor of Memory Data
82 : */
83 748388 : virtual ~MemoryData() = default;
84 :
85 : /**
86 : * @brief Set address
87 : */
88 0 : void setAddr(void *addr) { address = addr; }
89 :
90 : /**
91 : * @brief Get address
92 : */
93 : template <typename T = float> T *getAddr() const {
94 136521695 : return static_cast<T *>(address);
95 : }
96 :
97 : /**
98 : * @brief Validate memory data
99 : */
100 : void validate() {
101 136180265 : if (valid)
102 : return;
103 0 : if (validate_cb != nullptr)
104 0 : validate_cb(id);
105 : }
106 :
107 : /**
108 : * @brief Invalidate memory data
109 : */
110 : void invalidate() {
111 386074 : if (!valid)
112 : return;
113 386074 : if (invalidate_cb != nullptr)
114 386074 : invalidate_cb(id);
115 : }
116 :
117 : /**
118 : * @brief Set valid
119 : */
120 0 : void setValid(bool v) { valid = v; }
121 :
122 : /**
123 : * @brief Check if data is a shared virtual memory
124 : */
125 : bool isSVM() const { return svm_allocation; }
126 :
127 : private:
128 : /**
129 : * @brief Set SVM allocation flag (private - only accessible by MemoryPool)
130 : * @param[in] is_svm True if this memory is a shared virtual memory
131 : * @note This method is intentionally private to prevent modification of the
132 : * SVM flag after MemoryData creation. Only MemoryPool (friend class)
133 : * can call this during memory allocation to ensure data integrity.
134 : */
135 19092 : void setSVM(bool is_svm) { svm_allocation = is_svm; }
136 :
137 : bool valid;
138 : unsigned int id;
139 : void *address;
140 : MemoryDataValidateCallback validate_cb;
141 : MemoryDataValidateCallback invalidate_cb;
142 : bool svm_allocation;
143 : };
144 :
145 : } // namespace nntrainer
146 :
147 : #endif /* __MEMORY_DATA_H__ */
|