Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Steady State GPU Training Hangs and Quits #567

Closed
QiyaoWei opened this issue May 31, 2021 · 41 comments
Closed

Steady State GPU Training Hangs and Quits #567

QiyaoWei opened this issue May 31, 2021 · 41 comments

Comments

@QiyaoWei
Copy link
Contributor

## Classification of MNIST dataset 
## with the convolutional neural network known as LeNet5.
## This script also combines various
## packages from the Julia ecosystem with Flux.
using Flux
using Flux.Data:DataLoader
using Flux.Optimise: Optimiser, WeightDecay
using Flux: onehotbatch, onecold
using Flux.Losses:logitcrossentropy
using Statistics, Random
using Logging:with_logger
using TensorBoardLogger: TBLogger, tb_overwrite, set_step!, set_step_increment!
using ProgressMeter:@showprogress
import MLDatasets
import BSON
using CUDA
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
CUDA.allowscalar(false)

# LeNet5 "constructor". 
# The model can be adapted to any image size
# and any number of output classes.
# function LeNet5(; imgsize=(28,28,1), nclasses=10) 
#   out_conv_size = (imgsize[1]÷4 - 3, imgsize[2]÷4 - 3, 16)
  
#   return Chain(
#           Conv((5, 5), imgsize[end]=>6, relu),
#           MaxPool((2, 2)),
#           Conv((5, 5), 6=>16, relu),
#           MaxPool((2, 2)),
#           flatten,
#           Dense(prod(out_conv_size), 120, relu), 
#           Dense(120, 84, relu), 
#           Dense(84, nclasses)
#         )
# end
function LeNet5() 
    down = Chain(
      Conv((3, 3), 1 => 64, relu, stride=1) |> f32,
      GroupNorm(64, 64) |> f32,
      Conv((4, 4), 64 => 64, relu, stride=2, pad=1) |> f32,
      GroupNorm(64, 64) |> f32,
      Conv((4, 4), 64 => 64, stride=2, pad=1) |> f32,
  ) |> f32

    deq = Chain(
      Conv((3, 3), 64 => 64, relu, stride=1, pad=1) |> f32,
      Conv((3, 3), 64 => 64, relu, stride=1, pad=1) |> f32,
  ) |> f32

    p, re = Flux.destructure(deq)
    fc = Chain(
      GroupNorm(64, 64) |> f32,
      x -> relu.(x) |> f32,
      MeanPool((6, 6)) |> f32,
      x -> reshape(x, (64, bs)) |> f32,
      Dense(64, 10) |> f32,
  ) |> f32


    tspan = (0.0f0, 1.0f0)
    function solve_ss(x)
        # @info "type of x is $(typeof(x))"
        xg = gpu(x)
        # @info "type of xg is $(typeof(xg))"
        z = re(p)(xg) |> gpu
        # @info "type of z is $(typeof(z))"
        function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
            u = convert(CuArray{Float32},u)
            # @info "type of u is $(typeof(u))"
            re(_p)(u + xg) - u
        end
        # @info "type of z is $(typeof(z))"
        # @info "type of tspan is $(typeof(tspan))"
        # @info "type of p is $(typeof(p))"
        ss = SteadyStateProblem(ODEProblem(dudt_, gpu(z), tspan, p))
        x = solve(ss, DynamicSS(Tsit5()), u0=z, abstol=Float32(1e-5), reltol=Float32(1e-5)).u
    end
  # Build our over-all model topology
    m = Chain(
      down,               # (28,28,1,BS) -> (6,6,64,BS)
      solve_ss,           # (6,6,64,BS) -> (6,6,64,BS)
      fc,                 # (6,6,64,BS) -> (10, BS)
  ) |> gpu

    return m
end

function get_data(args)
    xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
    xtest, ytest = MLDatasets.MNIST.testdata(Float32)

    xtrain = reshape(xtrain, 28, 28, 1, :)
    xtest = reshape(xtest, 28, 28, 1, :)

    ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

    train_loader = DataLoader((xtrain, ytrain), batchsize=args.batchsize, shuffle=true)
    test_loader = DataLoader((xtest, ytest),  batchsize=args.batchsize)
    
    return train_loader, test_loader
end

loss(ŷ, y) = logitcrossentropy(ŷ, y)

function eval_loss_accuracy(loader, model, device)
    l = 0f0
    acc = 0
    ntot = 0
    for (x, y) in loader
        x, y = x |> device, y |> device
        # @info "type of x is $(typeof(x))"
        # @info "type of y is $(typeof(y))"
        ŷ = model(x)
        l += loss(ŷ, y) * size(x)[end]        
        acc += sum(onecold(ŷ |> cpu) .== onecold(y |> cpu))
        ntot += size(x)[end]
    end
    return (loss = l / ntot |> round4, acc = acc / ntot * 100 |> round4)
end

## utility functions
num_params(model) = sum(length, Flux.params(model)) 
round4(x) = round(x, digits=4)

# arguments for the `train` function 
Base.@kwdef mutable struct Args
    η = 3e-4             # learning rate
    λ = 0                # L2 regularizer param, implemented as weight decay
    batchsize = 128      # batch size
    epochs = 10          # number of epochs
    seed = 0             # set seed > 0 for reproducibility
    use_cuda = true      # if true use cuda (if available)
    infotime = 1 	     # report every `infotime` epochs
    checktime = 5        # Save the model every `checktime` epochs. Set to 0 for no checkpoints.
    tblogger = true      # log training with tensorboard
    savepath = "runs/"    # results path
end

function train(; kws...)
    args = Args(; kws...)
    args.seed > 0 && Random.seed!(args.seed)
    use_cuda = args.use_cuda && CUDA.functional()
    
    if use_cuda
        device = gpu
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end

    ## DATA
    train_loader, test_loader = get_data(args)
    @info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples"

    ## MODEL AND OPTIMIZER
    model = LeNet5() |> device
    @info "LeNet5 model: $(num_params(model)) trainable params"    
    
    ps = Flux.params(model)  

    opt = ADAM(args.η) 
    if args.λ > 0 # add weight decay, equivalent to L2 regularization
        opt = Optimiser(opt, WeightDecay(args.λ))
    end
    
    ## LOGGING UTILITIES
    if args.tblogger 
        tblogger = TBLogger(args.savepath, tb_overwrite)
        set_step_increment!(tblogger, 0) # 0 auto increment since we manually set_step!
        @info "TensorBoard logging at \"$(args.savepath)\""
    end
    
    function report(epoch)
        train = eval_loss_accuracy(train_loader, model, device)
        test = eval_loss_accuracy(test_loader, model, device)        
        println("Epoch: $epoch   Train: $(train)   Test: $(test)")
        if args.tblogger
            set_step!(tblogger, epoch)
            with_logger(tblogger) do
                @info "train" loss = train.loss  acc = train.acc
                @info "test"  loss = test.loss   acc = test.acc
            end
        end
    end
    
    ## TRAINING
    @info "Start Training"
    report(0)
    for epoch in 1:args.epochs
        @showprogress for (x, y) in train_loader
            x, y = x |> device, y |> device
            gs = Flux.gradient(ps) do
                ŷ = model(x)
                loss(ŷ, y)
            end

            Flux.Optimise.update!(opt, ps, gs)
        end
        
        ## Printing and logging
        epoch % args.infotime == 0 && report(epoch)
        if args.checktime > 0 && epoch % args.checktime == 0
            !ispath(args.savepath) && mkpath(args.savepath)
            modelpath = joinpath(args.savepath, "model.bson") 
            let model = cpu(model) # return model to cpu before serialization
                BSON.@save modelpath model epoch
            end
            @info "Model saved in \"$(modelpath)\""
        end
    end
end

train()

See code above. Also, this issue might be related to #557. Although the fix in #557 fixed the code in that issue, I am using the same fix here, but having no output in training and the system simply quits after some time. Any help would be appreciated. Thanks!

@ChrisRackauckas
Copy link
Member

@frankschae can you take a look?

@frankschae
Copy link
Member

yup, I'll have a look tomorrow. @QiyaoWei Could you post the equivalent code that worked for you on CPU?

@frankschae
Copy link
Member

So on my machine, I get an output when the training is aborted. (After a few couple of hours though, so there must probably be some computation that is just very slow on GPU).

┌ Warning: Interrupted. Larger maxiters is needed.
└ @ SciMLBase ~/.julia/packages/SciMLBase/Z1NtH/src/integrator_interface.jl:331

(and the use_cuda = false version aborts because it's not set fully consistently, so some portions are actually copied to the GPU.)

@QiyaoWei
Copy link
Contributor Author

QiyaoWei commented Jun 1, 2021

Allow me to respond to your question by arguing why the GPU implementation should in theory work. I will provide two working examples on GPU (non-SteadyState MNIST training and non-MNIST SteadyState training), the GPU implementation (which should replace the code above because I have done some debugging), and the CPU implementation (which unfortunately hangs, i.e. the same problem as the GPU version). As you will notice, the current GPU implementation (third code segment) is simply a combination of the two working GPU examples (first and second code segment). That is why I am confused why the training hangs in both the GPU version (third code segment) and the CPU version (fourth code segment).

Here is an example of non-SteadyState MNIST training that runs

## Classification of MNIST dataset 
## with the convolutional neural network known as LeNet5.
## This script also combines various
## packages from the Julia ecosystem with Flux.
using Flux
using Flux.Data:DataLoader
using Flux.Optimise: Optimiser, WeightDecay
using Flux: onehotbatch, onecold
using Flux.Losses:logitcrossentropy
using Statistics, Random
using Logging:with_logger
using TensorBoardLogger: TBLogger, tb_overwrite, set_step!, set_step_increment!
using ProgressMeter:@showprogress
import MLDatasets
import BSON
using CUDA
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
CUDA.allowscalar(false)

# LeNet5 "constructor". 
# The model can be adapted to any image size
# and any number of output classes.
# function LeNet5(; imgsize=(28,28,1), nclasses=10) 
#   out_conv_size = (imgsize[1]÷4 - 3, imgsize[2]÷4 - 3, 16)
  
#   return Chain(
#           Conv((5, 5), imgsize[end]=>6, relu),
#           MaxPool((2, 2)),
#           Conv((5, 5), 6=>16, relu),
#           MaxPool((2, 2)),
#           flatten,
#           Dense(prod(out_conv_size), 120, relu), 
#           Dense(120, 84, relu), 
#           Dense(84, nclasses)
#         )
# end

function LeNet5() 


    down = Chain(Conv((3, 3), 1=>64, relu, stride = 1), GroupNorm(64, 64),
                 Conv((4, 4), 64=>64, relu, stride = 2, pad=1), GroupNorm(64, 64),
                 Conv((4, 4), 64=>64, stride = 2, pad = 1))
    deq = Chain(Conv((3, 3), 64=>64, tanh, stride=1, pad=1),
                 Conv((3, 3), 64=>64, tanh, stride=1, pad=1))
    #p, re = Flux.destructure(deq)
    fc = Chain(GroupNorm(64, 64), x -> relu.(x), MeanPool((6, 6)),
               x -> reshape(x, (64, :)), Dense(64,10))

    #tspan = (0.0f0, 1.0f0)
    #function solve_ss(x)
        # @info "type of x is $(typeof(x))"
        # xg = gpu(x)
        # @info "type of xg is $(typeof(xg))"
        # z = re(p)(xg) |> gpu
        #z = re(p)(x)
        # @info "type of z is $(typeof(z))"
        #function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
            # u = convert(CuArray{Float32},u)
            #u = convert(Array{Float32},u)
            # @info "type of u is $(typeof(u))"
            #re(_p)(u + x) - u
            # re(_p)(u + xg) - u
        #end
        # @info "type of z is $(typeof(z))"
        # @info "type of tspan is $(typeof(tspan))"
        # @info "type of p is $(typeof(p))"
        # ss = SteadyStateProblem(ODEProblem(dudt_, gpu(z), tspan, p))
        #ss = SteadyStateProblem(ODEProblem(dudt_, z, tspan, p))
        #x = solve(ss, DynamicSS(Tsit5()), u0=z, abstol=Float32(1e-5), reltol=Float32(1e-5)).u
    #end
  # Build our over-all model topology
    m = Chain(
        down,               # (28,28,1,BS) -> (6,6,64,BS)
        #solve_ss,           # (6,6,64,BS) -> (6,6,64,BS)
        deq,           # (6,6,64,BS) -> (6,6,64,BS)
        fc,                 # (6,6,64,BS) -> (10, BS)
  )

    return m
end

function get_data(args)
    xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
    xtest, ytest = MLDatasets.MNIST.testdata(Float32)

    xtrain = reshape(xtrain, 28, 28, 1, :)
    xtest = reshape(xtest, 28, 28, 1, :)

    ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

    train_loader = DataLoader((xtrain, ytrain), batchsize=args.batchsize, shuffle=true)
    test_loader = DataLoader((xtest, ytest),  batchsize=args.batchsize)
    
    return train_loader, test_loader
end

loss(ŷ, y) = logitcrossentropy(ŷ, y)

function eval_loss_accuracy(loader, model, device)
    l = 0f0
    acc = 0
    ntot = 0
    for (x, y) in loader
        x, y = x |> device, y |> device
        # @info "type of x is $(typeof(x))"
        # @info "type of y is $(typeof(y))"
        ŷ = model(x)
        l += loss(ŷ, y) * size(x)[end]        
        acc += sum(onecold(ŷ |> cpu) .== onecold(y |> cpu))
        ntot += size(x)[end]
    end
    return (loss = l / ntot |> round4, acc = acc / ntot * 100 |> round4)
end

## utility functions
num_params(model) = sum(length, Flux.params(model)) 
round4(x) = round(x, digits=4)

# arguments for the `train` function 
Base.@kwdef mutable struct Args
    η = 3e-4             # learning rate
    λ = 0                # L2 regularizer param, implemented as weight decay
    batchsize = 128      # batch size
    epochs = 10          # number of epochs
    seed = 0             # set seed > 0 for reproducibility
    use_cuda = true      # if true use cuda (if available)
    infotime = 1 	     # report every `infotime` epochs
    checktime = 5        # Save the model every `checktime` epochs. Set to 0 for no checkpoints.
    tblogger = true      # log training with tensorboard
    savepath = "runs/"    # results path
end

function train(; kws...)
    args = Args(; kws...)
    args.seed > 0 && Random.seed!(args.seed)
    use_cuda = args.use_cuda && CUDA.functional()
    
    if use_cuda
        device = gpu
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end

    ## DATA
    train_loader, test_loader = get_data(args)
    @info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples"

    ## MODEL AND OPTIMIZER
    model = LeNet5() |> device
    @info "LeNet5 model: $(num_params(model)) trainable params"    
    
    ps = Flux.params(model)  

    opt = ADAM(args.η) 
    if args.λ > 0 # add weight decay, equivalent to L2 regularization
        opt = Optimiser(opt, WeightDecay(args.λ))
    end
    
    ## LOGGING UTILITIES
    if args.tblogger 
        tblogger = TBLogger(args.savepath, tb_overwrite)
        set_step_increment!(tblogger, 0) # 0 auto increment since we manually set_step!
        @info "TensorBoard logging at \"$(args.savepath)\""
    end
    
    function report(epoch)
        train = eval_loss_accuracy(train_loader, model, device)
        test = eval_loss_accuracy(test_loader, model, device)        
        println("Epoch: $epoch   Train: $(train)   Test: $(test)")
        if args.tblogger
            set_step!(tblogger, epoch)
            with_logger(tblogger) do
                @info "train" loss = train.loss  acc = train.acc
                @info "test"  loss = test.loss   acc = test.acc
            end
        end
    end
    
    ## TRAINING
    @info "Start Training"
    report(0)
    for epoch in 1:args.epochs
        @showprogress for (x, y) in train_loader
            x, y = x |> device, y |> device
            gs = Flux.gradient(ps) do
                ŷ = model(x)
                loss(ŷ, y)
            end

            Flux.Optimise.update!(opt, ps, gs)
        end
        
        ## Printing and logging
        epoch % args.infotime == 0 && report(epoch)
        if args.checktime > 0 && epoch % args.checktime == 0
            !ispath(args.savepath) && mkpath(args.savepath)
            modelpath = joinpath(args.savepath, "model.bson") 
            let model = cpu(model) # return model to cpu before serialization
                BSON.@save modelpath model epoch
            end
            @info "Model saved in \"$(modelpath)\""
        end
    end
end

train()

Here is an example of non-MNIST SteadyState training that runs

using Flux
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
using CUDA
CUDA.allowscalar(false)

# FastChain with initial_params should also work
# But it's more stable to use Chain and destructure
# ann = FastChain(
#   FastDense(1, 2, relu),
#   FastDense(2, 1, tanh))
# p1 = initial_params(ann)

ann = Chain(Dense(1, 2), Dense(2, 1)) |> gpu
p,re = Flux.destructure(ann)
tspan = (0.0f0, 1.0f0)

function solve_ss(x)
    xg = gpu(x)
    z = re(p)(xg) |> gpu
    #z = re(p)(x)
    function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
        # Key question: Is there any difference between
        # re(_p)(x) and re(_p)(u+x)?
        re(_p)(u+xg) - u
        #re(_p)(u+x) - u
    end
    ss = SteadyStateProblem(ODEProblem(dudt_, gpu(z), tspan, p))
    #ss = SteadyStateProblem(ODEProblem(dudt_, z, tspan, p))
    x = solve(ss, DynamicSS(Tsit5()), u0 = z, abstol = 1e-5, reltol = 1e-5).u

    #=
    ss = NonlinearProblem(dudt_, gpu(z), p)
    x = solve(ss, NewtonRaphson(), u0 = z).u
    =#
end

# Let's run a DEQ model on linear regression for y = 2x
X = [1;2;3;4;5;6;7;8;9;10]
Y = [2;4;6;8;10;12;14;16;18;20]
data = Flux.Data.DataLoader(gpu.(collect.((X, Y))), batchsize=1,shuffle=true)
#data = Flux.Data.DataLoader(collect.((X, Y)), batchsize=1,shuffle=true)
opt = ADAM(0.05)

function loss(x, y)
  ŷ = solve_ss(x)
  sum(abs2,y .- ŷ)
end

epochs = 100
for i in 1:epochs
    Flux.train!(loss, Flux.params(p), data, opt)
    println(solve_ss([-5])) # Print model prediction
end

Here is the new GPU implementation, with the problem that training hangs

## Classification of MNIST dataset 
## with the convolutional neural network known as LeNet5.
## This script also combines various
## packages from the Julia ecosystem with Flux.
using Flux
using Flux.Data:DataLoader
using Flux.Optimise: Optimiser, WeightDecay
using Flux: onehotbatch, onecold
using Flux.Losses:logitcrossentropy
using Statistics, Random
using Logging:with_logger
using TensorBoardLogger: TBLogger, tb_overwrite, set_step!, set_step_increment!
using ProgressMeter:@showprogress
import MLDatasets
import BSON
using CUDA
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
CUDA.allowscalar(false)

# LeNet5 "constructor". 
# The model can be adapted to any image size
# and any number of output classes.
# function LeNet5(; imgsize=(28,28,1), nclasses=10) 
#   out_conv_size = (imgsize[1]÷4 - 3, imgsize[2]÷4 - 3, 16)
  
#   return Chain(
#           Conv((5, 5), imgsize[end]=>6, relu),
#           MaxPool((2, 2)),
#           Conv((5, 5), 6=>16, relu),
#           MaxPool((2, 2)),
#           flatten,
#           Dense(prod(out_conv_size), 120, relu), 
#           Dense(120, 84, relu), 
#           Dense(84, nclasses)
#         )
# end

function LeNet5() 


    down = Chain(Conv((3, 3), 1=>64, relu, stride = 1), GroupNorm(64, 64),
                 Conv((4, 4), 64=>64, relu, stride = 2, pad=1), GroupNorm(64, 64),
                 Conv((4, 4), 64=>64, stride = 2, pad = 1))
    deq = Chain(Conv((3, 3), 64=>64, tanh, stride=1, pad=1),
                 Conv((3, 3), 64=>64, tanh, stride=1, pad=1))
    p, re = Flux.destructure(deq)
    fc = Chain(GroupNorm(64, 64), x -> relu.(x), MeanPool((6, 6)),
               x -> reshape(x, (64, :)), Dense(64,10))

    tspan = (0.0f0, 1.0f0)
    function solve_ss(x)
        # @info "type of x is $(typeof(x))"
        xg = gpu(x)
        # @info "type of xg is $(typeof(xg))"
        z = re(p)(xg) |> gpu
        #z = re(p)(x)
        # @info "type of z is $(typeof(z))"
        function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
            u = convert(CuArray{Float32},u)
            #u = convert(Array{Float32},u)
            # @info "type of u is $(typeof(u))"
            #re(_p)(u + x) - u
            re(_p)(u + xg) - u
        end
        # @info "type of z is $(typeof(z))"
        # @info "type of tspan is $(typeof(tspan))"
        # @info "type of p is $(typeof(p))"
        ss = SteadyStateProblem(ODEProblem(dudt_, gpu(z), tspan, p))
        #ss = SteadyStateProblem(ODEProblem(dudt_, z, tspan, p))
        x = solve(ss, DynamicSS(Tsit5()), u0=z, abstol=Float32(1e-5), reltol=Float32(1e-5)).u
    #end
  # Build our over-all model topology
    m = Chain(
        down,               # (28,28,1,BS) -> (6,6,64,BS)
        solve_ss,           # (6,6,64,BS) -> (6,6,64,BS)
        #deq,           # (6,6,64,BS) -> (6,6,64,BS)
        fc,                 # (6,6,64,BS) -> (10, BS)
  )

    return m
end

function get_data(args)
    xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
    xtest, ytest = MLDatasets.MNIST.testdata(Float32)

    xtrain = reshape(xtrain, 28, 28, 1, :)
    xtest = reshape(xtest, 28, 28, 1, :)

    ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

    train_loader = DataLoader((xtrain, ytrain), batchsize=args.batchsize, shuffle=true)
    test_loader = DataLoader((xtest, ytest),  batchsize=args.batchsize)
    
    return train_loader, test_loader
end

loss(ŷ, y) = logitcrossentropy(ŷ, y)

function eval_loss_accuracy(loader, model, device)
    l = 0f0
    acc = 0
    ntot = 0
    for (x, y) in loader
        x, y = x |> device, y |> device
        # @info "type of x is $(typeof(x))"
        # @info "type of y is $(typeof(y))"
        ŷ = model(x)
        l += loss(ŷ, y) * size(x)[end]        
        acc += sum(onecold(ŷ |> cpu) .== onecold(y |> cpu))
        ntot += size(x)[end]
    end
    return (loss = l / ntot |> round4, acc = acc / ntot * 100 |> round4)
end

## utility functions
num_params(model) = sum(length, Flux.params(model)) 
round4(x) = round(x, digits=4)

# arguments for the `train` function 
Base.@kwdef mutable struct Args
    η = 3e-4             # learning rate
    λ = 0                # L2 regularizer param, implemented as weight decay
    batchsize = 128      # batch size
    epochs = 10          # number of epochs
    seed = 0             # set seed > 0 for reproducibility
    use_cuda = true      # if true use cuda (if available)
    infotime = 1 	     # report every `infotime` epochs
    checktime = 5        # Save the model every `checktime` epochs. Set to 0 for no checkpoints.
    tblogger = true      # log training with tensorboard
    savepath = "runs/"    # results path
end

function train(; kws...)
    args = Args(; kws...)
    args.seed > 0 && Random.seed!(args.seed)
    use_cuda = args.use_cuda && CUDA.functional()
    
    if use_cuda
        device = gpu
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end

    ## DATA
    train_loader, test_loader = get_data(args)
    @info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples"

    ## MODEL AND OPTIMIZER
    model = LeNet5() |> device
    @info "LeNet5 model: $(num_params(model)) trainable params"    
    
    ps = Flux.params(model)  

    opt = ADAM(args.η) 
    if args.λ > 0 # add weight decay, equivalent to L2 regularization
        opt = Optimiser(opt, WeightDecay(args.λ))
    end
    
    ## LOGGING UTILITIES
    if args.tblogger 
        tblogger = TBLogger(args.savepath, tb_overwrite)
        set_step_increment!(tblogger, 0) # 0 auto increment since we manually set_step!
        @info "TensorBoard logging at \"$(args.savepath)\""
    end
    
    function report(epoch)
        train = eval_loss_accuracy(train_loader, model, device)
        test = eval_loss_accuracy(test_loader, model, device)        
        println("Epoch: $epoch   Train: $(train)   Test: $(test)")
        if args.tblogger
            set_step!(tblogger, epoch)
            with_logger(tblogger) do
                @info "train" loss = train.loss  acc = train.acc
                @info "test"  loss = test.loss   acc = test.acc
            end
        end
    end
    
    ## TRAINING
    @info "Start Training"
    report(0)
    for epoch in 1:args.epochs
        @showprogress for (x, y) in train_loader
            x, y = x |> device, y |> device
            gs = Flux.gradient(ps) do
                ŷ = model(x)
                loss(ŷ, y)
            end

            Flux.Optimise.update!(opt, ps, gs)
        end
        
        ## Printing and logging
        epoch % args.infotime == 0 && report(epoch)
        if args.checktime > 0 && epoch % args.checktime == 0
            !ispath(args.savepath) && mkpath(args.savepath)
            modelpath = joinpath(args.savepath, "model.bson") 
            let model = cpu(model) # return model to cpu before serialization
                BSON.@save modelpath model epoch
            end
            @info "Model saved in \"$(modelpath)\""
        end
    end
end

train()

Here is the CPU version, which also hangs

## Classification of MNIST dataset 
## with the convolutional neural network known as LeNet5.
## This script also combines various
## packages from the Julia ecosystem with Flux.
using Flux
using Flux.Data:DataLoader
using Flux.Optimise: Optimiser, WeightDecay
using Flux: onehotbatch, onecold
using Flux.Losses:logitcrossentropy
using Statistics, Random
using Logging:with_logger
using TensorBoardLogger: TBLogger, tb_overwrite, set_step!, set_step_increment!
using ProgressMeter:@showprogress
import MLDatasets
import BSON
using CUDA
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
CUDA.allowscalar(false)

# LeNet5 "constructor". 
# The model can be adapted to any image size
# and any number of output classes.
# function LeNet5(; imgsize=(28,28,1), nclasses=10) 
#   out_conv_size = (imgsize[1]÷4 - 3, imgsize[2]÷4 - 3, 16)
  
#   return Chain(
#           Conv((5, 5), imgsize[end]=>6, relu),
#           MaxPool((2, 2)),
#           Conv((5, 5), 6=>16, relu),
#           MaxPool((2, 2)),
#           flatten,
#           Dense(prod(out_conv_size), 120, relu), 
#           Dense(120, 84, relu), 
#           Dense(84, nclasses)
#         )
# end

function LeNet5() 


    down = Chain(Conv((3, 3), 1=>64, relu, stride = 1), GroupNorm(64, 64),
                 Conv((4, 4), 64=>64, relu, stride = 2, pad=1), GroupNorm(64, 64),
                 Conv((4, 4), 64=>64, stride = 2, pad = 1))
    deq = Chain(Conv((3, 3), 64=>64, tanh, stride=1, pad=1),
                 Conv((3, 3), 64=>64, tanh, stride=1, pad=1))
    p, re = Flux.destructure(deq)
    fc = Chain(GroupNorm(64, 64), x -> relu.(x), MeanPool((6, 6)),
               x -> reshape(x, (64, :)), Dense(64,10))

    tspan = (0.0f0, 1.0f0)
    function solve_ss(x)
        # @info "type of x is $(typeof(x))"
        # xg = gpu(x)
        # @info "type of xg is $(typeof(xg))"
        # z = re(p)(xg) |> gpu
        z = re(p)(x)
        # @info "type of z is $(typeof(z))"
        function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
            # u = convert(CuArray{Float32},u)
            u = convert(Array{Float32},u)
            # @info "type of u is $(typeof(u))"
            re(_p)(u + x) - u
            # re(_p)(u + xg) - u
        end
        # @info "type of z is $(typeof(z))"
        # @info "type of tspan is $(typeof(tspan))"
        # @info "type of p is $(typeof(p))"
        # ss = SteadyStateProblem(ODEProblem(dudt_, gpu(z), tspan, p))
        ss = SteadyStateProblem(ODEProblem(dudt_, z, tspan, p))
        x = solve(ss, DynamicSS(Tsit5()), u0=z, abstol=Float32(1e-5), reltol=Float32(1e-5)).u
    end
  # Build our over-all model topology
    m = Chain(
        down,               # (28,28,1,BS) -> (6,6,64,BS)
        solve_ss,           # (6,6,64,BS) -> (6,6,64,BS)
        #deq,           # (6,6,64,BS) -> (6,6,64,BS)
        fc,                 # (6,6,64,BS) -> (10, BS)
  )

    return m
end

function get_data(args)
    xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
    xtest, ytest = MLDatasets.MNIST.testdata(Float32)

    xtrain = reshape(xtrain, 28, 28, 1, :)
    xtest = reshape(xtest, 28, 28, 1, :)

    ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

    train_loader = DataLoader((xtrain, ytrain), batchsize=args.batchsize, shuffle=true)
    test_loader = DataLoader((xtest, ytest),  batchsize=args.batchsize)
    
    return train_loader, test_loader
end

loss(ŷ, y) = logitcrossentropy(ŷ, y)

function eval_loss_accuracy(loader, model, device)
    l = 0f0
    acc = 0
    ntot = 0
    for (x, y) in loader
        x, y = x |> device, y |> device
        # @info "type of x is $(typeof(x))"
        # @info "type of y is $(typeof(y))"
        ŷ = model(x)
        l += loss(ŷ, y) * size(x)[end]        
        acc += sum(onecold(ŷ |> cpu) .== onecold(y |> cpu))
        ntot += size(x)[end]
    end
    return (loss = l / ntot |> round4, acc = acc / ntot * 100 |> round4)
end

## utility functions
num_params(model) = sum(length, Flux.params(model)) 
round4(x) = round(x, digits=4)

# arguments for the `train` function 
Base.@kwdef mutable struct Args
    η = 3e-4             # learning rate
    λ = 0                # L2 regularizer param, implemented as weight decay
    batchsize = 128      # batch size
    epochs = 10          # number of epochs
    seed = 0             # set seed > 0 for reproducibility
    use_cuda = false      # if true use cuda (if available)
    infotime = 1 	     # report every `infotime` epochs
    checktime = 5        # Save the model every `checktime` epochs. Set to 0 for no checkpoints.
    tblogger = true      # log training with tensorboard
    savepath = "runs/"    # results path
end

function train(; kws...)
    args = Args(; kws...)
    args.seed > 0 && Random.seed!(args.seed)
    use_cuda = args.use_cuda && CUDA.functional()
    
    if use_cuda
        device = gpu
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end

    ## DATA
    train_loader, test_loader = get_data(args)
    @info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples"

    ## MODEL AND OPTIMIZER
    model = LeNet5() |> device
    @info "LeNet5 model: $(num_params(model)) trainable params"    
    
    ps = Flux.params(model)  

    opt = ADAM(args.η) 
    if args.λ > 0 # add weight decay, equivalent to L2 regularization
        opt = Optimiser(opt, WeightDecay(args.λ))
    end
    
    ## LOGGING UTILITIES
    if args.tblogger 
        tblogger = TBLogger(args.savepath, tb_overwrite)
        set_step_increment!(tblogger, 0) # 0 auto increment since we manually set_step!
        @info "TensorBoard logging at \"$(args.savepath)\""
    end
    
    function report(epoch)
        train = eval_loss_accuracy(train_loader, model, device)
        test = eval_loss_accuracy(test_loader, model, device)        
        println("Epoch: $epoch   Train: $(train)   Test: $(test)")
        if args.tblogger
            set_step!(tblogger, epoch)
            with_logger(tblogger) do
                @info "train" loss = train.loss  acc = train.acc
                @info "test"  loss = test.loss   acc = test.acc
            end
        end
    end
    
    ## TRAINING
    @info "Start Training"
    report(0)
    for epoch in 1:args.epochs
        @showprogress for (x, y) in train_loader
            x, y = x |> device, y |> device
            gs = Flux.gradient(ps) do
                ŷ = model(x)
                loss(ŷ, y)
            end

            Flux.Optimise.update!(opt, ps, gs)
        end
        
        ## Printing and logging
        epoch % args.infotime == 0 && report(epoch)
        if args.checktime > 0 && epoch % args.checktime == 0
            !ispath(args.savepath) && mkpath(args.savepath)
            modelpath = joinpath(args.savepath, "model.bson") 
            let model = cpu(model) # return model to cpu before serialization
                BSON.@save modelpath model epoch
            end
            @info "Model saved in \"$(modelpath)\""
        end
    end
end

train()

@ChrisRackauckas
Copy link
Member

Make sure ReverseDiffVJP is avoided (ZygoteVJP required) and the RHS function is out of place. If those two hold, then I think we're hitting some GPU bug because the last time I saw some stalls as well.

@frankschae
Copy link
Member

Could it be that it's just a batchsize/data formatting error? I tried to bring the GPU steady state version down to a much simpler model:

function LeNet5()
    down = Chain(
      Conv((5, 5), 1 => 1, relu, stride=3),
  ) |> f32

    deq = Chain(
      Conv((1, 1), 1 => 1, tanh, stride=1),
  ) |> f32

    p, re = Flux.destructure(deq)
    fc = Chain(
      Dense(10, 10) |> f32,
  ) |> f32


    tspan = (0.0f0, 1.0f0)
    function solve_ss(x)
        # @info "type of x is $(typeof(x))"
        xg = gpu(x)
        # @info "type of xg is $(typeof(xg))"
        z = re(p)(xg) |> gpu
        # @info "type of z is $(typeof(z))"
        function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
            u = convert(CuArray{Float32},u)
            # @info "type of u is $(typeof(u))"
            re(_p)(u + xg) - u
        end
        # @info "type of z is $(typeof(z))"
        #(z) @info "type of tspan is $(typeof(tspan))"
        # @info "type of p is $(typeof(p))"
        ss = SteadyStateProblem(ODEProblem(dudt_, gpu(z), tspan, p))
        x = solve(ss, DynamicSS(Tsit5()), u0=z, abstol=1f-2, reltol=1f-2).u
    end
  # Build our over-all model topology
    m = Chain(
      down,               # (28,28,1,4) -> (8,8,1,4)
      solve_ss,           # (8,8,1,4) -> (8,8,1,4)
      fc,                 #..
  ) |> gpu

    return m, deq
end

and it looks like:

sol = solve(ss, DynamicSS(Tsit5()), u0=gpu(z), abstol=1f-2, reltol=1f-2);

already takes very long (still runs on my machine) with input u0=8x8x1x4 CuArray and p=2 Element CuArray...

@ChrisRackauckas
Copy link
Member

Where is all of the time spent? Can someone share an nvproc?

@frankschae
Copy link
Member

The default inf causes the type change between Float32 and Float64. The fix is:

function eval()
    device = gpu
    xtrain, _ = MLDatasets.MNIST.traindata(Float32)
    x = reshape(xtrain[:,:,1],28,28,1,1) |> device

    deq = Chain(
      Conv((1, 1), 1 => 1, tanh, stride=1),
    ) |> f32 |> device
    p, re = Flux.destructure(deq)

    xg = device(x)
    z = re(p)(xg) |> device
    tspan = (0.0f0, 1.0f0)
    function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
        du = re(_p)(u + xg) - u
        return du
    end

    ss = SteadyStateProblem(ODEProblem(dudt_, device(z), tspan, p))
    x = solve(ss, DynamicSS(Tsit5()), u0=device(z), abstol=1f-2, reltol=1f-2, tspan=Inf32).u
end

@ChrisRackauckas should this somehow be added in SteadyStateDiffEq.jl?

@QiyaoWei
Copy link
Contributor Author

QiyaoWei commented Jun 2, 2021

Hey Frank! By "fix" do you mean this code fixes the Float32 vs Float64 problem, or the training problem? When I add the "tspan" part in my code the training still seems to be unreponsive

@ChrisRackauckas
Copy link
Member

Yes, we'd need to track where that Inf is coming from but it makes sense to type it based on u0 given no other information.

@frankschae
Copy link
Member

@QiyaoWei Unfortunately, I just checked the Float32 vs Float64 problem up to now .. Did you try setting a smaller number for the maximum time instead of inf?

@QiyaoWei
Copy link
Contributor Author

QiyaoWei commented Jun 3, 2021

hmmm setting tspan to 1.0f0 does prevent training from hanging, but I receive an error after the first epoch

LoadError: TypeError: in typeassert, expected Float32, got a value of type ForwardDiff.Dual{Nothing, Float32, 12}

Not sure if that's because I changed the tspan parameter?

@ChrisRackauckas
Copy link
Member

Share the entire stacktrace. You shared the least helpful part of it 😝

@QiyaoWei
Copy link
Contributor Author

QiyaoWei commented Jun 3, 2021

Sadly the stacktrace in this case is just one line

[ Info: Training on GPU
[ Info: Dataset MNIST: 60000 train and 10000 test examples
[ Info: LeNet5 model: 15910 trainable params
[ Info: TensorBoard logging at "runs/"
[ Info: Start Training
Epoch: 0   Train: (loss = 2.3069f0, acc = 14.2867)   Test: (loss = 2.3052f0, acc = 14.66)
ERROR: LoadError: TypeError: in typeassert, expected Float32, got a value of type ForwardDiff.Dual{Nothing, Float32, 12}
Stacktrace:
 [1] setindex!(A::Matrix{Float32}, x::ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, i1::Int64)
   @ Base .\array.jl:839
in expression starting at e:\wqy\julia\test.jl:252

I also tried searching for this error and found this (https://discourse.julialang.org/t/forward-differentiation-and-differential-equations/21002/2), which seems related to the tspan typecast previously mentioned

@ChrisRackauckas
Copy link
Member

It doesn't show anything below [1] setindex!(A::Matrix{Float32}, x::ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, i1::Int64) @ Base .\array.jl:839 in expression starting at e:\wqy\julia\test.jl:252 ? That's what I'm looking for.

@QiyaoWei
Copy link
Contributor Author

QiyaoWei commented Jun 3, 2021

Actually I think I previously set tspan to be the incorrect value. When I add the line

x = solve(ss, DynamicSS(Tsit5()), u0=device(z), abstol=1f-2, reltol=1f-2, tspan=1.0f0).u

I get the error

[ Info: Training on GPU
[ Info: Dataset MNIST: 60000 train and 10000 test examples
[ Info: LeNet5 model: 15910 trainable params
[ Info: TensorBoard logging at "runs/"
[ Info: Start Training
Epoch: 0   Train: (loss = 2.3311f0, acc = 11.9333)   Test: (loss = 2.3328f0, acc = 12.04)
ERROR: LoadError: TypeError: in typeassert, expected Float32, got a value of type ForwardDiff.Dual{Nothing, Float32, 12}
Stacktrace:
  [1] setindex!(A::Matrix{Float32}, x::ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, 
CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, i1::Int64)
    @ Base .\array.jl:839
  [2] _unsafe_copyto!(dest::Matrix{Float32}, doffs::Int64, src::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}}, soffs::Int64, n::Int64)
    @ Base .\array.jl:235
  [3] unsafe_copyto!
    @ .\array.jl:289 [inlined]
  [4] _copyto_impl!
    @ .\array.jl:313 [inlined]
  [5] copyto!
    @ .\array.jl:299 [inlined]
  [6] copyto!
    @ .\array.jl:325 [inlined]
  [7] copyto_axcheck!(dest::Matrix{Float32}, src::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}})
    @ Base .\abstractarray.jl:1056
  [8] Array
    @ .\array.jl:540 [inlined]
  [9] convert
    @ .\array.jl:532 [inlined]
 [10] convert(AT::Type{Matrix{Float32}}, A::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 2})
    @ GPUArrays ~\.julia\packages\GPUArrays\Z5nPF\src\host\construction.jl:90
 [11] convert(AT::Type{CuArray{Float32, N} where N}, A::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 2})
    @ GPUArrays ~\.julia\packages\GPUArrays\Z5nPF\src\host\construction.jl:82
 [12] (::var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}})(u::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 2}, _p::CuArray{Float32, 1}, t::Nothing)    @ Main e:\wqy\julia\test.jl:96
 [13] ODEFunction
    @ ~\.julia\packages\SciMLBase\grNUR\src\scimlfunctions.jl:334 [inlined]
 [14] UDerivativeWrapper
    @ ~\.julia\packages\SciMLBase\grNUR\src\function_wrappers.jl:30 [inlined]
 [15] chunk_mode_jacobian(f::SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, x::CuArray{Float32, 2}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, 
var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12, CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 2}})
    @ ForwardDiff ~\.julia\packages\ForwardDiff\QOqCN\src\jacobian.jl:223
 [16] jacobian(f::Function, x::CuArray{Float32, 2}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, 
Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12, CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 2}}, ::Val{true})
    @ ForwardDiff ~\.julia\packages\ForwardDiff\QOqCN\src\jacobian.jl:23
 [17] jacobian(f::Function, x::CuArray{Float32, 2}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, 
Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12, CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 2}}) (repeats 2 times)
    @ ForwardDiff ~\.julia\packages\ForwardDiff\QOqCN\src\jacobian.jl:19
 [18] jacobian(f::Function, x::CuArray{Float32, 2}, alg::SteadyStateAdjoint{0, true, Val{:central}, Bool, DefaultLinSolve})
    @ DiffEqSensitivity ~\.julia\packages\DiffEqSensitivity\p1AlV\src\derivative_wrappers.jl:135
 [19] SteadyStateAdjointProblem(sol::SciMLBase.NonlinearSolution{Float32, 2, CuArray{Float32, 2}, CuArray{Float32, 2}, SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing}, sensealg::SteadyStateAdjoint{0, true, Val{:central}, Bool, DefaultLinSolve}, g::Nothing, dg::SciMLBase.NonlinearSolution{Float32, 2, CuArray{Float32, 2}, CuArray{Float32, 2}, SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing}; save_idxs::Nothing)
    @ DiffEqSensitivity ~\.julia\packages\DiffEqSensitivity\p1AlV\src\steadystate_adjoint.jl:41
 [20] #_adjoint_sensitivities#57
    @ ~\.julia\packages\DiffEqSensitivity\p1AlV\src\sensitivity_interface.jl:65 [inlined]
 [21] #adjoint_sensitivities#54
    @ ~\.julia\packages\DiffEqSensitivity\p1AlV\src\sensitivity_interface.jl:6 [inlined]
 [22] steadystatebackpass
    @ ~\.julia\packages\DiffEqSensitivity\p1AlV\src\concrete_solve.jl:437 [inlined]
 [23] #98#back
    @ ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:65 [inlined]
 [24] #188
    @ ~\.julia\packages\Zygote\zowrf\src\lib\lib.jl:194 [inlined]
 [25] (::Zygote.var"#1689#back#190"{Zygote.var"#188#189"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, DiffEqBase.var"#98#back#74"{DiffEqSensitivity.var"#steadystatebackpass#209"{Nothing, DynamicSS{Tsit5, Float64, Float64, Float64}, SteadyStateAdjoint{0, true, Val{:central}, Bool, DefaultLinSolve}, Tuple{}, SciMLBase.NonlinearSolution{Float32, 2, CuArray{Float32, 2}, CuArray{Float32, 2}, SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing}}}}})(Δ::SciMLBase.NonlinearSolution{Float32, 2, CuArray{Float32, 
2}, CuArray{Float32, 2}, SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, 
Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [26] Pullback
    @ ~\.julia\packages\DiffEqBase\lULzQ\src\solve.jl:70 [inlined]
 [27] (::Zygote.Pullback{Tuple{DiffEqBase.var"##solve#57", Nothing, CuArray{Float32, 2}, Nothing, Base.Iterators.Pairs{Symbol, Float32, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:abstol, :reltol, :tspan), Tuple{Float32, Float32, Float32}}}, typeof(solve), SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}}, Any})(Δ::SciMLBase.NonlinearSolution{Float32, 2, CuArray{Float32, 2}, CuArray{Float32, 2}, SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, 
Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [28] (::Zygote.var"#188#189"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{DiffEqBase.var"##solve#57", Nothing, CuArray{Float32, 2}, Nothing, Base.Iterators.Pairs{Symbol, Float32, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:abstol, :reltol, :tspan), Tuple{Float32, Float32, Float32}}}, typeof(solve), SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}}, Any}})(Δ::SciMLBase.NonlinearSolution{Float32, 2, CuArray{Float32, 2}, CuArray{Float32, 2}, SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\lib\lib.jl:194
 [29] (::Zygote.var"#1689#back#190"{Zygote.var"#188#189"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{DiffEqBase.var"##solve#57", Nothing, CuArray{Float32, 2}, Nothing, Base.Iterators.Pairs{Symbol, Float32, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:abstol, :reltol, :tspan), Tuple{Float32, Float32, Float32}}}, typeof(solve), SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}}, Any}}})(Δ::SciMLBase.NonlinearSolution{Float32, 2, CuArray{Float32, 2}, CuArray{Float32, 2}, SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, 
Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [30] Pullback
    @ ~\.julia\packages\DiffEqBase\lULzQ\src\solve.jl:68 [inlined]
 [31] (::Zygote.Pullback{Tuple{CommonSolve.var"#solve##kw", NamedTuple{(:u0, :abstol, :reltol, :tspan), Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}, typeof(solve), SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}}, Any})(Δ::SciMLBase.NonlinearSolution{Float32, 2, CuArray{Float32, 2}, 
CuArray{Float32, 2}, SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, 
Nothing, Nothing})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [32] Pullback
    @ e:\wqy\julia\test.jl:107 [inlined]
 [33] (::Zygote.Pullback{Tuple{var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{Type{DynamicSS}, Tsit5}, Tuple{Zygote.var"#1546#back#129"{Zygote.var"#127#128"{Zygote.Context, GlobalRef, Float64}}, Zygote.Pullback{Tuple{SteadyStateDiffEq.var"##DynamicSS#6", Float64, Float64, Float64, Type{DynamicSS}, Tsit5}, Tuple{Zygote.Pullback{Tuple{Type{DynamicSS}, Tsit5, Float64, Float64, Float64}, 
Tuple{Zygote.Pullback{Tuple{Type{DynamicSS{Tsit5, Float64, Float64, Float64}}, Tsit5, Float64, Float64, Float64}, Tuple{Zygote.var"#1723#back#204"{Zygote.Jnew{DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, false}}, Zygote.var"#1772#back#230"{Zygote.var"#228#229"}, Zygote.var"#1772#back#230"{Zygote.var"#228#229"}, Zygote.Pullback{Tuple{typeof(convert), Type{Tsit5}, Tsit5}, Tuple{}}, Zygote.var"#1772#back#230"{Zygote.var"#228#229"}}}}}}}}}, 
Zygote.Pullback{Tuple{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, 
CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, 
CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}}}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, CuArray{Float32, 2}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{CuArray{Float32, 2}, CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{CuArray{Float32, 2}, CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}}}, 
Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{CuArray{Float32, 2}, CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:u0, :abstol, :reltol, :tspan), T} where T<:Tuple}, Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:u0, :abstol, :reltol, :tspan), Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}}, Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}, Tuple{Zygote.var"#1733#back#206"{Zygote.Jnew{NamedTuple{(:u0, :abstol, :reltol, :tspan), Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:m, Zygote.Context, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, Flux.var"#163#back#58"{Flux.var"#56#57"}}}, Zygote.Pullback{Tuple{typeof(gpu), CuArray{Float32, 2}}, Any}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:re, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}, Zygote.var"#1569#back#135"{typeof(identity)}, DiffEqBase.var"#150#back#172"{DiffEqBase.var"#solu_adjoint#171"{SciMLBase.NonlinearSolution{Float32, 2, CuArray{Float32, 2}, CuArray{Float32, 2}, SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing}}}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:p, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, CuArray{Float32, 1}}}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:re, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}, Zygote.Pullback{Tuple{Type{Tsit5}}, Tuple{}}, Zygote.Pullback{Tuple{Type{SteadyStateProblem}, ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Val{:p}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 
2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:p, Zygote.Context, ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, 
Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, CuArray{Float32, 1}}}}}}}, Zygote.Pullback{Tuple{Type{SteadyStateProblem{false, isinplace, P, F, K} where {isinplace, P, F, K}}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, CuArray{Float32, 2}, CuArray{Float32, 1}}, Tuple{Zygote.Pullback{Tuple{SciMLBase.var"#_#173#175", Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Type{SteadyStateProblem{false, isinplace, P, F, K} where {isinplace, P, F, K}}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, CuArray{Float32, 2}, CuArray{Float32, 1}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.Pullback{Tuple{typeof(isinplace), ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), 
CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(convert), Type{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(convert), Type{CuArray{Float32, 2}}, CuArray{Float32, 2}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#441"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.var"#1723#back#204"{Zygote.Jnew{SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Nothing, false}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#441"}, Zygote.Pullback{Tuple{typeof(convert), Type{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), 
CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#441"}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#441"}, Zygote.Pullback{Tuple{typeof(convert), Type{CuArray{Float32, 1}}, CuArray{Float32, 1}}, Tuple{}}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2031#back#331"{Zygote.var"#pairs_namedtuple#330"{(), NamedTuple{(), Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), 
Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Val{:f}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:f, Zygote.Context, ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}}}}}}, Zygote.Pullback{Tuple{typeof(isinplace), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Val{:u0}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:u0, Zygote.Context, ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, CuArray{Float32, 2}}}}}}}}}, Zygote.var"#1723#back#204"{Zygote.Jnew{var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, Nothing, false}}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:p, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, CuArray{Float32, 1}}}, Zygote.Pullback{Tuple{CommonSolve.var"#solve##kw", NamedTuple{(:u0, :abstol, :reltol, :tspan), Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}, typeof(solve), SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, 
typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}}, Any}, Zygote.var"#1784#back#234"{Zygote.var"#232#233"}, Zygote.Pullback{Tuple{Type{ODEProblem}, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, CuArray{Float32, 2}, Tuple{Float32, Float32}, CuArray{Float32, 1}}, Tuple{Zygote.Pullback{Tuple{SciMLBase.var"##ODEProblem#219", Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Type{ODEProblem}, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, CuArray{Float32, 2}, Tuple{Float32, Float32}, CuArray{Float32, 1}}, Any}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2031#back#331"{Zygote.var"#pairs_namedtuple#330"{(), NamedTuple{(), Tuple{}}}}}}, Zygote.var"#1784#back#234"{Zygote.var"#232#233"}, Zygote.Pullback{Tuple{typeof(gpu), CuArray{Float32, 2}}, Any}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:re, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, 
CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:tspan, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, Tuple{Float32, Float32}}}, Zygote.Pullback{Tuple{typeof(|>), CuArray{Float32, 2}, typeof(gpu)}, Tuple{Zygote.Pullback{Tuple{typeof(gpu), CuArray{Float32, 2}}, Any}}}, Zygote.Pullback{Tuple{typeof(gpu), CuArray{Float32, 2}}, Any}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [34] Pullback
    @ ~\.julia\packages\Flux\0c9kI\src\layers\basic.jl:36 [inlined]
--- the last 2 lines are repeated 1 more time ---
 [37] (::Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, var"#DiffEqArray_to_Array#1", Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, var"#DiffEqArray_to_Array#1", Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{var"#DiffEqArray_to_Array#1", Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, CuArray{Float32, 2}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}}}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, CuArray{Float32, 2}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{CuArray{Float32, 2}, CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3819#back#1049"{Zygote.var"#1047#1048"}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.Pullback{Tuple{var"#DiffEqArray_to_Array#1", CuArray{Float32, 2}}, Tuple{Zygote.var"#2436#back#489"{Zygote.var"#485#487"{CuArray{Float32, 2}, Tuple{Tuple{Int64, Int64}}}}, Zygote.var"#1605#back#151"{Zygote.var"#147#149"{2, UnitRange{Int64}}}, Zygote.Pullback{Tuple{typeof(gpu), CuArray{Float32, 2}}, Any}, Zygote.ZBack{ChainRules.var"#size_pullback#1091"}, Zygote.var"#357#364"}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing}}}, 
Zygote.Pullback{Tuple{var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{Type{DynamicSS}, Tsit5}, Tuple{Zygote.var"#1546#back#129"{Zygote.var"#127#128"{Zygote.Context, GlobalRef, Float64}}, Zygote.Pullback{Tuple{SteadyStateDiffEq.var"##DynamicSS#6", Float64, Float64, Float64, Type{DynamicSS}, Tsit5}, Tuple{Zygote.Pullback{Tuple{Type{DynamicSS}, Tsit5, Float64, Float64, Float64}, Tuple{Zygote.Pullback{Tuple{Type{DynamicSS{Tsit5, Float64, Float64, Float64}}, Tsit5, Float64, Float64, Float64}, Tuple{Zygote.var"#1723#back#204"{Zygote.Jnew{DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, false}}, Zygote.var"#1772#back#230"{Zygote.var"#228#229"}, Zygote.var"#1772#back#230"{Zygote.var"#228#229"}, Zygote.Pullback{Tuple{typeof(convert), Type{Tsit5}, Tsit5}, Tuple{}}, Zygote.var"#1772#back#230"{Zygote.var"#228#229"}}}}}}}}}, Zygote.Pullback{Tuple{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}}}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, CuArray{Float32, 2}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 2}, CuArray{Float32, 
1}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{CuArray{Float32, 2}, CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{CuArray{Float32, 2}, CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, 
Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{CuArray{Float32, 2}, CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 
1}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:u0, :abstol, :reltol, :tspan), T} where T<:Tuple}, Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:u0, :abstol, :reltol, :tspan), Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}}, Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}, Tuple{Zygote.var"#1733#back#206"{Zygote.Jnew{NamedTuple{(:u0, :abstol, :reltol, :tspan), Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:m, Zygote.Context, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, Flux.var"#163#back#58"{Flux.var"#56#57"}}}, Zygote.Pullback{Tuple{typeof(gpu), CuArray{Float32, 2}}, Any}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:re, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}, Zygote.var"#1569#back#135"{typeof(identity)}, DiffEqBase.var"#150#back#172"{DiffEqBase.var"#solu_adjoint#171"{SciMLBase.NonlinearSolution{Float32, 2, CuArray{Float32, 2}, CuArray{Float32, 2}, SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing}}}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:p, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, CuArray{Float32, 1}}}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:re, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}, Zygote.Pullback{Tuple{Type{Tsit5}}, Tuple{}}, Zygote.Pullback{Tuple{Type{SteadyStateProblem}, ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Val{:p}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:p, Zygote.Context, ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, 
NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, CuArray{Float32, 1}}}}}}}, Zygote.Pullback{Tuple{Type{SteadyStateProblem{false, isinplace, P, F, K} where {isinplace, P, F, K}}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, CuArray{Float32, 2}, CuArray{Float32, 1}}, Tuple{Zygote.Pullback{Tuple{SciMLBase.var"#_#173#175", Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Type{SteadyStateProblem{false, isinplace, P, F, K} where {isinplace, P, F, K}}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, CuArray{Float32, 2}, CuArray{Float32, 1}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.Pullback{Tuple{typeof(isinplace), ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(convert), Type{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(convert), Type{CuArray{Float32, 2}}, CuArray{Float32, 2}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#441"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.var"#1723#back#204"{Zygote.Jnew{SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Nothing, false}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#441"}, Zygote.Pullback{Tuple{typeof(convert), Type{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#441"}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#441"}, Zygote.Pullback{Tuple{typeof(convert), Type{CuArray{Float32, 1}}, CuArray{Float32, 1}}, Tuple{}}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2031#back#331"{Zygote.var"#pairs_namedtuple#330"{(), NamedTuple{(), Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Val{:f}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:f, Zygote.Context, ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}}}}}}, Zygote.Pullback{Tuple{typeof(isinplace), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Val{:u0}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:u0, Zygote.Context, ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, CuArray{Float32, 2}}}}}}}}}, Zygote.var"#1723#back#204"{Zygote.Jnew{var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, 
Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, Nothing, false}}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:p, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, CuArray{Float32, 1}}}, Zygote.Pullback{Tuple{CommonSolve.var"#solve##kw", NamedTuple{(:u0, :abstol, :reltol, :tspan), Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}, typeof(solve), SteadyStateProblem{CuArray{Float32, 2}, false, 
CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 
1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}}, Any}, Zygote.var"#1784#back#234"{Zygote.var"#232#233"}, Zygote.Pullback{Tuple{Type{ODEProblem}, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, CuArray{Float32, 2}, Tuple{Float32, Float32}, CuArray{Float32, 1}}, Tuple{Zygote.Pullback{Tuple{SciMLBase.var"##ODEProblem#219", Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Type{ODEProblem}, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, 
CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, CuArray{Float32, 2}, Tuple{Float32, Float32}, CuArray{Float32, 1}}, Any}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2031#back#331"{Zygote.var"#pairs_namedtuple#330"{(), NamedTuple{(), Tuple{}}}}}}, Zygote.var"#1784#back#234"{Zygote.var"#232#233"}, Zygote.Pullback{Tuple{typeof(gpu), CuArray{Float32, 2}}, Any}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:re, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:tspan, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, Tuple{Float32, Float32}}}, Zygote.Pullback{Tuple{typeof(|>), CuArray{Float32, 2}, typeof(gpu)}, Tuple{Zygote.Pullback{Tuple{typeof(gpu), CuArray{Float32, 2}}, Any}}}, Zygote.Pullback{Tuple{typeof(gpu), CuArray{Float32, 2}}, Any}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing, Nothing}}}, Zygote.Pullback{Tuple{Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}}}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, CuArray{Float32, 2}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{CuArray{Float32, 2}, CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.Pullback{Tuple{typeof(flatten), CuArray{Float32, 4}}, Tuple{Zygote.var"#2436#back#489"{Zygote.var"#485#487"{CuArray{Float32, 4}, Tuple{Colon, Int64}}}, Zygote.ZBack{ChainRules.var"#size_pullback#1091"}, Zygote.Pullback{Tuple{typeof(lastindex), NTuple{4, Int64}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#875"}}}, Zygote.var"#1593#back#145"{Zygote.var"#back#143"{4, Zygote.Context, Int64, Int64}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [38] Pullback
    @ ~\.julia\packages\Flux\0c9kI\src\layers\basic.jl:38 [inlined]
 [39] (::Zygote.Pullback{Tuple{Chain{Tuple{Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, var"#DiffEqArray_to_Array#1", Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, var"#DiffEqArray_to_Array#1", Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, var"#DiffEqArray_to_Array#1", Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, var"#DiffEqArray_to_Array#1", Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, Tuple{Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, var"#DiffEqArray_to_Array#1", Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}}}}}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, var"#DiffEqArray_to_Array#1", Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, var"#DiffEqArray_to_Array#1", Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{var"#DiffEqArray_to_Array#1", Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, CuArray{Float32, 2}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}}}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, CuArray{Float32, 2}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{CuArray{Float32, 2}, CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3819#back#1049"{Zygote.var"#1047#1048"}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.Pullback{Tuple{var"#DiffEqArray_to_Array#1", CuArray{Float32, 2}}, Tuple{Zygote.var"#2436#back#489"{Zygote.var"#485#487"{CuArray{Float32, 2}, Tuple{Tuple{Int64, Int64}}}}, Zygote.var"#1605#back#151"{Zygote.var"#147#149"{2, UnitRange{Int64}}}, Zygote.Pullback{Tuple{typeof(gpu), CuArray{Float32, 2}}, Any}, Zygote.ZBack{ChainRules.var"#size_pullback#1091"}, Zygote.var"#357#364"}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing}}}, Zygote.Pullback{Tuple{var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{Type{DynamicSS}, Tsit5}, Tuple{Zygote.var"#1546#back#129"{Zygote.var"#127#128"{Zygote.Context, GlobalRef, Float64}}, Zygote.Pullback{Tuple{SteadyStateDiffEq.var"##DynamicSS#6", Float64, Float64, Float64, Type{DynamicSS}, Tsit5}, Tuple{Zygote.Pullback{Tuple{Type{DynamicSS}, Tsit5, Float64, Float64, Float64}, Tuple{Zygote.Pullback{Tuple{Type{DynamicSS{Tsit5, Float64, Float64, Float64}}, Tsit5, Float64, Float64, Float64}, Tuple{Zygote.var"#1723#back#204"{Zygote.Jnew{DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, false}}, Zygote.var"#1772#back#230"{Zygote.var"#228#229"}, Zygote.var"#1772#back#230"{Zygote.var"#228#229"}, Zygote.Pullback{Tuple{typeof(convert), Type{Tsit5}, 
Tsit5}, Tuple{}}, Zygote.var"#1772#back#230"{Zygote.var"#228#229"}}}}}}}}}, Zygote.Pullback{Tuple{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}}}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, CuArray{Float32, 2}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{CuArray{Float32, 2}, CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{CuArray{Float32, 2}, CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{CuArray{Float32, 2}, CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:σ}}, 
Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:u0, :abstol, :reltol, :tspan), T} where T<:Tuple}, Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:u0, :abstol, :reltol, :tspan), Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}}, Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}, Tuple{Zygote.var"#1733#back#206"{Zygote.Jnew{NamedTuple{(:u0, :abstol, :reltol, :tspan), Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:m, Zygote.Context, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, Flux.var"#163#back#58"{Flux.var"#56#57"}}}, Zygote.Pullback{Tuple{typeof(gpu), CuArray{Float32, 2}}, Any}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:re, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}, Zygote.var"#1569#back#135"{typeof(identity)}, DiffEqBase.var"#150#back#172"{DiffEqBase.var"#solu_adjoint#171"{SciMLBase.NonlinearSolution{Float32, 2, CuArray{Float32, 2}, CuArray{Float32, 2}, SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing}}}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:p, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, CuArray{Float32, 1}}}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:re, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 
1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}, Zygote.Pullback{Tuple{Type{Tsit5}}, Tuple{}}, Zygote.Pullback{Tuple{Type{SteadyStateProblem}, ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Val{:p}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:p, Zygote.Context, ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, 
CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 
1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, CuArray{Float32, 1}}}}}}}, Zygote.Pullback{Tuple{Type{SteadyStateProblem{false, isinplace, P, F, K} where {isinplace, P, F, K}}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, CuArray{Float32, 2}, CuArray{Float32, 1}}, Tuple{Zygote.Pullback{Tuple{SciMLBase.var"#_#173#175", Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Type{SteadyStateProblem{false, isinplace, P, F, K} where {isinplace, P, F, K}}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, CuArray{Float32, 2}, CuArray{Float32, 1}}, Tuple{Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.Pullback{Tuple{typeof(isinplace), ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(convert), Type{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(convert), Type{CuArray{Float32, 2}}, CuArray{Float32, 2}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#441"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.ZBack{ChainRules.var"#typeof_pullback#26"}, Zygote.var"#1723#back#204"{Zygote.Jnew{SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Nothing, false}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#441"}, Zygote.Pullback{Tuple{typeof(convert), Type{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#441"}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#441"}, Zygote.Pullback{Tuple{typeof(convert), Type{CuArray{Float32, 1}}, CuArray{Float32, 1}}, Tuple{}}}}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2031#back#331"{Zygote.var"#pairs_namedtuple#330"{(), NamedTuple{(), Tuple{}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Val{:f}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:f, Zygote.Context, ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 
2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}}}}}}, Zygote.Pullback{Tuple{typeof(isinplace), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), 
Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Val{:u0}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:u0, Zygote.Context, ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, CuArray{Float32, 2}}}}}}}}}, Zygote.var"#1723#back#204"{Zygote.Jnew{var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, 
CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, Nothing, false}}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:p, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, CuArray{Float32, 1}}}, Zygote.Pullback{Tuple{CommonSolve.var"#solve##kw", NamedTuple{(:u0, :abstol, :reltol, :tspan), Tuple{CuArray{Float32, 2}, Float32, Float32, Float32}}, typeof(solve), SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}}, Any}, Zygote.var"#1784#back#234"{Zygote.var"#232#233"}, Zygote.Pullback{Tuple{Type{ODEProblem}, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, CuArray{Float32, 2}, Tuple{Float32, Float32}, CuArray{Float32, 1}}, Tuple{Zygote.Pullback{Tuple{SciMLBase.var"##ODEProblem#219", Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Type{ODEProblem}, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, CuArray{Float32, 2}, Tuple{Float32, Float32}, CuArray{Float32, 1}}, Any}, Zygote.Pullback{Tuple{Type{NamedTuple}}, Tuple{}}, Zygote.var"#2031#back#331"{Zygote.var"#pairs_namedtuple#330"{(), NamedTuple{(), Tuple{}}}}}}, Zygote.var"#1784#back#234"{Zygote.var"#232#233"}, Zygote.Pullback{Tuple{typeof(gpu), CuArray{Float32, 2}}, Any}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:re, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}, Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:tspan, Zygote.Context, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, Tuple{Float32, Float32}}}, Zygote.Pullback{Tuple{typeof(|>), CuArray{Float32, 2}, typeof(gpu)}, Tuple{Zygote.Pullback{Tuple{typeof(gpu), CuArray{Float32, 2}}, Any}}}, Zygote.Pullback{Tuple{typeof(gpu), CuArray{Float32, 2}}, Any}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing, Nothing}}}, Zygote.Pullback{Tuple{Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Tuple{typeof(flatten), 
Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}}}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, CuArray{Float32, 
2}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{CuArray{Float32, 2}, CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.Pullback{Tuple{typeof(flatten), CuArray{Float32, 4}}, Tuple{Zygote.var"#2436#back#489"{Zygote.var"#485#487"{CuArray{Float32, 4}, Tuple{Colon, Int64}}}, Zygote.ZBack{ChainRules.var"#size_pullback#1091"}, Zygote.Pullback{Tuple{typeof(lastindex), NTuple{4, Int64}}, Tuple{Zygote.ZBack{ChainRules.var"#length_pullback#875"}}}, 
Zygote.var"#1593#back#145"{Zygote.var"#back#143"{4, Zygote.Context, Int64, Int64}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}}})(Δ::CuArray{Float32, 2})        
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [40] Pullback
    @ e:\wqy\julia\test.jl:232 [inlined]
 [41] (::Zygote.Pullback{Tuple{var"#7#10"{Chain{Tuple{Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, var"#DiffEqArray_to_Array#1", Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}}, Any})(Δ::Float32)
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [42] (::Zygote.var"#69#70"{Zygote.Params, Zygote.Pullback{Tuple{var"#7#10"{Chain{Tuple{Chain{Tuple{typeof(flatten), Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, var"#solve_ss#2"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 1}}, var"#DiffEqArray_to_Array#1", Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}}, Any}, Zygote.Context})(Δ::Float32)
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface.jl:255
 [43] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface.jl:59
 [44] macro expansion
    @ e:\wqy\julia\test.jl:231 [inlined]
 [45] macro expansion
    @ ~\.julia\packages\ProgressMeter\Vf8un\src\ProgressMeter.jl:940 [inlined]
 [46] train(; kws::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Main e:\wqy\julia\test.jl:229
 [47] train()
    @ Main e:\wqy\julia\test.jl:178
 [48] top-level scope
    @ e:\wqy\julia\test.jl:252
 [49] include(fname::String)
    @ Base.MainInclude .\client.jl:444
 [50] startdebug(socket::Base.PipeEndpoint, error_handler::VSCodeDebugger.var"#3#4"{Tuple{String, String}})
    @ VSCodeDebugger.DebugAdapter ~\.vscode\extensions\julialang.language-julia-1.2.1\scripts\packages\DebugAdapter\src\packagedef.jl:93
 [51] startdebugger()
    @ VSCodeDebugger ~\.vscode\extensions\julialang.language-julia-1.2.1\scripts\packages\VSCodeDebugger\src\VSCodeDebugger.jl:38
in expression starting at e:\wqy\julia\test.jl:252

Julia debuggee finished. Press ENTER to close this terminal.

@ChrisRackauckas
Copy link
Member

@frankschae shouldn't it be calculating the vjp instead of building the jacobian?

@ChrisRackauckas
Copy link
Member

I see. I think we just need to be a bit more careful about the eltype of that Jacobian then?

@frankschae
Copy link
Member

@QiyaoWei
When I run your example on GPU, I get a very different error:

julia> train()
[ Info: Training on GPU
[ Info: Dataset MNIST: 60000 train and 10000 test examples
[ Info: LeNet5 model: 132874 trainable params
[ Info: TensorBoard logging at "runs/"
[ Info: Start Training
ERROR: GPU compilation of kernel broadcast_kernel(CUDA.CuKernelContext, CuDeviceArray{Float32, 4, 1}, Base.Broadcast.Broadcasted{Nothing, NTuple{4, Base.OneTo{Int64}}, typeof(convert), Tuple{CUDA.CuRefValue{DataType}, Base.Broadcast.Extruded{CuDeviceArray{Float32, 4, 1}, NTuple{4, Bool}, NTuple{4, Int64}}}}, Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{Nothing, NTuple{4, Base.OneTo{Int64}}, typeof(convert), Tuple{CUDA.CuRefValue{DataType}, Base.Broadcast.Extruded{CuDeviceArray{Float32, 4, 1}, NTuple{4, Bool}, NTuple{4, Int64}}}}, which is not isbits:
  .args is of type Tuple{CUDA.CuRefValue{DataType}, Base.Broadcast.Extruded{CuDeviceArray{Float32, 4, 1}, NTuple{4, Bool}, NTuple{4, Int64}}} which is not isbits.
    .1 is of type CUDA.CuRefValue{DataType} which is not isbits.
      .x is of type DataType which is not isbits.
        .name is of type Core.TypeName which is not isbits.
          .name is of type Symbol which is not isbits.
          .module is of type Module which is not isbits.
          .names is of type Core.SimpleVector which is not isbits.
          .wrapper is of type Type which is not isbits.
          .cache is of type Core.SimpleVector which is not isbits.
          .linearcache is of type Core.SimpleVector which is not isbits.
          .mt is of type Core.MethodTable which is not isbits.
            .name is of type Symbol which is not isbits.
            .defs is of type Any which is not isbits.
            .leafcache is of type Any which is not isbits.
            .cache is of type Any which is not isbits.
            .kwsorter is of type Any which is not isbits.
            .module is of type Module which is not isbits.
            .backedges is of type Vector{Any} which is not isbits.
          .partial is of type Any which is not isbits.
        .super is of type DataType which is not isbits.
          .name is of type Core.TypeName which is not isbits.
            .name is of type Symbol which is not isbits.
            .module is of type Module which is not isbits.
            .names is of type Core.SimpleVector which is not isbits.
            .wrapper is of type Type which is not isbits.
            .cache is of type Core.SimpleVector which is not isbits.
            .linearcache is of type Core.SimpleVector which is not isbits.
            .mt is of type Core.MethodTable which is not isbits.
              .name is of type Symbol which is not isbits.
              .defs is of type Any which is not isbits.
              .leafcache is of type Any which is not isbits.
              .cache is of type Any which is not isbits.
              .kwsorter is of type Any which is not isbits.
              .module is of type Module which is not isbits.
              .backedges is of type Vector{Any} which is not isbits.
            .partial is of type Any which is not isbits.
          .super is of type DataType which is not isbits.
            .name is of type Core.TypeName which is not isbits.
              .name is of type Symbol which is not isbits.
              .module is of type Module which is not isbits.
              .names is of type Core.SimpleVector which is not isbits.
              .wrapper is of type Type which is not isbits.
              .cache is of type Core.SimpleVector which is not isbits.
              .linearcache is of type Core.SimpleVector which is not isbits.
              .mt is of type Core.MethodTable which is not isbits.
                .name is of type Symbol which is not isbits.
                .defs is of type Any which is not isbits.
                .leafcache is of type Any which is not isbits.
                .cache is of type Any which is not isbits.
                .kwsorter is of type Any which is not isbits.
                .module is of type Module which is not isbits.
                .backedges is of type Vector{Any} which is not isbits.
              .partial is of type Any which is not isbits.
            .super is of type DataType which is not isbits.
              .name is of type Core.TypeName which is not isbits.
                .name is of type Symbol which is not isbits.
                .module is of type Module which is not isbits.
                .names is of type Core.SimpleVector which is not isbits.
                .wrapper is of type Type which is not isbits.
                .cache is of type Core.SimpleVector which is not isbits.
                .linearcache is of type Core.SimpleVector which is not isbits.
                .mt is of type Core.MethodTable which is not isbits.
                  .name is of type Symbol which is not isbits.
                  .defs is of type Any which is not isbits.
                  .leafcache is of type Any which is not isbits.
                  .cache is of type Any which is not isbits.
                  .kwsorter is of type Any which is not isbits.
                  .module is of type Module which is not isbits.
                  .backedges is of type Vector{Any} which is not isbits.
                .partial is of type Any which is not isbits.
              .super is of type DataType which is not isbits.
                .name is of type Core.TypeName which is not isbits.
                  .name is of type Symbol which is not isbits.
                  .module is of type Module which is not isbits.
                  .names is of type Core.SimpleVector which is not isbits.
                  .wrapper is of type Type which is not isbits.
                  .cache is of type Core.SimpleVector which is not isbits.
                  .linearcache is of type Core.SimpleVector which is not isbits.
                  .mt is of type Core.MethodTable which is not isbits.
                    .name is of type Symbol which is not isbits.
                    .defs is of type Any which is not isbits.
                    .leafcache is of type Any which is not isbits.
                    .cache is of type Any which is not isbits.
                    .kwsorter is of type Any which is not isbits.
                    .module is of type Module which is not isbits.
                    .backedges is of type Vector{Any} which is not isbits.
                  .partial is of type Any which is not isbits.
                .super is of type DataType which is not isbits.
                  .name is of type Core.TypeName which is not isbits.
                    .name is of type Symbol which is not isbits.
                    .module is of type Module which is not isbits.
                    .names is of type Core.SimpleVector which is not isbits.
                    .wrapper is of type Type which is not isbits.
                    .cache is of type Core.SimpleVector which is not isbits.
                    .linearcache is of type Core.SimpleVector which is not isbits.
                    .mt is of type Core.MethodTable which is not isbits.
                    .partial is of type Any which is not isbits.
                  .super is of type DataType which is not isbits.
                    .name is of type Core.TypeName which is not isbits.
                    .super is of type DataType which is not isbits.
                    .parameters is of type Core.SimpleVector which is not isbits.
                    .types is of type Core.SimpleVector which is not isbits.
                    .names is of type Core.SimpleVector which is not isbits.
                    .instance is of type Any which is not isbits.
                  .parameters is of type Core.SimpleVector which is not isbits.
                  .types is of type Core.SimpleVector which is not isbits.
                  .names is of type Core.SimpleVector which is not isbits.
                  .instance is of type Any which is not isbits.
                .parameters is of type Core.SimpleVector which is not isbits.
                .types is of type Core.SimpleVector which is not isbits.
                .names is of type Core.SimpleVector which is not isbits.
                .instance is of type Any which is not isbits.
              .parameters is of type Core.SimpleVector which is not isbits.
              .types is of type Core.SimpleVector which is not isbits.
              .names is of type Core.SimpleVector which is not isbits.
              .instance is of type Any which is not isbits.
            .parameters is of type Core.SimpleVector which is not isbits.
            .types is of type Core.SimpleVector which is not isbits.
            .names is of type Core.SimpleVector which is not isbits.
            .instance is of type Any which is not isbits.
          .parameters is of type Core.SimpleVector which is not isbits.
          .types is of type Core.SimpleVector which is not isbits.
          .names is of type Core.SimpleVector which is not isbits.
          .instance is of type Any which is not isbits.
        .parameters is of type Core.SimpleVector which is not isbits.
        .types is of type Core.SimpleVector which is not isbits.
        .names is of type Core.SimpleVector which is not isbits.
        .instance is of type Any which is not isbits.


Stacktrace:
  [1] check_invocation(job::GPUCompiler.CompilerJob, entry::LLVM.Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/eJOtJ/src/validation.jl:66
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/eJOtJ/src/driver.jl:309 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/TimerOutputs/PZq45/src/TimerOutput.jl:226 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/GPUCompiler/eJOtJ/src/driver.jl:308 [inlined]
  [5] emit_asm(job::GPUCompiler.CompilerJob, ir::LLVM.Module, kernel::LLVM.Function; strip::Bool, validate::Bool, format::LLVM.API.LLVMCodeGenFileType)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/eJOtJ/src/utils.jl:62
  [6] cufunction_compile(job::GPUCompiler.CompilerJob)
    @ CUDA ~/.julia/packages/CUDA/3VnCC/src/compiler/execution.jl:301
  [7] check_cache
    @ ~/.julia/packages/GPUCompiler/eJOtJ/src/cache.jl:47 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/GPUArrays/Z5nPF/src/host/broadcast.jl:57 [inlined]
  [9] cached_compilation(cache::Dict{UInt64, Any}, job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{GPUArrays.var"#broadcast_kernel#16", Tuple{CUDA.CuKernelContext, CuDeviceArray{Float32, 4, 1}, Base.Broadcast.Broadcasted{Nothing, NTuple{4, Base.OneTo{Int64}}, typeof(convert), Tuple{CUDA.CuRefValue{DataType}, Base.Broadcast.Extruded{CuDeviceArray{Float32, 4, 1}, NTuple{4, Bool}, NTuple{4, Int64}}}}, Int64}}}, compiler::typeof(CUDA.cufunction_compile), linker::typeof(CUDA.cufunction_link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/eJOtJ/src/cache.jl:0
 [10] cufunction(f::GPUArrays.var"#broadcast_kernel#16", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceArray{Float32, 4, 1}, Base.Broadcast.Broadcasted{Nothing, NTuple{4, Base.OneTo{Int64}}, typeof(convert), Tuple{CUDA.CuRefValue{DataType}, Base.Broadcast.Extruded{CuDeviceArray{Float32, 4, 1}, NTuple{4, Bool}, NTuple{4, Int64}}}}, Int64}}; name::Nothing, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ CUDA ~/.julia/packages/CUDA/3VnCC/src/compiler/execution.jl:289
 [11] cufunction
    @ ~/.julia/packages/CUDA/3VnCC/src/compiler/execution.jl:283 [inlined]
 [12] macro expansion
    @ ~/.julia/packages/CUDA/3VnCC/src/compiler/execution.jl:102 [inlined]
 [13] #launch_heuristic#286
    @ ~/.julia/packages/CUDA/3VnCC/src/gpuarrays.jl:17 [inlined]
 [14] launch_heuristic
    @ ~/.julia/packages/CUDA/3VnCC/src/gpuarrays.jl:17 [inlined]
 [15] copyto!
    @ ~/.julia/packages/GPUArrays/Z5nPF/src/host/broadcast.jl:63 [inlined]
 [16] copyto!
    @ ./broadcast.jl:936 [inlined]
 [17] copy
    @ ~/.julia/packages/GPUArrays/Z5nPF/src/host/broadcast.jl:47 [inlined]
 [18] materialize
    @ ./broadcast.jl:883 [inlined]
 [19] adapt_storage(T::Type{Float32}, xs::CuArray{Float32, 4})
    @ Flux ~/.julia/packages/Flux/0c9kI/src/functor.jl:73
 [20] adapt_structure(to::Type, x::CuArray{Float32, 4})
    @ Adapt ~/.julia/packages/Adapt/RGNRk/src/Adapt.jl:42
 [21] adapt
    @ ~/.julia/packages/Adapt/RGNRk/src/Adapt.jl:40 [inlined]
 [22] #125
    @ ~/.julia/packages/Flux/0c9kI/src/functor.jl:75 [inlined]
 [23] fmap(f::Flux.var"#125#126"{DataType}, x::CuArray{Float32, 4}; exclude::typeof(Functors.isleaf), cache::IdDict{Any, Any})
    @ Functors ~/.julia/packages/Functors/EWaud/src/functor.jl:56
 [24] fmap
    @ ~/.julia/packages/Functors/EWaud/src/functor.jl:55 [inlined]
 [25] paramtype
    @ ~/.julia/packages/Flux/0c9kI/src/functor.jl:75 [inlined]
 [26] f32
    @ ~/.julia/packages/Flux/0c9kI/src/functor.jl:82 [inlined]
 [27] |>
    @ ./operators.jl:858 [inlined]
 [28] #7
    @ ~/ownCloud/Private/JSoC2021/OtherIssues/SteadyStateAdjoint/DiffEqFlux567_original.jl:46 [inlined]
 [29] applychain(fs::Tuple{var"#7#9", MeanPool{2, 4}, var"#8#10", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}, x::CuArray{Float32, 4}) (repeats 2 times)
    @ Flux ~/.julia/packages/Flux/0c9kI/src/layers/basic.jl:36
 [30] Chain
    @ ~/.julia/packages/Flux/0c9kI/src/layers/basic.jl:38 [inlined]
 [31] applychain(fs::Tuple{Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#7#9", MeanPool{2, 4}, var"#8#10", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, x::CuArray{Float32, 4}) (repeats 3 times)
    @ Flux ~/.julia/packages/Flux/0c9kI/src/layers/basic.jl:36
 [32] (::Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, var"#solve_ss#11"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#7#9", MeanPool{2, 4}, var"#8#10", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}})(x::CuArray{Float32, 4})
    @ Flux ~/.julia/packages/Flux/0c9kI/src/layers/basic.jl:38
 [33] eval_loss_accuracy(loader::DataLoader{Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}, Random._GLOBAL_RNG}, model::Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, var"#solve_ss#11"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#7#9", MeanPool{2, 4}, var"#8#10", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, device::Function)
    @ Main ~/ownCloud/Private/JSoC2021/OtherIssues/SteadyStateAdjoint/DiffEqFlux567_original.jl:105
 [34] (::var"#report#17"{Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, var"#solve_ss#11"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#7#9", MeanPool{2, 4}, var"#8#10", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, DataLoader{Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}, Random._GLOBAL_RNG}, DataLoader{Tuple{Array{Float32, 4}, Flux.OneHotArray{UInt32, 10, 1, 2, Vector{UInt32}}}, Random._GLOBAL_RNG}, Args})(epoch::Int64)
    @ Main ~/ownCloud/Private/JSoC2021/OtherIssues/SteadyStateAdjoint/DiffEqFlux567_original.jl:167
 [35] train(; kws::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Main ~/ownCloud/Private/JSoC2021/OtherIssues/SteadyStateAdjoint/DiffEqFlux567_original.jl:181
 [36] train()
    @ Main ~/ownCloud/Private/JSoC2021/OtherIssues/SteadyStateAdjoint/DiffEqFlux567_original.jl:132
 [37] top-level scope
    @ none:1

On CPU I get a broadcasting error related to a line after the computation of the Jacobian (which will be easy to fix though). Can you update the example such that I can reproduce the Jacobian error? (Probably it would be nice if it would be a bit simpler than the original example, i.e. without the logging, less packages, smaller batch sizes, and with simpler NN structures.)

@QiyaoWei
Copy link
Contributor Author

QiyaoWei commented Jun 6, 2021

Hi Frank! I have also been playing around with the tspan parameter, and getting the same error (CUDA Kernel and broadcasting dimension mismatch, respectively). Could you elaborate what you mean by "example reproducing the Jacobian error"? I could certainly cut off the logging and packages

@frankschae
Copy link
Member

getting the same error (CUDA Kernel and broadcasting dimension mismatch, respectively)

I am a bit confused. Maybe I understood you wrong. I thought you got this error:

Actually I think I previously set tspan to be the incorrect value. When I add the line

x = solve(ss, DynamicSS(Tsit5()), u0=device(z), abstol=1f-2, reltol=1f-2, tspan=1.0f0).u

I get the error

[ Info: Training on GPU
[ Info: Dataset MNIST: 60000 train and 10000 test examples
[ Info: LeNet5 model: 15910 trainable params
[ Info: TensorBoard logging at "runs/"
[ Info: Start Training
Epoch: 0   Train: (loss = 2.3311f0, acc = 11.9333)   Test: (loss = 2.3328f0, acc = 12.04)
ERROR: LoadError: TypeError: in typeassert, expected Float32, got a value of type ForwardDiff.Dual{Nothing, Float32, 12}
Stacktrace:
  [1] setindex!(A::Matrix{Float32}, x::ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, 
CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, i1::Int64)
    @ Base .\array.jl:839
  [2] _unsafe_copyto!(dest::Matrix{Float32}, doffs::Int64, src::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}}, soffs::Int64, n::Int64)
    @ Base .\array.jl:235
  [3] unsafe_copyto!
    @ .\array.jl:289 [inlined]
  [4] _copyto_impl!
    @ .\array.jl:313 [inlined]
  [5] copyto!
    @ .\array.jl:299 [inlined]
  [6] copyto!
    @ .\array.jl:325 [inlined]
  [7] copyto_axcheck!(dest::Matrix{Float32}, src::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}})
    @ Base .\abstractarray.jl:1056
  [8] Array
    @ .\array.jl:540 [inlined]
  [9] convert
    @ .\array.jl:532 [inlined]
 [10] convert(AT::Type{Matrix{Float32}}, A::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 2})
    @ GPUArrays ~\.julia\packages\GPUArrays\Z5nPF\src\host\construction.jl:90
 [11] convert(AT::Type{CuArray{Float32, N} where N}, A::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 2})
    @ GPUArrays ~\.julia\packages\GPUArrays\Z5nPF\src\host\construction.jl:82
 [12] (::var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}})(u::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 2}, _p::CuArray{Float32, 1}, t::Nothing)    @ Main e:\wqy\julia\test.jl:96
 [13] ODEFunction
    @ ~\.julia\packages\SciMLBase\grNUR\src\scimlfunctions.jl:334 [inlined]
 [14] UDerivativeWrapper
    @ ~\.julia\packages\SciMLBase\grNUR\src\function_wrappers.jl:30 [inlined]
 [15] chunk_mode_jacobian(f::SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, x::CuArray{Float32, 2}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, 
var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12, CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 2}})
    @ ForwardDiff ~\.julia\packages\ForwardDiff\QOqCN\src\jacobian.jl:223
 [16] jacobian(f::Function, x::CuArray{Float32, 2}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, 
Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12, CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 2}}, ::Val{true})
    @ ForwardDiff ~\.julia\packages\ForwardDiff\QOqCN\src\jacobian.jl:23
 [17] jacobian(f::Function, x::CuArray{Float32, 2}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, 
Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12, CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 2}}) (repeats 2 times)
    @ ForwardDiff ~\.julia\packages\ForwardDiff\QOqCN\src\jacobian.jl:19
 [18] jacobian(f::Function, x::CuArray{Float32, 2}, alg::SteadyStateAdjoint{0, true, Val{:central}, Bool, DefaultLinSolve})
    @ DiffEqSensitivity ~\.julia\packages\DiffEqSensitivity\p1AlV\src\derivative_wrappers.jl:135
 [19] SteadyStateAdjointProblem(sol::SciMLBase.NonlinearSolution{Float32, 2, CuArray{Float32, 2}, CuArray{Float32, 2}, SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing}, sensealg::SteadyStateAdjoint{0, true, Val{:central}, Bool, DefaultLinSolve}, g::Nothing, dg::SciMLBase.NonlinearSolution{Float32, 2, CuArray{Float32, 2}, CuArray{Float32, 2}, SteadyStateProblem{CuArray{Float32, 2}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#3"{CuArray{Float32, 2}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing}; save_idxs::Nothing)
    @ DiffEqSensitivity ~\.julia\packages\DiffEqSensitivity\p1AlV\src\steadystate_adjoint.jl:41
 [20] #_adjoint_sensitivities#57
    @ ~\.julia\packages\DiffEqSensitivity\p1AlV\src\sensitivity_interface.jl:65 [inlined]
 [21] #adjoint_sensitivities#54
    @ ~\.julia\packages\DiffEqSensitivity\p1AlV\src\sensitivity_interface.jl:6 [inlined]
 [22] steadystatebackpass
    @ ~\.julia\packages\DiffEqSensitivity\p1AlV\src\concrete_solve.jl:437 [inlined]
 [23] #98#back

..

which I couldn't reproduce yet.

@QiyaoWei
Copy link
Contributor Author

QiyaoWei commented Jun 7, 2021

Ah I see what you are referring to. Actually that error went away after I upgraded my packages and incorporated your commit of "tspan" typecast. I don't think the ForwardDiff error is a problem anymore.

I cleaned up the code a bit and got rid of some logging code & packages. I also simplified the model declaration from convolutional to fully connected. I chose to stick with the MNIST dataset because simplifying more might conceal the problem. For example, non-SteadyState MNIST training and non-MNIST SteadyState training both work under GPU already.

To reiterate, here's the CUDA Kernel error in the GPU code

using Flux
using Flux.Data:DataLoader
using Flux.Optimise: Optimiser
using Flux: onehotbatch, onecold
using Flux.Losses:logitcrossentropy
using ProgressMeter:@showprogress
import MLDatasets
using CUDA
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
CUDA.allowscalar(false)


function Net() 



    down = Chain(
        x -> reshape(x, (784, 8)) |> f32,
        Dense(784, 200, tanh) |> f32,
        Dense(200, 20, tanh) |> f32,
    ) |> f32
    deq = Chain(
        Dense(20, 10, tanh) |> f32,
        Dense(10, 10, tanh) |> f32,
        Dense(10, 20, tanh) |> f32,
    ) |> f32
    p, re = Flux.destructure(deq)
    fc = Chain(
        Dense(20, 15, tanh) |> f32,
        Dense(15, 10, tanh) |> f32,
    ) |> f32

    tspan = (0.0f0, 1.0f0)
    function solve_ss(x)
        xg = gpu(x)

        z = re(p)(xg) |> gpu
        function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
            re(_p)(u + xg) - u
        end
        ss = SteadyStateProblem(ODEProblem(dudt_, gpu(z), tspan, p))
        x = solve(ss, DynamicSS(Tsit5()), u0=gpu(z), abstol=Float32(1e-2), reltol=Float32(1e-2), tspan=1.0f0).u
    end
  # Build our over-all model topology
    m = Chain(
        down,               # (28,28,1,BS) -> (6,6,64,BS)
        solve_ss,           # (6,6,64,BS) -> (6,6,64,BS)
        fc,                 # (6,6,64,BS) -> (10, BS)
    )

    return m
end

function get_data(args)
    xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
    xtest, ytest = MLDatasets.MNIST.testdata(Float32)

    xtrain = reshape(xtrain, 28, 28, 1, :)
    xtest = reshape(xtest, 28, 28, 1, :)

    ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

    train_loader = DataLoader((xtrain, ytrain), batchsize=args.batchsize, shuffle=true)
    test_loader = DataLoader((xtest, ytest),  batchsize=args.batchsize)
    
    return train_loader, test_loader
end

loss(ŷ, y) = logitcrossentropy(ŷ, y)

function eval_loss_accuracy(loader, model, device)
    l = 0f0
    acc = 0
    ntot = 0
    for (x, y) in loader
        x, y = x |> device, y |> device
        ŷ = model(x)
        l += loss(ŷ, y) * size(x)[end]        
        acc += sum(onecold(ŷ |> cpu) .== onecold(y |> cpu))
        ntot += size(x)[end]
    end
    return (loss = l / ntot |> round4, acc = acc / ntot * 100 |> round4)
end

# utility functions
round4(x) = round(x, digits=4)

# arguments for the `train` function 
Base.@kwdef mutable struct Args
    η = 3e-4             # learning rate
    λ = 0                # L2 regularizer param, implemented as weight decay
    batchsize = 8      # batch size
    epochs = 10          # number of epochs
    seed = 0             # set seed > 0 for reproducibility
    use_cuda = true      # if true use cuda (if available)
    infotime = 1 	     # report every `infotime` epochs
    checktime = 5        # Save the model every `checktime` epochs. Set to 0 for no checkpoints.
    tblogger = true      # log training with tensorboard
    savepath = "runs/"    # results path
end

function train(; kws...)
    args = Args(; kws...)
    args.seed > 0 && Random.seed!(args.seed)
    use_cuda = args.use_cuda && CUDA.functional()
    
    if use_cuda
        device = gpu
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end

    ## DATA
    train_loader, test_loader = get_data(args)
    @info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples"

    ## MODEL AND OPTIMIZER
    model = Net() |> device
    
    ps = Flux.params(model)
    opt = ADAM(args.η)
    
    ## TRAINING
    @info "Start Training"
    for epoch in 1:args.epochs
        @showprogress for (x, y) in train_loader
            x, y = x |> device, y |> device
            gs = Flux.gradient(ps) do
                ŷ = model(x)
                loss(ŷ, y)
            end

            Flux.Optimise.update!(opt, ps, gs)
        end
    end
end

train()
[ Info: Training on GPU
[ Info: Dataset MNIST: 60000 train and 10000 test examples
[ Info: Start Training
ERROR: LoadError: GPUCompiler.KernelError(GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{GPUArrays.var"#broadcast_kernel#16", Tuple{CUDA.CuKernelContext, CuDeviceMatrix{ForwardDiff.Dual{Nothing, Float32, 2}, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#1120#1123"{typeof(convert)}, Tuple{CUDA.CuRefValue{DataType}, Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}}}(GPUCompiler.PTXCompilerTarget(v"6.1.0", v"6.3.0", false, false, nothing, nothing, nothing, nothing), GPUCompiler.FunctionSpec{GPUArrays.var"#broadcast_kernel#16", Tuple{CUDA.CuKernelContext, CuDeviceMatrix{ForwardDiff.Dual{Nothing, Float32, 2}, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#1120#1123"{typeof(convert)}, 
Tuple{CUDA.CuRefValue{DataType}, Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}}(GPUArrays.var"#broadcast_kernel#16"(), Tuple{CUDA.CuKernelContext, CuDeviceMatrix{ForwardDiff.Dual{Nothing, Float32, 2}, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#1120#1123"{typeof(convert)}, Tuple{CUDA.CuRefValue{DataType}, Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}, true, nothing, 0xffffffffffffffff), CUDA.CUDACompilerParams()), "passing and using non-bitstype argument", "Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var\"#1120#1123\"{typeof(convert)}, Tuple{CUDA.CuRefValue{DataType}, Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:\n  .args is of type Tuple{CUDA.CuRefValue{DataType}, Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}} which is not isbits.\n    .1 is of type CUDA.CuRefValue{DataType} which is not isbits.\n      .x is of type DataType which is not isbits.\n        .name is of type Core.TypeName which is not isbits.\n          .name is of type Symbol which is not isbits.\n  
        .module is of type Module which is not isbits.\n          .names is of type Core.SimpleVector which is not isbits.\n          .wrapper is of type Type which is not isbits.\n          .cache is of type Core.SimpleVector which is not isbits.\n          .linearcache is of type Core.SimpleVector which is not isbits.\n          .mt is of type Core.MethodTable which is not isbits.\n            .name is of type Symbol which is not isbits.\n     
       .defs is of type Any which is not isbits.\n            .leafcache is of type Any which is not isbits.\n            .cache is of type Any which is not isbits.\n            .kwsorter is of type Any which is not isbits.\n            .module is of type Module which is not isbits.\n            .backedges is of type Vector{Any} which is not isbits.\n          .partial is of type Any which is not isbits.\n        .super is of type DataType which 
is not isbits.\n          .name is of type Core.TypeName which is not isbits.\n            .name is of type Symbol which is not isbits.\n            .module is of type Module which is not isbits.\n            .names is of type Core.SimpleVector which is not isbits.\n            .wrapper is of type Type which is not isbits.\n            .cache is of type Core.SimpleVector which is not isbits.\n            .linearcache is of type Core.SimpleVector 
which is not isbits.\n            .mt is of type Core.MethodTable which is not isbits.\n              .name is of type Symbol which is not isbits.\n              .defs is of type Any which is not isbits.\n              .leafcache is of type Any which is not isbits.\n              .cache is of type Any which is not isbits.\n              .kwsorter is of type Any which is not isbits.\n              .module is of type Module which is not isbits.\n  
            .backedges is of type Vector{Any} which is not isbits.\n            .partial is of type Any which is not isbits.\n          .super is of type DataType which is not isbits.\n            .name is of type Core.TypeName which is not isbits.\n              .name is of type Symbol which is not isbits.\n              .module is of type Module which is not isbits.\n              .names is of type Core.SimpleVector which is not isbits.\n      
        .wrapper is of type Type which is not isbits.\n              .cache is of type Core.SimpleVector which is not isbits.\n              .linearcache is of type Core.SimpleVector which is not isbits.\n              .mt is of type Core.MethodTable which is not isbits.\n                .name is of type Symbol which is not isbits.\n                .defs is of type Any which is not isbits.\n                .leafcache is of type Any which is not isbits.\n                .cache is of type Any which is not isbits.\n                .kwsorter is of type Any which is not isbits.\n                .module is of type Module which is not isbits.\n                .backedges is 
of type Vector{Any} which is not isbits.\n              .partial is of type Any which is not isbits.\n            .super is of type DataType which is not isbits.\n              .name is of type Core.TypeName which is not isbits.\n                .name is of type Symbol which is not isbits.\n                .module is of type Module which is not isbits.\n                .names is of type Core.SimpleVector which is not isbits.\n                .wrapper is of type Type which is not isbits.\n                .cache is of type Core.SimpleVector which is not isbits.\n                .linearcache is of type Core.SimpleVector which is not isbits.\n                .mt is of type Core.MethodTable which is not isbits.\n                  .name is of type Symbol which is not isbits.\n                  .defs is of type Any which is not isbits.\n                  .leafcache is of type Any which is not isbits.\n                  .cache is of type Any which is not isbits.\n                  .kwsorter is of type Any which is not isbits.\n                  .module is of type Module which is not isbits.\n                  .backedges is of type Vector{Any} which is not isbits.\n                .partial is of type Any which is not isbits.\n              .super is of type DataType which is not isbits.\n                .name is of type Core.TypeName which is not isbits.\n                  .name is of type Symbol which is not isbits.\n                  .module is of type Module which is not isbits.\n                  .names is of type Core.SimpleVector which is not isbits.\n                  .wrapper is of type Type which is not isbits.\n                  .cache is of type Core.SimpleVector which is not isbits.\n                  .linearcache is of type Core.SimpleVector which is not isbits.\n  
                .mt is of type Core.MethodTable which is not isbits.\n                    .name is of type Symbol which is not isbits.\n                    .defs is of type Any which is not isbits.\n                    .leafcache is of type Any which is not isbits.\n                    .cache is of type Any which is not isbits.\n                    .kwsorter is of type Any which is not isbits.\n                    .module is of type Module which 
is not isbits.\n                    .backedges is of type Vector{Any} which is not isbits.\n                  .partial is of type Any which is not isbits.\n                .super is of type DataType which is not isbits.\n    
              .name is of type Core.TypeName which is not isbits.\n                    .name is of type Symbol which is not isbits.\n                    .module is of type Module which is not isbits.\n                    .names is of type Core.SimpleVector which is not isbits.\n                    .wrapper is of type Type which is not isbits.\n                    .cache is of type Core.SimpleVector which is not isbits.\n                    .linearcache is of type Core.SimpleVector which is not isbits.\n                    .mt is of type Core.MethodTable which is not isbits.\n                    .partial is of type Any which is not isbits.\n                  .super is of type DataType which is not isbits.\n                    .name is of type Core.TypeName which is not isbits.\n                    .super is of type DataType which is not isbits.\n                    .parameters is of type 
Core.SimpleVector which is not isbits.\n                    .types is of type Core.SimpleVector which is not isbits.\n                    .names is of type Core.SimpleVector which is not isbits.\n                    .instance is of type Any which is not isbits.\n                  .parameters is of type Core.SimpleVector which is not isbits.\n                  .types is of type Core.SimpleVector which is not isbits.\n                  .names is of type Core.SimpleVector which is not isbits.\n                  .instance is of type Any which is not isbits.\n                .parameters is of type Core.SimpleVector which is not isbits.\n                .types is of type Core.SimpleVector which is not isbits.\n                .names is of type Core.SimpleVector which is not isbits.\n                .instance is of type Any which is not isbits.\n              .parameters is of type Core.SimpleVector which is not isbits.\n              .types is of type Core.SimpleVector which is not isbits.\n              .names is of type Core.SimpleVector which is not isbits.\n              .instance is of type Any which is not isbits.\n            .parameters is of type Core.SimpleVector which is not isbits.\n            .types is of type Core.SimpleVector which is not isbits.\n            .names is of type Core.SimpleVector which is not isbits.\n  
          .instance is of type Any which is not isbits.\n          .parameters is of type Core.SimpleVector which is not isbits.\n          .types is of type Core.SimpleVector which is not isbits.\n          .names is of type Core.SimpleVector which is not isbits.\n          .instance is of type Any which is not isbits.\n        .parameters is of type Core.SimpleVector which is not isbits.\n        .types is of type Core.SimpleVector which is not isbits.\n        .names is of type Core.SimpleVector which is not isbits.\n        .instance is of type Any which is not isbits.\n", Base.StackTraces.StackFrame[])
Stacktrace:
  [1] check_invocation(job::GPUCompiler.CompilerJob, entry::LLVM.Function)
    @ GPUCompiler ~\.julia\packages\GPUCompiler\eJOtJ\src\validation.jl:66
  [2] macro expansion
    @ ~\.julia\packages\GPUCompiler\eJOtJ\src\driver.jl:309 [inlined]
  [3] macro expansion
    @ ~\.julia\packages\TimerOutputs\PZq45\src\TimerOutput.jl:226 [inlined]
  [4] macro expansion
    @ ~\.julia\packages\GPUCompiler\eJOtJ\src\driver.jl:308 [inlined]
  [5] emit_asm(job::GPUCompiler.CompilerJob, ir::LLVM.Module, kernel::LLVM.Function; strip::Bool, validate::Bool, format::LLVM.API.LLVMCodeGenFileType)
    @ GPUCompiler ~\.julia\packages\GPUCompiler\eJOtJ\src\utils.jl:62
  [6] cufunction_compile(job::GPUCompiler.CompilerJob)
    @ CUDA ~\.julia\packages\CUDA\3VnCC\src\compiler\execution.jl:301
  [7] check_cache
    @ ~\.julia\packages\GPUCompiler\eJOtJ\src\cache.jl:47 [inlined]
  [8] cached_compilation
    @ ~\.julia\packages\GPUArrays\Z5nPF\src\host\broadcast.jl:57 [inlined]
  [9] cached_compilation(cache::Dict{UInt64, Any}, job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{GPUArrays.var"#broadcast_kernel#16", Tuple{CUDA.CuKernelContext, CuDeviceMatrix{ForwardDiff.Dual{Nothing, Float32, 2}, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#1120#1123"{typeof(convert)}, Tuple{CUDA.CuRefValue{DataType}, Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}}}, compiler::typeof(CUDA.cufunction_compile), linker::typeof(CUDA.cufunction_link))
    @ GPUCompiler ~\.julia\packages\GPUCompiler\eJOtJ\src\cache.jl:0
 [10] cufunction(f::GPUArrays.var"#broadcast_kernel#16", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceMatrix{ForwardDiff.Dual{Nothing, Float32, 2}, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#1120#1123"{typeof(convert)}, Tuple{CUDA.CuRefValue{DataType}, Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}}; name::Nothing, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ CUDA ~\.julia\packages\CUDA\3VnCC\src\compiler\execution.jl:289
 [11] cufunction
    @ ~\.julia\packages\CUDA\3VnCC\src\compiler\execution.jl:283 [inlined]
 [12] macro expansion
    @ ~\.julia\packages\CUDA\3VnCC\src\compiler\execution.jl:102 [inlined]
 [13] #launch_heuristic#286
    @ ~\.julia\packages\CUDA\3VnCC\src\gpuarrays.jl:17 [inlined]
 [14] launch_heuristic
    @ ~\.julia\packages\CUDA\3VnCC\src\gpuarrays.jl:17 [inlined]
 [15] copyto!
    @ ~\.julia\packages\GPUArrays\Z5nPF\src\host\broadcast.jl:63 [inlined]
 [16] copyto!
    @ .\broadcast.jl:936 [inlined]
 [17] copy
    @ ~\.julia\packages\GPUArrays\Z5nPF\src\host\broadcast.jl:47 [inlined]
 [18] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, Zygote.var"#1120#1123"{typeof(convert)}, Tuple{Base.RefValue{Type{Float32}}, CuArray{Float32, 2}}})
    @ Base.Broadcast .\broadcast.jl:883
 [19] broadcast_forward(::Function, ::Base.RefValue{Type{Float32}}, ::CuArray{Float32, 2})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\lib\broadcast.jl:226
 [20] adjoint
    @ ~\.julia\packages\Zygote\zowrf\src\lib\broadcast.jl:244 [inlined]
 [21] _pullback(::Zygote.Context, ::typeof(Base.Broadcast.broadcasted), ::CUDA.CuArrayStyle{2}, ::Function, ::Base.RefValue{Type{Float32}}, ::CuArray{Float32, 2})
    @ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:57
 [22] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core .\boot.jl:804
 [23] adjoint
    @ ~\.julia\packages\Zygote\zowrf\src\lib\lib.jl:191 [inlined]
 [24] _pullback
    @ ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:57 [inlined]
 [25] _pullback
    @ .\broadcast.jl:1315 [inlined]
 [26] _pullback
    @ ~\.julia\packages\Flux\0c9kI\src\functor.jl:73 [inlined]
 [27] _pullback
    @ ~\.julia\packages\Adapt\RGNRk\src\Adapt.jl:42 [inlined]
 [28] _pullback
    @ ~\.julia\packages\Adapt\RGNRk\src\Adapt.jl:40 [inlined]
 [29] _pullback(::Zygote.Context, ::typeof(Adapt.adapt), ::Type{Float32}, ::CuArray{Float32, 2})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [30] _pullback
    @ ~\.julia\packages\Flux\0c9kI\src\functor.jl:75 [inlined]
 [31] _pullback(ctx::Zygote.Context, f::Flux.var"#125#126"{DataType}, args::CuArray{Float32, 2})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [32] _pullback
    @ ~\.julia\packages\Functors\EWaud\src\functor.jl:56 [inlined]
 [33] _pullback(::Zygote.Context, ::Functors.var"##fmap#15", ::typeof(Functors.isleaf), ::IdDict{Any, Any}, ::typeof(fmap), ::Flux.var"#125#126"{DataType}, ::CuArray{Float32, 2})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [34] _pullback
    @ ~\.julia\packages\Functors\EWaud\src\functor.jl:55 [inlined]
 [35] _pullback
    @ ~\.julia\packages\Flux\0c9kI\src\functor.jl:75 [inlined]
 [36] _pullback
    @ ~\.julia\packages\Flux\0c9kI\src\functor.jl:82 [inlined]
 [37] _pullback
    @ .\operators.jl:858 [inlined]
 [38] _pullback
    @ e:\wqy\julia\test.jl:65 [inlined]
 [39] _pullback
    @ ~\.julia\packages\Flux\0c9kI\src\layers\basic.jl:36 [inlined]
 [40] _pullback(::Zygote.Context, ::typeof(Flux.applychain), ::Tuple{var"#1#2", Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}, ::CuArray{Float32, 4})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [41] _pullback
    @ ~\.julia\packages\Flux\0c9kI\src\layers\basic.jl:38 [inlined]
 [42] _pullback(ctx::Zygote.Context, f::Chain{Tuple{var"#1#2", Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, args::CuArray{Float32, 4})        
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
--- the last 4 lines are repeated 1 more time ---
 [47] _pullback
    @ e:\wqy\julia\test.jl:188 [inlined]
 [48] _pullback(::Zygote.Context, ::var"#7#8"{Chain{Tuple{Chain{Tuple{var"#1#2", Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, CuArray{Float32, 1}}, Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [49] pullback(f::Function, ps::Zygote.Params)
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface.jl:250
 [50] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface.jl:58
 [51] macro expansion
    @ e:\wqy\julia\test.jl:187 [inlined]
 [52] macro expansion
    @ ~\.julia\packages\ProgressMeter\Vf8un\src\ProgressMeter.jl:940 [inlined]
 [53] train(; kws::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Main e:\wqy\julia\test.jl:185
 [54] train()
    @ Main e:\wqy\julia\test.jl:159
 [55] top-level scope
    @ e:\wqy\julia\test.jl:197
 [56] include(fname::String)
    @ Base.MainInclude .\client.jl:444
 [57] startdebug(socket::Base.PipeEndpoint, error_handler::VSCodeDebugger.var"#3#4"{Tuple{String, String}})
    @ VSCodeDebugger.DebugAdapter ~\.vscode\extensions\julialang.language-julia-1.2.1\scripts\packages\DebugAdapter\src\packagedef.jl:93
 [58] startdebugger()
    @ VSCodeDebugger ~\.vscode\extensions\julialang.language-julia-1.2.1\scripts\packages\VSCodeDebugger\src\VSCodeDebugger.jl:38
in expression starting at e:\wqy\julia\test.jl:197

And here's the broadcast error in the CPU version

using Flux
using Flux.Data:DataLoader
using Flux.Optimise: Optimiser
using Flux: onehotbatch, onecold
using Flux.Losses:logitcrossentropy
using ProgressMeter:@showprogress
import MLDatasets
using CUDA
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
CUDA.allowscalar(false)


function Net() 



    down = Chain(
        x -> reshape(x, (784, 8)) |> f32,
        Dense(784, 200, tanh) |> f32,
        Dense(200, 20, tanh) |> f32,
    ) |> f32
    deq = Chain(
        Dense(20, 10, tanh) |> f32,
        Dense(10, 10, tanh) |> f32,
        Dense(10, 20, tanh) |> f32,
    ) |> f32
    p, re = Flux.destructure(deq)
    fc = Chain(
        Dense(20, 15, tanh) |> f32,
        Dense(15, 10, tanh) |> f32,
    ) |> f32

    tspan = (0.0f0, 1.0f0)
    function solve_ss(x)

        z = re(p)(x)
        function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
            re(_p)(u + x) - u
        end
        ss = SteadyStateProblem(ODEProblem(dudt_, z, tspan, p))
        x = solve(ss, DynamicSS(Tsit5()), u0=z, abstol=Float32(1e-2), reltol=Float32(1e-2), tspan=1.0f0).u
    end
  # Build our over-all model topology
    m = Chain(
        down,               # (28,28,1,BS) -> (6,6,64,BS)
        solve_ss,           # (6,6,64,BS) -> (6,6,64,BS)
        fc,                 # (6,6,64,BS) -> (10, BS)
    )

    return m
end

function get_data(args)
    xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
    xtest, ytest = MLDatasets.MNIST.testdata(Float32)

    xtrain = reshape(xtrain, 28, 28, 1, :)
    xtest = reshape(xtest, 28, 28, 1, :)

    ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

    train_loader = DataLoader((xtrain, ytrain), batchsize=args.batchsize, shuffle=true)
    test_loader = DataLoader((xtest, ytest),  batchsize=args.batchsize)
    
    return train_loader, test_loader
end

loss(ŷ, y) = logitcrossentropy(ŷ, y)

function eval_loss_accuracy(loader, model, device)
    l = 0f0
    acc = 0
    ntot = 0
    for (x, y) in loader
        x, y = x |> device, y |> device
        ŷ = model(x)
        l += loss(ŷ, y) * size(x)[end]        
        acc += sum(onecold(ŷ |> cpu) .== onecold(y |> cpu))
        ntot += size(x)[end]
    end
    return (loss = l / ntot |> round4, acc = acc / ntot * 100 |> round4)
end

# utility functions
round4(x) = round(x, digits=4)

# arguments for the `train` function 
Base.@kwdef mutable struct Args
    η = 3e-4             # learning rate
    λ = 0                # L2 regularizer param, implemented as weight decay
    batchsize = 8      # batch size
    epochs = 10          # number of epochs
    seed = 0             # set seed > 0 for reproducibility
    use_cuda = false      # if true use cuda (if available)
    infotime = 1 	     # report every `infotime` epochs
    checktime = 5        # Save the model every `checktime` epochs. Set to 0 for no checkpoints.
    tblogger = true      # log training with tensorboard
    savepath = "runs/"    # results path
end

function train(; kws...)
    args = Args(; kws...)
    args.seed > 0 && Random.seed!(args.seed)
    use_cuda = args.use_cuda && CUDA.functional()
    
    if use_cuda
        device = gpu
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end

    ## DATA
    train_loader, test_loader = get_data(args)
    @info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples"

    ## MODEL AND OPTIMIZER
    model = Net() |> device  
    
    ps = Flux.params(model)
    opt = ADAM(args.η)
    
    ## TRAINING
    @info "Start Training"
    for epoch in 1:args.epochs
        @showprogress for (x, y) in train_loader
            x, y = x |> device, y |> device
            gs = Flux.gradient(ps) do
                ŷ = model(x)
                loss(ŷ, y)
            end

            Flux.Optimise.update!(opt, ps, gs)
        end
    end
end

train()
[ Info: Training on CPU
[ Info: Dataset MNIST: 60000 train and 10000 test examples
[ Info: Start Training
ERROR: LoadError: DimensionMismatch("array could not be broadcast to match destination")
Stacktrace:
  [1] check_broadcast_shape
    @ .\broadcast.jl:520 [inlined]
  [2] check_broadcast_axes
    @ .\broadcast.jl:523 [inlined]
  [3] instantiate
    @ .\broadcast.jl:269 [inlined]
  [4] materialize!
    @ .\broadcast.jl:894 [inlined]
  [5] materialize!(dest::Matrix{Float32}, bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(identity), Tuple{Vector{Float32}}})
    @ Base.Broadcast .\broadcast.jl:891
  [6] SteadyStateAdjointProblem(sol::SciMLBase.NonlinearSolution{Float32, 2, Matrix{Float32}, Matrix{Float32}, SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt_#4"{Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing}, sensealg::SteadyStateAdjoint{0, true, Val{:central}, Bool, DefaultLinSolve}, g::Nothing, dg::SciMLBase.NonlinearSolution{Float32, 2, Matrix{Float32}, Matrix{Float32}, SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt_#4"{Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing}; save_idxs::Nothing)
    @ DiffEqSensitivity ~\.julia\packages\DiffEqSensitivity\p1AlV\src\steadystate_adjoint.jl:65
  [7] #_adjoint_sensitivities#57
    @ ~\.julia\packages\DiffEqSensitivity\p1AlV\src\sensitivity_interface.jl:65 [inlined]
  [8] #adjoint_sensitivities#54
    @ ~\.julia\packages\DiffEqSensitivity\p1AlV\src\sensitivity_interface.jl:6 [inlined]
  [9] steadystatebackpass
    @ ~\.julia\packages\DiffEqSensitivity\p1AlV\src\concrete_solve.jl:437 [inlined]
 [10] #98#back
    @ ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:65 [inlined]
 [11] #188
    @ ~\.julia\packages\Zygote\zowrf\src\lib\lib.jl:194 [inlined]
 [12] (::Zygote.var"#1689#back#190"{Zygote.var"#188#189"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, DiffEqBase.var"#98#back#74"{DiffEqSensitivity.var"#steadystatebackpass#209"{Nothing, DynamicSS{Tsit5, Float64, Float64, Float64}, SteadyStateAdjoint{0, true, Val{:central}, Bool, DefaultLinSolve}, Tuple{}, SciMLBase.NonlinearSolution{Float32, 2, Matrix{Float32}, Matrix{Float32}, SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt_#4"{Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, 
Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing}}}}})(Δ::SciMLBase.NonlinearSolution{Float32, 2, Matrix{Float32}, Matrix{Float32}, SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt_#4"{Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, 
Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [13] Pullback
    @ ~\.julia\packages\DiffEqBase\lULzQ\src\solve.jl:70 [inlined]
 [14] (::Zygote.Pullback{Tuple{DiffEqBase.var"##solve#57", Nothing, Matrix{Float32}, Nothing, Base.Iterators.Pairs{Symbol, Float32, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:abstol, :reltol, :tspan), Tuple{Float32, Float32, Float32}}}, typeof(solve), SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt_#4"{Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), 
Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}}, Any})(Δ::SciMLBase.NonlinearSolution{Float32, 2, Matrix{Float32}, Matrix{Float32}, SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt_#4"{Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [15] (::Zygote.var"#188#189"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{DiffEqBase.var"##solve#57", Nothing, Matrix{Float32}, Nothing, Base.Iterators.Pairs{Symbol, Float32, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:abstol, :reltol, :tspan), Tuple{Float32, Float32, Float32}}}, typeof(solve), SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt_#4"{Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}}, Any}})(Δ::SciMLBase.NonlinearSolution{Float32, 2, Matrix{Float32}, Matrix{Float32}, SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt_#4"{Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\lib\lib.jl:194
 [16] (::Zygote.var"#1689#back#190"{Zygote.var"#188#189"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{DiffEqBase.var"##solve#57", Nothing, Matrix{Float32}, Nothing, Base.Iterators.Pairs{Symbol, Float32, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:abstol, :reltol, :tspan), Tuple{Float32, Float32, Float32}}}, typeof(solve), SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt_#4"{Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, 
DynamicSS{Tsit5, Float64, Float64, Float64}}, Any}}})(Δ::SciMLBase.NonlinearSolution{Float32, 2, Matrix{Float32}, Matrix{Float32}, SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt_#4"{Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), 
Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [17] Pullback
    @ ~\.julia\packages\DiffEqBase\lULzQ\src\solve.jl:68 [inlined]
 [18] (::Zygote.Pullback{Tuple{CommonSolve.var"#solve##kw", NamedTuple{(:u0, :abstol, :reltol, :tspan), Tuple{Matrix{Float32}, Float32, Float32, Float32}}, typeof(solve), SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt_#4"{Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}}, Any})(Δ::SciMLBase.NonlinearSolution{Float32, 2, Matrix{Float32}, Matrix{Float32}, SteadyStateProblem{Matrix{Float32}, false, Vector{Float32}, ODEFunction{false, var"#dudt_#4"{Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, 
Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [19] Pullback
    @ e:\wqy\julia\test.jl:94 [inlined]
 [20] (::Zygote.Pullback{Tuple{var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Matrix{Float32}}, Any})(Δ::Matrix{Float32})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [21] Pullback
    @ ~\.julia\packages\Flux\0c9kI\src\layers\basic.jl:36 [inlined]
 [22] (::Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, 
Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, Matrix{Float32}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}}}}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, Matrix{Float32}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{Matrix{Float32}, Vector{Float32}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{Matrix{Float32}, Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{Matrix{Float32}, Vector{Float32}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{Matrix{Float32}, Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), 
Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.Pullback{Tuple{var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Matrix{Float32}}, Any}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}})(Δ::Matrix{Float32})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [23] Pullback
    @ ~\.julia\packages\Flux\0c9kI\src\layers\basic.jl:36 [inlined]
 [24] (::Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Array{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, Matrix{Float32}}, Tuple{}}, 
Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}}}}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, Matrix{Float32}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{Matrix{Float32}, Vector{Float32}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{Matrix{Float32}, Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:σ}}, 
Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{Matrix{Float32}, Vector{Float32}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{Matrix{Float32}, Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.Pullback{Tuple{var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Matrix{Float32}}, Any}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, 
Nothing}}}, Zygote.Pullback{Tuple{Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Array{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}}}}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, 
Array{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, Matrix{Float32}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{Matrix{Float32}, Vector{Float32}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{Matrix{Float32}, Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{Matrix{Float32}, Vector{Float32}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{Matrix{Float32}, Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing}}}, Zygote.Pullback{Tuple{var"#1#2", Array{Float32, 4}}, Tuple{Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#2436#back#489"{Zygote.var"#485#487"{Array{Float32, 4}, Tuple{Tuple{Int64, Int64}}}}, Zygote.Pullback{Tuple{typeof(|>), Matrix{Float32}, typeof(f32)}, Tuple{Zygote.Pullback{Tuple{typeof(f32), Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.paramtype), Type{Float32}, Matrix{Float32}}, Tuple{Zygote.var"#1723#back#204"{Zygote.Jnew{Flux.var"#125#126"{DataType}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(fmap), Flux.var"#125#126"{DataType}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{Functors.var"##fmap#15", typeof(Functors.isleaf), IdDict{Any, Any}, typeof(fmap), Flux.var"#125#126"{DataType}, Matrix{Float32}}, Any}, Zygote.Pullback{Tuple{Type{IdDict}}, Tuple{}}}}}}}}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}})(Δ::Matrix{Float32})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [25] Pullback
    @ ~\.julia\packages\Flux\0c9kI\src\layers\basic.jl:38 [inlined]
 [26] (::Zygote.Pullback{Tuple{Chain{Tuple{Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}, Array{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}, Tuple{Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}}}}}}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Array{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, Matrix{Float32}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}}}}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, Matrix{Float32}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{Matrix{Float32}, Vector{Float32}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{Matrix{Float32}, Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{Matrix{Float32}, Vector{Float32}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{Matrix{Float32}, Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.Pullback{Tuple{var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Matrix{Float32}}, Any}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing}}}, 
Zygote.Pullback{Tuple{Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Array{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}}}}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, Array{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, Matrix{Float32}}, Tuple{}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{Matrix{Float32}, Vector{Float32}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{Matrix{Float32}, Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.Pullback{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Matrix{Float32}}, Tuple{Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{Matrix{Float32}, Vector{Float32}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1527"{Matrix{Float32}, Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.var"#3831#back#1053"{Zygote.var"#1051#1052"{Matrix{Float32}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Matrix{Float32}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Symbol}, Any}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing}}}, Zygote.Pullback{Tuple{var"#1#2", Array{Float32, 4}}, Tuple{Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#2436#back#489"{Zygote.var"#485#487"{Array{Float32, 4}, Tuple{Tuple{Int64, Int64}}}}, Zygote.Pullback{Tuple{typeof(|>), Matrix{Float32}, typeof(f32)}, Tuple{Zygote.Pullback{Tuple{typeof(f32), Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.paramtype), Type{Float32}, Matrix{Float32}}, Tuple{Zygote.var"#1723#back#204"{Zygote.Jnew{Flux.var"#125#126"{DataType}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(fmap), Flux.var"#125#126"{DataType}, Matrix{Float32}}, Tuple{Zygote.Pullback{Tuple{Functors.var"##fmap#15", typeof(Functors.isleaf), IdDict{Any, Any}, typeof(fmap), Flux.var"#125#126"{DataType}, Matrix{Float32}}, Any}, Zygote.Pullback{Tuple{Type{IdDict}}, Tuple{}}}}}}}}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}}}}})(Δ::Matrix{Float32})   
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [27] Pullback
    @ e:\wqy\julia\test.jl:185 [inlined]
 [28] (::Zygote.Pullback{Tuple{var"#7#8"{Chain{Tuple{Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}}}, Any})(Δ::Float32)
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [29] (::Zygote.var"#69#70"{Zygote.Params, Zygote.Pullback{Tuple{var"#7#8"{Chain{Tuple{Chain{Tuple{var"#1#2", Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, var"#solve_ss#3"{Tuple{Float32, Float32}, Flux.var"#61#63"{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}, Vector{Float32}}, Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}}}}}, Any}, Zygote.Context})(Δ::Float32)
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface.jl:255
 [30] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface.jl:59
 [31] macro expansion
    @ e:\wqy\julia\test.jl:184 [inlined]
 [32] macro expansion
    @ ~\.julia\packages\ProgressMeter\Vf8un\src\ProgressMeter.jl:940 [inlined]
 [33] train(; kws::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Main e:\wqy\julia\test.jl:182
 [34] train()
    @ Main e:\wqy\julia\test.jl:156
 [35] top-level scope
    @ e:\wqy\julia\test.jl:194
 [36] include(fname::String)
    @ Base.MainInclude .\client.jl:444
 [37] startdebug(socket::Base.PipeEndpoint, error_handler::VSCodeDebugger.var"#3#4"{Tuple{String, String}})
    @ VSCodeDebugger.DebugAdapter ~\.vscode\extensions\julialang.language-julia-1.2.1\scripts\packages\DebugAdapter\src\packagedef.jl:93
 [38] startdebugger()
    @ VSCodeDebugger ~\.vscode\extensions\julialang.language-julia-1.2.1\scripts\packages\VSCodeDebugger\src\VSCodeDebugger.jl:38
 [39] top-level scope
    @ ~\.vscode\extensions\julialang.language-julia-1.2.1\scripts\debugger\run_debugger.jl:9
in expression starting at e:\wqy\julia\test.jl:194

@QiyaoWei
Copy link
Contributor Author

QiyaoWei commented Jun 9, 2021

Hey Frank! Thanks a lot for the fix in DiffEqSensitivity419! I updated my packages and restarted Julia, but was surprised to find that running under CPU is subject to the same problem as above DimensionMismatch("array could not be broadcast to match destination"). Could you kindly share your code that ran successfully under CPU? Thanks a lot!

@frankschae
Copy link
Member

I used your code :) We need a new tag on DiffEqSensitivity @ChrisRackauckas .

@ChrisRackauckas
Copy link
Member

tagged

@QiyaoWei
Copy link
Contributor Author

QiyaoWei commented Jun 9, 2021

Never mind. DiffEqSensitivity somehow decided not to automatically update :) I don't think this fix solves the GPU problem however. I encounter the same error as before (GPUCompilerKernelError), so I guess it has to do with something else

@ChrisRackauckas
Copy link
Member

Yeah so what's the current state of this?

@QiyaoWei
Copy link
Contributor Author

QiyaoWei commented Jun 10, 2021

So Frank's fix in DiffEqSensitivity419 solved the problem in my previous thread, and SteadyState training with fully connected architecture is now possible, with either CPU or GPU. However, there is a problem with SteadyState applied on conv, which I am still trying to figure out.

using Flux
using Flux.Data:DataLoader
using Flux.Optimise: Optimiser
using Flux: onehotbatch, onecold
using Flux.Losses:logitcrossentropy
using ProgressMeter:@showprogress
import MLDatasets
using CUDA
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
CUDA.allowscalar(false)


function Net() 


    down = Chain(
        Conv((3, 3), 1 => 64, relu, stride=1),
        GroupNorm(64, 64),
        Conv((4, 4), 64 => 64, relu, stride=2, pad=1),
        GroupNorm(64, 64),
        Conv((4, 4), 64 => 64, stride=2, pad=1),
    )

    deq = Chain(
        Conv((3, 3), 64 => 64, relu, stride=1, pad=1),
        Conv((3, 3), 64 => 64, relu, stride=1, pad=1),
    )

    p, re = Flux.destructure(deq)
    fc = Chain(
        GroupNorm(64, 64),
        x -> relu.(x),
        MeanPool((6, 6)),
        x -> reshape(x, (64, :)),
        Dense(64, 10),
    )

    tspan = (0.0f0, 1.0f0)
    function solve_ss(x)

        z = re(p)(x)
        function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
            re(_p)(u + x) - u
        end
        ss = SteadyStateProblem(ODEProblem(dudt_, z, tspan, p))
        x = solve(ss, DynamicSS(Tsit5()), u0=z, abstol=Float32(1e-2), reltol=Float32(1e-2), tspan=1.0f0).u
    end
  # Build our over-all model topology
    m = Chain(
        down,               # (28,28,1,BS) -> (6,6,64,BS)
        solve_ss,           # (6,6,64,BS) -> (6,6,64,BS)
        fc,                 # (6,6,64,BS) -> (10, BS)
    )

    return m
end

function get_data(args)
    xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
    xtest, ytest = MLDatasets.MNIST.testdata(Float32)

    xtrain = reshape(xtrain, 28, 28, 1, :)
    xtest = reshape(xtest, 28, 28, 1, :)

    ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

    train_loader = DataLoader((xtrain, ytrain), batchsize=args.batchsize, shuffle=true)
    test_loader = DataLoader((xtest, ytest),  batchsize=args.batchsize)
    
    return train_loader, test_loader
end

loss(ŷ, y) = logitcrossentropy(ŷ, y)

function eval_loss_accuracy(loader, model, device)
    l = 0f0
    acc = 0
    ntot = 0
    for (x, y) in loader
        x, y = x |> device, y |> device
        ŷ = model(x)
        l += loss(ŷ, y) * size(x)[end]        
        acc += sum(onecold(ŷ |> cpu) .== onecold(y |> cpu))
        ntot += size(x)[end]
    end
    return (loss = l / ntot |> round4, acc = acc / ntot * 100 |> round4)
end

# utility functions
round4(x) = round(x, digits=4)

# arguments for the `train` function 
Base.@kwdef mutable struct Args
    η = 3e-4             # learning rate
    λ = 0                # L2 regularizer param, implemented as weight decay
    batchsize = 8      # batch size
    epochs = 10          # number of epochs
    seed = 0             # set seed > 0 for reproducibility
    use_cuda = false      # if true use cuda (if available)
    infotime = 1 	     # report every `infotime` epochs
    checktime = 5        # Save the model every `checktime` epochs. Set to 0 for no checkpoints.
    tblogger = true      # log training with tensorboard
    savepath = "runs/"    # results path
end

function train(; kws...)
    args = Args(; kws...)
    args.seed > 0 && Random.seed!(args.seed)
    use_cuda = args.use_cuda && CUDA.functional()
    
    if use_cuda
        device = gpu
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end

    ## DATA
    train_loader, test_loader = get_data(args)
    @info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples"

    ## MODEL AND OPTIMIZER
    model = Net() |> device  
    
    ps = Flux.params(model)
    opt = ADAM(args.η)
    
    ## TRAINING
    @info "Start Training"
    for epoch in 1:args.epochs
        @showprogress for (x, y) in train_loader
            x, y = x |> device, y |> device
            gs = Flux.gradient(ps) do
                ŷ = model(x)
                loss(ŷ, y)
            end

            Flux.Optimise.update!(opt, ps, gs)
        end

        train = eval_loss_accuracy(train_loader, model, device)
        test = eval_loss_accuracy(test_loader, model, device) 
        @info "train" loss = train.loss  acc = train.acc
        @info "test"  loss = test.loss   acc = test.acc
    end
end

train()
┌ Warning: Slow fallback implementation invoked for conv!  You probably don't want this; check your datatypes.
│   yT = ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}
│   T1 = ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}
│   T2 = Float32
└ @ NNlib C:\Users\administered\.julia\packages\NNlib\3MZcC\src\conv.jl:206
ERROR: LoadError: scalar getindex is disallowed
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] assertscalar(op::String)
    @ GPUArrays ~\.julia\packages\GPUArrays\Z5nPF\src\host\indexing.jl:62
  [3] getindex(::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 5}, ::Int64, ::Int64, ::Int64, ::Int64, ::Vararg{Int64, N} where N)
    @ GPUArrays ~\.julia\packages\GPUArrays\Z5nPF\src\host\indexing.jl:104
  [4] conv_direct!(y::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 5}, x::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 5}, w::CuArray{Float32, 5}, cdims::DenseConvDims{3, (3, 3, 1), 64, 64, (1, 1, 1), (1, 
1, 1, 1, 0, 0), (1, 1, 1), false}; alpha::ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, beta::Bool)
    @ NNlib ~\.julia\packages\NNlib\3MZcC\src\impl\conv_direct.jl:91
  [5] conv_direct!
    @ ~\.julia\packages\NNlib\3MZcC\src\impl\conv_direct.jl:51 [inlined]
  [6] conv!(y::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 5}, in1::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 5}, in2::CuArray{Float32, 5}, cdims::DenseConvDims{3, (3, 3, 1), 64, 64, (1, 1, 1), (1, 1, 
1, 1, 0, 0), (1, 1, 1), false}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~\.julia\packages\NNlib\3MZcC\src\conv.jl:208
  [7] conv!(y::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 5}, in1::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 5}, in2::CuArray{Float32, 5}, cdims::DenseConvDims{3, (3, 3, 1), 64, 64, (1, 1, 1), (1, 1, 
1, 1, 0, 0), (1, 1, 1), false})
    @ NNlib ~\.julia\packages\NNlib\3MZcC\src\conv.jl:206
  [8] conv!(y::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 4}, x::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 4}, w::CuArray{Float32, 4}, cdims::DenseConvDims{2, (3, 3), 64, 64, (1, 1), (1, 1, 1, 1), (1, 1), false}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~\.julia\packages\NNlib\3MZcC\src\conv.jl:148
  [9] conv!
    @ ~\.julia\packages\NNlib\3MZcC\src\conv.jl:148 [inlined]
 [10] conv(x::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 4}, w::CuArray{Float32, 4}, cdims::DenseConvDims{2, (3, 3), 64, 64, (1, 1), (1, 1, 1, 1), (1, 1), false}; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~\.julia\packages\NNlib\3MZcC\src\conv.jl:91
 [11] conv(x::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 4}, w::CuArray{Float32, 4}, cdims::DenseConvDims{2, (3, 3), 64, 64, (1, 1), (1, 1, 1, 1), (1, 1), false})
    @ NNlib ~\.julia\packages\NNlib\3MZcC\src\conv.jl:89
 [12] (::Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}})(x::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 4})
    @ Flux ~\.julia\packages\Flux\0c9kI\src\layers\conv.jl:157
 [13] applychain(fs::Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}}, x::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 4) 
    @ Flux ~\.julia\packages\Flux\0c9kI\src\layers\basic.jl:36
 [14] (::Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}}})(x::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 4})     
    @ Flux ~\.julia\packages\Flux\0c9kI\src\layers\basic.jl:38
 [15] (::var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}})(u::CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 4}, _p::CuArray{Float32, 1}, t::Nothing)
    @ Main e:\wqy\julia\mnistDEQ.jl:80
 [16] ODEFunction
    @ ~\.julia\packages\SciMLBase\grNUR\src\scimlfunctions.jl:334 [inlined]
 [17] UDerivativeWrapper
    @ ~\.julia\packages\SciMLBase\grNUR\src\function_wrappers.jl:30 [inlined]
 [18] chunk_mode_jacobian(f::SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, x::CuArray{Float32, 4}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12, CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 4}})
    @ ForwardDiff ~\.julia\packages\ForwardDiff\QOqCN\src\jacobian.jl:223
 [19] jacobian(f::Function, x::CuArray{Float32, 4}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12, CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 4}}, ::Val{true})
    @ ForwardDiff ~\.julia\packages\ForwardDiff\QOqCN\src\jacobian.jl:23
 [20] jacobian(f::Function, x::CuArray{Float32, 4}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12, CuArray{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, CuArray{Float32, 1}}, Float32}, Float32, 12}, 4}}) (repeats 2 times)
    @ ForwardDiff ~\.julia\packages\ForwardDiff\QOqCN\src\jacobian.jl:19
 [21] jacobian(f::Function, x::CuArray{Float32, 4}, alg::SteadyStateAdjoint{0, true, Val{:central}, Bool, DefaultLinSolve})
    @ DiffEqSensitivity ~\.julia\packages\DiffEqSensitivity\aO4e2\src\derivative_wrappers.jl:135
 [22] SteadyStateAdjointProblem(sol::SciMLBase.NonlinearSolution{Float32, 4, CuArray{Float32, 4}, CuArray{Float32, 4}, SteadyStateProblem{CuArray{Float32, 4}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing}, sensealg::SteadyStateAdjoint{0, true, Val{:central}, Bool, DefaultLinSolve}, g::Nothing, dg::SciMLBase.NonlinearSolution{Float32, 4, CuArray{Float32, 4}, CuArray{Float32, 4}, SteadyStateProblem{CuArray{Float32, 4}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing}; save_idxs::Nothing)
    @ DiffEqSensitivity ~\.julia\packages\DiffEqSensitivity\aO4e2\src\steadystate_adjoint.jl:41
 [23] #_adjoint_sensitivities#59
    @ ~\.julia\packages\DiffEqSensitivity\aO4e2\src\sensitivity_interface.jl:65 [inlined]
 [24] #adjoint_sensitivities#56
    @ ~\.julia\packages\DiffEqSensitivity\aO4e2\src\sensitivity_interface.jl:6 [inlined]
 [25] steadystatebackpass
    @ ~\.julia\packages\DiffEqSensitivity\aO4e2\src\concrete_solve.jl:439 [inlined]
 [26] #98#back
    @ ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:65 [inlined]
 [27] #188
    @ ~\.julia\packages\Zygote\zowrf\src\lib\lib.jl:194 [inlined]
 [28] (::Zygote.var"#1689#back#190"{Zygote.var"#188#189"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, DiffEqBase.var"#98#back#76"{DiffEqSensitivity.var"#steadystatebackpass#219"{Nothing, DynamicSS{Tsit5, Float64, Float64, Float64}, SteadyStateAdjoint{0, true, Val{:central}, Bool, DefaultLinSolve}, Tuple{}, SciMLBase.NonlinearSolution{Float32, 4, CuArray{Float32, 4}, CuArray{Float32, 4}, SteadyStateProblem{CuArray{Float32, 4}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing}}}}})(Δ::SciMLBase.NonlinearSolution{Float32, 4, CuArray{Float32, 4}, CuArray{Float32, 4}, SteadyStateProblem{CuArray{Float32, 4}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), 
Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [29] Pullback
    @ ~\.julia\packages\DiffEqBase\lULzQ\src\solve.jl:70 [inlined]
 [30] (::Zygote.Pullback{Tuple{DiffEqBase.var"##solve#59", Nothing, CuArray{Float32, 4}, Nothing, Base.Iterators.Pairs{Symbol, Float32, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:abstol, :reltol, :tspan), Tuple{Float32, Float32, Float32}}}, typeof(solve), SteadyStateProblem{CuArray{Float32, 4}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}}, Any})(Δ::SciMLBase.NonlinearSolution{Float32, 4, CuArray{Float32, 4}, CuArray{Float32, 4}, SteadyStateProblem{CuArray{Float32, 4}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [31] (::Zygote.var"#188#189"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{DiffEqBase.var"##solve#59", Nothing, CuArray{Float32, 4}, Nothing, Base.Iterators.Pairs{Symbol, Float32, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:abstol, :reltol, :tspan), Tuple{Float32, Float32, Float32}}}, typeof(solve), SteadyStateProblem{CuArray{Float32, 4}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}}, Any}})(Δ::SciMLBase.NonlinearSolution{Float32, 4, CuArray{Float32, 4}, CuArray{Float32, 4}, SteadyStateProblem{CuArray{Float32, 4}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})       
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\lib\lib.jl:194
 [32] (::Zygote.var"#1689#back#190"{Zygote.var"#188#189"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{DiffEqBase.var"##solve#59", Nothing, CuArray{Float32, 4}, Nothing, Base.Iterators.Pairs{Symbol, Float32, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:abstol, :reltol, :tspan), Tuple{Float32, Float32, Float32}}}, typeof(solve), SteadyStateProblem{CuArray{Float32, 4}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}}, Any}}})(Δ::SciMLBase.NonlinearSolution{Float32, 4, CuArray{Float32, 4}, CuArray{Float32, 4}, SteadyStateProblem{CuArray{Float32, 4}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Zygote ~\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:59
 [33] Pullback
    @ ~\.julia\packages\DiffEqBase\lULzQ\src\solve.jl:68 [inlined]
 [34] (::Zygote.Pullback{Tuple{CommonSolve.var"#solve##kw", NamedTuple{(:u0, :abstol, :reltol, :tspan), Tuple{CuArray{Float32, 4}, Float32, Float32, Float32}}, typeof(solve), SteadyStateProblem{CuArray{Float32, 4}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}}, Any})(Δ::SciMLBase.NonlinearSolution{Float32, 4, CuArray{Float32, 4}, CuArray{Float32, 4}, SteadyStateProblem{CuArray{Float32, 4}, false, CuArray{Float32, 1}, ODEFunction{false, var"#dudt_#6"{Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float64, Float64, Float64}, Nothing, Nothing})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [35] Pullback
    @ e:\wqy\julia\mnistDEQ.jl:83 [inlined]
 [36] (::Zygote.Pullback{Tuple{var"#solve_ss#5"{Tuple{Float32, Float32}, Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, CuArray{Float32, 4}}, Any})(Δ::CuArray{Float32, 4})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [37] Pullback
    @ ~\.julia\packages\Flux\0c9kI\src\layers\basic.jl:36 [inlined]
--- the last 2 lines are repeated 1 more time ---
 [40] (::Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, var"#solve_ss#5"{Tuple{Float32, Float32}, Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:bias, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 1}}}}}}}, Zygote.ZBack{NNlib.var"#conv_pullback#183"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, CuArray{Float32, 4}, CuArray{Float32, 4}, DenseConvDims{2, (3, 3), 1, 64, (1, 1), (0, 0, 0, 0), (1, 1), false}}}, Zygote.var"#1723#back#204"{Zygote.Jnew{Flux.var"#174#175", Nothing, false}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:σ}}, 
Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:σ, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, typeof(relu)}}}}}}, Zygote.var"#1689#back#190"{Zygote.var"#188#189"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}, Tuple{Nothing, Nothing}}, Zygote.var"#2436#back#489"{Zygote.var"#485#487"{CuArray{Float32, 1}, Tuple{Int64, Int64, Colon, Int64}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 4}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:dilation}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:dilation, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.ZBack{ChainRules.var"#length_pullback#877"}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.ZBack{NNlib.var"#broadcasted_relu_pullback#32"{CuArray{Float32, 4}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:stride}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:stride, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:stride}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:stride, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:pad}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:pad, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, NTuple{4, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ntuple), Flux.var"#174#175", Int64}, Any}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, 
CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:weight, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:stride, :padding, :dilation), T} where T<:Tuple}, Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:stride, :padding, :dilation), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}}, Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.var"#1733#back#206"{Zygote.Jnew{NamedTuple{(:stride, :padding, :dilation), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:weight}}, 
Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:weight, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}}}}}}, Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 4}, CuArray{Float32, 4}}}}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#kw_zpullback#40"{NNlib.var"#DenseConvDims_pullback#162"{Tuple{CuArray{Float32, 4}, CuArray{Float32, 4}}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{NTuple{4, Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, CuArray{Float32, 4}}, Any}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing, Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:bias, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 1}}}}}}}, Zygote.ZBack{NNlib.var"#conv_pullback#183"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, CuArray{Float32, 4}, CuArray{Float32, 4}, DenseConvDims{2, (4, 4), 64, 64, (2, 2), (1, 1, 1, 1), (1, 1), false}}}, Zygote.var"#1723#back#204"{Zygote.Jnew{Flux.var"#174#175", Nothing, false}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:σ, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, typeof(relu)}}}}}}, Zygote.var"#1689#back#190"{Zygote.var"#188#189"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}, Tuple{Nothing, Nothing}}, Zygote.var"#2436#back#489"{Zygote.var"#485#487"{CuArray{Float32, 1}, Tuple{Int64, Int64, Colon, Int64}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 4}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:dilation}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:dilation, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.ZBack{ChainRules.var"#length_pullback#877"}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.ZBack{NNlib.var"#broadcasted_relu_pullback#32"{CuArray{Float32, 4}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:stride}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:stride, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 
4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:stride}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:stride, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:pad}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:pad, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, NTuple{4, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ntuple), Flux.var"#174#175", Int64}, Any}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:weight, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:stride, :padding, :dilation), T} where T<:Tuple}, Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:stride, :padding, :dilation), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}}, Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.var"#1733#back#206"{Zygote.Jnew{NamedTuple{(:stride, :padding, :dilation), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:weight, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}}}}}}, Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 4}, CuArray{Float32, 4}}}}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#kw_zpullback#40"{NNlib.var"#DenseConvDims_pullback#162"{Tuple{CuArray{Float32, 4}, CuArray{Float32, 4}}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, CuArray{Float32, 4}}, Any}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, 
CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:bias, Zygote.Context, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 1}}}}}}}, Zygote.ZBack{NNlib.var"#conv_pullback#183"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, CuArray{Float32, 4}, CuArray{Float32, 4}, DenseConvDims{2, (4, 4), 64, 64, (2, 2), (1, 1, 1, 1), (1, 1), false}}}, Zygote.var"#1723#back#204"{Zygote.Jnew{Flux.var"#174#175", Nothing, false}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:σ, Zygote.Context, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, typeof(identity)}}}}}}, Zygote.var"#1689#back#190"{Zygote.var"#188#189"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}, Tuple{Nothing, Nothing}}, Zygote.var"#2436#back#489"{Zygote.var"#485#487"{CuArray{Float32, 1}, Tuple{Int64, Int64, Colon, Int64}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 4}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:dilation}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:dilation, Zygote.Context, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.ZBack{ChainRules.var"#length_pullback#877"}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#3819#back#1049"{Zygote.var"#1047#1048"}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:stride}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:stride, Zygote.Context, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, 
typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:stride}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:stride, Zygote.Context, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:pad}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:pad, Zygote.Context, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, NTuple{4, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ntuple), Flux.var"#174#175", Int64}, Any}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:weight, Zygote.Context, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:stride, :padding, :dilation), T} where T<:Tuple}, Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:stride, :padding, :dilation), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}}, Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.var"#1733#back#206"{Zygote.Jnew{NamedTuple{(:stride, :padding, :dilation), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:weight, Zygote.Context, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}}}}}}, Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 4}, CuArray{Float32, 4}}}}, 
Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#kw_zpullback#40"{NNlib.var"#DenseConvDims_pullback#162"{Tuple{CuArray{Float32, 4}, CuArray{Float32, 4}}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, CuArray{Float32, 4}}, Tuple{}}}}}}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}}}}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{var"#solve_ss#5"{Tuple{Float32, Float32}, Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{var"#solve_ss#5"{Tuple{Float32, Float32}, Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, CuArray{Float32, 4}}, Any}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, 
Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, CuArray{Float32, 4}}, Any}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{NTuple{4, Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{var"#1#3", CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 4}}, Tuple{}}, Zygote.ZBack{NNlib.var"#broadcasted_relu_pullback#32"{CuArray{Float32, 4}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing, Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{MeanPool{2, 4}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:padding, :stride), T} where T<:Tuple}, Tuple{NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:padding, :stride), Tuple{NTuple{4, Int64}, Tuple{Int64, Int64}}}}, Tuple{NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.var"#1733#back#206"{Zygote.Jnew{NamedTuple{(:padding, :stride), Tuple{NTuple{4, Int64}, Tuple{Int64, Int64}}}, Nothing, true}}}}}}, Zygote.var"#kw_zpullback#40"{NNlib.var"#PoolDims_pullback#176"{Tuple{CuArray{Float32, 4}, Tuple{Int64, Int64}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), MeanPool{2, 4}, Val{:pad}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), MeanPool{2, 4}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:pad, Zygote.Context, MeanPool{2, 4}, NTuple{4, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), MeanPool{2, 4}, Val{:k}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), MeanPool{2, 4}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:k, Zygote.Context, MeanPool{2, 4}, Tuple{Int64, Int64}}}}}}}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.ZBack{NNlib.var"#meanpool_pullback#247"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, CuArray{Float32, 4}, PoolDims{2, (6, 6), (6, 6), (0, 0, 0, 0), (1, 1)}, CuArray{Float32, 4}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), MeanPool{2, 4}, Val{:stride}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), MeanPool{2, 4}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:stride, Zygote.Context, MeanPool{2, 4}, Tuple{Int64, Int64}}}}}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{var"#2#4", CuArray{Float32, 4}}, Tuple{Zygote.var"#2436#back#489"{Zygote.var"#485#487"{CuArray{Float32, 4}, Tuple{Tuple{Int64, Colon}}}}, Zygote.var"#1569#back#135"{typeof(identity)}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3819#back#1049"{Zygote.var"#1047#1048"}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1529"{CuArray{Float32, 2}, CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, CuArray{Float32, 2}}, Tuple{}}}}}}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, CuArray{Float32, 2}}, Tuple{}}}}}}}})(Δ:: C
uArray{Float32, 2})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [41] Pullback
    @ ~\.julia\packages\Flux\0c9kI\src\layers\basic.jl:38 [inlined]
 [42] (::Zygote.Pullback{Tuple{Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, var"#solve_ss#5"{Tuple{Float32, Float32}, Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, var"#solve_ss#5"{Tuple{Float32, Float32}, Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), 
Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:bias, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 1}}}}}}}, Zygote.ZBack{NNlib.var"#conv_pullback#183"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, CuArray{Float32, 4}, CuArray{Float32, 4}, DenseConvDims{2, (3, 3), 1, 64, (1, 1), (0, 0, 0, 0), (1, 1), false}}}, Zygote.var"#1723#back#204"{Zygote.Jnew{Flux.var"#174#175", Nothing, false}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:σ, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, typeof(relu)}}}}}}, Zygote.var"#1689#back#190"{Zygote.var"#188#189"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}, Tuple{Nothing, Nothing}}, Zygote.var"#2436#back#489"{Zygote.var"#485#487"{CuArray{Float32, 1}, Tuple{Int64, Int64, Colon, Int64}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 4}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:dilation}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:dilation, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.ZBack{ChainRules.var"#length_pullback#877"}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.ZBack{NNlib.var"#broadcasted_relu_pullback#32"{CuArray{Float32, 4}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:stride}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:stride, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:stride}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:stride, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:pad}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:pad, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, NTuple{4, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ntuple), Flux.var"#174#175", Int64}, Any}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:weight, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:stride, :padding, :dilation), T} where T<:Tuple}, 
Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:stride, :padding, :dilation), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}}, Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.var"#1733#back#206"{Zygote.Jnew{NamedTuple{(:stride, :padding, :dilation), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:weight, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}}}}}}, Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 4}, CuArray{Float32, 4}}}}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#kw_zpullback#40"{NNlib.var"#DenseConvDims_pullback#162"{Tuple{CuArray{Float32, 4}, CuArray{Float32, 4}}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{NTuple{4, Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, CuArray{Float32, 4}}, Any}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing, Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, 
GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:bias, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 1}}}}}}}, Zygote.ZBack{NNlib.var"#conv_pullback#183"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, CuArray{Float32, 4}, CuArray{Float32, 4}, DenseConvDims{2, (4, 4), 64, 64, (2, 2), (1, 1, 1, 1), (1, 1), false}}}, Zygote.var"#1723#back#204"{Zygote.Jnew{Flux.var"#174#175", Nothing, false}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:σ, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, typeof(relu)}}}}}}, Zygote.var"#1689#back#190"{Zygote.var"#188#189"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}, Tuple{Nothing, Nothing}}, Zygote.var"#2436#back#489"{Zygote.var"#485#487"{CuArray{Float32, 1}, Tuple{Int64, Int64, Colon, Int64}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 4}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:dilation}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:dilation, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.ZBack{ChainRules.var"#length_pullback#877"}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.ZBack{NNlib.var"#broadcasted_relu_pullback#32"{CuArray{Float32, 4}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:stride}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:stride, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:stride}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:stride, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:pad}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:pad, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, NTuple{4, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ntuple), Flux.var"#174#175", Int64}, Any}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:weight, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:stride, :padding, :dilation), T} where T<:Tuple}, Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:stride, :padding, :dilation), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}}, Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.var"#1733#back#206"{Zygote.Jnew{NamedTuple{(:stride, :padding, :dilation), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:weight, Zygote.Context, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}}}}}}, Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 4}, CuArray{Float32, 4}}}}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#kw_zpullback#40"{NNlib.var"#DenseConvDims_pullback#162"{Tuple{CuArray{Float32, 4}, CuArray{Float32, 4}}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), 
CuArray{Float32, 4}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, CuArray{Float32, 4}}, Any}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:bias, Zygote.Context, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 1}}}}}}}, Zygote.ZBack{NNlib.var"#conv_pullback#183"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, CuArray{Float32, 4}, CuArray{Float32, 4}, DenseConvDims{2, (4, 4), 64, 64, (2, 2), (1, 1, 1, 1), (1, 1), false}}}, Zygote.var"#1723#back#204"{Zygote.Jnew{Flux.var"#174#175", Nothing, false}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:σ, Zygote.Context, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, typeof(identity)}}}}}}, Zygote.var"#1689#back#190"{Zygote.var"#188#189"{Tuple{Tuple{Nothing}, Tuple{Nothing, Nothing}, Tuple{Nothing, Nothing}}, Zygote.var"#2436#back#489"{Zygote.var"#485#487"{CuArray{Float32, 1}, Tuple{Int64, Int64, Colon, Int64}}}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 4}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:dilation}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:dilation, Zygote.Context, 
Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.ZBack{ChainRules.var"#length_pullback#877"}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#3819#back#1049"{Zygote.var"#1047#1048"}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:stride}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:stride, Zygote.Context, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:stride}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:stride, Zygote.Context, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Tuple{Int64, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:pad}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:pad, Zygote.Context, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, NTuple{4, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ntuple), Flux.var"#174#175", Int64}, Any}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:weight, Zygote.Context, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}}}}}}, Zygote.Pullback{Tuple{Type{NamedTuple{(:stride, :padding, :dilation), T} where T<:Tuple}, Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:stride, :padding, :dilation), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}}, Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.var"#1733#back#206"{Zygote.Jnew{NamedTuple{(:stride, :padding, :dilation), Tuple{Tuple{Int64, Int64}, NTuple{4, Int64}, Tuple{Int64, Int64}}}, Nothing, true}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:weight, Zygote.Context, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}, CuArray{Float32, 4}}}}}}}, Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 4}, CuArray{Float32, 4}}}}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.var"#kw_zpullback#40"{NNlib.var"#DenseConvDims_pullback#162"{Tuple{CuArray{Float32, 4}, CuArray{Float32, 4}}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, CuArray{Float32, 4}}, Tuple{}}}}}}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}}}}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{var"#solve_ss#5"{Tuple{Float32, Float32}, Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{var"#solve_ss#5"{Tuple{Float32, Float32}, Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, CuArray{Float32, 4}}, Any}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, CuArray{Float32, 4}}, 
Any}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{NTuple{4, Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{var"#1#3", CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 4}}, Tuple{}}, Zygote.ZBack{NNlib.var"#broadcasted_relu_pullback#32"{CuArray{Float32, 4}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing, Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{MeanPool{2, 4}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:padding, :stride), T} where T<:Tuple}, Tuple{NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.Pullback{Tuple{Type{NamedTuple{(:padding, :stride), Tuple{NTuple{4, Int64}, Tuple{Int64, Int64}}}}, Tuple{NTuple{4, Int64}, Tuple{Int64, Int64}}}, Tuple{Zygote.var"#1733#back#206"{Zygote.Jnew{NamedTuple{(:padding, :stride), Tuple{NTuple{4, Int64}, Tuple{Int64, Int64}}}, Nothing, true}}}}}}, 
Zygote.var"#kw_zpullback#40"{NNlib.var"#PoolDims_pullback#176"{Tuple{CuArray{Float32, 4}, Tuple{Int64, Int64}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), MeanPool{2, 4}, Val{:pad}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), MeanPool{2, 4}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:pad, Zygote.Context, MeanPool{2, 4}, NTuple{4, Int64}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), MeanPool{2, 4}, Val{:k}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), MeanPool{2, 4}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:k, Zygote.Context, MeanPool{2, 4}, Tuple{Int64, Int64}}}}}}}, Zygote.var"#1569#back#135"{typeof(identity)}, Zygote.ZBack{NNlib.var"#meanpool_pullback#247"{Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, CuArray{Float32, 4}, PoolDims{2, (6, 6), (6, 6), (0, 0, 
0, 0), (1, 1)}, CuArray{Float32, 4}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), MeanPool{2, 4}, Val{:stride}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), MeanPool{2, 4}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:stride, Zygote.Context, MeanPool{2, 4}, Tuple{Int64, Int64}}}}}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing, Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 4}}, Tuple{Zygote.Pullback{Tuple{var"#2#4", CuArray{Float32, 4}}, Tuple{Zygote.var"#2436#back#489"{Zygote.var"#485#487"{CuArray{Float32, 4}, Tuple{Tuple{Int64, Colon}}}}, Zygote.var"#1569#back#135"{typeof(identity)}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{Nothing}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, CuArray{Float32, 2}}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 2}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:bias}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3819#back#1049"{Zygote.var"#1047#1048"}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:σ}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}, Zygote.var"#3723#back#1017"{Zygote.var"#1013#1015"{Tuple{CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Zygote.ZBack{ChainRules.var"#times_pullback#1529"{CuArray{Float32, 2}, CuArray{Float32, 2}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Val{:weight}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}, Symbol}, Any}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, CuArray{Float32, 2}}, Tuple{}}}}}}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}}}}}, Zygote.var"#1653#back#176"{Zygote.var"#173#175"{Tuple{}}}, Zygote.var"#1665#back#180"{Zygote.var"#178#179"}, Zygote.Pullback{Tuple{typeof(Flux.applychain), Tuple{}, CuArray{Float32, 2}}, Tuple{}}}}}}}}, Zygote.Pullback{Tuple{typeof(ZygoteRules.literal_getproperty), Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, var"#solve_ss#5"{Tuple{Float32, Float32}, Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, Val{:layers}}, Tuple{Zygote.Pullback{Tuple{typeof(getproperty), Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, var"#solve_ss#5"{Tuple{Float32, Float32}, Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, Symbol}, Tuple{Zygote.var"#1700#back#198"{Zygote.var"#back#197"{:layers, Zygote.Context, Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, var"#solve_ss#5"{Tuple{Float32, Float32}, Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}, Tuple{Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, 
GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, var"#solve_ss#5"{Tuple{Float32, Float32}, Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}}}}}}})(Δ::CuArray{Float32, 2})
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [43] Pullback
    @ e:\wqy\julia\mnistDEQ.jl:199 [inlined]
 [44] (::Zygote.Pullback{Tuple{var"#10#13"{Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, var"#solve_ss#5"{Tuple{Float32, Float32}, Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}}, Any})(Δ::Float32)
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface2.jl:0
 [45] (::Zygote.var"#69#70"{Zygote.Params, Zygote.Pullback{Tuple{var"#10#13"{Chain{Tuple{Chain{Tuple{Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(relu), CuArray{Float32, 4}, CuArray{Float32, 1}}, GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, Conv{2, 4, typeof(identity), CuArray{Float32, 4}, CuArray{Float32, 1}}}}, var"#solve_ss#5"{Tuple{Float32, Float32}, Flux.var"#63#65"{Chain{Tuple{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}}}}, CuArray{Float32, 1}}, Chain{Tuple{GroupNorm{typeof(identity), CuArray{Float32, 1}, Float32, Nothing}, var"#1#3", MeanPool{2, 4}, var"#2#4", Dense{typeof(identity), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}}}}, Any}, Zygote.Context})(Δ::Float32)        
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface.jl:255
 [46] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~\.julia\packages\Zygote\zowrf\src\compiler\interface.jl:59
 [47] macro expansion
    @ e:\wqy\julia\mnistDEQ.jl:198 [inlined]
 [48] macro expansion
    @ ~\.julia\packages\ProgressMeter\Vf8un\src\ProgressMeter.jl:940 [inlined]
 [49] train(; kws::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Main e:\wqy\julia\mnistDEQ.jl:196
 [50] train()
    @ Main e:\wqy\julia\mnistDEQ.jl:145
 [51] top-level scope
    @ e:\wqy\julia\mnistDEQ.jl:219
 [52] include(fname::String)
    @ Base.MainInclude .\client.jl:444
 [53] startdebug(socket::Base.PipeEndpoint, error_handler::VSCodeDebugger.var"#3#4"{Tuple{String, String}})
    @ VSCodeDebugger.DebugAdapter ~\.vscode\extensions\julialang.language-julia-1.2.2\scripts\packages\DebugAdapter\src\packagedef.jl:93
 [54] startdebugger()
    @ VSCodeDebugger ~\.vscode\extensions\julialang.language-julia-1.2.2\scripts\packages\VSCodeDebugger\src\VSCodeDebugger.jl:38
in expression starting at e:\wqy\julia\mnistDEQ.jl:219

@ChrisRackauckas
Copy link
Member

Try and narrow this down to something simpler. It should be possible with just conv and ForwardDiff.jacobian.

@QiyaoWei
Copy link
Contributor Author

So if I understood you correctly, I constructed the following code with similar warning

using Flux
using ForwardDiff
arr = rand(Float32, 10,10,1,1) #ForwardDiff.Dual{Nothing, Float32, 1}.(rand(Float32, 10,10,1,1))
deq = Conv((3, 3), 1 => 1, relu, stride=1)
ForwardDiff.jacobian(deq, arr)
┌ Warning: Slow fallback implementation invoked for conv!  You probably don't want this; check your datatypes.
│   yT = ForwardDiff.Dual{ForwardDiff.Tag{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Float32}, Float32, 12}
│   T1 = ForwardDiff.Dual{ForwardDiff.Tag{Conv{2, 4, typeof(relu), Array{Float32, 4}, Vector{Float32}}, Float32}, Float32, 12}
│   T2 = Float32
└ @ NNlib C:\Users\administered\.julia\packages\NNlib\3MZcC\src\conv.jl:206

It looks like there's some problem with constructing a full jacobian on a conv layer

@ChrisRackauckas
Copy link
Member

Yeah, that catches it. Your MWE probably fails on GPUs. It looks like what's happening is, if the types of the weights and inputs don't match, and aren't simple floating point numbers, then conv falls back to:

https://github.com/FluxML/NNlib.jl/blob/master/src/impl/conv_direct.jl#L48-L145

which is written only for CPU and thus it fails in that function. This comes up because the backpass of an implicit solve needs to solve a linear system.

https://github.com/SciML/DiffEqSensitivity.jl/blob/v6.48.0/src/steadystate_adjoint.jl#L35-L65

So it looks like we have to either:

  1. Specialize ForwardDiff on conv by using a array of structs -> struct of array transformation
  2. Use Zygote.forward_jacobian (@DhairyaLGandhi is that already setup to work with conv?)
  3. Have the user set that jacobian to numerical, which isn't actually too bad (since it's square), but it's annoying.

@QiyaoWei
Copy link
Contributor Author

Yes that makes sense. Also, my intuition is that Zygote.forward_jacobian will not work, because this warning happens

using Flux
using Zygote
arr = rand(Float32, 10,10,1,1)
deq = Conv((3, 3), 1 => 1, relu, stride=1)
Zygote.forward_jacobian(deq, arr)
┌ Warning: Slow fallback implementation invoked for conv!  You probably don't want this; check your datatypes.
│   yT = ForwardDiff.Dual{Nothing, Float32, 12}
│   T1 = ForwardDiff.Dual{Nothing, Float32, 12}
│   T2 = Float32
└ @ NNlib C:\Users\administered\.julia\packages\NNlib\3MZcC\src\conv.jl:206

@frankschae Not sure the best way to fix this problem?

@ChrisRackauckas
Copy link
Member

The issue in DiffEqSensitivity is that there's a conflation of the sensealgs in there.

https://github.com/SciML/DiffEqSensitivity.jl/blob/v6.48.0/src/steadystate_adjoint.jl#L66
https://github.com/SciML/DiffEqSensitivity.jl/blob/v6.48.0/src/steadystate_adjoint.jl#L39

autojacvec is being used in two ways, while here we would want to change only the forward mode calculation to numerical while keeping the Zygote reverse. @frankschae we should talk about resolving this, and see if we're doing too much AbstractDifferentiation.jl by hand in doing so.

@QiyaoWei
Copy link
Contributor Author

Hi all, in the process of looking at this issue I found this thread (FluxML/Flux.jl#896), might it be an alternative solution to our problem?

@avik-pal
Copy link
Member

Just as a pointer to anyone stumbling onto this issue. Pass sensealg = SteadyStateAdjoint(autodiff = false, autojacvec = ZygoteVJP()), and then DEQs should work with convolutions and on gpu.

@ChrisRackauckas
Copy link
Member

We should update the auto heuristic for that

@QiyaoWei
Copy link
Contributor Author

QiyaoWei commented Oct 1, 2021

Thanks @avik-pal ! That was exactly the fix this issue needed! There are some minor issues with my personal GPU right now, but the code runs without any warnings now! I'll get back to you on this after I manage to fix the GPU :)

Also @ChrisRackauckas do you mean updating the adjoints documentation page?

@ChrisRackauckas
Copy link
Member

Try SciML/SciMLSensitivity.jl#497 . That should make this all work by default.

@QiyaoWei
Copy link
Contributor Author

QiyaoWei commented Oct 5, 2021

I see. I am closing this issue for now with functional GPU-based convolution-DEQ implementations. Please reopen if there are any further questions!

## Classification of MNIST dataset 
## with the convolutional neural network known as LeNet5.
## This script also combines various
## packages from the Julia ecosystem with Flux.
using Flux
using Flux.Data:DataLoader
using Flux.Optimise: Optimiser, WeightDecay
using Flux: onehotbatch, onecold
using Flux.Losses:logitcrossentropy
using Statistics, Random
using Logging:with_logger
using TensorBoardLogger: TBLogger, tb_overwrite, set_step!, set_step_increment!
using ProgressMeter:@showprogress
import MLDatasets
import BSON
using CUDA
using DiffEqSensitivity
using SteadyStateDiffEq
using DiffEqFlux
using OrdinaryDiffEq
CUDA.allowscalar(false)


function Net()

    down = Chain(
        Conv((3, 3), 1 => 64, relu, stride=1),
        GroupNorm(64, 64),
        Conv((4, 4), 64 => 64, relu, stride=2, pad=1),
        GroupNorm(64, 64),
        Conv((4, 4), 64 => 64, stride=2, pad=1),
    )

    deq = Chain(
        Conv((3, 3), 64 => 64, relu, stride=1, pad=1),
        Conv((3, 3), 64 => 64, relu, stride=1, pad=1),
    )

    p, re = Flux.destructure(deq) |> gpu
    fc = Chain(
        GroupNorm(64, 64),
        x -> relu.(x),
        MeanPool((6, 6)),
        x -> reshape(x, (64, :)),
        Dense(64, 10),
    )

    tspan = (0.0f0, 1.0f0) |> gpu
    function solve_ss(x)
        xg = gpu(x)
        z = re(p)(xg) |> gpu
        function dudt_(u, _p, t)
        # Solving the equation f(u) - u = du = 0
            re(_p)(u + xg) - u
        end
        ss = SteadyStateProblem(ODEProblem(dudt_, gpu(z), gpu(tspan), gpu(p)))
        x = solve(ss, DynamicSS(Tsit5()), u0=gpu(z), abstol=Float32(1f-5), reltol=Float32(1f-5), tspan=1.0f0, sensealg=SteadyStateAdjoint(autodiff=false, autojacvec=ZygoteVJP())).u
    end
    # Build our over-all model topology
    m = Chain(
        down,               # (28,28,1,BS) -> (6,6,64,BS)
        solve_ss,           # (6,6,64,BS) -> (6,6,64,BS)
        fc,                 # (6,6,64,BS) -> (10, BS)
    ) |> gpu#|>f32

    return m
end

function get_data(args)
    xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
    xtest, ytest = MLDatasets.MNIST.testdata(Float32)

    xtrain = reshape(xtrain, 28, 28, 1, :)
    xtest = reshape(xtest, 28, 28, 1, :)

    ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

    train_loader = DataLoader((xtrain, ytrain), batchsize=args.batchsize, shuffle=true)
    test_loader = DataLoader((xtest, ytest),  batchsize=args.batchsize)
    
    return train_loader, test_loader
end

loss(y?, y) = logitcrossentropy(y?, y)

function eval_loss_accuracy(loader, model, device)
    l = 0f0
    acc = 0
    ntot = 0
    for (x, y) in loader
        x, y = x |> device, y |> device
        y? = model(x)
        l += loss(y?, y) * size(x)[end]        
        acc += sum(onecold(y? |> cpu) .== onecold(y |> cpu))
        ntot += size(x)[end]
    end
    return (loss = l / ntot |> round4, acc = acc / ntot * 100 |> round4)
end

## utility functions
num_params(model) = sum(length, Flux.params(model)) 
round4(x) = round(x, digits=4)

# arguments for the `train` function 
Base.@kwdef mutable struct Args
    ¶« = 3e-4             # learning rate
    ¶À = 0                # L2 regularizer param, implemented as weight decay
    batchsize = 8      # batch size
    epochs = 10          # number of epochs
    seed = 0             # set seed > 0 for reproducibility
    use_cuda = true      # if true use cuda (if available)
    infotime = 1 	     # report every `infotime` epochs
    checktime = 5        # Save the model every `checktime` epochs. Set to 0 for no checkpoints.
    tblogger = true      # log training with tensorboard
    savepath = "runs/"    # results path
end

function train(; kws...)
    args = Args(; kws...)
    args.seed > 0 && Random.seed!(args.seed)
    use_cuda = args.use_cuda && CUDA.functional()
    
    if use_cuda
        device = gpu
        @info "Training on GPU"
    else
        device = cpu
        @info "Training on CPU"
    end

    ## DATA
    train_loader, test_loader = get_data(args)
    @info "Dataset MNIST: $(train_loader.nobs) train and $(test_loader.nobs) test examples"

    ## MODEL AND OPTIMIZER
    model = Net() |> device
    @info "LeNet5 model: $(num_params(model)) trainable params"    
    
    ps = Flux.params(model)  

    opt = ADAM(args.¶«) 
    if args.¶À > 0 # add weight decay, equivalent to L2 regularization
        opt = Optimiser(opt, WeightDecay(args.¶À))
    end
    
    ## LOGGING UTILITIES
    if args.tblogger 
        tblogger = TBLogger(args.savepath, tb_overwrite)
        set_step_increment!(tblogger, 0) # 0 auto increment since we manually set_step!
        @info "TensorBoard logging at \"$(args.savepath)\""
    end
    
    function report(epoch)
        train = eval_loss_accuracy(train_loader, model, device)
        test = eval_loss_accuracy(test_loader, model, device)        
        println("Epoch: $epoch   Train: $(train)   Test: $(test)")
        if args.tblogger
            set_step!(tblogger, epoch)
            with_logger(tblogger) do
                @info "train" loss = train.loss  acc = train.acc
                @info "test"  loss = test.loss   acc = test.acc
            end
        end
    end
    
    ## TRAINING
    @info "Start Training"
    report(0)
    for epoch in 1:args.epochs
        @showprogress for (x, y) in train_loader
            x, y = x |> device, y |> device
            gs = Flux.gradient(ps) do
                y? = model(x)
                loss(y?, y)
            end

            Flux.Optimise.update!(opt, ps, gs)
        end
        
        ## Printing and logging
        epoch % args.infotime == 0 && report(epoch)
        if args.checktime > 0 && epoch % args.checktime == 0
            !ispath(args.savepath) && mkpath(args.savepath)
            modelpath = joinpath(args.savepath, "model.bson") 
            let model = cpu(model) # return model to cpu before serialization
                BSON.@save modelpath model epoch
            end
            @info "Model saved in \"$(modelpath)\""
        end
    end
end

train()

@QiyaoWei QiyaoWei closed this as completed Oct 5, 2021
@ChrisRackauckas
Copy link
Member

sensealg=SteadyStateAdjoint(autodiff=false, autojacvec=ZygoteVJP()) shouldn't be necessary. It'll do that automatically in the latest version, so the first code should be fine now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants