Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2022 Jiho Chu <jiho.chu@samsung.com>
4 : *
5 : * @file cache_pool.cpp
6 : * @date 01 July 2022
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Jiho Chu <jiho.chu@samsung.com>
9 : * @bug No known bugs except for NYI items
10 : * @brief Cache pool class inherited from memory pool
11 : *
12 : */
13 :
14 : #include "cache_pool.h"
15 :
16 : #include <limits>
17 : #include <numeric>
18 : #include <stdexcept>
19 : #include <vector>
20 :
21 : #include <nntrainer_error.h>
22 : #include <nntrainer_log.h>
23 : #include <profiler.h>
24 :
25 : namespace nntrainer {
26 :
27 : namespace {
28 :
29 : /**
30 : * @brief convert tensor lifespan to cache policy
31 : *
32 : * @param lifespand tensor lifespan
33 : * @return cache policy
34 : * @note cache policy is defined by tensor's lifetime. If it needs to be
35 : * maintained its value, ALWAYS_SYNCED or ITERATION_CONSIST is proper. If not,
36 : * TEMPORAL doesnot keep its value.
37 : */
38 : inline const CachePolicy
39 0 : convertTensorLifespanToCachePolicy(const TensorLifespan lifespan) {
40 : CachePolicy policy;
41 :
42 0 : switch (lifespan) {
43 : case TensorLifespan::UNMANAGED:
44 : policy = CachePolicy::ALWAYS_SYNCED;
45 : break;
46 : case TensorLifespan::FORWARD_FUNC_LIFESPAN:
47 : policy = CachePolicy::TEMPORAL;
48 : break;
49 0 : case TensorLifespan::FORWARD_INFER_LIFESPAN:
50 : policy = CachePolicy::SYNC_ONCE;
51 0 : break;
52 : case TensorLifespan::CALC_DERIV_LIFESPAN:
53 : policy = CachePolicy::TEMPORAL;
54 : break;
55 : case TensorLifespan::CALC_GRAD_LIFESPAN:
56 : policy = CachePolicy::TEMPORAL;
57 : break;
58 : case TensorLifespan::CALC_AGRAD_LIFESPAN:
59 : policy = CachePolicy::TEMPORAL;
60 : break;
61 : case TensorLifespan::CALC_GRAD_DERIV_LIFESPAN:
62 : policy = CachePolicy::TEMPORAL;
63 : break;
64 : case TensorLifespan::CALC_GRAD_DERIV_AGRAD_LIFESPAN:
65 : policy = CachePolicy::ITERATION_CONSIST;
66 : break;
67 : case TensorLifespan::FORWARD_GRAD_LIFESPAN:
68 : policy = CachePolicy::ITERATION_CONSIST;
69 : break;
70 : case TensorLifespan::FORWARD_GRAD_AGRAD_LIFESPAN:
71 : policy = CachePolicy::ITERATION_CONSIST;
72 : break;
73 : case TensorLifespan::FORWARD_DERIV_LIFESPAN:
74 : policy = CachePolicy::ALWAYS_SYNCED;
75 : break;
76 : case TensorLifespan::ITERATION_LIFESPAN:
77 : policy = CachePolicy::ITERATION_CONSIST;
78 : break;
79 : case TensorLifespan::EPOCH_LIFESPAN:
80 : policy = CachePolicy::ITERATION_CONSIST;
81 : break;
82 : case TensorLifespan::MAX_LIFESPAN:
83 : policy = CachePolicy::ALWAYS_SYNCED;
84 : break;
85 : default:
86 : policy = CachePolicy::ALWAYS_SYNCED;
87 : break;
88 : }
89 :
90 0 : return policy;
91 : }
92 :
93 : std::atomic_int pool_id = 0;
94 :
95 : } // namespace
96 :
97 0 : CachePool::CachePool(const std::string &n) :
98 0 : name(n),
99 0 : execution_mode_(ml::train::ExecutionMode::TRAIN),
100 0 : swap_device(std::make_shared<SwapDevice>(n + "_" + std::to_string(getpid()) +
101 0 : "_" + std::to_string(pool_id++))) {}
102 :
103 0 : CachePool::CachePool(const std::string &path, const std::string &n) :
104 0 : name(n), execution_mode_(ml::train::ExecutionMode::TRAIN) {
105 0 : if (path.empty())
106 0 : swap_device = std::make_shared<SwapDevice>(
107 0 : n + "_" + std::to_string(getpid()) + "_" + std::to_string(pool_id++));
108 : else
109 : swap_device =
110 0 : std::make_shared<SwapDevice>(path, n + "_" + std::to_string(getpid()) +
111 0 : "_" + std::to_string(pool_id++));
112 0 : }
113 :
114 0 : CachePool::CachePool(const std::string &path, const std::string &name_,
115 0 : ml::train::ExecutionMode exec_mode_) :
116 0 : name(name_), execution_mode_(exec_mode_) {
117 0 : if (path.empty())
118 0 : swap_device = std::make_shared<SwapDevice>(
119 0 : name_ + "_" + std::to_string(getpid()) + "_" + std::to_string(pool_id++));
120 : else
121 0 : swap_device = std::make_shared<SwapDevice>(
122 : path,
123 0 : name_ + "_" + std::to_string(getpid()) + "_" + std::to_string(pool_id++));
124 0 : }
125 :
126 0 : CachePool::~CachePool() {
127 : try {
128 0 : deallocate();
129 0 : } catch (...) {
130 0 : ml_loge("Failed deallocate");
131 0 : }
132 0 : }
133 :
134 0 : void CachePool::inActive(unsigned int order) {
135 :
136 : auto exec_id = exec_ids[order];
137 0 : std::lock_guard<std::mutex> lock(mutex);
138 0 : for (auto &id : exec_id) {
139 : actives.erase(id);
140 : elems[id]->inActive();
141 : }
142 0 : }
143 :
144 0 : void CachePool::allocate() {
145 0 : NNTR_THROW_IF(swap_device->isOperating(), std::runtime_error)
146 : << "Cache pool is already allocated";
147 :
148 0 : size_t pool_size = size();
149 :
150 0 : NNTR_THROW_IF(pool_size == 0, std::runtime_error)
151 : << "Allocating memory pool with size 0";
152 0 : if (execution_mode_ == ml::train::ExecutionMode::INFERENCE)
153 0 : MemoryPool::allocateFSU();
154 0 : swap_device->start(size(), execution_mode_);
155 0 : }
156 :
157 0 : void CachePool::deallocate() {
158 0 : MemoryPool::deallocate();
159 0 : if (!swap_device->isOperating())
160 : return;
161 :
162 0 : if (execution_mode_ == ml::train::ExecutionMode::INFERENCE)
163 0 : MemoryPool::deallocate();
164 :
165 0 : for (auto &[id, elem] : elems)
166 0 : invalidate(id);
167 :
168 : actives.clear();
169 0 : swap_device->finish();
170 : }
171 :
172 0 : void CachePool::validate(unsigned int id) {
173 0 : if (!elems[id]->isActive()) {
174 0 : elems[id]->swapIn();
175 0 : std::lock_guard<std::mutex> lock(mutex);
176 : actives.insert(id);
177 : }
178 0 : }
179 :
180 0 : void CachePool::invalidate(unsigned int id) {
181 0 : if (elems[id]->isActive()) {
182 0 : elems[id]->swapOut();
183 0 : std::lock_guard<std::mutex> lock(mutex);
184 : actives.erase(id);
185 : }
186 0 : }
187 :
188 0 : unsigned int CachePool::requestMemory(size_t bytes, unsigned int start_time,
189 : unsigned int end_time,
190 : std::vector<unsigned int> exec_order,
191 : TensorLifespan lifespan, bool is_wgrad) {
192 0 : auto id = MemoryPool::requestMemory(bytes, start_time, end_time, exec_order,
193 : lifespan, is_wgrad);
194 :
195 0 : const CachePolicy policy = convertTensorLifespanToCachePolicy(lifespan);
196 :
197 0 : policies.push_back(policy);
198 :
199 0 : NNTR_THROW_IF(id != policies.size(), std::runtime_error)
200 : << "Invalid requestMemory call exist";
201 :
202 0 : return id;
203 : }
204 :
205 0 : std::shared_ptr<MemoryData> CachePool::getMemory(unsigned int id) {
206 0 : NNTR_THROW_IF(!swap_device->isOperating(), std::invalid_argument)
207 : << "Allocate memory before allocation";
208 :
209 0 : off_t offset = getMemoryOffset().at(id - 1);
210 0 : size_t len = getMemorySize().at(id - 1);
211 0 : auto exe_order = getMemoryExecOrder().at(id - 1);
212 0 : auto policy = getCachePolicy().at(id - 1);
213 :
214 0 : void *memory_ptr = nullptr;
215 0 : if (execution_mode_ == ml::train::ExecutionMode::INFERENCE) {
216 0 : memory_ptr = getMemoryPtrs().at(id - 1);
217 : }
218 :
219 : auto mem_data = std::make_shared<MemoryData>(
220 0 : id, std::bind(&CachePool::validate, this, std::placeholders::_1),
221 0 : std::bind(&CachePool::invalidate, this, std::placeholders::_1), memory_ptr);
222 :
223 0 : elems.emplace(id, std::make_unique<CacheElem>(swap_device, id, offset, len,
224 : mem_data, policy, memory_ptr));
225 :
226 : std::string ords;
227 :
228 0 : if (execution_mode_ == ml::train::ExecutionMode::INFERENCE) {
229 : auto &o = exe_order[0];
230 : exec_ids[o].insert(id);
231 0 : ords.append(std::to_string(o));
232 : } else {
233 0 : for (auto &o : exe_order) {
234 : exec_ids[o].insert(id);
235 0 : ords.append(std::to_string(o));
236 : }
237 : }
238 0 : ml_logd("[%d] exe_order(%s), offset: %llu, len: %zu", id, ords.c_str(),
239 : (long long unsigned int)offset, len);
240 :
241 0 : return mem_data;
242 0 : }
243 :
244 0 : void CachePool::flush() {
245 0 : for (auto &id : actives) {
246 0 : elems[id]->swapOut(CacheElem::LAST_ACCESS);
247 : }
248 :
249 0 : for (auto &[id, elem] : elems)
250 : elem->reset();
251 :
252 : actives.clear();
253 0 : }
254 :
255 0 : void CachePool::flushExcept(unsigned int order) {
256 0 : auto exe_orders = getMemoryExecOrder();
257 :
258 0 : eraseActiveIf([&, order](const unsigned int id) -> bool {
259 0 : auto exe_order = exe_orders.at(id - 1);
260 0 : auto found = std::find(exe_order.begin(), exe_order.end(), order);
261 0 : if (found != exe_order.end()) {
262 : /**
263 : * We assumes that flushExcept will be called in front of each execution
264 : * order, and the order is incremental. So, we can conclude that, if the
265 : * order passes by the max order of the cache element, it was LAST
266 : * access of the element.
267 : */
268 : CacheElem::Options opt = CacheElem::NONE;
269 0 : if (*std::max_element(exe_order.begin(), exe_order.end()) < order)
270 : opt = CacheElem::LAST_ACCESS;
271 0 : elems[id]->swapOut(opt);
272 : return true;
273 : }
274 : return false;
275 0 : });
276 0 : }
277 :
278 0 : void CachePool::flushExcept(std::vector<unsigned int> order) {
279 0 : auto exe_orders = getMemoryExecOrder();
280 :
281 0 : eraseActiveIf([&, order](const unsigned int id) -> bool {
282 0 : auto exe_order = exe_orders.at(id - 1);
283 0 : for (auto &o : order) {
284 0 : auto found = std::find(exe_order.begin(), exe_order.end(), o);
285 0 : if (found != exe_order.end())
286 0 : return false;
287 : }
288 : /**
289 : * We assumes that flushExcept will be called in front of each execution
290 : * order, and the order is incremental. So, we can conclude that, if the
291 : * order passes by the max order of the cache element, it was LAST access of
292 : * the element.
293 : */
294 : CacheElem::Options opt = CacheElem::NONE;
295 0 : if (*std::max_element(exe_order.begin(), exe_order.end()) < order[0])
296 : opt = CacheElem::LAST_ACCESS;
297 0 : elems[id]->swapOut(opt);
298 : return true;
299 0 : });
300 0 : }
301 :
302 0 : void CachePool::clear() {
303 0 : flush();
304 0 : deallocate();
305 : policies.clear();
306 0 : MemoryPool::clear();
307 0 : }
308 :
309 0 : bool CachePool::isAllocated() const { return swap_device->isOperating(); }
310 :
311 0 : void CachePool::loadExec(unsigned int order) {
312 0 : for (auto &id : exec_ids[order]) {
313 0 : validate(id);
314 : }
315 0 : }
316 :
317 0 : void CachePool::loadTensor(unsigned int id) { validate(id); }
318 :
319 0 : bool CachePool::loadExecOnce(unsigned int order, ExecIdsIter &iter) {
320 0 : if (iter == exec_ids[order].end())
321 : return true;
322 :
323 0 : validate(*iter);
324 :
325 : iter++;
326 0 : return false;
327 : }
328 :
329 0 : void CachePool::unloadExec(unsigned int order) {
330 0 : for (auto &id : exec_ids[order]) {
331 0 : invalidate(id);
332 : }
333 : actives.clear();
334 0 : }
335 :
336 0 : void CachePool::unloadTensor(unsigned int order) {
337 0 : invalidate(order);
338 0 : std::lock_guard<std::mutex> lock(mutex);
339 : actives.erase(order);
340 0 : }
341 :
342 0 : void CachePool::loadActives() {
343 0 : ml_logd("load active caches");
344 :
345 0 : for (auto &id : actives) {
346 0 : elems[id]->swapIn();
347 : }
348 0 : }
349 :
350 0 : void CachePool::unloadActives() {
351 0 : ml_logd("unload active caches");
352 0 : for (auto &id : actives) {
353 0 : elems[id]->swapOut();
354 : }
355 0 : }
356 :
357 0 : void CachePool::setFsuWeightPath(std::string path) {
358 : auto start_with = [](const std::string &str, const std::string &prefix) {
359 0 : return str.size() >= prefix.size() &&
360 0 : str.compare(0, prefix.size(), prefix) == 0;
361 : };
362 :
363 0 : if (!start_with(swap_device->getDevicePath(), "weight_pool")) {
364 0 : remove(swap_device->getDevicePath().c_str());
365 : }
366 :
367 0 : swap_device->setFsuWeightPath(path);
368 0 : swap_device->finish();
369 0 : swap_device->start(size(), execution_mode_);
370 0 : }
371 :
372 0 : void CachePool::eraseActiveIf(
373 : const std::function<bool(unsigned int id)> &pred) {
374 0 : for (auto it = actives.begin(); it != actives.end();
375 0 : pred(*it) ? it = actives.erase(it) : ++it) {
376 : }
377 0 : }
378 :
379 : } // namespace nntrainer
|