From dbc1dedf83af0d2e6c799657465d978cb339fd8c Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Fri, 3 May 2024 15:00:14 -0400 Subject: [PATCH] add leakyrelu def --- ext/ForwardDiffExt.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ext/ForwardDiffExt.jl b/ext/ForwardDiffExt.jl index 26227f69..01cfecff 100644 --- a/ext/ForwardDiffExt.jl +++ b/ext/ForwardDiffExt.jl @@ -377,3 +377,15 @@ for f in (:vmapt, :vmapnt, :vmapntt) end end end + +if Base.ifelse !== IfElse.ifelse + @inline function NNlib.leakyrelu(x::AbstractSIMD) + fx = float(x) + NNlib.leakyrelu(fx, convert(typeof(fx), NNlib.leakyrelu_a)) + end + @inline function NNlib.leakyrelu(x::AbstractSIMD, a) + fx = float(x) + ax = convert(typeof(fx), a * x) + ifelse(x > 0, fx, ax) # max(a*x, x) is 3x slower + end +end