Line data Source code
1 : // SPDX-License-Identifier: Apache-2.0
2 : /**
3 : * Copyright (C) 2020 Jihoon Lee <jhoon.it.lee@samsung.com>
4 : *
5 : * @file acti_func.cpp
6 : * @date 22 March 2021
7 : * @see https://github.com/nnstreamer/nntrainer
8 : * @author Jihoon Lee <jhoon.it.lee@samsung.com>
9 : * @author Jijoong Moon <jijoong.moon@samsung.com>
10 : * @bug No known bugs except for NYI items
11 : * @brief This is Activation Function Class for Neural Network
12 : *
13 : */
14 :
15 : #ifndef __ACTI_FUNC_H__
16 : #define __ACTI_FUNC_H__
17 : #ifdef __cplusplus
18 :
19 : #include <common_properties.h>
20 : #include <cpu_backend.h>
21 :
22 : #if defined(_WIN32)
23 : #define _USE_MATH_DEFINES
24 : #include <math.h>
25 : #endif
26 :
27 : namespace nntrainer {
28 :
29 : class Tensor;
30 :
31 : /**
32 : * @class ActiFunc Class
33 : * @brief ActiFunc Class
34 : */
35 : class ActiFunc {
36 :
37 : public:
38 : constexpr static inline float NEGATIVE_SLOPE = 0.01f;
39 :
40 : /**
41 : * @brief Constructor of ActiFunc
42 : */
43 : template <typename T = float>
44 2213 : ActiFunc(ActivationType at = ActivationType::ACT_NONE,
45 : bool is_inplace_ = true) :
46 2213 : is_inplace(is_inplace_) {
47 2213 : setActiFunc<T>(at);
48 2213 : }
49 :
50 : /**
51 : * @brief Destructor of ActiFunc
52 : */
53 2213 : ~ActiFunc(){};
54 :
55 : /**
56 : * @brief setActivation by preset ActivationType
57 : *
58 : * @param[in] ActivationType
59 : */
60 9066 : template <typename T = float> void setActiFunc(ActivationType acti_type) {
61 9066 : activation_type = acti_type;
62 :
63 9066 : switch (acti_type) {
64 : case ActivationType::ACT_TANH:
65 1016 : this->setActivation<T>(tanhFloat<T>, tanhPrime<T>);
66 508 : break;
67 : case ActivationType::ACT_SIGMOID:
68 5072 : this->setActivation<T>(sigmoid<T>, sigmoidPrime<T>);
69 2536 : break;
70 : case ActivationType::ACT_SOFTMAX:
71 3852 : this->setActivation<Tensor>(softmax<T>, softmaxPrime<T>);
72 1926 : break;
73 : case ActivationType::ACT_RELU:
74 2574 : this->setActivation<T>(relu<T>, reluPrime<T>);
75 1287 : break;
76 : case ActivationType::ACT_LEAKY_RELU:
77 40 : this->setActivation<T>(leakyRelu<T>, leakyReluPrime<T>);
78 20 : break;
79 8 : case ActivationType::ACT_SWISH:
80 8 : is_inplace = false;
81 16 : this->setActivation<Tensor>(swish<T>, swishPrime<T>);
82 8 : break;
83 8 : case ActivationType::ACT_GELU:
84 8 : is_inplace = false;
85 16 : this->setActivation<Tensor>(gelu<T>, geluPrime<T>);
86 8 : break;
87 0 : case ActivationType::ACT_TANH_GELU:
88 0 : is_inplace = false;
89 0 : this->setActivation<Tensor>(tanhGelu<T>, tanhGeluPrime<T>);
90 0 : break;
91 0 : case ActivationType::ACT_SIGMOID_GELU:
92 0 : is_inplace = false;
93 0 : this->setActivation<Tensor>(sigmoidGelu<T>, sigmoidGeluPrime<T>);
94 0 : break;
95 : case ActivationType::ACT_ELU:
96 0 : this->setActivation<T>(elu<T>, eluPrime<T>);
97 0 : break;
98 : case ActivationType::ACT_SELU:
99 0 : this->setActivation<T>(selu<T>, seluPrime<T>);
100 0 : break;
101 : case ActivationType::ACT_SOFTPLUS:
102 0 : this->setActivation<T>(softplus<T>, softplusPrime<T>);
103 0 : break;
104 : case ActivationType::ACT_MISH:
105 0 : this->setActivation<T>(mish<T>, mishPrime<T>);
106 0 : break;
107 : case ActivationType::ACT_NONE:
108 5546 : this->setActivation<T>(no_op<T>, no_op_prime<T>);
109 2773 : break;
110 0 : case ActivationType::ACT_UNKNOWN:
111 : default:
112 0 : throw std::runtime_error("Error: Not Supported Activation Type");
113 : }
114 9066 : }
115 :
116 : /**
117 : * @brief run function
118 : *
119 : * @param[in] input : input
120 : * @param[out] output : output
121 : */
122 8455 : void run_fn(Tensor const &input, Tensor &output) { _act_fn(input, output); }
123 :
124 : /**
125 : * @brief run prime function
126 : *
127 : * @param[in] input input
128 : * @param[in] output output
129 : * @param[out] outgoing_derivative outgoing derivative
130 : * @param[in] incoming_derivative incoming derivative
131 : * @retVal Tensor
132 : */
133 : Tensor &run_prime_fn(Tensor &input, Tensor &output,
134 : Tensor &outgoing_derivative,
135 : Tensor const &incoming_derivative) {
136 : return _act_prime_fn(input, output, outgoing_derivative,
137 693 : incoming_derivative);
138 : }
139 :
140 : /**
141 : * @brief run prime function
142 : *
143 : * @param[in] output output
144 : * @param[out] outgoing_derivative outgoing derivative
145 : * @param[in] incoming_derivative incoming derivative
146 : * @retVal Tensor
147 : */
148 4421 : Tensor &run_prime_fn(Tensor &output, Tensor &outgoing_derivative,
149 : Tensor const &incoming_derivative) {
150 8842 : return _act_prime_fn(Tensor(), output, outgoing_derivative,
151 4421 : incoming_derivative);
152 : }
153 :
154 : /**
155 : * @copydoc Layer::supportInPlace()
156 : */
157 1176 : bool supportInPlace() const { return is_inplace; }
158 :
159 : /**
160 : * @brief Calculate softmax for Tensor Type
161 : * @param[in] input input Tensor
162 : * @param[out] output output Tensor
163 : * @retval Tensor
164 : */
165 : template <typename T = float>
166 1252 : static Tensor &softmax(Tensor const &input, Tensor &output) {
167 : /**
168 : * shiftx_logit = logit - max_batch(logit)
169 : * softmax = exp(shiftx_logit) / (sum(exp(shiftx_logit)))
170 : *
171 : * @note softmax is applied on the last dimension
172 : */
173 : /** TODO: support strided operations */
174 2158 : if (input.size() == output.size() &&
175 906 : input.getStrides() != output.getStrides())
176 1 : throw std::invalid_argument(
177 : "Softmax does not support operating on strided tensors");
178 :
179 1251 : unsigned int width = input.width();
180 1251 : unsigned int bch_size = input.getDim().getDataLen() / width;
181 :
182 : // copy will not executed in inplace case
183 1251 : output.copy(input);
184 :
185 : T *output_data = output.getData<T>();
186 :
187 : // prevent overflow
188 1251 : Tensor tmp(width, input.getTensorType());
189 13807 : for (unsigned int i = 0; i < bch_size; i++) {
190 12556 : T *ptr = output_data + i * width;
191 :
192 : // find max value and subtract it
193 12556 : T max_value = *std::max_element(ptr, ptr + width);
194 :
195 12556 : tmp.setValue(max_value);
196 12556 : saxpy(width, -1, tmp.getData<T>(), 1, ptr, 1);
197 : }
198 :
199 : // take exp
200 1251 : output.apply<T>(exp_util<T>, output);
201 :
202 : // take sum over the last dimension
203 1251 : Tensor sum = output.sum(3);
204 :
205 13807 : for (unsigned int i = 0; i < bch_size; i++) {
206 12556 : T *ptr = output_data + i * width;
207 12556 : std::transform(ptr, ptr + width, ptr,
208 : std::bind(std::divides<T>(), std::placeholders::_1,
209 : sum.getValue<T>(i)));
210 : }
211 :
212 1251 : return output;
213 1251 : }
214 :
215 : /**
216 : * @brief Calculate derivative of softmax function
217 : * @param[in] output output tensor
218 : * @param[out] outgoing_derivative result of calculated derivative of softmax
219 : * @param[in] incoming_derivative incoming derivative tensor from next layer
220 : * @retVal Tensor
221 : */
222 :
223 : template <typename T = float>
224 237 : static Tensor &softmaxPrime(Tensor const &output, Tensor &outgoing_derivative,
225 : Tensor const &incoming_derivative = Tensor()) {
226 : /** TODO: support strided operations */
227 :
228 473 : if ((output.size() == outgoing_derivative.size() &&
229 473 : output.getStrides() != outgoing_derivative.getStrides()) ||
230 471 : (output.size() == incoming_derivative.size() &&
231 235 : output.getStrides() != incoming_derivative.getStrides()))
232 1 : throw std::invalid_argument(
233 : "SoftmaxPrime does not support operating on strided tensors");
234 :
235 236 : unsigned int batch = output.batch();
236 236 : unsigned int channel = output.channel();
237 236 : unsigned int height = output.height();
238 236 : unsigned int width = output.width();
239 :
240 236 : if (outgoing_derivative.empty())
241 2 : outgoing_derivative = Tensor(output.getDim());
242 :
243 : const T *output_data = output.getData<T>();
244 : const T *incoming_derivative_data = incoming_derivative.getData<T>();
245 : T *outgoing_derivative_data = outgoing_derivative.getData<T>();
246 :
247 472 : Tensor tmp = Tensor(width, output.getTensorType());
248 : T *tmp_data = tmp.getData<T>();
249 236 : unsigned int output_width_stride = output.getStrides()[3];
250 1890 : for (unsigned int b = 0; b < batch; ++b) {
251 1654 : int b_offset = b * channel * height * width;
252 3308 : for (unsigned int c = 0; c < channel; ++c) {
253 1654 : int bc_offset = b_offset + c * height * width;
254 4960 : for (unsigned int h = 0; h < height; ++h) {
255 3306 : int bch_offset = bc_offset + h * width;
256 24065 : for (unsigned int w1 = 0; w1 < width; ++w1) {
257 : T sum = 0;
258 177634 : for (unsigned int w2 = 0; w2 < width; ++w2) {
259 : T val;
260 156875 : if (w1 == w2) {
261 20759 : val = output_data[bch_offset + w2] *
262 20759 : ((T)1 - output_data[bch_offset + w1]);
263 : } else {
264 136116 : val =
265 136116 : -output_data[bch_offset + w2] * output_data[bch_offset + w1];
266 : }
267 156875 : if (!incoming_derivative.empty())
268 156575 : val *= incoming_derivative_data[bch_offset + w2];
269 156875 : sum += val;
270 : }
271 20759 : tmp.setValue(0, 0, 0, w1, sum);
272 : }
273 3306 : scopy(width, tmp_data, 1, outgoing_derivative_data + bch_offset,
274 : output_width_stride);
275 : }
276 : }
277 : }
278 :
279 236 : return outgoing_derivative;
280 236 : }
281 :
282 : /**
283 : * @brief sigmoid activation function
284 : * @param[in] x input
285 : */
286 28336070 : template <typename T = float> static T sigmoid(T x) {
287 28351486 : return static_cast<T>(1.0 / (1.0 + exp_util<T>(-x)));
288 : }
289 :
290 : /**
291 : * @brief derivative sigmoid function
292 : * @param[in] x input
293 : */
294 2044733 : template <typename T = float> static T sigmoidPrime(T x) {
295 2044733 : return static_cast<T>(x * (static_cast<T>(1.0) - x));
296 : }
297 :
298 : /**
299 : * @brief tanh function for float type
300 : * @param[in] x input
301 : */
302 15236 : template <typename T = float> static T tanhFloat(T x) {
303 15296 : return static_cast<T>(2.0 * sigmoid<T>(static_cast<T>(2.0) * x) - 1.0);
304 : }
305 :
306 : /**
307 : * @brief derivative tanh function
308 : * @param[in] x input
309 : */
310 5753 : template <typename T = float> static T tanhPrime(T x) {
311 5783 : return static_cast<T>(1.0 - x * x);
312 : }
313 :
314 : /**
315 : * @brief relu activation function
316 : * @param[in] x input
317 : */
318 943944 : template <typename T = float> static T relu(T x) {
319 943944 : if (x <= 0)
320 547262 : return 0;
321 : return x;
322 : }
323 :
324 : /**
325 : * @brief derivative relu function
326 : * @param[in] x input
327 : */
328 428597 : template <typename T = float> static T reluPrime(T x) {
329 428597 : if (x <= 0)
330 265759 : return 0;
331 : return 1;
332 : }
333 :
334 : /**
335 : * @brief no_op function
336 : * @param[in] x input
337 : */
338 0 : template <typename T = float> static T no_op(T x) { return x; }
339 :
340 : /**
341 : * @brief no_op function
342 : * @param[in] x input
343 : */
344 0 : template <typename T = float> static T no_op_prime(T x) { return 1; }
345 :
346 : /**
347 : * @brief leaky relu function
348 : * @note slope parameter is needed for leaky relu, but supporting property on
349 : * this class will need extensive refactoring. For now 0.01 is used for
350 : * negative slope.
351 : *
352 : * @param x input
353 : * @return float output
354 : */
355 120 : template <typename T = float> static T leakyRelu(T x) {
356 120 : return x >= static_cast<T>(0.0) ? x : static_cast<T>(NEGATIVE_SLOPE) * x;
357 : }
358 :
359 : /**
360 : * @brief leaky relu prime function
361 : * @note slope parameter is needed for leaky relu, but supporting property on
362 : * this class will need extensive refactoring. For now 0.01 is used for
363 : * negative slope.
364 : *
365 : * @param x input
366 : * @return float output
367 : */
368 30 : template <typename T = float> static T leakyReluPrime(T x) {
369 30 : return x >= static_cast<T>(0.0) ? static_cast<T>(1.0)
370 30 : : static_cast<T>(NEGATIVE_SLOPE);
371 : }
372 :
373 : /**
374 : * @brief Softplus activation function
375 : * @tparam T type of an input/output
376 : * @param x input
377 : * @return T type output
378 : */
379 30 : template <typename T = float> static T softplus(T x) {
380 : /** TODO: Change beta to be a property */
381 90 : return static_cast<T>(log(1 + exp(beta * x)) / beta);
382 : }
383 :
384 : /**
385 : * @brief derivative softplus function
386 : * @tparam T type of an input/output
387 : * @param x input
388 : * @return T type output
389 : */
390 30 : template <typename T = float> static T softplusPrime(T x) {
391 30 : return sigmoid<T>(static_cast<T>(beta * x));
392 : }
393 :
394 : /**
395 : * @brief swish activation function
396 : * @param[in] t_in input tensor
397 : * @param[in] t_out output tensor
398 : */
399 : template <typename T = float>
400 2 : static Tensor &swish(Tensor const &t_in, Tensor &t_out) {
401 2 : t_in.apply<T>([&](T x) { return sigmoid<T>(x); }, t_out);
402 2 : t_out.multiply_i(t_in);
403 :
404 2 : return t_out;
405 : }
406 :
407 : /**
408 : * @brief derivative swish function
409 : * @param[in] t_in input tensor
410 : * @param[in] t_out output tensor
411 : * @param[in] outgoing_derivative outgoing derivative
412 : * @param[in] incoming_derivative incoming derivative
413 : */
414 : template <typename T = float>
415 1 : static Tensor &swishPrime(Tensor const &t_in, Tensor const &t_out,
416 : Tensor &outgoing_derivative,
417 : Tensor const &incoming_derivative = Tensor()) {
418 1 : if (outgoing_derivative.empty())
419 0 : outgoing_derivative = Tensor(t_out.getDim());
420 :
421 1 : Tensor tmp = Tensor(t_out.getDim());
422 2 : t_in.apply<T>([&](T x) { return sigmoid(x); }, outgoing_derivative);
423 1 : t_out.apply<T>([&](T x) { return 1 - x; }, tmp);
424 1 : outgoing_derivative.multiply_i(tmp);
425 1 : outgoing_derivative.add_i(t_out);
426 :
427 1 : outgoing_derivative.multiply_i_strided(incoming_derivative);
428 :
429 1 : return outgoing_derivative;
430 1 : }
431 :
432 : /**
433 : * @brief gelu activation function
434 : * @param[in] t_in input tensor
435 : * @param[in] t_out output tensor
436 : */
437 : template <typename T = float>
438 2 : static Tensor &gelu(Tensor const &t_in, Tensor &t_out) {
439 2 : double tmp = 1.0 / sqrt(2.0);
440 2 : t_in.apply<T>(
441 60 : [&](T x) { return static_cast<T>(0.5 * x * (1 + erf(x * tmp))); }, t_out);
442 2 : return t_out;
443 : }
444 :
445 : /**
446 : * @brief derivative gelu function
447 : * @param[in] t_in input tensor
448 : * @param[in] t_out output tensor
449 : * @param[in] outgoing_derivative outgoing derivative
450 : * @param[in] incoming_derivative incoming derivative
451 : */
452 : template <typename T = float>
453 1 : static Tensor &geluPrime(Tensor const &t_in, Tensor const &t_out,
454 : Tensor &outgoing_derivative,
455 : Tensor const &incoming_derivative = Tensor()) {
456 :
457 1 : if (outgoing_derivative.empty())
458 0 : outgoing_derivative = Tensor(t_out.getDim());
459 :
460 1 : T tmp = static_cast<T>(1 / sqrt(2));
461 1 : t_in.apply<T>(
462 30 : [&](T x) {
463 : return static_cast<T>(
464 30 : 0.5 * (1 + erf(x * tmp) +
465 30 : x * ((2 / sqrt(M_PI)) * exp(-pow(x * tmp, 2))) * tmp));
466 : },
467 : outgoing_derivative);
468 :
469 1 : outgoing_derivative.multiply_i_strided(incoming_derivative);
470 :
471 1 : return outgoing_derivative;
472 : }
473 :
474 : /**
475 : * @brief tanh-based gelu approximate function
476 : * @param[in] t_in input tensor
477 : * @param[in] t_out output tensor
478 : */
479 : template <typename T = float>
480 0 : static Tensor &tanhGelu(Tensor const &t_in, Tensor &t_out) {
481 0 : t_in.apply<T>(
482 0 : [&](T x) {
483 : return static_cast<T>(
484 0 : 0.5 * x *
485 0 : (1 + tanhFloat<T>(
486 0 : static_cast<T>(sqrt(2 / M_PI) * (x + 0.044715 * pow(x, 3))))));
487 : },
488 : t_out);
489 0 : return t_out;
490 : }
491 :
492 : /**
493 : * @brief derivative of tanh-based gelu approximate function
494 : * @param[in] t_in input tensor
495 : * @param[in] t_out output tensor
496 : * @param[in] outgoing_derivative outgoing derivative
497 : * @param[in] incoming_derivative incoming derivative
498 : */
499 : template <typename T = float>
500 0 : static Tensor &tanhGeluPrime(Tensor const &t_in, Tensor const &t_out,
501 : Tensor &outgoing_derivative,
502 : Tensor const &incoming_derivative = Tensor()) {
503 : // NYI
504 0 : ml_logw("tanhGeluPrime which is calculate derivate of tanhGelu function is "
505 : "not yet implemented");
506 0 : return outgoing_derivative;
507 : }
508 :
509 : /**
510 : * @brief sigmoid-based gelu approximate function (quick gelu)
511 : * @param[in] t_in input tensor
512 : * @param[in] t_out output tensor
513 : */
514 : template <typename T = float>
515 0 : static Tensor &sigmoidGelu(Tensor const &t_in, Tensor &t_out) {
516 0 : t_in.apply<T>(
517 : [&](T x) {
518 0 : return static_cast<T>(x * (sigmoid<T>(static_cast<T>(1.702 * x))));
519 : },
520 : t_out);
521 0 : return t_out;
522 : }
523 :
524 : /**
525 : * @brief derivative of sigmoid-based gelu approximate function
526 : * @param[in] t_in input tensor
527 : * @param[in] t_out output tensor
528 : * @param[in] outgoing_derivative outgoing derivative
529 : * @param[in] incoming_derivative incoming derivative
530 : */
531 : template <typename T = float>
532 : static Tensor &
533 0 : sigmoidGeluPrime(Tensor const &t_in, Tensor const &t_out,
534 : Tensor &outgoing_derivative,
535 : Tensor const &incoming_derivative = Tensor()) {
536 : // NYI
537 0 : ml_logw("sigmoidGeluPrime which is calculate derivate of sigmoidGelu "
538 : "function is not yet implemented");
539 0 : return outgoing_derivative;
540 : }
541 :
542 : /**
543 : * @brief elu function
544 : * @note alpha parameter is needed for elu, but supporting property on
545 : * this class will need extensive refactoring. For now 1.0 is used for
546 : * alpha.
547 : *
548 : * @tparam T type of an input/output
549 : * @param x input
550 : * @return T type output
551 : */
552 30 : template <typename T = float> static T elu(T x) {
553 30 : return x >= static_cast<T>(0.0) ? x : static_cast<T>(alpha * (exp(x) - 1));
554 : }
555 :
556 : /**
557 : * @brief elu prime function
558 : * @note alpha parameter is needed for elu, but supporting property on
559 : * this class will need extensive refactoring. For now 1.0 is used for
560 : * alpha.
561 : *
562 : * @tparam T type of an input/output
563 : * @param x input
564 : * @return T type output
565 : */
566 30 : template <typename T = float> static T eluPrime(T x) {
567 30 : return x >= static_cast<T>(0.0) ? static_cast<T>(1.0)
568 12 : : static_cast<T>(alpha * exp(x));
569 : }
570 :
571 : /**
572 : * @brief selu function
573 : * @tparam T type of an input/output
574 : * @param x input
575 : * @return T type output
576 : */
577 30 : template <typename T = float> static T selu(T x) {
578 : return x > static_cast<T>(0.0)
579 30 : ? static_cast<T>(selu_scale * x)
580 15 : : static_cast<T>(selu_scale * selu_alpha * (exp(x) - 1));
581 : }
582 :
583 : /**
584 : * @brief selu prime function
585 : * @tparam T type of an input/output
586 : * @param x input
587 : * @return T type output
588 : */
589 30 : template <typename T = float> static T seluPrime(T x) {
590 : return x > static_cast<T>(0.0)
591 30 : ? static_cast<T>(selu_scale)
592 15 : : static_cast<T>(selu_scale * selu_alpha * exp(x));
593 : }
594 :
595 : /**
596 : * @brief mish activation function
597 : * @param[in] x input
598 : */
599 30 : template <typename T = float> static T mish(T x) {
600 30 : return static_cast<T>(x * tanhFloat<T>(softplus<T>(x)));
601 : }
602 :
603 : /**
604 : * @brief mish prime function
605 : * @param[in] x input
606 : */
607 30 : template <typename T = float> static T mishPrime(T x) {
608 : return static_cast<T>(tanhFloat<T>(softplus<T>(x)) +
609 30 : x * softplusPrime<T>(x) *
610 30 : tanhPrime<T>(tanhFloat<T>(softplus<T>(x))));
611 : }
612 :
613 : /**
614 : * @brief setActivation by custom activation function
615 : * @note apply derivative as this activation_prime_fn does not utilize
616 : * derivative
617 : * @param[in] std::function<Tensor(Tensor const &, Tensor &)> activation_fn
618 : * activation function to be used
619 : * @param[in] std::function<Tensor(Tensor const &, Tensor &)>
620 : * activation_prime_fn activation_prime_function to be used
621 : * @retval #ML_ERROR_NONE when successful
622 : */
623 : template <typename funcParam = Tensor>
624 1926 : int setActivation(
625 : std::function<funcParam &(funcParam const &, funcParam &)> const
626 : &activation_fn,
627 : std::function<funcParam &(funcParam &, funcParam &,
628 : funcParam const &)> const &activation_prime_fn) {
629 1926 : _act_fn = activation_fn;
630 5778 : _act_prime_fn = [activation_prime_fn](
631 : funcParam const &t_in, funcParam &t_out,
632 : funcParam &outgoing_derivative,
633 : funcParam const &incoming_derivative) -> funcParam & {
634 : return activation_prime_fn(t_out, outgoing_derivative,
635 : incoming_derivative);
636 : };
637 :
638 1926 : return ML_ERROR_NONE;
639 : }
640 :
641 : /**
642 : * @brief setActivation by custom activation function
643 : * @note derivative not applied here as this activation_prime_fn applies
644 : * derivative itself
645 : * @param[in] std::function<Tensor(Tensor const &, Tensor &)> activation_fn
646 : * activation function to be used
647 : * @param[in] std::function<Tensor(Tensor const &, Tensor &, Tensor const &)>
648 : * activation_prime_fn activation_prime_function to be used
649 : * @retval #ML_ERROR_NONE when successful
650 : */
651 : template <typename funcParam = Tensor>
652 16 : int setActivation(
653 : std::function<funcParam &(funcParam const &, funcParam &)> const
654 : &activation_fn,
655 : std::function<funcParam &(funcParam const &, funcParam const &, funcParam &,
656 : funcParam const &)> const &activation_prime_fn) {
657 16 : if (is_inplace)
658 : return ML_ERROR_INVALID_PARAMETER;
659 :
660 16 : _act_fn = activation_fn;
661 16 : _act_prime_fn = activation_prime_fn;
662 :
663 16 : return ML_ERROR_NONE;
664 : }
665 :
666 : /**
667 : * @brief setActivation by custom activation function
668 : * @note derivative not applied here as this activation_prime_fn applies
669 : * derivative itself
670 : * @param[in] activation_fn activation function to be used
671 : * @param[in] activtion_prime_fn activation prime function to be used
672 : * @retval #ML_ERROR_NONE when successful
673 : */
674 : template <typename funcParam = Tensor>
675 : int setActivation(
676 : std::function<funcParam &(funcParam const &, funcParam &)> const
677 : &activation_fn,
678 : std::function<funcParam &(funcParam &, funcParam &)> const
679 : &activation_prime_fn) {
680 : if (!is_inplace) {
681 : _act_prime_fn = [activation_prime_fn](
682 : funcParam const &t_in, funcParam &t_out,
683 : funcParam &outgoing_derivative,
684 : funcParam const &incoming_derivative) -> funcParam & {
685 : /** @todo update this based on supportInPlace */
686 : activation_prime_fn(t_out, outgoing_derivative);
687 : outgoing_derivative.multiply_i_strided(incoming_derivative);
688 :
689 : return outgoing_derivative;
690 : };
691 : } else {
692 : _act_prime_fn = [activation_prime_fn](
693 : funcParam const &t_in, funcParam &t_out,
694 : funcParam &outgoing_derivative,
695 : funcParam const &incoming_derivative) -> funcParam & {
696 : activation_prime_fn(t_out, t_out);
697 : incoming_derivative.multiply_strided(t_out, outgoing_derivative);
698 :
699 : return outgoing_derivative;
700 : };
701 : }
702 :
703 : return ML_ERROR_NONE;
704 : }
705 :
706 : /**
707 : * @brief setActivation by custom activation function
708 : * @note apply derivative as this activation_prime_fn does not utilize
709 : * derivative
710 : * @param[in] std::function<float(float const &)> activation_fn activation
711 : * function to be used
712 : * @param[in] std::function<float(float const &)> activation_prime_fn
713 : * activation_prime_function to be used
714 : * @retval #ML_ERROR_NONE when successful
715 : */
716 : template <typename funcParam = float>
717 7124 : int setActivation(
718 : std::function<funcParam(funcParam const)> const &activation_fn,
719 : std::function<funcParam(funcParam const)> const &activation_prime_fn) {
720 21372 : _act_fn = [activation_fn](Tensor const &x, Tensor &hidden) -> Tensor & {
721 18764 : return x.apply(activation_fn, hidden);
722 : };
723 7124 : if (!is_inplace) {
724 274 : _act_prime_fn =
725 1096 : [activation_prime_fn](Tensor const &t_in, Tensor &t_out,
726 : Tensor &outgoing_derivative,
727 : Tensor const &incoming_derivative) -> Tensor & {
728 : /** @todo update this based on supportInPlace */
729 308 : t_out.apply(activation_prime_fn, outgoing_derivative);
730 308 : outgoing_derivative.multiply_i_strided(incoming_derivative);
731 :
732 308 : return outgoing_derivative;
733 : };
734 : } else {
735 6850 : _act_prime_fn =
736 27400 : [activation_prime_fn](Tensor const &t_in, Tensor &t_out,
737 : Tensor &outgoing_derivative,
738 : Tensor const &incoming_derivative) -> Tensor & {
739 4571 : t_out.apply(activation_prime_fn, t_out);
740 4571 : incoming_derivative.multiply_strided(t_out, outgoing_derivative);
741 :
742 4571 : return outgoing_derivative;
743 : };
744 : }
745 :
746 7124 : return ML_ERROR_NONE;
747 : }
748 :
749 : /**
750 : * @brief setActivation by custom activation function
751 : * @note apply derivative as this activation_prime_fn does not utilize
752 : * derivative
753 : * @param[in] std::function<float(float const)> activation_fn activation
754 : * function to be used
755 : * @param[in] std::function<float(float const, float const)>
756 : * activation_prime_fn activation_prime_function to be used
757 : * @retval #ML_ERROR_NONE when successful
758 : */
759 : int setActivation(
760 : std::function<float(float const)> const &activation_fn,
761 : std::function<float(float const, float const)> const &activation_prime_fn);
762 :
763 : /**
764 : * @brief Notify that this layer will execute in-place
765 : *
766 : * @param val True if execute in-place, else false
767 : */
768 432 : void setInPlace(bool val) {
769 432 : if (val && !supportInPlace())
770 : throw std::runtime_error(
771 0 : "Error setting activation layer to work in-place");
772 :
773 432 : is_inplace = val;
774 432 : }
775 :
776 : private:
777 : constexpr static inline float alpha = 1.0f; /**< alpha for elu */
778 : constexpr static inline float beta = 1.0f; /**< beta for Softplus */
779 : constexpr static inline float selu_alpha = 1.67326324f; /**< alpha for selu */
780 : constexpr static inline float selu_scale = 1.05070098f; /**< scale for selu */
781 :
782 : std::function<Tensor &(Tensor const &, Tensor &)> _act_fn;
783 : std::function<Tensor &(Tensor const &, Tensor &, Tensor &, Tensor const &)>
784 : _act_prime_fn; /**< prime function with input and output*/
785 :
786 : ActivationType
787 : activation_type; /**< type of the activation represented by this */
788 : bool is_inplace; /**< if this class should operate is_inplace */
789 : };
790 :
791 : } // namespace nntrainer
792 :
793 : #endif /* __cplusplus */
794 : #endif /* __ACTI_FUNC_H__ */
|