Skip to content

Commit

Permalink
Trim dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Jul 11, 2021
1 parent 7824003 commit a1f968f
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 36 deletions.
14 changes: 0 additions & 14 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,23 @@ authors = ["David Lung <[email protected]> and contributors"]
version = "0.1.0"

[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
BlackBoxOptim = "a134a8b2-14d6-55f6-9291-3336d3ab0209"
CMAEvolutionStrategy = "8d3b24bd-414e-49e0-94fb-163cc3a3e411"
Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Evolutionary = "86b6b26d-c046-49b6-aa0b-5f0f74682bd6"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
GalacticOptim = "a75be94c-b780-496d-a8a9-0878b188d577"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
MLDataPattern = "9920b226-0b2a-5f5f-9153-9aa70a013f8b"
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
ODEInterfaceDiffEq = "09606e27-ecf5-54fc-bb29-004bd9f985bf"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PProf = "e4faabce-9ead-11e9-39d9-4379958e3056"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProfileView = "c46f51b8-102a-5cf2-8d2c-8597cb0e0da7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
julia = "1"
Expand Down
18 changes: 2 additions & 16 deletions src/LTC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,14 @@ module LTC
using Reexport

using Distributions
import NPZ: npzread
export npzread
using Juno
using DiffEqBase
using OrdinaryDiffEq
using DiffEqSensitivity
using DiffEqFlux
using DiffEqFlux: initial_params, paramlength, FastChain, FastDense, sciml_train
import DiffEqFlux: initial_params, paramlength, FastChain, FastDense, sciml_train
import DifferentialEquations: PresetTimeCallback, PeriodicCallback
export sciml_train
using GalacticOptim
using ModelingToolkit
using Zygote
using Zygote: @adjoint, Numeric, literal_getproperty, accum
export Zygote
using Flux: reset!, Zeros, Data.DataLoader
using Flux: Data.DataLoader
import Flux: reset!
export DataLoader
using IterTools: ncycle
export ncycle
using Flux
using IterTools


rand_uniform(TYPE, lb,ub,dims...) = TYPE.(rand(Uniform(lb,ub),dims...))
Expand Down
8 changes: 4 additions & 4 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct Broadcaster{M,P}
paramlength::Int
end
function Broadcaster(model)
p = DiffEqFlux.initial_params(model)
p = initial_params(model)
paramlength = length(p)
Broadcaster(model,p,paramlength)
end
Expand Down Expand Up @@ -89,7 +89,7 @@ paramlength(m::FluxLayerWrapper) = m.paramlength

reset_state!(m,p) = nothing

function reset_state!(m::Union{Flux.Chain, DiffEqFlux.FastChain}, p)
function reset_state!(m::Union{Flux.Chain, FastChain}, p)
start_idx = 1
for l in m.layers
pl = paramlength(l)
Expand All @@ -111,13 +111,13 @@ function get_bounds(m::Mapper)
lb, ub
end

function get_bounds(m::Union{Flux.Chain, DiffEqFlux.FastChain})
function get_bounds(m::Union{Flux.Chain, FastChain})
lb = vcat([get_bounds(layer)[1] for layer in m.layers]...)
ub = vcat([get_bounds(layer)[2] for layer in m.layers]...)
lb, ub
end

function get_bounds(m::DiffEqFlux.FastDense)
function get_bounds(m::FastDense)
lb = vcat(fill(-10.1, m.out*m.in),
fill(-10.1, m.out)) |> f32
ub = vcat(fill(10.1, m.out*m.in),
Expand Down
2 changes: 1 addition & 1 deletion src/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function loss_seq(p, re, x, y)
return mean(Flux.Losses.mse.(ŷ,y, agg=mean)), ŷ, y
end

function loss_seq(p, m::DiffEqFlux.FastChain, x, y)
function loss_seq(p, m::FastChain, x, y)
# ŷ = m.(x, [p])

LTC.reset_state!(m, p)
Expand Down
2 changes: 1 addition & 1 deletion src/mtk_recur.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mutable struct RecurMTK{T,V,S}# <:MTKLayer
state::S
end
function RecurMTK(cell; seq_len=1)
p = DiffEqFlux.initial_params(cell)
p = initial_params(cell)
RecurMTK(cell, p, length(p), cell.state0)
end
function (m::RecurMTK)(x, p=m.p)
Expand Down

0 comments on commit a1f968f

Please sign in to comment.