-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
generated z2d #1394
generated z2d #1394
Conversation
src/compiler/chainrules.jl
Outdated
return NoTangent() # collapse all-zero case | ||
else | ||
backing = NamedTuple{$fnames}(inner) | ||
return canonicalize(Tangent{T, typeof(backing)}(backing)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that the original has this canonicalize
, but is it clear this is ever needed? By construction this includes all fieldnames(T)
already.
It's another complicated function, and if inference & 2nd order things are the goal, then perhaps fewer such calls is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I simply convert the code to generated function, should we remove the canonicalize
call in both path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's entirely possible that I put that there, I don't remember. But today it looks unnecessary. At least trying without seems like a good idea.
buildkite fail seems unrelated |
Can you devise a test that only infers with the generated function version? |
@ToucheSir I can't find a simple example, but it's required to make Transformers.jl type stable. |
Though it's hard to find a test for it, this PR does not bring up extra cost. The code is mostly a direct translation and we preserve the non-generated path as well. Meanwhile the instability can be easily spot at the potential nested |
@ToucheSir If you want to inspect some example, use this branch of Transformers.jl and take gradient of a simple |
src/compiler/chainrules.jl
Outdated
fnames = fieldnames(T) | ||
deltas = map(n -> get(delta, n, nothing), fnames) | ||
primals = map(n -> getfield(primal, n), fnames) | ||
inner = map(z2d, deltas, primals) # recurse into fields | ||
if inner isa Tuple{Vararg{AbstractZero}} | ||
return NoTangent() # collapse all-zero case | ||
else | ||
backing = NamedTuple{fnames}(inner) | ||
return Tangent{T, typeof(backing)}(backing) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on direction from Slack (where I was pointed to JuliaLang/julia#23168 (comment)), the easiest way to test both branches would be to pull this into an internal helper function _z2d_struct_fallback
. Then we can test z2d(...) == _z2d_struct_fallback(...)
since we know the former will always take the generated path on our CI.
[Diff since v0.6.58](FluxML/Zygote.jl@v0.6.58...v0.6.59) **Merged pull requests:** - Actually make sure conda env dir is set on Buildkite CI (FluxML#1392) (@ToucheSir) - generated z2d (FluxML#1394) (@chengchingwen) - bump version (FluxML#1398) (@chengchingwen)
This PR add generated version of
z2d
. Should improve type stability of nested functor