LCOV - code coverage report
Current view: top level - nntrainer/tensor - short_tensor.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 20.0 % 5 1
Test Date: 2025-12-14 20:38:17 Functions: 25.0 % 4 1

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * @file        short_tensor.h
       4              :  * @date        10 January 2025
       5              :  * @brief       This is ShortTensor class for 16-bit signed integer calculation
       6              :  * @see         https://github.com/nnstreamer/nntrainer
       7              :  * @author      Donghyeon Jeong <dhyeon.jeong@samsung.com>
       8              :  * @bug         No known bugs except for NYI items
       9              :  */
      10              : 
      11              : #ifndef __SHORT_TENSOR_H__
      12              : #define __SHORT_TENSOR_H__
      13              : #ifdef __cplusplus
      14              : 
      15              : #include <tensor_base.h>
      16              : 
      17              : namespace nntrainer {
      18              : 
      19              : /**
      20              :  * @class ShortTensor class
      21              :  * @brief ShortTensor class for 16-bit unsigned integer calculation
      22              :  */
      23              : class ShortTensor : public TensorBase {
      24              : public:
      25              :   /**
      26              :    * @brief     Basic Constructor of Tensor
      27              :    */
      28              :   ShortTensor(std::string name_ = "", Tformat fm = Tformat::NCHW,
      29              :               QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE);
      30              : 
      31              :   /**
      32              :    * @brief Construct a new ShortTensor object
      33              :    *
      34              :    * @param d Tensor dim for this float tensor
      35              :    * @param alloc_now Allocate memory to this tensor or not
      36              :    * @param init Initializer for the tensor
      37              :    * @param name Name of the tensor
      38              :    */
      39              :   ShortTensor(const TensorDim &d, bool alloc_now,
      40              :               Initializer init = Initializer::NONE, std::string name = "",
      41              :               QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE);
      42              : 
      43              :   /**
      44              :    * @brief Construct a new ShortTensor object
      45              :    * @param d Tensor dim for this tensor
      46              :    * @param buf buffer
      47              :    */
      48              :   ShortTensor(const TensorDim &d, const void *buf = nullptr,
      49              :               QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE);
      50              : 
      51              :   /**
      52              :    * @brief Construct a new ShortTensor object
      53              :    * @param d data for the Tensor
      54              :    * @param fm format for the Tensor
      55              :    */
      56              :   ShortTensor(
      57              :     std::vector<std::vector<std::vector<std::vector<int16_t>>>> const &d,
      58              :     std::vector<float> const &scales, Tformat fm, QScheme qscheme_);
      59              : 
      60              :   /**
      61              :    * @brief Construct a new ShortTensor object
      62              :    * @param rhs TensorBase object to copy
      63              :    */
      64            0 :   ShortTensor(TensorBase &rhs) :
      65            0 :     TensorBase(rhs), qscheme(QScheme::PER_TENSOR_AFFINE) {}
      66              : 
      67              :   /**
      68              :    * @brief Basic Destructor
      69              :    */
      70           22 :   ~ShortTensor() {}
      71              : 
      72              :   /**
      73              :    * @brief     Comparison operator overload
      74              :    * @param[in] rhs Tensor to be compared with
      75              :    * @note      Only compares Tensor data
      76              :    */
      77              :   bool operator==(const ShortTensor &rhs) const;
      78              : 
      79              :   /**
      80              :    * @brief     Comparison operator overload
      81              :    * @param[in] rhs Tensor to be compared with
      82              :    * @note      Only compares Tensor data
      83              :    */
      84              :   bool operator!=(const ShortTensor &rhs) const { return !(*this == rhs); }
      85              : 
      86              :   /**
      87              :    * @copydoc Tensor::allocate()
      88              :    */
      89              :   void allocate() override;
      90              : 
      91              :   /**
      92              :    * @copydoc Tensor::deallocate()
      93              :    */
      94              :   void deallocate() override;
      95              : 
      96              :   /**
      97              :    * @copydoc Tensor::getData()
      98              :    */
      99              :   void *getData() const override;
     100              : 
     101              :   /**
     102              :    * @copydoc Tensor::getData(size_t idx)
     103              :    */
     104              :   void *getData(size_t idx) const override;
     105              : 
     106              :   /**
     107              :    * @copydoc Tensor::getScale()
     108              :    */
     109              :   void *getScale() const override;
     110              : 
     111              :   /**
     112              :    * @copydoc Tensor::getScale(size_t idx)
     113              :    */
     114              :   void *getScale(size_t idx) const override;
     115              : 
     116              :   /**
     117              :    * @brief     i data index
     118              :    * @retval    address of ith data
     119              :    */
     120              :   void *getAddress(unsigned int i) override;
     121              : 
     122              :   /**
     123              :    * @brief     i data index
     124              :    * @retval    address of ith data
     125              :    */
     126              :   const void *getAddress(unsigned int i) const override;
     127              : 
     128              :   /**
     129              :    * @brief     return value at specific location
     130              :    * @param[in] i index
     131              :    */
     132              :   const int16_t &getValue(unsigned int i) const;
     133              : 
     134              :   /**
     135              :    * @brief     return value at specific location
     136              :    * @param[in] i index
     137              :    */
     138              :   int16_t &getValue(unsigned int i);
     139              : 
     140              :   /**
     141              :    * @brief     return value at specific location
     142              :    * @param[in] b batch location
     143              :    * @param[in] c channel location
     144              :    * @param[in] h height location
     145              :    * @param[in] w width location
     146              :    */
     147              :   const int16_t &getValue(unsigned int b, unsigned int c, unsigned int h,
     148              :                           unsigned int w) const;
     149              : 
     150              :   /**
     151              :    * @brief     return value at specific location
     152              :    * @param[in] b batch location
     153              :    * @param[in] c channel location
     154              :    * @param[in] h height location
     155              :    * @param[in] w width location
     156              :    */
     157              :   int16_t &getValue(unsigned int b, unsigned int c, unsigned int h,
     158              :                     unsigned int w);
     159              : 
     160              :   /**
     161              :    * @copydoc Tensor::setValue(float value)
     162              :    */
     163              :   void setValue(float value) override;
     164              : 
     165              :   /**
     166              :    * @copydoc Tensor::setValue(b, c, h, w, value)
     167              :    */
     168              :   void setValue(unsigned int b, unsigned int c, unsigned int h, unsigned int w,
     169              :                 float value) override;
     170              : 
     171              :   /**
     172              :    * @copydoc Tensor::addValue(b, c, h, w, value, beta)
     173              :    */
     174              :   void addValue(unsigned int b, unsigned int c, unsigned int h, unsigned int w,
     175              :                 float value, float beta) override;
     176              : 
     177              :   /**
     178              :    * @copydoc Tensor::setZero()
     179              :    */
     180              :   void setZero() override;
     181              : 
     182              :   /**
     183              :    * @copydoc Tensor::initialize()
     184              :    */
     185              :   void initialize() override;
     186              : 
     187              :   /**
     188              :    * @copydoc Tensor::initialize(Initializer init)
     189              :    */
     190              :   void initialize(Initializer init) override;
     191              : 
     192              :   /**
     193              :    * @copydoc Tensor::copy(const Tensor &from)
     194              :    */
     195              :   void copy(const Tensor &from) override;
     196              : 
     197              :   /**
     198              :    * @copydoc Tensor::copyData(const Tensor &from)
     199              :    */
     200              :   void copyData(const Tensor &from) override;
     201              : 
     202              :   /**
     203              :    * @copydoc Tensor::copy_with_stride()
     204              :    */
     205              :   void copy_with_stride(const Tensor &input, Tensor &output) override;
     206              : 
     207              :   /**
     208              :    * @copydoc Tensor::save(std::ostream &file)
     209              :    */
     210              :   void save(std::ostream &file) override;
     211              : 
     212              :   /**
     213              :    * @copydoc Tensor::read(std::ifstream &file)
     214              :    */
     215              :   void read(std::ifstream &file, size_t start_offset,
     216              :             bool read_from_offset) override;
     217              : 
     218              :   /**
     219              :    * @copydoc Tensor::argmax()
     220              :    */
     221              :   std::vector<unsigned int> argmax() const override;
     222              : 
     223              :   /**
     224              :    * @copydoc Tensor::argmin()
     225              :    */
     226              :   std::vector<unsigned int> argmin() const override;
     227              : 
     228              :   /**
     229              :    * @copydoc Tensor::max_abs()
     230              :    */
     231              :   float max_abs() const override;
     232              : 
     233              :   /**
     234              :    * @copydoc Tensor::maxValue()
     235              :    */
     236              :   float maxValue() const override;
     237              : 
     238              :   /**
     239              :    * @copydoc Tensor::minValue()
     240              :    */
     241              :   float minValue() const override;
     242              : 
     243              :   /**
     244              :    * @copydoc Tensor::getMemoryBytes()
     245              :    */
     246              :   size_t getMemoryBytes() const override;
     247              : 
     248              :   /**
     249              :    * @copydoc Tensor::print(std::ostream &out)
     250              :    */
     251              :   void print(std::ostream &out) const override;
     252              : 
     253              :   /**
     254              :    * @copydoc TensorBase::save_quantization_info()
     255              :    */
     256              :   void save_quantization_info(std::ostream &file) override;
     257              : 
     258              :   /**
     259              :    * @copydoc TensorBase::read_quantization_info()
     260              :    */
     261              :   void read_quantization_info(std::ifstream &file, size_t start_offset,
     262              :                               bool read_from_offset) override;
     263              : 
     264              :   /**
     265              :    * @copydoc Tensor::scale_size()
     266              :    */
     267              :   size_t scale_size() const override;
     268              : 
     269              :   /**
     270              :    * @copydoc Tensor::scale_size()
     271              :    */
     272              :   QScheme q_scheme() const override;
     273              : 
     274              : private:
     275              :   /**
     276              :    * @brief quantization scheme
     277              :    */
     278              :   QScheme qscheme;
     279              : 
     280              :   /**
     281              :    * @brief copy a buffer to @a this, the caller has to ensure that @a this is
     282              :    * initialized otherwise undefined behavior
     283              :    *
     284              :    * @param buf buffer to copy from
     285              :    */
     286              :   void copy(const void *buf);
     287              : 
     288              :   /**
     289              :    * @brief  Get the Data Type String object
     290              :    * @return std::string of tensor data type (QINT16)
     291              :    */
     292            0 :   std::string getStringDataType() const override { return "QINT16"; }
     293              : 
     294              :   /**
     295              :    * @copydoc Tensor::isValid()
     296              :    */
     297            0 :   bool isValid() const override { return true; };
     298              : };
     299              : 
     300              : } // namespace nntrainer
     301              : 
     302              : #endif /* __cplusplus */
     303              : #endif /* __SHORT_TENSOR_H__ */
        

Generated by: LCOV version 2.0-1