-
I met with a problem dumping jax arrays in a complex jitted funtion. I want to dump two jax arrays which is computed by the following function: def fourier_embeddings(x: jnp.ndarray, dim: int) -> jnp.ndarray:
w_key, b_key = jax.random.split(jax.random.PRNGKey(42))
weight = jax.random.normal(w_key, shape=[dim])
bias = jax.random.uniform(b_key, shape=[dim])
return (
jnp.cos(2 * jnp.pi * (x[..., None] * weight + bias)),
weight, bias
) Following the community discussion: I wrote this function, hoping it will work:
The function call is as follows: noise_embedding, weight, bias = fourier_embeddings(
(1 / 4) * jnp.log(noise_level / SIGMA_DATA), dim=256
)
jax_save_with_jit('w', weight)
jax_save_with_jit('b', bias) For the value type:
Unfortunately, it turns out that nothing was saved at all in the directed path... |
Beta Was this translation helpful? Give feedback.
Answered by
jakevdp
Jan 13, 2025
Replies: 1 comment 1 reply
-
Note that |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
drewjin
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note that
host_callback
is deprecated; you probably wantio_callback
. For more information, see https://jax.readthedocs.io/en/latest/external-callbacks.html#exploring-io-callback