Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file iteration_queue.cpp
6 : * @date 13 July 2021
7 : * @brief This file contains thread safe queue
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 : #include <chrono>
14 : #include <iteration_queue.h>
15 :
16 : #include <mutex>
17 : #include <nntrainer_error.h>
18 : #include <shared_mutex>
19 :
20 : using namespace std::literals::chrono_literals;
21 :
22 : namespace nntrainer {
23 :
24 1441 : IterationQueue::IterationQueue(
25 : unsigned int num_slots, const std::vector<ml::train::TensorDim> &input_dims,
26 1441 : const std::vector<ml::train::TensorDim> &label_dims) :
27 1441 : being_filled(nullptr),
28 1441 : num_being_filled(0),
29 1441 : flow_state(IterationQueue::FlowState::FLOW_STATE_OPEN) {
30 1444 : NNTR_THROW_IF(num_slots == 0, std::invalid_argument)
31 : << "number of slots must be more then zero";
32 :
33 1438 : iterations.reserve(num_slots);
34 7134 : for (decltype(num_slots) i = 0; i < num_slots; ++i) {
35 5703 : iterations.emplace_back(input_dims, label_dims, this);
36 5696 : empty_q.push(&iterations.back());
37 : }
38 1434 : batch_size = iterations.front().get().batch();
39 1448 : }
40 :
41 1434 : IterationQueue::~IterationQueue() {
42 1434 : std::scoped_lock lg(empty_mutex, filled_mutex);
43 :
44 : /// if an iteration is not included in either empty_q or filled_q, that
45 : /// means it's either being filled or being served. Which means it will be
46 : /// dangerous to destroy @a this, we might want to wait on the destructor if
47 : /// we can assure this can stay no except
48 1434 : if (empty_q.size() + filled_q.size() < iterations.size()) {
49 0 : ml_logw(
50 : "Destroying the iteration queue, while some buffers are being used");
51 : }
52 1434 : }
53 :
54 107552 : ScopedView<Sample> IterationQueue::requestEmptySlot() {
55 107552 : std::scoped_lock lg(empty_mutex);
56 : auto current_flow_state = flow_state.load();
57 107560 : NNTR_THROW_IF(current_flow_state != FlowState::FLOW_STATE_OPEN,
58 : std::invalid_argument)
59 : << "the queue expect state of "
60 : << static_cast<unsigned>(FlowState::FLOW_STATE_OPEN) << " but met "
61 : << static_cast<unsigned>(current_flow_state);
62 :
63 : /// below is useful information when debugging iteration queue, but there will
64 : /// be too much log if we turn the log on. so leaving it as a comment for now.
65 : // std::cout << "[requestEmptySlot] empty_q.size(): " << empty_q.size()
66 : // << " being_filled: " << num_being_filled
67 : // << " filled_q.size(): " << filled_q.size() << '\n';
68 :
69 107544 : if (being_filled == nullptr ||
70 : current_iterator + 1 == being_filled->get().end()) {
71 12135 : being_filled = empty_q.waitAndPop();
72 12135 : being_filled->reset();
73 12135 : num_being_filled++;
74 12135 : current_iterator = being_filled->get().begin();
75 : } else {
76 : current_iterator++;
77 : }
78 :
79 : auto view = ScopedView<Sample>(
80 : &(*current_iterator),
81 : [current_being_filed = this->being_filled] {
82 107540 : current_being_filed->markSampleFilled();
83 : },
84 107544 : [this, current_being_filled = this->being_filled] {
85 4 : std::unique_lock lg(empty_mutex);
86 4 : this->markEmpty(current_being_filled);
87 4 : num_being_filled--;
88 4 : notify_emptied_cv.notify_all();
89 107548 : });
90 107544 : return view;
91 : }
92 :
93 13529 : ScopedView<Iteration> IterationQueue::requestFilledSlot() {
94 13529 : std::scoped_lock lock(filled_mutex);
95 :
96 : /// below is useful information when debugging iteration queue, but there will
97 : /// be too much log if we turn the log on. so leaving it as a comment for now.
98 : // std::cout << "[requestFilledSlot] empty_q.size(): " << empty_q.size()
99 : // << " num being filled: " << num_being_filled
100 : // << " filled_q.size(): " << filled_q.size() << '\n';
101 13529 : if (flow_state.load() == FlowState::FLOW_STATE_STOPPED) {
102 : return ScopedView<Iteration>(nullptr);
103 : }
104 :
105 13525 : auto iteration = filled_q.waitAndPop();
106 13525 : if (iteration == nullptr) {
107 : auto stop_request_state = FlowState::FLOW_STATE_STOP_REQUESTED;
108 : bool exchange_result = flow_state.compare_exchange_strong(
109 : stop_request_state, FlowState::FLOW_STATE_STOPPED);
110 1395 : NNTR_THROW_IF(!exchange_result, std::runtime_error)
111 : << "the queue has either already stopped or running, but trying stopping "
112 : "without requesting stop, queue size: "
113 0 : << iterations.size() << " num currently empty: " << empty_q.size()
114 0 : << " filled_q.size(): " << filled_q.size();
115 :
116 : return ScopedView<Iteration>(nullptr);
117 : }
118 :
119 : return ScopedView<Iteration>(
120 12126 : &iteration->get(), [this, iteration] { markEmpty(iteration); },
121 4 : [this, iteration] {
122 4 : std::unique_lock lock(filled_mutex);
123 4 : flow_state.store(FlowState::FLOW_STATE_STOPPED);
124 4 : markEmpty(iteration);
125 12134 : });
126 : }
127 :
128 1408 : void IterationQueue::notifyEndOfRequestEmpty() {
129 1408 : std::unique_lock lg(empty_mutex);
130 : auto open_state = FlowState::FLOW_STATE_OPEN;
131 :
132 : /// we have to defined ordering of having stop_requested -> push nullptr to
133 : /// filled_q -> stopped so when the case of changing to stopped it has to push
134 : /// nullptr to empty_q, and filled_q to wake them up and stop. this has
135 : /// potential cases that weren't considered. let's change this to a simpler
136 : /// mechanisms to wait on conditional variable.
137 : bool exchange_result = flow_state.compare_exchange_strong(
138 : open_state, FlowState::FLOW_STATE_STOP_REQUESTED);
139 1412 : NNTR_THROW_IF(!exchange_result, std::invalid_argument)
140 : << "the queue expect state of " << static_cast<unsigned>(open_state)
141 : << " but met " << static_cast<unsigned>(flow_state.load());
142 : /// below is useful information when debugging iteration queue, but there will
143 : /// be too much log if we turn the log on. so leaving it as a comment for now.
144 : // std::cout << "[notifyEnd] empty_q.size(): " << empty_q.size()
145 : // << " num being filled: " << num_being_filled
146 : // << " filled_q.size(): " << filled_q.size() << '\n';
147 :
148 1404 : if (being_filled) {
149 1396 : being_filled->setEndSample(current_iterator + 1);
150 : }
151 1410 : notify_emptied_cv.wait(lg, [this] { return num_being_filled == 0; });
152 1404 : filled_q.push(nullptr);
153 1404 : }
154 :
155 10867 : void IterationQueue::markFilled(MarkableIteration *iteration) {
156 : {
157 10867 : std::lock_guard lg(empty_mutex);
158 10867 : --num_being_filled;
159 10867 : filled_q.push(iteration);
160 : }
161 10867 : notify_emptied_cv.notify_all();
162 10867 : }
163 :
164 12134 : void IterationQueue::markEmpty(MarkableIteration *iteration) {
165 12134 : empty_q.push(iteration);
166 12134 : }
167 :
168 5700 : IterationQueue::MarkableIteration::MarkableIteration(
169 : const std::vector<ml::train::TensorDim> &input_dims,
170 5700 : const std::vector<ml::train::TensorDim> &label_dims, IterationQueue *iq) :
171 5700 : num_observed(0), iteration(input_dims, label_dims), iq(iq) {}
172 :
173 0 : IterationQueue::MarkableIteration::MarkableIteration(MarkableIteration &&rhs) :
174 0 : iteration(std::move(rhs.iteration)), iq(rhs.iq) {
175 0 : std::lock_guard notify_lock_guard(notify_mutex);
176 0 : num_observed = rhs.num_observed;
177 0 : }
178 :
179 12135 : void IterationQueue::MarkableIteration::reset() {
180 12135 : std::lock_guard notify_lock_guard(notify_mutex);
181 12135 : num_observed = 0;
182 12135 : iteration.setEndSample();
183 12135 : }
184 :
185 : IterationQueue::MarkableIteration &
186 0 : IterationQueue::MarkableIteration::operator=(MarkableIteration &&rhs) {
187 0 : if (this == &rhs) {
188 : return *this;
189 : }
190 0 : std::scoped_lock lock(this->notify_mutex, rhs.notify_mutex);
191 0 : std::swap(iteration, rhs.iteration);
192 : std::swap(iq, rhs.iq);
193 : std::swap(num_observed, rhs.num_observed);
194 : return *this;
195 : }
196 :
197 107540 : void IterationQueue::MarkableIteration::markSampleFilled() {
198 107540 : std::unique_lock notify_lock_guard(notify_mutex);
199 107540 : num_observed++;
200 107540 : if (num_observed == iteration.batch()) {
201 10867 : num_observed = 0;
202 10867 : notify_lock_guard.unlock();
203 10867 : iq->markFilled(this);
204 10867 : notify_lock_guard.lock();
205 : }
206 107540 : }
207 :
208 1396 : void IterationQueue::MarkableIteration::setEndSample(
209 : std::vector<Sample>::iterator sample_iterator) {
210 1396 : std::scoped_lock notify_lock_guard(notify_mutex);
211 1396 : auto old_batch = iteration.batch();
212 1396 : if (sample_iterator != iteration.end()) {
213 1266 : iteration.setEndSample(sample_iterator);
214 : }
215 1396 : auto new_batch = iteration.batch();
216 : /// if batch has changed, check if this batch is partially filled and should
217 : /// be moved
218 1396 : if (old_batch != new_batch && num_observed == new_batch) {
219 : #if DEBUG
220 : NNTR_THROW_IF_CLEANUP(iq->empty_mutex.try_lock(), std::runtime_error,
221 : [this] { iq->empty_mutex.unlock(); })
222 : << "iteration queue must be locked already but empty_mutex is not "
223 : "locked.";
224 : #endif
225 : /// warning: iq has to be locked with iq->empty_mutex
226 1264 : iq->num_being_filled--;
227 1264 : iq->filled_q.push(this);
228 1264 : iq->notify_emptied_cv.notify_all();
229 1264 : num_observed = 0;
230 : }
231 1396 : }
232 :
233 : } // namespace nntrainer
|