Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace unrolled foldl used to evaluate Chain with a better one #1809

Merged
merged 8 commits into from
Feb 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10, 5, tanh)),
julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
true
```

For large models, there is a special type-unstable path which can reduce compilation
times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`.
This feature is somewhat experimental, beware!
"""
struct Chain{T<:Union{Tuple, NamedTuple}}
struct Chain{T<:Union{Tuple, NamedTuple, AbstractVector}}
layers::T
end

Expand All @@ -44,10 +48,22 @@ end

@functor Chain

applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
(c::Chain)(x) = applychain(c.layers, x)

@generated function applychain(layers::Tuple{Vararg{<:Any,N}}, x) where {N}
symbols = vcat(:x, [gensym() for _ in 1:N])
calls = [:($(symbols[i+1]) = layers[$i]($(symbols[i]))) for i in 1:N]
Expr(:block, calls...)
end

applychain(layers::NamedTuple, x) = applychain(Tuple(layers), x)

(c::Chain)(x) = applychain(Tuple(c.layers), x)
function applychain(layers::AbstractVector, x) # type-unstable path, helps compile times
for f in layers
x = f(x)
end
x
end

Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Expand All @@ -60,6 +76,7 @@ function Base.show(io::IO, c::Chain)
end
_show_layers(io, layers::Tuple) = join(io, layers, ", ")
_show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ")
_show_layers(io, layers::AbstractVector) = (print(io, "["); join(io, layers, ", "); print(io, "]"))

# This is a temporary and naive implementation
# it might be replaced in the future for better performance
Comment on lines 81 to 82
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, in addition to a hand-written foldl, the function Flux.activations is just accumulate(|>, m1.layers; init=x1). Since we don't support Julia < 1.5, we could just replace it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought accumulate would face the same issue as foldl, namely that the rrule doesn't consider init? This PR need not be concerned with activations either way, we can kick that can down the road until rrules are tweaked or someone complains about performance.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, you could likewise do accumulate(|>, (x, m.layers...)). But happy to leave it alone for now.

Expand Down
13 changes: 7 additions & 6 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ for T in [
end

function _big_show(io::IO, obj, indent::Int=0, name=nothing)
pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")")
children = _show_children(obj)
if all(_show_leaflike, children)
_layer_show(io, obj, indent, name)
else
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(")
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre)
if obj isa Chain{<:NamedTuple} && children == getfield(obj, :layers)
# then we insert names -- can this be done more generically?
for k in Base.keys(obj)
Expand All @@ -35,10 +36,10 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
end
end
if indent == 0 # i.e. this is the outermost container
print(io, ")")
print(io, rpad(post, 2))
_big_finale(io, obj)
else
println(io, " "^indent, "),")
println(io, " "^indent, post, ",")
end
end
end
Expand Down Expand Up @@ -90,18 +91,18 @@ function _big_finale(io::IO, m)
noncnt = _childarray_sum(_->1, m) - length(ps)
if noncnt > 0
nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps))
printstyled(io, " "^09, "# Total: ", length(ps), " trainable arrays, "; color=:light_black)
printstyled(io, " "^08, "# Total: ", length(ps), " trainable arrays, "; color=:light_black)
println(io, pars, " parameters,")
printstyled(io, " "^10, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, summarysize "; color=:light_black)
print(io, bytes, ".")
else
printstyled(io, " "^19, "# Total: ", length(ps), " arrays, "; color=:light_black)
printstyled(io, " "^18, "# Total: ", length(ps), " arrays, "; color=:light_black)
print(io, pars, " parameters, ", bytes, ".")
end
end
end

_childarray_sum(f, x::AbstractArray) = f(x)
_childarray_sum(f, x::AbstractArray{<:Number}) = f(x)
_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x))

# utility functions
Expand Down
40 changes: 40 additions & 0 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import Flux: activations
@test m == fmap(identity, m) # does not forget names

@test_throws ArgumentError Chain(layers = Dense(10, 10), two = identity) # reserved name

@test_nowarn Chain([Dense(10, 5, σ), Dense(5, 2)])(randn(Float32, 10)) # vector of layers
end

@testset "Activations" begin
Expand Down Expand Up @@ -297,3 +299,41 @@ import Flux: activations
@test_throws DimensionMismatch m(OneHotVector(3, 1000))
end
end

@testset "second derivatives" begin
m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2))
@test Zygote.hessian_dual(sum∘m1, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m1, [1,2,3])

m1v = Chain([m1[1], m1[2]]) # vector of layers
@test Zygote.hessian_dual(sum∘m1v, [1,2,3]) ≈ Zygote.hessian_dual(sum∘m1, [1,2,3])
@test_broken Zygote.hessian_dual(sum∘m1v, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m1v, [1,2,3])

# NNlib's softmax gradient writes in-place
m2 = Chain(Dense(3,4,tanh), Dense(4,2), softmax)
@test_broken Zygote.hessian_dual(sum∘m2, [1,2,3]) ≈ Zygote.hessian_reverse(sum∘m2, [1,2,3])

# https://github.com/FluxML/NNlib.jl/issues/362
m3 = Chain(Conv((3,), 2 => 3, relu), Dense(2,2))
x3 = cat(Float32[1 2; 3 4; 5 6; 7 8]; dims=3)
@test_broken Zygote.hessian_dual(sum∘m3, x3) ≈ Zygote.hessian_reverse(sum∘m3, x3)
end

@testset "gradients of Chain{Vector}" begin
m1 = Chain(Dense(3,4,tanh; bias=false), Dense(4,2))
m1v = Chain([m1[1], m1[2]])
@test sum(length, params(m1)) == sum(length, params(m1v))

x1 = randn(Float32,3,5)
@test m1(x1) ≈ m1v(x1)

y1 = rand(Bool,2,5)
g1 = gradient(() -> Flux.Losses.logitcrossentropy(m1(x1), y1), params(m1))
g1v = gradient(() -> Flux.Losses.logitcrossentropy(m1v(x1), y1), params(m1v))
@test g1[m1[1].weight] ≈ g1v[m1v[1].weight]
@test g1[m1[2].bias] ≈ g1v[m1v[2].bias]

@test Flux.destructure(m1)[1] ≈ Flux.destructure(m1v)[1]
z1 = rand(22);
@test Flux.destructure(m1)[2](z1)[1].weight ≈ Flux.destructure(m1v)[2](z1)[1].weight
# Note that Flux.destructure(m1v)[2](z) has a Chain{Tuple}, as does m1v[1:2]
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Flux: OneHotArray, OneHotMatrix, OneHotVector
using Test
using Random, Statistics, LinearAlgebra
using IterTools: ncycle
using Zygote
using CUDA

Random.seed!(0)
Expand Down