Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/nf.f90
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ module nf
use nf_optimizers, only: sgd
use nf_activation, only: activation_function, elu, exponential, &
gaussian, linear, relu, leaky_relu, &
sigmoid, softmax, softplus, step, tanhf
sigmoid, softmax, softplus, step, tanhf, &
celu
end module nf
60 changes: 60 additions & 0 deletions src/nf/nf_activation.f90
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ module nf_activation
public :: softplus
public :: step
public :: tanhf
public :: celu

type, abstract :: activation_function
contains
Expand Down Expand Up @@ -140,6 +141,15 @@ end function eval_3d_i
procedure :: eval_3d_prime => eval_3d_tanh_prime
end type tanhf

type, extends(activation_function) :: celu
real:: alpha = 1.0 ! Pytorch default
contains
procedure :: eval_1d => eval_1d_celu
procedure :: eval_1d_prime => eval_1d_celu_prime
procedure :: eval_3d => eval_3d_celu
procedure :: eval_3d_prime => eval_3d_celu_prime
end type celu

contains

pure function eval_1d_elu(self, x) result(res)
Expand Down Expand Up @@ -522,6 +532,54 @@ pure function eval_3d_tanh_prime(self, x) result(res)
res = 1 - tanh(x)**2
end function eval_3d_tanh_prime

pure function eval_1d_celu(self, x) result(res)
! Celu activation function.
class(celu), intent(in) :: self
real, intent(in) :: x(:)
real :: res(size(x))
where (x >= 0.0)
res = x
else where
res = self % alpha * (exp(x / self % alpha) - 1.0)
end where
end function

pure function eval_1d_celu_prime(self, x) result(res)
! Celu activation function.
class(celu), intent(in) :: self
real, intent(in) :: x(:)
real :: res(size(x))
where (x >= 0.0)
res = 1.0
else where
res = exp(x / self % alpha)
end where
end function

pure function eval_3d_celu(self, x) result(res)
! Celu activation function.
class(celu), intent(in) :: self
real, intent(in) :: x(:,:,:)
real :: res(size(x,1),size(x,2),size(x,3))
where (x >= 0.0)
res = x
else where
res = self % alpha * (exp(x / self % alpha) - 1.0)
end where
end function

pure function eval_3d_celu_prime(self, x) result(res)
! Celu activation function.
class(celu), intent(in) :: self
real, intent(in) :: x(:,:,:)
real :: res(size(x,1),size(x,2),size(x,3))
where (x >= 0.0)
res = 1.0
else where
res = exp(x / self % alpha)
end where
end function

pure function get_name(self) result(name)
!! Return the name of the activation function.
!!
Expand Down Expand Up @@ -556,6 +614,8 @@ pure function get_name(self) result(name)
name = 'step'
class is (tanhf)
name = 'tanh'
class is (celu)
name = 'celu'
class default
error stop 'Unknown activation function type.'
end select
Expand Down
8 changes: 6 additions & 2 deletions src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
softmax, &
softplus, &
step, &
tanhf
tanhf, &
celu

implicit none

Expand Down Expand Up @@ -268,10 +269,13 @@ pure function get_activation_by_name(activation_name) result(res)
case('tanh')
allocate ( res, source = tanhf() )

case('celu')
allocate ( res, source = celu() )

case default
error stop 'activation_name must be one of: ' // &
'"elu", "exponential", "gaussian", "linear", "relu", ' // &
'"leaky_relu", "sigmoid", "softmax", "softplus", "step", or "tanh".'
'"leaky_relu", "sigmoid", "softmax", "softplus", "step", "tanh" or "celu".'
end select

end function get_activation_by_name
Expand Down