Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions vllm/_xpu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,92 @@ def cp_gather_indexer_k_quant_cache(
)
dst_scale[:] = kv_cache_flat[scale_indices]

@staticmethod
def fp8_mqa_logits_torch(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
kv_fp8, scale = kv
seq_len_kv = kv_fp8.shape[0]
k = kv_fp8.to(torch.bfloat16)
q = q.to(torch.bfloat16)

mask_lo = (
torch.arange(0, seq_len_kv, device=q.device)[None, :]
>= cu_seqlen_ks[:, None]
)
mask_hi = (
torch.arange(0, seq_len_kv, device=q.device)[None, :]
< cu_seqlen_ke[:, None]
)
mask = mask_lo & mask_hi

score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
Comment on lines +455 to +456
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of torch.einsum to compute a [H, M, N] tensor poses a critical risk of out-of-memory (OOM) errors for large sequence lengths. For instance, with 128 heads and a 1GB logits budget, this intermediate tensor would require 128GB of memory. It is highly recommended to iterate over the heads and accumulate the results to keep memory usage bounded by [M, N].

Suggested change
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = torch.zeros((q.shape[0], seq_len_kv),
device=q.device,
dtype=torch.float32)
for h in range(q.shape[1]):
score_h = (q[:, h, :] @ k.T).float() * scale
logits += score_h.relu() * weights[:, h, None]

logits = logits.masked_fill(~mask, float("-inf"))
return logits

@staticmethod
def fp8_paged_mqa_logits_torch(
q: torch.Tensor,
kv_cache: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
max_model_len: int,
) -> torch.Tensor:
from vllm.utils.math_utils import cdiv

fp8_dtype = current_platform.fp8_dtype()
batch_size, next_n, _, dim = q.size()
kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
scale = scale.contiguous().view(torch.float)
q = q.float()
kv_cache = kv_cache.view(fp8_dtype).float() * scale
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Dequantizing the entire kv_cache block pool into float32 is extremely memory-intensive and inefficient, especially for large context windows where the block pool can be very large. This allocation happens on every call and can lead to OOM or significant performance degradation. It is better to dequantize only the specific blocks required for the current batch inside the loop.

_, block_size, _, _ = kv_cache.size()
logits = torch.full(
[batch_size * next_n, max_model_len],
float("-inf"),
device=q.device,
dtype=torch.float32,
)
for i in range(batch_size):
context_len = context_lens[i].item()
q_offsets = torch.arange(context_len - next_n, context_len, device=q.device)
weight_slice = (
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
)
for block_idx in range(cdiv(context_len, block_size)):
block_id = block_tables[i][block_idx]
qx, kx = q[i], kv_cache[block_id]
k_offsets = torch.arange(
block_idx * block_size,
(block_idx + 1) * block_size,
device=q.device,
)
mask = (k_offsets[None, :] < context_len) & (
k_offsets[None, :] <= q_offsets[:, None]
)
s = torch.where(
mask[None, :, :],
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
logits.dtype
),
float("-inf"),
)
s = torch.relu(s) * weight_slice[..., None]
s = s.sum(dim=0)
logits[
i * next_n : (i + 1) * next_n,
block_idx * block_size : (block_idx + 1) * block_size,
] = torch.where(k_offsets[None, :] <= q_offsets[:, None],
s,
float("-inf"))
return logits

@staticmethod
def top_k_per_row_prefill(
logits: torch.Tensor,
Expand Down
55 changes: 37 additions & 18 deletions vllm/model_executor/layers/sparse_attn_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,23 @@ def sparse_attn_indexer(
chunk.cu_seq_lens,
)

logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
clean_logits=False,
)
if current_platform.is_xpu():
logits = ops.fp8_mqa_logits_torch(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
else:
logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
clean_logits=False,
)
num_rows = logits.shape[0]

topk_indices = topk_indices_buffer[
Expand Down Expand Up @@ -191,16 +200,26 @@ def sparse_attn_indexer(
next_n = padded_q_fp8_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
logits = fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
clean_logits=False,
)
if current_platform.is_xpu():
logits = ops.fp8_paged_mqa_logits_torch(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
max_model_len=max_model_len,
)
else:
logits = fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
clean_logits=False,
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]

Expand Down
Loading