diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 11a85e924b8f..89714f00f640 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -382,6 +382,7 @@ def _get_backend_priorities( if is_aiter_found_and_supported(): backends.append(AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN) backends.append(AttentionBackendEnum.TRITON_ATTN) + backends.append(AttentionBackendEnum.TURBOQUANT) return backends diff --git a/vllm/v1/attention/backends/turboquant_attn.py b/vllm/v1/attention/backends/turboquant_attn.py index 279fcb04ace4..2659cbffa824 100644 --- a/vllm/v1/attention/backends/turboquant_attn.py +++ b/vllm/v1/attention/backends/turboquant_attn.py @@ -507,8 +507,7 @@ def _prefill_attention( # max_query_len == max_seq_len means no request has prior cached KV. # Both are Python ints — no GPU sync. if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len: - output = torch.empty(N, Hq, D, device=query.device, dtype=query.dtype) - flash_attn_varlen_func( + return flash_attn_varlen_func( q=query, k=key, v=value, @@ -518,9 +517,7 @@ def _prefill_attention( max_seqlen_k=attn_metadata.max_query_len, softmax_scale=self.scale, causal=True, - out=output, ) - return output # Continuation or no flash_attn: per-request attention. # For continuation chunks (seq_len > q_len), we must attend to @@ -557,10 +554,9 @@ def _prefill_attention( if q_len == seq_len: # First-chunk prefill: all K/V are in the current batch. if _HAS_FLASH_ATTN: - out = torch.empty_like(q_seq) _cu_2[1] = q_len cu = _cu_2 - flash_attn_varlen_func( + out = flash_attn_varlen_func( q=q_seq, k=k_seq, v=v_seq, @@ -570,7 +566,6 @@ def _prefill_attention( max_seqlen_k=q_len, softmax_scale=self.scale, causal=True, - out=out, ) else: q_t = q_seq.transpose(0, 1).contiguous() @@ -733,10 +728,9 @@ def _continuation_prefill( # Attention: q_len queries attending to seq_len K/V with causal mask if _HAS_FLASH_ATTN: - output = torch.empty(q_len, Hq, D, device=device, dtype=query.dtype) cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32) cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32) - flash_attn_varlen_func( + return flash_attn_varlen_func( q=query, k=k_full, v=v_full, @@ -746,9 +740,7 @@ def _continuation_prefill( max_seqlen_k=seq_len, softmax_scale=self.scale, causal=True, - out=output, ) - return output else: # SDPA fallback: expand KV for GQA, build causal mask q_t = query.transpose(0, 1).unsqueeze(0) # (1, Hq, q_len, D) diff --git a/vllm/v1/attention/ops/triton_turboquant_decode.py b/vllm/v1/attention/ops/triton_turboquant_decode.py index 8b276e31eafb..fbedddea8fac 100644 --- a/vllm/v1/attention/ops/triton_turboquant_decode.py +++ b/vllm/v1/attention/ops/triton_turboquant_decode.py @@ -137,12 +137,12 @@ def _tq_decode_stage1( Block_table_ptr + bt_base + page_idx, mask=kv_mask, other=0, - ) + ).to(tl.int64) slot_bases = ( block_nums * stride_cache_block - + page_off * stride_cache_pos - + kv_head * stride_cache_head + + page_off.to(tl.int64) * stride_cache_pos + + tl.cast(kv_head, tl.int64) * stride_cache_head ) # ============================================================ @@ -350,11 +350,11 @@ def _tq_full_dequant_kv( page_idx = pos // BLOCK_SIZE page_off = pos % BLOCK_SIZE - block_num = tl.load(Block_table_ptr + bid * stride_bt_b + page_idx) + block_num = tl.load(Block_table_ptr + bid * stride_bt_b + page_idx).to(tl.int64) slot_base = ( block_num * stride_cache_block - + page_off * stride_cache_pos - + hid * stride_cache_head + + tl.cast(page_off, tl.int64) * stride_cache_pos + + tl.cast(hid, tl.int64) * stride_cache_head ) d_offs = tl.arange(0, BLOCK_D) diff --git a/vllm/v1/attention/ops/triton_turboquant_store.py b/vllm/v1/attention/ops/triton_turboquant_store.py index 3da3347d5df5..3ad2d41488e7 100644 --- a/vllm/v1/attention/ops/triton_turboquant_store.py +++ b/vllm/v1/attention/ops/triton_turboquant_store.py @@ -174,10 +174,13 @@ def _tq_fused_store_fp8( slot = tl.load(Slot_mapping_ptr + token_idx) if slot < 0: return - blk = slot // BLOCK_SIZE - off = slot % BLOCK_SIZE + blk = (slot // BLOCK_SIZE).to(tl.int64) + off = (slot % BLOCK_SIZE).to(tl.int64) + head_idx_i64 = tl.cast(head_idx, tl.int64) slot_base = ( - blk * stride_cache_block + off * stride_cache_pos + head_idx * stride_cache_head + blk * stride_cache_block + + off * stride_cache_pos + + head_idx_i64 * stride_cache_head ) base = pid * D @@ -259,10 +262,13 @@ def _tq_fused_store_mse( slot = tl.load(Slot_mapping_ptr + token_idx) if slot < 0: return - blk = slot // BLOCK_SIZE - off = slot % BLOCK_SIZE + blk = (slot // BLOCK_SIZE).to(tl.int64) + off = (slot % BLOCK_SIZE).to(tl.int64) + head_idx_i64 = tl.cast(head_idx, tl.int64) slot_base = ( - blk * stride_cache_block + off * stride_cache_pos + head_idx * stride_cache_head + blk * stride_cache_block + + off * stride_cache_pos + + head_idx_i64 * stride_cache_head ) base = pid * D