Skip to content

Commit

Permalink
Merge branch 'dev' into depstrim
Browse files Browse the repository at this point in the history
  • Loading branch information
lungd committed Jul 24, 2021
2 parents 031861a + a4a0e90 commit 1a27f40
Show file tree
Hide file tree
Showing 18 changed files with 798 additions and 243 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@ version = "0.1.0"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
DiffEqOperators = "9fdde737-9c7f-55bf-ade8-46b3f136cc48"
DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GalacticOptim = "a75be94c-b780-496d-a8a9-0878b188d577"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[compat]
Expand Down
11 changes: 6 additions & 5 deletions example/half-cheetah_mtk_recur.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ using DiffEqSensitivity
using OrdinaryDiffEq
using DiffEqFlux
using GalacticOptim
using BlackBoxOptim
# using BlackBoxOptim
using ModelingToolkit
using IterTools: ncycle

include("half_cheetah_data_loader.jl")

Expand All @@ -26,8 +27,8 @@ function train_cheetah(n, solver=VCABM(), sensealg=InterpolatingAdjoint(autojacv
return false
end

batchsize=20
seq_len=10
batchsize=10
seq_len=20
train_dl, test_dl, valid_dl = get_dl(batchsize=batchsize, seq_len=seq_len)

wiring = LTC.NCPWiring(17,17;
Expand All @@ -41,11 +42,11 @@ function train_cheetah(n, solver=VCABM(), sensealg=InterpolatingAdjoint(autojacv
sys = ModelingToolkit.structural_simplify(net)

model = DiffEqFlux.FastChain(LTC.Mapper(wiring.n_in),
LTC.RecurMTK(LTC.MTKCell(wiring.n_in, wiring.n_out, sys, solver, sensealg)),
LTC.RecurMTK(LTC.MTKCell(wiring.n_in, wiring.n_out, net, sys, solver, sensealg)),
LTC.Mapper(wiring.n_out),
)

opt = Flux.Optimiser(ClipValue(1.00), ExpDecay(0.01, 0.1, 200, 1e-4), ADAM())
opt = Flux.Optimiser(ClipValue(1.00), ExpDecay(1, 0.1, 200, 1e-4), ADAM())
# opt = Optim.LBFGS()
# opt = BBO()
# opt = ParticleSwarm(;lower=lb, upper=ub)
Expand Down
9 changes: 6 additions & 3 deletions example/half-cheetah_mtk_recur_flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ using DiffEqSensitivity
using OrdinaryDiffEq
using DiffEqFlux
using GalacticOptim
using BlackBoxOptim
using Flux
# using BlackBoxOptim
using ModelingToolkit
using IterTools: ncycle

include("half_cheetah_data_loader.jl")

Expand Down Expand Up @@ -40,10 +40,11 @@ function train_cheetah(n, solver=VCABM(), sensealg=InterpolatingAdjoint(autojacv

net = LTC.Net(wiring, name=:net)
sys = ModelingToolkit.structural_simplify(net)
# return net, sys

model = Flux.Chain(Flux.Dense(wiring.n_in, 5, tanh),
Flux.Dense(5, wiring.n_in),
LTC.RecurMTK(LTC.MTKCell(wiring.n_in, wiring.n_out, sys, solver, sensealg)),
LTC.RecurMTK(LTC.MTKCell(wiring.n_in, wiring.n_out, net, sys, solver, sensealg)),
LTC.Mapper(wiring.n_out),
)

Expand All @@ -57,6 +58,8 @@ function train_cheetah(n, solver=VCABM(), sensealg=InterpolatingAdjoint(autojacv
LTC.optimize(model, LTC.loss_seq, cbg, opt, AD, ncycle(train_dl,n)), model
end

# net, sys = train_cheetah(1)

@time res1,model = train_cheetah(100)

# @time res1,model = train_cheetah(500, AutoTsit5(Rosenbrock23()))
Expand Down
16 changes: 1 addition & 15 deletions example/half_cheetah_data_loader.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
# one batch: vector of matrices with shape (features,batchsize)
# [rand(features,batchsize) for 1:seq_len]


import NPZ: npzread

function get_dl(; seq_len=32, batchsize=16)
filepath = joinpath(@__DIR__, "half-cheetah-data")
Expand All @@ -16,17 +13,6 @@ function get_dl(; seq_len=32, batchsize=16)
valid_files = all_files[1:5]


# goal:
# size(first(train_dl)[1]) == [(17,16) for 1:32]
# x_data == [ [(17,16) for 1:32] for 1:60 ]



# sequences = (32,17,100)
# sequences = [(32,17) for 1:100]
#




train_x, train_y = _load_files(train_files, seq_len)
Expand Down
147 changes: 147 additions & 0 deletions example/outpins_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
using ModelingToolkit
@variables t
D = Differential(t)

@variables xxx(t)
[xxx]

function InPin(;name)
@parameters x
defaults = Dict(x => 13.37)
ODESystem(Equation[],t,Num[],[x]; name, defaults)
end

function OutPin(;name)
vars = @variables x(t) xx(t)
defaults = Dict(x => 0.0)
ODESystem(Equation[D(x)~xx],t,[x,xx],Num[]; name, defaults)
end

function create_pins(in::Integer, out::Integer)
inpins = [InPin(;name=Symbol("x$(i)_InPin")) for i in 1:in]
outpins = [OutPin(;name=Symbol("x$(i)_OutPin")) for i in 1:out]
inpins, outpins
end

function Network(;name)
@variables x(t)=0.0 a(t)
@parameters p=0.4
eqs = [D(x) ~ a, a ~ 0.1p]
ODESystem(eqs;name)
end

function Model(;name)
op = OutPin(;name=:op)
net = Network(;name=:net)
eqs = [op.xx ~ net.x]
# eqs = Equation[]
systems = [op,net]
ODESystem(eqs,t;name,systems)
end

# using Symbolics

model = Model(;name=:model)
sys = structural_simplify(model)
@nonamespace n = model.net
@nonamespace o = model.op
u0 = Dict(
n.x => 0.0,
o.x => 0.0,
)
p = Dict(
n.p => 0.4,
)
defs = ModelingToolkit.get_defaults(sys)
prob = ODEProblem(sys, u0, (0.0,1.0), p)
sol = solve(prob, Tsit5())
plot(sol[o.x])
plot(sol[n.x])
sol
plot(sol(sol.t,Val{1})[1,:])
plot(sol(sol.t,Val{1})[2,:])
plot(sol[1,:])
plot(sol[2,:])

p = [0.1]
u0b = rand(1,10)
function prob_func(prob,i,repeat)
u0 = @view u0b[:,i]
remake(prob; u0)
end

function output_func(sol,i)
sol[:, end], false
end

function loss(p,model)
ensemble_prob = EnsembleProblem(prob; prob_func, output_func, safetycopy=false) # TODO: safetycopy ???
sol = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories=size(u0b,2),
sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)), save_everystep=false, save_start=false)
sum(Array(sol))
end



cbg = function (p,l;doplot=false)
display(l)
return false
end

optfun = GalacticOptim.OptimizationFunction((θ,p) -> loss(θ,model), GalacticOptim.AutoZygote())
optfunc = GalacticOptim.instantiate_function(optfun, p, GalacticOptim.AutoZygote(), nothing)
optprob = GalacticOptim.OptimizationProblem(optfunc, p)
GalacticOptim.solve(optprob, ADAM(), cb = cbg, maxiters=100)









using DiffEqCallbacks, OrdinaryDiffEq, LinearAlgebra
prob = ODEProblem((du,u,p,t) -> du .= u, rand(4), (0.0,1.0))
saved_values = SavedValues(Float64, Vector{Float64})
cb = SavingCallback((u,t,integrator)->integrator(t,Val{1})[:,1], saved_values, saveat = 0.0:0.1:1.0)
sol = solve(prob, Tsit5(), saveat=0.1, callback=cb)
as = Array(sol)
plot(s)
sol[:]
saved_values.saveval
plot(saved_values.saveval)



using ModelingToolkit
using ModelingToolkit: get_defaults
using OrdinaryDiffEq

@variables t
D = Differential(t)

function SubSys(;name)
@variables x(t)=0.0
@parameters p=0.3
ODESystem([x~p],t,[x],[p];name)
end
function Network(;name)
@variables x(t)=0.0
@named subsys = SubSys()
ODESystem([D(x)~subsys.x],t,[x],[];name,systems=[subsys])
end

@named net = Network()
sys = structural_simplify(net)
prob = ODEProblem(sys,get_defaults(sys),(0.0,0.1))
sol = solve(prob,Tsit5())

@nonamespace x = net.x
@nonamespace subsys = net.subsys
sol[x]
sol[subsys.x]


sol[net.x]
sol[net.subsys.x]
16 changes: 7 additions & 9 deletions example/sine-mtk_recur.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,8 @@ using DiffEqSensitivity
using OrdinaryDiffEq
using DiffEqFlux
using GalacticOptim
using Juno
using Cthulhu
using Profile
using BlackBoxOptim
#using PProf
using ProfileView
using ModelingToolkit
import Flux: Data.DataLoader

function generate_data()
in_features = 2
Expand All @@ -30,7 +25,7 @@ end

function train_sine(n, solver=VCABM(), sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))

cbg = function (p,l,pred,y;doplot=false)
cbg = function (p,l,pred,y;doplot=true)
display(l)
if doplot
fig = plot([ŷ[end,1] for ŷ in pred], label="ŷ")
Expand All @@ -49,7 +44,7 @@ function train_sine(n, solver=VCABM(), sensealg=InterpolatingAdjoint(autojacvec=
sys = ModelingToolkit.structural_simplify(net)

model = DiffEqFlux.FastChain(LTC.Mapper(wiring.n_in),
LTC.RecurMTK(LTC.MTKCell(wiring.n_in, wiring.n_out, sys, solver, sensealg)),
LTC.RecurMTK(LTC.MTKCell(wiring.n_in, wiring.n_out, net, sys, solver, sensealg)),
LTC.Mapper(wiring.n_out),
)

Expand All @@ -62,11 +57,14 @@ function train_sine(n, solver=VCABM(), sensealg=InterpolatingAdjoint(autojacvec=
AD = GalacticOptim.AutoZygote()
# AD = GalacticOptim.AutoModelingToolkit()

# return model

LTC.optimize(model, LTC.loss_seq, cbg, opt, AD, train_dl)

end

@time train_sine(10)
@time model = train_sine(100)
# @time model = train_sine(100, Tsit5(), InterpolatingAdjoint(autojacvec=DiffEqSensitivity.EnzymeVJP()))
# @time traintest(1000, QNDF())
# @time traintest(1000, TRBDF2())
# @time traintest(1000, AutoTsit5(Rosenbrock23()))
9 changes: 2 additions & 7 deletions example/sine-mtk_recur_flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,8 @@ using DiffEqSensitivity
using OrdinaryDiffEq
using DiffEqFlux
using GalacticOptim
using Juno
using Cthulhu
using Profile
using BlackBoxOptim
#using PProf
using ProfileView
using ModelingToolkit
import Flux: Data.DataLoader

function generate_data()
in_features = 2
Expand Down Expand Up @@ -48,7 +43,7 @@ function train_sine(n, solver=VCABM(), sensealg=InterpolatingAdjoint(autojacvec=
sys = ModelingToolkit.structural_simplify(net)

model = DiffEqFlux.Chain(Flux.Dense(wiring.n_in,wiring.n_in,tanh), LTC.Mapper(wiring.n_in),
LTC.RecurMTK(LTC.MTKCell(wiring.n_in, wiring.n_out, sys, solver, sensealg)),
LTC.RecurMTK(LTC.MTKCell(wiring.n_in, wiring.n_out, net, sys, solver, sensealg)),
LTC.Mapper(wiring.n_out),
)

Expand Down
12 changes: 8 additions & 4 deletions src/LTC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,24 @@ import DiffEqFlux: initial_params, paramlength, FastChain, FastDense, sciml_trai
using GalacticOptim
using ModelingToolkit
using Flux
using IterTools

using NNlib: sigmoid

rand_uniform(TYPE, lb,ub,dims...) = TYPE.(rand(Uniform(lb,ub),dims...))
rand_uniform(TYPE, lb,ub) = rand_uniform(TYPE, lb,ub,1)[1]

#Zygote.@nograd rand_uniform, reshape

include("layers.jl")
include("mtk_recur.jl")
include("optimization.jl")
include("losses.jl")
include("variables.jl")
# include("mkt_sysstruct.jl")
# include("zygote.jl")

include("ncp/ncp_sys_gen.jl")
include("ncp/wiring.jl")
include("systems/systems.jl")
include("systems/ncp/ncp_sys_gen.jl")
include("systems/ncp/wiring.jl")


export MTKRecur, MTKCell, Mapper, Broadcaster, get_bounds
Expand Down
2 changes: 1 addition & 1 deletion src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct Mapper{V}
end
function Mapper(in::Integer)
W = ones(Float32,in)
b = zeros(Float32,in)
b = fill(0.00001f0,in)
p = vcat(W,b)
Mapper(W, b, p, length(p))
end
Expand Down
Loading

0 comments on commit 1a27f40

Please sign in to comment.