Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a "How to think in JAX" doc #4643

Closed
jakevdp opened this issue Oct 19, 2020 · 10 comments
Closed

Add a "How to think in JAX" doc #4643

jakevdp opened this issue Oct 19, 2020 · 10 comments
Assignees

Comments

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 19, 2020

Over the past year, I think there has been a bit of a transition in the usage of JAX-flavored numpy from JAX is a drop-in replacement for numpy to JAX is a tool to use beside numpy. This has manifested in many different ways (the switch from import jax.numpy as np to import jax.numpy as jnp, the deprecation of silent conversions to array within aggregates, some aspects of omnistaging, etc.) but I think we've now gotten to the point where using JAX effectively requires thinking about which operations should be staged (via jnp routines) and which should not be (via np routines).

I think it's important that we provide some entry-point into developing the correct mental model for using JAX effectively. I think this probably should take the form of a new narrative doc, which would absorb some aspect of the existing Sharp Bits doc.

A quick brainstorm of key points we should cover:

  • JAX is not a drop-in replacement for numpy, but is a numpy-like interface for staged computations.
  • Use jnp for operations that you want to be staged/optimized, and np for "compile-time" operations
  • Common example: jnp.prod(shape) vs. np.prod(shape). Why should you use the latter?
  • Key difference: JAX arrays are immutable, so x[idx] = y does not work. Instead x = x.at[idx].set(y) & related
  • JAX and dynamic shapes: can be staged but not compiled.
  • Other points?
@jakevdp jakevdp self-assigned this Oct 19, 2020
@lukepfister
Copy link
Contributor

I think this would be great.

Do you have any preliminary pointers on when to use jnp vs np?

@jakevdp
Copy link
Collaborator Author

jakevdp commented Oct 21, 2020

Briefly: use jnp when you want the calculation to be compiled / to be performed on the accelerator. Use np when you want the calculation to happen on the CPU / at compile time.

@Joshuaalbert
Copy link
Contributor

Another point to touch on is the idea of fixed size iterative algorithms over recursive. It's the main challenge in coding complicated algorithms with JAX. It's what hinders many people I presume, and I know I have had to redesign an algorithm several times to make it work with JAX. But the results are amazing in terms of speed improvments when you do this.

Another point might be about reducing how many cond operations you use to let accelerators work as effectively as possible. I.e. somethings it's better to use a where and compute both branchs of a switch rather than use a cond. Knowing when to do that typically requires profiling, but there are some good rules of thumb.

@NeilGirdhar
Copy link
Contributor

It might also be interesting to discuss

  • static vs nonstatic attributes in pytrees, jitted functions, and custom derivatives, and
  • imagining the propagation through a single code path of both primals and cotangents.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Dec 30, 2020

A relevant user question: #5280 (comment)

@NeilGirdhar
Copy link
Contributor

Hello @jakevdp, I ended up writing up something that reminded me of this question and I thought I'd paste it here in case it's helpful. I'm sure the JAX team can write this up better, but I think it would be useful to have something like this explained in the JAX documentation:

Static versus dynamic parameters in JAX

In JAX, various jit and gradient functions accept static and dynamic parameters. These always default to being dynamic unless
marked static (using the static_argnums, static_argnames, or for historical reasons nondiff_argnums).

Static parameters must be hashable. Dynamic parameters must be pytrees, which are either

  • leaves comprising scalars or jax.numpy.ndarray instances, or else
  • aggregate objects comprising dynamic fields (which act as dynamic parameters) and static fields (which act as static parameters despite being passed as part of a dynamic parameter).

The JIT looks up the compiled program using a dictionary that's keyed by

  • the tree structure (Python types of all its components),
  • the shapes and dtypes of its array-valued leaves, and
  • the values of its static parameters.

Consequently,

  • calling the jitted function with different values of the static parameters always induces recompilation, but
  • calling the jitted function with different values (but the same shape) of the dynamic parameters never induces recompilation. They are merely arguments to the compiled program.

Dynamic parameters are replaced with tracers within the JAX-decorated functions, so

  • they cannot be used as the limit of a jax.lax.scan,
  • they cannot be used in Python switches (like if and while), but
  • they can be vectorized by vmap, and be the differentiand of grad, vjp, jvp, etc.

Static parameters are passed to the jitted function unchanged, so

  • they can be used as the limit of a jax.lax.scan,
  • they can be used in Python switches, but
  • cannot be vectorized or be the differentiand.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Nov 22, 2021

Thanks - a couple clarifications to this:

  • parameters may also be static if closed over (e.g. in jit(lambda x: f(x, y)), y will be treated as static)
  • autodiff can be done with respect to dynamic parameters

@NeilGirdhar
Copy link
Contributor

@jakevdp Good point about closed-over parameters.

Could you elaborate on the second point if you have time?

@jakevdp
Copy link
Collaborator Author

jakevdp commented Nov 22, 2021

Sorry, that wasn't very clear. I meant to bring up the difference here:

import jax
from functools import partial

def f(x):
  if x < 0:
    return 0.0
  else:
    return x ** 2

jax.jit(f)(1.0) # Fails because x is traced
jax.grad(f)(1.0)  # grad is fine with x being traced

autodiff and JIT have different restrictions concerning operations that can be done with traced parameters.

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Nov 22, 2021

Wow, thanks a lot for correcting my understanding.

So your example shows a dynamic parameter to grad being used in a condition. It's a mystery to me that that works. I guess those tracers have the static value available.

When it's possible, should we jit before grad to prevent re-evaluation of the function?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants