You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
jax.checkpoint is brilliant. We normally use static_argnums for things we force to be key word arguments so we are currently having to write wrappers for modules accepting static things passed as kwargs. Please could support for static_argnames be added, as in jax.jit?
e.g. in FLAX Linen (which is just lifting and passing to jax.checkpoint):
Hi,
jax.checkpoint is brilliant. We normally use static_argnums for things we force to be key word arguments so we are currently having to write wrappers for modules accepting static things passed as kwargs. Please could support for static_argnames be added, as in jax.jit?
e.g. in FLAX Linen (which is just lifting and passing to jax.checkpoint):
The text was updated successfully, but these errors were encountered: