Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generated z2d #1394

Merged
merged 3 commits into from
Mar 15, 2023
Merged

generated z2d #1394

merged 3 commits into from
Mar 15, 2023

Conversation

chengchingwen
Copy link
Member

This PR add generated version of z2d. Should improve type stability of nested functor

return NoTangent() # collapse all-zero case
else
backing = NamedTuple{$fnames}(inner)
return canonicalize(Tangent{T, typeof(backing)}(backing))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that the original has this canonicalize, but is it clear this is ever needed? By construction this includes all fieldnames(T) already.

It's another complicated function, and if inference & 2nd order things are the goal, then perhaps fewer such calls is better.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simply convert the code to generated function, should we remove the canonicalize call in both path?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's entirely possible that I put that there, I don't remember. But today it looks unnecessary. At least trying without seems like a good idea.

@chengchingwen
Copy link
Member Author

buildkite fail seems unrelated

@ToucheSir
Copy link
Member

Can you devise a test that only infers with the generated function version?

@chengchingwen
Copy link
Member Author

@ToucheSir I can't find a simple example, but it's required to make Transformers.jl type stable.

@chengchingwen
Copy link
Member Author

Though it's hard to find a test for it, this PR does not bring up extra cost. The code is mostly a direct translation and we preserve the non-generated path as well. Meanwhile the instability can be easily spot at the potential nested map(z2d, ...).

@chengchingwen
Copy link
Member Author

@ToucheSir If you want to inspect some example, use this branch of Transformers.jl and take gradient of a simple SelfAttention layer with/without this patch.

Comment on lines 334 to 343
fnames = fieldnames(T)
deltas = map(n -> get(delta, n, nothing), fnames)
primals = map(n -> getfield(primal, n), fnames)
inner = map(z2d, deltas, primals) # recurse into fields
if inner isa Tuple{Vararg{AbstractZero}}
return NoTangent() # collapse all-zero case
else
backing = NamedTuple{fnames}(inner)
return Tangent{T, typeof(backing)}(backing)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on direction from Slack (where I was pointed to JuliaLang/julia#23168 (comment)), the easiest way to test both branches would be to pull this into an internal helper function _z2d_struct_fallback. Then we can test z2d(...) == _z2d_struct_fallback(...) since we know the former will always take the generated path on our CI.

@ToucheSir ToucheSir merged commit 04b527b into FluxML:master Mar 15, 2023
marius311 added a commit to marius311/Zygote.jl that referenced this pull request Mar 20, 2023
[Diff since v0.6.58](FluxML/Zygote.jl@v0.6.58...v0.6.59)

**Merged pull requests:**
- Actually make sure conda env dir is set on Buildkite CI (FluxML#1392) (@ToucheSir)
- generated z2d (FluxML#1394) (@chengchingwen)
- bump version (FluxML#1398) (@chengchingwen)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants