LCOV - code coverage report
Current view: top level - nntrainer/tensor - task_executor.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 0.0 % 86 0
Test Date: 2025-12-14 20:38:17 Functions: 0.0 % 13 0

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2025 Jijoong Moon <jijoong.moon@samsung.com>
       4              :  *
       5              :  * @file   task_executor.h
       6              :  * @date   04 April 2025
       7              :  * @brief  This file contains a task executor
       8              :  * @see    https://github.com/nnstreamer/nntrainer
       9              :  * @author Jijoong Moon <jijoong.moon@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  * @brief  Task executor class
      12              :  *
      13              :  */
      14              : 
      15              : #include "task_executor.h"
      16              : 
      17              : #include <nntrainer_error.h>
      18              : #include <nntrainer_log.h>
      19              : 
      20              : namespace nntrainer {
      21              : 
      22            0 : TaskExecutor::TaskExecutor(std::string n, size_t thread_count) :
      23            0 :   name(n), stop(false) {
      24            0 :   for (size_t i = 0; i < thread_count; ++i) {
      25            0 :     workers.emplace_back([this] { this->worker_thread(); });
      26              :   }
      27            0 : }
      28              : 
      29            0 : TaskExecutor::~TaskExecutor() {
      30              :   {
      31            0 :     std::unique_lock<std::mutex> lock(queue_mutex);
      32              :     stop = true;
      33              :   }
      34              : 
      35            0 :   cond_var.notify_all();
      36            0 :   for (std::thread &t : workers) {
      37            0 :     if (t.joinable())
      38            0 :       t.join();
      39              :   }
      40            0 : }
      41              : 
      42            0 : void TaskExecutor::worker_thread() {
      43              : 
      44              :   while (true) {
      45              :     Task task;
      46              :     {
      47            0 :       std::unique_lock<std::mutex> lock(queue_mutex);
      48            0 :       cond_var.wait(lock, [this]() { return stop || !task_queue.empty(); });
      49              : 
      50            0 :       if (stop && task_queue.empty()) {
      51            0 :         return;
      52              :       }
      53              : 
      54            0 :       task = std::move(task_queue.front());
      55              :       task_queue.pop();
      56            0 :       task_started[task.id] = true;
      57            0 :       task_started_cv.notify_all();
      58              : 
      59              :       // we are not going to remove the Done Tasks.
      60              :       // we exeplicitly call release tasks. until then, we keep the results and
      61              :       // not going to submit that task again
      62              :       // queued_ids.erase(task.id);
      63              :     }
      64              : 
      65              :     try {
      66            0 :       task.callback(task.data);
      67            0 :       task.promise.set_value();
      68            0 :     } catch (...) {
      69            0 :       ml_loge("[%s] : [Error ] Task ID %d threw an exception\n", name.c_str(),
      70              :               task.id);
      71            0 :     }
      72            0 :   }
      73              : }
      74              : 
      75            0 : int TaskExecutor::submit(TaskCallback cb, void *data) {
      76              : 
      77            0 :   auto canceled = std::make_shared<std::atomic_bool>(false);
      78              :   auto promise = std::make_shared<std::promise<void>>();
      79            0 :   std::shared_future<void> fut = promise->get_future().share();
      80            0 :   int id = getNextTaskId();
      81              : 
      82              :   {
      83            0 :     std::lock_guard<std::mutex> lock(queue_mutex);
      84              : 
      85              :     if (future_map.count(id)) {
      86            0 :       if (!future_map[id].valid()) {
      87            0 :         ml_loge("[%s] : [Error] Future is not valid : Task id - %d\n",
      88              :                 name.c_str(), id);
      89              :       }
      90            0 :       auto status = future_map[id].wait_for(std::chrono::seconds(0));
      91            0 :       if (status != std::future_status::ready) {
      92            0 :         ml_logi("[%s] : Task id - %d is still active\n", name.c_str(), id);
      93            0 :         return id;
      94              :       }
      95              :     }
      96              : 
      97            0 :     Task task{id, std::move(*promise), cb, data};
      98              : 
      99            0 :     future_map[id] = fut;
     100              : 
     101              :     task_queue.push(std::move(task));
     102            0 :   }
     103            0 :   cond_var.notify_one();
     104            0 :   return id;
     105              : }
     106              : 
     107            0 : void TaskExecutor::submitTasks(const std::vector<TaskDesc> &tasks) {
     108            0 :   for (const auto &task : tasks) {
     109            0 :     submit(task.callback, task.data);
     110              :   }
     111            0 : }
     112              : 
     113            0 : bool TaskExecutor::cancel(int id) {
     114            0 :   std::lock_guard<std::mutex> lock(queue_mutex);
     115              :   auto it = cancel_map.find(id);
     116            0 :   if (it != cancel_map.end()) {
     117              :     *(it->second) = true;
     118            0 :     return true;
     119              :   }
     120              :   return false;
     121              : }
     122              : 
     123            0 : void TaskExecutor::wait(int id) {
     124            0 :   std::shared_future<void> fut;
     125              :   {
     126            0 :     std::unique_lock<std::mutex> lock(queue_mutex);
     127              : 
     128            0 :     task_started_cv.wait(
     129            0 :       lock, [&] { return task_started.count(id) && task_started[id]; });
     130              : 
     131              :     auto it = future_map.find(id);
     132            0 :     if (it == future_map.end() || !it->second.valid()) {
     133              :       return;
     134              :     }
     135            0 :     fut = it->second;
     136              :   }
     137              :   try {
     138            0 :     fut.wait();
     139            0 :   } catch (const std::future_error &e) {
     140            0 :     ml_loge("[%s] : exception while waiting on future : %s\n", name.c_str(),
     141              :             e.what());
     142            0 :   }
     143              : }
     144              : 
     145            0 : void TaskExecutor::waitAll(const std::vector<int> &ids) {
     146              :   std::vector<std::shared_future<void>> futures;
     147              :   {
     148            0 :     std::lock_guard<std::mutex> lock(queue_mutex);
     149            0 :     for (int id : ids) {
     150              :       auto it = future_map.find(id);
     151            0 :       if (it != future_map.end()) {
     152            0 :         futures.push_back(it->second);
     153              :       } else {
     154            0 :         ml_logw("[%s] : Task ID is not found : %d\n", name.c_str(), id);
     155              :       }
     156              :     }
     157              :   }
     158              : 
     159            0 :   for (auto &fut : futures) {
     160              :     try {
     161            0 :       fut.wait();
     162            0 :     } catch (const std::exception &e) {
     163            0 :       ml_loge("[%s] : exception while waiting on future : %s\n", name.c_str(),
     164              :               e.what());
     165            0 :     }
     166              :   }
     167            0 : }
     168              : 
     169            0 : bool TaskExecutor::isDone(int id) {
     170            0 :   std::lock_guard<std::mutex> lock(queue_mutex);
     171              :   auto it = future_map.find(id);
     172            0 :   if (it == future_map.end())
     173              :     return false;
     174            0 :   return it->second.wait_for(std::chrono::seconds(0)) ==
     175            0 :          std::future_status::ready;
     176              : }
     177              : 
     178            0 : bool TaskExecutor::isAllDone(const std::vector<int> &ids) {
     179            0 :   std::lock_guard<std::mutex> lock(queue_mutex);
     180            0 :   for (int id : ids) {
     181            0 :     isDone(id);
     182              :   }
     183            0 :   return true;
     184              : }
     185              : 
     186            0 : void TaskExecutor::releaseTask(int id) {
     187            0 :   std::lock_guard<std::mutex> lock(queue_mutex);
     188              :   future_map.erase(id);
     189              :   cancel_map.erase(id);
     190              :   reusable_ids.push(id);
     191            0 : }
     192              : 
     193              : } // namespace nntrainer
        

Generated by: LCOV version 2.0-1