[attrs] add linearize and vjp support #19960
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There are two commits here:
attrs.jvp
to be simpler (IMO, though I'm not sure yet...) so that we handle input perturbations at the traceable level and the inner transformations never need to worry about them (e.g. we don't create separate attr input tracers)attrs.linearize
andattrs.vjp
The plan is for all of these features to be incorporated into the normal
jax.jvp
,jax.linearize
, andjax.vjp
, but we're keeping them separate for now while we play with them.The signatures generalize
jax.linearize
andjax.vjp
, like:We're currently pretty inconsistent between using lists like
list[tuple[Object, str]]
andlist[tuple[Object, str, Array]]
vs sets and dicts likeset[tuple[Object, str]]
anddict[tuple[Object, str], Array]
. For the latter we require the mutable objects of interest to useobject.__hash__
/object.__eq__
. We're also being inconsistent about whether tangent/cotangent result dicts represent zeros by missing entries, or symbolic zeros, or dense zeros. We'll make these things consistent at some point.These APIs are general but pretty low-level. A neural net library, for example, might help handle some of these details for the user, e.g. a model might be able to report
trainable_params : set[(Object, str)]
to be fed into this kind of API.The implementation for
attrs.linearize
andattrs.vjp
in this PR ended up being very straightforward, leveraging one assumption we may want to relax later: so long as the linearized computation staged out into a jaxpr never involves any jax_getattr/jax_setattr (notice custom_jvp/custom_vjp rules which include jax_getattr/jax_setattr satisfy this definition), we don't need to change the partial eval or transpose machinery. The partial-evaled-out computation remains pure. That is, the JVP support we already landed is enough. Soattrs.linearize
andattrs.vjp
are actually just likejax.linearize
andjax.vjp
, with two bookkeeping differences:attrs.jvp
under the hood rather thanjax.jvp
, andf_lin
/f_vjp
function, route attrs dict entries to the appropriate jaxpr inputs/outputs.