diff --git a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py index 34229658598b..cd57120fc027 100644 --- a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py +++ b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py @@ -5,7 +5,9 @@ from typing import TYPE_CHECKING import torch +import torch.nn.functional as F +from sglang.srt.environ import envs from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo @@ -162,6 +164,16 @@ def process_prebuilt( hidden_states_list = [req.hidden_states_tensor for req in self.reqs] hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device) + enable_spec_v2_zero_bubble = envs.SGLANG_SPEC_V2_ZERO_BUBBLE.get() + + if enable_spec_v2_zero_bubble and server_args.speculative_num_steps > 1: + topk_pad_size = ( + server_args.speculative_num_steps * num_states - topk_p.shape[-1] + ) + + topk_p = F.pad(topk_p, (0, topk_pad_size)) + topk_index = F.pad(topk_index, (0, topk_pad_size)) + # local import to avoid circular import from sglang.srt.speculative.eagle_info import EagleDraftInput diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 797502679b47..7e9423fbc6cf 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -449,6 +449,7 @@ class Envs: SGLANG_SPEC_ENABLE_STRICT_FILTER_CHECK = EnvBool(True) SGLANG_SPEC_NAN_DETECTION = EnvBool(False) SGLANG_SPEC_OOB_DETECTION = EnvBool(False) + SGLANG_SPEC_V2_ZERO_BUBBLE = EnvBool(False) # VLM SGLANG_VLM_CACHE_SIZE_MB = EnvInt(100) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py index dc40a0fb8843..c15a68ef51df 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py @@ -179,10 +179,10 @@ def replay( # Replay if not is_deepseek_nsa(self.model_runner.model_config.hf_config): if forward_batch.forward_mode.is_target_verify(): - seq_lens_cpu = forward_batch.seq_lens.cpu() + self.num_tokens_per_bs + seq_lens_cpu = forward_batch.seq_lens_cpu + self.num_tokens_per_bs seq_lens = seq_lens_cpu.tolist() + [0] * (self.bs - self.raw_bs) else: - seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * ( + seq_lens = forward_batch.seq_lens_cpu.tolist() + [0] * ( self.bs - self.raw_bs ) thread = threading.Thread(target=self._update_inputs, args=(seq_lens,)) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 831b3b6a003e..e847b19cc4ad 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -63,7 +63,7 @@ is_npu, support_triton, ) -from sglang.srt.utils.common import ceil_align +from sglang.srt.utils.common import ceil_align, is_pin_memory_available if TYPE_CHECKING: from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -486,6 +486,7 @@ def init_new( rids=[req.rid for req in batch.reqs], ) device = model_runner.device + _pin = is_pin_memory_available(device) if batch.extend_input_logprob_token_ids is not None: ret.extend_input_logprob_token_ids_gpu = ( @@ -494,9 +495,9 @@ def init_new( num_tokens = len(batch.input_ids) if batch.input_ids is not None else 0 if enable_num_token_non_padded(model_runner.server_args): - ret.num_token_non_padded = torch.tensor(num_tokens, dtype=torch.int32).to( - device, non_blocking=True - ) + ret.num_token_non_padded = torch.tensor( + num_tokens, dtype=torch.int32, pin_memory=_pin + ).to(device, non_blocking=True) ret.num_token_non_padded_cpu = num_tokens # For MLP sync @@ -516,15 +517,18 @@ def init_new( ret.original_global_num_tokens_cpu = batch.global_num_tokens ret.global_num_tokens_cpu = global_num_tokens ret.global_num_tokens_gpu = torch.tensor( - global_num_tokens, dtype=torch.int64 + global_num_tokens, dtype=torch.int64, pin_memory=_pin ).to(device, non_blocking=True) ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob ret.global_num_tokens_for_logprob_gpu = torch.tensor( - global_num_tokens_for_logprob, dtype=torch.int64 + global_num_tokens_for_logprob, dtype=torch.int64, pin_memory=_pin ).to(device, non_blocking=True) if ret.forward_mode.is_idle(): + if _is_npu: + # This synchronize is necessary to prevent the system from hanging on npu. + torch.npu.synchronize() ret.positions = torch.empty((0,), dtype=torch.int64, device=device) return ret @@ -540,6 +544,7 @@ def init_new( for i in range(block_offset, block_offset + block_size) ], dtype=positions_dtype, + pin_memory=_pin, ).to(device, non_blocking=True) elif ( ret.spec_info is not None @@ -555,10 +560,10 @@ def init_new( assert isinstance(batch.extend_seq_lens, list) assert isinstance(batch.extend_prefix_lens, list) ret.extend_seq_lens = torch.tensor( - batch.extend_seq_lens, dtype=torch.int32 + batch.extend_seq_lens, dtype=torch.int32, pin_memory=_pin ).to(device, non_blocking=True) ret.extend_prefix_lens = torch.tensor( - batch.extend_prefix_lens, dtype=torch.int32 + batch.extend_prefix_lens, dtype=torch.int32, pin_memory=_pin ).to(device, non_blocking=True) ret.extend_num_tokens = batch.extend_num_tokens positions, ret.extend_start_loc = compute_position( @@ -761,6 +766,7 @@ def _compute_mrope_positions( # batch_size * [3 * seq_len] batch_size = self.seq_lens_cpu.shape[0] mrope_positions_list = [[]] * batch_size + _pin = is_pin_memory_available(model_runner.device) for batch_idx in range(batch_size): mm_input = batch.multimodal_inputs[batch_idx] if self.forward_mode.is_decode(): @@ -812,10 +818,20 @@ def _compute_mrope_positions( ) mrope_positions_list[batch_idx] = mrope_positions - self.mrope_positions = torch.cat( - [pos for pos in mrope_positions_list], - dim=1, - ).to(dtype=torch.int64, device=model_runner.device, non_blocking=True) + if _pin: + self.mrope_positions = ( + torch.cat( + [pos for pos in mrope_positions_list], + dim=1, + ) + .pin_memory() + .to(dtype=torch.int64, device=model_runner.device, non_blocking=True) + ) + else: + self.mrope_positions = torch.cat( + [pos for pos in mrope_positions_list], + dim=1, + ).to(dtype=torch.int64, device=model_runner.device, non_blocking=True) def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0): if value == 0: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index bde3c135e8e5..1a86e70afc8c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -348,6 +348,7 @@ def __init__( self.init_new_workspace = False self.draft_model_idx = draft_model_idx self.enable_hisparse = server_args.enable_hisparse + self.enable_spec_v2_zero_bubble = envs.SGLANG_SPEC_V2_ZERO_BUBBLE.get() self.remote_instance_transfer_engine = None self.remote_instance_transfer_engine_session_id = "" @@ -2920,6 +2921,7 @@ def _forward_raw( and forward_batch.global_num_tokens_gpu is not None and require_gathered_buffer(self.server_args) and not is_nsa_enable_prefill_cp() + and not self.enable_spec_v2_zero_bubble ): forward_batch.adjust_num_token_non_padded_for_attn_tp( server_args=self.server_args, diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 40e859b2d6d6..fa410621a237 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -6,6 +6,7 @@ import torch +from sglang.srt.environ import envs from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len from sglang.srt.model_executor.cuda_graph_runner import ( CUDA_GRAPH_CAPTURE_FAILED_MSG, @@ -79,6 +80,7 @@ def __init__(self, eagle_worker: EAGLEWorker): ) self.enable_pdmux = False self.deepep_adapter = DeepEPCudaGraphRunnerAdapter() + self.enable_spec_v2_zero_bubble = envs.SGLANG_SPEC_V2_ZERO_BUBBLE.get() # Batch sizes to capture self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) @@ -329,7 +331,13 @@ def run_once(): output_cache_loc_backup = forward_batch.out_cache_loc hidden_states_backup = forward_batch.spec_info.hidden_states - ret = self.eagle_worker.draft_forward(forward_batch) + if self.enable_spec_v2_zero_bubble: + assert hasattr( + self.eagle_worker, "draft_forward_zero_bubble" + ), "`Spec v2 zero bubble` just support when enable `overlap scheduler` and enable `eagle algorithm` now" + ret = self.eagle_worker.draft_forward_zero_bubble(forward_batch) + else: + ret = self.eagle_worker.draft_forward(forward_batch) forward_batch.out_cache_loc = output_cache_loc_backup forward_batch.spec_info.hidden_states = hidden_states_backup @@ -348,6 +356,10 @@ def run_once(): def _postprocess_output_to_raw_bs(self, out, raw_bs): # Keep the variables name for readability + if self.enable_spec_v2_zero_bubble: + ret_topk_p_list, ret_topk_index_list = (t[:raw_bs] for t in out) + return ret_topk_p_list, ret_topk_index_list + parent_list, top_scores_index, draft_tokens = (t[:raw_bs] for t in out) return parent_list, top_scores_index, draft_tokens diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 86f72527f0d3..3c558cc9761a 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -4,6 +4,7 @@ from typing import List, Optional, Tuple import torch +import torch.nn.functional as F from sglang.srt.environ import envs from sglang.srt.hardware_backend.npu.graph_runner.eagle_draft_extend_npu_graph_runner import ( @@ -27,7 +28,11 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.tp_worker import TpModelWorker -from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.base_spec_worker import BaseDraftWorker, BaseSpecWorker from sglang.srt.speculative.draft_utils import DraftBackendFactory @@ -115,6 +120,7 @@ def __init__( self.speculative_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm ) + self.enable_spec_v2_zero_bubble = envs.SGLANG_SPEC_V2_ZERO_BUBBLE.get() # Do not capture cuda graph in `TpModelWorker` init, # will capture later with init_cuda_graphs() @@ -317,7 +323,76 @@ def init_cuda_graphs(self): f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." ) - def draft(self, model_worker_batch: ModelWorkerBatch): + @staticmethod + def draft_wrapper(draft_func): + def wrapper(self, model_worker_batch: ModelWorkerBatch): + parent_list, top_scores_index, draft_tokens, verified_id = draft_func( + self, model_worker_batch + ) + + if model_worker_batch.forward_mode.is_idle(): + return EagleVerifyInput.create_idle_input( + self.topk, + self.speculative_num_steps, + self.speculative_num_draft_tokens, + ) + + # Build tree mask + # Directly write to cuda graph buffers for verify attn + tree_mask_buf, position_buf = ( + self.target_worker.model_runner.attn_backend.get_verify_buffers_to_fill_after_draft() + ) + + ( + tree_mask, + position, + retrive_index, + retrive_next_token, + retrive_next_sibling, + draft_tokens, + ) = build_tree_kernel_efficient( + verified_id, + parent_list, + top_scores_index, + draft_tokens, + model_worker_batch.seq_lens, + model_worker_batch.seq_lens_sum, + self.topk, + self.speculative_num_steps, + self.speculative_num_draft_tokens, + self.tree_mask_mode, + tree_mask_buf, + position_buf, + ) + + return EagleVerifyInput( + draft_token=draft_tokens, + custom_mask=tree_mask, + positions=position, + retrive_index=retrive_index, + retrive_next_token=retrive_next_token, + retrive_next_sibling=retrive_next_sibling, + retrive_cum_len=None, + spec_steps=self.speculative_num_steps, + topk=self.topk, + draft_token_num=self.speculative_num_draft_tokens, + capture_hidden_mode=None, + seq_lens_sum=None, + seq_lens_cpu=None, + ) + + return wrapper + + @draft_wrapper + def prepare_verify_fully_async_decoding(self, model_worker_batch): + draft_input = model_worker_batch.spec_info + parent_list, top_scores_index, draft_tokens = self.draft_forward_for_prepare( + draft_input + ) + return parent_list, top_scores_index, draft_tokens, draft_input.verified_id + + @draft_wrapper + def draft(self, model_worker_batch): draft_input: EagleDraftInput = model_worker_batch.spec_info forward_batch, can_cuda_graph = draft_input.prepare_for_v2_draft( self.req_to_token_pool, @@ -345,57 +420,119 @@ def draft(self, model_worker_batch: ModelWorkerBatch): forward_batch ) - if model_worker_batch.forward_mode.is_idle(): - return EagleVerifyInput.create_idle_input( - self.topk, - self.speculative_num_steps, - self.speculative_num_draft_tokens, - ) + return parent_list, top_scores_index, draft_tokens, draft_input.verified_id + + def draft_zero_bubble( + self, + model_worker_batch: ModelWorkerBatch, + batch_result: GenerationBatchResult, + draft_input: EagleDraftInput, + ): + if self.speculative_num_steps <= 1: + return - # Build tree mask - # Directly write to cuda graph buffers for verify attn - tree_mask_buf, position_buf = ( - self.target_worker.model_runner.attn_backend.get_verify_buffers_to_fill_after_draft() + model_worker_batch.forward_mode = ( + ForwardMode.IDLE + if model_worker_batch.forward_mode.is_idle() + else ForwardMode.DECODE ) + model_worker_batch.seq_lens = batch_result.next_draft_input.new_seq_lens + # To ensure accurate acceptance length, seq_lens_cpu synchronization is needed here. + # However, this synchronization contradicts the intent and benefit of spec_v2_zero_bubble. + # As a result, spec_v2_zero_bubble is ideal for architectures like DeepSeek-V3.2 + # that don't need seq_lens_cpu. For models dependent on seq_lens_cpu, + # skipping this synchronization might affect the acceptance length. + # model_worker_batch.seq_lens_cpu = model_worker_batch.seq_lens.to("cpu") - ( - tree_mask, - position, - retrive_index, - retrive_next_token, - retrive_next_sibling, - draft_tokens, - ) = build_tree_kernel_efficient( - draft_input.verified_id, - parent_list, - top_scores_index, - draft_tokens, - model_worker_batch.seq_lens, - model_worker_batch.seq_lens_sum, + forward_batch, can_cuda_graph = draft_input.prepare_for_v2_draft( + self.req_to_token_pool, + model_worker_batch, + self.cuda_graph_runner, + self.draft_runner, self.topk, self.speculative_num_steps, - self.speculative_num_draft_tokens, - self.tree_mask_mode, - tree_mask_buf, - position_buf, ) - return EagleVerifyInput( - draft_token=draft_tokens, - custom_mask=tree_mask, - positions=position, - retrive_index=retrive_index, - retrive_next_token=retrive_next_token, - retrive_next_sibling=retrive_next_sibling, - retrive_cum_len=None, - spec_steps=self.speculative_num_steps, - topk=self.topk, - draft_token_num=self.speculative_num_draft_tokens, - capture_hidden_mode=None, - seq_lens_sum=None, - seq_lens_cpu=None, + forward_batch.spec_info.hidden_states = ( + batch_result.next_draft_input.hidden_states + ) + forward_batch.spec_info.topk_p = batch_result.next_draft_input.topk_p + forward_batch.spec_info.topk_index = batch_result.next_draft_input.topk_index + + # Run draft + if can_cuda_graph: + ret_topk_p, ret_topk_index = self.cuda_graph_runner.replay( + forward_batch, + ) + else: + if not forward_batch.forward_mode.is_idle(): + self.draft_attn_backend.init_forward_metadata(forward_batch) + ret_topk_p, ret_topk_index = self.draft_forward_zero_bubble(forward_batch) + + assert isinstance(ret_topk_p, torch.Tensor) and isinstance( + ret_topk_index, torch.Tensor + ) + next_draft_input = batch_result.next_draft_input + ret_topk_p_list = [next_draft_input.topk_p, ret_topk_p] + ret_topk_index_list = [next_draft_input.topk_index, ret_topk_index] + ( + next_draft_input.topk_p, + next_draft_input.topk_index, + next_draft_input.hidden_states, + ) = ( + torch.cat(ret_topk_p_list, dim=1).clone(), + torch.cat(ret_topk_index_list, dim=1).clone(), + None, # When enable `spec_v2_zero_bubble` feature, we don't need to save hidden_states for next step + ) + + def draft_forward_for_prepare(self, spec_info): + topk_p, topk_index, hidden_states = ( + spec_info.topk_p, + spec_info.topk_index, + spec_info.hidden_states, + ) + if self.hot_token_id is not None: + topk_index = self.hot_token_id[topk_index] + + # Forward multiple steps + input_ids, hidden_states, scores, tree_info = select_top_k_tokens( + 0, topk_p, topk_index, hidden_states, None, self.topk ) + score_list = [ + tree_info[0][:, :, i].unsqueeze(-1) + for i in range(self.speculative_num_steps) + ] + token_list = [ + tree_info[1][:, i].unsqueeze(-1) for i in range(self.speculative_num_steps) + ] + parents_list = [tree_info[2]] + [ + torch.full((tree_info[2].size(0), 1), i, dtype=torch.long, device="cuda") + for i in range(1, self.speculative_num_steps) + ] + + # Organize the results + score_list = torch.cat(score_list, dim=1).flatten( + 1 + ) # b, n, topk; n= 1 + (num_steps-1) * self.topk + ss_token_list = torch.cat( + token_list, dim=1 + ) # b, (self.topk + (num_steps-1) * self.topk) + top_scores = torch.topk( + score_list, self.speculative_num_draft_tokens - 1, dim=-1 + ) + top_scores_index = top_scores.indices + top_scores_index = torch.sort(top_scores_index).values + draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) + + if len(parents_list) > 1: + parent_list = torch.cat(parents_list[:-1], dim=1) + else: + batch_size = parents_list[0].shape[0] + parent_list = torch.empty(batch_size, 0, device=parents_list[0].device) + + return parent_list, top_scores_index, draft_tokens + def draft_forward(self, forward_batch: ForwardBatch): # Parse args spec_info: EagleDraftInput = forward_batch.spec_info @@ -429,6 +566,7 @@ def draft_forward(self, forward_batch: ForwardBatch): input_ids, hidden_states, scores, tree_info = select_top_k_tokens( i, topk_p, topk_index, hidden_states, scores, self.topk ) + score_list.append(tree_info[0]) token_list.append(tree_info[1]) parents_list.append(tree_info[2]) @@ -489,6 +627,73 @@ def draft_forward(self, forward_batch: ForwardBatch): return parent_list, top_scores_index, draft_tokens + def draft_forward_zero_bubble(self, forward_batch: ForwardBatch): + # Parse args + spec_info: EagleDraftInput = forward_batch.spec_info + out_cache_loc = forward_batch.out_cache_loc + topk_p, topk_index, hidden_states = ( + spec_info.topk_p, + spec_info.topk_index, + spec_info.hidden_states, + ) + + maybe_detect_nan(topk_p, "draft_forward: NaN in initial topk_p from spec_info") + + if self.hot_token_id is not None: + topk_index = self.hot_token_id[topk_index] + + out_cache_loc = out_cache_loc.reshape( + forward_batch.batch_size, self.topk, self.speculative_num_steps + ) + out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape( + self.speculative_num_steps, -1 + ) + + # Return values + ret_topk_p_list: List[torch.Tensor] = [] + ret_topk_index_list: List[torch.Tensor] = [] + + # Forward multiple steps + scores = None + for i in range(self.speculative_num_steps): + if i == self.speculative_num_steps - 1: + break + + input_ids, hidden_states, scores, tree_info = select_top_k_tokens( + i, topk_p, topk_index, hidden_states, scores, self.topk + ) + + # Set inputs + forward_batch.input_ids = input_ids + forward_batch.out_cache_loc = out_cache_loc[i] + forward_batch.positions.add_(1) + forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i] + spec_info.hidden_states = hidden_states + + # Run forward + logits_output = self.draft_runner.forward( + forward_batch, skip_attn_backend_init=True + ).logits_output + maybe_detect_nan(logits_output.next_token_logits, f"draft_forward step {i}") + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + hidden_states = logits_output.hidden_states + maybe_detect_oob( + topk_index, + 0, + logits_output.next_token_logits.shape[-1], + f"draft_forward step {i}: topk_index OOB vs vocab_size={logits_output.next_token_logits.shape[-1]}", + ) + # Save return values + ret_topk_p_list.append(topk_p) + ret_topk_index_list.append(topk_index) + if self.hot_token_id is not None: + topk_index = self.hot_token_id[topk_index] + + ret_topk_p = torch.cat(ret_topk_p_list, dim=1) + ret_topk_index = torch.cat(ret_topk_index_list, dim=1) + return ret_topk_p, ret_topk_index + def draft_extend(self): pass @@ -539,9 +744,12 @@ def _draft_extend_for_prefill( # Update spec_info for the next draft step probs = torch.softmax(logits_output.next_token_logits, dim=-1) - next_draft_input.topk_p, next_draft_input.topk_index = fast_topk( - probs, self.topk, dim=-1 - ) + topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + if self.enable_spec_v2_zero_bubble and self.speculative_num_steps > 1: + topk_pad_size = self.speculative_num_steps * self.topk - topk_p.shape[-1] + topk_p = F.pad(topk_p, (0, topk_pad_size)) + topk_index = F.pad(topk_index, (0, topk_pad_size)) + next_draft_input.topk_p, next_draft_input.topk_index = topk_p, topk_index next_draft_input.hidden_states = logits_output.hidden_states return next_draft_input @@ -621,6 +829,10 @@ def _draft_extend_for_decode( ret_hidden_states, ) + # If enable spec_v2_zero_bubble, draft will be handled after draft_extend, not before verify + if self.enable_spec_v2_zero_bubble: + self.draft_zero_bubble(batch, batch_result, draft_input) + class EAGLEWorkerV2(BaseSpecWorker): def __init__( @@ -677,11 +889,11 @@ def __init__( self.plan_stream, self.plan_stream_ctx = _get_plan_stream(self.device) @property - def target_worker(self): + def target_worker(self) -> TpModelWorker: return self._target_worker @property - def draft_worker(self): + def draft_worker(self) -> EagleDraftWorker: return self._draft_worker def clear_cache_pool(self): @@ -715,19 +927,29 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): return batch_output else: if model_worker_batch.spec_info is None: + topk = self.topk + if self.draft_worker.enable_spec_v2_zero_bubble: + topk *= self.speculative_num_steps model_worker_batch.spec_info = EagleDraftInput.create_idle_input( device=self.device, hidden_size=self.target_worker.model_config.hidden_size, dtype=self.target_worker.model_config.dtype, - topk=self.topk, + topk=topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) with self.draft_worker.draft_tp_context( self.draft_worker.draft_runner.tp_group ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): - verify_input: EagleVerifyInput = self.draft_worker.draft( - model_worker_batch - ) + if self.draft_worker.enable_spec_v2_zero_bubble: + verify_input: EagleVerifyInput = ( + self.draft_worker.prepare_verify_fully_async_decoding( + model_worker_batch + ) + ) + else: + verify_input: EagleVerifyInput = self.draft_worker.draft( + model_worker_batch + ) assert verify_input.is_verify_input() model_worker_batch.spec_info = verify_input batch_output = self.verify(model_worker_batch) diff --git a/python/sglang/srt/speculative/spec_utils.py b/python/sglang/srt/speculative/spec_utils.py index 05421c202cae..e96e9b99ab5f 100644 --- a/python/sglang/srt/speculative/spec_utils.py +++ b/python/sglang/srt/speculative/spec_utils.py @@ -53,8 +53,11 @@ def spec_need_hidden_states(server_args: Optional[ServerArgs] = None) -> bool: if server_args is None: server_args = get_global_server_args() + # When enable `spec_v2_zero_bubble` feature, we don't need to save hidden_states for next step + enable_spec_v2_zero_bubble = envs.SGLANG_SPEC_V2_ZERO_BUBBLE.get() + # TODO(lsyin): also skip when 1) step = 1 or 2) standalone draft model - return not server_args.enable_multi_layer_eagle + return not server_args.enable_multi_layer_eagle and not enable_spec_v2_zero_bubble @triton.jit diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index a6be5c606c4f..c5b2f6824aba 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -601,12 +601,18 @@ def get_available_gpu_memory( return free_gpu_memory / (1 << 30) +_is_cuda = is_cuda() +_is_npu = is_npu() + + def is_pin_memory_available(device=None) -> bool: - if not torch.cuda.is_available(): - return False - if device is not None and str(device) == "cpu": - return False - return True + if device is None: + return _is_cuda or _is_npu + if str(device) == "cuda": + return _is_cuda + if str(device) == "npu": + return _is_npu + return False class LayerFn(Protocol): diff --git a/test/registered/8-gpu-models/test_dsa_models_mtp.py b/test/registered/8-gpu-models/test_dsa_models_mtp.py index 11b311ba8c74..2d7f21ced4db 100644 --- a/test/registered/8-gpu-models/test_dsa_models_mtp.py +++ b/test/registered/8-gpu-models/test_dsa_models_mtp.py @@ -365,5 +365,91 @@ def test_bs_1_speed(self): self.assertGreater(speed, 150) +class TestDeepseekV32TPMTPV2ZeroBubble(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = FULL_DEEPSEEK_V32_MODEL_PATH + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--trust-remote-code", + "--tp", + "8", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--mem-frac", + "0.7", + "--model-loader-extra-config", + '{"enable_multithread_load": true, "num_threads": 64}', + ] + with envs.SGLANG_ENABLE_SPEC_V2.override( + True + ) and envs.SGLANG_SPEC_V2_ZERO_BUBBLE.override(True): + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_a_gsm8k( + self, + ): # Append an "a" to make this test run first (alphabetically) to warm up the server + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=500, + num_threads=500, + num_shots=20, + ) + metrics = run_eval(args) + print(f"{metrics=}") + + server_info = requests.get(self.base_url + "/server_info") + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] + print(f"{avg_spec_accept_length=}") + + if is_in_ci(): + write_github_step_summary( + f"### test_gsm8k (deepseek-v32 mtp)\n" + f'{metrics["score"]=:.3f}\n' + f"{avg_spec_accept_length=:.2f}\n" + ) + self.assertGreater(metrics["score"], 0.94) + self.assertGreater(avg_spec_accept_length, 2.7) + + def test_bs_1_speed(self): + args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) + acc_length, speed = send_one_prompt(args) + + print(f"{acc_length=:.2f} {speed=:.2f}") + + if is_in_ci(): + write_github_step_summary( + f"### test_bs_1_speed (deepseek-v32 mtp)\n" + f"{acc_length=:.2f}\n" + f"{speed=:.2f} token/s\n" + ) + + self.assertGreater(acc_length, 2.7) + self.assertGreater(speed, 180) + + if __name__ == "__main__": unittest.main()