-
-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SteadyStateAdjoint + GPUs: scalar getindex is disallowed #557
Comments
Update your packages so you get SciML/DiffEqBase.jl#669 and this should be much closer: using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
using CUDA
CUDA.allowscalar(false)
# FastChain with initial_params should also work
# But it's more stable to use Chain and destructure
# ann = FastChain(
# FastDense(1, 2, relu),
# FastDense(2, 1, tanh))
# p1 = initial_params(ann)
ann = Chain(Dense(1, 2), Dense(2, 1)) |> gpu
p,re = Flux.destructure(ann)
tspan = (0.0f0, 1.0f0)
function solve_ss(x)
xg = gpu(x)
z = re(p)(xg) |> gpu
function dudt_(u, _p, t)
# Solving the equation f(u) - u = du = 0
# Key question: Is there any difference between
# re(_p)(x) and re(_p)(u+x)?
re(_p)(u+xg) - u
end
ss = SteadyStateProblem(ODEProblem(dudt_, gpu(z), tspan, p))
x = solve(ss, DynamicSS(Tsit5()), u0 = z, abstol = 1e-2, reltol = 1e-2).u
#=
ss = NonlinearProblem(dudt_, gpu(z), p)
x = solve(ss, NewtonRaphson(), u0 = z).u
=#
end
# Let's run a DEQ model on linear regression for y = 2x
X = [1;2;3;4;5;6;7;8;9;10]
Y = [2;4;6;8;10;12;14;16;18;20]
data = Flux.Data.DataLoader(gpu.(collect.((X, Y))), batchsize=1,shuffle=true)
opt = ADAM(0.05)
function loss(x, y)
ŷ = solve_ss(x)
sum(abs2,y .- ŷ)
end
epochs = 100
for i in 1:epochs
Flux.train!(loss, Flux.params(p), data, opt)
println(solve_ss([-5])) # Print model prediction
end |
ChrisRackauckas
changed the title
scalar getindex is disallowed
SteadyStateAdjoint + GPUs: scalar getindex is disallowed
May 24, 2021
@avik-pal or @DhairyaLGandhi if you could help dig in here that would be great. |
Fixed wit the new adjoints. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi all,
I am trying to use GPU on a SteadyStateProblem (to implement Deep Equilibrium Models to be more specific), and run into a "scalar getindex is disallowed" problem. Any help would be appreciated! Thanks a lot!
MWE:
Error:
The text was updated successfully, but these errors were encountered: