Skip to content

Freezing parameters with partial function or static_argnames argument. #13913

Answered by jakevdp
gerdm asked this question in Q&A
Discussion options

You must be logged in to vote

static_argnames operates differently than closures; in particular, closed-over arrays will be treated as dynamic arguments, while closed-over values that are non-arrays will be treated as static.

You can see this in the jaxpr for closed-over functions:

@jax.jit
def f(x, y):
  return x + y

x = jnp.arange(4.0)
y = jnp.ones(4)

print(jax.make_jaxpr(f)(x, y))
# { lambda ; a:f32[4] b:f32[4]. let
#     c:f32[4] = xla_call[
#       call_jaxpr={ lambda ; d:f32[4] e:f32[4]. let f:f32[4] = add d e in (f,) }
#       name=f
#     ] a b
#   in (c,) }

print(jax.make_jaxpr(partial(f, y=y))(a))
# { lambda a:f32[4]; b:f32[4]. let
#     c:f32[4] = xla_call[
#       call_jaxpr={ lambda ; d:f32[4] e:f32[4]…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@gerdm
Comment options

@mattjj
Comment options

Answer selected by gerdm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants