Skip to content

Commit

Permalink
Merge pull request #134 from SciML/gpu
Browse files Browse the repository at this point in the history
Fix GPU exponential! defaults
  • Loading branch information
ChrisRackauckas authored Sep 9, 2023
2 parents 887d7d5 + c449067 commit a2a45ae
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/exp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ end
@deprecate _exp! exponential!
@deprecate exp_generic exponential!
exponential!(A) = exponential!(A, ExpMethodHigham2005(A));
exponential!(A::GPUArraysCore.AbstractGPUArray) = exponential!(A, ExpMethodHigham2005(false));

## The diagonalization based
"""
Expand Down
1 change: 1 addition & 0 deletions src/exp_noalloc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ struct ExpMethodHigham2005
end
ExpMethodHigham2005(A::AbstractMatrix) = ExpMethodHigham2005(A isa StridedMatrix)
ExpMethodHigham2005() = ExpMethodHigham2005(true)
ExpMethodHigham2005(A::GPUArraysCore.AbstractGPUArray) = ExpMethodHigham2005(false)

function alloc_mem(A, ::ExpMethodHigham2005)
T = eltype(A)
Expand Down
2 changes: 2 additions & 0 deletions test/gpu/gputests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ end
eA = exp(A)
A_d = cu(A)

exponential!(copy(A_d)) # Make sure simple command works

# Iterate over GPU-compatible methods
for m in (ExpMethodHigham2005(false),)
@testset "GPU Exponential, $(string(m))" begin
Expand Down

0 comments on commit a2a45ae

Please sign in to comment.