diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 226bc436058d..d124e30a3f67 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -454,6 +454,16 @@ def __init__(self, parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.aot_schedule = current_platform.is_cuda() + + self.speculative_config = vllm_config.speculative_config + # Set reorder_batch_threshold based on speculative config + if (self.speculative_config is not None and + self.speculative_config.num_speculative_tokens is not None): + self.reorder_batch_threshold = ( # type: ignore[misc] + 1 + self.speculative_config.num_speculative_tokens) + else: + self.reorder_batch_threshold = 1 # type: ignore[misc] + try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group @@ -662,9 +672,10 @@ def build(self, num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu) + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=self.reorder_batch_threshold) + split_decodes_and_prefills(common_attn_metadata,self.reorder_batch_threshold + ,require_uniform=True) # Note(hc): update seq_lens of decode reqs under DCP. if self.dcp_world_size > 1: diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 2f13f19218d9..9c33e2217a16 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -184,8 +184,23 @@ def _forward_decode( q = torch.cat(q, dim=-1) assert isinstance(q, torch.Tensor) + + batch_size = attn_metadata.decode.seq_lens.shape[0] + total_tokens = q.shape[0] + num_heads = q.shape[1] + head_dim = q.shape[2] + + # support uniform batch + if total_tokens % batch_size == 0: + seq_len = total_tokens // batch_size + q = q.view(batch_size, seq_len, num_heads, head_dim) + else: + raise ValueError( + f"total_tokens={total_tokens}, batch_size={batch_size}. " + f"Expected uniform batches with seq_len=1 or seq_len=2.") + o, lse = flash_mla_with_kvcache( - q=q.unsqueeze(1), # Add seqlen dim of 1 (decode) + q=q, k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 block_table=attn_metadata.decode.block_table, cache_seqlens=attn_metadata.decode.seq_lens, @@ -199,4 +214,6 @@ def _forward_decode( descale_k=layer._k_scale.reshape(1), ) + o = o.view(total_tokens, num_heads, self.kv_lora_rank) + return o, lse diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index b286a4ba9fe5..bf3f6da322c1 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -645,35 +645,41 @@ def subclass_attention_backend( def split_decodes_and_prefills( common_attn_metadata: CommonAttentionMetadata, decode_threshold: int = 1, + require_uniform: bool = False, ) -> tuple[int, int, int, int]: """ Assuming a reordered batch, finds the boundary between prefill and decode requests. - Args: common_attn_metadata: CommonAttentionMetadata object containing the batch metadata. decode_threshold: The maximum query length to be considered a decode. - + require_uniform: If True, only selects decode requests with the same + query length for uniform batching + If False, selects all decode requests regardless of + length variation. + Returns: num_decodes: The number of decode requests. num_prefills: The number of prefill requests. num_decode_tokens: The number of tokens in the decode requests. num_prefill_tokens: The number of tokens in the prefill requests. """ + + if require_uniform: + return split_decodes_and_prefills_uniform(common_attn_metadata, + decode_threshold) + max_query_len = common_attn_metadata.max_query_len num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu - if max_query_len <= decode_threshold: return num_reqs, 0, num_tokens, 0 - query_lens = query_start_loc[1:] - query_start_loc[:-1] is_prefill = query_lens > decode_threshold if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0 - first_prefill = is_prefill.int().argmax(dim=-1).item() assert torch.all(query_lens[first_prefill:] > decode_threshold) assert torch.all(query_lens[:first_prefill] <= decode_threshold) @@ -684,6 +690,53 @@ def split_decodes_and_prefills( return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) +def split_decodes_and_prefills_uniform( + common_attn_metadata: CommonAttentionMetadata, + decode_threshold: int = 1, +) -> tuple[int, int, int, int]: + """ + Similar to split_decodes_and_prefills but ensures decode batch is uniform. + Only selects decode requests with the same query length. + """ + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc_cpu + query_lens = query_start_loc[1:] - query_start_loc[:-1] + # find all candidates that satisfy the threshold + decode_candidates = query_lens <= decode_threshold + + if not torch.any(decode_candidates): + return 0, num_reqs, 0, num_tokens + + first_len = None + first_prefill = 0 + + # find the longest continuous uniform sequence from the front + for i in range(num_reqs): + current_len = query_lens[i].item() + if current_len > decode_threshold: + # prefill request,stop + break + if first_len is None: + # the first decode request + first_len = current_len + first_prefill = 1 + elif current_len == first_len: + # same length, continue + first_prefill = i + 1 + else: + # different length, stop + break + + num_decodes = first_prefill + num_prefills = num_reqs - num_decodes + num_decode_tokens = query_start_loc[first_prefill].item( + ) if first_prefill < len(query_start_loc) else num_tokens + num_prefill_tokens = num_tokens - num_decode_tokens + + return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) + + def reorder_batch_to_split_decodes_and_prefills( input_batch: "InputBatch", scheduler_output: "SchedulerOutput",