-
Hi, I'm trying to understand why is it possible to use the SetupFirst, I create the dataset import jax
from functools import partial
from string import ascii_uppercase
from jax.flatten_util import ravel_pytree
n_elements = 10
key = jax.random.PRNGKey(314159)
key_length, key_vals = jax.random.split(key)
keys_vals = jax.random.split(key_vals, n_elements)
letters = ascii_uppercase[:n_elements]
lengths = jax.random.choice(key_length, 20, (n_elements,))
elements = {
letter: jax.random.normal(key, (length,))
for letter, key, length in zip(letters, keys_vals, lengths)
} Next, I specify the static parameters that will be passed to a function. vals, r_fn = ravel_pytree(elements)
config = {
"fn": r_fn,
"more_stuff": ...
} I'm interested in defining a function that jits over
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
You can see this in the jaxpr for closed-over functions: @jax.jit
def f(x, y):
return x + y
x = jnp.arange(4.0)
y = jnp.ones(4)
print(jax.make_jaxpr(f)(x, y))
# { lambda ; a:f32[4] b:f32[4]. let
# c:f32[4] = xla_call[
# call_jaxpr={ lambda ; d:f32[4] e:f32[4]. let f:f32[4] = add d e in (f,) }
# name=f
# ] a b
# in (c,) }
print(jax.make_jaxpr(partial(f, y=y))(a))
# { lambda a:f32[4]; b:f32[4]. let
# c:f32[4] = xla_call[
# call_jaxpr={ lambda ; d:f32[4] e:f32[4]. let f:f32[4] = add d e in (f,) }
# name=f
# ] b a
# in (c,) } In both cases |
Beta Was this translation helpful? Give feedback.
static_argnames
operates differently than closures; in particular, closed-over arrays will be treated as dynamic arguments, while closed-over values that are non-arrays will be treated as static.You can see this in the jaxpr for closed-over functions: