Avoid jitting jax functions inside a jitted function #14333
Replies: 1 comment 5 replies
-
In general you cannot avoid jit-compiling jax functions within a JIT function. The problem is that the return value of Use static properties in control flowYou can do python control flow that branches on static rather than dynamic attributes. Examples of static attributes are the @jax.jit
def dummy():
if jnp.issubdtype(x.dtype, jnp.complexfloating):
print("Is complex!") Use numpy for static computationsIf you'd like to do static computations on static values, you can do so with import numpy as np
x = np.array(3) # note: numpy not jax.numpy
@jax.jit
def dummy():
if np.iscomplex(x): # note: numpy not jax.numpy
print("Is complex!") External CallbacksIf you're doing something more sophisticated that you would like executed at runtime, you can do this with external callbacks, although this typically comes with performance penalties due to the data transfer and synchronization between host and device. You can read more about that here: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html |
Beta Was this translation helpful? Give feedback.
-
Is there any way to disable the jitting of jax functions inside a jitted function?
If I call the
dummy
function defined below, an error will be raised, becausejnp.iscomplex(x)
returns a tracer.But
x
is fixed, thus I'd wantjnp.iscomplex(x)
to returnFalse
.Beta Was this translation helpful? Give feedback.
All reactions