Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2022 Jijoong Moon <jijoong.moon@samsung.com>
4 : *
5 : * @file nntr_threads.h
6 : * @date 07 July 2022
7 : * @brief Thread Management for NNTrainer
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jijoong Moon <jijoong.moon@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : */
12 : #ifndef __NNTR_THREADS_H__
13 : #define __NNTR_THREADS_H__
14 :
15 : #include <string>
16 : #include <thread>
17 : #include <vector>
18 :
19 : #include <common_properties.h>
20 : #include <nntrainer_error.h>
21 : #include <util_func.h>
22 :
23 : typedef void (*loop_cb)(unsigned int start, unsigned int end, unsigned int pid,
24 : void *user_data);
25 :
26 : typedef std::function<std::remove_pointer<loop_cb>::type> threaded_cb;
27 :
28 : namespace nntrainer {
29 :
30 : /**
31 : * @brief ParallelBatch class to parallelize along batch direction
32 : *
33 : */
34 : class ParallelBatch {
35 : public:
36 : /**
37 : * @brief Construct a new ParallelBatch object
38 : * @param unsigned int total number of batch
39 : *
40 : */
41 :
42 : ParallelBatch(unsigned int batch);
43 :
44 : /**
45 : * @brief Construct a new ParallelBatch object
46 : * @param threaded_cb the function run in thread
47 : * @param unsigned int total number of batch
48 : * @param void* user data for the threaded callback function
49 : *
50 : */
51 : ParallelBatch(threaded_cb threaded_cb_, unsigned int batch, void *user_data_);
52 :
53 : /**
54 : * @brief Destroy the ParallelBatch object
55 : *
56 : */
57 : ~ParallelBatch();
58 :
59 : /**
60 : * @brief Run the workders
61 : *
62 : */
63 : void run();
64 :
65 : /**
66 : * @brief set the thread callback function
67 : * @param Threadedcb the function run in thread
68 : * @param void* user data for the threaded callback function
69 : */
70 : void setCallback(threaded_cb t_cb, void *user_data);
71 :
72 : /**
73 : * @brief return the number of workders
74 : * @return unsigned int the number of workers
75 : *
76 : */
77 1242 : unsigned int getNumWorkers() { return num_workers; }
78 :
79 : private:
80 : threaded_cb cb;
81 : unsigned int batch;
82 : unsigned int num_workers;
83 : std::vector<std::thread> workers;
84 : std::unique_ptr<props::PropsUserData> user_data_prop;
85 : };
86 :
87 : } // namespace nntrainer
88 : #endif // __NODE_EXPORTER_H__
|