Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static_argnames support in jax.checkpoint, as in jax.jit #27153

Open
oliverdutton-iso opened this issue Mar 14, 2025 · 1 comment
Open

Add static_argnames support in jax.checkpoint, as in jax.jit #27153

oliverdutton-iso opened this issue Mar 14, 2025 · 1 comment
Labels
enhancement New feature or request

Comments

@oliverdutton-iso
Copy link
Contributor

oliverdutton-iso commented Mar 14, 2025

Hi,

jax.checkpoint is brilliant. We normally use static_argnums for things we force to be key word arguments so we are currently having to write wrappers for modules accepting static things passed as kwargs. Please could support for static_argnames be added, as in jax.jit?

e.g. in FLAX Linen (which is just lifting and passing to jax.checkpoint):

def checkpoint(
    mdl,
    *args,
    static_argnames=('is_training',),
    **kwargs,
):
  """Checkpoint a module."""
  static_kwargs = {k: v for k, v in kwargs.items() if k in static_argnames}
  other_kwargs = {k: v for k, v in kwargs.items() if k not in static_argnames}

  @functools.partial(nn.checkpoint, static_argnums=(1,))
  def f(mdl, static_args, *args, **kwargs):
    return mdl(
        *args, **kwargs, **dict(zip(static_kwargs.keys(), static_args))
    )
  return f(mdl, tuple(static_kwargs.values()), *args, **other_kwargs)
@oliverdutton-iso oliverdutton-iso added the enhancement New feature or request label Mar 14, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Mar 14, 2025

Related to this proposal: #10614

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants