Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SteadyStateAdjoint + GPUs: scalar getindex is disallowed #557

QiyaoWei opened this issue May 18, 2021 · 3 comments

SteadyStateAdjoint + GPUs: scalar getindex is disallowed #557

QiyaoWei opened this issue May 18, 2021 · 3 comments


Copy link

Hi all,

I am trying to use GPU on a SteadyStateProblem (to implement Deep Equilibrium Models to be more specific), and run into a "scalar getindex is disallowed" problem. Any help would be appreciated! Thanks a lot!


using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
using CUDA

# FastChain with initial_params should also work
# But it's more stable to use Chain and destructure
# ann = FastChain(
#   FastDense(1, 2, relu),
#   FastDense(2, 1, tanh))
# p1 = initial_params(ann)

ann = Chain(Dense(1, 2), Dense(2, 1)) |> gpu
p,re = Flux.destructure(ann)
tspan = (0.0f0, 1.0f0) |> gpu

function solve_ss(x)
    xg = gpu(x)
    z = re(p)(xg) |> gpu
    function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
        # Key question: Is there any difference between
        # re(_p)(x) and re(_p)(u+x)?
        re(_p)(u+xg) - u
    ss = SteadyStateProblem(ODEProblem(gpu(dudt_), gpu(z), gpu(tspan), p))
    Array(solve(ss, DynamicSS(Tsit5()), u0 = z, sensealg=SteadyStateAdjoint()))

# Let's run a DEQ model on linear regression for y = 2x
X = [1;2;3;4;5;6;7;8;9;10]
Y = [2;4;6;8;10;12;14;16;18;20]
data = Flux.Data.DataLoader(gpu.(collect.((X, Y))), batchsize=1,shuffle=true)
opt = ADAM(0.05)

function loss(x, y)
  ŷ = solve_ss(x)
  @show x
  @show y
  @show ŷ
  @show sum((y .- ŷ).^2)

epochs = 100
for i in 1:epochs
    Flux.train!(loss, Flux.params(p), data, opt)
    println(solve_ss([-5])) # Print model prediction


└ @ SciMLBase C:\Users\administered\.julia\packages\SciMLBase\9EjAY\src\integrator_interface.jl:331
ERROR: LoadError: scalar getindex is disallowed
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] assertscalar(op::String)
    @ GPUArrays ~\.julia\packages\GPUArrays\gjXOn\src\host\indexing.jl:56
  [3] getindex
    @ ~\.julia\packages\GPUArrays\gjXOn\src\host\indexing.jl:98 [inlined]
  [4] getindex
    @ ~\.julia\packages\SciMLBase\9EjAY\src\solutions\solution_interface.jl:6 [inlined]
  [5] iterate
    @ .\abstractarray.jl:1096 [inlined]
  [6] iterate
    @ .\abstractarray.jl:1094 [inlined]
  [7] copyto_unaliased!(deststyle::IndexLinear, dest::Vector{Float32}, srcstyle::IndexCartesian, src::SciMLBase.NonlinearSolution{Float32, 1, CuArray{Float32, 1}, CuArray{Float32, 1}, SteadyStateProblem{CuArray{Float32, 1}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#1"{CuArray{Int64, 1}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Base .\abstractarray.jl:975
  [8] copyto!
    @ .\abstractarray.jl:950 [inlined]
  [9] copyto_axcheck!
    @ .\abstractarray.jl:1056 [inlined]
 [10] Array
    @ .\array.jl:540 [inlined]
 [11] Array
    @ .\boot.jl:472 [inlined]
 [12] adjoint
    @ ~\.julia\packages\Zygote\6HN9x\src\lib\array.jl:8 [inlined]
 [13] _pullback(__context__::Zygote.Context, 486::Type{Array}, xs::SciMLBase.NonlinearSolution{Float32, 1, CuArray{Float32, 1}, CuArray{Float32, 1}, SteadyStateProblem{CuArray{Float32, 1}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#1"{CuArray{Int64, 1}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), 
Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:57
 [14] _pullback
    @ e:\wqy\julia\simpleDEQ.jl:31 [inlined]
 [15] _pullback(ctx::Zygote.Context, f::typeof(solve_ss), args::CuArray{Int64, 1})
    @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [16] _pullback
    @ e:\wqy\julia\simpleDEQ.jl:42 [inlined]
 [17] _pullback(::Zygote.Context, ::typeof(loss), ::CuArray{Int64, 1}, ::CuArray{Int64, 1})
    @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [18] _apply
    @ .\boot.jl:804 [inlined]
 [19] adjoint
    @ ~\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:191 [inlined]
 [20] _pullback
    @ ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:57 [inlined]
 [21] _pullback
    @ ~\.julia\packages\Flux\qp1gc\src\optimise\train.jl:102 [inlined]
 [22] _pullback(::Zygote.Context, ::Flux.Optimise.var"#39#45"{typeof(loss), Tuple{CuArray{Int64, 1}, CuArray{Int64, 1}}})
    @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [23] pullback(f::Function, ps::Zygote.Params)
    @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface.jl:247
 [24] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface.jl:58
 [25] macro expansion
    @ ~\.julia\packages\Flux\qp1gc\src\optimise\train.jl:101 [inlined]
 [26] macro expansion
    @ ~\.julia\packages\Juno\n6wyj\src\progress.jl:134 [inlined]
 [27] train!(loss::Function, ps::Zygote.Params, data::Flux.Data.DataLoader{Tuple{CuArray{Int64, 1}, CuArray{Int64, 1}}, Random._GLOBAL_RNG}, opt::ADAM; cb::Flux.Optimise.var"#40#46")
    @ Flux.Optimise ~\.julia\packages\Flux\qp1gc\src\optimise\train.jl:99
 [28] train!(loss::Function, ps::Zygote.Params, data::Flux.Data.DataLoader{Tuple{CuArray{Int64, 1}, CuArray{Int64, 1}}, Random._GLOBAL_RNG}, opt::ADAM)
    @ Flux.Optimise ~\.julia\packages\Flux\qp1gc\src\optimise\train.jl:97
 [29] top-level scope
    @ e:\wqy\julia\simpleDEQ.jl:51
 [30] include(fname::String)
    @ Base.MainInclude .\client.jl:444
 [31] startdebug(socket::Base.PipeEndpoint, error_handler::VSCodeDebugger.var"#3#4"{Tuple{String, String}})
    @ VSCodeDebugger.DebugAdapter ~\.vscode\extensions\julialang.language-julia-1.1.38\scripts\packages\DebugAdapter\src\packagedef.jl:91
 [32] startdebugger()
    @ VSCodeDebugger ~\.vscode\extensions\julialang.language-julia-1.1.38\scripts\packages\VSCodeDebugger\src\VSCodeDebugger.jl:38
 [33] top-level scope
    @ ~\.vscode\extensions\julialang.language-julia-1.1.38\scripts\debugger\run_debugger.jl:9
in expression starting at e:\wqy\julia\simpleDEQ.jl:50
Copy link

Update your packages so you get SciML/DiffEqBase.jl#669 and this should be much closer:

using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
using CUDA

# FastChain with initial_params should also work
# But it's more stable to use Chain and destructure
# ann = FastChain(
#   FastDense(1, 2, relu),
#   FastDense(2, 1, tanh))
# p1 = initial_params(ann)

ann = Chain(Dense(1, 2), Dense(2, 1)) |> gpu
p,re = Flux.destructure(ann)
tspan = (0.0f0, 1.0f0)

function solve_ss(x)
    xg = gpu(x)
    z = re(p)(xg) |> gpu
    function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
        # Key question: Is there any difference between
        # re(_p)(x) and re(_p)(u+x)?
        re(_p)(u+xg) - u
    ss = SteadyStateProblem(ODEProblem(dudt_, gpu(z), tspan, p))
    x = solve(ss, DynamicSS(Tsit5()), u0 = z, abstol = 1e-2, reltol = 1e-2).u

    ss = NonlinearProblem(dudt_, gpu(z), p)
    x = solve(ss, NewtonRaphson(), u0 = z).u

# Let's run a DEQ model on linear regression for y = 2x
X = [1;2;3;4;5;6;7;8;9;10]
Y = [2;4;6;8;10;12;14;16;18;20]
data = Flux.Data.DataLoader(gpu.(collect.((X, Y))), batchsize=1,shuffle=true)
opt = ADAM(0.05)

function loss(x, y)
  ŷ = solve_ss(x)
  sum(abs2,y .- ŷ)

epochs = 100
for i in 1:epochs
    Flux.train!(loss, Flux.params(p), data, opt)
    println(solve_ss([-5])) # Print model prediction

@ChrisRackauckas ChrisRackauckas changed the title scalar getindex is disallowed SteadyStateAdjoint + GPUs: scalar getindex is disallowed May 24, 2021
Copy link

@avik-pal or @DhairyaLGandhi if you could help dig in here that would be great.

Copy link

Fixed wit the new adjoints.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
None yet
None yet

No branches or pull requests

2 participants