Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# We need to import the CUDA kernels after importing torch
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
if USE_TRITON_ROCM:
from .flash_attn_triton_amd import interface_fa as flash_attn_gpu
from .flash_attn_triton_amd import flash_attn_2 as flash_attn_gpu
else:
import flash_attn_2_cuda as flash_attn_gpu

Expand Down Expand Up @@ -127,7 +127,10 @@ def _flash_attn_forward_fake(
softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout)
p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
if return_softmax:
p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout)
if torch.cuda.is_available() and torch.version.hip:
p = torch.empty((batch_size, num_heads, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device, layout=q.layout)
else:
p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout)
rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)

return out, softmax_lse, p, rng_state
Expand Down Expand Up @@ -220,10 +223,11 @@ def _flash_attn_varlen_forward_fake(
out = torch.empty_like(q)
softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout)
p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
seqlen_k_rounded = round_multiple(max_seqlen_k, 128)
if return_softmax:
p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout)
if torch.cuda.is_available() and torch.version.hip:
p = torch.empty((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device, layout=q.layout)
else:
p = torch.empty((batch_size, num_heads, round_multiple(max_seqlen_q, 128), round_multiple(max_seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout)
rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
return out, softmax_lse, p, rng_state

Expand Down Expand Up @@ -315,7 +319,10 @@ def _flash_attn_backward_fake(
if dv is None:
dv = torch.empty_like(v)
batch_size, seqlen_q, num_heads, _ = q.shape
softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32)
if torch.cuda.is_available() and torch.version.hip:
softmax_d = torch.empty((batch_size, num_heads, seqlen_q), device=q.device, dtype=torch.float32)
else:
softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32)

return softmax_d

Expand Down Expand Up @@ -426,7 +433,10 @@ def _flash_attn_varlen_backward_fake(
dk = torch.empty_like(k)
if dv is None:
dv = torch.empty_like(v)
softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32)
if torch.cuda.is_available() and torch.version.hip:
softmax_d = torch.empty((num_heads, total_q), device=q.device, dtype=torch.float32)
else:
softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32)

return softmax_d

Expand Down
4 changes: 4 additions & 0 deletions flash_attn/flash_attn_triton_amd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from . import interface_v2 as flash_attn_2
from . import interface_v3 as flash_attn_3

__all__ = ["flash_attn_2", "flash_attn_3"]
Loading