diff --git a/src/exp.jl b/src/exp.jl index 6b23b77..6dad01d 100644 --- a/src/exp.jl +++ b/src/exp.jl @@ -13,6 +13,7 @@ end @deprecate _exp! exponential! @deprecate exp_generic exponential! exponential!(A) = exponential!(A, ExpMethodHigham2005(A)); +exponential!(A::GPUArraysCore.AbstractGPUArray) = exponential!(A, ExpMethodHigham2005(false)); ## The diagonalization based """ diff --git a/src/exp_noalloc.jl b/src/exp_noalloc.jl index 7b9f7eb..147e235 100644 --- a/src/exp_noalloc.jl +++ b/src/exp_noalloc.jl @@ -10,6 +10,7 @@ struct ExpMethodHigham2005 end ExpMethodHigham2005(A::AbstractMatrix) = ExpMethodHigham2005(A isa StridedMatrix) ExpMethodHigham2005() = ExpMethodHigham2005(true) +ExpMethodHigham2005(A::GPUArraysCore.AbstractGPUArray) = ExpMethodHigham2005(false) function alloc_mem(A, ::ExpMethodHigham2005) T = eltype(A) diff --git a/test/gpu/gputests.jl b/test/gpu/gputests.jl index 9693472..347abd9 100644 --- a/test/gpu/gputests.jl +++ b/test/gpu/gputests.jl @@ -25,6 +25,8 @@ end eA = exp(A) A_d = cu(A) + exponential!(copy(A_d)) # Make sure simple command works + # Iterate over GPU-compatible methods for m in (ExpMethodHigham2005(false),) @testset "GPU Exponential, $(string(m))" begin