Skip to content

Commit

Permalink
Add Flux 0.13 compatibility (#202)
Browse files Browse the repository at this point in the history
* Add Flux 0.13 compatibility

* `params` -> `Flux.params`

* Forward-compatible (Flux 0.13 and 0.12) fix for tabular model

* Subtype `AbstractOptimiser`
  • Loading branch information
lorenzoh authored Apr 20, 2022
1 parent f51ecb6 commit cbdf36b
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 7 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
MosaicViews = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Expand All @@ -51,7 +52,7 @@ DataLoaders = "0.1"
FileIO = "1.7"
FilePathsBase = "0.9"
FixedPointNumbers = "0.8"
Flux = "0.12"
Flux = "0.12, 0.13"
FluxTraining = "0.2"
Glob = "1"
ImageInTerminal = "0.4"
Expand Down
4 changes: 2 additions & 2 deletions src/Tabular/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function TabularModel(

tabularbackbone = Parallel(vcat, catbackbone, contbackbone)

classifierin = mapreduce(layer -> size(layer.weight)[1], +, catbackbone[2].layers;
classifierin = mapreduce(layer -> size(layer.weight)[1], +, Tuple(catbackbone[2].layers);
init = contbackbone.chs)
dropout_rates = Iterators.cycle(dropout_rates)
classifiers = []
Expand Down Expand Up @@ -139,7 +139,7 @@ function tabular_embedding_backbone(embedding_sizes, dropout_rate=0.)
emb_drop = iszero(dropout_rate) ? identity : Dropout(dropout_rate)
Chain(
x -> tuple(eachrow(x)...),
Parallel(vcat, embedslist),
Parallel(vcat, embedslist...),
emb_drop
)
end
Expand Down
2 changes: 1 addition & 1 deletion src/Vision/models/blocks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function visionhead(
acts = vcat([relu for _ 1:n-2], [identity])
pool = concat_pool ? AdaptiveConcatPool((1, 1)) : AdaptiveMeanPool((1, 1))

layers = [pool, flatten]
layers = [pool, Flux.flatten]

for (h_in, h_out, act) in zip(hs, hs[2:end], acts)
push!(layers, linbndrop(h_in, h_out, act=act, p=p))
Expand Down
2 changes: 1 addition & 1 deletion src/training/discriminativelrs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dlro = DiscriminativeLRs(paramgroups, Dict(1 => 0., 2 => 1.))
o = Optimiser(dlro, Descent(0.1))
```
"""
struct DiscriminativeLRs
struct DiscriminativeLRs <: Flux.Optimise.AbstractOptimiser
pg::ParamGroups
factorfn
end
Expand Down
2 changes: 1 addition & 1 deletion src/training/lrfind.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function lrfind(
withfields(
learner,
model = modelcheckpoint,
params = params(modelcheckpoint),
params = Flux.params(modelcheckpoint),
optimizer = deepcopy(learner.optimizer)
) do

Expand Down
2 changes: 1 addition & 1 deletion src/training/paramgroups.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ getgroup(pg::ParamGroups, x::AbstractArray) = get(pg.map, x, nothing)

function assigngroups!(pg::ParamGroups, grouper, m)
for (group, m_) in group(grouper, m)
for p in params(m_)
for p in Flux.params(m_)
pg.map[p] = group
end
end
Expand Down

0 comments on commit cbdf36b

Please sign in to comment.