diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index 82526f0e875..c29c1497135 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -41,37 +41,6 @@ class ForwardMetadata: class AscendAttnBackend(AttentionBackend): - def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16): - mask_flag = torch.tril( - torch.ones((max_seq_len, max_seq_len), dtype=torch.bool) - ).view(max_seq_len, max_seq_len) - mask_flag = ~mask_flag - if dtype == torch.float16: - mask_value = torch.finfo(torch.float32).min - else: - mask_value = 1 - self.mask = ( - torch.masked_fill( - torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value - ) - .to(dtype) - .to(self.device) - ) - self.mask_len = max_seq_len - - def get_verify_buffers_to_fill_after_draft(self): - """ - Return buffers for verify attention kernels that needs to be filled after draft. - - Typically, these are tree mask and position buffers. - """ - return [None, None] - - def update_verify_buffers_to_fill_after_draft( - self, spec_info: SpecInput, cuda_graph_bs: Optional[int] - ): - pass - def __init__(self, model_runner: ModelRunner): super().__init__() self.forward_metadata = None @@ -106,34 +75,100 @@ def __init__(self, model_runner: ModelRunner): self.mtp_mask = torch.tril(torch.ones(2048, 2048, dtype=torch.bool)).npu() self.mtp_mask = ~self.mtp_mask + def gen_attention_mask(self, max_seq_len: int, dtype=torch.float16): + mask_flag = torch.tril( + torch.ones((max_seq_len, max_seq_len), dtype=torch.bool) + ).view(max_seq_len, max_seq_len) + mask_flag = ~mask_flag + if dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 + self.mask = ( + torch.masked_fill( + torch.zeros(size=(max_seq_len, max_seq_len)), mask_flag, mask_value + ) + .to(dtype) + .to(self.device) + ) + self.mask_len = max_seq_len + + def get_verify_buffers_to_fill_after_draft(self): + """ + Return buffers for verify attention kernels that needs to be filled after draft. + + Typically, these are tree mask and position buffers. + """ + return [None, None] + + def update_verify_buffers_to_fill_after_draft( + self, spec_info: SpecInput, cuda_graph_bs: Optional[int] + ): + pass + def init_forward_metadata(self, forward_batch: ForwardBatch): """Init the metadata for a forward pass.""" tp_size = get_attention_tp_size() self.forward_metadata = ForwardMetadata() - seq_lens_max = forward_batch.seq_lens.max() - if forward_batch.forward_mode.is_target_verify(): - seq_lens_max += self.speculative_num_draft_tokens - self.forward_metadata.block_tables = ( - forward_batch.req_to_token_pool.req_to_token[ - forward_batch.req_pool_indices, :seq_lens_max - ][:, :: self.page_size] - // self.page_size - ) - if forward_batch.extend_seq_lens is not None: - self.forward_metadata.extend_seq_lens_cpu_int = ( - forward_batch.extend_seq_lens.cpu().int() + if forward_batch.forward_mode.is_mixed(): + bs_prefill = forward_batch.batch_size - forward_batch.running_decode_bs + seq_lens_max_mix = ( + forward_batch.seq_lens[:bs_prefill].max(), + forward_batch.seq_lens[bs_prefill:].max(), ) - self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() - if ( - not forward_batch.forward_mode.is_draft_extend_v2() - and not forward_batch.forward_mode.is_draft_extend() - and not forward_batch.forward_mode.is_target_verify() - ): - seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu) - self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum - - if forward_batch.forward_mode.is_target_verify(): - self.forward_metadata.seq_lens_cpu_int += self.speculative_num_draft_tokens + req_pool_indices_mix = ( + forward_batch.req_pool_indices[:bs_prefill], + forward_batch.req_pool_indices[bs_prefill:], + ) + self.forward_metadata.block_tables_mix = ( + forward_batch.req_to_token_pool.req_to_token[ + req_pool_indices_mix[0], : seq_lens_max_mix[0] + ][:, :: self.page_size] + // self.page_size, + forward_batch.req_to_token_pool.req_to_token[ + req_pool_indices_mix[1], : seq_lens_max_mix[1] + ][:, :: self.page_size] + // self.page_size, + ) + if forward_batch.extend_seq_lens is not None: + self.forward_metadata.extend_seq_lens_cpu_int_mix = ( + forward_batch.extend_seq_lens.cpu().int()[:bs_prefill], + forward_batch.extend_seq_lens.cpu().int()[bs_prefill:], + ) + self.forward_metadata.seq_lens_cpu_int_mix = ( + forward_batch.seq_lens_cpu.int()[:bs_prefill], + forward_batch.seq_lens_cpu.int()[bs_prefill:], + ) + self.forward_metadata.seq_lens_list_cumsum_mix = np.cumsum( + forward_batch.extend_seq_lens_cpu[:bs_prefill] + ), np.cumsum(forward_batch.extend_seq_lens_cpu[bs_prefill:]) + else: + seq_lens_max = forward_batch.seq_lens.max() + if forward_batch.forward_mode.is_target_verify(): + seq_lens_max += self.speculative_num_draft_tokens + self.forward_metadata.block_tables = ( + forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, :seq_lens_max + ][:, :: self.page_size] + // self.page_size + ) + if forward_batch.extend_seq_lens is not None: + self.forward_metadata.extend_seq_lens_cpu_int = ( + forward_batch.extend_seq_lens.cpu().int() + ) + self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() + if ( + not forward_batch.forward_mode.is_draft_extend_v2() + and not forward_batch.forward_mode.is_draft_extend() + and not forward_batch.forward_mode.is_target_verify() + ): + seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu) + self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum + + if forward_batch.forward_mode.is_target_verify(): + self.forward_metadata.seq_lens_cpu_int += ( + self.speculative_num_draft_tokens + ) self.graph_mode = False @@ -851,6 +886,95 @@ def forward_decode( ) return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank) + def forward_mixed( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + **kwargs, + ): + out_cache_loc = forward_batch.out_cache_loc + + # Calculates the batch sizes for prefill and decode stages + out_cache_len_prefill = len(out_cache_loc) - forward_batch.running_decode_bs + bs_prefill = forward_batch.batch_size - forward_batch.running_decode_bs + + # Splits the input tensors into two parts (prefill and decode stages) based on batch sizes + q_prefill, q_decode = q[:out_cache_len_prefill], q[out_cache_len_prefill:] + k_prefill, k_decode = k[:out_cache_len_prefill], k[out_cache_len_prefill:] + v_prefill, v_decode = v[:out_cache_len_prefill], v[out_cache_len_prefill:] + loc_prefill, loc_decode = ( + out_cache_loc[:out_cache_len_prefill], + out_cache_loc[out_cache_len_prefill:], + ) + + forward_batch.out_cache_loc = loc_prefill + if not self.use_mla: + if self.use_fia: + extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu + forward_batch.extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu[ + :bs_prefill + ] + else: + self.forward_metadata.extend_seq_lens_cpu_int = ( + self.forward_metadata.extend_seq_lens_cpu_int_mix[0] + ) + self.forward_metadata.block_tables = ( + self.forward_metadata.block_tables_mix[0] + ) + self.forward_metadata.seq_lens_cpu_int = ( + self.forward_metadata.seq_lens_cpu_int_mix[0] + ) + else: + forward_batch.num_token_non_padded_cpu = forward_batch.prefill_input_ids + self.forward_metadata.seq_lens_list_cumsum = ( + self.forward_metadata.seq_lens_list_cumsum_mix[0] + ) + + # Performs the forward pass for the prefill stage + output_prefill = self.forward_extend( + q_prefill, + k_prefill, + v_prefill, + layer, + forward_batch, + save_kv_cache=save_kv_cache, + **kwargs, + ) + + forward_batch.out_cache_loc = loc_decode + self.forward_metadata.block_tables = self.forward_metadata.block_tables_mix[1] + self.forward_metadata.seq_lens_cpu_int = ( + self.forward_metadata.seq_lens_cpu_int_mix[1] + ) + + # Performs the forward pass for the decode stage + output_decode = self.forward_decode( + q_decode, + k_decode, + v_decode, + layer, + forward_batch, + save_kv_cache=save_kv_cache, + **kwargs, + ) + + # Resets forward_metadata and forward_batch properties after processing + forward_batch.out_cache_loc = out_cache_loc + forward_batch.num_token_non_padded_cpu = None + self.forward_metadata.extend_seq_lens_cpu_int = None + self.forward_metadata.seq_lens_list_cumsum = None + self.forward_metadata.block_tables = None + self.forward_metadata.seq_lens_cpu_int = None + if not self.use_mla and self.use_fia: + forward_batch.extend_seq_lens_cpu = extend_seq_lens_cpu + + # Concatenates and returns the outputs from both parts + return torch.cat([output_prefill, output_decode], dim=0) + class AscendAttnMultiStepDraftBackend: """ diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index dcbf1c8fdf1..b77aa41f96a 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -97,6 +97,16 @@ def forward( save_kv_cache=save_kv_cache, **kwargs, ) + elif forward_batch.forward_mode.is_mixed(): + return self.forward_mixed( + q, + k, + v, + layer, + forward_batch, + save_kv_cache=save_kv_cache, + **kwargs, + ) else: return self.forward_extend( q, @@ -132,6 +142,18 @@ def forward_extend( """Run a forward for extend.""" raise NotImplementedError() + def forward_mixed( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + ): + """Run a forward for mix.""" + raise NotImplementedError() + def support_triton(self): """Check if the current backend supports triton.""" return True diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 11f13383906..f165554f3ac 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1076,6 +1076,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # hicache pointer for synchronizing data loading from CPU to GPU hicache_consumer_index: int = -1 + # mix chunk + running_decode_bs: int = 0 + prefill_input_ids: int = 0 + @classmethod def init_new( cls, @@ -1393,12 +1397,14 @@ def prepare_for_split_prefill(self): def mix_with_running(self, running_batch: "ScheduleBatch"): self.forward_mode = ForwardMode.MIXED running_bs = running_batch.batch_size() + self.running_decode_bs = running_bs for req in running_batch.reqs: req.fill_ids = req.origin_input_ids + req.output_ids req.extend_input_len = 1 input_ids = torch.cat([self.input_ids, running_batch.input_ids]) + self.prefill_input_ids = self.input_ids out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) self.merge_batch(running_batch) @@ -1791,6 +1797,8 @@ def get_model_worker_batch( encoder_out_cache_loc=self.encoder_out_cache_loc, lora_ids=[req.lora_id for req in self.reqs], sampling_info=self.sampling_info, + running_decode_bs=self.running_decode_bs, + prefill_input_ids=self.prefill_input_ids, input_embeds=self.input_embeds, token_type_ids=self.token_type_ids, spec_algorithm=self.spec_algorithm, @@ -1900,6 +1908,10 @@ class ModelWorkerBatch: # Sampling info sampling_info: SamplingBatchInfo + # For mixed chunk + running_decode_bs: int + prefill_input_ids: int + # The original sequence lengths, Qwen-1M related orig_seq_lens: Optional[torch.Tensor] = None diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 9a43121e7ab..1572616338f 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -1,5 +1,7 @@ from __future__ import annotations +from python.sglang.srt.utils.common import is_npu + # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -331,14 +333,18 @@ def __init__( self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.running_batch = running_batch self.new_token_ratio = new_token_ratio - self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_chunk_tokens = rem_chunk_tokens - if self.rem_chunk_tokens is not None: - self.rem_chunk_tokens -= mixed_with_decode_tokens - - self.rem_total_token_offset = mixed_with_decode_tokens - self.cur_rem_token_offset = mixed_with_decode_tokens + if is_npu(): + self.rem_input_tokens = rem_input_tokens + self.rem_total_token_offset = 0 + self.cur_rem_token_offset = 0 + else: + self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens + if self.rem_chunk_tokens is not None: + self.rem_chunk_tokens -= mixed_with_decode_tokens + self.rem_total_token_offset = mixed_with_decode_tokens + self.cur_rem_token_offset = mixed_with_decode_tokens self.req_states = None self.can_run_list = [] self.preempt_list = [] diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index 4fefac941aa..7f537beb00e 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -558,6 +558,7 @@ def free(self, free_index: torch.Tensor): self.release_pages = torch.cat((free_page_indices, self.release_pages)) else: self.free_pages = torch.cat((free_page_indices, self.free_pages)) + self.free_pages = torch.unique(torch.sort(self.free_pages)[0]) else: self.free_group.append(free_index) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 80f8c65648c..fda9e0fb9dc 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -639,6 +639,8 @@ def capture_one_batch_size(self, bs: int, forward: Callable): num_token_non_padded=self.num_token_non_padded, global_forward_mode=self.capture_forward_mode, lora_ids=lora_ids, + running_decode_bs=None, + prefill_input_ids=None, ) self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index df0461804d6..49c51414bb7 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -184,6 +184,10 @@ class ForwardBatch: # The sum of all sequence lengths seq_lens_sum: int + # For mix chunk + running_decode_bs: int + prefill_input_ids: int + # The original sequence length without being chunked. Qwen-1M related. orig_seq_lens: Optional[torch.Tensor] = None @@ -360,6 +364,8 @@ def init_new( token_type_ids=batch.token_type_ids, tbo_split_seq_index=batch.tbo_split_seq_index, dimensions=batch.dimensions, + running_decode_bs=batch.running_decode_bs, + prefill_input_ids=batch.prefill_input_ids, ) device = model_runner.device