Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file data_producer.h
6 : * @date 09 July 2021
7 : * @brief This file contains data producer interface
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 : #ifndef __DATA_PRODUCER_H__
14 : #define __DATA_PRODUCER_H__
15 :
16 : #include <functional>
17 : #include <limits>
18 : #include <string>
19 : #include <tuple>
20 : #include <vector>
21 :
22 : #include <common.h>
23 : #include <tensor.h>
24 : #include <tensor_dim.h>
25 : namespace nntrainer {
26 :
27 : class Exporter;
28 :
29 : /**
30 : * @brief DataProducer interface used to abstract data provider
31 : *
32 : */
33 : class DataProducer {
34 : public:
35 : /**
36 : * @brief generator callable type which will fill a sample
37 : * @param[in] index current index with range of [0, size() - 1]. If
38 : * size() == SIZE_UNDEFINED, this parameter can be ignored
39 : * @param[out] inputs allocate tensor before expected to be filled by this
40 : * function
41 : * @param[out] labels allocate tensor before expected to be filled by this
42 : * function function.
43 : * @return bool true if this is the last sample, samples will NOT be ignored
44 : * and should be used, or passed at will of caller
45 : *
46 : */
47 : using Generator = std::function<bool(unsigned int, /** index */
48 : std::vector<Tensor> & /** inputs */,
49 : std::vector<Tensor> & /** labels */)>;
50 :
51 : constexpr inline static unsigned int SIZE_UNDEFINED =
52 : std::numeric_limits<unsigned int>::max();
53 :
54 : /**
55 : * @brief Destroy the Data Loader object
56 : *
57 : */
58 : virtual ~DataProducer() {}
59 :
60 : /**
61 : * @brief Get the producer type
62 : * @return const std::string type representation
63 : */
64 : virtual const std::string getType() const = 0;
65 :
66 : /**
67 : * @brief Set the Property object
68 : *
69 : * @param properties properties to set
70 : */
71 0 : virtual void setProperty(const std::vector<std::string> &properties) {
72 0 : if (!properties.empty()) {
73 0 : throw std::invalid_argument("There are unparsed properties");
74 : }
75 0 : }
76 :
77 : /**
78 : * @brief finalize the class to return an immutable Generator.
79 : * @remark this function must assume that the batch dimension of each tensor
80 : * dimension is one. If actual dimension is not one, this function must ignore
81 : * the batch dimension and assume it to be one.
82 : * @param input_dims input dimensions.
83 : * @param label_dims label dimensions.
84 : * @param user_data user data to be used when finalize.
85 : * @return Generator generator is a function that generates a sample upon
86 : * call.
87 : */
88 0 : virtual Generator finalize(const std::vector<TensorDim> &input_dims,
89 : const std::vector<TensorDim> &label_dims,
90 : void *user_data = nullptr) {
91 0 : return Generator();
92 : }
93 :
94 : /**
95 : * @brief get the number of samples inside the dataset, if size
96 : * cannot be determined, this function must return.
97 : * DataProducer::SIZE_UNDEFINED.
98 : * @remark this function must assume that the batch dimension of each tensor
99 : * dimension is one. If actual dimension is not one, this function must ignore
100 : * the batch dimension and assume it to be one
101 : * @param input_dims input dimensions
102 : * @param label_dims label dimensions
103 : *
104 : * @return size calculated size
105 : */
106 1321 : virtual unsigned int size(const std::vector<TensorDim> &input_dims,
107 : const std::vector<TensorDim> &label_dims) const {
108 1321 : return SIZE_UNDEFINED;
109 : }
110 :
111 : /**
112 : * @brief this function helps exporting the data producer in a predefined
113 : * format, while workarounding issue caused by templated function type eraser
114 : *
115 : * @param exporter exporter that contains exporting logic
116 : * @param method enum value to identify how it should be exported to
117 : */
118 0 : virtual void exportTo(Exporter &exporter,
119 0 : const ml::train::ExportMethods &method) const {}
120 :
121 : /**
122 : * @brief denote if given producer is thread safe and can be parallelized.
123 : * @note if size() == SIZE_UNDEFINED, thread safe shall be false
124 : *
125 : * @return bool true if thread safe.
126 : */
127 0 : virtual bool isMultiThreadSafe() const { return false; }
128 : };
129 : } // namespace nntrainer
130 : #endif // __DATA_PRODUCER_H__
|