LCOV - code coverage report
Current view: top level - nntrainer/dataset - iteration_queue.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 88.2 % 102 90
Test Date: 2025-12-14 20:38:17 Functions: 86.7 % 15 13

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
       4              :  *
       5              :  * @file   iteration_queue.cpp
       6              :  * @date   13 July 2021
       7              :  * @brief  This file contains thread safe queue
       8              :  * @see    https://github.com/nnstreamer/nntrainer
       9              :  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  *
      12              :  */
      13              : #include <chrono>
      14              : #include <iteration_queue.h>
      15              : 
      16              : #include <mutex>
      17              : #include <nntrainer_error.h>
      18              : #include <shared_mutex>
      19              : 
      20              : using namespace std::literals::chrono_literals;
      21              : 
      22              : namespace nntrainer {
      23              : 
      24         1441 : IterationQueue::IterationQueue(
      25              :   unsigned int num_slots, const std::vector<ml::train::TensorDim> &input_dims,
      26         1441 :   const std::vector<ml::train::TensorDim> &label_dims) :
      27         1441 :   being_filled(nullptr),
      28         1441 :   num_being_filled(0),
      29         1441 :   flow_state(IterationQueue::FlowState::FLOW_STATE_OPEN) {
      30         1444 :   NNTR_THROW_IF(num_slots == 0, std::invalid_argument)
      31              :     << "number of slots must be more then zero";
      32              : 
      33         1438 :   iterations.reserve(num_slots);
      34         7134 :   for (decltype(num_slots) i = 0; i < num_slots; ++i) {
      35         5703 :     iterations.emplace_back(input_dims, label_dims, this);
      36         5696 :     empty_q.push(&iterations.back());
      37              :   }
      38         1434 :   batch_size = iterations.front().get().batch();
      39         1448 : }
      40              : 
      41         1434 : IterationQueue::~IterationQueue() {
      42         1434 :   std::scoped_lock lg(empty_mutex, filled_mutex);
      43              : 
      44              :   /// if an iteration is not included in either empty_q or filled_q, that
      45              :   /// means it's either being filled or being served. Which means it will be
      46              :   /// dangerous to destroy @a this, we might want to wait on the destructor if
      47              :   /// we can assure this can stay no except
      48         1434 :   if (empty_q.size() + filled_q.size() < iterations.size()) {
      49            0 :     ml_logw(
      50              :       "Destroying the iteration queue, while some buffers are being used");
      51              :   }
      52         1434 : }
      53              : 
      54       107552 : ScopedView<Sample> IterationQueue::requestEmptySlot() {
      55       107552 :   std::scoped_lock lg(empty_mutex);
      56              :   auto current_flow_state = flow_state.load();
      57       107560 :   NNTR_THROW_IF(current_flow_state != FlowState::FLOW_STATE_OPEN,
      58              :                 std::invalid_argument)
      59              :     << "the queue expect state of "
      60              :     << static_cast<unsigned>(FlowState::FLOW_STATE_OPEN) << " but met "
      61              :     << static_cast<unsigned>(current_flow_state);
      62              : 
      63              :   /// below is useful information when debugging iteration queue, but there will
      64              :   /// be too much log if we turn the log on. so leaving it as a comment for now.
      65              :   // std::cout << "[requestEmptySlot] empty_q.size(): " << empty_q.size()
      66              :   // << " being_filled: " << num_being_filled
      67              :   // << " filled_q.size():  " << filled_q.size() << '\n';
      68              : 
      69       107544 :   if (being_filled == nullptr ||
      70              :       current_iterator + 1 == being_filled->get().end()) {
      71        12135 :     being_filled = empty_q.waitAndPop();
      72        12135 :     being_filled->reset();
      73        12135 :     num_being_filled++;
      74        12135 :     current_iterator = being_filled->get().begin();
      75              :   } else {
      76              :     current_iterator++;
      77              :   }
      78              : 
      79              :   auto view = ScopedView<Sample>(
      80              :     &(*current_iterator),
      81              :     [current_being_filed = this->being_filled] {
      82       107540 :       current_being_filed->markSampleFilled();
      83              :     },
      84       107544 :     [this, current_being_filled = this->being_filled] {
      85            4 :       std::unique_lock lg(empty_mutex);
      86            4 :       this->markEmpty(current_being_filled);
      87            4 :       num_being_filled--;
      88            4 :       notify_emptied_cv.notify_all();
      89       107548 :     });
      90       107544 :   return view;
      91              : }
      92              : 
      93        13529 : ScopedView<Iteration> IterationQueue::requestFilledSlot() {
      94        13529 :   std::scoped_lock lock(filled_mutex);
      95              : 
      96              :   /// below is useful information when debugging iteration queue, but there will
      97              :   /// be too much log if we turn the log on. so leaving it as a comment for now.
      98              :   // std::cout << "[requestFilledSlot] empty_q.size(): " << empty_q.size()
      99              :   // << " num being filled: " << num_being_filled
     100              :   // << " filled_q.size(): " << filled_q.size() << '\n';
     101        13529 :   if (flow_state.load() == FlowState::FLOW_STATE_STOPPED) {
     102              :     return ScopedView<Iteration>(nullptr);
     103              :   }
     104              : 
     105        13525 :   auto iteration = filled_q.waitAndPop();
     106        13525 :   if (iteration == nullptr) {
     107              :     auto stop_request_state = FlowState::FLOW_STATE_STOP_REQUESTED;
     108              :     bool exchange_result = flow_state.compare_exchange_strong(
     109              :       stop_request_state, FlowState::FLOW_STATE_STOPPED);
     110         1395 :     NNTR_THROW_IF(!exchange_result, std::runtime_error)
     111              :       << "the queue has either already stopped or running, but trying stopping "
     112              :          "without requesting stop, queue size: "
     113            0 :       << iterations.size() << " num currently empty: " << empty_q.size()
     114            0 :       << " filled_q.size(): " << filled_q.size();
     115              : 
     116              :     return ScopedView<Iteration>(nullptr);
     117              :   }
     118              : 
     119              :   return ScopedView<Iteration>(
     120        12126 :     &iteration->get(), [this, iteration] { markEmpty(iteration); },
     121            4 :     [this, iteration] {
     122            4 :       std::unique_lock lock(filled_mutex);
     123            4 :       flow_state.store(FlowState::FLOW_STATE_STOPPED);
     124            4 :       markEmpty(iteration);
     125        12134 :     });
     126              : }
     127              : 
     128         1408 : void IterationQueue::notifyEndOfRequestEmpty() {
     129         1408 :   std::unique_lock lg(empty_mutex);
     130              :   auto open_state = FlowState::FLOW_STATE_OPEN;
     131              : 
     132              :   /// we have to defined ordering of having stop_requested -> push nullptr to
     133              :   /// filled_q -> stopped so when the case of changing to stopped it has to push
     134              :   /// nullptr to empty_q, and filled_q to wake them up and stop. this has
     135              :   /// potential cases that weren't considered. let's change this to a simpler
     136              :   /// mechanisms to wait on conditional variable.
     137              :   bool exchange_result = flow_state.compare_exchange_strong(
     138              :     open_state, FlowState::FLOW_STATE_STOP_REQUESTED);
     139         1412 :   NNTR_THROW_IF(!exchange_result, std::invalid_argument)
     140              :     << "the queue expect state of " << static_cast<unsigned>(open_state)
     141              :     << " but met " << static_cast<unsigned>(flow_state.load());
     142              :   /// below is useful information when debugging iteration queue, but there will
     143              :   /// be too much log if we turn the log on. so leaving it as a comment for now.
     144              :   // std::cout << "[notifyEnd] empty_q.size(): " << empty_q.size()
     145              :   //           << " num being filled: " << num_being_filled
     146              :   //           << " filled_q.size(): " << filled_q.size() << '\n';
     147              : 
     148         1404 :   if (being_filled) {
     149         1396 :     being_filled->setEndSample(current_iterator + 1);
     150              :   }
     151         1410 :   notify_emptied_cv.wait(lg, [this] { return num_being_filled == 0; });
     152         1404 :   filled_q.push(nullptr);
     153         1404 : }
     154              : 
     155        10867 : void IterationQueue::markFilled(MarkableIteration *iteration) {
     156              :   {
     157        10867 :     std::lock_guard lg(empty_mutex);
     158        10867 :     --num_being_filled;
     159        10867 :     filled_q.push(iteration);
     160              :   }
     161        10867 :   notify_emptied_cv.notify_all();
     162        10867 : }
     163              : 
     164        12134 : void IterationQueue::markEmpty(MarkableIteration *iteration) {
     165        12134 :   empty_q.push(iteration);
     166        12134 : }
     167              : 
     168         5700 : IterationQueue::MarkableIteration::MarkableIteration(
     169              :   const std::vector<ml::train::TensorDim> &input_dims,
     170         5700 :   const std::vector<ml::train::TensorDim> &label_dims, IterationQueue *iq) :
     171         5700 :   num_observed(0), iteration(input_dims, label_dims), iq(iq) {}
     172              : 
     173            0 : IterationQueue::MarkableIteration::MarkableIteration(MarkableIteration &&rhs) :
     174            0 :   iteration(std::move(rhs.iteration)), iq(rhs.iq) {
     175            0 :   std::lock_guard notify_lock_guard(notify_mutex);
     176            0 :   num_observed = rhs.num_observed;
     177            0 : }
     178              : 
     179        12135 : void IterationQueue::MarkableIteration::reset() {
     180        12135 :   std::lock_guard notify_lock_guard(notify_mutex);
     181        12135 :   num_observed = 0;
     182        12135 :   iteration.setEndSample();
     183        12135 : }
     184              : 
     185              : IterationQueue::MarkableIteration &
     186            0 : IterationQueue::MarkableIteration::operator=(MarkableIteration &&rhs) {
     187            0 :   if (this == &rhs) {
     188              :     return *this;
     189              :   }
     190            0 :   std::scoped_lock lock(this->notify_mutex, rhs.notify_mutex);
     191            0 :   std::swap(iteration, rhs.iteration);
     192              :   std::swap(iq, rhs.iq);
     193              :   std::swap(num_observed, rhs.num_observed);
     194              :   return *this;
     195              : }
     196              : 
     197       107540 : void IterationQueue::MarkableIteration::markSampleFilled() {
     198       107540 :   std::unique_lock notify_lock_guard(notify_mutex);
     199       107540 :   num_observed++;
     200       107540 :   if (num_observed == iteration.batch()) {
     201        10867 :     num_observed = 0;
     202        10867 :     notify_lock_guard.unlock();
     203        10867 :     iq->markFilled(this);
     204        10867 :     notify_lock_guard.lock();
     205              :   }
     206       107540 : }
     207              : 
     208         1396 : void IterationQueue::MarkableIteration::setEndSample(
     209              :   std::vector<Sample>::iterator sample_iterator) {
     210         1396 :   std::scoped_lock notify_lock_guard(notify_mutex);
     211         1396 :   auto old_batch = iteration.batch();
     212         1396 :   if (sample_iterator != iteration.end()) {
     213         1266 :     iteration.setEndSample(sample_iterator);
     214              :   }
     215         1396 :   auto new_batch = iteration.batch();
     216              :   /// if batch has changed, check if this batch is partially filled and should
     217              :   /// be moved
     218         1396 :   if (old_batch != new_batch && num_observed == new_batch) {
     219              : #if DEBUG
     220              :     NNTR_THROW_IF_CLEANUP(iq->empty_mutex.try_lock(), std::runtime_error,
     221              :                           [this] { iq->empty_mutex.unlock(); })
     222              :       << "iteration queue must be locked already but empty_mutex is not "
     223              :          "locked.";
     224              : #endif
     225              :     /// warning: iq has to be locked with iq->empty_mutex
     226         1264 :     iq->num_being_filled--;
     227         1264 :     iq->filled_q.push(this);
     228         1264 :     iq->notify_emptied_cv.notify_all();
     229         1264 :     num_observed = 0;
     230              :   }
     231         1396 : }
     232              : 
     233              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1