Skip to content

Commit 40654e4

Browse files
committed
add rocm aiter backend
Signed-off-by: fsx950223 <[email protected]>
1 parent efe59bd commit 40654e4

File tree

4 files changed

+614
-222
lines changed

4 files changed

+614
-222
lines changed

vllm/platforms/rocm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,11 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
190190
selected_backend = (_Backend.ROCM_FLASH if selected_backend
191191
== _Backend.FLASH_ATTN else selected_backend)
192192
if envs.VLLM_USE_V1:
193-
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA:
193+
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
194+
and on_mi250_mi300():
194195
logger.info("Using Flash Attention backend on V1 engine.")
195196
return ("vllm.v1.attention.backends."
196-
"flash_attn.FlashAttentionBackend")
197+
"rocm_aiter_fa.AiterFlashAttentionBackend")
197198
else:
198199
logger.info("Using Triton Attention backend on V1 engine.")
199200
return ("vllm.v1.attention.backends."

vllm/v1/attention/backends/flash_attn.py

Lines changed: 22 additions & 211 deletions
Original file line numberDiff line numberDiff line change
@@ -30,144 +30,6 @@
3030
if current_platform.is_cuda():
3131
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
3232
get_scheduler_metadata)
33-
if current_platform.is_rocm():
34-
import aiter
35-
36-
from vllm.attention.ops.triton_unified_attention import unified_attention
37-
from vllm.triton_utils import tl, triton
38-
from vllm.utils import direct_register_custom_op
39-
40-
@triton.jit
41-
def _vllm_layout_trans_kernel(
42-
k_buffer_ptr,
43-
v_buffer_ptr,
44-
k_values_ptr,
45-
v_values_ptr,
46-
b_seq_lens_loc,
47-
block_table,
48-
block_table_stride_0,
49-
E_DIM: tl.constexpr,
50-
BLOCK_SIZE: tl.constexpr,
51-
):
52-
batch_idx = tl.program_id(0)
53-
block_idx = tl.program_id(1)
54-
batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx +
55-
tl.arange(0, 2))
56-
batch_token_start, batch_token_end = tl.split(batch_token_indexes)
57-
seq_len = batch_token_end - batch_token_start
58-
if block_idx * BLOCK_SIZE < seq_len:
59-
block_mask = (block_idx * BLOCK_SIZE +
60-
tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len
61-
62-
kv_idx = tl.load(block_table + batch_idx * block_table_stride_0 +
63-
block_idx)
64-
65-
kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange(
66-
0, BLOCK_SIZE)[:, None] * E_DIM + tl.arange(0, E_DIM)[None, :]
67-
k_vals = tl.load(k_buffer_ptr + kv_buffer_off,
68-
mask=block_mask,
69-
other=0.0)
70-
v_vals = tl.load(v_buffer_ptr + kv_buffer_off,
71-
mask=block_mask,
72-
other=0.0)
73-
74-
kv_values_off = batch_token_start * E_DIM + \
75-
block_idx * BLOCK_SIZE * E_DIM + \
76-
tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + \
77-
tl.arange(0, E_DIM)[None, :]
78-
tl.store(k_values_ptr + kv_values_off, k_vals, mask=block_mask)
79-
tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask)
80-
81-
def vllm_layout_trans(b_seq_lens_loc, block_table, k_buffer, v_buffer,
82-
max_seq_len, total_tokens):
83-
H_KV = v_buffer.shape[2]
84-
D = v_buffer.shape[3]
85-
BLOCK_SIZE = v_buffer.shape[1]
86-
dtype = k_buffer.dtype
87-
k_values = torch.empty((total_tokens, H_KV, D),
88-
dtype=dtype,
89-
device="cuda")
90-
v_values = torch.empty((total_tokens, H_KV, D),
91-
dtype=dtype,
92-
device="cuda")
93-
94-
grid = (block_table.shape[0],
95-
(max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
96-
97-
_vllm_layout_trans_kernel[grid](k_buffer,
98-
v_buffer,
99-
k_values,
100-
v_values,
101-
b_seq_lens_loc,
102-
block_table,
103-
block_table.stride(0),
104-
E_DIM=H_KV * D,
105-
BLOCK_SIZE=BLOCK_SIZE)
106-
107-
return k_values, v_values
108-
109-
def flash_attn_varlen_func_impl(
110-
q: torch.Tensor,
111-
k_cache: torch.Tensor,
112-
v_cache: torch.Tensor,
113-
out: torch.Tensor,
114-
cu_seqlens_q: torch.Tensor,
115-
cu_seqlens_k: torch.Tensor,
116-
total_tokens: int,
117-
max_seqlen_q: int,
118-
max_seqlen_k: int,
119-
softmax_scale: float,
120-
window_size: Optional[list[int]], # -1 means infinite context window
121-
alibi_slopes: Optional[list[float]],
122-
block_table: torch.Tensor,
123-
) -> torch.Tensor:
124-
k, v = vllm_layout_trans(cu_seqlens_k, block_table, k_cache, v_cache,
125-
max_seqlen_k, total_tokens)
126-
output = aiter.flash_attn_varlen_func(
127-
q=q,
128-
k=k,
129-
v=v,
130-
cu_seqlens_q=cu_seqlens_q,
131-
max_seqlen_q=max_seqlen_q,
132-
cu_seqlens_k=cu_seqlens_k,
133-
max_seqlen_k=max_seqlen_k,
134-
softmax_scale=softmax_scale,
135-
causal=True,
136-
alibi_slopes=alibi_slopes,
137-
window_size=window_size,
138-
out=out,
139-
)
140-
return output
141-
142-
def flash_attn_varlen_func_fake(
143-
q: torch.Tensor,
144-
k_cache: torch.Tensor,
145-
v_cache: torch.Tensor,
146-
out: torch.Tensor,
147-
cu_seqlens_q: torch.Tensor,
148-
cu_seqlens_k: torch.Tensor,
149-
total_tokens: int,
150-
max_seqlen_q: int,
151-
max_seqlen_k: int,
152-
softmax_scale: float,
153-
window_size: Optional[list[int]], # -1 means infinite context window
154-
alibi_slopes: Optional[list[float]],
155-
block_table: torch.Tensor,
156-
) -> torch.Tensor:
157-
return torch.empty(q.shape[0],
158-
q.shape[1],
159-
v_cache.shape[-2],
160-
dtype=torch.float8_e4m3fnuz,
161-
device="cuda")
162-
163-
try:
164-
direct_register_custom_op("flash_attn_varlen_func",
165-
flash_attn_varlen_func_impl, ["out"],
166-
flash_attn_varlen_func_fake)
167-
flash_attn_varlen_func = torch.ops.vllm.flash_attn_varlen_func
168-
169-
except AttributeError:
170-
flash_attn_varlen_func = flash_attn_varlen_func_impl
17133

17234
logger = init_logger(__name__)
17335

@@ -223,8 +85,6 @@ class FlashAttentionMetadata:
22385
query_start_loc: torch.Tensor
22486
max_seq_len: int
22587
seq_lens: torch.Tensor
226-
cu_seq_lens: torch.Tensor
227-
total_tokens: int
22888
block_table: torch.Tensor
22989
slot_mapping: torch.Tensor
23090

@@ -466,7 +326,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
466326
common_prefix_len: int,
467327
common_attn_metadata: CommonAttentionMetadata):
468328
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
469-
total_tokens = self.runner.seq_lens_np[:num_reqs].sum()
470329
query_start_loc = common_attn_metadata.query_start_loc
471330
seq_lens = common_attn_metadata.seq_lens
472331
block_table = self.block_table
@@ -481,13 +340,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
481340

482341
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
483342

484-
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
485-
dtype=torch.int32,
486-
device="cuda")
487-
torch.cumsum(seq_lens,
488-
dim=0,
489-
dtype=cu_seq_lens.dtype,
490-
out=cu_seq_lens[1:])
491343
if self.aot_sliding_window is None:
492344
self.aot_sliding_window = (-1, -1)
493345
# For the AOT scheduler we need the sliding window value to be
@@ -601,8 +453,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
601453
query_start_loc=query_start_loc,
602454
max_seq_len=max_seq_len,
603455
seq_lens=seq_lens,
604-
cu_seq_lens=cu_seq_lens,
605-
total_tokens=total_tokens,
606456
block_table=block_table_tensor,
607457
slot_mapping=slot_mapping,
608458
use_cascade=use_cascade,
@@ -768,67 +618,28 @@ def forward(
768618
scheduler_metadata = attn_metadata.scheduler_metadata
769619

770620
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
771-
if current_platform.is_rocm():
772-
cu_seq_lens = attn_metadata.cu_seq_lens
773-
total_tokens = attn_metadata.total_tokens
774-
if max_seqlen_q <= 1:
775-
unified_attention(
776-
q=query[:num_actual_tokens],
777-
k=key_cache,
778-
v=value_cache,
779-
out=output[:num_actual_tokens],
780-
cu_seqlens_q=cu_seqlens_q,
781-
max_seqlen_q=max_seqlen_q,
782-
seqused_k=seqused_k,
783-
max_seqlen_k=max_seqlen_k,
784-
softmax_scale=self.scale,
785-
causal=True,
786-
alibi_slopes=self.alibi_slopes,
787-
window_size=self.sliding_window,
788-
block_table=block_table,
789-
softcap=self.logits_soft_cap,
790-
q_descale=None, # Not supported
791-
k_descale=layer._k_scale.expand(descale_shape),
792-
v_descale=layer._v_scale.expand(descale_shape),
793-
)
794-
else:
795-
flash_attn_varlen_func(
796-
query[:num_actual_tokens],
797-
key_cache,
798-
value_cache,
799-
out=output[:num_actual_tokens],
800-
cu_seqlens_q=cu_seqlens_q,
801-
max_seqlen_q=max_seqlen_q,
802-
max_seqlen_k=max_seqlen_k,
803-
total_tokens=total_tokens,
804-
softmax_scale=self.scale,
805-
alibi_slopes=self.alibi_slopes,
806-
window_size=list(self.sliding_window),
807-
block_table=block_table,
808-
cu_seqlens_k=cu_seq_lens,
809-
)
810-
else:
811-
flash_attn_varlen_func(
812-
q=query[:num_actual_tokens],
813-
k=key_cache,
814-
v=value_cache,
815-
out=output[:num_actual_tokens],
816-
cu_seqlens_q=cu_seqlens_q,
817-
max_seqlen_q=max_seqlen_q,
818-
seqused_k=seqused_k,
819-
max_seqlen_k=max_seqlen_k,
820-
softmax_scale=self.scale,
821-
causal=True,
822-
alibi_slopes=self.alibi_slopes,
823-
window_size=self.sliding_window,
824-
block_table=block_table,
825-
softcap=self.logits_soft_cap,
826-
scheduler_metadata=scheduler_metadata,
827-
fa_version=self.vllm_flash_attn_version,
828-
q_descale=layer._q_scale.expand(descale_shape),
829-
k_descale=layer._k_scale.expand(descale_shape),
830-
v_descale=layer._v_scale.expand(descale_shape),
831-
)
621+
622+
flash_attn_varlen_func(
623+
q=query[:num_actual_tokens],
624+
k=key_cache,
625+
v=value_cache,
626+
out=output[:num_actual_tokens],
627+
cu_seqlens_q=cu_seqlens_q,
628+
max_seqlen_q=max_seqlen_q,
629+
seqused_k=seqused_k,
630+
max_seqlen_k=max_seqlen_k,
631+
softmax_scale=self.scale,
632+
causal=True,
633+
alibi_slopes=self.alibi_slopes,
634+
window_size=self.sliding_window,
635+
block_table=block_table,
636+
softcap=self.logits_soft_cap,
637+
scheduler_metadata=scheduler_metadata,
638+
fa_version=self.vllm_flash_attn_version,
639+
q_descale=layer._q_scale.expand(descale_shape),
640+
k_descale=layer._k_scale.expand(descale_shape),
641+
v_descale=layer._v_scale.expand(descale_shape),
642+
)
832643
return output
833644

834645
assert not use_local_attn, (

0 commit comments

Comments
 (0)