Skip to content

Commit

Permalink
Change to GPUArraysCore
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jun 22, 2022
1 parent 92a25f7 commit 2375658
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
Expand Down Expand Up @@ -54,7 +54,7 @@ EllipsisNotation = "1"
Enzyme = "0.8, 0.9, 0.10"
FiniteDiff = "2"
ForwardDiff = "0.10"
GPUArrays = "8"
GPUArraysCore = "0.1"
LinearSolve = "1"
OrdinaryDiffEq = "5.60, 6"
Parameters = "0.12"
Expand Down
2 changes: 1 addition & 1 deletion src/DiffEqSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using Random
import ZygoteRules, Zygote, ReverseDiff
import ArrayInterfaceCore, ArrayInterfaceTracker
import Enzyme
import GPUArrays
import GPUArraysCore

using Cassette, DiffRules
using Core: CodeInfo, SlotNumber, SSAValue, ReturnNode, GotoIfNot
Expand Down
6 changes: 3 additions & 3 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function automatic_sensealg_choice(prob::Union{ODEProblem,SDEProblem},u0,p,verbo
!(eltype(p) <: Complex) &&
length(u0) + length(p) <= 100
ForwardDiffSensitivity()
elseif u0 isa GPUArrays.AbstractGPUArray || !DiffEqBase.isinplace(prob)
elseif u0 isa GPUArraysCore.AbstractGPUArray || !DiffEqBase.isinplace(prob)
# only Zygote is GPU compatible and fast
# so if out-of-place, try Zygote
if p === nothing || p === DiffEqBase.NullParameters()
Expand All @@ -75,7 +75,7 @@ end

function automatic_sensealg_choice(prob::Union{NonlinearProblem,SteadyStateProblem}, u0, p, verbose)

default_sensealg = if u0 isa GPUArrays.AbstractGPUArray || !DiffEqBase.isinplace(prob)
default_sensealg = if u0 isa GPUArraysCore.AbstractGPUArray || !DiffEqBase.isinplace(prob)
# autodiff = false because forwarddiff fails on many GPU kernels
# this only effects the Jacobian calculation and is same computation order
SteadyStateAdjoint(autodiff=false, autojacvec=ZygoteVJP())
Expand Down Expand Up @@ -747,7 +747,7 @@ end
function DiffEqBase._concrete_solve_adjoint(prob,alg,sensealg::ReverseDiffAdjoint,
u0,p,originator::SciMLBase.ADOriginator,args...;kwargs...)

if typeof(u0) isa GPUArrays.AbstractGPUArray
if typeof(u0) isa GPUArraysCore.AbstractGPUArray
throw(ReverseDiffGPUStateCompatibilityError())
end

Expand Down

0 comments on commit 2375658

Please sign in to comment.