Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2025 Jijoong Moon <jijoong.moon@samsung.com>
4 : *
5 : * @file task_executor.h
6 : * @date 04 April 2025
7 : * @brief This file contains a task executor
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jijoong Moon <jijoong.moon@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : * @brief Task executor class
12 : *
13 : */
14 :
15 : #ifndef __TASK_EXECUTOR_H__
16 : #define __TASK_EXECUTOR_H__
17 :
18 : #include <atomic>
19 : #include <condition_variable>
20 : #include <functional>
21 : #include <future>
22 : #include <iostream>
23 : #include <map>
24 : #include <mutex>
25 : #include <queue>
26 : #include <task.h>
27 : #include <thread>
28 : #include <unordered_set>
29 : #include <vector>
30 :
31 : namespace nntrainer {
32 :
33 : /**
34 : * @brief This is call back for load/unload Task
35 : *
36 : */
37 : using TaskCallback = std::function<void(void *)>;
38 :
39 : /**
40 : * @class TaskExecutor Class
41 : * @brief This is load / unload Task Executor with thread pool
42 : *
43 : */
44 : class TaskExecutor {
45 : public:
46 : /**
47 : * @enum Temperal Enum for CompeleteStatus
48 : *
49 : */
50 : enum CompleteStatus {
51 : SUCCESS,
52 : FAIL_CANCEL,
53 : FAIL_TIMEOUT,
54 : FAIL,
55 : };
56 :
57 : /**
58 : * @struct To describe Task
59 : * @brief THis is Task Describe struct
60 : */
61 : struct TaskDesc {
62 : int id;
63 : TaskCallback callback;
64 : void *data;
65 : };
66 :
67 : /**
68 : * @enum Temperal definition for callback
69 : *
70 : */
71 : using CompleteCallback =
72 : std::function<void(int, CompleteStatus,
73 : std::future<CompleteStatus>)>; /**< (task id, status) */
74 :
75 : template <typename T = std::chrono::milliseconds>
76 : using TaskInfo =
77 : std::tuple<int, std::shared_ptr<TaskAsync<T>>, CompleteCallback,
78 : std::atomic_bool, std::promise<CompleteStatus>>;
79 : /**< (task id, task, complete callback, running, complete promise) */
80 :
81 : /**
82 : * @brief Constructor of TaskExecutor
83 : *
84 : */
85 : TaskExecutor(std::string name = "",
86 : size_t thread_count = std::thread::hardware_concurrency());
87 :
88 : /**
89 : * @brief Destructor of TaskExecutor
90 : *
91 : */
92 : ~TaskExecutor();
93 :
94 : /**
95 : * @brief submit Task
96 : *
97 : */
98 : int submit(TaskCallback cb, void *data = nullptr);
99 :
100 : /**
101 : * @brief Cancel Task
102 : *
103 : */
104 : bool cancel(int id);
105 :
106 : /**
107 : * @brief Wait to complete
108 : *
109 : */
110 : void wait(int id);
111 :
112 : /**
113 : * @brief Wait to complete Tasks in vectors
114 : *
115 : */
116 : void waitAll(const std::vector<int> &ids);
117 :
118 : /**
119 : * @brief check done of task id
120 : *
121 : */
122 : bool isDone(int id);
123 :
124 : /**
125 : * @brief check done all the tasks in vector
126 : *
127 : */
128 : bool isAllDone(const std::vector<int> &ids);
129 :
130 : /**
131 : * @brief Submit mutiple tasks
132 : *
133 : */
134 : void submitTasks(const std::vector<TaskDesc> &tasks);
135 :
136 : /**
137 : * @brief release Task
138 : *
139 : */
140 : void releaseTask(int id);
141 :
142 : private:
143 : /**
144 : * @brief Definition of Task
145 : *
146 : */
147 0 : struct Task {
148 : int id;
149 : std::promise<void> promise;
150 : TaskCallback callback;
151 : void *data = nullptr;
152 : };
153 :
154 : /**
155 : * @brief Create Worker Thread
156 : *
157 : */
158 : void worker_thread();
159 :
160 : /**
161 : * @brief Get Next Task Id for protect the overflow
162 : *
163 : */
164 0 : int getNextTaskId() {
165 0 : if (!reusable_ids.empty()) {
166 0 : int id = reusable_ids.front();
167 : reusable_ids.pop();
168 0 : return id;
169 : }
170 0 : return next_task_id.fetch_add(1);
171 : }
172 :
173 : std::string name;
174 : std::vector<std::thread> workers;
175 : std::queue<Task> task_queue;
176 : std::map<int, std::shared_ptr<std::atomic_bool>> cancel_map;
177 : std::map<int, std::shared_future<void>> future_map;
178 : std::map<int, bool> task_started;
179 : std::mutex queue_mutex;
180 : std::condition_variable cond_var;
181 : std::condition_variable task_started_cv;
182 : std::atomic<bool> stop;
183 : std::unordered_set<int> queued_ids;
184 : std::queue<int> reusable_ids;
185 : std::atomic<int> next_task_id{0};
186 : };
187 :
188 : } // namespace nntrainer
189 :
190 : #endif /** __TASK_EXECUTOR_H__ */
|