Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2022 Jiho Chu <jiho.chu@samsung.com>
4 : *
5 : * @file cache_loader.cpp
6 : * @date 10 Nov 2022
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Jiho Chu <jiho.chu@samsung.com>
9 : * @bug No known bugs except for NYI items
10 : * @brief Cache loader class
11 : *
12 : */
13 :
14 : #include "cache_loader.h"
15 : #include "task.h"
16 : #include "task_executor.h"
17 :
18 : #include <cache_pool.h>
19 : #include <climits>
20 : #include <cstdint>
21 : #include <exception>
22 : #include <memory>
23 : #include <nntrainer_error.h>
24 : #include <nntrainer_log.h>
25 :
26 : namespace nntrainer {
27 :
28 0 : CacheLoader::CacheLoader(std::shared_ptr<CachePool> cache_pool) :
29 : pool(cache_pool),
30 0 : load_task_executor(nullptr),
31 0 : unload_task_executor(nullptr) {}
32 :
33 0 : CacheLoader::~CacheLoader() {
34 0 : if (load_task_executor)
35 0 : delete load_task_executor;
36 0 : if (unload_task_executor)
37 0 : delete unload_task_executor;
38 0 : }
39 :
40 0 : void CacheLoader::init() {
41 0 : if (load_task_executor == nullptr)
42 0 : load_task_executor = new TaskExecutor("loadPool", 2);
43 0 : if (unload_task_executor == nullptr)
44 0 : unload_task_executor = new TaskExecutor("UnloadPool", 2);
45 0 : }
46 :
47 0 : void CacheLoader::finish() {
48 0 : delete load_task_executor;
49 0 : load_task_executor = nullptr;
50 0 : delete unload_task_executor;
51 0 : unload_task_executor = nullptr;
52 0 : }
53 :
54 0 : void CacheLoader::load(unsigned int order) { loadAllinOrder(order); }
55 :
56 0 : bool CacheLoader::loadAllinOrder(unsigned int order) {
57 0 : if (!load_task_executor) {
58 0 : ml_loge("init is needed");
59 0 : return false;
60 : }
61 :
62 0 : std::set<unsigned int> exec_id = pool->getExecIDs(order);
63 :
64 0 : for (auto &id : exec_id) {
65 0 : loadTensor(id);
66 : }
67 :
68 : return true;
69 : }
70 :
71 0 : int CacheLoader::loadTensor(unsigned int id) {
72 0 : if (!load_task_executor) {
73 0 : ml_loge("init is needed");
74 0 : return ML_ERROR_INVALID_PARAMETER;
75 : }
76 0 : checkUnloadComplete(id);
77 :
78 0 : std::lock_guard<std::mutex> lock(state_mutex);
79 :
80 0 : if (states[id] == LoadState::Loading || states[id] == LoadState::Loaded)
81 0 : return -1;
82 :
83 0 : states[id] = LoadState::Loading;
84 :
85 0 : int load_task_id = load_task_executor->submit(
86 0 : [this, id](void *data) {
87 0 : pool->loadTensor(id);
88 0 : std::lock_guard<std::mutex> lock(this->state_mutex);
89 0 : this->states[id] = LoadState::Loaded;
90 0 : },
91 0 : (void *)(std::uintptr_t)id);
92 :
93 0 : pool->getCacheElem(id).setLoadTaskID(load_task_id);
94 :
95 0 : return load_task_id;
96 : }
97 :
98 0 : bool CacheLoader::unloadAllinOrder(unsigned int order) {
99 0 : if (!load_task_executor) {
100 0 : ml_loge("init is needed");
101 0 : return false;
102 : }
103 :
104 0 : std::set<unsigned int> exec_id = pool->getExecIDs(order);
105 :
106 0 : for (auto &id : exec_id) {
107 0 : unloadTensor(id);
108 : }
109 :
110 : return true;
111 : }
112 :
113 0 : int CacheLoader::unloadTensor(unsigned int id) {
114 0 : if (!load_task_executor) {
115 0 : ml_loge("init is needed");
116 0 : return ML_ERROR_INVALID_PARAMETER;
117 : }
118 :
119 0 : checkLoadComplete(id);
120 :
121 0 : std::lock_guard<std::mutex> lock(state_mutex);
122 :
123 0 : if (states[id] != LoadState::Loaded)
124 : return -1;
125 :
126 0 : states[id] = LoadState::Unloading;
127 :
128 0 : int unload_task_id = load_task_executor->submit(
129 0 : [this, id](void *data) {
130 0 : pool->unloadTensor(id);
131 0 : std::lock_guard<std::mutex> lock(this->state_mutex);
132 0 : this->states[id] = LoadState::Idle;
133 0 : },
134 0 : (void *)(std::uintptr_t)id);
135 :
136 0 : pool->getCacheElem(id).setUnloadTaskID(unload_task_id);
137 0 : return unload_task_id;
138 : }
139 :
140 0 : LoadState CacheLoader::getState(int id) const {
141 0 : std::lock_guard<std::mutex> lock(state_mutex);
142 : auto it = states.find(id);
143 0 : if (it == states.end())
144 : return LoadState::Idle;
145 0 : return it->second;
146 : }
147 :
148 0 : int CacheLoader::flushAsync(unsigned int order,
149 : TaskExecutor::CompleteCallback complete) {
150 0 : return flushAsync(order, complete, LONG_MAX);
151 : }
152 :
153 0 : int CacheLoader::flushAsync(unsigned int order,
154 : TaskExecutor::CompleteCallback complete,
155 : long timeout_ms) {
156 0 : if (!unload_task_executor) {
157 0 : ml_loge("init is needed");
158 0 : return ML_ERROR_INVALID_PARAMETER;
159 : }
160 :
161 0 : std::set<unsigned int> exec_id = pool->getExecIDs(order);
162 :
163 0 : for (auto &id : exec_id) {
164 0 : unloadTensor(id);
165 : }
166 :
167 : return 0;
168 : }
169 0 : void CacheLoader::flush() {
170 : auto actives = pool->getActiveElems();
171 :
172 0 : for (auto &id : actives) {
173 0 : unloadTensor(id);
174 : }
175 :
176 0 : for (auto &id : actives) {
177 0 : checkUnloadComplete(id);
178 : }
179 :
180 0 : pool->flush();
181 0 : }
182 :
183 0 : int CacheLoader::cancelAsync(int id) {
184 : try {
185 0 : load_task_executor->cancel(id);
186 0 : } catch (const std::exception &e) {
187 0 : ml_loge("CacheLoader(%s): failed to cancel(%d): %s",
188 : pool->getName().c_str(), id, e.what());
189 : return ML_ERROR_UNKNOWN;
190 0 : }
191 :
192 : return ML_ERROR_NONE;
193 : }
194 :
195 0 : unsigned int CacheLoader::inActive(unsigned int order) {
196 0 : std::set<unsigned int> exec_id = pool->getExecIDs(order);
197 0 : for (auto &id : exec_id) {
198 0 : auto &elem = pool->getCacheElem(id);
199 : int load_task_id = elem.getLoadTaskID();
200 0 : if (load_task_id >= 0) {
201 0 : load_task_executor->releaseTask(load_task_id);
202 : elem.setLoadTaskID(-1);
203 0 : states[id] = LoadState::Unloading;
204 : }
205 0 : pool->inActive(id);
206 : }
207 0 : return 0;
208 : }
209 :
210 0 : bool CacheLoader::checkAllLoadComplete(unsigned int order) {
211 :
212 0 : std::set<unsigned int> exec_id = pool->getExecIDs(order);
213 :
214 0 : for (auto &id : exec_id) {
215 0 : checkLoadComplete(id);
216 : }
217 0 : return true;
218 : }
219 :
220 0 : bool CacheLoader::checkAllUnloadComplete(unsigned int order) {
221 :
222 0 : std::set<unsigned int> exec_id = pool->getExecIDs(order);
223 :
224 0 : for (auto &id : exec_id) {
225 0 : checkUnloadComplete(id);
226 : }
227 0 : return true;
228 : }
229 :
230 0 : bool CacheLoader::checkLoadComplete(unsigned int id) {
231 0 : auto &elem = pool->getCacheElem(id);
232 : int unload_task_id = elem.getUnloadTaskID();
233 : int load_task_id = elem.getLoadTaskID();
234 0 : if (unload_task_id >= 0) {
235 0 : load_task_executor->releaseTask(unload_task_id);
236 : elem.setUnloadTaskID(-1);
237 : }
238 :
239 0 : if (load_task_id >= 0) {
240 0 : load_task_executor->wait(load_task_id);
241 : }
242 :
243 0 : return true;
244 : }
245 :
246 0 : bool CacheLoader::checkUnloadComplete(unsigned int id) {
247 0 : auto &elem = pool->getCacheElem(id);
248 : int unload_task_id = elem.getUnloadTaskID();
249 : int load_task_id = elem.getLoadTaskID();
250 0 : if (load_task_id >= 0) {
251 0 : load_task_executor->releaseTask(load_task_id);
252 : elem.setLoadTaskID(-1);
253 : }
254 0 : if (unload_task_id >= 0) {
255 0 : load_task_executor->wait(unload_task_id);
256 : }
257 0 : return true;
258 : }
259 :
260 : } // namespace nntrainer
|