-
-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Steady State GPU Training Hangs and Quits #567
Comments
@frankschae can you take a look? |
yup, I'll have a look tomorrow. @QiyaoWei Could you post the equivalent code that worked for you on CPU? |
So on my machine, I get an output when the training is aborted. (After a few couple of hours though, so there must probably be some computation that is just very slow on GPU). ┌ Warning: Interrupted. Larger maxiters is needed.
└ @ SciMLBase ~/.julia/packages/SciMLBase/Z1NtH/src/integrator_interface.jl:331 (and the |
Allow me to respond to your question by arguing why the GPU implementation should in theory work. I will provide two working examples on GPU (non-SteadyState MNIST training and non-MNIST SteadyState training), the GPU implementation (which should replace the code above because I have done some debugging), and the CPU implementation (which unfortunately hangs, i.e. the same problem as the GPU version). As you will notice, the current GPU implementation (third code segment) is simply a combination of the two working GPU examples (first and second code segment). That is why I am confused why the training hangs in both the GPU version (third code segment) and the CPU version (fourth code segment). Here is an example of non-SteadyState MNIST training that runs
Here is an example of non-MNIST SteadyState training that runs
Here is the new GPU implementation, with the problem that training hangs
Here is the CPU version, which also hangs
|
Make sure ReverseDiffVJP is avoided (ZygoteVJP required) and the RHS function is out of place. If those two hold, then I think we're hitting some GPU bug because the last time I saw some stalls as well. |
Could it be that it's just a batchsize/data formatting error? I tried to bring the GPU steady state version down to a much simpler model: function LeNet5()
down = Chain(
Conv((5, 5), 1 => 1, relu, stride=3),
) |> f32
deq = Chain(
Conv((1, 1), 1 => 1, tanh, stride=1),
) |> f32
p, re = Flux.destructure(deq)
fc = Chain(
Dense(10, 10) |> f32,
) |> f32
tspan = (0.0f0, 1.0f0)
function solve_ss(x)
# @info "type of x is $(typeof(x))"
xg = gpu(x)
# @info "type of xg is $(typeof(xg))"
z = re(p)(xg) |> gpu
# @info "type of z is $(typeof(z))"
function dudt_(u, _p, t)
# Solving the equation f(u) - u = du = 0
u = convert(CuArray{Float32},u)
# @info "type of u is $(typeof(u))"
re(_p)(u + xg) - u
end
# @info "type of z is $(typeof(z))"
#(z) @info "type of tspan is $(typeof(tspan))"
# @info "type of p is $(typeof(p))"
ss = SteadyStateProblem(ODEProblem(dudt_, gpu(z), tspan, p))
x = solve(ss, DynamicSS(Tsit5()), u0=z, abstol=1f-2, reltol=1f-2).u
end
# Build our over-all model topology
m = Chain(
down, # (28,28,1,4) -> (8,8,1,4)
solve_ss, # (8,8,1,4) -> (8,8,1,4)
fc, #..
) |> gpu
return m, deq
end and it looks like: sol = solve(ss, DynamicSS(Tsit5()), u0=gpu(z), abstol=1f-2, reltol=1f-2); already takes very long (still runs on my machine) with input u0=8x8x1x4 CuArray and p=2 Element CuArray... |
Where is all of the time spent? Can someone share an nvproc? |
The default function eval()
device = gpu
xtrain, _ = MLDatasets.MNIST.traindata(Float32)
x = reshape(xtrain[:,:,1],28,28,1,1) |> device
deq = Chain(
Conv((1, 1), 1 => 1, tanh, stride=1),
) |> f32 |> device
p, re = Flux.destructure(deq)
xg = device(x)
z = re(p)(xg) |> device
tspan = (0.0f0, 1.0f0)
function dudt_(u, _p, t)
# Solving the equation f(u) - u = du = 0
du = re(_p)(u + xg) - u
return du
end
ss = SteadyStateProblem(ODEProblem(dudt_, device(z), tspan, p))
x = solve(ss, DynamicSS(Tsit5()), u0=device(z), abstol=1f-2, reltol=1f-2, tspan=Inf32).u
end @ChrisRackauckas should this somehow be added in SteadyStateDiffEq.jl? |
Hey Frank! By "fix" do you mean this code fixes the Float32 vs Float64 problem, or the training problem? When I add the "tspan" part in my code the training still seems to be unreponsive |
Yes, we'd need to track where that Inf is coming from but it makes sense to type it based on |
@QiyaoWei Unfortunately, I just checked the Float32 vs Float64 problem up to now .. Did you try setting a smaller number for the maximum time instead of |
hmmm setting tspan to 1.0f0 does prevent training from hanging, but I receive an error after the first epoch
Not sure if that's because I changed the tspan parameter? |
Share the entire stacktrace. You shared the least helpful part of it 😝 |
Sadly the stacktrace in this case is just one line
I also tried searching for this error and found this (https://discourse.julialang.org/t/forward-differentiation-and-differential-equations/21002/2), which seems related to the tspan typecast previously mentioned |
It doesn't show anything below |
Actually I think I previously set tspan to be the incorrect value. When I add the line
I get the error
|
@frankschae shouldn't it be calculating the vjp instead of building the jacobian? |
I think that should be ok. In |
I see. I think we just need to be a bit more careful about the eltype of that Jacobian then? |
@QiyaoWei
On CPU I get a broadcasting error related to a line after the computation of the Jacobian (which will be easy to fix though). Can you update the example such that I can reproduce the Jacobian error? (Probably it would be nice if it would be a bit simpler than the original example, i.e. without the logging, less packages, smaller batch sizes, and with simpler NN structures.) |
Hi Frank! I have also been playing around with the tspan parameter, and getting the same error (CUDA Kernel and broadcasting dimension mismatch, respectively). Could you elaborate what you mean by "example reproducing the Jacobian error"? I could certainly cut off the logging and packages |
I am a bit confused. Maybe I understood you wrong. I thought you got this error:
.. which I couldn't reproduce yet. |
Ah I see what you are referring to. Actually that error went away after I upgraded my packages and incorporated your commit of "tspan" typecast. I don't think the ForwardDiff error is a problem anymore. I cleaned up the code a bit and got rid of some logging code & packages. I also simplified the model declaration from convolutional to fully connected. I chose to stick with the MNIST dataset because simplifying more might conceal the problem. For example, non-SteadyState MNIST training and non-MNIST SteadyState training both work under GPU already. To reiterate, here's the CUDA Kernel error in the GPU code
And here's the broadcast error in the CPU version
|
Hey Frank! Thanks a lot for the fix in DiffEqSensitivity419! I updated my packages and restarted Julia, but was surprised to find that running under CPU is subject to the same problem as above DimensionMismatch("array could not be broadcast to match destination"). Could you kindly share your code that ran successfully under CPU? Thanks a lot! |
I used your code :) We need a new tag on DiffEqSensitivity @ChrisRackauckas . |
tagged |
Never mind. DiffEqSensitivity somehow decided not to automatically update :) I don't think this fix solves the GPU problem however. I encounter the same error as before (GPUCompilerKernelError), so I guess it has to do with something else |
Yeah so what's the current state of this? |
So Frank's fix in DiffEqSensitivity419 solved the problem in my previous thread, and SteadyState training with fully connected architecture is now possible, with either CPU or GPU. However, there is a problem with SteadyState applied on conv, which I am still trying to figure out.
|
Try and narrow this down to something simpler. It should be possible with just |
So if I understood you correctly, I constructed the following code with similar warning
It looks like there's some problem with constructing a full jacobian on a conv layer |
Yeah, that catches it. Your MWE probably fails on GPUs. It looks like what's happening is, if the types of the weights and inputs don't match, and aren't simple floating point numbers, then which is written only for CPU and thus it fails in that function. This comes up because the backpass of an implicit solve needs to solve a linear system. So it looks like we have to either:
|
Yes that makes sense. Also, my intuition is that Zygote.forward_jacobian will not work, because this warning happens
@frankschae Not sure the best way to fix this problem? |
The issue in DiffEqSensitivity is that there's a conflation of the sensealgs in there. https://github.com/SciML/DiffEqSensitivity.jl/blob/v6.48.0/src/steadystate_adjoint.jl#L66 autojacvec is being used in two ways, while here we would want to change only the forward mode calculation to numerical while keeping the Zygote reverse. @frankschae we should talk about resolving this, and see if we're doing too much AbstractDifferentiation.jl by hand in doing so. |
Hi all, in the process of looking at this issue I found this thread (FluxML/Flux.jl#896), might it be an alternative solution to our problem? |
Just as a pointer to anyone stumbling onto this issue. Pass |
We should update the auto heuristic for that |
Thanks @avik-pal ! That was exactly the fix this issue needed! There are some minor issues with my personal GPU right now, but the code runs without any warnings now! I'll get back to you on this after I manage to fix the GPU :) Also @ChrisRackauckas do you mean updating the adjoints documentation page? |
Try SciML/SciMLSensitivity.jl#497 . That should make this all work by default. |
I see. I am closing this issue for now with functional GPU-based convolution-DEQ implementations. Please reopen if there are any further questions!
|
|
See code above. Also, this issue might be related to #557. Although the fix in #557 fixed the code in that issue, I am using the same fix here, but having no output in training and the system simply quits after some time. Any help would be appreciated. Thanks!
The text was updated successfully, but these errors were encountered: