Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2022 Jiho Chu <jiho.chu@samsung.com>
4 : *
5 : * @file cache_pool.h
6 : * @date 01 July 2022
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Jiho Chu <jiho.chu@samsung.com>
9 : * @bug No known bugs except for NYI items
10 : * @brief Cache pool class inherited from memory pool
11 : *
12 : */
13 :
14 : #ifndef __CACHE_POOL_H__
15 : #define __CACHE_POOL_H__
16 :
17 : #include <list>
18 : #include <mutex>
19 : #include <unordered_set>
20 : #include <vector>
21 :
22 : #include <cache_elem.h>
23 : #include <common.h>
24 : #include <memory_pool.h>
25 : #include <swap_device.h>
26 :
27 : namespace nntrainer {
28 :
29 : /**
30 : * @class CachePool
31 : * @brief Cache memory with fixed size utilizing swap device
32 : */
33 : class CachePool : public MemoryPool {
34 : public:
35 : using CacheElems =
36 : std::unordered_map<unsigned int,
37 : std::unique_ptr<CacheElem>>; /**< cache id, cache elem */
38 : using CacheElemsIter = CacheElems::iterator;
39 : using ExecIds = std::set<unsigned int>;
40 : using ExecIdsIter = ExecIds::iterator;
41 :
42 : /**
43 : * @brief CachePool default constructor
44 : *
45 : * @param name name of the cache pool
46 : */
47 : explicit CachePool(const std::string &name);
48 :
49 : /**
50 : * @brief CachePool constructor with cache path
51 : *
52 : */
53 : explicit CachePool(const std::string &path, const std::string &name);
54 :
55 : /**
56 : * @brief CachePool constructor with cache path & ExecutionMode
57 : *
58 : */
59 : explicit CachePool(
60 : const std::string &path, const std::string &name,
61 : ml::train::ExecutionMode exec_mode = ml::train::ExecutionMode::TRAIN);
62 :
63 : /**
64 : * @brief MemoryPool destructor
65 : *
66 : */
67 : virtual ~CachePool();
68 :
69 : /**
70 : * @brief inactive elements
71 : *
72 : * @param order order to inactive
73 : */
74 : void inActive(unsigned int order);
75 :
76 : /**
77 : * @brief Do the allocation of cache
78 : *
79 : */
80 : virtual void allocate() override;
81 :
82 : /**
83 : * @brief Free all the allocated cache
84 : *
85 : */
86 : virtual void deallocate() override;
87 :
88 : /**
89 : * @brief Request Memory from memory pool
90 : * @note start_time is inclusive, but end_time is exclusive
91 : */
92 : virtual unsigned int requestMemory(
93 : size_t bytes, unsigned int start_time, unsigned int end_time,
94 : std::vector<unsigned int> exec_order = std::vector<unsigned int>(),
95 : TensorLifespan lifespan = TensorLifespan::MAX_LIFESPAN,
96 : bool is_wgrad = false) override;
97 : /**
98 : * @brief Get the allocated cache
99 : *
100 : * @param id The token received from the requestMemory
101 : *
102 : * @return The pointer of the cache
103 : *
104 : * @details This function will throw if called before allocation.
105 : */
106 : virtual std::shared_ptr<MemoryData> getMemory(unsigned int id) override;
107 :
108 : /**
109 : * @brief Is the cache pool allocated
110 : *
111 : * @return true if the memory is allocated, else false
112 : */
113 : virtual bool isAllocated() const override;
114 :
115 : /**
116 : * @brief Flush cache data to device
117 : *
118 : * @note it must be called only when epoch ends.
119 : */
120 : virtual void flush();
121 :
122 : /**
123 : * @brief Flush cache data to device except given order
124 : *
125 : * @param order except execution order
126 : */
127 : virtual void flushExcept(unsigned int order);
128 :
129 : /**
130 : * @brief Flush cache data to device except given order
131 : *
132 : * @param order except execution order
133 : */
134 : virtual void flushExcept(std::vector<unsigned int> order);
135 :
136 : /**
137 : * @brief Clear the memory pool
138 : *
139 : */
140 : virtual void clear() override;
141 :
142 : /**
143 : * @brief Load cache data by execution order
144 : *
145 : * @param order execution order
146 : */
147 : virtual void loadExec(unsigned int order);
148 :
149 : /**
150 : * @brief Load Tensor
151 : *
152 : * @param order order of Tensor to load
153 : */
154 : virtual void loadTensor(unsigned int order);
155 :
156 : /**
157 : * @brief Load cache data by execution order
158 : *
159 : * @param order execution order
160 : */
161 : virtual bool loadExecOnce(unsigned int order, ExecIdsIter &iter);
162 :
163 : /**
164 : * @brief Unload cache data by execution order
165 : *
166 : * @param order execution order
167 : */
168 : virtual void unloadExec(unsigned int order);
169 :
170 : /**
171 : * @brief Unload Tensor
172 : *
173 : * @param order order of Tensor to unload
174 : */
175 : virtual void unloadTensor(unsigned int order);
176 :
177 : /**
178 : * @brief Load active cache data
179 : */
180 : virtual void loadActives();
181 :
182 : /**
183 : * @brief Unload active cache data
184 : */
185 : virtual void unloadActives();
186 :
187 : /**
188 : * @brief Get name
189 : *
190 : * @return cache pool name
191 : */
192 0 : virtual std::string getName() { return name; }
193 :
194 : /**
195 : * @brief Get ExecutionMode
196 : *
197 : * @return ml::train::ExecutionMode
198 : */
199 : ml::train::ExecutionMode getExecMode() const { return execution_mode_; }
200 :
201 : /**
202 : * @brief set FSU weight path
203 : *
204 : * @param path FSU weight file path
205 : */
206 : void setFsuWeightPath(std::string path) override;
207 :
208 : /**
209 : * @brief set weight file offset for FSU loading
210 : *
211 : * @param offsets weight file offset
212 : */
213 : void
214 0 : setWeightOffset(std::vector<std::pair<size_t, size_t>> offsets) override {
215 0 : swap_device->setWeightOffset(offsets);
216 0 : }
217 :
218 : /**
219 : * @brief get Tensor ID set in order
220 : *
221 : * @param order Execution order
222 : * @return Tensor id set
223 : */
224 : std::set<unsigned int> getExecIDs(unsigned int order) {
225 : return exec_ids[order];
226 : }
227 :
228 : /**
229 : * @brief get Active Cache Elem lists
230 : *
231 : * @return Active Cache Elem list
232 : */
233 : std::unordered_set<unsigned int> getActiveElems() { return actives; }
234 :
235 : /**
236 : * @brief get Cache Elem with id
237 : * @param id Tensor ID
238 : * @return Cache Elem
239 : */
240 : CacheElem &getCacheElem(unsigned int id) { return *elems[id]; }
241 :
242 : /**
243 : * @brief check Cache Elem with id is loaded (Active)
244 : * @param id Tensor ID
245 : * @return true if it is loaded
246 : */
247 : bool isLoaded(unsigned int id) { return elems[id]->isActive(); }
248 :
249 : protected:
250 : /**
251 : * @brief validate cache element
252 : *
253 : * @param cache element id
254 : */
255 : virtual void validate(unsigned int id);
256 :
257 : /**
258 : * @brief invalidate cache element
259 : *
260 : * @param cache element id
261 : */
262 : virtual void invalidate(unsigned int id);
263 :
264 : /**
265 : * @brief Get cache policies
266 : *
267 : * @return Cache polices
268 : */
269 : std::vector<CachePolicy> &getCachePolicy() { return policies; }
270 :
271 : private:
272 : void eraseActiveIf(const std::function<bool(unsigned int id)> &pred);
273 :
274 : std::string name; /**< pool name */
275 : ml::train::ExecutionMode execution_mode_; /**< execution mode */
276 : std::shared_ptr<SwapDevice> swap_device; /**< swap device */
277 : CacheElems elems; /**< cache elements */
278 : std::unordered_set<unsigned int> actives;
279 : std::vector<CachePolicy> policies;
280 : std::unordered_map<unsigned int, ExecIds> exec_ids;
281 :
282 : mutable std::mutex mutex;
283 : };
284 :
285 : } // namespace nntrainer
286 :
287 : #endif /** __CACHE_POOL_H__ */
|