Skip to content

Commit

Permalink
destructure returns only trainable params
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jan 15, 2022
1 parent 79dbbd6 commit 4dc70b5
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 74 deletions.
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Zygote, MacroTools, Juno, Reexport
using MacroTools: @forward
@reexport using NNlib
using Zygote: Params, @adjoint, gradient, pullback, @nograd
using Functors: Functors, @functor, functor, fmap
export gradient

export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
Expand Down
119 changes: 118 additions & 1 deletion src/functor.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import Adapt: adapt, adapt_storage
using LinearAlgebra: Cholesky
using Zygote: IdSet
import Functors: Functors, @functor, functor, fmap, isleaf
using SparseArrays: AbstractSparseArray

trainable(m) = functor(m)[1]
Expand Down Expand Up @@ -38,6 +37,124 @@ Possible values include:
"""
trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode)


# Flattening models to weight vectors, and back

function _restructure(m, xs)
i = 0
filter = (x, c) -> any(y -> c === y, trainable(x))
walk = filtered_walk(filter)
= fmap(m; walk) do x
x isa AbstractArray{<:Number} || return x
x = reshape(xs[i .+ (1:length(x))], size(x))
i += length(x)
return x
end
length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
return
end

@adjoint function _restructure(m, xs)
m̄, numel = _restructure(m, xs), length(xs)
function _restructure_pullback(dm)
xs′ = destructure(dm)[1]
numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))"
return (nothing, xs′)
end
return m̄, _restructure_pullback
end

"""
destructure(m)
Flatten a model's parameters into a single weight vector.
julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
julia> θ, re = destructure(m);
julia> θ
67-element Vector{Float32}:
-0.1407104
...
The second return value `re` allows you to reconstruct the original network after making
modifications to the weight vector (for example, with a hypernetwork).
julia> re(θ .* 2)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
"""
function destructure(m)
xs = Zygote.Buffer([])
collect_params!(xs, m)
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
end

function collect_params!(xs, m)
filter = (x, c) -> any(y -> c === y, trainable(x))
walk = filtered_walk(filter)
fmap(m; walk) do x
x isa AbstractArray{<:Number} && push!(xs, x)
return x
end
end

function filtered_walk(filter)
seen = IdSet()

function walk(f, x)
x in seen && return x
push!(seen, x)

children, reconstruct = functor(x)
mappedchildren = map(children) do c
filter(x, c) ? f(c) : c
end
reconstruct(mappedchildren)
end

return walk
end


"""
params(m...)
Collect trainable parameters (a.k.a. numerical arrays)
from the input model(s) `m` into a [`Zygote.Params`](@ref) object.
Only the parameters that can be reached by recursion
on the [`trainable`](@ref) children of
the tree with root `m` are collected.
# Usage
```julia-repl
julia> m = Dense(ones(2, 3), zeros(2))
Dense(3, 2) # 8 parameters
julia> ps = Flux.params(m)
Params([[1.0 1.0 1.0; 1.0 1.0 1.0], [0.0, 0.0]])
julia> x = ones(3)
3-element Vector{Float64}:
1.0
1.0
1.0
julia> gs = gradient(() -> sum(2 .* m(x)), ps)
Grads(...)
julia> gs[m.weight]
2×3 Matrix{Float64}:
2.0 2.0 2.0
2.0 2.0 2.0
```
"""
function params end

## TODO This causes some test regressions. Why?
# function params(m...)
# ps = Params()
# collect_params!(ps, m)
# return ps
# end

params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)

function params!(p::Params, x, seen = IdSet())
Expand Down
2 changes: 1 addition & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys

functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)
Functors.functor(::Type{<:Chain}, c) = c.layers, ls -> Chain(ls...)

applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
Expand Down
4 changes: 2 additions & 2 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
end
end

_show_leaflike(x) = isleaf(x) # mostly follow Functors, except for:
_show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for:
_show_leaflike(::Tuple{Vararg{<:Number}}) = true # e.g. stride of Conv
_show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LSTMcell
_show_leaflike(::Diagonal) = true # appears inside LayerNorm
Expand Down Expand Up @@ -97,7 +97,7 @@ function _big_finale(io::IO, m)
end

_childarray_sum(f, x::AbstractArray) = f(x)
_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x))
_childarray_sum(f, x) = Functors.isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x))

# utility functions

Expand Down
53 changes: 0 additions & 53 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -629,59 +629,6 @@ function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs))
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
end

# Flattening models to weight vectors, and back

function _restructure(m, xs)
i = 0
= fmap(m) do x
x isa AbstractArray || return x
x = reshape(xs[i.+(1:length(x))], size(x))
i += length(x)
return x
end
length(xs) == i || @warn "Expected $(i) params, got $(length(xs))"
return
end

@adjoint function _restructure(m, xs)
m̄, numel = _restructure(m, xs), length(xs)
function _restructure_pullback(dm)
xs′ = destructure(dm)[1]
numel == length(xs′) || @warn "Expected $(numel) params, got $(length(xs′))"
return (nothing, xs′)
end
return m̄, _restructure_pullback
end

"""
destructure(m)
Flatten a model's parameters into a single weight vector.
julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
julia> θ, re = destructure(m);
julia> θ
67-element Vector{Float32}:
-0.1407104
...
The second return value `re` allows you to reconstruct the original network after making
modifications to the weight vector (for example, with a hypernetwork).
julia> re(θ .* 2)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
"""
function destructure(m)
xs = Zygote.Buffer([])
fmap(m) do x
x isa AbstractArray && push!(xs, x)
return x
end
return vcat(vec.(copy(xs))...), p -> _restructure(m, p)
end

# Other

Expand Down
165 changes: 165 additions & 0 deletions test/functor.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
using Flux: loadparams!, Zeros, destructure

ls(dims...) = reshape(collect(Float32, 1:prod(dims)), dims...) # accepts dims in reverse order to Dense

dl(nin, nout, bias) = Dense(ls(nout, nin), bias(nout))

dm(bias) = Chain(
dl(3, 5, bias),
dl(5, 4, bias),
dl(4, 3, bias)
)

nobias(n) = Zeros()

function testdense(m, bt)
@testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt)))
@test l1.weight == l2.weight
@test l1.bias == l2.bias
@test typeof(l1.bias) === typeof(l2.bias)
end
end

@testset "Params" begin
m = Dense(10, 5)
@test size.(params(m)) == [(5, 10), (5,)]
m = RNN(10, 5)
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5, 1)]

# Layer duplicated in same chain, params just once pls.
c = Chain(m, m)
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5, 1)]

# Self-referential array. Just want params, no stack overflow pls.
r = Any[nothing,m]
r[1] = r
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)]

@testset "use params in gradient context" begin
m = Chain(Dense(3,2), Dense(2,2))
ps = Flux.params(m)
gs = gradient(() -> sum(sum(p) for p in Flux.params(m)), ps)
for p in ps
@test gs[p] ones(size(p))
end

w1, w2 = rand(2), rand(2)
ps = Flux.params(w1, w2)
gs = gradient(() -> sum(sum(p) for p in Flux.params(w1, w2)), ps)
for p in ps
@test gs[p] ones(size(p))
end

m = Chain(Dense(3,2), Dense(2,2))
g = gradient(m -> sum(params(m)[1]), m)[1]
@test g.layers[1].weight == ones(Float32, 2, 3)

gs = gradient(() -> sum(params(m)[1]), params(m))
@test gs[params(m)[1]] == ones(Float32, 2, 3)

# Tests from https://github.com/FluxML/Flux.jl/pull/1614
m = Dense(3, 2)
ps = Flux.params(m)
data = rand(Float32, 3, 5)
loss(m, x) = sum(m(x).^2)

g1 = gradient(Flux.params(m)) do
loss(m, data)
end
g2 = gradient(Flux.params(m)) do
ps = Flux.params(m) # just creating params without using them
loss(m, data)
end
g3 = gradient(Flux.params(m)) do
ps = Flux.params(m)
loss(m, data) + sum(sum(p) for p in ps)
end
g4 = gradient(Flux.params(m)) do
loss(m, data) + sum(sum(p) for p in ps)
end
g5 = gradient(Flux.params(m)) do
sum(Flux.params(m)[1]) + sum(Flux.params(m)[2])
end
g6 = gradient(Flux.params(m)) do
sum(ps[1]) + sum(ps[2])
end
@test g2[m.weight] == g1[m.weight]
@test g3[m.weight] == g1[m.weight] .+ 1
@test g4[m.weight] == g1[m.weight] .+ 1
@test all(g5[m.weight] .== 1)
@test_broken all(g6[m.weight] .== 1)
end
end


@testset "Param remapping" begin
@testset "loadparams!" begin
pars(w, b) = [w, b]

pars(w, b::Zeros) = [w, Flux.zeros32(size(w,1))]
pars(l) = pars(l.weight, l.bias)
pararray(m) = mapreduce(pars, vcat, m)
weights(m) = mapreduce(l -> [l.weight], vcat, m)
@testset "Bias type $bt" for bt in (Flux.zeros32, nobias)
m = dm(bt)
loadparams!(m, params(m))
testdense(m, bt)
end

@testset "$b1 to $b2" for (b1, b2, be) in (
(Flux.zeros32, Flux.ones32, Flux.ones32), # Load ones as bias to a model with zeros as bias -> model gets ones as bias
(Flux.ones32, nobias, Flux.zeros32), # Load Zeros as bias to a model with ones as bias-> model gets zeros as bias
(nobias, Flux.ones32, nobias), # Load ones as bias to a model with Zeros as bias-> model bias does not change
)
m1 = dm(b1)
m2 = dm(b2)
loadparams!(m1, b1 == nobias ? weights(m2) : pararray(m2))
testdense(m1, be)
end
end
end

@testset "Destructure" begin
@testset "Bias type $bt" for bt in (zeros, nobias)
m = dm(bt)
p, re = destructure(m)
testdense(re(p), bt)
end

@testset "restructure in gradient" begin
x = rand(Float32, 3, 1)
m = dm(zeros)
∇m = gradient(m -> sum(m(x)), m)[1]
p, re = destructure(m)
∇p = gradient-> sum(re(θ)(x)), p)[1]
@test ∇p destructure(∇m)[1] rtol=1e-6
end

@testset "destructure with buffers" begin
p, re = destructure(BatchNorm(3))
@test length(p) == 6

# https://github.com/FluxML/Flux.jl/issues/1727
x = rand(Float32, 3, 4)
y, back = Flux.pullback(x, p) do x, p
vec(re(p)(x))
end
@test_nowarn back(y)
b = back(y)
@test size(b[1]) == size(x)
@test size(b[2]) == size(p)
end
end

@testset "Train and test mode" begin
mutable struct DummyLayer
testing::Bool
end
Flux.testmode!(m::DummyLayer, testing=true) = (m.testing = testing; m)

c = Chain(DummyLayer(true))
testmode!(c)
@test c[1].testing
trainmode!(c)
@test !c[1].testing
end
Loading

0 comments on commit 4dc70b5

Please sign in to comment.