Skip to content

Commit

Permalink
Add back the import of jtu in flash_attention.py
Browse files Browse the repository at this point in the history
This was erroneously removed in de3191f.
  • Loading branch information
andportnoy authored Oct 22, 2024
1 parent 2b7b074 commit 6378984
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions jax/experimental/mosaic/gpu/examples/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jax
from jax import random
from jax._src.interpreters import mlir
from jax._src import test_util as jtu
from jax.experimental.mosaic.gpu import profiler
from jax.experimental.mosaic.gpu import * # noqa: F403
import jax.numpy as jnp
Expand Down

0 comments on commit 6378984

Please sign in to comment.