Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def _get_backend_priorities(
if is_aiter_found_and_supported():
backends.append(AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN)
backends.append(AttentionBackendEnum.TRITON_ATTN)
backends.append(AttentionBackendEnum.TURBOQUANT)

return backends

Expand Down
14 changes: 3 additions & 11 deletions vllm/v1/attention/backends/turboquant_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,7 @@ def _prefill_attention(
# max_query_len == max_seq_len means no request has prior cached KV.
# Both are Python ints — no GPU sync.
if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len:
output = torch.empty(N, Hq, D, device=query.device, dtype=query.dtype)
flash_attn_varlen_func(
return flash_attn_varlen_func(
q=query,
k=key,
v=value,
Expand All @@ -518,9 +517,7 @@ def _prefill_attention(
max_seqlen_k=attn_metadata.max_query_len,
softmax_scale=self.scale,
causal=True,
out=output,
)
return output

# Continuation or no flash_attn: per-request attention.
# For continuation chunks (seq_len > q_len), we must attend to
Expand Down Expand Up @@ -557,10 +554,9 @@ def _prefill_attention(
if q_len == seq_len:
# First-chunk prefill: all K/V are in the current batch.
if _HAS_FLASH_ATTN:
out = torch.empty_like(q_seq)
_cu_2[1] = q_len
cu = _cu_2
flash_attn_varlen_func(
out = flash_attn_varlen_func(
q=q_seq,
k=k_seq,
v=v_seq,
Expand All @@ -570,7 +566,6 @@ def _prefill_attention(
max_seqlen_k=q_len,
softmax_scale=self.scale,
causal=True,
out=out,
)
else:
q_t = q_seq.transpose(0, 1).contiguous()
Expand Down Expand Up @@ -733,10 +728,9 @@ def _continuation_prefill(

# Attention: q_len queries attending to seq_len K/V with causal mask
if _HAS_FLASH_ATTN:
output = torch.empty(q_len, Hq, D, device=device, dtype=query.dtype)
cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32)
flash_attn_varlen_func(
return flash_attn_varlen_func(
q=query,
k=k_full,
v=v_full,
Expand All @@ -746,9 +740,7 @@ def _continuation_prefill(
max_seqlen_k=seq_len,
softmax_scale=self.scale,
causal=True,
out=output,
)
return output
else:
# SDPA fallback: expand KV for GQA, build causal mask
q_t = query.transpose(0, 1).unsqueeze(0) # (1, Hq, q_len, D)
Expand Down
12 changes: 6 additions & 6 deletions vllm/v1/attention/ops/triton_turboquant_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ def _tq_decode_stage1(
Block_table_ptr + bt_base + page_idx,
mask=kv_mask,
other=0,
)
).to(tl.int64)

slot_bases = (
block_nums * stride_cache_block
+ page_off * stride_cache_pos
+ kv_head * stride_cache_head
+ page_off.to(tl.int64) * stride_cache_pos
+ tl.cast(kv_head, tl.int64) * stride_cache_head
)

# ============================================================
Expand Down Expand Up @@ -350,11 +350,11 @@ def _tq_full_dequant_kv(

page_idx = pos // BLOCK_SIZE
page_off = pos % BLOCK_SIZE
block_num = tl.load(Block_table_ptr + bid * stride_bt_b + page_idx)
block_num = tl.load(Block_table_ptr + bid * stride_bt_b + page_idx).to(tl.int64)
slot_base = (
block_num * stride_cache_block
+ page_off * stride_cache_pos
+ hid * stride_cache_head
+ tl.cast(page_off, tl.int64) * stride_cache_pos
+ tl.cast(hid, tl.int64) * stride_cache_head
)

d_offs = tl.arange(0, BLOCK_D)
Expand Down
18 changes: 12 additions & 6 deletions vllm/v1/attention/ops/triton_turboquant_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,13 @@ def _tq_fused_store_fp8(
slot = tl.load(Slot_mapping_ptr + token_idx)
if slot < 0:
return
blk = slot // BLOCK_SIZE
off = slot % BLOCK_SIZE
blk = (slot // BLOCK_SIZE).to(tl.int64)
off = (slot % BLOCK_SIZE).to(tl.int64)
head_idx_i64 = tl.cast(head_idx, tl.int64)
slot_base = (
blk * stride_cache_block + off * stride_cache_pos + head_idx * stride_cache_head
blk * stride_cache_block
+ off * stride_cache_pos
+ head_idx_i64 * stride_cache_head
)

base = pid * D
Expand Down Expand Up @@ -259,10 +262,13 @@ def _tq_fused_store_mse(
slot = tl.load(Slot_mapping_ptr + token_idx)
if slot < 0:
return
blk = slot // BLOCK_SIZE
off = slot % BLOCK_SIZE
blk = (slot // BLOCK_SIZE).to(tl.int64)
off = (slot % BLOCK_SIZE).to(tl.int64)
head_idx_i64 = tl.cast(head_idx, tl.int64)
slot_base = (
blk * stride_cache_block + off * stride_cache_pos + head_idx * stride_cache_head
blk * stride_cache_block
+ off * stride_cache_pos
+ head_idx_i64 * stride_cache_head
)

base = pid * D
Expand Down
Loading