We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 92aa9a3 + 6378984 commit 1a2737bCopy full SHA for 1a2737b
jax/experimental/mosaic/gpu/examples/flash_attention.py
@@ -22,6 +22,7 @@
22
import jax
23
from jax import random
24
from jax._src.interpreters import mlir
25
+from jax._src import test_util as jtu
26
from jax.experimental.mosaic.gpu import profiler
27
from jax.experimental.mosaic.gpu import * # noqa: F403
28
import jax.numpy as jnp
0 commit comments