diff --git a/ext/ForwardDiffExt.jl b/ext/ForwardDiffExt.jl index 26227f69..8f4448a8 100644 --- a/ext/ForwardDiffExt.jl +++ b/ext/ForwardDiffExt.jl @@ -377,3 +377,15 @@ for f in (:vmapt, :vmapnt, :vmapntt) end end end + +if Base.ifelse !== ifelse + @inline function NNlib.leakyrelu(x::AbstractSIMD) + fx = float(x) + NNlib.leakyrelu(fx, convert(typeof(fx), NNlib.leakyrelu_a)) + end + @inline function NNlib.leakyrelu(x::AbstractSIMD, a) + fx = float(x) + ax = convert(typeof(fx), a * x) + ifelse(x > 0, fx, ax) # max(a*x, x) is 3x slower + end +end