-
Consider this toy example of a tracker that returns the sum of all seen values. Its hidden state, import jax
from jax import (
numpy as np,
tree_util as jtu
)
from dataclasses import dataclass
from typing import Any
@jtu.register_pytree_node_class
@dataclass
class SumTracker:
sum: jax.Array = jnp.zeros(1) # Hidden state
dtype: Any = jnp.float32 # Default op dtype
def __call__(self, x, dtype=None):
"""Returns the sum of all seen values"""
if dtype is not None:
self.dtype = dtype # Override default op dtype
self.sum = self.sum + x.sum(dtype=self.dtype)
return self.sum, self
def tree_flatten(self):
children = (self.sum,)
aux = (self.dtype,) # dtype in aux data
return children, aux
@classmethod
def tree_unflatten(cls, aux, children):
return cls(*children, *aux) As expected, calling the tracker updates its hidden state tracker = SumTracker()
print(tracker) # SumTracker(sum=Array([0.], dtype=float32), dtype=<class 'jax.numpy.float32'>)
jitted_call = jax.jit(SumTracker.__call__, static_argnames="dtype")
_, new_tracker = jitted_call(tracker, x=jnp.ones((2, 2)), dtype=jnp.float16)
print(new_tracker) # SumTracker(sum=Array([4.], dtype=float32), dtype=<class 'jax.numpy.float16'>) In other words, although JIT functions require outputs to be valid JAX types, I "tricked" the JIT function into operating on non JAX types as well. Of course, retracing happens for each new Is this intended? Or am I abusing undefined behavior? P.S. However, doing so in my case would cause this _, new_tracker = jitted_call(tracker, x=jnp.ones((2, 2)), dtype=jnp.float16) to no longer work, since JIT functions cannot return the non JAX types, which would be found in |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
This assertion is not entirely correct:
More precisisely, JIT requires outputs to be valid JAX types or pytrees whose leaves are valid JAX types. Your function returns a PyTree whose leaves are valid jax types, so it is fine. So what about _, input_typeclass = jax.tree_util.tree_flatten(tracker)
_, output_typeclass = jax.tree_util.tree_flatten(new_tracker)
print(input_typeclass)
# PyTreeDef(CustomNode(SumTracker[(<class 'jax.numpy.float32'>,)], [*]))
print(output_typeclass)
# PyTreeDef(CustomNode(SumTracker[(<class 'jax.numpy.float16'>,)], [*]))
print(input_typeclass == output_typeclass)
# False The input and output have different types, and this is fine: JAX generally allows you to define functions whose outputs are of a different type than the inputs.
As you found, this is not possible. PyTree leaves must be valid JAX types or pytrees thereof, and a dtype is neither a valid JAX type, nor is it registered as a pytree. All in all, it looks like this is working as intended. |
Beta Was this translation helpful? Give feedback.
-
That's good to know! I can see that being quite useful in updating the |
Beta Was this translation helpful? Give feedback.
This assertion is not entirely correct:
More precisisely, JIT requires outputs to be valid JAX types or pytrees whose leaves are valid JAX types. Your function returns a PyTree whose leaves are valid jax types, so it is fine.
So what about
dtype
? Well,dtype
in your example is not a leaf, but is part ofaux_data
which you can think about as being part of the typeclass defined by your pytree registration. So technically your function returns a different typeclass than the input. It might help to look at it this way: