LCOV - code coverage report
Current view: top level - nntrainer/dataset - iteration_queue.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 93.1 % 29 27
Test Date: 2025-12-14 20:38:17 Functions: 100.0 % 6 6

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
       4              :  *
       5              :  * @file   iteration_queue.h
       6              :  * @date   13 July 2021
       7              :  * @brief  This file contains thread safe queue for a single iteration
       8              :  * @see    https://github.com/nnstreamer/nntrainer
       9              :  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  *
      12              :  */
      13              : #ifndef __ITERATION_QUEUE_H__
      14              : #define __ITERATION_QUEUE_H__
      15              : 
      16              : #include <atomic>
      17              : #include <condition_variable>
      18              : #include <functional>
      19              : #include <memory>
      20              : #include <queue>
      21              : #include <shared_mutex>
      22              : #include <stdexcept>
      23              : #include <tuple>
      24              : 
      25              : #include <data_iteration.h>
      26              : #include <data_producer.h>
      27              : #include <nntrainer_log.h>
      28              : #include <tensor.h>
      29              : #include <tensor_dim.h>
      30              : 
      31              : namespace nntrainer {
      32              : 
      33              : /**
      34              :  * @brief Thread Safe Queue implementation dedicated for the non-owing pointer
      35              :  *
      36              :  * @tparam original type of the view (T * will be pushed and pop)
      37              :  */
      38         1448 : template <typename T> class ViewQueue {
      39              : public:
      40              :   /**
      41              :    * @brief Construct a new queue
      42              :    */
      43         5764 :   ViewQueue() : q() {}
      44              : 
      45              :   /**
      46              :    * @brief push data to queue
      47              :    *
      48              :    * @param data data to put
      49              :    */
      50        31365 :   void push(T *data) {
      51              :     {
      52        31365 :       std::unique_lock<std::shared_mutex> lk(q_mutex);
      53              :       q.push(data);
      54              :     }
      55              : 
      56        31365 :     q_cv.notify_one();
      57        31365 :   }
      58              : 
      59              :   /**
      60              :    * @brief pop data from the queue, wait if empty
      61              :    * @note when fail to get, this will return nullptr (eg) when interrupt
      62              :    * happens)
      63              :    * @return T* view of the data
      64              :    */
      65        25660 :   T *waitAndPop() {
      66        25660 :     std::unique_lock<std::shared_mutex> lk(q_mutex);
      67        25660 :     q_cv.wait(lk, [this] { return !q.empty(); });
      68        25660 :     auto ptr = q.front();
      69              :     q.pop();
      70              : 
      71        25660 :     return ptr;
      72              :   }
      73              : 
      74              :   /**
      75              :    * @brief check if current queue is empty
      76              :    *
      77              :    * @return bool true if empty
      78              :    */
      79              :   bool isEmpty() const {
      80              :     std::shared_lock<std::shared_mutex> lk(q_mutex);
      81              :     return q.empty();
      82              :   }
      83              : 
      84              :   /**
      85              :    * @brief check if current queue is empty
      86              :    *
      87              :    * @return bool true if empty
      88              :    */
      89         2868 :   typename std::queue<T *>::size_type size() const {
      90         2868 :     std::shared_lock<std::shared_mutex> lk(q_mutex);
      91         2868 :     return q.size();
      92              :   }
      93              : 
      94              : private:
      95              :   mutable std::shared_mutex q_mutex;
      96              :   std::condition_variable_any q_cv;
      97              : 
      98              :   std::queue<T *> q;
      99              : };
     100              : 
     101              : /**
     102              :  * @brief A view container that calls a callback on destruct
     103              :  * @note the callback must be noexcept, and the given underlying data must
     104              :  * outlive the lifetime of this class
     105              :  *
     106              :  * @tparam T underlying type
     107              :  */
     108              : template <typename T> class ScopedView {
     109              : public:
     110              :   /**
     111              :    * @brief Construct a new Scoped View object
     112              :    *
     113              :    * @param data_ reference of the underlying data
     114              :    * @param on_notify_ callback to be called on exit
     115              :    * @param on_error_ callback to be called on error
     116              :    */
     117              :   ScopedView(T *data_, std::function<void(void)> &&on_notify_ = nullptr,
     118              :              std::function<void(void)> &&on_error_ = nullptr) :
     119       108943 :     data(data_),
     120              :     on_notify(std::forward<std::function<void(void)>>(on_notify_)),
     121              :     on_error(std::forward<std::function<void(void)>>(on_error_)) {}
     122              : 
     123              :   ScopedView(const ScopedView &rhs) = delete;
     124              :   ScopedView &operator=(const ScopedView &rhs) = delete;
     125              : 
     126           19 :   ScopedView(ScopedView &&rhs) = default;
     127              :   ScopedView &operator=(ScopedView &&rhs) = default;
     128              : 
     129              :   /**
     130              :    * @brief check if scoped view contains any underlying data
     131              :    *
     132              :    * @return bool if data is empty
     133              :    */
     134       121057 :   bool isEmpty() { return data == nullptr; }
     135              : 
     136              :   /**
     137              :    * @brief Destroy the Scoped View object, callback is called at this time
     138              :    *
     139              :    */
     140       121092 :   ~ScopedView() {
     141              :     try {
     142       121092 :       if (std::uncaught_exceptions()) {
     143            8 :         if (on_error) {
     144              :           on_error();
     145              :         }
     146              :       } else {
     147       121084 :         if (on_notify) {
     148              :           on_notify();
     149              :         }
     150              :       }
     151            0 :     } catch (...) {
     152            0 :       ml_loge("while handling on_notify or on_error, error happened");
     153              :     }
     154       121092 :   }
     155              : 
     156              :   /**
     157              :    * @brief get the underlying data
     158              :    *
     159              :    * @return T & reference to the underlying data
     160              :    */
     161            4 :   T &get() { return *data; }
     162              : 
     163              :   /**
     164              :    * @brief get the underlying data
     165              :    *
     166              :    * @return T & reference to the underlying data
     167              :    */
     168              :   T const &get() const { return *data; }
     169              : 
     170              : private:
     171              :   T *data; /**< underlying data pointer */
     172              :   std::function<void(void)>
     173              :     on_notify; /**< called when destroyed without error */
     174              :   std::function<void(void)> on_error; /**< called when destroyed with error */
     175              : };
     176              : 
     177              : /**
     178              :  * @brief Iteration queue that owns the buffer for input / labels
     179              :  * @details
     180              :  *
     181              :  * - requestEmptySlot() will give a ScopedView<sample>
     182              :  *     Destructing the returned object will notify the iteration that is done
     183              :  * filling the sample. Once iteration is done filling, it will internally call
     184              :  * IterationQueue::markFilled();
     185              :  * - requestFilledSlot() will give a ScopedView<Iteration>
     186              :  *     Destructing this will notify the queue that is done used (internally
     187              :  * calls IterationQueue::markEmpty())
     188              :  *
     189              :  * @details For an iteration there can be four state.
     190              :  * 1. The buffer is empty, waiting to be filled (will be in empty_q)
     191              :  * 2. The buffer is being filled sample by sample, waiting to be marked as
     192              :  * filled.
     193              :  * 3. The buffer is filled, waiting to be served (will be in filled_q)
     194              :  * 4. The buffer is being served, waiting to be marked as emptied.
     195              :  * @todo apply this to the databuffer
     196              :  * @todo handle error case: 1. when ScopedView<Sample> has met throw
     197              :  *                          2. when ScopedView<Iteration> has met throw
     198              :  */
     199              : class IterationQueue {
     200              : public:
     201              :   /**
     202              :    * @brief Construct a new Iteration Queue object
     203              :    * @note  input_dimension and label_dimension should include the batch, if
     204              :    * IterationQueue::batch() is zero, it means it's invalid
     205              :    * @param num_slots number of slots this iteration queue will allocate, it
     206              :    * should be buffersize/batchsize
     207              :    * @param input_dims input dimensions
     208              :    * @param label_dims label dimensions
     209              :    */
     210              :   IterationQueue(unsigned int num_slots,
     211              :                  const std::vector<ml::train::TensorDim> &input_dims,
     212              :                  const std::vector<ml::train::TensorDim> &label_dims);
     213              : 
     214              :   /**
     215              :    * @brief Destroy the Iteration Queue object
     216              :    *
     217              :    */
     218              :   ~IterationQueue();
     219              : 
     220              :   /**
     221              :    * @brief request empty sample from the queue.
     222              :    * @note User must check if ScopedView actually has a value by calling
     223              :    * ScopedView::isEmpty()
     224              :    * @return ScopedView<Sample> sample view. ScopedView::isEmpty() == true
     225              :    * if there is no more data coming. Destroying the returned object will
     226              :    * signal the queue that the sample is filled.
     227              :    */
     228              :   ScopedView<Sample> requestEmptySlot();
     229              : 
     230              :   /**
     231              :    * @brief request filled iteration from the queue.
     232              :    * @note User must check if ScopedView actually has a value by calling
     233              :    * ScopedView::isEmpty()
     234              :    * @return ScopedView<Iteration> Ieration view. ScopedView::isEmpty() == true
     235              :    * if there is no more data coming. Destroying the returned object will
     236              :    * signal the queue that the sample is done using.
     237              :    *
     238              :    */
     239              :   ScopedView<Iteration> requestFilledSlot();
     240              : 
     241              :   /**
     242              :    * @brief get slot size, slot size is number of batches inside the queue
     243              :    *
     244              :    * @return unsigned int num slot
     245              :    */
     246           47 :   unsigned int slots() { return iterations.size(); }
     247              : 
     248              :   /**
     249              :    * @brief get size of batch for one iteration
     250              :    *
     251              :    * @return unsigned int size of batch
     252              :    */
     253         3128 :   unsigned int batch() { return batch_size; }
     254              : 
     255              :   /**
     256              :    * @brief notifyEndOfRequest, when the producing by requestEmptySlot has
     257              :    * finished.
     258              :    * @note It is important that the owner of this class must ensure that there
     259              :    * will be no more requestEmptySlot call after this. This means that, in case
     260              :    * of multiple workers, the manager of the worker(producer) must know every
     261              :    * producer has finished. and call this function other than each worker call
     262              :    * this function.
     263              :    *
     264              :    */
     265              :   void notifyEndOfRequestEmpty();
     266              : 
     267              : private:
     268              :   /**
     269              :    * @brief A wrapper object around @c Iteration which marks filled when filling
     270              :    * sample is done
     271              :    * @note the given @a iteration_ and @a bq_ must outleave the lifetime of this
     272              :    * class
     273              :    *
     274              :    */
     275         5696 :   class MarkableIteration {
     276              :   public:
     277              :     /**
     278              :      * @brief Construct a new Markable Iteration object
     279              :      *
     280              :      * @param input_dims input dimensions
     281              :      * @param label_dims label dimensions
     282              :      * @param iq_ iteration queue view to notify
     283              :      */
     284              :     MarkableIteration(const std::vector<ml::train::TensorDim> &input_dims,
     285              :                       const std::vector<ml::train::TensorDim> &label_dims,
     286              :                       IterationQueue *iq);
     287              : 
     288              :     /**
     289              :      * @brief reset num observation and internal batch size of iteration
     290              :      *
     291              :      */
     292              :     void reset();
     293              : 
     294              :     /**
     295              :      * @brief Construct a new Markable Iteration object
     296              :      *
     297              :      * @param rhs right side to move
     298              :      */
     299              :     MarkableIteration(MarkableIteration &&rhs);
     300              : 
     301              :     /**
     302              :      * @brief Move Assignement operator
     303              :      *
     304              :      * @param rhs rhs to move
     305              :      * @return MarkableIteration& markable iteration
     306              :      */
     307              :     MarkableIteration &operator=(MarkableIteration &&rhs);
     308              : 
     309              :     /**
     310              :      * @brief mark iteration that one sample is filled
     311              :      * @todo make this function noexcept
     312              :      */
     313              :     void markSampleFilled() /** noexcept */;
     314              : 
     315              :     /**
     316              :      * @brief update end sample to the given iterator and mark last
     317              :      * @note after updating end iterator, this can be markFilled() if every
     318              :      * sample is already filled
     319              :      *
     320              :      * @param iterator non-inclusive iterator to mark the last
     321              :      */
     322              :     void setEndSample(std::vector<Sample>::iterator sample_iterator);
     323              : 
     324              :     /**
     325              :      * @brief get underlying iteration
     326              :      *
     327              :      * @return Iteration& iteration
     328              :      */
     329        13564 :     Iteration &get() { return iteration; }
     330              : 
     331              :   private:
     332              :     unsigned int num_observed; /**< number of observed samples which were passed
     333              :                                   to the callee and notified done filling */
     334              :     mutable std::mutex
     335              :       notify_mutex;      /**< mutex which should be locked when try to notify */
     336              :     Iteration iteration; /**< underlying iteration that this class owns */
     337              :     IterationQueue *iq;  /**< view of iteration queue */
     338              :   };
     339              : 
     340              :   /**
     341              :    * @brief Queue running state enum
     342              :    *
     343              :    */
     344              :   enum class FlowState {
     345              :     FLOW_STATE_OPEN = 0,           /**< nothing */
     346              :     FLOW_STATE_STOP_REQUESTED = 1, /**< request stop */
     347              :     FLOW_STATE_STOPPED = 2,        /**< queue is fully stopped */
     348              :   };
     349              : 
     350              :   /**
     351              :    * @brief mark the given iteration filled
     352              :    * @todo make this noexcept with the thread safe queue
     353              :    * @param iteration iteration to mark it as filled
     354              :    */
     355              :   void markFilled(MarkableIteration *iteration) /** noexcept */;
     356              : 
     357              :   /**
     358              :    * @brief mark the given iteration empty
     359              :    * @todo make this noexcept with the thread safe queue
     360              :    * @param iteration iteration to mark it as emptied
     361              :    */
     362              :   void markEmpty(MarkableIteration *iteration) /** noexcept */;
     363              : 
     364              :   std::vector<MarkableIteration> iterations; /**< allocated iterations */
     365              :   MarkableIteration *being_filled; /**< last iteration that is being filled */
     366              :   std::vector<Sample>::iterator
     367              :     current_iterator; /**< current sample iteration of being_filled */
     368              : 
     369              :   mutable std::mutex empty_mutex; /**< mutex to be used when it is mutually
     370              :                                      exclusive to the requesting empty slots */
     371              :   unsigned int
     372              :     num_being_filled; /**< number of iteration that is in being_filled state */
     373              :   mutable std::mutex
     374              :     filled_mutex; /**< mutex to be used when it is mutually exclusive to the
     375              :                      requesting filled slots */
     376              :   std::condition_variable_any
     377              :     notify_emptied_cv; /**< conditional variable to wait based on the
     378              :                            num_being_filled */
     379              :   std::atomic<FlowState> flow_state; /**< flow state of the queue */
     380              : 
     381              :   unsigned int batch_size;
     382              :   ViewQueue<MarkableIteration> empty_q;  /**< iterations to be filled */
     383              :   ViewQueue<MarkableIteration> filled_q; /**< iterations to be served */
     384              : };
     385              : 
     386              : } // namespace nntrainer
     387              : 
     388              : #endif // __ITERATION_QUEUE_H__
        

Generated by: LCOV version 2.0-1