Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file lazy_tensor.h
6 : * @date 05 Jun 2020
7 : * @brief A lazy evaluation calculator for tensors
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #ifndef __LAZY_TENSOR_H__
15 : #define __LAZY_TENSOR_H__
16 : #ifdef __cplusplus
17 :
18 : #include <tensor.h>
19 : #include <vector>
20 :
21 : namespace nntrainer {
22 :
23 : /**
24 : * @class LazyTensor a wrapper class for lazy calculation of tensor
25 : * @brief calculation is delayed until Tensor LazyTensor::run() is
26 : * called, can be contructed by Tensor::chain() method
27 : */
28 5020 : class LazyTensor {
29 : public:
30 : /**
31 : * @brief Constructor of Lazy Tensor, Tensor is copied to gaurantee
32 : * immutability
33 : */
34 10040 : LazyTensor(const Tensor &from) { target.copy(from); };
35 :
36 : /**
37 : * @brief Wrapper method of add_i. see tensor.h for more detail
38 : * @param[in] value to be added
39 : * @retval LazyTensor *this
40 : */
41 : LazyTensor &add_i(float const &value);
42 :
43 : /**
44 : * @brief Wrapper method of add_i. see tensor.h for more detail
45 : * @param[in] m Tensor to be added
46 : * @retval LazyTensor *this
47 : */
48 : LazyTensor &add_i(Tensor const &m, float const alpha = 1);
49 :
50 : /**
51 : * @brief Wrapper method of subtract_i. see tensor.h for more detail
52 : * @param[in] m Tensor to subtract
53 : * @retval LazyTensor *this
54 : */
55 : LazyTensor &subtract_i(Tensor const &m);
56 :
57 : /**
58 : * @brief Wrapper method of subtract_i. see tensor.h for more detail
59 : * @param[in] value value to subtract
60 : * @retval LazyTensor *this
61 : */
62 : LazyTensor &subtract_i(float const &value);
63 :
64 : /**
65 : * @brief Wrapper method of multiply_i. see tensor.h for more detail
66 : * @param[in] value to be added
67 : * @retval LazyTensor *this
68 : */
69 : LazyTensor &multiply_i(float const &value);
70 :
71 : /**
72 : * @brief Wrapper method of multiply_i. see tensor.h for more detail
73 : * @param[in] m Tensor to be multiplied
74 : * @retval LazyTensor *this
75 : */
76 : LazyTensor &multiply_i(Tensor const &m);
77 :
78 : /**
79 : * @brief Wrapper method of divide_i. see tensor.h for more detail
80 : * @param[in] value divisor
81 : * @retval LazyTensor *this
82 : */
83 : LazyTensor ÷_i(float const &value);
84 :
85 : /**
86 : * @brief Wrapper method of divide_i. see tensor.h for more detail
87 : * @param[in] m Tensor to for division
88 : * @retval LazyTensor *this
89 : */
90 : LazyTensor ÷_i(Tensor const &m);
91 :
92 : /**
93 : * @brief Wrapper method of dot. see tensor.h for more detail (memcopy
94 : * happens)
95 : * @param[in] m Tensor
96 : * @retval LazyTensor *this
97 : */
98 : LazyTensor &dot(Tensor const &m);
99 :
100 : /**
101 : * @brief Wrapper method of transpose. see tensor.h for more detail
102 : * (memcopy happens)
103 : * @param[in] direction to transpose ex) 0:2:1
104 : * @retval LazyTensor *this
105 : */
106 : LazyTensor &transpose(std::string direction);
107 :
108 : /**
109 : * @brief Wrapper method of sum_by_batch. see tensor.h for more detail
110 : * (memcopy happens)
111 : * @retval LazyTensor *this
112 : */
113 : LazyTensor &sum_by_batch();
114 :
115 : /**
116 : * @brief Wrapper method of sum. see tensor.h for more detail (memcopy
117 : * happens) 0 : batch direction 1 : channel direction 2 : height direction 3 :
118 : * width direction
119 : * @retval LazyTensor *this
120 : */
121 : LazyTensor &sum(int axis);
122 :
123 : /**
124 : * @brief Wrapper method of average. see tensor.h for more detail (memcopy
125 : * happens) 0 : batch direction 1 : channel direction 2 : height direction 3 :
126 : * width direction
127 : * @retval LazyTensor *this
128 : */
129 : LazyTensor &average(int axis);
130 :
131 : /**
132 : * @brief Wrapper method of average. see tensor.h for more detail (memcopy
133 : * happens)
134 : * @retval LazyTensor *this
135 : */
136 : LazyTensor &average();
137 :
138 : /**
139 : * @brief execute the call_chain to get the tensor
140 : * @retval calculated tensor
141 : */
142 : Tensor run();
143 :
144 : private:
145 : /**< handle the data as a std::vector type */
146 : std::vector<std::function<int(Tensor &)>> call_chain;
147 : Tensor target;
148 : };
149 :
150 : } /* namespace nntrainer */
151 :
152 : #endif /* __cplusplus */
153 : #endif /* __LAZY_TENSOR_H__ */
|