diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index 3394eaaa0a3d..daacefb135e9 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -22,6 +22,7 @@ import jax from jax import random from jax._src.interpreters import mlir +from jax._src import test_util as jtu from jax.experimental.mosaic.gpu import profiler from jax.experimental.mosaic.gpu import * # noqa: F403 import jax.numpy as jnp