Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * @file quantizer.h
4 : * @date 10 December 2024
5 : * @brief This defines quantizers for different types of quantization schemes
6 : * @see https://github.com/nnstreamer/nntrainer
7 : * @author Donghyeon Jeong <dhyeon.jeong@samsung.com>
8 : * @bug No known bugs except for NYI items
9 : */
10 :
11 : #ifndef __QUANTIZER_H__
12 : #define __QUANTIZER_H__
13 : #ifdef __cplusplus
14 :
15 : #include <memory>
16 : #include <stdexcept>
17 : #include <unordered_map>
18 :
19 : #include <tensor_dim.h>
20 :
21 : namespace nntrainer {
22 :
23 : class Tensor;
24 :
25 : /**
26 : * @brief defines the quantization scheme
27 : * @details NNTrainer provides basic quantization schemes (e.g., Per tensor
28 : * affine quantization). Various quantization schemes will be continuously
29 : * updated. If you would like to use a different quantization technique, please
30 : * select a custom quantizer scheme.
31 : */
32 : enum class QScheme : uint16_t {
33 : /** predefined quantizer */
34 : PER_TENSOR_AFFINE = 0x00,
35 : PER_CHANNEL_AFFINE = 0x01,
36 : BINARY_CODE_BASED = 0x02,
37 : Q4_Kx8 = 0x03,
38 : Q6_K = 0x4,
39 : Q4_0 = 0x5,
40 : /** this is for custom use */
41 : CUSTOM_QUANTIZER_01 = 0x10,
42 : CUSTOM_QUANTIZER_02 = 0x11,
43 : CUSTOM_QUANTIZER_03 = 0x12,
44 : CUSTOM_QUANTIZER_04 = 0x13,
45 : CUSTOM_QUANTIZER_05 = 0x14,
46 : CUSTOM_QUANTIZER_06 = 0x15,
47 : };
48 :
49 : /**
50 : * @class Quantizer class
51 : * @brief Quantizer class is a base class for all quantizers.
52 : * @note A custom quantizer must inherit this class and implement virtual
53 : * functions.
54 : */
55 : class Quantizer {
56 : private:
57 : static inline std::unordered_map<QScheme, Quantizer *>
58 : custom_quantizers; /** Hash table that holds empty instances of the custom
59 : quantizers */
60 :
61 : protected:
62 : long int quant_min = 0;
63 : long int quant_max = 0;
64 :
65 : /**
66 : * @brief Register the user defined quantizer class
67 : *
68 : * @param qscheme Quantization scheme (use CUSTOM_QUANTIZER_#)
69 : * @param quantizer quantizer class to register
70 : *
71 : * @note This function registers the custom quantizer class. User defined
72 : * derived class must be registered with this function.
73 : */
74 : static void registerQuantizer(QScheme qscheme, Quantizer &quantizer) {
75 : custom_quantizers.insert(std::make_pair(qscheme, &quantizer));
76 : }
77 :
78 : /**
79 : * @brief Calculate the quantization parameters
80 : *
81 : * @note This will be used to determine the quantization parameters.
82 : * QParams must be determined before quantization.
83 : *
84 : * @param input input tensor
85 : * @param qtype quantized data type
86 : */
87 : virtual void calculateQParams(const Tensor &input,
88 : ml::train::TensorDim::DataType qtype) = 0;
89 :
90 : /**
91 : * @brief Calculate the minimum & maximum value
92 : * @param qtype quantized data type
93 : */
94 : void calculateMinMaxValue(ml::train::TensorDim::DataType qtype);
95 :
96 : public:
97 : /**
98 : * @brief Basic Constructor of a Quantizer
99 : */
100 : Quantizer() = default;
101 :
102 : /**
103 : * @brief Basic Destructor of a Quantizer
104 : */
105 : virtual ~Quantizer() = default;
106 :
107 : /**
108 : * @brief Get the Registered Quantizer object
109 : *
110 : * @param qscheme Quantization scheme
111 : * @return Quantizer* registered quantizer object
112 : */
113 0 : static Quantizer *getRegisteredQuantizer(QScheme qscheme) {
114 0 : if (custom_quantizers.find(qscheme) == custom_quantizers.end()) {
115 0 : throw std::invalid_argument("requested quantizer is not registered.");
116 : }
117 0 : return custom_quantizers.at(qscheme);
118 : }
119 :
120 : /** Derived classes must implement the following functions */
121 : /**
122 : * @brief Create a new object of itself
123 : *
124 : * @return std::unique_ptr<Quantizer>
125 : */
126 : virtual std::unique_ptr<Quantizer> create() = 0;
127 :
128 : /**
129 : * @brief Quantize a tensor into a quantized tensor.
130 : * @param[in] input Floating point tensor to quantize
131 : * @return Tensor quantized tensor
132 : */
133 : virtual Tensor quantize(const Tensor &input,
134 : ml::train::TensorDim::DataType qtype) = 0;
135 :
136 : /**
137 : * @brief Quantize a tensor into a quantized tensor.
138 : * @param[in] input Floating point tensor to quantize
139 : * @param[out] output Quantized tensor
140 : * @param[in] scales float scale factors
141 : * @param[in] zero_points unsigned int zero points
142 : * @return Tensor quantized tensor
143 : */
144 : virtual Tensor &quantize(const Tensor &input, Tensor &output, float *scales,
145 : unsigned int *zero_points = nullptr) = 0;
146 :
147 : /**
148 : * @brief Dequantize a quantized tensor into a tensor.
149 : * @param[in] input Quantized tensor to dequantize
150 : * @return Tensor dequantized tensor
151 : */
152 : virtual Tensor dequantize(const Tensor &input,
153 : ml::train::TensorDim::DataType qtype) = 0;
154 :
155 : /**
156 : * @brief Get quantization Scheme type.
157 : * @return Quantization scheme
158 : */
159 : virtual QScheme qscheme() const = 0;
160 : };
161 :
162 : /**
163 : * @class UniformQuantizer class
164 : * @brief UniformQuantizer class serves as the parent class for various types of
165 : * uniform quantizers.
166 : */
167 : class UniformQuantizer : public Quantizer {
168 : public:
169 10 : UniformQuantizer() : Quantizer() {}
170 : };
171 :
172 : /**
173 : * @class NonUniformQuantizer class
174 : * @brief NonUniformQuantizer class serves as the parent class for various types
175 : * of non-uniform quantizers.
176 : */
177 : class NonUniformQuantizer : public Quantizer {
178 : public:
179 0 : NonUniformQuantizer() : Quantizer() {}
180 : };
181 :
182 : /**
183 : * @class PerTensorAffineQuantizer class
184 : * @brief PerTensorAffineQuantizer class uses affine quantization scheme.
185 : *
186 : * Quantization: x_q = clip(round(x / scale + zero_point), min, max)
187 : * Dequantization: x = scale * (x_q - zero_point)
188 : *
189 : * @note Single scale and zero point values are used for the entire tensor.
190 : */
191 : class PerTensorAffineQuantizer : public UniformQuantizer {
192 : public:
193 : /**
194 : * @brief Basic Constructor of a PerTensorAffineQuantizer
195 : */
196 10 : PerTensorAffineQuantizer() : UniformQuantizer(), scale(1) {}
197 :
198 : /**
199 : * @copydoc Quantizer::create()
200 : */
201 : std::unique_ptr<Quantizer> create() override;
202 :
203 : /**
204 : * @copydoc Quantizer::quantize(const Tensor &input)
205 : */
206 : Tensor quantize(const Tensor &input,
207 : ml::train::TensorDim::DataType qtype) override;
208 :
209 : /**
210 : * @copydoc Quantizer::quantize(const Tensor &input, Tensor &output, float
211 : * *scales, unsigned int *zero_points)
212 : */
213 : Tensor &quantize(const Tensor &input, Tensor &output, float *scales,
214 : unsigned int *zero_points = nullptr) override;
215 :
216 : /**
217 : * @copydoc Quantizer::dequantize(const Tensor &input)
218 : */
219 : Tensor dequantize(const Tensor &input,
220 : ml::train::TensorDim::DataType dtype) override;
221 :
222 : /**
223 : * @copydoc Quantizer::qscheme()
224 : */
225 : QScheme qscheme() const override;
226 :
227 : private:
228 : float scale;
229 : unsigned int zero_point = 0;
230 :
231 : /**
232 : * @copydoc Quantizer::calculateQParams(const Tensor &input,
233 : * ml::train::TensorDim::DataType qtype)
234 : */
235 : void calculateQParams(const Tensor &input,
236 : ml::train::TensorDim::DataType qtype) override;
237 : };
238 :
239 : /**
240 : * @class PerChannelAffineQuantizer class
241 : * @brief PerChannelAffineQuantizer class uses affine quantization scheme.
242 : *
243 : * @note PerChannelAffineQuantizer is similar to PerTensorAffineQuantizer, but
244 : * it has separate scale and zero_point parameters for each channel. This allows
245 : * for more precise quantization of different channels within the same tensor.
246 : *
247 : */
248 : class PerChannelAffineQuantizer : public UniformQuantizer {
249 : public:
250 : /**
251 : * @brief Basic Constructor of a PerChannelAffineQuantizer
252 : */
253 0 : PerChannelAffineQuantizer() : UniformQuantizer() {}
254 :
255 : /**
256 : * @copydoc Quantizer::create()
257 : */
258 : std::unique_ptr<Quantizer> create() override;
259 :
260 : /**
261 : * @copydoc Quantizer::quantize(const Tensor &input)
262 : */
263 : Tensor quantize(const Tensor &input,
264 : ml::train::TensorDim::DataType qtype) override;
265 :
266 : /**
267 : * @copydoc Quantizer::quantize(const Tensor &input, Tensor &output, float
268 : * *scales, unsigned int *zero_points)
269 : */
270 : Tensor &quantize(const Tensor &input, Tensor &output, float *scales,
271 : unsigned int *zero_points = nullptr) override;
272 :
273 : /**
274 : * @copydoc Quantizer::dequantize(const Tensor &input)
275 : */
276 : Tensor dequantize(const Tensor &input,
277 : ml::train::TensorDim::DataType dtype) override;
278 :
279 : /**
280 : * @copydoc Quantizer::qscheme()
281 : */
282 : QScheme qscheme() const override;
283 :
284 : private:
285 : /**
286 : * @copydoc Quantizer::calculateQParams(const Tensor &input,
287 : * ml::train::TensorDim::DataType qtype)
288 : */
289 0 : void calculateQParams(const Tensor &input,
290 0 : ml::train::TensorDim::DataType qtype) override {}
291 : };
292 :
293 : /**
294 : * @class BinaryCodeBasedQuantizer class
295 : * @brief BinaryCodeBasedQuantizer class uses Binary-code-based quantization
296 : * (BCQ) scheme.
297 : *
298 : */
299 : class BinaryCodeBasedQuantizer : public NonUniformQuantizer {
300 : public:
301 : /**
302 : * @brief Basic Constructor of a BinaryCodeBasedQuantizer
303 : */
304 0 : BinaryCodeBasedQuantizer() : NonUniformQuantizer() {}
305 :
306 : /**
307 : * @copydoc Quantizer::create()
308 : */
309 : std::unique_ptr<Quantizer> create() override;
310 :
311 : /**
312 : * @copydoc Quantizer::quantize(const Tensor &input)
313 : */
314 : Tensor quantize(const Tensor &input,
315 : ml::train::TensorDim::DataType qtype) override;
316 :
317 : /**
318 : * @copydoc Quantizer::quantize(const Tensor &input, Tensor &output, float
319 : * *scales, unsigned int *zero_points)
320 : */
321 : Tensor &quantize(const Tensor &input, Tensor &output, float *scales,
322 : unsigned int *zero_points = nullptr) override;
323 :
324 : /**
325 : * @copydoc Quantizer::dequantize(const Tensor &input)
326 : */
327 : Tensor dequantize(const Tensor &input,
328 : ml::train::TensorDim::DataType dtype) override;
329 :
330 : /**
331 : * @copydoc Quantizer::qscheme()
332 : */
333 : QScheme qscheme() const override;
334 :
335 : private:
336 : // float *scales = nullptr;
337 : // int *zero_points = nullptr;
338 : // long int quant_min = 0;
339 : // long int quant_max = 0;
340 :
341 : /**
342 : * @copydoc Quantizer::calculateQParams(const Tensor &input,
343 : * ml::train::TensorDim::DataType qtype)
344 : */
345 0 : void calculateQParams(const Tensor &input,
346 0 : ml::train::TensorDim::DataType qtype) override {}
347 : };
348 :
349 : /**
350 : * @brief Quantization class to create a quantizer
351 : *
352 : * @details The quantization class is a creator class to create a predefined
353 : * quantization and a user-defined quantizer. Please check QScheme to find out
354 : * about the predefined quantizers.
355 : *
356 : * If a preferred quantization scheme is not provided, create a new class that
357 : * inherits the Quantizer class, select the quantization scheme
358 : * CUSTOM_QUANTIZER_#, register it using registerQuantizer(), and then use it.
359 : */
360 : class Quantization {
361 : public:
362 : /**
363 : * @brief Create a Quantizer object
364 : *
365 : * @param qscheme quantization scheme
366 : * @return std::unique_ptr<Quantizer> quantizer object
367 : */
368 10 : static std::unique_ptr<Quantizer> createQuantizer(QScheme qscheme) {
369 10 : switch (qscheme) {
370 10 : case QScheme::PER_TENSOR_AFFINE:
371 : return std::make_unique<PerTensorAffineQuantizer>();
372 : break;
373 0 : case QScheme::PER_CHANNEL_AFFINE:
374 : return std::make_unique<PerChannelAffineQuantizer>();
375 : break;
376 0 : case QScheme::BINARY_CODE_BASED:
377 : return std::make_unique<BinaryCodeBasedQuantizer>();
378 : break;
379 0 : default:
380 0 : return Quantizer::getRegisteredQuantizer(qscheme)->create();
381 : break;
382 : }
383 : }
384 : };
385 :
386 : } // namespace nntrainer
387 :
388 : #endif /* __cplusplus */
389 : #endif /* __QUANTIZER_H__ */
|