Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file data_iteration.h
6 : * @date 11 Aug 2021
7 : * @brief This file contains iteration and sample class
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 : #ifndef __DATA_SAMPLE_H__
14 : #define __DATA_SAMPLE_H__
15 :
16 : #include <functional>
17 : #include <memory>
18 : #include <tuple>
19 : #include <vector>
20 :
21 : #include <tensor.h>
22 : #include <tensor_dim.h>
23 :
24 : namespace nntrainer {
25 :
26 : class Sample;
27 :
28 : /**
29 : * @brief Iteration class which owns the memory chunk for a single batch
30 : *
31 : */
32 : class Iteration {
33 :
34 : public:
35 : /**
36 : * @brief Construct a new Iteration object
37 : * @note the batch dimension must be the same for all given dimensions and the
38 : * there must be at least one input
39 : *
40 : * @param input_dims input dimension
41 : * @param label_dims label dimension
42 : */
43 : Iteration(const std::vector<ml::train::TensorDim> &input_dims,
44 : const std::vector<ml::train::TensorDim> &label_dims);
45 :
46 : Iteration(const Iteration &rhs) = delete;
47 : Iteration &operator=(const Iteration &rhs) = delete;
48 0 : Iteration(Iteration &&rhs) = default;
49 0 : Iteration &operator=(Iteration &&rhs) = default;
50 :
51 : /**
52 : * @brief get batch size of iteration
53 : *
54 : * @return unsigned int batch size
55 : */
56 : unsigned int batch();
57 :
58 : /**
59 : * @brief Get the Input Reference object
60 : *
61 : * @return std::vector<Tensor>& input
62 : */
63 10860 : std::vector<Tensor> &getInputsRef() { return inputs; }
64 :
65 : /**
66 : * @brief Get the Input Reference object
67 : *
68 : * @return const std::vector<Tensor>& input
69 : */
70 91527 : const std::vector<Tensor> &getInputsRef() const { return inputs; }
71 :
72 : /**
73 : * @brief Get the Label Reference object
74 : *
75 : * @return std::vector<Tensor>& label
76 : */
77 10860 : std::vector<Tensor> &getLabelsRef() { return labels; }
78 :
79 : /**
80 : * @brief Get the Label Reference object
81 : *
82 : * @return const std::vector<Tensor>& label
83 : */
84 91525 : const std::vector<Tensor> &getLabelsRef() const { return labels; }
85 :
86 : /**
87 : * @brief get sample iterator begin()
88 : *
89 : * @return std::vector<Sample>::iterator
90 : */
91 : std::vector<Sample>::iterator begin() { return samples.begin(); }
92 :
93 : /**
94 : * @brief get sample iterator end
95 : *
96 : * @return std::vector<Sample>::iterator
97 : */
98 226224 : std::vector<Sample>::iterator end() { return end_iterator; }
99 :
100 : /**
101 : * @brief get sample iterator begin
102 : *
103 : * @return std::vector<Sample>::const_iterator
104 : */
105 : std::vector<Sample>::const_iterator begin() const { return samples.begin(); }
106 :
107 : /**
108 : * @brief get sample iterator end
109 : *
110 : * @return std::vector<Sample>::const_iterator
111 : */
112 : std::vector<Sample>::const_iterator end() const { return end_iterator; }
113 :
114 : /**
115 : * @brief set end of the sample which will be used to calculate the batch size
116 : * @note @a iteration must be non-inclusive
117 : *
118 : */
119 : void setEndSample(std::vector<Sample>::iterator sample_iterator);
120 :
121 : /**
122 : * @brief Set the End Sample to the original end
123 : *
124 : */
125 : void setEndSample();
126 :
127 : private:
128 : std::vector<Tensor> inputs, labels;
129 : std::vector<Sample> samples;
130 : std::vector<Sample>::iterator end_iterator; /**< actual end iterator */
131 : };
132 :
133 : /**
134 : * @brief Sample class which views the memory for a single sample
135 : *
136 : */
137 91525 : class Sample {
138 :
139 : public:
140 : /**
141 : * @brief Construct a new Sample object
142 : * @note the batch dimension will be ignored to make a single sample
143 : *
144 : * @param iter iteration objects
145 : * @param batch nth batch to create the sample
146 : */
147 : Sample(const Iteration &iter, unsigned int batch);
148 :
149 : Sample(const Sample &rhs) = delete;
150 : Sample &operator=(const Sample &rhs) = delete;
151 0 : Sample(Sample &&rhs) = default;
152 : Sample &operator=(Sample &&rhs) = default;
153 :
154 : /**
155 : * @brief Get the Input Reference object
156 : *
157 : * @return std::vector<Tensor>& input
158 : */
159 107532 : std::vector<Tensor> &getInputsRef() { return inputs; }
160 :
161 : /**
162 : * @brief Get the Input Reference object
163 : *
164 : * @return const std::vector<Tensor>& input
165 : */
166 : const std::vector<Tensor> &getInputsRef() const { return inputs; }
167 :
168 : /**
169 : * @brief Get the Label Reference object
170 : *
171 : * @return std::vector<Tensor>& label
172 : */
173 107532 : std::vector<Tensor> &getLabelsRef() { return labels; }
174 :
175 : /**
176 : * @brief Get the Label Reference object
177 : *
178 : * @return const std::vector<Tensor>& label
179 : */
180 : const std::vector<Tensor> &getLabelsRef() const { return labels; }
181 :
182 : private:
183 : std::vector<Tensor> inputs, labels;
184 : };
185 :
186 : } // namespace nntrainer
187 :
188 : #endif // __DATA_SAMPLE_H__
|