Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2021 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file base_properties.h
6 : * @date 08 April 2021
7 : * @brief Convenient property type definition for automated serialization
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : */
12 : #ifndef __BASE_PROPERTIES_H__
13 : #define __BASE_PROPERTIES_H__
14 :
15 : #include <array>
16 : #include <memory>
17 : #include <regex>
18 : #include <sstream>
19 : #include <string>
20 : #include <vector>
21 :
22 : #include <common.h>
23 : #include <nntrainer_error.h>
24 : #include <tensor_dim.h>
25 : #include <util_func.h>
26 :
27 : /** base and predefined structures */
28 :
29 : namespace nntrainer {
30 :
31 : using TensorDim = ml::train::TensorDim;
32 :
33 : /**
34 : * @brief property info to specialize functions based on this
35 : * @tparam T property type
36 : */
37 : template <typename T> struct prop_info {
38 : using prop_type = std::decay_t<T>; /** property type of T */
39 : using tag_type = typename prop_type::prop_tag; /** Property tag of T */
40 : using data_type =
41 : std::decay_t<decltype(std::declval<prop_type>().get())>; /** Underlying
42 : datatype of T */
43 : };
44 :
45 : /**
46 : * @brief property info when it is wrapped inside a vector
47 : *
48 : * @tparam T property type
49 : */
50 : template <typename T> struct prop_info<std::vector<T>> {
51 : using prop_type = typename prop_info<T>::prop_type;
52 : using tag_type = typename prop_info<T>::tag_type;
53 : using data_type = typename prop_info<T>::data_type;
54 : };
55 :
56 : /**
57 : * @brief property info when it is wrapped inside an array
58 : *
59 : * @tparam T property type
60 : */
61 : template <typename T, size_t size> struct prop_info<std::array<T, size>> {
62 : using prop_type = typename prop_info<T>::prop_type;
63 : using tag_type = typename prop_info<T>::tag_type;
64 : using data_type = typename prop_info<T>::data_type;
65 : };
66 :
67 : /**
68 : * @brief Get the Prop Key object
69 : *
70 : * @tparam T property to get type
71 : * @param prop property
72 : * @return constexpr const char* key
73 : */
74 : template <typename T> constexpr const char *getPropKey(T &&prop) {
75 : return prop_info<std::decay_t<T>>::prop_type::key;
76 : }
77 :
78 : /**
79 : * @brief property is treated as integer
80 : *
81 : */
82 : struct int_prop_tag {};
83 :
84 : /**
85 : * @brief property is treated as unsigned integer
86 : *
87 : */
88 : struct uint_prop_tag {};
89 :
90 : /**
91 : * @brief property is treated as unsigned integer
92 : *
93 : */
94 : struct size_t_prop_tag {};
95 :
96 : /**
97 : * @brief property is treated as dimension, eg 1:2:3
98 : *
99 : */
100 : struct dimension_prop_tag {};
101 :
102 : /**
103 : * @brief property is treated as float
104 : *
105 : */
106 : struct float_prop_tag {};
107 :
108 : /**
109 : * @brief property is treated as double
110 : *
111 : */
112 : struct double_prop_tag {};
113 :
114 : /**
115 : * @brief property is treated as string
116 : *
117 : */
118 : struct str_prop_tag {};
119 :
120 : /**
121 : * @brief property is treated as boolean
122 : *
123 : */
124 : struct bool_prop_tag {};
125 :
126 : /**
127 : * @brief property is treated as enum class
128 : *
129 : */
130 : struct enum_class_prop_tag {};
131 :
132 : /**
133 : * @brief property treated as a raw pointer
134 : *
135 : */
136 : struct ptr_prop_tag {};
137 :
138 : /**
139 : * @brief base property class, inherit this to make a convenient property
140 : *
141 : * @tparam T
142 : */
143 : template <typename T> class Property {
144 :
145 : public:
146 : /**
147 : * @brief Construct a new Property object
148 : *
149 : */
150 69690 : Property() : value(nullptr){};
151 :
152 : /**
153 : * @brief Construct a new Property object
154 : *
155 : * @param value default value
156 : */
157 28347 : Property(const T &value_) { set(value_); }
158 :
159 : /**
160 : * @brief Copy Construct a new Property object
161 : *
162 : * @param rhs right side to copy from
163 : */
164 82389 : Property(const Property &rhs) {
165 82389 : if (this != &rhs && rhs.value) {
166 1706 : value = std::make_unique<T>(*rhs.value);
167 : }
168 82389 : }
169 :
170 : /**
171 : * @brief Copy assignment operator of a new property
172 : *
173 : * @param rhs right side to copy from
174 : * @return Property& this
175 : */
176 3218 : Property &operator=(const Property &rhs) {
177 3218 : if (this != &rhs && rhs.value) {
178 26 : value = std::make_unique<T>(*rhs.value);
179 : }
180 3218 : return *this;
181 : };
182 :
183 25067 : Property(Property &&rhs) noexcept = default;
184 : Property &operator=(Property &&rhs) noexcept = default;
185 :
186 : /**
187 : * @brief Destroy the Property object
188 : *
189 : */
190 64455 : virtual ~Property() = default;
191 :
192 : /**
193 : * @brief cast operator for property
194 : *
195 : * @return T value
196 : */
197 1950922 : operator T &() { return get(); }
198 :
199 : /**
200 : * @brief cast operator for property
201 : *
202 : * @return T value
203 : */
204 44231404 : operator const T &() const { return get(); }
205 :
206 : /**
207 : * @brief get the underlying data
208 : *
209 : * @return const T& data
210 : */
211 44292560 : const T &get() const {
212 44292562 : NNTR_THROW_IF(value == nullptr, std::invalid_argument)
213 : << "Cannot get property, property is empty";
214 44292558 : return *value;
215 : }
216 :
217 : /**
218 : * @brief get the underlying data
219 : *
220 : * @return T& data
221 : */
222 2307088 : T &get() {
223 2307090 : NNTR_THROW_IF(value == nullptr, std::invalid_argument)
224 : << "Cannot get property, property is empty";
225 2307086 : return *value;
226 : }
227 :
228 : /**
229 : * @brief check if property is empty
230 : *
231 : * @retval true empty
232 : * @retval false not empty
233 : */
234 : bool empty() const { return value == nullptr; }
235 :
236 : /**
237 : * @brief set the underlying data
238 : *
239 : * @param v value to set
240 : * @throw std::invalid_argument if argument is not valid
241 : */
242 189210 : virtual void set(const T &v) {
243 189254 : NNTR_THROW_IF(isValid(v) == false, std::invalid_argument)
244 : << "argument is not valid";
245 93626 : value = std::make_unique<T>(v);
246 189158 : }
247 :
248 : /**
249 : * @brief check if given value is valid
250 : *
251 : * @param v value to check
252 : * @retval true if valid
253 : * @retval false if not valid
254 : */
255 86526 : virtual bool isValid(const T &v) const { return true; }
256 :
257 : /**
258 : * @brief operator==
259 : *
260 : * @param rhs right side to compare
261 : * @retval true if equal
262 : * @retval false if not equal
263 : */
264 17 : bool operator==(const Property<T> &rhs) const { return *value == *rhs.value; }
265 :
266 : private:
267 : std::unique_ptr<T> value; /**< underlying data */
268 : };
269 :
270 : /**
271 : * @brief enum property
272 : *
273 : * @tparam T underlying type info to query enum_info
274 : */
275 : template <typename EnumInfo>
276 13940 : class EnumProperty : public Property<typename EnumInfo::Enum> {
277 : public:
278 : static EnumInfo enum_info_;
279 : };
280 :
281 : /**
282 : * @brief abstract class for tensor dimension
283 : *
284 : */
285 37 : class TensorDimProperty : public Property<TensorDim> {
286 : public:
287 : /**
288 : * @brief Destroy the TensorDim Property object
289 : *
290 : */
291 : virtual ~TensorDimProperty() = default;
292 : };
293 :
294 : /**
295 : * @brief abstract class for positive integer
296 : *
297 : */
298 12920 : class PositiveIntegerProperty : public Property<unsigned int> {
299 : public:
300 : /**
301 : * @brief Destroy the Positive Integer Property object
302 : *
303 : */
304 0 : virtual ~PositiveIntegerProperty() = default;
305 :
306 : /**
307 : * @brief isValid override, check if value > 0
308 : *
309 : * @param value value to check
310 : * @retval true if value > 0
311 : */
312 : virtual bool isValid(const unsigned int &value) const override;
313 : };
314 : /**
315 : * @brief meta function to cast tag to it's base
316 : * @code below is the test spec for the cast
317 : *
318 : * struct custom_tag: int_prop_tag {};
319 : *
320 : * using tag_type = tag_cast<custom_tag, float_prop_tag>::type
321 : * static_assert(<std::is_same_v<tag_type, custom_tag> == true);
322 : *
323 : * using tag_type = tag_cast<custom_tag, int_prop_tag>::type
324 : * static_assert(<std::is_same_v<tag_type, int_prop_tag> == true);
325 : *
326 : * using tag_type = tag_cast<custom_tag, float_prop_tag, int_prop_tag>::type
327 : * static_assert(std::is_same_v<tag_type, int_prop_tag> == true);
328 : *
329 : * @tparam Tags First tag: tag to be casted, rest tags: candidates
330 : */
331 : template <typename... Tags> struct tag_cast;
332 :
333 : /**
334 : * @brief base case of tag_cast, if nothing matches return @a Tag
335 : *
336 : * @tparam Tag Tag to be casted
337 : * @tparam Others empty parameter pack
338 : */
339 : template <typename Tag, typename... Others> struct tag_cast<Tag, Others...> {
340 : using type = Tag;
341 : };
342 :
343 : /**
344 : * @brief normal case of the tag cast
345 : *
346 : * @tparam Tag tag to be casted
347 : * @tparam BaseTag candidates to cast the tag
348 : * @tparam Others pending candidates to be compared
349 : */
350 : template <typename Tag, typename BaseTag, typename... Others>
351 : struct tag_cast<Tag, BaseTag, Others...> {
352 : using type = std::conditional_t<std::is_base_of<BaseTag, Tag>::value, BaseTag,
353 : typename tag_cast<Tag, Others...>::type>;
354 : };
355 :
356 : /**
357 : * @brief property to string converter.
358 : * This structure defines how to convert to convert from/to string
359 : *
360 : * @tparam Tag tag type for the converter
361 : * @tparam DataType underlying datatype
362 : */
363 : template <typename Tag, typename DataType> struct str_converter {
364 :
365 : /**
366 : * @brief convert underlying value to string
367 : *
368 : * @param value value to convert to string
369 : * @return std::string string
370 : */
371 : static std::string to_string(const DataType &value);
372 :
373 : /**
374 : * @brief convert string to underlying value
375 : *
376 : * @param value value to convert to string
377 : * @return DataType converted type
378 : */
379 : static DataType from_string(const std::string &value);
380 : };
381 :
382 : /**
383 : * @brief str converter specialization for enum classes
384 : *
385 : * @tparam EnumInfo enum informations
386 : */
387 : template <typename EnumInfo>
388 : struct str_converter<enum_class_prop_tag, EnumInfo> {
389 :
390 : /**
391 : * @copydoc template <typename Tag, typename DataType> struct str_converter
392 : */
393 8765 : static std::string to_string(const typename EnumInfo::Enum &value) {
394 : constexpr auto size = EnumInfo::EnumList.size();
395 : constexpr const auto data = std::data(EnumInfo::EnumList);
396 32692 : for (unsigned i = 0; i < size; ++i) {
397 32692 : if (data[i] == value) {
398 8765 : return EnumInfo::EnumStr[i];
399 : }
400 : }
401 0 : throw std::invalid_argument("Cannot find value in the enum list");
402 : }
403 :
404 : /**
405 : * @copydoc template <typename Tag, typename DataType> struct str_converter
406 : */
407 41595 : static typename EnumInfo::Enum from_string(const std::string &value) {
408 : constexpr auto size = EnumInfo::EnumList.size();
409 : constexpr const auto data = std::data(EnumInfo::EnumList);
410 209421 : for (unsigned i = 0; i < size; ++i) {
411 418812 : if (istrequal(EnumInfo::EnumStr[i], value.c_str())) {
412 41580 : return data[i];
413 : }
414 : }
415 30 : throw std::invalid_argument("No matching enum for value: " + value);
416 : }
417 : };
418 :
419 : /**
420 : * @brief str converter which serializes a pointer and returns back to a ptr
421 : *
422 : * @tparam DataType pointer type
423 : */
424 : template <typename DataType> struct str_converter<ptr_prop_tag, DataType> {
425 :
426 : /**
427 : * @brief convert underlying value to string
428 : *
429 : * @param value value to convert to string
430 : * @return std::string string
431 : */
432 2 : static std::string to_string(const DataType &value) {
433 2 : std::ostringstream ss;
434 2 : ss << value;
435 2 : return ss.str();
436 2 : }
437 :
438 : /**
439 : * @brief convert string to underlying value
440 : *
441 : * @param value value to convert to string
442 : * @return DataType converted type
443 : */
444 4 : static DataType from_string(const std::string &value) {
445 4 : std::stringstream ss(value);
446 : uintptr_t addr = static_cast<uintptr_t>(std::stoull(value, 0, 16));
447 4 : return reinterpret_cast<DataType>(addr);
448 4 : }
449 : };
450 :
451 : /**
452 : * @copydoc template <typename Tag, typename DataType> struct str_converter
453 : */
454 : template <>
455 : std::string
456 : str_converter<str_prop_tag, std::string>::to_string(const std::string &value);
457 :
458 : /**
459 : * @copydoc template <typename Tag, typename DataType> struct str_converter
460 : */
461 : template <>
462 : std::string
463 : str_converter<str_prop_tag, std::string>::from_string(const std::string &value);
464 :
465 : /**
466 : * @copydoc template <typename Tag, typename DataType> struct str_converter
467 : */
468 : template <>
469 : std::string str_converter<uint_prop_tag, unsigned int>::to_string(
470 : const unsigned int &value);
471 :
472 : /**
473 : * @copydoc template <typename Tag, typename DataType> struct str_converter
474 : */
475 : template <>
476 : unsigned int str_converter<uint_prop_tag, unsigned int>::from_string(
477 : const std::string &value);
478 :
479 : /**
480 : * @copydoc template <typename Tag, typename DataType> struct str_converter
481 : */
482 : template <>
483 : std::string
484 : str_converter<size_t_prop_tag, size_t>::to_string(const size_t &value);
485 :
486 : /**
487 : * @copydoc template <typename Tag, typename DataType> struct str_converter
488 : */
489 : template <>
490 : size_t
491 : str_converter<size_t_prop_tag, size_t>::from_string(const std::string &value);
492 :
493 : /**
494 : * @copydoc template <typename Tag, typename DataType> struct str_converter
495 : */
496 : template <>
497 : std::string str_converter<bool_prop_tag, bool>::to_string(const bool &value);
498 :
499 : /**
500 : * @copydoc template <typename Tag, typename DataType> struct str_converter
501 : */
502 : template <>
503 : bool str_converter<bool_prop_tag, bool>::from_string(const std::string &value);
504 :
505 : /**
506 : * @copydoc template <typename Tag, typename DataType> struct str_converter
507 : */
508 : template <>
509 : std::string str_converter<float_prop_tag, float>::to_string(const float &value);
510 :
511 : /**
512 : * @copydoc template <typename Tag, typename DataType> struct str_converter
513 : */
514 : template <>
515 : float str_converter<float_prop_tag, float>::from_string(
516 : const std::string &value);
517 :
518 : /**
519 : * @copydoc template <typename Tag, typename DataType> struct str_converter
520 : */
521 : template <>
522 : std::string
523 : str_converter<double_prop_tag, double>::to_string(const double &value);
524 :
525 : /**
526 : * @copydoc template <typename Tag, typename DataType> struct str_converter
527 : */
528 : template <>
529 : double
530 : str_converter<double_prop_tag, double>::from_string(const std::string &value);
531 :
532 : /**
533 : * @brief convert dispatcher (to string)
534 : *
535 : * @tparam T type to convert
536 : * @param property property to convert
537 : * @return std::string converted string
538 : */
539 : template <typename T> std::string to_string(const T &property) {
540 : using info = prop_info<T>;
541 : using tag_type =
542 : typename tag_cast<typename info::tag_type, int_prop_tag, uint_prop_tag,
543 : dimension_prop_tag, float_prop_tag, str_prop_tag,
544 : enum_class_prop_tag>::type;
545 : using data_type = typename info::data_type;
546 :
547 : if constexpr (std::is_same_v<tag_type, enum_class_prop_tag>) {
548 : return str_converter<tag_type, decltype(T::enum_info_)>::to_string(
549 8071 : property.get());
550 : } else {
551 29631 : return str_converter<tag_type, data_type>::to_string(property.get());
552 : }
553 : }
554 :
555 : /**
556 : * @brief to_string vector specialization
557 : * @copydoc template <typename T> std::string to_string(const T &property)
558 : */
559 2816 : template <typename T> std::string to_string(const std::vector<T> &property) {
560 2816 : std::stringstream ss;
561 : auto last_iter = property.end() - 1;
562 3985 : for (auto iter = property.begin(); iter != last_iter; ++iter) {
563 2338 : ss << to_string(*iter) << ',';
564 : }
565 2816 : ss << to_string(*last_iter);
566 :
567 2816 : return ss.str();
568 2816 : }
569 :
570 : /**
571 : * @brief to_string array specialization
572 : * @copydoc template <typename T> std::string to_string(const T &property)
573 : */
574 : template <typename T, size_t sz>
575 141 : static std::string to_string(const std::array<T, sz> &value) {
576 141 : std::stringstream ss;
577 141 : auto last_iter = value.end() - 1;
578 284 : for (auto iter = value.begin(); iter != last_iter; ++iter) {
579 286 : ss << to_string(*iter) << ',';
580 : }
581 141 : ss << to_string(*last_iter);
582 :
583 141 : return ss.str();
584 141 : }
585 :
586 : /**
587 : *
588 : * @brief convert dispatcher (from string)
589 : *
590 : *
591 : * @tparam T type to convert
592 : * @param str string to convert
593 : * @param[out] property property, converted type
594 : */
595 53610 : template <typename T> void from_string(const std::string &str, T &property) {
596 : using info = prop_info<T>;
597 : using tag_type =
598 : typename tag_cast<typename info::tag_type, int_prop_tag, uint_prop_tag,
599 : dimension_prop_tag, float_prop_tag, str_prop_tag,
600 : enum_class_prop_tag>::type;
601 : using data_type = typename info::data_type;
602 :
603 : if constexpr (std::is_same_v<tag_type, enum_class_prop_tag>) {
604 15650 : property.set(
605 7834 : str_converter<tag_type, decltype(T::enum_info_)>::from_string(str));
606 : } else {
607 56029 : property.set(str_converter<tag_type, data_type>::from_string(str));
608 : }
609 53570 : }
610 :
611 : /**
612 : * @brief transform iternal data, this is to use with std::transform
613 : *
614 : * @param item item to transform
615 : * @return DataType transformed result
616 : */
617 10335 : template <typename T> static T from_string_helper_(const std::string &item) {
618 7066 : T t;
619 10335 : from_string(item, t);
620 10325 : return t;
621 : }
622 :
623 : static const std::regex reg_("\\s*\\,\\s*");
624 :
625 : /**
626 : * @brief from_string array specialization
627 : * @copydoc template <typename T> void from_string(const std::string &str, T
628 : * &property)
629 : * @note array implies that the size is @b fixed so there will be a validation
630 : * check on size
631 : */
632 : template <typename T, size_t sz>
633 420 : void from_string(const std::string &value, std::array<T, sz> &property) {
634 420 : auto v = split(value, reg_);
635 429 : NNTR_THROW_IF(v.size() != sz, std::invalid_argument)
636 : << "size must match with array size, array size: " << sz
637 : << " string: " << value;
638 :
639 411 : std::transform(v.begin(), v.end(), property.begin(), from_string_helper_<T>);
640 420 : }
641 :
642 : /**
643 : * @brief from_string vector specialization
644 : * @copydoc str_converter<Tag, DataType>::to_string(const DataType &value)
645 : * @note vector implies that the size is @b not fixed so there shouldn't be any
646 : * validation on size
647 : *
648 : */
649 : template <typename T>
650 6902 : void from_string(const std::string &value, std::vector<T> &property) {
651 6902 : auto v = split(value, reg_);
652 :
653 : property.clear();
654 6902 : property.reserve(v.size());
655 6902 : std::transform(v.begin(), v.end(), std::back_inserter(property),
656 : from_string_helper_<T>);
657 6902 : }
658 : /******** below section is for enumerations ***************/
659 : /**
660 : * @brief Enumeration of Data Type for model & layer
661 : */
662 : struct TensorDataTypeInfo {
663 : using Enum = nntrainer::TensorDim::DataType;
664 : static constexpr std::initializer_list<Enum> EnumList = {
665 : Enum::BCQ, Enum::QINT4, Enum::QINT8, Enum::QINT16,
666 : Enum::FP16, Enum::FP32, Enum::UINT4, Enum::UINT8,
667 : Enum::UINT16, Enum::Q4_K, Enum::Q6_K, Enum::Q4_0};
668 : static constexpr const char *EnumStr[] = {
669 : "BCQ", "QINT4", "QINT8", "QINT16", "FP16", "FP32",
670 : "UINT4", "UINT8", "UINT16", "Q4_K", "Q6_K", "Q4_0"};
671 : };
672 :
673 : /**
674 : * @brief Enumeration of Format for model & layer
675 : */
676 : struct TensorFormatInfo {
677 : using Enum = nntrainer::TensorDim::Format;
678 : static constexpr std::initializer_list<Enum> EnumList = {Enum::NCHW,
679 : Enum::NHWC};
680 :
681 : static constexpr const char *EnumStr[] = {"NCHW", "NHWC"};
682 : };
683 :
684 : /**
685 : * @brief Enumeration of Tensor Type for model & layer
686 : */
687 : enum class TensorType_ {
688 : WEIGHT, /**< Weight Tensor */
689 : IN_TENSOR, /**< Input Tensor for FORWARD_FUNC_LIFESPAN */
690 : MAX_IN_TENSOR, /**< MAX LIFESPAN INPUT TENSOR */
691 : OUT_TENSOR, /**< OUTPUT TENSOR */
692 : MAX_OUT_TENSOR, /**< MAX LIFESPAN OUTPUT TENSOR */
693 : };
694 :
695 : /**
696 : * @brief Enumeration of Tensor Type for model & layer
697 : */
698 : struct TensorTypeInfo {
699 : using Enum = nntrainer::TensorType_;
700 : static constexpr std::initializer_list<Enum> EnumList = {
701 : Enum::WEIGHT, Enum::IN_TENSOR, Enum::MAX_IN_TENSOR, Enum::OUT_TENSOR,
702 : Enum::MAX_OUT_TENSOR};
703 : static constexpr const char *EnumStr[] = {
704 : "WEIGHT", "IN_TENSOR", "MAX_IN_TENSOR", "OUT_TENSOR", "MAX_OUT_TENSOR"};
705 : };
706 :
707 : namespace props {
708 :
709 : /**
710 : * @brief Weight Data Type Enumeration Information
711 : * This property can be used when any layer is created.
712 : * This property is differentiated with TensorDataType in that it doesn't have
713 : * default value
714 : */
715 12382 : class WeightDtype final : public EnumProperty<TensorDataTypeInfo> {
716 : public:
717 : using prop_tag = enum_class_prop_tag;
718 : static constexpr const char *key = "weight_dtype";
719 :
720 : /**
721 : * @brief Constructor
722 : */
723 6191 : WeightDtype(){};
724 :
725 : /**
726 : * @brief Constructor
727 : *
728 : * @param value value to set
729 : */
730 : WeightDtype(TensorDataTypeInfo::Enum value) { set(value); };
731 : };
732 :
733 : /**
734 : * @brief Activation Enumeration Information
735 : *
736 : */
737 118 : class TensorDataType final : public EnumProperty<TensorDataTypeInfo> {
738 : public:
739 : using prop_tag = enum_class_prop_tag;
740 : static constexpr const char *key = "tensor_dtype";
741 :
742 : /**
743 : * @brief Constructor
744 : *
745 : * @param value value to set, defaults to FP32
746 : */
747 81 : TensorDataType(
748 81 : TensorDataTypeInfo::Enum value = TensorDataTypeInfo::Enum::FP32) {
749 81 : set(value);
750 81 : };
751 : };
752 :
753 : /**
754 : * @brief model tensor type : NCHW or NHWC
755 : *
756 : */
757 2370 : class TensorFormat final : public EnumProperty<TensorFormatInfo> {
758 : public:
759 : static constexpr const char *key =
760 : "tensor_format"; /**< unique key to access */
761 : using prop_tag = enum_class_prop_tag; /**< property type */
762 :
763 : /**
764 : * @brief Constructor
765 : *
766 : * @param value value to set, defaults to NCHW
767 : */
768 836 : TensorFormat(TensorFormatInfo::Enum value = TensorFormatInfo::Enum::NCHW) {
769 836 : set(value);
770 836 : };
771 : };
772 :
773 : /**
774 : * @brief model tensor type is the clue of life span
775 : *
776 : */
777 : class TensorType final : public EnumProperty<nntrainer::TensorTypeInfo> {
778 : public:
779 : static constexpr const char *key = "tensor_type";
780 : using prop_tag = enum_class_prop_tag;
781 : TensorType(TensorTypeInfo::Enum value = TensorTypeInfo::Enum::WEIGHT) {
782 : set(value);
783 : }
784 : };
785 :
786 : /**
787 : * @brief Enumeration of Run Engine type
788 : */
789 : struct ComputeEngineTypeInfo {
790 : using Enum = ml::train::LayerComputeEngine;
791 : static constexpr std::initializer_list<Enum> EnumList = {Enum::CPU, Enum::GPU,
792 : Enum::QNN};
793 : static constexpr const char *EnumStr[] = {"cpu", "gpu", "qnn"};
794 : };
795 :
796 : /**
797 : * @brief ComputeEngine Enumeration Information
798 : *
799 : */
800 18573 : class ComputeEngine final
801 : : public EnumProperty<nntrainer::props::ComputeEngineTypeInfo> {
802 : public:
803 : using prop_tag = enum_class_prop_tag;
804 : static constexpr const char *key = "engine";
805 : };
806 :
807 : // /**
808 : // * @brief trainable property, use this to set and check how if certain
809 : // layer is
810 : // * trainable
811 : // *
812 : // */
813 : // class Trainable : public nntrainer::Property<bool> {
814 : // public:
815 : // /**
816 : // * @brief Construct a new Trainable object
817 : // *
818 : // */
819 : // Trainable(bool val = true) : nntrainer::Property<bool>(val) {}
820 : // static constexpr const char *key = "trainable";
821 : // using prop_tag = bool_prop_tag;
822 : // };
823 :
824 : } // namespace props
825 :
826 : } // namespace nntrainer
827 :
828 : #endif // __BASE_PROPERTIES_H__
|