Skip to content

JIT-compiled functions updating and returning non JAX types. Intended or not? #13912

Answered by jakevdp
zongyf02 asked this question in Q&A
Discussion options

You must be logged in to vote

This assertion is not entirely correct:

JIT functions require outputs to be valid JAX types

More precisisely, JIT requires outputs to be valid JAX types or pytrees whose leaves are valid JAX types. Your function returns a PyTree whose leaves are valid jax types, so it is fine.

So what about dtype? Well, dtype in your example is not a leaf, but is part of aux_data which you can think about as being part of the typeclass defined by your pytree registration. So technically your function returns a different typeclass than the input. It might help to look at it this way:

_, input_typeclass = jax.tree_util.tree_flatten(tracker)
_, output_typeclass = jax.tree_util.tree_flatten(new_tracker)

print

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by zongyf02
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants