Skip to content

Commit

Permalink
add leakyrelu def
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed May 3, 2024
1 parent 51ee029 commit 8e33be9
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions ext/ForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,15 @@ for f in (:vmapt, :vmapnt, :vmapntt)
end
end
end

if Base.ifelse !== ifelse
@inline function NNlib.leakyrelu(x::AbstractSIMD)
fx = float(x)
NNlib.leakyrelu(fx, convert(typeof(fx), NNlib.leakyrelu_a))
end
@inline function NNlib.leakyrelu(x::AbstractSIMD, a)
fx = float(x)
ax = convert(typeof(fx), a * x)
ifelse(x > 0, fx, ax) # max(a*x, x) is 3x slower
end
end

0 comments on commit 8e33be9

Please sign in to comment.