Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2023 Jijoong Moon <jijoong.moon@@samsung.com>
4 : *
5 : * @file tensor_api.h
6 : * @date 11 December 2023
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Jijoong Moon <jijoong.moon@samsung.com>
9 : * @bug No known bugs except for NYI items
10 : * @brief This is Tensor interface for c++ API
11 : *
12 : * @note This is experimental API and not stable.
13 : */
14 :
15 : #ifndef __ML_TRAIN_TENSOR_H__
16 : #define __ML_TRAIN_TENSOR_H__
17 :
18 : #if __cplusplus < MIN_CPP_VERSION
19 : #error "CPP versions c++17 or over are only supported"
20 : #endif // __cpluscplus
21 :
22 : #include <layer.h>
23 : #include <tensor.h>
24 : #include <tuple>
25 : #include <var_grad.h>
26 :
27 : using iTensor = nntrainer::Tensor;
28 :
29 : namespace ml {
30 : namespace train {
31 :
32 : /**
33 : * @class Tensor
34 : * @brief Tensor extends over Var_Grad for the API
35 : */
36 : class Tensor : public nntrainer::Var_Grad {
37 : public:
38 : /**
39 : * @brief Weight default constructor
40 : */
41 1 : Tensor() : nntrainer::Var_Grad() {}
42 :
43 : /**
44 : * @brief Construct a new Tensor object
45 : *
46 : * @param dim Variable and gradient tensor dimension
47 : * @param init Initializer for the Tensor
48 : * @param needg If the tensor needs gradient
49 : * @param name Name for this tensor
50 : */
51 : explicit Tensor(
52 : const TensorDim &dim,
53 : const nntrainer::Initializer init = nntrainer::Initializer::ZEROS,
54 : bool ng = false, std::string name = ""){};
55 :
56 : /**
57 : * @brief Swap for weight
58 : *
59 : * @param lhs Swap to
60 : * @param rhs Swap from
61 : * @note Only swap gradient if need gradient
62 : */
63 : friend void swap(Tensor &lhs, Tensor &rhs) noexcept {
64 : using std::swap;
65 : swap(static_cast<Var_Grad &>(lhs), static_cast<Var_Grad &>(rhs));
66 : }
67 :
68 : /**
69 : * @brief Copy constructor for weight
70 : *
71 : * @param rhs weight to construct from
72 : */
73 : Tensor(const Tensor &rhs) = default;
74 :
75 : /**
76 : * @brief Move constructor for weight
77 : *
78 : * @param rhs weight to construct from
79 : */
80 : Tensor(Tensor &&rhs) = default;
81 :
82 : /**
83 : * @brief copy assigment
84 : *
85 : * @param rhs copy from
86 : * @return Tensor& Updated weight
87 : */
88 : Tensor &operator=(const Tensor &rhs) = default;
89 :
90 : /**
91 : * @brief move assignment
92 : *
93 : * @param rhs move from
94 : * @return Tensor& Updated weight
95 : */
96 : Tensor &operator=(Tensor &&rhs) = default;
97 :
98 : /**
99 : * @brief Clone the currnet object
100 : *
101 : * @return Cloned copy
102 : */
103 : Tensor clone() const {
104 : Tensor t(*this);
105 : if (!this->var->empty())
106 : t.var = std::make_shared<iTensor>(this->var->clone());
107 : if (!this->grad->empty())
108 : t.grad = std::make_shared<iTensor>(this->grad->clone());
109 :
110 : return t;
111 : }
112 :
113 : /**
114 : * @brief source layer setter
115 : *
116 : */
117 : void setSrcLayer(std::shared_ptr<Layer> l) { src_layer = l; }
118 :
119 : /**
120 : * @brief source layer getter
121 : *
122 : * @return Layer
123 : */
124 : std::shared_ptr<Layer> getSrcLayer() { return src_layer; }
125 :
126 : private:
127 : std::shared_ptr<Layer>
128 : src_layer; /**< source layer which create this Tensor */
129 : };
130 :
131 : } // namespace train
132 : } // namespace ml
133 : #endif // __ML_TRAIN_TENSOR_H__
|