From 37b30ae83e58b1fd05397e0931d38d6bba75e238 Mon Sep 17 00:00:00 2001 From: Alison Shao Date: Sat, 28 Feb 2026 06:02:15 -0800 Subject: [PATCH] Revert "[SGL] sync patch: Remove sync points, prefill cudagraph for DP, disable cache reset in mem check (#19190)" This reverts commit b5a8e4179ea7577291ed2f11ad4563560fc9b66c. --- .../srt/batch_overlap/two_batch_overlap.py | 1 - python/sglang/srt/connector/__init__.py | 2 +- python/sglang/srt/layers/logits_processor.py | 17 ++--- python/sglang/srt/managers/schedule_batch.py | 36 +++------- python/sglang/srt/mem_cache/memory_pool.py | 69 +++++++------------ .../srt/model_executor/forward_batch_info.py | 2 - .../sglang/srt/model_executor/model_runner.py | 2 +- 7 files changed, 42 insertions(+), 87 deletions(-) diff --git a/python/sglang/srt/batch_overlap/two_batch_overlap.py b/python/sglang/srt/batch_overlap/two_batch_overlap.py index c05e10b5b387..d167e65c7f36 100644 --- a/python/sglang/srt/batch_overlap/two_batch_overlap.py +++ b/python/sglang/srt/batch_overlap/two_batch_overlap.py @@ -685,7 +685,6 @@ def filter_batch( for key in [ "forward_mode", "is_extend_in_batch", - "all_extend_in_batch", "return_logprob", "req_to_token_pool", "token_to_kv_pool", diff --git a/python/sglang/srt/connector/__init__.py b/python/sglang/srt/connector/__init__.py index 053ab7c49eca..c9663a836d14 100644 --- a/python/sglang/srt/connector/__init__.py +++ b/python/sglang/srt/connector/__init__.py @@ -22,7 +22,7 @@ class ConnectorType(str, enum.Enum): INSTANCE = "instance" -def create_remote_connector(url, device=None, **kwargs) -> BaseConnector: +def create_remote_connector(url, device, **kwargs) -> BaseConnector: connector_type = parse_connector_type(url) if connector_type == "redis": return RedisConnector(url) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 5362357f9036..aff05bf42703 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -519,11 +519,11 @@ def _get_pruned_states( if hidden_states_before_norm is not None: pruned_states_before_norm = torch.cat(pruned_states_before_norm_list) sample_indices = torch.tensor( - sample_indices, dtype=torch.int64, pin_memory=True - ).to(pruned_states.device, non_blocking=True) + sample_indices, device=pruned_states.device, dtype=torch.int64 + ) input_logprob_indices = torch.tensor( - input_logprob_indices, dtype=torch.int64, pin_memory=True - ).to(pruned_states.device, non_blocking=True) + input_logprob_indices, device=pruned_states.device, dtype=torch.int64 + ) return ( pruned_states, @@ -590,24 +590,19 @@ def _get_hidden_states_to_store( def _expand_metadata_for_logprobs( self, logits_metadata: LogitsMetadata, device: torch.device ): - # Avoid implicit device sync inside repeat_interleave by providing output_size, - # which we can compute from CPU metadata. - total_pruned_len = sum(logits_metadata.extend_logprob_pruned_lens_cpu) pruned_lens = torch.tensor( logits_metadata.extend_logprob_pruned_lens_cpu, - pin_memory=True, - ).to(device, non_blocking=True) + device=device, + ) if logits_metadata.temp_scaled_logprobs: logits_metadata.temperature = torch.repeat_interleave( logits_metadata.temperature.view(-1), pruned_lens, - output_size=total_pruned_len, ).view(-1, 1) if logits_metadata.top_p_normalized_logprobs: logits_metadata.top_p = torch.repeat_interleave( logits_metadata.top_p, pruned_lens, - output_size=total_pruned_len, ) def process_input_logprobs(self, input_logits, logits_metadata: LogitsMetadata): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 96649cbe6f4c..03f86c6152d0 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1226,7 +1226,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): global_num_tokens: Optional[List[int]] = None global_num_tokens_for_logprob: Optional[List[int]] = None is_extend_in_batch: bool = False - all_extend_in_batch: bool = False can_run_dp_cuda_graph: bool = False tbo_split_seq_index: Optional[int] = None global_forward_mode: Optional[ForwardMode] = None @@ -1986,34 +1985,22 @@ def prepare_for_decode(self): self.seq_lens_sum += bs if get_global_server_args().enable_mamba_extra_buffer(): - # Build indices fully on GPU without scalar extraction. - # Each slice is shape [1]; cat -> [bs]. - if len(self.reqs) == 0: - self.mamba_track_indices = torch.empty( - (0,), dtype=torch.int64, device=self.device - ) - else: - self.mamba_track_indices = torch.cat( - [ - ( - req.mamba_ping_pong_track_buffer[1:] - if req.mamba_next_track_idx == 1 - else req.mamba_ping_pong_track_buffer[:1] - ) - for req in self.reqs - ], - dim=0, - ).to(torch.int64) - - # Keep mask construction in the pinned-tensor form. + self.mamba_track_indices = torch.tensor( + [ + req.mamba_ping_pong_track_buffer[req.mamba_next_track_idx] + for req in self.reqs + ], + dtype=torch.int64, + device=self.device, + ) self.mamba_track_mask = torch.tensor( [ sl % get_global_server_args().mamba_track_interval == 0 for sl in self.seq_lens_cpu ], dtype=torch.bool, - pin_memory=True, - ).to(device=self.device, non_blocking=True) + device=self.device, + ) def maybe_wait_verify_done(self): if self.is_spec_v2: @@ -2183,7 +2170,6 @@ def get_model_worker_batch( global_num_tokens=self.global_num_tokens, global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, is_extend_in_batch=self.is_extend_in_batch, - all_extend_in_batch=self.all_extend_in_batch, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, tbo_split_seq_index=self.tbo_split_seq_index, global_forward_mode=self.global_forward_mode, @@ -2241,7 +2227,6 @@ def copy(self): global_num_tokens=self.global_num_tokens, global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, - all_extend_in_batch=self.all_extend_in_batch, is_extend_in_batch=self.is_extend_in_batch, is_prefill_only=self.is_prefill_only, seq_lens_cpu=self.seq_lens_cpu, @@ -2346,7 +2331,6 @@ class ModelWorkerBatch: global_num_tokens: Optional[List[int]] global_num_tokens_for_logprob: Optional[List[int]] is_extend_in_batch: bool - all_extend_in_batch: bool can_run_dp_cuda_graph: bool tbo_split_seq_index: Optional[int] global_forward_mode: Optional[ForwardMode] diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 11c45ff1cc8f..d654b7f31944 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -343,18 +343,10 @@ def alloc(self, need_size: int) -> Optional[torch.Tensor]: select_index = self.free_slots[:need_size] self.free_slots = self.free_slots[need_size:] - # clear at alloc time — expand a scalar GPU zero to the right shape, no CPU-GPU sync + # clear at alloc time, fill allocated slots with zeros for i in range(len(self.mamba_cache.conv)): - t = self.mamba_cache.conv[i] - z = torch.zeros(1, dtype=t.dtype, device=t.device).expand( - t.shape[0], need_size, *t.shape[2:] - ) - t[:, select_index] = z - t = self.mamba_cache.temporal - z = torch.zeros(1, dtype=t.dtype, device=t.device).expand( - t.shape[0], need_size, *t.shape[2:] - ) - t[:, select_index] = z + self.mamba_cache.conv[i][:, select_index] = 0 + self.mamba_cache.temporal[:, select_index] = 0 return select_index @@ -522,8 +514,8 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]: if select_index is None: return None - mamba_indices: list[torch.Tensor] = [] - mamba_ping_pong_track_buffers: list[torch.Tensor] = [] + mamba_index = [] + mamba_ping_pong_track_buffer_list = [] for req in reqs: mid = None if req.mamba_pool_idx is not None: # for radix cache @@ -535,7 +527,7 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]: ), f"Not enough space for mamba cache, try to increase --mamba-full-memory-ratio or --max-mamba-cache-size. {mid=}, {self.mamba_pool.size=}, {self.mamba_pool.available_size()=}, {len(reqs)=}" mid = mid[0] req.mamba_pool_idx = mid - mamba_indices.append(mid) + mamba_index.append(mid) if self.enable_mamba_extra_buffer: if req.mamba_ping_pong_track_buffer is None: req.mamba_ping_pong_track_buffer = self.mamba_pool.alloc( @@ -545,22 +537,26 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]: req.mamba_ping_pong_track_buffer is not None ), "Not enough space for mamba ping pong idx, try to increase --mamba-full-memory-ratio." req.mamba_next_track_idx = 0 - mamba_ping_pong_track_buffers.append(req.mamba_ping_pong_track_buffer) + mamba_ping_pong_track_buffer_list.append( + req.mamba_ping_pong_track_buffer.tolist() + ) assert len(select_index) == len( - mamba_indices + mamba_index ), f"Not enough space for mamba cache, try to increase --mamba-full-memory-ratio or --max-mamba-cache-size." if self.enable_mamba_extra_buffer: assert len(select_index) == len( - mamba_ping_pong_track_buffers + mamba_ping_pong_track_buffer_list ), f"Not enough space for mamba ping pong idx, try to increase --mamba-full-memory-ratio." - mamba_index_tensor = torch.stack(mamba_indices).to(dtype=torch.int32) - self.req_index_to_mamba_index_mapping[select_index] = mamba_index_tensor + self.req_index_to_mamba_index_mapping[select_index] = torch.tensor( + mamba_index, dtype=torch.int32, device=self.device + ) if self.enable_mamba_extra_buffer: - ping_pong_tensor = torch.stack(mamba_ping_pong_track_buffers).to( - dtype=torch.int32 - ) self.req_index_to_mamba_ping_pong_track_buffer_mapping[select_index] = ( - ping_pong_tensor + torch.tensor( + mamba_ping_pong_track_buffer_list, + dtype=torch.int32, + device=self.device, + ) ) return select_index @@ -597,28 +593,11 @@ def free_mamba_cache( 0, 1, ], f"mamba_ping_pong_track_buffer_to_keep must be 0 or 1, {mamba_ping_pong_track_buffer_to_keep=}" - # Avoid Python-list advanced indexing on a device tensor. - # The ping-pong buffer size is either 2 (normal) or 1 (spec decode). - if self.mamba_ping_pong_track_buffer_size == 2: - idx_to_free = 1 - mamba_ping_pong_track_buffer_to_keep - mamba_ping_pong_track_buffer_to_free = ( - mamba_ping_pong_track_buffer_to_free[ - idx_to_free : idx_to_free + 1 - ] - ) - else: - assert self.mamba_ping_pong_track_buffer_size == 1, ( - f"Unexpected mamba_ping_pong_track_buffer_size=" - f"{self.mamba_ping_pong_track_buffer_size}" - ) - assert mamba_ping_pong_track_buffer_to_keep == 0, ( - "mamba_ping_pong_track_buffer_to_keep must be 0 when " - "mamba_ping_pong_track_buffer_size is 1" - ) - # Keep the only slot, so free nothing. - mamba_ping_pong_track_buffer_to_free = ( - mamba_ping_pong_track_buffer_to_free[0:0] - ) + idx_to_free = list(range(self.mamba_ping_pong_track_buffer_size)) + idx_to_free.remove(mamba_ping_pong_track_buffer_to_keep) + mamba_ping_pong_track_buffer_to_free = ( + mamba_ping_pong_track_buffer_to_free[idx_to_free] + ) self.mamba_pool.free(mamba_ping_pong_track_buffer_to_free) def clear(self): diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 84c64b08d5dd..0e84ec8aab27 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -338,7 +338,6 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime global_dp_buffer_len: Optional[int] = None is_extend_in_batch: bool = False - all_extend_in_batch: bool = False can_run_dp_cuda_graph: bool = False global_forward_mode: Optional[ForwardMode] = None @@ -405,7 +404,6 @@ def init_new( top_logprobs_nums=batch.top_logprobs_nums, token_ids_logprobs=batch.token_ids_logprobs, is_extend_in_batch=batch.is_extend_in_batch, - all_extend_in_batch=batch.all_extend_in_batch, can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, global_forward_mode=batch.global_forward_mode, is_prefill_only=batch.is_prefill_only, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index cee3dd2ea325..219473245f96 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1134,7 +1134,7 @@ def update_weights_from_disk( """Update engine weights in-place from the disk.""" logger.info( f"Update engine weights online from disk begin. " - f"avail mem={get_available_gpu_memory(self.device, self.gpu_id, empty_cache=False):.2f} GB" + f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) target_device = torch.device(self.device)