Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2025 Jijoong Moon <jijoong.moon@samsung.com>
4 : *
5 : * @file task_executor.h
6 : * @date 04 April 2025
7 : * @brief This file contains a task executor
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jijoong Moon <jijoong.moon@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : * @brief Task executor class
12 : *
13 : */
14 :
15 : #include "task_executor.h"
16 :
17 : #include <nntrainer_error.h>
18 : #include <nntrainer_log.h>
19 :
20 : namespace nntrainer {
21 :
22 0 : TaskExecutor::TaskExecutor(std::string n, size_t thread_count) :
23 0 : name(n), stop(false) {
24 0 : for (size_t i = 0; i < thread_count; ++i) {
25 0 : workers.emplace_back([this] { this->worker_thread(); });
26 : }
27 0 : }
28 :
29 0 : TaskExecutor::~TaskExecutor() {
30 : {
31 0 : std::unique_lock<std::mutex> lock(queue_mutex);
32 : stop = true;
33 : }
34 :
35 0 : cond_var.notify_all();
36 0 : for (std::thread &t : workers) {
37 0 : if (t.joinable())
38 0 : t.join();
39 : }
40 0 : }
41 :
42 0 : void TaskExecutor::worker_thread() {
43 :
44 : while (true) {
45 : Task task;
46 : {
47 0 : std::unique_lock<std::mutex> lock(queue_mutex);
48 0 : cond_var.wait(lock, [this]() { return stop || !task_queue.empty(); });
49 :
50 0 : if (stop && task_queue.empty()) {
51 0 : return;
52 : }
53 :
54 0 : task = std::move(task_queue.front());
55 : task_queue.pop();
56 0 : task_started[task.id] = true;
57 0 : task_started_cv.notify_all();
58 :
59 : // we are not going to remove the Done Tasks.
60 : // we exeplicitly call release tasks. until then, we keep the results and
61 : // not going to submit that task again
62 : // queued_ids.erase(task.id);
63 : }
64 :
65 : try {
66 0 : task.callback(task.data);
67 0 : task.promise.set_value();
68 0 : } catch (...) {
69 0 : ml_loge("[%s] : [Error ] Task ID %d threw an exception\n", name.c_str(),
70 : task.id);
71 0 : }
72 0 : }
73 : }
74 :
75 0 : int TaskExecutor::submit(TaskCallback cb, void *data) {
76 :
77 0 : auto canceled = std::make_shared<std::atomic_bool>(false);
78 : auto promise = std::make_shared<std::promise<void>>();
79 0 : std::shared_future<void> fut = promise->get_future().share();
80 0 : int id = getNextTaskId();
81 :
82 : {
83 0 : std::lock_guard<std::mutex> lock(queue_mutex);
84 :
85 : if (future_map.count(id)) {
86 0 : if (!future_map[id].valid()) {
87 0 : ml_loge("[%s] : [Error] Future is not valid : Task id - %d\n",
88 : name.c_str(), id);
89 : }
90 0 : auto status = future_map[id].wait_for(std::chrono::seconds(0));
91 0 : if (status != std::future_status::ready) {
92 0 : ml_logi("[%s] : Task id - %d is still active\n", name.c_str(), id);
93 0 : return id;
94 : }
95 : }
96 :
97 0 : Task task{id, std::move(*promise), cb, data};
98 :
99 0 : future_map[id] = fut;
100 :
101 : task_queue.push(std::move(task));
102 0 : }
103 0 : cond_var.notify_one();
104 0 : return id;
105 : }
106 :
107 0 : void TaskExecutor::submitTasks(const std::vector<TaskDesc> &tasks) {
108 0 : for (const auto &task : tasks) {
109 0 : submit(task.callback, task.data);
110 : }
111 0 : }
112 :
113 0 : bool TaskExecutor::cancel(int id) {
114 0 : std::lock_guard<std::mutex> lock(queue_mutex);
115 : auto it = cancel_map.find(id);
116 0 : if (it != cancel_map.end()) {
117 : *(it->second) = true;
118 0 : return true;
119 : }
120 : return false;
121 : }
122 :
123 0 : void TaskExecutor::wait(int id) {
124 0 : std::shared_future<void> fut;
125 : {
126 0 : std::unique_lock<std::mutex> lock(queue_mutex);
127 :
128 0 : task_started_cv.wait(
129 0 : lock, [&] { return task_started.count(id) && task_started[id]; });
130 :
131 : auto it = future_map.find(id);
132 0 : if (it == future_map.end() || !it->second.valid()) {
133 : return;
134 : }
135 0 : fut = it->second;
136 : }
137 : try {
138 0 : fut.wait();
139 0 : } catch (const std::future_error &e) {
140 0 : ml_loge("[%s] : exception while waiting on future : %s\n", name.c_str(),
141 : e.what());
142 0 : }
143 : }
144 :
145 0 : void TaskExecutor::waitAll(const std::vector<int> &ids) {
146 : std::vector<std::shared_future<void>> futures;
147 : {
148 0 : std::lock_guard<std::mutex> lock(queue_mutex);
149 0 : for (int id : ids) {
150 : auto it = future_map.find(id);
151 0 : if (it != future_map.end()) {
152 0 : futures.push_back(it->second);
153 : } else {
154 0 : ml_logw("[%s] : Task ID is not found : %d\n", name.c_str(), id);
155 : }
156 : }
157 : }
158 :
159 0 : for (auto &fut : futures) {
160 : try {
161 0 : fut.wait();
162 0 : } catch (const std::exception &e) {
163 0 : ml_loge("[%s] : exception while waiting on future : %s\n", name.c_str(),
164 : e.what());
165 0 : }
166 : }
167 0 : }
168 :
169 0 : bool TaskExecutor::isDone(int id) {
170 0 : std::lock_guard<std::mutex> lock(queue_mutex);
171 : auto it = future_map.find(id);
172 0 : if (it == future_map.end())
173 : return false;
174 0 : return it->second.wait_for(std::chrono::seconds(0)) ==
175 0 : std::future_status::ready;
176 : }
177 :
178 0 : bool TaskExecutor::isAllDone(const std::vector<int> &ids) {
179 0 : std::lock_guard<std::mutex> lock(queue_mutex);
180 0 : for (int id : ids) {
181 0 : isDone(id);
182 : }
183 0 : return true;
184 : }
185 :
186 0 : void TaskExecutor::releaseTask(int id) {
187 0 : std::lock_guard<std::mutex> lock(queue_mutex);
188 : future_map.erase(id);
189 : cancel_map.erase(id);
190 : reusable_ids.push(id);
191 0 : }
192 :
193 : } // namespace nntrainer
|