Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * @file short_tensor.h
4 : * @date 10 January 2025
5 : * @brief This is ShortTensor class for 16-bit signed integer calculation
6 : * @see https://github.com/nnstreamer/nntrainer
7 : * @author Donghyeon Jeong <dhyeon.jeong@samsung.com>
8 : * @bug No known bugs except for NYI items
9 : */
10 :
11 : #ifndef __SHORT_TENSOR_H__
12 : #define __SHORT_TENSOR_H__
13 : #ifdef __cplusplus
14 :
15 : #include <tensor_base.h>
16 :
17 : namespace nntrainer {
18 :
19 : /**
20 : * @class ShortTensor class
21 : * @brief ShortTensor class for 16-bit unsigned integer calculation
22 : */
23 : class ShortTensor : public TensorBase {
24 : public:
25 : /**
26 : * @brief Basic Constructor of Tensor
27 : */
28 : ShortTensor(std::string name_ = "", Tformat fm = Tformat::NCHW,
29 : QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE);
30 :
31 : /**
32 : * @brief Construct a new ShortTensor object
33 : *
34 : * @param d Tensor dim for this float tensor
35 : * @param alloc_now Allocate memory to this tensor or not
36 : * @param init Initializer for the tensor
37 : * @param name Name of the tensor
38 : */
39 : ShortTensor(const TensorDim &d, bool alloc_now,
40 : Initializer init = Initializer::NONE, std::string name = "",
41 : QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE);
42 :
43 : /**
44 : * @brief Construct a new ShortTensor object
45 : * @param d Tensor dim for this tensor
46 : * @param buf buffer
47 : */
48 : ShortTensor(const TensorDim &d, const void *buf = nullptr,
49 : QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE);
50 :
51 : /**
52 : * @brief Construct a new ShortTensor object
53 : * @param d data for the Tensor
54 : * @param fm format for the Tensor
55 : */
56 : ShortTensor(
57 : std::vector<std::vector<std::vector<std::vector<int16_t>>>> const &d,
58 : std::vector<float> const &scales, Tformat fm, QScheme qscheme_);
59 :
60 : /**
61 : * @brief Construct a new ShortTensor object
62 : * @param rhs TensorBase object to copy
63 : */
64 0 : ShortTensor(TensorBase &rhs) :
65 0 : TensorBase(rhs), qscheme(QScheme::PER_TENSOR_AFFINE) {}
66 :
67 : /**
68 : * @brief Basic Destructor
69 : */
70 22 : ~ShortTensor() {}
71 :
72 : /**
73 : * @brief Comparison operator overload
74 : * @param[in] rhs Tensor to be compared with
75 : * @note Only compares Tensor data
76 : */
77 : bool operator==(const ShortTensor &rhs) const;
78 :
79 : /**
80 : * @brief Comparison operator overload
81 : * @param[in] rhs Tensor to be compared with
82 : * @note Only compares Tensor data
83 : */
84 : bool operator!=(const ShortTensor &rhs) const { return !(*this == rhs); }
85 :
86 : /**
87 : * @copydoc Tensor::allocate()
88 : */
89 : void allocate() override;
90 :
91 : /**
92 : * @copydoc Tensor::deallocate()
93 : */
94 : void deallocate() override;
95 :
96 : /**
97 : * @copydoc Tensor::getData()
98 : */
99 : void *getData() const override;
100 :
101 : /**
102 : * @copydoc Tensor::getData(size_t idx)
103 : */
104 : void *getData(size_t idx) const override;
105 :
106 : /**
107 : * @copydoc Tensor::getScale()
108 : */
109 : void *getScale() const override;
110 :
111 : /**
112 : * @copydoc Tensor::getScale(size_t idx)
113 : */
114 : void *getScale(size_t idx) const override;
115 :
116 : /**
117 : * @brief i data index
118 : * @retval address of ith data
119 : */
120 : void *getAddress(unsigned int i) override;
121 :
122 : /**
123 : * @brief i data index
124 : * @retval address of ith data
125 : */
126 : const void *getAddress(unsigned int i) const override;
127 :
128 : /**
129 : * @brief return value at specific location
130 : * @param[in] i index
131 : */
132 : const int16_t &getValue(unsigned int i) const;
133 :
134 : /**
135 : * @brief return value at specific location
136 : * @param[in] i index
137 : */
138 : int16_t &getValue(unsigned int i);
139 :
140 : /**
141 : * @brief return value at specific location
142 : * @param[in] b batch location
143 : * @param[in] c channel location
144 : * @param[in] h height location
145 : * @param[in] w width location
146 : */
147 : const int16_t &getValue(unsigned int b, unsigned int c, unsigned int h,
148 : unsigned int w) const;
149 :
150 : /**
151 : * @brief return value at specific location
152 : * @param[in] b batch location
153 : * @param[in] c channel location
154 : * @param[in] h height location
155 : * @param[in] w width location
156 : */
157 : int16_t &getValue(unsigned int b, unsigned int c, unsigned int h,
158 : unsigned int w);
159 :
160 : /**
161 : * @copydoc Tensor::setValue(float value)
162 : */
163 : void setValue(float value) override;
164 :
165 : /**
166 : * @copydoc Tensor::setValue(b, c, h, w, value)
167 : */
168 : void setValue(unsigned int b, unsigned int c, unsigned int h, unsigned int w,
169 : float value) override;
170 :
171 : /**
172 : * @copydoc Tensor::addValue(b, c, h, w, value, beta)
173 : */
174 : void addValue(unsigned int b, unsigned int c, unsigned int h, unsigned int w,
175 : float value, float beta) override;
176 :
177 : /**
178 : * @copydoc Tensor::setZero()
179 : */
180 : void setZero() override;
181 :
182 : /**
183 : * @copydoc Tensor::initialize()
184 : */
185 : void initialize() override;
186 :
187 : /**
188 : * @copydoc Tensor::initialize(Initializer init)
189 : */
190 : void initialize(Initializer init) override;
191 :
192 : /**
193 : * @copydoc Tensor::copy(const Tensor &from)
194 : */
195 : void copy(const Tensor &from) override;
196 :
197 : /**
198 : * @copydoc Tensor::copyData(const Tensor &from)
199 : */
200 : void copyData(const Tensor &from) override;
201 :
202 : /**
203 : * @copydoc Tensor::copy_with_stride()
204 : */
205 : void copy_with_stride(const Tensor &input, Tensor &output) override;
206 :
207 : /**
208 : * @copydoc Tensor::save(std::ostream &file)
209 : */
210 : void save(std::ostream &file) override;
211 :
212 : /**
213 : * @copydoc Tensor::read(std::ifstream &file)
214 : */
215 : void read(std::ifstream &file, size_t start_offset,
216 : bool read_from_offset) override;
217 :
218 : /**
219 : * @copydoc Tensor::argmax()
220 : */
221 : std::vector<unsigned int> argmax() const override;
222 :
223 : /**
224 : * @copydoc Tensor::argmin()
225 : */
226 : std::vector<unsigned int> argmin() const override;
227 :
228 : /**
229 : * @copydoc Tensor::max_abs()
230 : */
231 : float max_abs() const override;
232 :
233 : /**
234 : * @copydoc Tensor::maxValue()
235 : */
236 : float maxValue() const override;
237 :
238 : /**
239 : * @copydoc Tensor::minValue()
240 : */
241 : float minValue() const override;
242 :
243 : /**
244 : * @copydoc Tensor::getMemoryBytes()
245 : */
246 : size_t getMemoryBytes() const override;
247 :
248 : /**
249 : * @copydoc Tensor::print(std::ostream &out)
250 : */
251 : void print(std::ostream &out) const override;
252 :
253 : /**
254 : * @copydoc TensorBase::save_quantization_info()
255 : */
256 : void save_quantization_info(std::ostream &file) override;
257 :
258 : /**
259 : * @copydoc TensorBase::read_quantization_info()
260 : */
261 : void read_quantization_info(std::ifstream &file, size_t start_offset,
262 : bool read_from_offset) override;
263 :
264 : /**
265 : * @copydoc Tensor::scale_size()
266 : */
267 : size_t scale_size() const override;
268 :
269 : /**
270 : * @copydoc Tensor::scale_size()
271 : */
272 : QScheme q_scheme() const override;
273 :
274 : private:
275 : /**
276 : * @brief quantization scheme
277 : */
278 : QScheme qscheme;
279 :
280 : /**
281 : * @brief copy a buffer to @a this, the caller has to ensure that @a this is
282 : * initialized otherwise undefined behavior
283 : *
284 : * @param buf buffer to copy from
285 : */
286 : void copy(const void *buf);
287 :
288 : /**
289 : * @brief Get the Data Type String object
290 : * @return std::string of tensor data type (QINT16)
291 : */
292 0 : std::string getStringDataType() const override { return "QINT16"; }
293 :
294 : /**
295 : * @copydoc Tensor::isValid()
296 : */
297 0 : bool isValid() const override { return true; };
298 : };
299 :
300 : } // namespace nntrainer
301 :
302 : #endif /* __cplusplus */
303 : #endif /* __SHORT_TENSOR_H__ */
|