Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[attrs] add linearize and vjp support #19960

Merged
merged 2 commits into from
Feb 24, 2024

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Feb 24, 2024

There are two commits here:

  • 67572d3 is a small tweak to the implementation of attrs.jvp to be simpler (IMO, though I'm not sure yet...) so that we handle input perturbations at the traceable level and the inner transformations never need to worry about them (e.g. we don't create separate attr input tracers)
  • 2ce8c57 adds attrs.linearize and attrs.vjp

The plan is for all of these features to be incorporated into the normal jax.jvp, jax.linearize, and jax.vjp, but we're keeping them separate for now while we play with them.

The signatures generalize jax.linearize and jax.vjp, like:

InPrimal = OutPrimal = PyTree[Array]
def vjp(f: Callable, *primals: InPrimal, attrs: list[tuple[Object, str]]
        ) -> tuple[OutPrimal, VJPFun]:
  ...

OutCT = PyTree[Array]
ArgCTs = tuple[PyTree[Array], ...]

vjpfun : VJPFun
def vjpfun(out_ct: OutCT, *, attr_cotangents: dict[tuple[Object, str], Array] = {}
           ) -> tuple[ArgCTs, dict[tuple[Object, str], Array]]:
  ...

We're currently pretty inconsistent between using lists like list[tuple[Object, str]] and list[tuple[Object, str, Array]] vs sets and dicts like set[tuple[Object, str]] and dict[tuple[Object, str], Array]. For the latter we require the mutable objects of interest to use object.__hash__ / object.__eq__. We're also being inconsistent about whether tangent/cotangent result dicts represent zeros by missing entries, or symbolic zeros, or dense zeros. We'll make these things consistent at some point.

These APIs are general but pretty low-level. A neural net library, for example, might help handle some of these details for the user, e.g. a model might be able to report trainable_params : set[(Object, str)] to be fed into this kind of API.

The implementation for attrs.linearize and attrs.vjp in this PR ended up being very straightforward, leveraging one assumption we may want to relax later: so long as the linearized computation staged out into a jaxpr never involves any jax_getattr/jax_setattr (notice custom_jvp/custom_vjp rules which include jax_getattr/jax_setattr satisfy this definition), we don't need to change the partial eval or transpose machinery. The partial-evaled-out computation remains pure. That is, the JVP support we already landed is enough. So attrs.linearize and attrs.vjp are actually just like jax.linearize and jax.vjp, with two bookkeeping differences:

  1. call attrs.jvp under the hood rather than jax.jvp, and
  2. in the returned f_lin / f_vjp function, route attrs dict entries to the appropriate jaxpr inputs/outputs.

@mattjj mattjj requested a review from dougalm February 24, 2024 00:12
@mattjj mattjj self-assigned this Feb 24, 2024
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Feb 24, 2024
@copybara-service copybara-service bot merged commit 072b43b into jax-ml:main Feb 24, 2024
12 of 13 checks passed
@mattjj mattjj deleted the attrs-autodiff branch February 24, 2024 04:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants