Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SteadyStateAdjoint + GPUs: scalar getindex is disallowed #557

Closed
QiyaoWei opened this issue May 18, 2021 · 3 comments
Closed

SteadyStateAdjoint + GPUs: scalar getindex is disallowed #557

QiyaoWei opened this issue May 18, 2021 · 3 comments

Comments

@QiyaoWei
Copy link
Contributor

Hi all,

I am trying to use GPU on a SteadyStateProblem (to implement Deep Equilibrium Models to be more specific), and run into a "scalar getindex is disallowed" problem. Any help would be appreciated! Thanks a lot!

MWE:

using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
using CUDA
CUDA.allowscalar(false)

# FastChain with initial_params should also work
# But it's more stable to use Chain and destructure
# ann = FastChain(
#   FastDense(1, 2, relu),
#   FastDense(2, 1, tanh))
# p1 = initial_params(ann)

ann = Chain(Dense(1, 2), Dense(2, 1)) |> gpu
p,re = Flux.destructure(ann)
tspan = (0.0f0, 1.0f0) |> gpu

function solve_ss(x)
    xg = gpu(x)
    z = re(p)(xg) |> gpu
    function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
        # Key question: Is there any difference between
        # re(_p)(x) and re(_p)(u+x)?
        re(_p)(u+xg) - u
    end
    ss = SteadyStateProblem(ODEProblem(gpu(dudt_), gpu(z), gpu(tspan), p))
    Array(solve(ss, DynamicSS(Tsit5()), u0 = z, sensealg=SteadyStateAdjoint()))
end

# Let's run a DEQ model on linear regression for y = 2x
X = [1;2;3;4;5;6;7;8;9;10]
Y = [2;4;6;8;10;12;14;16;18;20]
data = Flux.Data.DataLoader(gpu.(collect.((X, Y))), batchsize=1,shuffle=true)
opt = ADAM(0.05)


function loss(x, y)
  ŷ = solve_ss(x)
  @show x
  @show y
  @show ŷ
  @show sum((y .- ŷ).^2)
end

epochs = 100
for i in 1:epochs
    Flux.train!(loss, Flux.params(p), data, opt)
    println(solve_ss([-5])) # Print model prediction
end

Error:

└ @ SciMLBase C:\Users\administered\.julia\packages\SciMLBase\9EjAY\src\integrator_interface.jl:331
ERROR: LoadError: scalar getindex is disallowed
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] assertscalar(op::String)
    @ GPUArrays ~\.julia\packages\GPUArrays\gjXOn\src\host\indexing.jl:56
  [3] getindex
    @ ~\.julia\packages\GPUArrays\gjXOn\src\host\indexing.jl:98 [inlined]
  [4] getindex
    @ ~\.julia\packages\SciMLBase\9EjAY\src\solutions\solution_interface.jl:6 [inlined]
  [5] iterate
    @ .\abstractarray.jl:1096 [inlined]
  [6] iterate
    @ .\abstractarray.jl:1094 [inlined]
  [7] copyto_unaliased!(deststyle::IndexLinear, dest::Vector{Float32}, srcstyle::IndexCartesian, src::SciMLBase.NonlinearSolution{Float32, 1, CuArray{Float32, 1}, CuArray{Float32, 1}, SteadyStateProblem{CuArray{Float32, 1}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#1"{CuArray{Int64, 1}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Base .\abstractarray.jl:975
  [8] copyto!
    @ .\abstractarray.jl:950 [inlined]
  [9] copyto_axcheck!
    @ .\abstractarray.jl:1056 [inlined]
 [10] Array
    @ .\array.jl:540 [inlined]
 [11] Array
    @ .\boot.jl:472 [inlined]
 [12] adjoint
    @ ~\.julia\packages\Zygote\6HN9x\src\lib\array.jl:8 [inlined]
 [13] _pullback(__context__::Zygote.Context, 486::Type{Array}, xs::SciMLBase.NonlinearSolution{Float32, 1, CuArray{Float32, 1}, CuArray{Float32, 1}, SteadyStateProblem{CuArray{Float32, 1}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#1"{CuArray{Int64, 1}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), 
Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:57
 [14] _pullback
    @ e:\wqy\julia\simpleDEQ.jl:31 [inlined]
 [15] _pullback(ctx::Zygote.Context, f::typeof(solve_ss), args::CuArray{Int64, 1})
    @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [16] _pullback
    @ e:\wqy\julia\simpleDEQ.jl:42 [inlined]
 [17] _pullback(::Zygote.Context, ::typeof(loss), ::CuArray{Int64, 1}, ::CuArray{Int64, 1})
    @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [18] _apply
    @ .\boot.jl:804 [inlined]
 [19] adjoint
    @ ~\.julia\packages\Zygote\6HN9x\src\lib\lib.jl:191 [inlined]
 [20] _pullback
    @ ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:57 [inlined]
 [21] _pullback
    @ ~\.julia\packages\Flux\qp1gc\src\optimise\train.jl:102 [inlined]
 [22] _pullback(::Zygote.Context, ::Flux.Optimise.var"#39#45"{typeof(loss), Tuple{CuArray{Int64, 1}, CuArray{Int64, 1}}})
    @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface2.jl:0
 [23] pullback(f::Function, ps::Zygote.Params)
    @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface.jl:247
 [24] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~\.julia\packages\Zygote\6HN9x\src\compiler\interface.jl:58
 [25] macro expansion
    @ ~\.julia\packages\Flux\qp1gc\src\optimise\train.jl:101 [inlined]
 [26] macro expansion
    @ ~\.julia\packages\Juno\n6wyj\src\progress.jl:134 [inlined]
 [27] train!(loss::Function, ps::Zygote.Params, data::Flux.Data.DataLoader{Tuple{CuArray{Int64, 1}, CuArray{Int64, 1}}, Random._GLOBAL_RNG}, opt::ADAM; cb::Flux.Optimise.var"#40#46")
    @ Flux.Optimise ~\.julia\packages\Flux\qp1gc\src\optimise\train.jl:99
 [28] train!(loss::Function, ps::Zygote.Params, data::Flux.Data.DataLoader{Tuple{CuArray{Int64, 1}, CuArray{Int64, 1}}, Random._GLOBAL_RNG}, opt::ADAM)
    @ Flux.Optimise ~\.julia\packages\Flux\qp1gc\src\optimise\train.jl:97
 [29] top-level scope
    @ e:\wqy\julia\simpleDEQ.jl:51
 [30] include(fname::String)
    @ Base.MainInclude .\client.jl:444
 [31] startdebug(socket::Base.PipeEndpoint, error_handler::VSCodeDebugger.var"#3#4"{Tuple{String, String}})
    @ VSCodeDebugger.DebugAdapter ~\.vscode\extensions\julialang.language-julia-1.1.38\scripts\packages\DebugAdapter\src\packagedef.jl:91
 [32] startdebugger()
    @ VSCodeDebugger ~\.vscode\extensions\julialang.language-julia-1.1.38\scripts\packages\VSCodeDebugger\src\VSCodeDebugger.jl:38
 [33] top-level scope
    @ ~\.vscode\extensions\julialang.language-julia-1.1.38\scripts\debugger\run_debugger.jl:9
in expression starting at e:\wqy\julia\simpleDEQ.jl:50
@ChrisRackauckas
Copy link
Member

Update your packages so you get SciML/DiffEqBase.jl#669 and this should be much closer:

using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
using CUDA
CUDA.allowscalar(false)

# FastChain with initial_params should also work
# But it's more stable to use Chain and destructure
# ann = FastChain(
#   FastDense(1, 2, relu),
#   FastDense(2, 1, tanh))
# p1 = initial_params(ann)

ann = Chain(Dense(1, 2), Dense(2, 1)) |> gpu
p,re = Flux.destructure(ann)
tspan = (0.0f0, 1.0f0)

function solve_ss(x)
    xg = gpu(x)
    z = re(p)(xg) |> gpu
    function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
        # Key question: Is there any difference between
        # re(_p)(x) and re(_p)(u+x)?
        re(_p)(u+xg) - u
    end
    ss = SteadyStateProblem(ODEProblem(dudt_, gpu(z), tspan, p))
    x = solve(ss, DynamicSS(Tsit5()), u0 = z, abstol = 1e-2, reltol = 1e-2).u

    #=
    ss = NonlinearProblem(dudt_, gpu(z), p)
    x = solve(ss, NewtonRaphson(), u0 = z).u
    =#
end

# Let's run a DEQ model on linear regression for y = 2x
X = [1;2;3;4;5;6;7;8;9;10]
Y = [2;4;6;8;10;12;14;16;18;20]
data = Flux.Data.DataLoader(gpu.(collect.((X, Y))), batchsize=1,shuffle=true)
opt = ADAM(0.05)

function loss(x, y)
  ŷ = solve_ss(x)
  sum(abs2,y .- ŷ)
end

epochs = 100
for i in 1:epochs
    Flux.train!(loss, Flux.params(p), data, opt)
    println(solve_ss([-5])) # Print model prediction
end

@ChrisRackauckas ChrisRackauckas changed the title scalar getindex is disallowed SteadyStateAdjoint + GPUs: scalar getindex is disallowed May 24, 2021
@ChrisRackauckas
Copy link
Member

@avik-pal or @DhairyaLGandhi if you could help dig in here that would be great.

@ChrisRackauckas
Copy link
Member

Fixed wit the new adjoints.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants