LCOV - code coverage report
Current view: top level - nntrainer/layers - acti_func.h (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 82.1 % 224 184
Test Date: 2025-12-14 20:38:17 Functions: 82.9 % 41 34

            Line data    Source code
       1              : // SPDX-License-Identifier: Apache-2.0
       2              : /**
       3              :  * Copyright (C) 2020 Jihoon Lee <jhoon.it.lee@samsung.com>
       4              :  *
       5              :  * @file   acti_func.cpp
       6              :  * @date   22 March 2021
       7              :  * @see    https://github.com/nnstreamer/nntrainer
       8              :  * @author Jihoon Lee <jhoon.it.lee@samsung.com>
       9              :  * @author Jijoong Moon <jijoong.moon@samsung.com>
      10              :  * @bug    No known bugs except for NYI items
      11              :  * @brief  This is Activation Function Class for Neural Network
      12              :  *
      13              :  */
      14              : 
      15              : #ifndef __ACTI_FUNC_H__
      16              : #define __ACTI_FUNC_H__
      17              : #ifdef __cplusplus
      18              : 
      19              : #include <common_properties.h>
      20              : #include <cpu_backend.h>
      21              : 
      22              : #if defined(_WIN32)
      23              : #define _USE_MATH_DEFINES
      24              : #include <math.h>
      25              : #endif
      26              : 
      27              : namespace nntrainer {
      28              : 
      29              : class Tensor;
      30              : 
      31              : /**
      32              :  * @class   ActiFunc Class
      33              :  * @brief   ActiFunc Class
      34              :  */
      35              : class ActiFunc {
      36              : 
      37              : public:
      38              :   constexpr static inline float NEGATIVE_SLOPE = 0.01f;
      39              : 
      40              :   /**
      41              :    * @brief     Constructor of ActiFunc
      42              :    */
      43              :   template <typename T = float>
      44         2213 :   ActiFunc(ActivationType at = ActivationType::ACT_NONE,
      45              :            bool is_inplace_ = true) :
      46         2213 :     is_inplace(is_inplace_) {
      47         2213 :     setActiFunc<T>(at);
      48         2213 :   }
      49              : 
      50              :   /**
      51              :    * @brief     Destructor of ActiFunc
      52              :    */
      53         2213 :   ~ActiFunc(){};
      54              : 
      55              :   /**
      56              :    * @brief setActivation by preset ActivationType
      57              :    *
      58              :    * @param[in] ActivationType
      59              :    */
      60         9066 :   template <typename T = float> void setActiFunc(ActivationType acti_type) {
      61         9066 :     activation_type = acti_type;
      62              : 
      63         9066 :     switch (acti_type) {
      64              :     case ActivationType::ACT_TANH:
      65         1016 :       this->setActivation<T>(tanhFloat<T>, tanhPrime<T>);
      66          508 :       break;
      67              :     case ActivationType::ACT_SIGMOID:
      68         5072 :       this->setActivation<T>(sigmoid<T>, sigmoidPrime<T>);
      69         2536 :       break;
      70              :     case ActivationType::ACT_SOFTMAX:
      71         3852 :       this->setActivation<Tensor>(softmax<T>, softmaxPrime<T>);
      72         1926 :       break;
      73              :     case ActivationType::ACT_RELU:
      74         2574 :       this->setActivation<T>(relu<T>, reluPrime<T>);
      75         1287 :       break;
      76              :     case ActivationType::ACT_LEAKY_RELU:
      77           40 :       this->setActivation<T>(leakyRelu<T>, leakyReluPrime<T>);
      78           20 :       break;
      79            8 :     case ActivationType::ACT_SWISH:
      80            8 :       is_inplace = false;
      81           16 :       this->setActivation<Tensor>(swish<T>, swishPrime<T>);
      82            8 :       break;
      83            8 :     case ActivationType::ACT_GELU:
      84            8 :       is_inplace = false;
      85           16 :       this->setActivation<Tensor>(gelu<T>, geluPrime<T>);
      86            8 :       break;
      87            0 :     case ActivationType::ACT_TANH_GELU:
      88            0 :       is_inplace = false;
      89            0 :       this->setActivation<Tensor>(tanhGelu<T>, tanhGeluPrime<T>);
      90            0 :       break;
      91            0 :     case ActivationType::ACT_SIGMOID_GELU:
      92            0 :       is_inplace = false;
      93            0 :       this->setActivation<Tensor>(sigmoidGelu<T>, sigmoidGeluPrime<T>);
      94            0 :       break;
      95              :     case ActivationType::ACT_ELU:
      96            0 :       this->setActivation<T>(elu<T>, eluPrime<T>);
      97            0 :       break;
      98              :     case ActivationType::ACT_SELU:
      99            0 :       this->setActivation<T>(selu<T>, seluPrime<T>);
     100            0 :       break;
     101              :     case ActivationType::ACT_SOFTPLUS:
     102            0 :       this->setActivation<T>(softplus<T>, softplusPrime<T>);
     103            0 :       break;
     104              :     case ActivationType::ACT_MISH:
     105            0 :       this->setActivation<T>(mish<T>, mishPrime<T>);
     106            0 :       break;
     107              :     case ActivationType::ACT_NONE:
     108         5546 :       this->setActivation<T>(no_op<T>, no_op_prime<T>);
     109         2773 :       break;
     110            0 :     case ActivationType::ACT_UNKNOWN:
     111              :     default:
     112            0 :       throw std::runtime_error("Error: Not Supported Activation Type");
     113              :     }
     114         9066 :   }
     115              : 
     116              :   /**
     117              :    * @brief run function
     118              :    *
     119              :    * @param[in] input : input
     120              :    * @param[out] output : output
     121              :    */
     122         8455 :   void run_fn(Tensor const &input, Tensor &output) { _act_fn(input, output); }
     123              : 
     124              :   /**
     125              :    * @brief run prime function
     126              :    *
     127              :    * @param[in] input input
     128              :    * @param[in] output output
     129              :    * @param[out] outgoing_derivative outgoing derivative
     130              :    * @param[in] incoming_derivative incoming derivative
     131              :    * @retVal    Tensor
     132              :    */
     133              :   Tensor &run_prime_fn(Tensor &input, Tensor &output,
     134              :                        Tensor &outgoing_derivative,
     135              :                        Tensor const &incoming_derivative) {
     136              :     return _act_prime_fn(input, output, outgoing_derivative,
     137          693 :                          incoming_derivative);
     138              :   }
     139              : 
     140              :   /**
     141              :    * @brief run prime function
     142              :    *
     143              :    * @param[in] output output
     144              :    * @param[out] outgoing_derivative outgoing derivative
     145              :    * @param[in] incoming_derivative incoming derivative
     146              :    * @retVal    Tensor
     147              :    */
     148         4421 :   Tensor &run_prime_fn(Tensor &output, Tensor &outgoing_derivative,
     149              :                        Tensor const &incoming_derivative) {
     150         8842 :     return _act_prime_fn(Tensor(), output, outgoing_derivative,
     151         4421 :                          incoming_derivative);
     152              :   }
     153              : 
     154              :   /**
     155              :    * @copydoc Layer::supportInPlace()
     156              :    */
     157         1176 :   bool supportInPlace() const { return is_inplace; }
     158              : 
     159              :   /**
     160              :    * @brief       Calculate softmax for Tensor Type
     161              :    * @param[in] input input Tensor
     162              :    * @param[out] output output Tensor
     163              :    * @retval      Tensor
     164              :    */
     165              :   template <typename T = float>
     166         1252 :   static Tensor &softmax(Tensor const &input, Tensor &output) {
     167              :     /**
     168              :      * shiftx_logit = logit - max_batch(logit)
     169              :      * softmax = exp(shiftx_logit) / (sum(exp(shiftx_logit)))
     170              :      *
     171              :      * @note softmax is applied on the last dimension
     172              :      */
     173              :     /** TODO: support strided operations */
     174         2158 :     if (input.size() == output.size() &&
     175          906 :         input.getStrides() != output.getStrides())
     176            1 :       throw std::invalid_argument(
     177              :         "Softmax does not support operating on strided tensors");
     178              : 
     179         1251 :     unsigned int width = input.width();
     180         1251 :     unsigned int bch_size = input.getDim().getDataLen() / width;
     181              : 
     182              :     // copy will not executed in inplace case
     183         1251 :     output.copy(input);
     184              : 
     185              :     T *output_data = output.getData<T>();
     186              : 
     187              :     // prevent overflow
     188         1251 :     Tensor tmp(width, input.getTensorType());
     189        13807 :     for (unsigned int i = 0; i < bch_size; i++) {
     190        12556 :       T *ptr = output_data + i * width;
     191              : 
     192              :       // find max value and subtract it
     193        12556 :       T max_value = *std::max_element(ptr, ptr + width);
     194              : 
     195        12556 :       tmp.setValue(max_value);
     196        12556 :       saxpy(width, -1, tmp.getData<T>(), 1, ptr, 1);
     197              :     }
     198              : 
     199              :     // take exp
     200         1251 :     output.apply<T>(exp_util<T>, output);
     201              : 
     202              :     // take sum over the last dimension
     203         1251 :     Tensor sum = output.sum(3);
     204              : 
     205        13807 :     for (unsigned int i = 0; i < bch_size; i++) {
     206        12556 :       T *ptr = output_data + i * width;
     207        12556 :       std::transform(ptr, ptr + width, ptr,
     208              :                      std::bind(std::divides<T>(), std::placeholders::_1,
     209              :                                sum.getValue<T>(i)));
     210              :     }
     211              : 
     212         1251 :     return output;
     213         1251 :   }
     214              : 
     215              :   /**
     216              :    * @brief     Calculate derivative of softmax function
     217              :    * @param[in] output output tensor
     218              :    * @param[out] outgoing_derivative result of calculated derivative of softmax
     219              :    * @param[in] incoming_derivative incoming derivative tensor from next layer
     220              :    * @retVal    Tensor
     221              :    */
     222              : 
     223              :   template <typename T = float>
     224          237 :   static Tensor &softmaxPrime(Tensor const &output, Tensor &outgoing_derivative,
     225              :                               Tensor const &incoming_derivative = Tensor()) {
     226              :     /** TODO: support strided operations */
     227              : 
     228          473 :     if ((output.size() == outgoing_derivative.size() &&
     229          473 :          output.getStrides() != outgoing_derivative.getStrides()) ||
     230          471 :         (output.size() == incoming_derivative.size() &&
     231          235 :          output.getStrides() != incoming_derivative.getStrides()))
     232            1 :       throw std::invalid_argument(
     233              :         "SoftmaxPrime does not support operating on strided tensors");
     234              : 
     235          236 :     unsigned int batch = output.batch();
     236          236 :     unsigned int channel = output.channel();
     237          236 :     unsigned int height = output.height();
     238          236 :     unsigned int width = output.width();
     239              : 
     240          236 :     if (outgoing_derivative.empty())
     241            2 :       outgoing_derivative = Tensor(output.getDim());
     242              : 
     243              :     const T *output_data = output.getData<T>();
     244              :     const T *incoming_derivative_data = incoming_derivative.getData<T>();
     245              :     T *outgoing_derivative_data = outgoing_derivative.getData<T>();
     246              : 
     247          472 :     Tensor tmp = Tensor(width, output.getTensorType());
     248              :     T *tmp_data = tmp.getData<T>();
     249          236 :     unsigned int output_width_stride = output.getStrides()[3];
     250         1890 :     for (unsigned int b = 0; b < batch; ++b) {
     251         1654 :       int b_offset = b * channel * height * width;
     252         3308 :       for (unsigned int c = 0; c < channel; ++c) {
     253         1654 :         int bc_offset = b_offset + c * height * width;
     254         4960 :         for (unsigned int h = 0; h < height; ++h) {
     255         3306 :           int bch_offset = bc_offset + h * width;
     256        24065 :           for (unsigned int w1 = 0; w1 < width; ++w1) {
     257              :             T sum = 0;
     258       177634 :             for (unsigned int w2 = 0; w2 < width; ++w2) {
     259              :               T val;
     260       156875 :               if (w1 == w2) {
     261        20759 :                 val = output_data[bch_offset + w2] *
     262        20759 :                       ((T)1 - output_data[bch_offset + w1]);
     263              :               } else {
     264       136116 :                 val =
     265       136116 :                   -output_data[bch_offset + w2] * output_data[bch_offset + w1];
     266              :               }
     267       156875 :               if (!incoming_derivative.empty())
     268       156575 :                 val *= incoming_derivative_data[bch_offset + w2];
     269       156875 :               sum += val;
     270              :             }
     271        20759 :             tmp.setValue(0, 0, 0, w1, sum);
     272              :           }
     273         3306 :           scopy(width, tmp_data, 1, outgoing_derivative_data + bch_offset,
     274              :                 output_width_stride);
     275              :         }
     276              :       }
     277              :     }
     278              : 
     279          236 :     return outgoing_derivative;
     280          236 :   }
     281              : 
     282              :   /**
     283              :    * @brief     sigmoid activation function
     284              :    * @param[in] x input
     285              :    */
     286     28336070 :   template <typename T = float> static T sigmoid(T x) {
     287     28351486 :     return static_cast<T>(1.0 / (1.0 + exp_util<T>(-x)));
     288              :   }
     289              : 
     290              :   /**
     291              :    * @brief     derivative sigmoid function
     292              :    * @param[in] x input
     293              :    */
     294      2044733 :   template <typename T = float> static T sigmoidPrime(T x) {
     295      2044733 :     return static_cast<T>(x * (static_cast<T>(1.0) - x));
     296              :   }
     297              : 
     298              :   /**
     299              :    * @brief     tanh function for float type
     300              :    * @param[in] x input
     301              :    */
     302        15236 :   template <typename T = float> static T tanhFloat(T x) {
     303        15296 :     return static_cast<T>(2.0 * sigmoid<T>(static_cast<T>(2.0) * x) - 1.0);
     304              :   }
     305              : 
     306              :   /**
     307              :    * @brief     derivative tanh function
     308              :    * @param[in] x input
     309              :    */
     310         5753 :   template <typename T = float> static T tanhPrime(T x) {
     311         5783 :     return static_cast<T>(1.0 - x * x);
     312              :   }
     313              : 
     314              :   /**
     315              :    * @brief     relu activation function
     316              :    * @param[in] x input
     317              :    */
     318       943944 :   template <typename T = float> static T relu(T x) {
     319       943944 :     if (x <= 0)
     320       547262 :       return 0;
     321              :     return x;
     322              :   }
     323              : 
     324              :   /**
     325              :    * @brief     derivative relu function
     326              :    * @param[in] x input
     327              :    */
     328       428597 :   template <typename T = float> static T reluPrime(T x) {
     329       428597 :     if (x <= 0)
     330       265759 :       return 0;
     331              :     return 1;
     332              :   }
     333              : 
     334              :   /**
     335              :    * @brief     no_op function
     336              :    * @param[in] x input
     337              :    */
     338            0 :   template <typename T = float> static T no_op(T x) { return x; }
     339              : 
     340              :   /**
     341              :    * @brief     no_op function
     342              :    * @param[in] x input
     343              :    */
     344            0 :   template <typename T = float> static T no_op_prime(T x) { return 1; }
     345              : 
     346              :   /**
     347              :    * @brief leaky relu function
     348              :    * @note slope parameter is needed for leaky relu, but supporting property on
     349              :    * this class will need extensive refactoring. For now 0.01 is used for
     350              :    * negative slope.
     351              :    *
     352              :    * @param x input
     353              :    * @return float output
     354              :    */
     355          120 :   template <typename T = float> static T leakyRelu(T x) {
     356          120 :     return x >= static_cast<T>(0.0) ? x : static_cast<T>(NEGATIVE_SLOPE) * x;
     357              :   }
     358              : 
     359              :   /**
     360              :    * @brief leaky relu prime function
     361              :    * @note slope parameter is needed for leaky relu, but supporting property on
     362              :    * this class will need extensive refactoring. For now 0.01 is used for
     363              :    * negative slope.
     364              :    *
     365              :    * @param x input
     366              :    * @return float output
     367              :    */
     368           30 :   template <typename T = float> static T leakyReluPrime(T x) {
     369           30 :     return x >= static_cast<T>(0.0) ? static_cast<T>(1.0)
     370           30 :                                     : static_cast<T>(NEGATIVE_SLOPE);
     371              :   }
     372              : 
     373              :   /**
     374              :    * @brief     Softplus activation function
     375              :    * @tparam T type of an input/output
     376              :    * @param x input
     377              :    * @return T type output
     378              :    */
     379           30 :   template <typename T = float> static T softplus(T x) {
     380              :     /** TODO: Change beta to be a property */
     381           90 :     return static_cast<T>(log(1 + exp(beta * x)) / beta);
     382              :   }
     383              : 
     384              :   /**
     385              :    * @brief     derivative softplus function
     386              :    * @tparam T type of an input/output
     387              :    * @param x input
     388              :    * @return T type output
     389              :    */
     390           30 :   template <typename T = float> static T softplusPrime(T x) {
     391           30 :     return sigmoid<T>(static_cast<T>(beta * x));
     392              :   }
     393              : 
     394              :   /**
     395              :    * @brief     swish activation function
     396              :    * @param[in] t_in input tensor
     397              :    * @param[in] t_out output tensor
     398              :    */
     399              :   template <typename T = float>
     400            2 :   static Tensor &swish(Tensor const &t_in, Tensor &t_out) {
     401            2 :     t_in.apply<T>([&](T x) { return sigmoid<T>(x); }, t_out);
     402            2 :     t_out.multiply_i(t_in);
     403              : 
     404            2 :     return t_out;
     405              :   }
     406              : 
     407              :   /**
     408              :    * @brief     derivative swish function
     409              :    * @param[in] t_in input tensor
     410              :    * @param[in] t_out output tensor
     411              :    * @param[in] outgoing_derivative outgoing derivative
     412              :    * @param[in] incoming_derivative incoming derivative
     413              :    */
     414              :   template <typename T = float>
     415            1 :   static Tensor &swishPrime(Tensor const &t_in, Tensor const &t_out,
     416              :                             Tensor &outgoing_derivative,
     417              :                             Tensor const &incoming_derivative = Tensor()) {
     418            1 :     if (outgoing_derivative.empty())
     419            0 :       outgoing_derivative = Tensor(t_out.getDim());
     420              : 
     421            1 :     Tensor tmp = Tensor(t_out.getDim());
     422            2 :     t_in.apply<T>([&](T x) { return sigmoid(x); }, outgoing_derivative);
     423            1 :     t_out.apply<T>([&](T x) { return 1 - x; }, tmp);
     424            1 :     outgoing_derivative.multiply_i(tmp);
     425            1 :     outgoing_derivative.add_i(t_out);
     426              : 
     427            1 :     outgoing_derivative.multiply_i_strided(incoming_derivative);
     428              : 
     429            1 :     return outgoing_derivative;
     430            1 :   }
     431              : 
     432              :   /**
     433              :    * @brief     gelu activation function
     434              :    * @param[in] t_in input tensor
     435              :    * @param[in] t_out output tensor
     436              :    */
     437              :   template <typename T = float>
     438            2 :   static Tensor &gelu(Tensor const &t_in, Tensor &t_out) {
     439            2 :     double tmp = 1.0 / sqrt(2.0);
     440            2 :     t_in.apply<T>(
     441           60 :       [&](T x) { return static_cast<T>(0.5 * x * (1 + erf(x * tmp))); }, t_out);
     442            2 :     return t_out;
     443              :   }
     444              : 
     445              :   /**
     446              :    * @brief     derivative gelu function
     447              :    * @param[in] t_in input tensor
     448              :    * @param[in] t_out output tensor
     449              :    * @param[in] outgoing_derivative outgoing derivative
     450              :    * @param[in] incoming_derivative incoming derivative
     451              :    */
     452              :   template <typename T = float>
     453            1 :   static Tensor &geluPrime(Tensor const &t_in, Tensor const &t_out,
     454              :                            Tensor &outgoing_derivative,
     455              :                            Tensor const &incoming_derivative = Tensor()) {
     456              : 
     457            1 :     if (outgoing_derivative.empty())
     458            0 :       outgoing_derivative = Tensor(t_out.getDim());
     459              : 
     460            1 :     T tmp = static_cast<T>(1 / sqrt(2));
     461            1 :     t_in.apply<T>(
     462           30 :       [&](T x) {
     463              :         return static_cast<T>(
     464           30 :           0.5 * (1 + erf(x * tmp) +
     465           30 :                  x * ((2 / sqrt(M_PI)) * exp(-pow(x * tmp, 2))) * tmp));
     466              :       },
     467              :       outgoing_derivative);
     468              : 
     469            1 :     outgoing_derivative.multiply_i_strided(incoming_derivative);
     470              : 
     471            1 :     return outgoing_derivative;
     472              :   }
     473              : 
     474              :   /**
     475              :    * @brief     tanh-based gelu approximate function
     476              :    * @param[in] t_in input tensor
     477              :    * @param[in] t_out output tensor
     478              :    */
     479              :   template <typename T = float>
     480            0 :   static Tensor &tanhGelu(Tensor const &t_in, Tensor &t_out) {
     481            0 :     t_in.apply<T>(
     482            0 :       [&](T x) {
     483              :         return static_cast<T>(
     484            0 :           0.5 * x *
     485            0 :           (1 + tanhFloat<T>(
     486            0 :                  static_cast<T>(sqrt(2 / M_PI) * (x + 0.044715 * pow(x, 3))))));
     487              :       },
     488              :       t_out);
     489            0 :     return t_out;
     490              :   }
     491              : 
     492              :   /**
     493              :    * @brief     derivative of tanh-based gelu approximate function
     494              :    * @param[in] t_in input tensor
     495              :    * @param[in] t_out output tensor
     496              :    * @param[in] outgoing_derivative outgoing derivative
     497              :    * @param[in] incoming_derivative incoming derivative
     498              :    */
     499              :   template <typename T = float>
     500            0 :   static Tensor &tanhGeluPrime(Tensor const &t_in, Tensor const &t_out,
     501              :                                Tensor &outgoing_derivative,
     502              :                                Tensor const &incoming_derivative = Tensor()) {
     503              :     // NYI
     504            0 :     ml_logw("tanhGeluPrime which is calculate derivate of tanhGelu function is "
     505              :             "not yet implemented");
     506            0 :     return outgoing_derivative;
     507              :   }
     508              : 
     509              :   /**
     510              :    * @brief     sigmoid-based gelu approximate function (quick gelu)
     511              :    * @param[in] t_in input tensor
     512              :    * @param[in] t_out output tensor
     513              :    */
     514              :   template <typename T = float>
     515            0 :   static Tensor &sigmoidGelu(Tensor const &t_in, Tensor &t_out) {
     516            0 :     t_in.apply<T>(
     517              :       [&](T x) {
     518            0 :         return static_cast<T>(x * (sigmoid<T>(static_cast<T>(1.702 * x))));
     519              :       },
     520              :       t_out);
     521            0 :     return t_out;
     522              :   }
     523              : 
     524              :   /**
     525              :    * @brief     derivative of sigmoid-based gelu approximate function
     526              :    * @param[in] t_in input tensor
     527              :    * @param[in] t_out output tensor
     528              :    * @param[in] outgoing_derivative outgoing derivative
     529              :    * @param[in] incoming_derivative incoming derivative
     530              :    */
     531              :   template <typename T = float>
     532              :   static Tensor &
     533            0 :   sigmoidGeluPrime(Tensor const &t_in, Tensor const &t_out,
     534              :                    Tensor &outgoing_derivative,
     535              :                    Tensor const &incoming_derivative = Tensor()) {
     536              :     // NYI
     537            0 :     ml_logw("sigmoidGeluPrime which is calculate derivate of sigmoidGelu "
     538              :             "function is not yet implemented");
     539            0 :     return outgoing_derivative;
     540              :   }
     541              : 
     542              :   /**
     543              :    * @brief elu function
     544              :    * @note alpha parameter is needed for elu, but supporting property on
     545              :    * this class will need extensive refactoring. For now 1.0 is used for
     546              :    * alpha.
     547              :    *
     548              :    * @tparam T type of an input/output
     549              :    * @param x input
     550              :    * @return T type output
     551              :    */
     552           30 :   template <typename T = float> static T elu(T x) {
     553           30 :     return x >= static_cast<T>(0.0) ? x : static_cast<T>(alpha * (exp(x) - 1));
     554              :   }
     555              : 
     556              :   /**
     557              :    * @brief elu prime function
     558              :    * @note alpha parameter is needed for elu, but supporting property on
     559              :    * this class will need extensive refactoring. For now 1.0 is used for
     560              :    * alpha.
     561              :    *
     562              :    * @tparam T type of an input/output
     563              :    * @param x input
     564              :    * @return T type output
     565              :    */
     566           30 :   template <typename T = float> static T eluPrime(T x) {
     567           30 :     return x >= static_cast<T>(0.0) ? static_cast<T>(1.0)
     568           12 :                                     : static_cast<T>(alpha * exp(x));
     569              :   }
     570              : 
     571              :   /**
     572              :    * @brief selu function
     573              :    * @tparam T type of an input/output
     574              :    * @param x input
     575              :    * @return T type output
     576              :    */
     577           30 :   template <typename T = float> static T selu(T x) {
     578              :     return x > static_cast<T>(0.0)
     579           30 :              ? static_cast<T>(selu_scale * x)
     580           15 :              : static_cast<T>(selu_scale * selu_alpha * (exp(x) - 1));
     581              :   }
     582              : 
     583              :   /**
     584              :    * @brief selu prime function
     585              :    * @tparam T type of an input/output
     586              :    * @param x input
     587              :    * @return T type output
     588              :    */
     589           30 :   template <typename T = float> static T seluPrime(T x) {
     590              :     return x > static_cast<T>(0.0)
     591           30 :              ? static_cast<T>(selu_scale)
     592           15 :              : static_cast<T>(selu_scale * selu_alpha * exp(x));
     593              :   }
     594              : 
     595              :   /**
     596              :    * @brief     mish activation function
     597              :    * @param[in] x input
     598              :    */
     599           30 :   template <typename T = float> static T mish(T x) {
     600           30 :     return static_cast<T>(x * tanhFloat<T>(softplus<T>(x)));
     601              :   }
     602              : 
     603              :   /**
     604              :    * @brief     mish prime function
     605              :    * @param[in] x input
     606              :    */
     607           30 :   template <typename T = float> static T mishPrime(T x) {
     608              :     return static_cast<T>(tanhFloat<T>(softplus<T>(x)) +
     609           30 :                           x * softplusPrime<T>(x) *
     610           30 :                             tanhPrime<T>(tanhFloat<T>(softplus<T>(x))));
     611              :   }
     612              : 
     613              :   /**
     614              :    * @brief setActivation by custom activation function
     615              :    * @note  apply derivative as this activation_prime_fn does not utilize
     616              :    * derivative
     617              :    * @param[in] std::function<Tensor(Tensor const &, Tensor &)> activation_fn
     618              :    * activation function to be used
     619              :    * @param[in] std::function<Tensor(Tensor const &, Tensor &)>
     620              :    * activation_prime_fn activation_prime_function to be used
     621              :    * @retval #ML_ERROR_NONE when successful
     622              :    */
     623              :   template <typename funcParam = Tensor>
     624         1926 :   int setActivation(
     625              :     std::function<funcParam &(funcParam const &, funcParam &)> const
     626              :       &activation_fn,
     627              :     std::function<funcParam &(funcParam &, funcParam &,
     628              :                               funcParam const &)> const &activation_prime_fn) {
     629         1926 :     _act_fn = activation_fn;
     630         5778 :     _act_prime_fn = [activation_prime_fn](
     631              :                       funcParam const &t_in, funcParam &t_out,
     632              :                       funcParam &outgoing_derivative,
     633              :                       funcParam const &incoming_derivative) -> funcParam & {
     634              :       return activation_prime_fn(t_out, outgoing_derivative,
     635              :                                  incoming_derivative);
     636              :     };
     637              : 
     638         1926 :     return ML_ERROR_NONE;
     639              :   }
     640              : 
     641              :   /**
     642              :    * @brief setActivation by custom activation function
     643              :    * @note  derivative not applied here as this activation_prime_fn applies
     644              :    * derivative itself
     645              :    * @param[in] std::function<Tensor(Tensor const &, Tensor &)> activation_fn
     646              :    * activation function to be used
     647              :    * @param[in] std::function<Tensor(Tensor const &, Tensor &, Tensor const &)>
     648              :    * activation_prime_fn activation_prime_function to be used
     649              :    * @retval #ML_ERROR_NONE when successful
     650              :    */
     651              :   template <typename funcParam = Tensor>
     652           16 :   int setActivation(
     653              :     std::function<funcParam &(funcParam const &, funcParam &)> const
     654              :       &activation_fn,
     655              :     std::function<funcParam &(funcParam const &, funcParam const &, funcParam &,
     656              :                               funcParam const &)> const &activation_prime_fn) {
     657           16 :     if (is_inplace)
     658              :       return ML_ERROR_INVALID_PARAMETER;
     659              : 
     660           16 :     _act_fn = activation_fn;
     661           16 :     _act_prime_fn = activation_prime_fn;
     662              : 
     663           16 :     return ML_ERROR_NONE;
     664              :   }
     665              : 
     666              :   /**
     667              :    * @brief setActivation by custom activation function
     668              :    * @note  derivative not applied here as this activation_prime_fn applies
     669              :    * derivative itself
     670              :    * @param[in] activation_fn activation function to be used
     671              :    * @param[in] activtion_prime_fn activation prime function to be used
     672              :    * @retval #ML_ERROR_NONE when successful
     673              :    */
     674              :   template <typename funcParam = Tensor>
     675              :   int setActivation(
     676              :     std::function<funcParam &(funcParam const &, funcParam &)> const
     677              :       &activation_fn,
     678              :     std::function<funcParam &(funcParam &, funcParam &)> const
     679              :       &activation_prime_fn) {
     680              :     if (!is_inplace) {
     681              :       _act_prime_fn = [activation_prime_fn](
     682              :                         funcParam const &t_in, funcParam &t_out,
     683              :                         funcParam &outgoing_derivative,
     684              :                         funcParam const &incoming_derivative) -> funcParam & {
     685              :         /** @todo update this based on supportInPlace */
     686              :         activation_prime_fn(t_out, outgoing_derivative);
     687              :         outgoing_derivative.multiply_i_strided(incoming_derivative);
     688              : 
     689              :         return outgoing_derivative;
     690              :       };
     691              :     } else {
     692              :       _act_prime_fn = [activation_prime_fn](
     693              :                         funcParam const &t_in, funcParam &t_out,
     694              :                         funcParam &outgoing_derivative,
     695              :                         funcParam const &incoming_derivative) -> funcParam & {
     696              :         activation_prime_fn(t_out, t_out);
     697              :         incoming_derivative.multiply_strided(t_out, outgoing_derivative);
     698              : 
     699              :         return outgoing_derivative;
     700              :       };
     701              :     }
     702              : 
     703              :     return ML_ERROR_NONE;
     704              :   }
     705              : 
     706              :   /**
     707              :    * @brief setActivation by custom activation function
     708              :    * @note  apply derivative as this activation_prime_fn does not utilize
     709              :    * derivative
     710              :    * @param[in] std::function<float(float const &)> activation_fn activation
     711              :    *            function to be used
     712              :    * @param[in] std::function<float(float const &)> activation_prime_fn
     713              :    *            activation_prime_function to be used
     714              :    * @retval #ML_ERROR_NONE when successful
     715              :    */
     716              :   template <typename funcParam = float>
     717         7124 :   int setActivation(
     718              :     std::function<funcParam(funcParam const)> const &activation_fn,
     719              :     std::function<funcParam(funcParam const)> const &activation_prime_fn) {
     720        21372 :     _act_fn = [activation_fn](Tensor const &x, Tensor &hidden) -> Tensor & {
     721        18764 :       return x.apply(activation_fn, hidden);
     722              :     };
     723         7124 :     if (!is_inplace) {
     724          274 :       _act_prime_fn =
     725         1096 :         [activation_prime_fn](Tensor const &t_in, Tensor &t_out,
     726              :                               Tensor &outgoing_derivative,
     727              :                               Tensor const &incoming_derivative) -> Tensor & {
     728              :         /** @todo update this based on supportInPlace */
     729          308 :         t_out.apply(activation_prime_fn, outgoing_derivative);
     730          308 :         outgoing_derivative.multiply_i_strided(incoming_derivative);
     731              : 
     732          308 :         return outgoing_derivative;
     733              :       };
     734              :     } else {
     735         6850 :       _act_prime_fn =
     736        27400 :         [activation_prime_fn](Tensor const &t_in, Tensor &t_out,
     737              :                               Tensor &outgoing_derivative,
     738              :                               Tensor const &incoming_derivative) -> Tensor & {
     739         4571 :         t_out.apply(activation_prime_fn, t_out);
     740         4571 :         incoming_derivative.multiply_strided(t_out, outgoing_derivative);
     741              : 
     742         4571 :         return outgoing_derivative;
     743              :       };
     744              :     }
     745              : 
     746         7124 :     return ML_ERROR_NONE;
     747              :   }
     748              : 
     749              :   /**
     750              :    * @brief setActivation by custom activation function
     751              :    * @note  apply derivative as this activation_prime_fn does not utilize
     752              :    * derivative
     753              :    * @param[in] std::function<float(float const)> activation_fn activation
     754              :    * function to be used
     755              :    * @param[in] std::function<float(float const, float const)>
     756              :    * activation_prime_fn activation_prime_function to be used
     757              :    * @retval #ML_ERROR_NONE when successful
     758              :    */
     759              :   int setActivation(
     760              :     std::function<float(float const)> const &activation_fn,
     761              :     std::function<float(float const, float const)> const &activation_prime_fn);
     762              : 
     763              :   /**
     764              :    * @brief   Notify that this layer will execute in-place
     765              :    *
     766              :    * @param val True if execute in-place, else false
     767              :    */
     768          432 :   void setInPlace(bool val) {
     769          432 :     if (val && !supportInPlace())
     770              :       throw std::runtime_error(
     771            0 :         "Error setting activation layer to work in-place");
     772              : 
     773          432 :     is_inplace = val;
     774          432 :   }
     775              : 
     776              : private:
     777              :   constexpr static inline float alpha = 1.0f; /**< alpha for elu */
     778              :   constexpr static inline float beta = 1.0f;  /**< beta for Softplus */
     779              :   constexpr static inline float selu_alpha = 1.67326324f; /**< alpha for selu */
     780              :   constexpr static inline float selu_scale = 1.05070098f; /**< scale for selu */
     781              : 
     782              :   std::function<Tensor &(Tensor const &, Tensor &)> _act_fn;
     783              :   std::function<Tensor &(Tensor const &, Tensor &, Tensor &, Tensor const &)>
     784              :     _act_prime_fn; /**< prime function with input and output*/
     785              : 
     786              :   ActivationType
     787              :     activation_type; /**< type of the activation represented by this */
     788              :   bool is_inplace;   /**< if this class should operate is_inplace */
     789              : };
     790              : 
     791              : } // namespace nntrainer
     792              : 
     793              : #endif /* __cplusplus */
     794              : #endif /* __ACTI_FUNC_H__ */
        

Generated by: LCOV version 2.0-1