-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update GNNChain #202
Update GNNChain #202
Conversation
## TODO see if this is faster for small chains | ||
## see https://github.com/FluxML/Flux.jl/pull/1809#discussion_r781691180 | ||
# @generated function _applychain(layers::Tuple{Vararg{<:Any,N}}, g::GNNGraph, x) where {N} | ||
# symbols = vcat(:x, [gensym() for _ in 1:N]) | ||
# calls = [:($(symbols[i+1]) = _applylayer(layers[$i], $(symbols[i]))) for i in 1:N] | ||
# Expr(:block, calls...) | ||
# end | ||
# _applychain(layers::NamedTuple, g, x) = _applychain(Tuple(layers), x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to myself: remember to benchmark this before merging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmarked on a small graph / net
n, deg = 10, 4
din, d, dout = 10, 3, 4, 2
g = GNNGraph(random_regular_graph(n, deg),
graph_type=GRAPH_T,
ndata= randn(Float32, din, n))
x = g.ndata.x
gnn = GNNChain(GCNConv(din => d),
BatchNorm(d),
x -> tanh.(x),
GraphConv(d => d, tanh),
Dropout(0.5),
Dense(d, dout))
There is a performance increase with the generated _applychain
but not large enough for the change to be worthwhile
julia> using BenchmarkTools
### without @generated _applychain
julia> @btime gnn(g, x)
7.469 μs (84 allocations: 12.48 KiB)
2×10 Matrix{Float32}:
-0.8186 -0.570312 -0.777638 … -0.641642 -0.684857 -0.975505
0.305567 0.559996 0.631279 0.4687 0.479899 0.321139
julia> @btime gradient(x -> sum(gnn(g, x)), x)
515.917 μs (2422 allocations: 160.52 KiB)
(Float32[0.3974119 -0.5917164 … -0.9200875 1.1957061; -0.54502636 -1.5056851 … -2.6915464 2.5114572; -0.97105116 0.7726713 … 1.0995824 -1.5013595],)
### with @generated _applychain
julia> @btime gnn(g, x)
6.825 μs (73 allocations: 11.55 KiB)
2×10 Matrix{Float32}:
-0.8186 -0.570312 -0.777638 … -0.641642 -0.684857 -0.975505
0.305567 0.559996 0.631279 0.4687 0.479899 0.321139
julia> @btime gradient(x -> sum(gnn(g, x)), x)
454.750 μs (2157 allocations: 161.00 KiB)
(Float32[-0.564121 0.3105453 … 0.19531891 -0.22819248; -0.6428803 0.13550264 … 0.9421329 -0.79201597; 0.7816532 -0.4734739 … 0.23667078 0.033573348],)
In both cases the gradient is very slow, this should be further investigated
Codecov Report
@@ Coverage Diff @@
## master #202 +/- ##
==========================================
+ Coverage 86.44% 87.35% +0.90%
==========================================
Files 15 15
Lines 1365 1368 +3
==========================================
+ Hits 1180 1195 +15
+ Misses 185 173 -12
Continue to review full report at Codecov.
|
Keeping in sync GNNChain with last months' changes with Flux.Chain implementation (FluxML/Flux.jl#1809)
TODO: