Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file lazy_tensor.cpp
6 : * @date 05 Jun 2020
7 : * @brief A lazy evaluation calculator for tensors
8 : * @see https://github.com/nnstreamer/nntrainer
9 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : *
12 : */
13 :
14 : #include <lazy_tensor.h>
15 : #include <nntrainer_error.h>
16 :
17 : namespace nntrainer {
18 :
19 : /**
20 : * @brief Wrapper method of add_i (immediate version of add)
21 : * @retval this
22 : */
23 7 : LazyTensor &LazyTensor::add_i(float const &value) {
24 14 : call_chain.push_back(
25 14 : [value](Tensor &t) mutable -> int { return t.add_i(value); });
26 7 : return *this;
27 : }
28 : /**
29 : * @brief Wrapper method of add_i. see tensor.h for more detail
30 : * @param[in] m Tensor to be added
31 : * @retval LazyTensor *this
32 : */
33 2 : LazyTensor &LazyTensor::add_i(Tensor const &m, float const alpha) {
34 2 : auto f = [&m, alpha](Tensor &t) mutable -> int { return t.add_i(m, alpha); };
35 2 : call_chain.push_back(f);
36 2 : return *this;
37 : }
38 :
39 : /**
40 : * @brief Wrapper method of subtract_i. see tensor.h for more detail
41 : * @param[in] m Tensor to subtract
42 : * @retval LazyTensor *this
43 : */
44 1 : LazyTensor &LazyTensor::subtract_i(Tensor const &m) {
45 1 : auto f = [&m](Tensor &t) mutable -> int { return t.subtract_i(m); };
46 1 : call_chain.push_back(f);
47 1 : return *this;
48 : }
49 :
50 : /**
51 : * @brief Wrapper method of subtract_i. see tensor.h for more detail
52 : * @param[in] value value to subtract
53 : * @retval LazyTensor *this
54 : */
55 1 : LazyTensor &LazyTensor::subtract_i(float const &value) {
56 1 : auto f = [value](Tensor &t) mutable -> int { return t.subtract_i(value); };
57 1 : call_chain.push_back(f);
58 1 : return *this;
59 : }
60 :
61 : /**
62 : * @brief Wrapper method of multiply_i. see tensor.h for more detail
63 : * @param[in] value to be added
64 : * @retval LazyTensor *this
65 : */
66 1 : LazyTensor &LazyTensor::multiply_i(float const &value) {
67 1 : auto f = [value](Tensor &t) mutable -> int { return t.multiply_i(value); };
68 1 : call_chain.push_back(f);
69 1 : return *this;
70 : }
71 :
72 : /**
73 : * @brief Wrapper method of multiply_i. see tensor.h for more detail
74 : * @param[in] m Tensor to be multiplied
75 : * @retval LazyTensor *this
76 : */
77 5002 : LazyTensor &LazyTensor::multiply_i(Tensor const &m) {
78 5002 : auto f = [&m](Tensor &t) mutable -> int { return t.multiply_i(m); };
79 5002 : call_chain.push_back(f);
80 5002 : return *this;
81 : }
82 :
83 : /**
84 : * @brief Wrapper method of divide_i. see tensor.h for more detail
85 : * @param[in] value divisor
86 : * @retval LazyTensor *this
87 : */
88 1 : LazyTensor &LazyTensor::divide_i(float const &value) {
89 1 : auto f = [value](Tensor &t) mutable -> int { return t.divide_i(value); };
90 1 : call_chain.push_back(f);
91 1 : return *this;
92 : }
93 :
94 : /**
95 : * @brief Wrapper method of divide_i. see tensor.h for more detail
96 : * @param[in] m Tensor to for division
97 : * @retval LazyTensor *this
98 : */
99 2 : LazyTensor &LazyTensor::divide_i(Tensor const &m) {
100 2 : auto f = [&m](Tensor &t) mutable -> int { return t.divide_i(m); };
101 2 : call_chain.push_back(f);
102 2 : return *this;
103 : }
104 :
105 : /**
106 : * @brief Wrapper method of dot. see tensor.h for more detail (memcopy
107 : * happens)
108 : * @param[in] m Tensor
109 : * @retval LazyTensor *this
110 : */
111 0 : LazyTensor &LazyTensor::dot(Tensor const &m) {
112 0 : auto f = [&m](Tensor &t) mutable -> int {
113 : try {
114 0 : t = t.dot(m);
115 0 : return ML_ERROR_NONE;
116 0 : } catch (std::runtime_error &e) {
117 : return ML_ERROR_INVALID_PARAMETER;
118 0 : }
119 : };
120 :
121 0 : call_chain.push_back(f);
122 0 : return *this;
123 : }
124 :
125 : /**
126 : * @brief Wrapper method of transpose. see tensor.h for more detail (memcopy
127 : * happens)
128 : * @param[in] direction to transpose ex) 0:2:1
129 : * @retval LazyTensor *this
130 : */
131 0 : LazyTensor &LazyTensor::transpose(std::string direction) {
132 0 : auto f = [direction](Tensor &t) mutable -> int {
133 : try {
134 0 : t = t.transpose(direction);
135 0 : return ML_ERROR_NONE;
136 0 : } catch (std::runtime_error &e) {
137 : return ML_ERROR_INVALID_PARAMETER;
138 0 : }
139 : };
140 :
141 0 : call_chain.push_back(f);
142 0 : return *this;
143 : }
144 :
145 : /**
146 : * @brief Wrapper method of sum. see tensor.h for more detail (memcopy
147 : * happens)
148 : * @param[in] direction to transpose ex) 0:2:1
149 : * @retval LazyTensor *this
150 : */
151 1 : LazyTensor &LazyTensor::sum_by_batch() {
152 1 : auto f = [](Tensor &t) mutable -> int {
153 : try {
154 2 : t = t.sum_by_batch();
155 1 : return ML_ERROR_NONE;
156 0 : } catch (std::runtime_error &e) {
157 : return ML_ERROR_INVALID_PARAMETER;
158 0 : }
159 : };
160 :
161 1 : call_chain.push_back(f);
162 1 : return *this;
163 : }
164 :
165 : /**
166 : * @brief Wrapper method of sum. see tensor.h for more detail (memcopy
167 : * happens) 0 : batch direction 1 : channel direction 2 : channel direction 3 :
168 : * channel direction
169 : * @retval LazyTensor *this
170 : */
171 4 : LazyTensor &LazyTensor::sum(int axis) {
172 4 : auto f = [axis](Tensor &t) mutable -> int {
173 : try {
174 8 : t = t.sum(axis);
175 4 : return ML_ERROR_NONE;
176 0 : } catch (std::runtime_error &e) {
177 : return ML_ERROR_INVALID_PARAMETER;
178 0 : }
179 : };
180 :
181 4 : call_chain.push_back(f);
182 4 : return *this;
183 : }
184 :
185 : /**
186 : * @brief Wrapper method of average. see tensor.h for more detail (memcopy
187 : * happens)
188 : * @retval LazyTensor *this
189 : */
190 0 : LazyTensor &LazyTensor::average(int axis) {
191 0 : auto f = [axis](Tensor &t) mutable -> int {
192 : try {
193 0 : t = t.average(axis);
194 0 : return ML_ERROR_NONE;
195 0 : } catch (std::runtime_error &e) {
196 : return ML_ERROR_INVALID_PARAMETER;
197 0 : }
198 : };
199 :
200 0 : call_chain.push_back(f);
201 0 : return *this;
202 : }
203 :
204 : /**
205 : * @brief Wrapper method of average. see tensor.h for more detail (memcopy
206 : * happens)
207 : * @retval LazyTensor *this
208 : */
209 0 : LazyTensor &LazyTensor::average() {
210 0 : auto f = [](Tensor &t) mutable -> int {
211 : try {
212 0 : t = t.average();
213 0 : return ML_ERROR_NONE;
214 0 : } catch (std::runtime_error &e) {
215 : return ML_ERROR_INVALID_PARAMETER;
216 0 : }
217 : };
218 :
219 0 : call_chain.push_back(f);
220 0 : return *this;
221 : }
222 :
223 : /**
224 : * @brief execute the call_chain to evaluate
225 : * @retval calculated tensor
226 : */
227 5020 : Tensor LazyTensor::run() {
228 : int status;
229 10038 : for (auto &item : call_chain) {
230 5022 : status = item(target);
231 5022 : if (status != ML_ERROR_NONE) {
232 4 : throw std::runtime_error("Error: evaluation failed");
233 : }
234 : }
235 5016 : return target;
236 : }
237 :
238 : } /* namespace nntrainer */
|