diff --git a/src/nf.f90 b/src/nf.f90 index e159ee64..e4522aa1 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -8,5 +8,6 @@ module nf use nf_optimizers, only: sgd use nf_activation, only: activation_function, elu, exponential, & gaussian, linear, relu, leaky_relu, & - sigmoid, softmax, softplus, step, tanhf + sigmoid, softmax, softplus, step, tanhf, & + celu end module nf diff --git a/src/nf/nf_activation.f90 b/src/nf/nf_activation.f90 index b83fc0f7..67034a37 100644 --- a/src/nf/nf_activation.f90 +++ b/src/nf/nf_activation.f90 @@ -18,6 +18,7 @@ module nf_activation public :: softplus public :: step public :: tanhf + public :: celu type, abstract :: activation_function contains @@ -140,6 +141,15 @@ end function eval_3d_i procedure :: eval_3d_prime => eval_3d_tanh_prime end type tanhf + type, extends(activation_function) :: celu + real:: alpha = 1.0 ! Pytorch default + contains + procedure :: eval_1d => eval_1d_celu + procedure :: eval_1d_prime => eval_1d_celu_prime + procedure :: eval_3d => eval_3d_celu + procedure :: eval_3d_prime => eval_3d_celu_prime + end type celu + contains pure function eval_1d_elu(self, x) result(res) @@ -522,6 +532,54 @@ pure function eval_3d_tanh_prime(self, x) result(res) res = 1 - tanh(x)**2 end function eval_3d_tanh_prime + pure function eval_1d_celu(self, x) result(res) + ! Celu activation function. + class(celu), intent(in) :: self + real, intent(in) :: x(:) + real :: res(size(x)) + where (x >= 0.0) + res = x + else where + res = self % alpha * (exp(x / self % alpha) - 1.0) + end where + end function + + pure function eval_1d_celu_prime(self, x) result(res) + ! Celu activation function. + class(celu), intent(in) :: self + real, intent(in) :: x(:) + real :: res(size(x)) + where (x >= 0.0) + res = 1.0 + else where + res = exp(x / self % alpha) + end where + end function + + pure function eval_3d_celu(self, x) result(res) + ! Celu activation function. + class(celu), intent(in) :: self + real, intent(in) :: x(:,:,:) + real :: res(size(x,1),size(x,2),size(x,3)) + where (x >= 0.0) + res = x + else where + res = self % alpha * (exp(x / self % alpha) - 1.0) + end where + end function + + pure function eval_3d_celu_prime(self, x) result(res) + ! Celu activation function. + class(celu), intent(in) :: self + real, intent(in) :: x(:,:,:) + real :: res(size(x,1),size(x,2),size(x,3)) + where (x >= 0.0) + res = 1.0 + else where + res = exp(x / self % alpha) + end where + end function + pure function get_name(self) result(name) !! Return the name of the activation function. !! @@ -556,6 +614,8 @@ pure function get_name(self) result(name) name = 'step' class is (tanhf) name = 'tanh' + class is (celu) + name = 'celu' class default error stop 'Unknown activation function type.' end select diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 33c4a665..e71a8e80 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -25,7 +25,8 @@ softmax, & softplus, & step, & - tanhf + tanhf, & + celu implicit none @@ -268,10 +269,13 @@ pure function get_activation_by_name(activation_name) result(res) case('tanh') allocate ( res, source = tanhf() ) + case('celu') + allocate ( res, source = celu() ) + case default error stop 'activation_name must be one of: ' // & '"elu", "exponential", "gaussian", "linear", "relu", ' // & - '"leaky_relu", "sigmoid", "softmax", "softplus", "step", or "tanh".' + '"leaky_relu", "sigmoid", "softmax", "softplus", "step", "tanh" or "celu".' end select end function get_activation_by_name