Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file iteration_queue.h
6 : * @date 13 July 2021
7 : * @brief This file contains thread safe queue for a single iteration
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 : #ifndef __ITERATION_QUEUE_H__
14 : #define __ITERATION_QUEUE_H__
15 :
16 : #include <atomic>
17 : #include <condition_variable>
18 : #include <functional>
19 : #include <memory>
20 : #include <queue>
21 : #include <shared_mutex>
22 : #include <stdexcept>
23 : #include <tuple>
24 :
25 : #include <data_iteration.h>
26 : #include <data_producer.h>
27 : #include <nntrainer_log.h>
28 : #include <tensor.h>
29 : #include <tensor_dim.h>
30 :
31 : namespace nntrainer {
32 :
33 : /**
34 : * @brief Thread Safe Queue implementation dedicated for the non-owing pointer
35 : *
36 : * @tparam original type of the view (T * will be pushed and pop)
37 : */
38 1448 : template <typename T> class ViewQueue {
39 : public:
40 : /**
41 : * @brief Construct a new queue
42 : */
43 5764 : ViewQueue() : q() {}
44 :
45 : /**
46 : * @brief push data to queue
47 : *
48 : * @param data data to put
49 : */
50 31365 : void push(T *data) {
51 : {
52 31365 : std::unique_lock<std::shared_mutex> lk(q_mutex);
53 : q.push(data);
54 : }
55 :
56 31365 : q_cv.notify_one();
57 31365 : }
58 :
59 : /**
60 : * @brief pop data from the queue, wait if empty
61 : * @note when fail to get, this will return nullptr (eg) when interrupt
62 : * happens)
63 : * @return T* view of the data
64 : */
65 25660 : T *waitAndPop() {
66 25660 : std::unique_lock<std::shared_mutex> lk(q_mutex);
67 25660 : q_cv.wait(lk, [this] { return !q.empty(); });
68 25660 : auto ptr = q.front();
69 : q.pop();
70 :
71 25660 : return ptr;
72 : }
73 :
74 : /**
75 : * @brief check if current queue is empty
76 : *
77 : * @return bool true if empty
78 : */
79 : bool isEmpty() const {
80 : std::shared_lock<std::shared_mutex> lk(q_mutex);
81 : return q.empty();
82 : }
83 :
84 : /**
85 : * @brief check if current queue is empty
86 : *
87 : * @return bool true if empty
88 : */
89 2868 : typename std::queue<T *>::size_type size() const {
90 2868 : std::shared_lock<std::shared_mutex> lk(q_mutex);
91 2868 : return q.size();
92 : }
93 :
94 : private:
95 : mutable std::shared_mutex q_mutex;
96 : std::condition_variable_any q_cv;
97 :
98 : std::queue<T *> q;
99 : };
100 :
101 : /**
102 : * @brief A view container that calls a callback on destruct
103 : * @note the callback must be noexcept, and the given underlying data must
104 : * outlive the lifetime of this class
105 : *
106 : * @tparam T underlying type
107 : */
108 : template <typename T> class ScopedView {
109 : public:
110 : /**
111 : * @brief Construct a new Scoped View object
112 : *
113 : * @param data_ reference of the underlying data
114 : * @param on_notify_ callback to be called on exit
115 : * @param on_error_ callback to be called on error
116 : */
117 : ScopedView(T *data_, std::function<void(void)> &&on_notify_ = nullptr,
118 : std::function<void(void)> &&on_error_ = nullptr) :
119 108943 : data(data_),
120 : on_notify(std::forward<std::function<void(void)>>(on_notify_)),
121 : on_error(std::forward<std::function<void(void)>>(on_error_)) {}
122 :
123 : ScopedView(const ScopedView &rhs) = delete;
124 : ScopedView &operator=(const ScopedView &rhs) = delete;
125 :
126 19 : ScopedView(ScopedView &&rhs) = default;
127 : ScopedView &operator=(ScopedView &&rhs) = default;
128 :
129 : /**
130 : * @brief check if scoped view contains any underlying data
131 : *
132 : * @return bool if data is empty
133 : */
134 121057 : bool isEmpty() { return data == nullptr; }
135 :
136 : /**
137 : * @brief Destroy the Scoped View object, callback is called at this time
138 : *
139 : */
140 121092 : ~ScopedView() {
141 : try {
142 121092 : if (std::uncaught_exceptions()) {
143 8 : if (on_error) {
144 : on_error();
145 : }
146 : } else {
147 121084 : if (on_notify) {
148 : on_notify();
149 : }
150 : }
151 0 : } catch (...) {
152 0 : ml_loge("while handling on_notify or on_error, error happened");
153 : }
154 121092 : }
155 :
156 : /**
157 : * @brief get the underlying data
158 : *
159 : * @return T & reference to the underlying data
160 : */
161 4 : T &get() { return *data; }
162 :
163 : /**
164 : * @brief get the underlying data
165 : *
166 : * @return T & reference to the underlying data
167 : */
168 : T const &get() const { return *data; }
169 :
170 : private:
171 : T *data; /**< underlying data pointer */
172 : std::function<void(void)>
173 : on_notify; /**< called when destroyed without error */
174 : std::function<void(void)> on_error; /**< called when destroyed with error */
175 : };
176 :
177 : /**
178 : * @brief Iteration queue that owns the buffer for input / labels
179 : * @details
180 : *
181 : * - requestEmptySlot() will give a ScopedView<sample>
182 : * Destructing the returned object will notify the iteration that is done
183 : * filling the sample. Once iteration is done filling, it will internally call
184 : * IterationQueue::markFilled();
185 : * - requestFilledSlot() will give a ScopedView<Iteration>
186 : * Destructing this will notify the queue that is done used (internally
187 : * calls IterationQueue::markEmpty())
188 : *
189 : * @details For an iteration there can be four state.
190 : * 1. The buffer is empty, waiting to be filled (will be in empty_q)
191 : * 2. The buffer is being filled sample by sample, waiting to be marked as
192 : * filled.
193 : * 3. The buffer is filled, waiting to be served (will be in filled_q)
194 : * 4. The buffer is being served, waiting to be marked as emptied.
195 : * @todo apply this to the databuffer
196 : * @todo handle error case: 1. when ScopedView<Sample> has met throw
197 : * 2. when ScopedView<Iteration> has met throw
198 : */
199 : class IterationQueue {
200 : public:
201 : /**
202 : * @brief Construct a new Iteration Queue object
203 : * @note input_dimension and label_dimension should include the batch, if
204 : * IterationQueue::batch() is zero, it means it's invalid
205 : * @param num_slots number of slots this iteration queue will allocate, it
206 : * should be buffersize/batchsize
207 : * @param input_dims input dimensions
208 : * @param label_dims label dimensions
209 : */
210 : IterationQueue(unsigned int num_slots,
211 : const std::vector<ml::train::TensorDim> &input_dims,
212 : const std::vector<ml::train::TensorDim> &label_dims);
213 :
214 : /**
215 : * @brief Destroy the Iteration Queue object
216 : *
217 : */
218 : ~IterationQueue();
219 :
220 : /**
221 : * @brief request empty sample from the queue.
222 : * @note User must check if ScopedView actually has a value by calling
223 : * ScopedView::isEmpty()
224 : * @return ScopedView<Sample> sample view. ScopedView::isEmpty() == true
225 : * if there is no more data coming. Destroying the returned object will
226 : * signal the queue that the sample is filled.
227 : */
228 : ScopedView<Sample> requestEmptySlot();
229 :
230 : /**
231 : * @brief request filled iteration from the queue.
232 : * @note User must check if ScopedView actually has a value by calling
233 : * ScopedView::isEmpty()
234 : * @return ScopedView<Iteration> Ieration view. ScopedView::isEmpty() == true
235 : * if there is no more data coming. Destroying the returned object will
236 : * signal the queue that the sample is done using.
237 : *
238 : */
239 : ScopedView<Iteration> requestFilledSlot();
240 :
241 : /**
242 : * @brief get slot size, slot size is number of batches inside the queue
243 : *
244 : * @return unsigned int num slot
245 : */
246 47 : unsigned int slots() { return iterations.size(); }
247 :
248 : /**
249 : * @brief get size of batch for one iteration
250 : *
251 : * @return unsigned int size of batch
252 : */
253 3128 : unsigned int batch() { return batch_size; }
254 :
255 : /**
256 : * @brief notifyEndOfRequest, when the producing by requestEmptySlot has
257 : * finished.
258 : * @note It is important that the owner of this class must ensure that there
259 : * will be no more requestEmptySlot call after this. This means that, in case
260 : * of multiple workers, the manager of the worker(producer) must know every
261 : * producer has finished. and call this function other than each worker call
262 : * this function.
263 : *
264 : */
265 : void notifyEndOfRequestEmpty();
266 :
267 : private:
268 : /**
269 : * @brief A wrapper object around @c Iteration which marks filled when filling
270 : * sample is done
271 : * @note the given @a iteration_ and @a bq_ must outleave the lifetime of this
272 : * class
273 : *
274 : */
275 5696 : class MarkableIteration {
276 : public:
277 : /**
278 : * @brief Construct a new Markable Iteration object
279 : *
280 : * @param input_dims input dimensions
281 : * @param label_dims label dimensions
282 : * @param iq_ iteration queue view to notify
283 : */
284 : MarkableIteration(const std::vector<ml::train::TensorDim> &input_dims,
285 : const std::vector<ml::train::TensorDim> &label_dims,
286 : IterationQueue *iq);
287 :
288 : /**
289 : * @brief reset num observation and internal batch size of iteration
290 : *
291 : */
292 : void reset();
293 :
294 : /**
295 : * @brief Construct a new Markable Iteration object
296 : *
297 : * @param rhs right side to move
298 : */
299 : MarkableIteration(MarkableIteration &&rhs);
300 :
301 : /**
302 : * @brief Move Assignement operator
303 : *
304 : * @param rhs rhs to move
305 : * @return MarkableIteration& markable iteration
306 : */
307 : MarkableIteration &operator=(MarkableIteration &&rhs);
308 :
309 : /**
310 : * @brief mark iteration that one sample is filled
311 : * @todo make this function noexcept
312 : */
313 : void markSampleFilled() /** noexcept */;
314 :
315 : /**
316 : * @brief update end sample to the given iterator and mark last
317 : * @note after updating end iterator, this can be markFilled() if every
318 : * sample is already filled
319 : *
320 : * @param iterator non-inclusive iterator to mark the last
321 : */
322 : void setEndSample(std::vector<Sample>::iterator sample_iterator);
323 :
324 : /**
325 : * @brief get underlying iteration
326 : *
327 : * @return Iteration& iteration
328 : */
329 13564 : Iteration &get() { return iteration; }
330 :
331 : private:
332 : unsigned int num_observed; /**< number of observed samples which were passed
333 : to the callee and notified done filling */
334 : mutable std::mutex
335 : notify_mutex; /**< mutex which should be locked when try to notify */
336 : Iteration iteration; /**< underlying iteration that this class owns */
337 : IterationQueue *iq; /**< view of iteration queue */
338 : };
339 :
340 : /**
341 : * @brief Queue running state enum
342 : *
343 : */
344 : enum class FlowState {
345 : FLOW_STATE_OPEN = 0, /**< nothing */
346 : FLOW_STATE_STOP_REQUESTED = 1, /**< request stop */
347 : FLOW_STATE_STOPPED = 2, /**< queue is fully stopped */
348 : };
349 :
350 : /**
351 : * @brief mark the given iteration filled
352 : * @todo make this noexcept with the thread safe queue
353 : * @param iteration iteration to mark it as filled
354 : */
355 : void markFilled(MarkableIteration *iteration) /** noexcept */;
356 :
357 : /**
358 : * @brief mark the given iteration empty
359 : * @todo make this noexcept with the thread safe queue
360 : * @param iteration iteration to mark it as emptied
361 : */
362 : void markEmpty(MarkableIteration *iteration) /** noexcept */;
363 :
364 : std::vector<MarkableIteration> iterations; /**< allocated iterations */
365 : MarkableIteration *being_filled; /**< last iteration that is being filled */
366 : std::vector<Sample>::iterator
367 : current_iterator; /**< current sample iteration of being_filled */
368 :
369 : mutable std::mutex empty_mutex; /**< mutex to be used when it is mutually
370 : exclusive to the requesting empty slots */
371 : unsigned int
372 : num_being_filled; /**< number of iteration that is in being_filled state */
373 : mutable std::mutex
374 : filled_mutex; /**< mutex to be used when it is mutually exclusive to the
375 : requesting filled slots */
376 : std::condition_variable_any
377 : notify_emptied_cv; /**< conditional variable to wait based on the
378 : num_being_filled */
379 : std::atomic<FlowState> flow_state; /**< flow state of the queue */
380 :
381 : unsigned int batch_size;
382 : ViewQueue<MarkableIteration> empty_q; /**< iterations to be filled */
383 : ViewQueue<MarkableIteration> filled_q; /**< iterations to be served */
384 : };
385 :
386 : } // namespace nntrainer
387 :
388 : #endif // __ITERATION_QUEUE_H__
|