Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2022 Jijoong Moon <jijoong.moon@samsung.com>
4 : *
5 : * @file nntr_threads.cpp
6 : * @date 07 July 2022
7 : * @brief Thread Management for NNTrainer
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jijoong Moon <jijoong.moon@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : */
12 :
13 : #include <algorithm>
14 : #include <nntr_threads.h>
15 :
16 : #ifdef NNTR_NUM_THREADS
17 : static const unsigned int nntr_num_threads = NNTR_NUM_THREADS;
18 : #else
19 : static const unsigned int nntr_num_threads = 1;
20 : #endif
21 :
22 : namespace nntrainer {
23 :
24 331 : ParallelBatch::ParallelBatch(unsigned int batch_size) :
25 : cb(nullptr),
26 331 : batch(batch_size),
27 331 : num_workers(nntr_num_threads > batch ? 1 : nntr_num_threads),
28 331 : user_data_prop(new props::PropsUserData(nullptr)){};
29 :
30 911 : ParallelBatch::ParallelBatch(threaded_cb threaded_cb_, unsigned int batch_size,
31 911 : void *user_data_) :
32 911 : cb(threaded_cb_),
33 911 : batch(batch_size),
34 911 : num_workers(nntr_num_threads > batch ? 1 : nntr_num_threads),
35 911 : user_data_prop(new props::PropsUserData(user_data_)) {}
36 :
37 2484 : ParallelBatch::~ParallelBatch() {}
38 :
39 0 : void ParallelBatch::run() {
40 :
41 0 : if (!cb) {
42 0 : throw std::invalid_argument("nntrainer threads: callback is not defined");
43 : }
44 :
45 : unsigned int start = 0;
46 0 : unsigned int end = batch;
47 :
48 0 : unsigned int chunk = (end - start + (num_workers - 1)) / num_workers;
49 :
50 0 : for (unsigned int i = 0; i < num_workers; ++i) {
51 0 : unsigned int s = start + i * chunk;
52 0 : unsigned int e = s + chunk;
53 0 : if (e > end)
54 0 : e = end;
55 0 : workers.push_back(std::thread(cb, s, e, i, user_data_prop->get()));
56 : }
57 :
58 : std::for_each(workers.begin(), workers.end(),
59 : std::mem_fn(&std::thread::join));
60 0 : }
61 :
62 0 : void ParallelBatch::setCallback(threaded_cb threaded_cb_, void *user_data_) {
63 0 : cb = threaded_cb_;
64 0 : user_data_prop = std::make_unique<props::PropsUserData>(user_data_);
65 0 : }
66 :
67 : } // namespace nntrainer
|