-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a "How to think in JAX" doc #4643
Comments
I think this would be great. Do you have any preliminary pointers on when to use |
Briefly: use |
Another point to touch on is the idea of fixed size iterative algorithms over recursive. It's the main challenge in coding complicated algorithms with JAX. It's what hinders many people I presume, and I know I have had to redesign an algorithm several times to make it work with JAX. But the results are amazing in terms of speed improvments when you do this. Another point might be about reducing how many |
It might also be interesting to discuss
|
A relevant user question: #5280 (comment) |
Hello @jakevdp, I ended up writing up something that reminded me of this question and I thought I'd paste it here in case it's helpful. I'm sure the JAX team can write this up better, but I think it would be useful to have something like this explained in the JAX documentation: Static versus dynamic parameters in JAXIn JAX, various jit and gradient functions accept static and dynamic parameters. These always default to being dynamic unless Static parameters must be hashable. Dynamic parameters must be pytrees, which are either
The JIT looks up the compiled program using a dictionary that's keyed by
Consequently,
Dynamic parameters are replaced with tracers within the JAX-decorated functions, so
Static parameters are passed to the jitted function unchanged, so
|
Thanks - a couple clarifications to this:
|
@jakevdp Good point about closed-over parameters. Could you elaborate on the second point if you have time? |
Sorry, that wasn't very clear. I meant to bring up the difference here: import jax
from functools import partial
def f(x):
if x < 0:
return 0.0
else:
return x ** 2
jax.jit(f)(1.0) # Fails because x is traced
jax.grad(f)(1.0) # grad is fine with x being traced autodiff and JIT have different restrictions concerning operations that can be done with traced parameters. |
Wow, thanks a lot for correcting my understanding. So your example shows a dynamic parameter to When it's possible, should we |
Over the past year, I think there has been a bit of a transition in the usage of JAX-flavored numpy from JAX is a drop-in replacement for numpy to JAX is a tool to use beside numpy. This has manifested in many different ways (the switch from
import jax.numpy as np
toimport jax.numpy as jnp
, the deprecation of silent conversions to array within aggregates, some aspects of omnistaging, etc.) but I think we've now gotten to the point where using JAX effectively requires thinking about which operations should be staged (viajnp
routines) and which should not be (vianp
routines).I think it's important that we provide some entry-point into developing the correct mental model for using JAX effectively. I think this probably should take the form of a new narrative doc, which would absorb some aspect of the existing Sharp Bits doc.
A quick brainstorm of key points we should cover:
jnp
for operations that you want to be staged/optimized, andnp
for "compile-time" operationsjnp.prod(shape)
vs.np.prod(shape)
. Why should you use the latter?x[idx] = y
does not work. Insteadx = x.at[idx].set(y)
& relatedThe text was updated successfully, but these errors were encountered: