LCOV - code coverage report
Current view: top level - nntrainer/tensor - tensor.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 95.9 % 98 94
Test Date: 2025-12-14 20:38:17 Functions: 82.9 % 41 34

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * @file        tensor.h
       4              :  * @date        01 December 2023
       5              :  * @brief       This is a Tensor class
       6              :  * @see         https://github.com/nnstreamer/nntrainer
       7              :  * @author      Jijoong Moon <jijoong.moon@samsung.com>
       8              :  * @author      Donghyeon Jeong <dhyeon.jeong@samsung.com>
       9              :  * @bug         No known bugs except for NYI items
      10              :  */
      11              : 
      12              : #ifndef __TENSOR_H__
      13              : #define __TENSOR_H__
      14              : #ifdef __cplusplus
      15              : 
      16              : #define MAKE_SHARED_TENSOR(...) std::make_shared<nntrainer::Tensor>(__VA_ARGS__)
      17              : 
      18              : #define CREATE_IF_EMPTY_DIMS(tensor, ...)                                      \
      19              :   do {                                                                         \
      20              :     if (tensor.empty())                                                        \
      21              :       tensor = Tensor(__VA_ARGS__);                                            \
      22              :   } while (0);
      23              : 
      24              : #include <cstddef>
      25              : 
      26              : #include <cpu_backend.h>
      27              : #include <nntrainer_log.h>
      28              : #include <tensor_base.h>
      29              : 
      30              : #ifdef ENABLE_FP16
      31              : #include <half_tensor.h>
      32              : #endif
      33              : 
      34              : namespace nntrainer {
      35              : 
      36              : class LazyTensor;
      37              : 
      38              : /**
      39              :  * @class Tensor Class
      40              :  * @brief Tensor is a multidimensional matrix that contain elements of a single
      41              :  * data type and can perform various operations like addition, division,
      42              :  * multiplication, dot product, data averaging, and more.
      43              :  * NNTrainer defines tensor types using different data types and memory formats.
      44              :  * Supported data types and format are specified in the file 'tensor_dim.h'.
      45              :  *
      46              :  * @note The Tensor class utilizes the TensorBase class to support tensors with
      47              :  * various data types. In other words, this tensor class serves as a container
      48              :  * for tensors, and thus the functionality of the tensor should be defined in
      49              :  * each tensor class (FloatTensor, HalfTensor, etc.).
      50              :  *
      51              :  */
      52              : class Tensor {
      53              : public:
      54              :   /**
      55              :    * @brief     Basic Constructor of Tensor
      56              :    */
      57              :   Tensor(std::string name_ = "", Tformat fm = Tformat::NCHW,
      58              :          Tdatatype d_type = Tdatatype::FP32);
      59              : 
      60              :   /**
      61              :    * @brief     Constructor of Tensor with dimension, possibly lazily
      62              :    * @param d Tensor dim for this tensor
      63              :    * @param alloc_now If the memory of the tensor must be allocated
      64              :    * @param init Initializer for the tensor
      65              :    * @param name Name of the tensor
      66              :    * @param qscheme_ Quantization scheme (only applies to Quantized Tensor)
      67              :    * @param is_virtual virtual tensor boolean (default=false)
      68              :    */
      69              :   Tensor(const TensorDim &d, bool alloc_now,
      70              :          Initializer init = Initializer::NONE, std::string name = "",
      71              :          QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE,
      72              :          bool is_virtual = false);
      73              : 
      74              :   /**
      75              :    * @brief     Constructor of Tensor with dimension/buf
      76              :    * @param d Tensor dim for this tensor
      77              :    * @param buf buffer
      78              :    * @param qscheme_ Quantization scheme (only applies to Quantized Tensor)
      79              :    * @note Memory for this tensor is instantaneously allocated
      80              :    */
      81              :   Tensor(const TensorDim &d, const void *buf = nullptr,
      82              :          QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE);
      83              : 
      84              :   /**
      85              :    * @brief     Constructor of Tensor
      86              :    * @param[in] d0 Batch of Tensor
      87              :    * @param[in] d1 Channel
      88              :    * @param[in] d2 Height
      89              :    * @param[in] d3 Width
      90              :    * @param[in] fm Tensor Format
      91              :    * @param[in] d_type Tensor Data Type
      92              :    * @param[in] qscheme_ Quantization scheme (only applies to Quantized Tensor)
      93              :    */
      94         2152 :   Tensor(size_t d0, size_t d1, size_t d2, size_t d3, Tformat fm = Tformat::NCHW,
      95              :          Tdatatype d_type = Tdatatype::FP32,
      96         2152 :          QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE) :
      97         2152 :     Tensor(TensorDim(d0, d1, d2, d3, fm, d_type), nullptr, qscheme_){};
      98              : 
      99              :   /**
     100              :    * @brief     Constructor of Tensor
     101              :    * @param[in] d1 Channel
     102              :    * @param[in] d2 Height
     103              :    * @param[in] d3 Width
     104              :    * @param[in] fm Tensor Format
     105              :    * @param[in] d_type Tensor Data Type
     106              :    * @param[in] qscheme_ Quantization scheme (only applies to Quantized Tensor)
     107              :    */
     108              :   Tensor(size_t d1, size_t d2, size_t d3, Tformat fm = Tformat::NCHW,
     109              :          Tdatatype d_type = Tdatatype::FP32,
     110            6 :          QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE) :
     111            6 :     Tensor(1, d1, d2, d3, fm, d_type, qscheme_){};
     112              : 
     113              :   /**
     114              :    * @brief     Constructor of Tensor with batch size one and d1 size one
     115              :    * @param[in] d2 Height (NCHW) or Width (NHWC)
     116              :    * @param[in] d3 Width (NCHW) or Channel (NHWC)
     117              :    * @param[in] fm Tensor Format
     118              :    * @param[in] d_type Tensor Data Type
     119              :    * @param[in] qscheme_ Quantization scheme (only applies to Quantized Tensor)
     120              :    */
     121              :   Tensor(size_t d2, size_t d3, Tformat fm = Tformat::NCHW,
     122              :          Tdatatype d_type = Tdatatype::FP32,
     123              :          QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE) :
     124              :     Tensor(1, 1, d2, d3, fm, d_type, qscheme_){};
     125              : 
     126              :   /**
     127              :    * @brief     Constructor of Tensor with just Width or Channel
     128              :    * @param[in] d3 Width (NCHW) or Channel (NHWC)
     129              :    * @param[in] fm Tensor Format
     130              :    * @param[in] d_type Tensor Data Type
     131              :    * @param[in] qscheme_ Quantization scheme (only applies to Quantized Tensor)
     132              :    */
     133              :   explicit Tensor(size_t d3, Tformat fm = Tformat::NCHW,
     134              :                   Tdatatype d_type = Tdatatype::FP32,
     135          948 :                   QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE) :
     136          948 :     Tensor(1, 1, 1, d3, fm, d_type, qscheme_){};
     137              : 
     138              :   /**
     139              :    * @brief     Constructor of Tensor
     140              :    * @param[in] d0 Batch of Tensor
     141              :    * @param[in] d1 Channel (NCHW) or Height (NHWC)
     142              :    * @param[in] d2 Height (NCHW) or Width (NHWC)
     143              :    * @param[in] d3 Width (NCHW) or Channel (NHWC)
     144              :    * @param[in] t_type Tensor Type
     145              :    * @param[in] qscheme_ Quantization scheme (only applies to Quantized Tensor)
     146              :    */
     147       139486 :   Tensor(size_t d0, size_t d1, size_t d2, size_t d3,
     148              :          ml::train::TensorDim::TensorType t_type,
     149       139486 :          QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE) :
     150       139486 :     Tensor(TensorDim(d0, d1, d2, d3, t_type), nullptr, qscheme_){};
     151              : 
     152              :   /**
     153              :    * @brief     Constructor of Tensor
     154              :    * @param[in] d1 Channel
     155              :    * @param[in] d2 Height
     156              :    * @param[in] d3 Width
     157              :    * @param[in] t_type Tensor Type
     158              :    * @param[in] qscheme_ Quantization scheme (only applies to Quantized Tensor)
     159              :    */
     160              :   Tensor(size_t d1, size_t d2, size_t d3,
     161              :          ml::train::TensorDim::TensorType t_type,
     162              :          QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE) :
     163              :     Tensor(1, d1, d2, d3, t_type){};
     164              : 
     165              :   /**
     166              :    * @brief     Constructor of Tensor with batch size one and d1 size one
     167              :    * @param[in] d2 Height (NCHW) or Width (NHWC)
     168              :    * @param[in] d3 Width (NCHW) or Channel (NHWC)
     169              :    * @param[in] t_type Tensor Type
     170              :    * @param[in] qscheme_ Quantization scheme (only applies to Quantized Tensor)
     171              :    */
     172              :   Tensor(size_t d2, size_t d3, ml::train::TensorDim::TensorType t_type,
     173              :          QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE) :
     174              :     Tensor(1, (t_type.format == Tformat::NCHW) ? 1 : d3,
     175              :            (t_type.format == Tformat::NCHW) ? d2 : 1,
     176              :            (t_type.format == Tformat::NCHW) ? d3 : d2, t_type, qscheme_){};
     177              :   /**
     178              :    * @brief     Constructor of Tensor with just Width or Channel
     179              :    * @param[in] d3 Width (NCHW) or Channel (NHWC)
     180              :    * @param[in] t_type Tensor Type
     181              :    * @param[in] qscheme_ Quantization scheme (only applies to Quantized Tensor)
     182              :    */
     183              :   explicit Tensor(size_t d3, ml::train::TensorDim::TensorType t_type,
     184         3127 :                   QScheme qscheme_ = QScheme::PER_TENSOR_AFFINE) :
     185              :     Tensor(1, (t_type.format == Tformat::NCHW) ? 1 : d3, 1,
     186         3127 :            (t_type.format == Tformat::NCHW) ? d3 : 1, t_type, qscheme_){};
     187              : 
     188              :   /**
     189              :    * @brief     Constructor of Tensor
     190              :    * @param[in] d data for the Tensor. It needs to set format properly.
     191              :    * @param[in] t_type Tensor Type
     192              :    */
     193              :   Tensor(std::vector<std::vector<std::vector<std::vector<float>>>> const &d,
     194              :          ml::train::TensorDim::TensorType t_type);
     195              : 
     196              :   /**
     197              :    * @brief     Constructor of Tensor
     198              :    * @note      This constructor copies vector again. needs refactoring
     199              :    * @param[in] d data for the Tensor. It needs to set format properly.
     200              :    * @param[in] t_type Tensor Type
     201              :    */
     202            3 :   Tensor(std::vector<std::vector<std::vector<float>>> const &d,
     203            3 :          ml::train::TensorDim::TensorType t_type) :
     204            7 :     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};
     205              : 
     206              :   /**
     207              :    * @brief     Constructor of Tensor
     208              :    * @note      This constructor copies vector again. needs refactoring
     209              :    * @param[in] d data for the Tensor with batch size one
     210              :    * @param[in] t_type Tensor Type
     211              :    */
     212            1 :   Tensor(std::vector<std::vector<float>> const &d,
     213            1 :          ml::train::TensorDim::TensorType t_type) :
     214            2 :     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};
     215              : 
     216              : #ifdef ENABLE_FP16
     217              :   /**
     218              :    * @brief     Constructor of Tensor
     219              :    * @note      This constructor copies vector again. needs refactoring
     220              :    * @param[in] d data for the Tensor with batch size one
     221              :    * @param[in] t_type Tensor Type
     222              :    * @todo      It is more desirable to move this implementaton into
     223              :    *            `tensor.cpp`, for it requires half_tensor.h
     224              :    */
     225              :   Tensor(std::vector<std::vector<std::vector<std::vector<_FP16>>>> const &d,
     226              :          ml::train::TensorDim::TensorType t_type) {
     227              :     itensor_ = std::make_unique<HalfTensor>(d, t_type.format);
     228              :   }
     229              : 
     230              :   /**
     231              :    * @brief     Constructor of Tensor
     232              :    * @note      This constructor copies vector again. needs refactoring
     233              :    * @param[in] d data for the Tensor. It needs to set format properly.
     234              :    * @param[in] t_type Tensor Type
     235              :    */
     236              :   Tensor(std::vector<std::vector<std::vector<_FP16>>> const &d,
     237              :          ml::train::TensorDim::TensorType t_type) :
     238              :     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};
     239              : 
     240              :   /**
     241              :    * @brief     Constructor of Tensor
     242              :    * @note      This constructor copies vector again. needs refactoring
     243              :    * @param[in] d data for the Tensor with batch size one
     244              :    * @param[in] t_type Tensor Type
     245              :    */
     246              :   Tensor(std::vector<std::vector<_FP16>> const &d,
     247              :          ml::train::TensorDim::TensorType t_type) :
     248              :     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, t_type){};
     249              : #endif
     250              : 
     251              :   /**
     252              :    * @brief     Constructor of Tensor
     253              :    * @param[in] d data for the Tensor. It needs to set format properly.
     254              :    * @param[in] t_type Tensor Type
     255              :    */
     256              :   Tensor(std::vector<std::vector<std::vector<std::vector<uint8_t>>>> const &d,
     257              :          std::vector<float> const &scales,
     258              :          std::vector<unsigned int> const &zero_points,
     259              :          ml::train::TensorDim::TensorType t_type, QScheme qscheme_);
     260              : 
     261              :   /**
     262              :    * @brief     Constructor of Tensor
     263              :    * @note      This constructor copies vector again. needs refactoring
     264              :    * @param[in] d data for the Tensor. It needs to set format properly.
     265              :    * @param[in] t_type Tensor Type
     266              :    */
     267            1 :   Tensor(std::vector<std::vector<std::vector<uint8_t>>> const &d,
     268              :          std::vector<float> const &scales,
     269              :          std::vector<unsigned int> const &zero_points,
     270            1 :          ml::train::TensorDim::TensorType t_type, QScheme qscheme_) :
     271            3 :     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, scales, zero_points,
     272            3 :            t_type, qscheme_){};
     273              : 
     274              :   /**
     275              :    * @brief     Constructor of Tensor
     276              :    * @note      This constructor copies vector again. needs refactoring
     277              :    * @param[in] d data for the Tensor with batch size one
     278              :    * @param[in] t_type Tensor Type
     279              :    */
     280              :   Tensor(std::vector<std::vector<uint8_t>> const &d,
     281              :          std::vector<float> const &scales,
     282              :          std::vector<unsigned int> const &zero_points,
     283              :          ml::train::TensorDim::TensorType t_type, QScheme qscheme_) :
     284              :     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, scales, zero_points,
     285              :            t_type, qscheme_){};
     286              : 
     287              :   /**
     288              :    * @brief     Constructor of Tensor
     289              :    * @param[in] d data for the Tensor. It needs to set format properly.
     290              :    * @param[in] t_type Tensor Type
     291              :    */
     292              :   Tensor(std::vector<std::vector<std::vector<std::vector<uint16_t>>>> const &d,
     293              :          std::vector<float> const &scales,
     294              :          std::vector<unsigned int> const &zero_points,
     295              :          ml::train::TensorDim::TensorType t_type, QScheme qscheme_);
     296              : 
     297              :   /**
     298              :    * @brief     Constructor of Tensor
     299              :    * @note      This constructor copies vector again. needs refactoring
     300              :    * @param[in] d data for the Tensor. It needs to set format properly.
     301              :    * @param[in] t_type Tensor Type
     302              :    */
     303            5 :   Tensor(std::vector<std::vector<std::vector<uint16_t>>> const &d,
     304              :          std::vector<float> const &scales,
     305              :          std::vector<unsigned int> const &zero_points,
     306            5 :          ml::train::TensorDim::TensorType t_type, QScheme qscheme_) :
     307           18 :     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, scales, zero_points,
     308           12 :            t_type, qscheme_){};
     309              : 
     310              :   /**
     311              :    * @brief     Constructor of Tensor
     312              :    * @note      This constructor copies vector again. needs refactoring
     313              :    * @param[in] d data for the Tensor with batch size one
     314              :    * @param[in] t_type Tensor Type
     315              :    */
     316              :   Tensor(std::vector<std::vector<uint16_t>> const &d,
     317              :          std::vector<float> const &scales,
     318              :          std::vector<unsigned int> const &zero_points,
     319              :          ml::train::TensorDim::TensorType t_type, QScheme qscheme_) :
     320              :     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, scales, zero_points,
     321              :            t_type, qscheme_){};
     322              : 
     323              :   /**
     324              :    * @brief     Constructor of Tensor
     325              :    * @param[in] d data for the Tensor. It needs to set format properly.
     326              :    * @param[in] t_type Tensor Type
     327              :    */
     328              :   Tensor(std::vector<std::vector<std::vector<std::vector<uint32_t>>>> const &d,
     329              :          std::vector<float> const &scales,
     330              :          std::vector<unsigned int> const &zero_points,
     331              :          ml::train::TensorDim::TensorType t_type, QScheme qscheme_);
     332              : 
     333              :   /**
     334              :    * @brief     Constructor of Tensor
     335              :    * @note      This constructor copies vector again. needs refactoring
     336              :    * @param[in] d data for the Tensor. It needs to set format properly.
     337              :    * @param[in] t_type Tensor Type
     338              :    */
     339            1 :   Tensor(std::vector<std::vector<std::vector<uint32_t>>> const &d,
     340              :          std::vector<float> const &scales,
     341              :          std::vector<unsigned int> const &zero_points,
     342            1 :          ml::train::TensorDim::TensorType t_type, QScheme qscheme_) :
     343            3 :     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, scales, zero_points,
     344            3 :            t_type, qscheme_){};
     345              : 
     346              :   /**
     347              :    * @brief     Constructor of Tensor
     348              :    * @note      This constructor copies vector again. needs refactoring
     349              :    * @param[in] d data for the Tensor with batch size one
     350              :    * @param[in] t_type Tensor Type
     351              :    */
     352              :   Tensor(std::vector<std::vector<uint32_t>> const &d,
     353              :          std::vector<float> const &scales,
     354              :          std::vector<unsigned int> const &zero_points,
     355              :          ml::train::TensorDim::TensorType t_type, QScheme qscheme_) :
     356              :     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, scales, zero_points,
     357              :            t_type, qscheme_){};
     358              : 
     359              :   /**
     360              :    * @brief     Constructor of CharTensor (QINT8)
     361              :    * @param[in] d data for the Tensor. It needs to set format properly.
     362              :    * @param[in] scales scale factors for the Tensor.
     363              :    * @param[in] t_type Tensor Type
     364              :    * @param[in] qscheme_ Quantization scheme (only applies to Quantized Tensor)
     365              :    */
     366              :   Tensor(std::vector<std::vector<std::vector<std::vector<int8_t>>>> const &d,
     367              :          std::vector<float> const &scales,
     368              :          ml::train::TensorDim::TensorType t_type, QScheme qscheme_);
     369              : 
     370              :   /**
     371              :    * @brief     Constructor of CharTensor (QINT8)
     372              :    * @note      This constructor copies vector again. needs refactoring
     373              :    * @param[in] d data for the Tensor. It needs to set format properly.
     374              :    * @param[in] scales scale factors for the Tensor.
     375              :    * @param[in] t_type Tensor Type
     376              :    * @param[in] qscheme_ Quantization scheme (only applies to Quantized Tensor)
     377              :    */
     378            6 :   Tensor(std::vector<std::vector<std::vector<int8_t>>> const &d,
     379              :          std::vector<float> const &scales,
     380            6 :          ml::train::TensorDim::TensorType t_type, QScheme qscheme_) :
     381           20 :     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, scales, t_type,
     382           16 :            qscheme_){};
     383              : 
     384              :   /**
     385              :    * @brief     Constructor of CharTensor (QINT8)
     386              :    * @note      This constructor copies vector again. needs refactoring
     387              :    * @param[in] d data for the Tensor with batch size one
     388              :    * @param[in] scales scale factors for the Tensor.
     389              :    * @param[in] t_type Tensor Type
     390              :    * @param[in] qscheme_ Quantization scheme (only applies to Quantized Tensor)
     391              :    */
     392            1 :   Tensor(std::vector<std::vector<int8_t>> const &d,
     393              :          std::vector<float> const &scales,
     394            1 :          ml::train::TensorDim::TensorType t_type, QScheme qscheme_) :
     395            3 :     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, scales, t_type,
     396            3 :            qscheme_){};
     397              : 
     398              :   /**
     399              :    * @brief     Constructor of CharTensor (QINT16)
     400              :    * @param[in] d data for the Tensor. It needs to set format properly.
     401              :    * @param[in] scales scale factors for the Tensor.
     402              :    * @param[in] t_type Tensor Type
     403              :    * @param[in] qscheme_ Quantization scheme (only applies to Quantized Tensor)
     404              :    */
     405              :   Tensor(std::vector<std::vector<std::vector<std::vector<int16_t>>>> const &d,
     406              :          std::vector<float> const &scales,
     407              :          ml::train::TensorDim::TensorType t_type, QScheme qscheme_);
     408              : 
     409              :   /**
     410              :    * @brief     Constructor of CharTensor (QINT16)
     411              :    * @note      This constructor copies vector again. needs refactoring
     412              :    * @param[in] d data for the Tensor. It needs to set format properly.
     413              :    * @param[in] scales scale factors for the Tensor.
     414              :    * @param[in] t_type Tensor Type
     415              :    * @param[in] qscheme_ Quantization scheme (only applies to Quantized Tensor)
     416              :    */
     417            1 :   Tensor(std::vector<std::vector<std::vector<int16_t>>> const &d,
     418              :          std::vector<float> const &scales,
     419            1 :          ml::train::TensorDim::TensorType t_type, QScheme qscheme_) :
     420            3 :     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, scales, t_type,
     421            3 :            qscheme_){};
     422              : 
     423              :   /**
     424              :    * @brief     Constructor of CharTensor (QINT16)
     425              :    * @note      This constructor copies vector again. needs refactoring
     426              :    * @param[in] d data for the Tensor with batch size one
     427              :    * @param[in] scales scale factors for the Tensor.
     428              :    * @param[in] t_type Tensor Type
     429              :    * @param[in] qscheme_ Quantization scheme (only applies to Quantized Tensor)
     430              :    */
     431              :   Tensor(std::vector<std::vector<int16_t>> const &d,
     432              :          std::vector<float> const &scales,
     433              :          ml::train::TensorDim::TensorType t_type, QScheme qscheme_) :
     434              :     Tensor(std::vector<std::decay<decltype(d)>::type>{d}, scales, t_type,
     435              :            qscheme_){};
     436              : 
     437              :   /**
     438              :    *  @brief  Constructor of Tensor by directly assigning TensorBase.
     439              :    *  @param[in] rhs unique_ptr of a TensorBase
     440              :    *  @note TensorBase is an abstract class so we can't directly instantiate
     441              :    it.
     442              :    *  Make sure to use a unique_ptr with a derived class when utilizing this
     443              :    *  constructor.
     444              :    */
     445              :   Tensor(const std::unique_ptr<TensorBase> &rhs);
     446              : 
     447              :   /**
     448              :    * @brief Basic Destructor
     449              :    */
     450      1392779 :   ~Tensor() noexcept {
     451      1392779 :     if (is_virtual && mapped_ptr != nullptr) {
     452            0 :       deactivate();
     453              :     }
     454      1392779 :   };
     455              : 
     456              :   /**
     457              :    *  @brief  Copy constructor of Tensor.
     458              :    *  @param[in] Tensor &
     459              :    */
     460              :   Tensor(const Tensor &rhs);
     461              : 
     462              :   /**
     463              :    *  @brief  Move constructor of Tensor.
     464              :    *  @param[in] Tensor &&
     465              :    */
     466       200939 :   Tensor(Tensor &&rhs) noexcept = default;
     467              : 
     468              :   /**
     469              :    * @brief  Copy assignment operator.
     470              :    * @param[in] rhs Tensor to be copied.
     471              :    */
     472              :   Tensor &operator=(const Tensor &rhs);
     473              : 
     474              :   /**
     475              :    * @brief  Move assignment operator.
     476              :    * @parma[in] rhs Tensor to be moved.
     477              :    */
     478       186418 :   Tensor &operator=(Tensor &&rhs) noexcept = default;
     479              : 
     480              :   /**
     481              :    * @brief     Comparison operator overload
     482              :    * @param[in] rhs Tensor to be compared with
     483              :    */
     484              :   bool operator==(const Tensor &rhs) const;
     485              : 
     486              :   /**
     487              :    * @brief     Comparison operator overload
     488              :    * @param[in] rhs Tensor to be compared with
     489              :    */
     490           18 :   bool operator!=(const Tensor &rhs) const { return !(*this == rhs); }
     491              : 
     492              :   /**
     493              :    *  @brief  Compare itensor considering dynamic type checking.
     494              :    *  @param[in] lhs pointer of a TensorBase
     495              :    *  @param[in] rhs pointer of a TensorBase
     496              :    */
     497              :   template <typename T>
     498        10824 :   static bool itensorCompare(const TensorBase *lhs, const TensorBase *rhs) {
     499        10824 :     auto lhs_cast = dynamic_cast<const T *>(lhs);
     500        10824 :     auto rhs_cast = dynamic_cast<const T *>(rhs);
     501              : 
     502        10824 :     if (!lhs_cast || !rhs_cast) {
     503              :       return false;
     504              :     }
     505              : 
     506        10824 :     return *lhs_cast == *rhs_cast;
     507              :   }
     508              : 
     509              :   /**
     510              :    * @brief Construct a new Tensor object from a buffer
     511              :    * This will not copy buffer to a new tensor but directly uses it
     512              :    *
     513              :    * @param[in] buf buffer
     514              :    * @param[in] bytes buffer size in bytes
     515              :    * @param[in] d tensor dim
     516              :    * @param[in] offset offset to be used from current
     517              :    * @return    Tensor object
     518              :    * @throws    std::invalid_argument if buf is null
     519              :    */
     520              :   template <typename T = float>
     521        46138 :   static Tensor Map(T *buf, unsigned int bytes, const TensorDim &d,
     522              :                     size_t offset = 0) {
     523        46138 :     if (d.getDataLen() == 0 || buf == nullptr) {
     524            1 :       throw std::invalid_argument(
     525              :         "[Tensor::Map] empty tensor dim is not allowed");
     526              :     }
     527              : 
     528        46137 :     if (d.getDataLen() * sizeof(T) + offset > bytes) {
     529            1 :       throw std::invalid_argument(
     530              :         "Creating shared tensor of size bigger than tensor memory.");
     531              :     }
     532              : 
     533        46136 :     Tensor output("", d.getFormat(), d.getDataType());
     534        46136 :     output.setTensorVar(d, buf, offset);
     535        46136 :     return output;
     536            0 :   };
     537              : 
     538              :   /**
     539              :    * @brief    Allocate memory for this tensor
     540              :    */
     541              :   void allocate();
     542              : 
     543              :   /**
     544              :    * @brief    Deallocate memory for this tensor
     545              :    * @note     This will not necessary free the memory as tensors share memory
     546              :    */
     547              :   void deallocate();
     548              : 
     549              :   /**
     550              :    * @brief    Check if the tensor has memory allocated/assigned/associated
     551              :    */
     552              :   bool isAllocated();
     553              : 
     554              :   /**
     555              :    * @brief     return Data pointer of Tensor
     556              :    * @retval    template T pointer
     557              :    */
     558              :   template <typename T = float> T *getData() const {
     559     63802778 :     return (T *)itensor_->getData();
     560              :   }
     561              : 
     562              :   /**
     563              :    * @brief     return Data pointer of Tensor
     564              :    * @retval    template T pointer
     565              :    */
     566              :   template <typename T = float> T *getData(size_t idx) const {
     567            4 :     return (T *)itensor_->getData(idx);
     568              :   }
     569              : 
     570              :   /**
     571              :    * @brief     return scale pointer of Tensor
     572              :    * @retval    template T pointer
     573              :    */
     574              :   template <typename T = float> T *getScale() const {
     575           27 :     return (T *)itensor_->getScale();
     576              :   }
     577              : 
     578              :   /**
     579              :    * @brief     return scale pointer of Tensor
     580              :    * @retval    template T pointer
     581              :    */
     582              :   template <typename T = float> T *getScale(size_t idx) const {
     583              :     return (T *)itensor_->getScale(idx);
     584              :   }
     585              : 
     586              :   /**
     587              :    * @brief     return zero point pointer of Tensor
     588              :    * @retval    unsigned int pointer
     589              :    */
     590           10 :   unsigned int *getZeroPoint() const { return itensor_->getZeroPoint(); }
     591              : 
     592              :   /**
     593              :    * @brief     return zero point pointer of Tensor
     594              :    * @retval    unsigned int pointer
     595              :    */
     596              :   unsigned int *getZeroPoint(size_t idx) const {
     597              :     return itensor_->getZeroPoint(idx);
     598              :   }
     599              : 
     600              :   /**
     601              :    * @brief     i data index
     602              :    * @retval    template T pointer (address of ith data)
     603              :    */
     604              :   template <typename T = float> T *getAddress(unsigned int i) {
     605      5277730 :     return (T *)itensor_->getAddress(i);
     606              :   }
     607              : 
     608              :   /**
     609              :    * @brief     i data index
     610              :    * @retval    template T pointer (address of ith data)
     611              :    */
     612              :   template <typename T = float> const T *getAddress(unsigned int i) const {
     613       145562 :     return (T *)itensor_->getAddress(i);
     614              :   }
     615              : 
     616              :   /**
     617              :    * @brief    get address of n-d data
     618              :    */
     619              :   template <typename T = float>
     620      5260242 :   T *getAddress(unsigned int b, unsigned int c, unsigned int h,
     621              :                 unsigned int w) {
     622      5260242 :     return getAddress<T>(getIndex(b, c, h, w));
     623              :   }
     624              : 
     625              :   /**
     626              :    * @brief    get address of n-d data
     627              :    */
     628              :   template <typename T = float>
     629       141505 :   const T *getAddress(unsigned int b, unsigned int c, unsigned int h,
     630              :                       unsigned int w) const {
     631       141505 :     return getAddress<T>(getIndex(b, c, h, w));
     632              :   }
     633              : 
     634              :   /**
     635              :    * @brief     return value at specific location
     636              :    * @param[in] idx location
     637              :    */
     638              :   template <typename T = float>
     639              :   const T &getValue(unsigned int idx) const noexcept {
     640     41674434 :     return getData<T>()[idx];
     641              :   }
     642              : 
     643              :   /**
     644              :    * @brief     return value at specific location
     645              :    * @param[in] idx location
     646              :    */
     647              :   template <typename T = float> T &getValue(unsigned int idx) noexcept {
     648      8155258 :     return getData<T>()[idx];
     649              :   }
     650              : 
     651              :   /**
     652              :    * @brief     return value at specific location
     653              :    * @param[in] b batch location
     654              :    * @param[in] c channel location
     655              :    * @param[in] h height location
     656              :    * @param[in] w width location
     657              :    */
     658              :   template <typename T = float>
     659     41753829 :   const T &getValue(unsigned int b, unsigned int c, unsigned int h,
     660              :                     unsigned int w) const noexcept {
     661     41753829 :     return getValue<T>(getIndex(b, c, h, w));
     662              :   }
     663              : 
     664              :   /**
     665              :    * @brief     return value at specific location
     666              :    * @param[in] b batch location
     667              :    * @param[in] c channel location
     668              :    * @param[in] h height location
     669              :    * @param[in] w width location
     670              :    */
     671              :   template <typename T = float>
     672     14568250 :   T &getValue(unsigned int b, unsigned int c, unsigned int h,
     673              :               unsigned int w) noexcept {
     674     14568250 :     return getValue<T>(getIndex(b, c, h, w));
     675              :   }
     676              : 
     677              :   /**
     678              :    * @brief     Fill the Tensor elements with value
     679              :    * @param[in] value value to be stored
     680              :    */
     681              :   void setValue(float value);
     682              : 
     683              :   /**
     684              :    * @brief     Set the element value
     685              :    * @param[in] b batch location
     686              :    * @param[in] c channel location
     687              :    * @param[in] h height location
     688              :    * @param[in] w width location
     689              :    * @param[in] value value to be stored
     690              :    */
     691              :   void setValue(unsigned int b, unsigned int c, unsigned int h, unsigned int w,
     692              :                 float value);
     693              : 
     694              :   /**
     695              :    * @brief     Set the element value
     696              :    * @param[in] offset offset from start location
     697              :    * @param[in] value value to be stored
     698              :    *
     699              :    * @todo      This is a temporary workout. Remove this
     700              :    */
     701              :   void setValueInt(unsigned int offset, int value) noexcept {
     702              :     int *data_int = (int *)getData();
     703       520703 :     data_int[offset] = value;
     704       520698 :   }
     705              : 
     706              :   /**
     707              :    * @brief     add the element value to the location
     708              :    * @param[in] b batch location
     709              :    * @param[in] c channel location
     710              :    * @param[in] h height location
     711              :    * @param[in] w width location
     712              :    * @param[in] value value to be stored
     713              :    * @param[in] beta scalar to multiply output with and add
     714              :    */
     715              :   void addValue(unsigned int b, unsigned int c, unsigned int h, unsigned int w,
     716              :                 float value, float beta) noexcept;
     717              : 
     718              :   /**
     719              :    * @brief     Fill the Tensor elements with zero
     720              :    */
     721              :   void setZero();
     722              : 
     723              :   /**
     724              :    * @brief     Set the tensor with random normal distribution
     725              :    * @param[in] mean mean of the distribution
     726              :    * @param[in] std standard deviation of the distribution
     727              :    */
     728              :   void setRandNormal(float mean = 0.0f, float stddev = 0.05f);
     729              : 
     730              :   /**
     731              :    * @brief     Set the tensor with random uniform distribution
     732              :    * @param[in] min minimum value for the distribution
     733              :    * @param[in] max maximum value for the distribution
     734              :    */
     735              :   void setRandUniform(float min = -0.05f, float max = 0.05f);
     736              : 
     737              :   /**
     738              :    * @brief     Set the tensor with random bernoulli distribution
     739              :    * @param[in] probability probability value for the distribution
     740              :    */
     741              :   void setRandBernoulli(float probability = 0.5f);
     742              : 
     743              :   /**
     744              :    * @brief     Initialize the memory of the given tensor
     745              :    */
     746              :   void initialize();
     747              : 
     748              :   /**
     749              :    * @brief     Initialize the memory of the given tensor
     750              :    * @param     init Initiailizer to use for the initialization
     751              :    */
     752              :   void initialize(Initializer init);
     753              : 
     754              :   /**
     755              :    * @brief Apply instantly to the element
     756              :    * @param[in] *function function pointer applied
     757              :    * @return int ML_ERROR_NONE if successful
     758              :    */
     759          133 :   template <typename T = float> int apply_i(std::function<T(T)> f) {
     760          133 :     Tensor result = *this;
     761          133 :     apply<T>(f, result);
     762              : 
     763          133 :     return ML_ERROR_NONE;
     764          133 :   };
     765              : 
     766              :   /**
     767              :    * @brief     Apply function element by element
     768              :    * @param[in] *function function pointer applied
     769              :    * @retval    Tensor
     770              :    */
     771        20913 :   template <typename T = float> Tensor apply(std::function<T(T)> f) const {
     772        20913 :     Tensor result;
     773        20913 :     apply<T>(f, result);
     774              : 
     775        20913 :     return result;
     776            0 :   };
     777              : 
     778              :   /**
     779              :    * @brief     Apply function element by element
     780              :    * @param[in] *function function pointer applied
     781              :    * @param[out] output output tensor
     782              :    * @retval    Tensor
     783              :    */
     784              :   template <typename T = float>
     785        47635 :   Tensor &apply(std::function<T(T)> f, Tensor &output) const {
     786        69154 :     CREATE_IF_EMPTY_DIMS(output, itensor_->getDim(), nullptr);
     787              : 
     788        47635 :     if (itensor_->getFormat() != output.itensor_->getFormat() ||
     789              :         itensor_->getDataType() != output.itensor_->getDataType()) {
     790              :       /// @todo add unittest
     791            0 :       throw std::invalid_argument(
     792              :         "[Tensor::apply] output format or data type does not match");
     793              :     }
     794              : 
     795        47635 :     itensor_->apply(f, output);
     796              : 
     797        47635 :     return output;
     798              :   }
     799              : 
     800              :   /**
     801              :    * @brief     Apply function to Tensor
     802              :    * @param[in] *function function pointer applied
     803              :    * @retval    Tensor
     804              :    */
     805              :   Tensor apply(std::function<Tensor(Tensor)> f) const;
     806              : 
     807              :   /**
     808              :    * @brief     Apply function to Tensor
     809              :    * @param[in] *function function pointer applied
     810              :    * @param[out] output output tensor
     811              :    * @retval    Tensor
     812              :    */
     813              :   Tensor &apply(std::function<Tensor &(Tensor, Tensor &)> f,
     814              :                 Tensor &output) const;
     815              : 
     816              :   /**
     817              :    * @brief     Multiply Tensor Elementwise
     818              :    * @param[in] m Tensor to be multiplied
     819              :    * @param[in] beta scalar to multiply output with and add
     820              :    * @retval    #ML_ERROR_NONE successful
     821              :    *
     822              :    * @note support different strided inputs and output
     823              :    * @note does not support broadcasting
     824              :    *
     825              :    * @todo merge this to multiply_i
     826              :    */
     827              :   int multiply_i_strided(Tensor const &m, const float beta = 0.0);
     828              : 
     829              :   /**
     830              :    * @brief     Multiply Tensor Element by Element ( Not the MxM )
     831              :    * @param[in] m Tensor to be multiplied
     832              :    * @param[in] beta scalar to multiply output with and add
     833              :    * @retval    Calculated Tensor
     834              :    *
     835              :    * @note support different strided inputs and output
     836              :    * @note does not support broadcasting
     837              :    *
     838              :    * @todo merge this to multiply
     839              :    */
     840              :   Tensor multiply_strided(Tensor const &m, const float beta = 0.0) const;
     841              : 
     842              :   /**
     843              :    * @brief     Multiply Tensor Element by Element ( Not the MxM )
     844              :    * @param[in] m Tensor to be multiplied
     845              :    * @param[out] output Tensor to store the result
     846              :    * @param[in] beta scalar to multiply output with and add
     847              :    * @retval    Calculated Tensor
     848              :    *
     849              :    * @note support different strided inputs and output
     850              :    * @note does not support broadcasting
     851              :    *
     852              :    * @todo merge this to multiply
     853              :    */
     854              :   Tensor &multiply_strided(Tensor const &m, Tensor &output,
     855              :                            const float beta = 0.0) const;
     856              : 
     857              :   /**
     858              :    * @brief     Multiply value element by element immediately
     859              :    * @param[in] value multiplier
     860              :    * @retval    #ML_ERROR_INVALID_PARAMETER Tensor dimension is not right
     861              :    * @retval    #ML_ERROR_NONE Successful
     862              :    */
     863              :   int multiply_i(float const &value);
     864              : 
     865              :   /**
     866              :    * @brief     Multiply value element by element
     867              :    * @param[in] value multiplier
     868              :    * @retval    Calculated Tensor
     869              :    */
     870              :   Tensor multiply(float const &value) const;
     871              : 
     872              :   /**
     873              :    * @brief      multiply value element by element
     874              :    * @param[in]  value multiplier
     875              :    * @param[out] out out tensor to store the result
     876              :    * @retval     Calculated Tensor
     877              :    */
     878              :   Tensor &multiply(float const &value, Tensor &out) const;
     879              : 
     880              :   /**
     881              :    * @brief     Multiply Tensor Elementwise
     882              :    * @param[in] m Tensor to be multiplied
     883              :    * @param[in] beta scalar to multiply output with and add
     884              :    * @retval    #ML_ERROR_NONE successful
     885              :    */
     886              :   int multiply_i(Tensor const &m, const float beta = 0.0);
     887              : 
     888              :   /**
     889              :    * @brief     Multiply Tensor Element by Element ( Not the MxM )
     890              :    * @param[in] m Tensor to be multiplied
     891              :    * @param[in] beta scalar to multiply output with and add
     892              :    * @retval    Calculated Tensor
     893              :    */
     894              :   Tensor multiply(Tensor const &m, const float beta = 0.0) const;
     895              : 
     896              :   /**
     897              :    * @brief      Multiply Tensor Element by Element ( Not the MxM )
     898              :    * @param[in]  m Tensor to be multiplied
     899              :    * @param[out] output Tensor to store the result
     900              :    * @param[in]  beta scalar to multiply output with and add
     901              :    * @retval     Calculated Tensor
     902              :    */
     903              :   Tensor &multiply(Tensor const &m, Tensor &output,
     904              :                    const float beta = 0.0) const;
     905              : 
     906              :   /**
     907              :    * @brief     Divide value element by element immediately
     908              :    * @param[in] value divisor
     909              :    * @retval    #ML_ERROR_INVALID_PARAMETER Tensor dimension is not right
     910              :    * @retval    #ML_ERROR_NONE Successful
     911              :    */
     912              :   int divide_i(float const &value);
     913              : 
     914              :   /**
     915              :    * @brief     Divide value element by element
     916              :    * @param[in] value Divisor
     917              :    * @retval    Calculated Tensor
     918              :    */
     919              :   Tensor divide(float const &value) const;
     920              : 
     921              :   /**
     922              :    * @brief     Divide value element by element
     923              :    * @param[in] value Divisor
     924              :    * @param[out] output Tensor to store the result
     925              :    * @retval    Calculated Tensor
     926              :    */
     927              :   Tensor &divide(float const &value, Tensor &output) const;
     928              : 
     929              :   /**
     930              :    * @brief     divide Tensor Elementwise
     931              :    * @param[in] m Tensor to be multiplied
     932              :    * @retval    #ML_ERROR_NONE successful
     933              :    */
     934              :   int divide_i(Tensor const &m);
     935              : 
     936              :   /**
     937              :    * @brief     Divide Tensor Element by Element
     938              :    * @param[in] m Divisor Tensor
     939              :    * @retval    Calculated Tensor
     940              :    */
     941              :   Tensor divide(Tensor const &m) const;
     942              : 
     943              :   /**
     944              :    * @brief     divide Tensor Elementwise
     945              :    * @param[in] m Tensor to be multiplied
     946              :    * @param[out] output Tensor to store the result
     947              :    * @retval    Calculated Tensor
     948              :    */
     949              :   Tensor &divide(Tensor const &m, Tensor &output) const;
     950              : 
     951              :   /**
     952              :    * @brief     Add Tensor Elementwise
     953              :    * @param[in] input Tensor to be added
     954              :    * @param[in] beta scalar to add output with and add
     955              :    * @retval    #ML_ERROR_NONE successful
     956              :    *
     957              :    * @note support different strided inputs and output
     958              :    * @note does not support broadcasting
     959              :    *
     960              :    * @todo merge this to add_i
     961              :    */
     962              :   int add_i_strided(Tensor const &input, const float beta = 0.0);
     963              : 
     964              :   /**
     965              :    * @brief     Add Tensor Element by Element
     966              :    * @param[in] input Tensor to be added
     967              :    * @param[in] beta Value to be scale the input tensor
     968              :    * @retval    Calculated Tensor
     969              :    *
     970              :    * @note support different strided inputs and output
     971              :    * @note does not support broadcasting
     972              :    *
     973              :    * @todo merge this to add
     974              :    */
     975              :   Tensor add_strided(Tensor const &input, const float beta = 0.0) const;
     976              : 
     977              :   /**
     978              :    * @brief      Add Tensor Element by Element
     979              :    * @param[in]  input Tensor to be added
     980              :    * @param[out] output Tensor to store the result
     981              :    * @param[in]  beta Value to be scale the input tensor
     982              :    * @retval     Calculated Tensor
     983              :    *
     984              :    * @note support different strided inputs and output
     985              :    * @note does not support broadcasting
     986              :    *
     987              :    * @todo merge this to add
     988              :    */
     989              :   Tensor &add_strided(Tensor const &input, Tensor &output,
     990              :                       const float beta = 0.0) const;
     991              : 
     992              :   /**
     993              :    * @brief     Add Tensor Element immediately to target tensor without mem copy
     994              :    * @param[in] value value to be added
     995              :    * @retval    #ML_ERROR_NONE  Successful
     996              :    * @retval    #ML_ERROR_INVALID_PARAMETER Invalid Parameter
     997              :    */
     998              :   int add_i(float const &value);
     999              : 
    1000              :   /**
    1001              :    * @brief     Add value Element by Element
    1002              :    * @param[in] value value to be added
    1003              :    * @retval    Calculated Tensor
    1004              :    */
    1005              :   Tensor add(float const &value) const;
    1006              : 
    1007              :   /**
    1008              :    * @brief      Add Tensor Element by Element
    1009              :    * @param[in]  value value to be added
    1010              :    * @param[out] output Tensor to save output without allocating new memory
    1011              :    * @retval     Calculated Tensor
    1012              :    */
    1013              :   Tensor &add(float const &value, Tensor &output) const;
    1014              : 
    1015              :   /**
    1016              :    * @brief     Add Tensor Element by Element without mem copy
    1017              :    * @param[in] m Tensor to be added
    1018              :    * @param[in] alpha Values to be scaled
    1019              :    * @retval    #ML_ERROR_NONE  Successful
    1020              :    * @retval    #ML_ERROR_INVALID_PARAMETER Invalid Parameter
    1021              :    */
    1022              :   int add_i(Tensor const &m, float const alpha = 1.F);
    1023              : 
    1024              :   /**
    1025              :    * @brief Do add_i for specific section
    1026              :    *
    1027              :    * @param len Length of the specific section
    1028              :    * @param addr_idx Starting index of the psecific section
    1029              :    * @param m Input Tensor to be added
    1030              :    * @param incX Incremental index of X
    1031              :    * @param incY Incremental index of Y
    1032              :    * @param alphas Vector of multiple alpha values
    1033              :    * @param alpha_idx Index of alpha in alpha vector
    1034              :    * @retval #ML_ERROR_NONE  Successful
    1035              :    * @retval #ML_ERROR_INVALID_PARAMETER Invalid Parameter
    1036              :    */
    1037              :   int add_i_partial(unsigned int len, unsigned int addr_idx, Tensor &m,
    1038              :                     unsigned int incX, unsigned int incY, const Tensor alphas,
    1039              :                     unsigned int alpha_idx);
    1040              : 
    1041              :   /**
    1042              :    * @brief     Add Tensor Element by Element
    1043              :    * @param[in] m Tensor to be added
    1044              :    * @param[in] alpha Values to be scaled
    1045              :    * @retval    Calculated Tensor
    1046              :    */
    1047              :   Tensor add(Tensor const &m, float const alpha = 1) const;
    1048              : 
    1049              :   /**
    1050              :    * @brief      Add Tensor Element by Element
    1051              :    * @param[in]  m Tensor to be added
    1052              :    * @param[out] output Tensor to be out
    1053              :    * @param[in]  alpha Values to be scaled
    1054              :    * @retval     Calculated Tensor
    1055              :    */
    1056              :   Tensor &add(Tensor const &m, Tensor &output, float const alpha = 1) const;
    1057              : 
    1058              :   /**
    1059              :    * @brief     memcpyless version of subtract
    1060              :    * @retval    #ML_ERROR_NONE  Successful
    1061              :    * @retval    #ML_ERROR_INVALID_PARAMETER Invalid Parameter
    1062              :    */
    1063              :   int subtract_i(float const &value);
    1064              : 
    1065              :   /**
    1066              :    * @brief     subtract value Element by Element
    1067              :    * @param[in] value value to be subtracted
    1068              :    * @retval    Calculated Tensor
    1069              :    */
    1070              :   Tensor subtract(float const &value) const;
    1071              : 
    1072              :   /**
    1073              :    * @brief      Subtract Tensor Element by Element
    1074              :    * @param[in]  value value to be added
    1075              :    * @param[out] output Tensor to save output without allocating new memory
    1076              :    * @retval     Calculated Tensor
    1077              :    */
    1078              :   Tensor &subtract(float const &value, Tensor &output) const;
    1079              : 
    1080              :   /**
    1081              :    * @brief     memcpyless version of subtract
    1082              :    * @param[in] m Tensor to be subtracted
    1083              :    * @retval    #ML_ERROR_NONE  Successful
    1084              :    * @retval    #ML_ERROR_INVALID_PARAMETER Invalid Parameter
    1085              :    */
    1086              :   int subtract_i(Tensor const &m);
    1087              : 
    1088              :   /**
    1089              :    * @brief     Substract Tensor Element by Element
    1090              :    * @param[in] m Tensor to be subtracted
    1091              :    * @retval    Calculated Tensor
    1092              :    */
    1093              :   Tensor subtract(Tensor const &m) const;
    1094              : 
    1095              :   /**
    1096              :    * @brief      Subtract Tensor Element by Element
    1097              :    * @param[in]  m Tensor to be added
    1098              :    * @param[out] output Tensor to be out
    1099              :    * @retval     Calculated Tensor
    1100              :    */
    1101              :   Tensor &subtract(Tensor const &m, Tensor &output) const;
    1102              : 
    1103              :   /**
    1104              :    * @brief     sum all the Tensor elements according to the batch
    1105              :    * @retval    Calculated Tensor(batch, 1, 1, 1)
    1106              :    */
    1107              :   Tensor sum_by_batch() const;
    1108              : 
    1109              :   /**
    1110              :    * @brief     sum all the Tensor elements according to the axis
    1111              :    *            0 : batch direction
    1112              :    *            1 : channel direction
    1113              :    *            2 : height direction
    1114              :    *            3 : width direction
    1115              :    * @param[in] axis Axis to calculate sum along
    1116              :    * @param[in] alpha Scale the sum by this value
    1117              :    * @retval    Calculated Tensor
    1118              :    */
    1119              :   Tensor sum(unsigned int axis, float alpha = 1.0) const;
    1120              : 
    1121              :   /**
    1122              :    * @brief     sum all the Tensor elements according to the axis
    1123              :    *            0 : batch direction
    1124              :    *            1 : channel direction
    1125              :    *            2 : height direction
    1126              :    *            3 : width direction
    1127              :    * @param[in] axis Axis to calculate sum along
    1128              :    * @param[out] output output tensor
    1129              :    * @param[in] alpha Scale the sum by this value
    1130              :    * @retval    Calculated Tensor
    1131              :    */
    1132              :   Tensor &sum(unsigned int axis, Tensor &output, float alpha = 1.0,
    1133              :               float beta = 0.0) const;
    1134              : 
    1135              :   /**
    1136              :    * @brief sum all the Tensor by multiple axes
    1137              :    *
    1138              :    * @param axes axes to sum along
    1139              :    * @param alpha Scale the sum by this value
    1140              :    * @return Tensor
    1141              :    */
    1142              :   Tensor sum(const std::vector<unsigned int> &axes, float alpha = 1.0) const;
    1143              : 
    1144              :   /**
    1145              :    * @brief sum all the Tensor by multiple axes
    1146              :    *
    1147              :    * @param axes axes to sum along
    1148              :    * @param[out] output output tensor
    1149              :    * @param alpha Scale the sum by this value
    1150              :    * @return Tensor
    1151              :    */
    1152              :   Tensor &sum(const std::vector<unsigned int> &axes, Tensor &output,
    1153              :               float alpha = 1.0) const;
    1154              : 
    1155              :   /**
    1156              :    * @brief  return absolute value
    1157              :    * @retval Calculated Tensor
    1158              :    */
    1159              :   Tensor &abs(Tensor &output) const;
    1160              : 
    1161              :   /**
    1162              :    * @brief     Averaging the Tensor elements according to the axis
    1163              :    *            0 : batch direction
    1164              :    *            1 : channel direction
    1165              :    *            2 : height direction
    1166              :    *            3 : width direction
    1167              :    * @retval    Calculated Tensor
    1168              :    */
    1169              :   Tensor average(unsigned int axis) const;
    1170              : 
    1171              :   /**
    1172              :    * @brief     Averaging the Tensor elements according to the axis
    1173              :    * @retval    Calculated Tensor
    1174              :    */
    1175              :   Tensor &average(unsigned int axis, Tensor &output) const;
    1176              : 
    1177              :   /**
    1178              :    * @brief     Average all the Tensor by multiple axes
    1179              :    * @param[in] axes axes to sum along
    1180              :    * @retval    Calculated Tensor
    1181              :    */
    1182              :   Tensor average(const std::vector<unsigned int> &axes) const;
    1183              : 
    1184              :   /**
    1185              :    * @brief      Average all the Tensor by multiple axes
    1186              :    * @param[in]  axes axes to sum along
    1187              :    * @param[out] output output tensor
    1188              :    * @retval     Calculated Tensor
    1189              :    */
    1190              :   Tensor &average(const std::vector<unsigned int> &axes, Tensor &output) const;
    1191              : 
    1192              :   /**
    1193              :    * @brief     Average the Tensor elements by all axis
    1194              :    * @retval    Calculated Tensor
    1195              :    */
    1196              :   Tensor average() const;
    1197              : 
    1198              :   /**
    1199              :    * @brief     Averaging the Tensor elements by all axis
    1200              :    * @retval    Calculated Tensor
    1201              :    */
    1202              :   Tensor &average(Tensor &output) const;
    1203              : 
    1204              :   /**
    1205              :    * @brief     Tensor power element without mem copy
    1206              :    * @param[in] exponent exponent
    1207              :    * @retval    #ML_ERROR_NONE  Successful
    1208              :    */
    1209              :   int pow_i(float exponent);
    1210              : 
    1211              :   /**
    1212              :    * @brief     Tensor power element by element
    1213              :    * @param[in] exponent exponent
    1214              :    * @retval    Calculated Tensor
    1215              :    */
    1216              :   Tensor pow(float exponent) const;
    1217              : 
    1218              :   /**
    1219              :    * @brief      Tensor power element by element
    1220              :    * @param[in]  exponent exponent
    1221              :    * @param[out] output out to store the result
    1222              :    * @retval     Calculated Tensor
    1223              :    */
    1224              :   Tensor &pow(float exponent, Tensor &output) const;
    1225              : 
    1226              :   /**
    1227              :    * @brief     Compute square-root element by element
    1228              :    * @retval    #ML_ERROR_NONE  Successful
    1229              :    */
    1230              :   int sqrt_i();
    1231              : 
    1232              :   /**
    1233              :    * @brief     Compute square-root by element
    1234              :    * @retval    Calculated Tensor
    1235              :    */
    1236              :   Tensor sqrt() const;
    1237              : 
    1238              :   /**
    1239              :    * @brief      Compute square-root by element
    1240              :    * @param[out] output out to store the result
    1241              :    * @retval     Calculated Tensor
    1242              :    */
    1243              :   Tensor &sqrt(Tensor &output) const;
    1244              : 
    1245              :   /**
    1246              :    * @brief     Compute negation by element
    1247              :    * @retval    Calculated Tensor
    1248              :    */
    1249              :   Tensor neg() const;
    1250              : 
    1251              :   /**
    1252              :    * @brief      Compute negation by element
    1253              :    * @param[out] output out to store the result
    1254              :    * @retval     Calculated Tensor
    1255              :    */
    1256              :   Tensor &neg(Tensor &output) const;
    1257              : 
    1258              :   /**
    1259              :    * @brief     Gauss error function
    1260              :    * @retval    #ML_ERROR_NONE  Successful
    1261              :    */
    1262              :   int erf_i();
    1263              : 
    1264              :   /**
    1265              :    * @brief     Gauss error function
    1266              :    * @retval    Calculated Tensor
    1267              :    */
    1268              :   Tensor erf() const;
    1269              : 
    1270              :   /**
    1271              :    * @brief      Gauss error function
    1272              :    * @param[out] output out to store the result
    1273              :    * @retval     Calculated Tensor
    1274              :    */
    1275              :   Tensor &erf(Tensor &output) const;
    1276              : 
    1277              :   /**
    1278              :    * @brief    sin transform function
    1279              :    * @param[out] out out to store the result
    1280              :    */
    1281              :   void sin(Tensor &out, float alpha = 1.0) const;
    1282              : 
    1283              :   /**
    1284              :    * @brief    cos transform function
    1285              :    * @param[out] out out to store the result
    1286              :    */
    1287              :   void cos(Tensor &out, float alpha = 1.0) const;
    1288              : 
    1289              :   /**
    1290              :    * @brief tangent transform function
    1291              :    * @param[out] output out to store the result
    1292              :    */
    1293              :   void tan(Tensor &output, float alpha = 1.0) const;
    1294              : 
    1295              :   /**
    1296              :    * @brief inverse squared root function (in-place)
    1297              :    */
    1298              :   void inv_sqrt_i();
    1299              : 
    1300              :   /**
    1301              :    * @brief inverse squared root function
    1302              :    * @param[in] out output Tensor
    1303              :    */
    1304              :   Tensor inv_sqrt(Tensor &out) const;
    1305              : 
    1306              :   /**
    1307              :    * @brief     Anchor a starting point to defer following evaluation
    1308              :    * @retval    LazyTensor class that can be used with run();
    1309              :    */
    1310              :   LazyTensor chain() const;
    1311              : 
    1312              :   /**
    1313              :    * @brief     l2norm the Tensor elements
    1314              :    * @retval    Calculated l2norm
    1315              :    */
    1316              :   float l2norm() const;
    1317              : 
    1318              :   /**
    1319              :    * @brief     Normalize the Tensor elements
    1320              :    * @retval    Calculated Tensor
    1321              :    */
    1322              :   Tensor &normalization(Tensor &output) const;
    1323              : 
    1324              :   /**
    1325              :    * @brief     Standardize the Tensor elements
    1326              :    * @retval    Calculated Tensor
    1327              :    */
    1328              :   Tensor &standardization(Tensor &output) const;
    1329              : 
    1330              :   /**
    1331              :    * @brief     Normalize the Tensor elements in-place
    1332              :    * @retval    Calculated Tensor
    1333              :    */
    1334              :   void normalization_i();
    1335              : 
    1336              :   /**
    1337              :    * @brief     Standardize the Tensor elements in-place
    1338              :    * @retval    Calculated Tensor
    1339              :    */
    1340              :   void standardization_i();
    1341              : 
    1342              :   /**
    1343              :    * @brief     Dot Product of Tensor ( equal MxM )
    1344              :    * @details   This applies dot of the last dimension of this and second-last
    1345              :    * dimension of passed input tensor.
    1346              :    * @param[in] input Tensor
    1347              :    * @param[in] trans Transpose
    1348              :    * @param[in] trans_in Transpose input
    1349              :    * @retval    Calculated Tensor
    1350              :    */
    1351              :   Tensor dot(Tensor const &input, bool trans = false,
    1352              :              bool trans_in = false) const;
    1353              : 
    1354              :   /**
    1355              :    * @brief     Dot Product of Tensor ( equal MxM )
    1356              :    * @details   This applies dot of the last dimension of this and
    1357              :    * second-last dimension of passed input tensor.
    1358              :    * @param[in] input Tensor
    1359              :    * @param[in] output output Tensor
    1360              :    * @param[in] trans Transpose
    1361              :    * @param[in] trans_in Transpose input
    1362              :    * @param[in] beta beta
    1363              :    * @retval    Calculated Tensor
    1364              :    */
    1365              :   Tensor &dot(Tensor const &input, Tensor &output, bool trans = false,
    1366              :               bool trans_in = false, float beta = 0.0f) const;
    1367              : 
    1368              :   void dot(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs,
    1369              :            bool trans = false, bool trans_in = false, float beta = 0.0f) const;
    1370              : 
    1371              :   /**
    1372              :    * @brief compute the derivative of this in the current tensor
    1373              :    * @param input same as given to the dot()
    1374              :    * @param output_deriv the derivative of the output
    1375              :    * @param[in] trans same as given to the dot()
    1376              :    * @param[in] trans_in same as given to the dot()
    1377              :    * @param[in] beta same as given to the dot()
    1378              :    * @note This will compute the derivative in-place and will overwrite
    1379              :    existing
    1380              :    * data in the tensor
    1381              :    */
    1382              :   Tensor &dot_deriv_wrt_1(Tensor const &input, Tensor const &output_deriv,
    1383              :                           bool trans = false, bool trans_in = false,
    1384              :                           float beta = 0.0f);
    1385              : 
    1386              :   /**
    1387              :    * @brief compute the derivative wrt m in the input tensor
    1388              :    * @param input_deriv tensor where derivative wrt m will be stored
    1389              :    * @param output_deriv the derivative of the output
    1390              :    * @param[in] trans same as given to the dot()
    1391              :    * @param[in] trans_in same as given to the dot()
    1392              :    * @param[in] beta same as given to the dot()
    1393              :    * @note The caller tensor must be the same tensor as the one which called
    1394              :    the dot() product.
    1395              :    */
    1396              :   Tensor &dot_deriv_wrt_2(Tensor &input_deriv, Tensor const &output_deriv,
    1397              :                           bool trans = false, bool trans_in = false,
    1398              :                           float beta = 0.0f) const;
    1399              : 
    1400              :   /**
    1401              :    * @copydoc Tensor::dot(Tensor const &input, Tensor &output, bool trans,
    1402              :               bool trans_in, float beta) const
    1403              :    * @details performs dot operation over a batch of inputs. If the batch sizes
    1404              :    of the given two tensors are different, the bigger one should be a multiple
    1405              :    of the smaller one.
    1406              :    */
    1407              :   Tensor &dotBatched(Tensor const &input, Tensor &result, bool trans = false,
    1408              :                      bool trans_in = false, float beta = 0.0f) const;
    1409              : 
    1410              :   /**
    1411              :    * @copydoc Tensor::dot_deriv_wrt_1(Tensor const &input, Tensor const
    1412              :    &output_deriv, bool trans, bool trans_in, float beta)
    1413              :    */
    1414              :   Tensor &dot_batched_deriv_wrt_1(Tensor const &input,
    1415              :                                   Tensor const &output_deriv,
    1416              :                                   bool trans = false, bool trans_in = false,
    1417              :                                   float beta = 0.0f);
    1418              : 
    1419              :   /**
    1420              :    * @brief Tensor::dot_deriv_wrt_2(Tensor const &input_deriv, Tensor const
    1421              :    &output_deriv, bool trans, bool trans_in, float beta) const
    1422              :    */
    1423              :   Tensor &dot_batched_deriv_wrt_2(Tensor &input_deriv,
    1424              :                                   Tensor const &output_deriv,
    1425              :                                   bool trans = false, bool trans_in = false,
    1426              :                                   float beta = 0.0f) const;
    1427              : 
    1428              :   /**
    1429              :    * @brief Calculate Drop Out Mask : x * 1.0/(1.0-rate)
    1430              :    * @param dropout drop out rate
    1431              :    * @retval Tensor& reference of drop out mask
    1432              :    */
    1433              :   Tensor dropout_mask(float dropout) const;
    1434              : 
    1435              :   /**
    1436              :    * @brief Calculate Drop Out Mask : x * 1.0/(1.0-rate) inplace
    1437              :    * @param dropout drop out rate
    1438              :    */
    1439              :   void dropout_mask(float dropout);
    1440              : 
    1441              :   /**
    1442              :    * @brief Calculate filter mask
    1443              :    * @param mask_len length of each mask along the last axis
    1444              :    * @param invert invert the mask
    1445              :    */
    1446              :   void filter_mask(const Tensor &mask_len, bool reverse = false);
    1447              : 
    1448              :   /**
    1449              :    * @brief Calculate 2 Zone Out Mask
    1450              :    * @details Calculate zone out mask according to the bernoulli distribution.
    1451              :    * Zone out mask with rate @a zoneout for inplace and the other zone out mask
    1452              :    * with rate @a (1-zoneout).
    1453              :    * @param zoneout zone out rate
    1454              :    * @retval Tensor zone out mask for opposite tensor
    1455              :    */
    1456              :   Tensor zoneout_mask(float zoneout);
    1457              : 
    1458              :   /**
    1459              :    * @brief Calculate 2 Zone Out Mask
    1460              :    * @details Calculate zone out mask according to the bernoulli distribution.
    1461              :    * Zone out mask with rate @a zoneout for inplace and the other zone out mask
    1462              :    * with rate @a (1-zoneout).
    1463              :    * @param opposite opposite zone out mask
    1464              :    * @param zoneout zone out rate
    1465              :    */
    1466              :   void zoneout_mask(Tensor &opposite, float zoneout);
    1467              : 
    1468              :   /**
    1469              :    * @brief split tensor along axis.
    1470              :    *
    1471              :    * @param num_size num_size
    1472              :    * @param axis axis
    1473              :    * @return Tensor splitted tensor
    1474              :    */
    1475              :   std::vector<Tensor> split(unsigned num_size, int axis = 0);
    1476              : 
    1477              :   /**
    1478              :    * @brief split tensor along axis.
    1479              :    *
    1480              :    * @param sizes sizes
    1481              :    * @param axis axis
    1482              :    * @return Tensor splitted tensor
    1483              :    * @note if the given array sizes is just a 1 unsigned int value, assumes that
    1484              :    * it divide tensor by given size evenly
    1485              :    */
    1486              :   std::vector<Tensor> split(std::vector<size_t> sizes, int axis = 0);
    1487              : 
    1488              :   /**
    1489              :    * @brief concatenate tensors along axis
    1490              :    *
    1491              :    * @param tensors tensors to be concatenated to the first tensor
    1492              :    * @param axis axis
    1493              :    * @param output output tensor to store the result
    1494              :    * @return Tensor concatenated tensor
    1495              :    *
    1496              :    * @note  This function should not be used directly. Please use cat() instead.
    1497              :    */
    1498              :   Tensor concat(const std::vector<Tensor> &tensors, int axis, Tensor &output);
    1499              : 
    1500              :   /**
    1501              :    * @brief concatenate tensors along axis
    1502              :    *
    1503              :    * @param tensors tensors to be concatenated to the first tensor
    1504              :    * @param axis axis
    1505              :    * @return Tensor concatenated tensor
    1506              :    */
    1507              :   static Tensor cat(const std::vector<Tensor> &tensors, int axis = 0);
    1508              : 
    1509              :   /**
    1510              :    * @brief concatenate tensors along axis
    1511              :    *
    1512              :    * @param tensors tensors to be concatenated to the first tensor
    1513              :    * @param axis axis
    1514              :    * @param output output tensor to store the result
    1515              :    * @return Tensor concatenated tensor
    1516              :    */
    1517              :   static Tensor cat(const std::vector<Tensor> &tensors, int axis,
    1518              :                     Tensor &output);
    1519              : 
    1520              :   /**
    1521              :    * @brief     Print element
    1522              :    * @param[in] out out stream
    1523              :    */
    1524              :   void print(std::ostream &out) const;
    1525              : 
    1526              :   /**
    1527              :    * @brief     put data of Tensor
    1528              :    * @note      It is only effective when fsu is used
    1529              :    */
    1530              :   void putData() const;
    1531              : 
    1532              :   /**
    1533              :    * @brief Set the memory buffer for the tensor
    1534              :    *
    1535              :    * @param buf the memory buffer
    1536              :    * @param init intialize the buffer
    1537              :    */
    1538              :   void setData(const std::shared_ptr<MemoryData> buf, size_t off = 0,
    1539              :                bool init = false);
    1540              : 
    1541              :   /**
    1542              :    * @brief     return Data pointer of Tensor
    1543              :    * @retval    template T pointer (float pointer as default)
    1544              :    */
    1545              :   const std::shared_ptr<MemoryData> getMemoryData() const;
    1546              : 
    1547              :   /**
    1548              :    * @brief     return offset
    1549              :    */
    1550              :   size_t getOffset() const;
    1551              : 
    1552              :   /**
    1553              :    * @brief     Copy the Tensor
    1554              :    * @param[in] from Tensor to be copied
    1555              :    *
    1556              :    * @note copy can reshape the tensor to match the shape
    1557              :    * @note support copying data from multiple data type
    1558              :    */
    1559              :   void copy(const Tensor &from);
    1560              : 
    1561              :   /**
    1562              :    * @brief     Copy the Tensor
    1563              :    * @param[in] from Tensor to be copied
    1564              :    * @note      support copying data from multiple data type
    1565              :    */
    1566              :   void copyData(const Tensor &from);
    1567              : 
    1568              :   /**
    1569              :    * @brief     Copy the Tensor
    1570              :    * @param[in] from Tensor to be copied
    1571              :    * @note      only support copying data from tensor with the same data type
    1572              :    */
    1573              :   void copy_with_stride(const Tensor &from);
    1574              : 
    1575              :   /**
    1576              :    * @brief Get slice of the tensor, sliced by batch
    1577              :    * @param[in] offset offset in batch to start the slice
    1578              :    * @param[in] size size of the slice
    1579              :    * @retval slice of this tensor
    1580              :    * @note This function provides a slice of this tensor, and does not create a
    1581              :    * copy
    1582              :    */
    1583              :   Tensor getBatchSlice(size_t offset, unsigned int size) const;
    1584              : 
    1585              :   /**
    1586              :    * @brief Extract sub-tensor containing specified batch indices
    1587              :    *
    1588              :    * @param indices List of batch indices to extract (0-based) Duplicates are
    1589              :    * allowed and will result in the same batch data being copied multiple times.
    1590              :    * @return Tensor New tensor containing only specified batches  (copied
    1591              :    * tensor!)
    1592              :    *
    1593              :    * @details
    1594              :    * This function creates a new tensor containing copies of data from
    1595              :    * specified batch indices of the original tensor. The operation:
    1596              :    * - Requires the original tensor to be contiguous in memory
    1597              :    * - Preserves channel/height/width dimensions
    1598              :    * - Maintains data ordering within each batch
    1599              :    * - Uses memcpy for efficient memory operations
    1600              :    *
    1601              :    * @note Duplicate indices: If the same index appears multiple times, the
    1602              :    * corresponding batch data will be copied to each position in the output
    1603              :    * tensor. Example: indices {0, 1, 1} creates output with 3 batches where
    1604              :    * positions 1 and 2 contain identical copies  of input batch 1.
    1605              :    *
    1606              :    * @note
    1607              :    * - Time complexity: O(k*C*H*W) where k = num_indices
    1608              :    * - Memory complexity: O(k*C*H*W)
    1609              :    * - Thread-safe when using different indices in parallel
    1610              :    *
    1611              :    * @throw std::runtime_error If:
    1612              :    * - Tensor is not contiguous
    1613              :    * - Any index is out of bounds
    1614              :    */
    1615              :   Tensor getBatchSlice(const std::vector<unsigned int> &indices) const;
    1616              : 
    1617              :   /**
    1618              :    * @brief     Convient wrapper for inplace copy of @a this.
    1619              :    * @retval    Copied version of this
    1620              :    */
    1621              :   Tensor clone() const;
    1622              : 
    1623              :   /**
    1624              :    * @brief     Convient wrapper for inplace copy of @a this.
    1625              :    * @param[in] type output tensor data type
    1626              :    * @retval    Copied version of this
    1627              :    */
    1628              :   Tensor clone(ml::train::TensorDim::DataType type) const;
    1629              : 
    1630              :   /**
    1631              :    * @brief     Read the Tensor For FSU
    1632              :    *
    1633              :    */
    1634              :   void readFSU();
    1635              : 
    1636              :   /**
    1637              :    * @brief     Save the Tensor into file
    1638              :    * @param[in] file output file stream
    1639              :    */
    1640              :   void save(std::ostream &file);
    1641              : 
    1642              :   /**
    1643              :    * @brief     Read the Tensor from file
    1644              :    * @param[in] file input file stream
    1645              :    */
    1646              :   void read(std::ifstream &file, size_t start_offset = 0,
    1647              :             bool read_from_offset = false, int file_fd = -1);
    1648              : 
    1649              :   /**
    1650              :    * @brief     ReadSource
    1651              :    * @param[in] ReadSource input file source
    1652              :    */
    1653              :   void read(ReadSource src, size_t start_offset = 0,
    1654              :             bool read_from_offset = false);
    1655              : 
    1656              :   /**
    1657              :    * @brief     return argument index which value is max by batch
    1658              :    * @retval    unsigned int argument indices
    1659              :    */
    1660              :   std::vector<unsigned int> argmax() const;
    1661              : 
    1662              :   /**
    1663              :    * @brief     return argument index which value is min by batch
    1664              :    * @retval    unsigned int argument indices
    1665              :    */
    1666              :   std::vector<unsigned int> argmin() const;
    1667              : 
    1668              :   /**
    1669              :    * @brief Find top-K maximum values along the width dimension and return
    1670              :    * results as tensors
    1671              :    *
    1672              :    * @details This function computes the top-K maximum values and their
    1673              :    * corresponding indices along the **width** dimension for each batch,
    1674              :    * channel, and height slice. The operation preserves the original tensor
    1675              :    * format (NCHW/NHWC) while reducing the width dimension to size K. The
    1676              :    * indices are returned as a separate tensor of type `UINT32`.
    1677              :    *
    1678              :    * @param[in] k Number of largest elements to select (1 <= k <= width_size)
    1679              :    *
    1680              :    * @return std::pair<Tensor, Tensor> containing:
    1681              :    *         - First: Output tensor of shape [batch][channel][height][k] (NCHW)
    1682              :    * or [batch][height][k][channel] (NHWC) with top-K values
    1683              :    *         - Second: Indices tensor of shape [batch][channel][height][k]
    1684              :    * (NCHW) or [batch][height][k][channel] (NHWC) with original width positions
    1685              :    *
    1686              :    * @throw std::invalid_argument If:
    1687              :    *         - k is 0 or exceeds width dimension size
    1688              :    *         - Called on non-floating point tensor (UINT8/UINT16/etc)
    1689              :    *
    1690              :    * @note
    1691              :    * - Indices represent positions in the **original width dimension**
    1692              :    * - Sorting is done in descending order
    1693              :    * - Preserves tensor format (NCHW/NHWC) of the original tensor
    1694              :    */
    1695              :   std::pair<Tensor, Tensor> topK(unsigned int k) const;
    1696              : 
    1697              :   /**
    1698              :    * @brief     return max of the absolute values of the tensor
    1699              :    * @retval    maximum absolute value
    1700              :    */
    1701              :   float max_abs() const;
    1702              : 
    1703              :   /**
    1704              :    * @brief  return maximum value
    1705              :    * @retval Maximum value of the tensor data
    1706              :    */
    1707              :   float maxValue() const;
    1708              : 
    1709              :   /**
    1710              :    * @brief  return minimum value
    1711              :    * @retval Minimum value of the tensor data
    1712              :    */
    1713              :   float minValue() const;
    1714              : 
    1715              :   /**
    1716              :    * @brief  Transpose Tensor
    1717              :    * @param  direction to transpose ex) 0:2:1
    1718              :    * @return Tensor
    1719              :    */
    1720              :   Tensor transpose(const std::string &direction) const;
    1721              : 
    1722              :   /**
    1723              :    * @brief      Transpose Tensor
    1724              :    * @param      direction to transpose ex) 0:2:1
    1725              :    * @param[out] Tensor to save to, dimension is always reshaped.
    1726              :    * @retval     Tensor& reference to the out
    1727              :    */
    1728              :   Tensor &transpose(const std::string &direction, Tensor &out) const;
    1729              : 
    1730              :   /**
    1731              :    * @brief     set Tensor Dim
    1732              :    * @param[in] d TensorDim
    1733              :    * @note      Throws std::invalid_argument if size mismatch
    1734              :    */
    1735              :   void reshape(const TensorDim &d);
    1736              : 
    1737              :   /**
    1738              :    * @brief fill tensor data with current value,
    1739              :    * if dimension is not exactly same, it is a hard error in this function
    1740              :    * so, only stride is overriden to @a this
    1741              :    *
    1742              :    * @param from Tensor to fill the data from
    1743              :    * @param allocate if unallocated, allocate with from.getDim()
    1744              :    * @throws std::invalid_argument if dimension and stride does not match
    1745              :    */
    1746              :   void fill(const Tensor &from, bool allocate = false);
    1747              : 
    1748              :   /**
    1749              :    * @brief     return a copy of the Tensor Dim
    1750              :    * @retval    TensorDim
    1751              :    */
    1752              :   TensorDim getDim() const;
    1753              : 
    1754              :   /**
    1755              :    * @brief     return Tensor Type
    1756              :    */
    1757              :   TensorDim::TensorType getTensorType() const;
    1758              : 
    1759              :   /**
    1760              :    * @brief Get initializer for the tensor
    1761              :    *
    1762              :    * @return initializer of the tensor
    1763              :    */
    1764              :   Initializer getInitializer() const;
    1765              : 
    1766              :   /**
    1767              :    * @brief Get format for the tensor
    1768              :    * @return format of the tensor
    1769              :    */
    1770              :   TensorDim::Format getFormat() const;
    1771              : 
    1772              :   /**
    1773              :    * @brief Get data type for the tensor
    1774              :    *
    1775              :    * @return data type of the tensor
    1776              :    */
    1777              :   Tdatatype getDataType() const;
    1778              : 
    1779              :   /**
    1780              :    * @brief     update batch size for this tensor
    1781              :    * @param     batch size
    1782              :    * @note      The batchsize of src_tensor need not be related with this
    1783              :    * tensor's batch size
    1784              :    *
    1785              :    * @note      The memory for this tensor will re-allocated/re-assigned if the
    1786              :    * updated batch size is different than the current batch size.
    1787              :    *
    1788              :    * @note      If this tensor is/was the src_tensor for some other, then
    1789              :    * reduction in batch size can make the dependent tensors allocate fail due to
    1790              :    * memory smaller. Caller must handle this in their own end.
    1791              :    *
    1792              :    * @note      If this tensor is re-allocated, then the memory might not be
    1793              :    * immediately freed as the tensor already depending on this tensor also
    1794              :    * share the same memory. So, the peak memory consumption in worst case can
    1795              :    * reach the total memory requirements of a model with old batchsize and the
    1796              :    * new batch size. It is recommended to first deallocate all the tensors,
    1797              :    * updateBatch and then allocate again to avoid such issues.
    1798              :    */
    1799              :   void updateBatch(unsigned int batch);
    1800              : 
    1801              :   /**
    1802              :    * @brief     update the dimension for this tensor
    1803              :    * @param     dimension dimension to be updated
    1804              :    * @note      if this tensor is allocated this will throw an error.
    1805              :    * @note      we assume that the caller checks if the tensor is not allocated
    1806              :    */
    1807              :   void updateDimension(TensorDim dimension);
    1808              : 
    1809              :   /**
    1810              :    * @brief     return whether tensor is contiguous or not.
    1811              :    * @retval    bool contiguous
    1812              :    */
    1813              :   const bool getContiguous() const noexcept;
    1814              : 
    1815              :   /**
    1816              :    * @brief     return current stride of tensor.
    1817              :    * @retval    int[MAXDIM] strides
    1818              :    */
    1819              :   const std::array<size_t, TensorDim::MAXDIM> getStrides() const noexcept;
    1820              : 
    1821              :   /**
    1822              :    * @brief     Check if two given axes are contiguous
    1823              :    * @param[in] np1 first axis
    1824              :    * @param[in] np2 second axis to compare with first axis
    1825              :    * @retval    bool continuous
    1826              :    */
    1827              :   bool checkContinuous(unsigned int np1, unsigned int np2) const;
    1828              : 
    1829              :   /**
    1830              :    * @brief     set FileOffset to Tensor
    1831              :    * @param     off FileOffset
    1832              :    */
    1833              :   void setFileOffset(size_t file_offset);
    1834              : 
    1835              :   /**
    1836              :    * @brief     get FileOffset of Tensor
    1837              :    * @return    size_t fileOffset
    1838              :    */
    1839              :   size_t getFileOffset() const;
    1840              : 
    1841              :   /**
    1842              :    * @brief     Set name of the tensor
    1843              :    * @param[in] name_ tensor name
    1844              :    */
    1845              :   void setName(const std::string &name_);
    1846              : 
    1847              :   /**
    1848              :    * @brief     Get name of the tensor
    1849              :    * @retval    string name
    1850              :    */
    1851              :   const std::string &getName() const;
    1852              : 
    1853              :   /**
    1854              :    * @brief Get linear index given the n-d index
    1855              :    */
    1856              :   size_t getIndex(unsigned int b, unsigned int c, unsigned int h,
    1857              :                   unsigned int w) const noexcept;
    1858              :   /**
    1859              :    * @brief     Get size of current tensor
    1860              :    * @retval    unsigned int size of the current tensor
    1861              :    */
    1862              :   size_t size() const;
    1863              : 
    1864              :   /**
    1865              :    * @brief     Get if the tensor is empty
    1866              :    * @retval    true if the tensor is empty
    1867              :    */
    1868              :   bool empty() const;
    1869              : 
    1870              :   /**
    1871              :    * @brief     Get size of the data in bytes
    1872              :    * @retval    size_t Size in bytes
    1873              :    */
    1874              :   size_t bytes() const;
    1875              : 
    1876              :   /**
    1877              :    * @brief     Get a total size of the memory data in bytes
    1878              :    * @retval    size_t Size in bytes
    1879              :    * @note      This is the total size of the memory data, including the scale
    1880              :    * factors and the zero points. For float type, this will return the same as
    1881              :    * bytes()
    1882              :    */
    1883              :   size_t getMemoryBytes() const;
    1884              : 
    1885              :   /**
    1886              :    * @brief     return Tensor batch size
    1887              :    * @retval    batch size
    1888              :    */
    1889              :   size_t batch() const;
    1890              : 
    1891              :   /**
    1892              :    * @brief     return Tensor channel size
    1893              :    * @retval    channel size
    1894              :    */
    1895              :   size_t channel() const;
    1896              : 
    1897              :   /**
    1898              :    * @brief     return Tensor height size
    1899              :    * @retval    height size
    1900              :    */
    1901              :   size_t height() const;
    1902              : 
    1903              :   /**
    1904              :    * @brief     return Tensor width size
    1905              :    * @retval    width size
    1906              :    */
    1907              :   size_t width() const;
    1908              : 
    1909              :   /**
    1910              :    * @brief     return Tensor scale factor size if exists
    1911              :    * @retval    scale factor size
    1912              :    */
    1913              :   size_t scale_size() const;
    1914              : 
    1915              :   /**
    1916              :    * @brief     return Tensor quantization scheme
    1917              :    * @retval    Qscheme qscheme
    1918              :    */
    1919              :   QScheme q_scheme() const;
    1920              : 
    1921              :   /**
    1922              :    * @brief Merge the given two axis for tensor at second axis inplace
    1923              :    *
    1924              :    * @param axis1 first axis to merge
    1925              :    * @param axis2 second axis to merge
    1926              :    */
    1927              :   void mergeAxis(unsigned int axis1, unsigned int axis2);
    1928              : 
    1929              :   /**
    1930              :    * @brief Update destination tensor to share memory with source tensor
    1931              :    *
    1932              :    * @param src src tensor containing the memory
    1933              :    * @param dest destination tensor which will share the memory
    1934              :    * @param offset offset to be used from the start of the data in bytes
    1935              :    * @note The new tensor will share the same data as the current tensor but
    1936              :    * can have different size.
    1937              :    * @note New size added with offset must be less than the size of the original
    1938              :    * tensor.
    1939              :    */
    1940              :   void createSharedDataTensor(const Tensor &src, Tensor &dest,
    1941              :                               size_t offset) const;
    1942              : 
    1943              :   /**
    1944              :    * @brief Get new tensor which shares memory with current tensor but different
    1945              :    * shape
    1946              :    *
    1947              :    * @param dim new dimension to be set for this tensor
    1948              :    * @param offset offset to be used from the start of the data in elements
    1949              :    * @note The new tensor will share the same data as the current tensor but
    1950              :    * can have different size.
    1951              :    * @note New size added with offset must be less than the size of the original
    1952              :    * tensor.
    1953              :    */
    1954              :   Tensor getSharedDataTensor(const TensorDim dim_, size_t offset,
    1955              :                              bool reset_stride = true,
    1956              :                              const std::string &name_ = "") const;
    1957              : 
    1958              :   /**
    1959              :    * @brief    Swaps Tensor lhs and rhs
    1960              :    * @param[in] lhs Tensor to be swapped
    1961              :    * @param[in] rhs Tensor to be swapped
    1962              :    */
    1963              :   friend void swap(Tensor &lhs, Tensor &rhs) noexcept {
    1964              :     std::swap(lhs.itensor_, rhs.itensor_);
    1965              :   }
    1966              : 
    1967              :   /**
    1968              :    * @brief      check if there is NaN or Inf element
    1969              :    * @param[out] bool false if there is NaN or Inf else false
    1970              :    */
    1971           12 :   bool isValid() const { return itensor_->isValid(); };
    1972              : 
    1973              :   /**
    1974              :    * @brief check if tensor is virtual
    1975              :    * @param[out] bool false if tensor is not virtual else true
    1976              :    */
    1977              :   bool isVirtual() const { return is_virtual; }
    1978              : 
    1979              :   /**
    1980              :    * @brief activate virtual tensor
    1981              :    * @note if the tensor is virtual, this method activates virtual tensor, which
    1982              :    * means allocate the tensor memory and read the corresponding value from the
    1983              :    * file descriptor
    1984              :    * @todo it is not supported on Windows yet
    1985              :    */
    1986              :   void activate();
    1987              : 
    1988              :   /**
    1989              :    * @brief deactivate virtual tensor
    1990              :    * @note if the tensor is virtual and already activated, the tensor is
    1991              :    * deallocated.
    1992              :    */
    1993              :   void deactivate();
    1994              : 
    1995              :   static constexpr float epsilon = 1e-5f;
    1996              : 
    1997              : private:
    1998              :   std::unique_ptr<TensorBase> itensor_;
    1999              : 
    2000              :   /**
    2001              :    * @brief properties for virtual tensor
    2002              :    * @note This should be removed by defining VirutalTensor class
    2003              :    * */
    2004              :   bool is_virtual = false;    /** flag to check virtual */
    2005              :   size_t read_offset;         /** save read_offset info for virtual */
    2006              :   int fd = -1;                /** save fd info for virtual */
    2007              :   void *mapped_ptr = nullptr; /** save mmap buf pointer for virtual */
    2008              : 
    2009              :   /**
    2010              :    * @brief Set tensor variables
    2011              :    *
    2012              :    * @param[in] d TensorDim
    2013              :    * @param[in] buf buffer
    2014              :    * @param[in] offset offset to be used
    2015              :    */
    2016              :   void setTensorVar(TensorDim d, void *buf, size_t offset);
    2017              : 
    2018              :   /**
    2019              :    * @brief Calculate the output tensor dimension of the concatenating a list of
    2020              :    * tensors as an input.
    2021              :    *
    2022              :    * @param[in] tensors tensors to be concatenated to the first tensor
    2023              :    * @param[in] axis axis
    2024              :    */
    2025              :   static TensorDim calculateConcatOutputDim(const std::vector<Tensor> &tensors,
    2026              :                                             int axis);
    2027              : };
    2028              : 
    2029              : /**
    2030              :  * @brief   Overriding output stream
    2031              :  */
    2032              : std::ostream &operator<<(std::ostream &out, Tensor const &input);
    2033              : 
    2034              : typedef std::shared_ptr<Tensor> sharedTensor;
    2035              : 
    2036              : typedef std::shared_ptr<const Tensor> sharedConstTensor;
    2037              : 
    2038              : typedef std::vector<sharedConstTensor> sharedConstTensors;
    2039              : 
    2040              : typedef std::vector<sharedTensor> sharedTensors;
    2041              : 
    2042              : } // namespace nntrainer
    2043              : 
    2044              : #endif /* __cplusplus */
    2045              : #endif /* __TENSOR_H__ */
        

Generated by: LCOV version 2.0-1