diff --git a/python/sglang/srt/batch_overlap/two_batch_overlap.py b/python/sglang/srt/batch_overlap/two_batch_overlap.py index cfd2a54ed132..11ae504e8512 100644 --- a/python/sglang/srt/batch_overlap/two_batch_overlap.py +++ b/python/sglang/srt/batch_overlap/two_batch_overlap.py @@ -683,6 +683,7 @@ def filter_batch( for key in [ "forward_mode", "is_extend_in_batch", + "all_extend_in_batch", "return_logprob", "req_to_token_pool", "token_to_kv_pool", diff --git a/python/sglang/srt/connector/__init__.py b/python/sglang/srt/connector/__init__.py index c9663a836d14..053ab7c49eca 100644 --- a/python/sglang/srt/connector/__init__.py +++ b/python/sglang/srt/connector/__init__.py @@ -22,7 +22,7 @@ class ConnectorType(str, enum.Enum): INSTANCE = "instance" -def create_remote_connector(url, device, **kwargs) -> BaseConnector: +def create_remote_connector(url, device=None, **kwargs) -> BaseConnector: connector_type = parse_connector_type(url) if connector_type == "redis": return RedisConnector(url) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index aff05bf42703..5362357f9036 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -519,11 +519,11 @@ def _get_pruned_states( if hidden_states_before_norm is not None: pruned_states_before_norm = torch.cat(pruned_states_before_norm_list) sample_indices = torch.tensor( - sample_indices, device=pruned_states.device, dtype=torch.int64 - ) + sample_indices, dtype=torch.int64, pin_memory=True + ).to(pruned_states.device, non_blocking=True) input_logprob_indices = torch.tensor( - input_logprob_indices, device=pruned_states.device, dtype=torch.int64 - ) + input_logprob_indices, dtype=torch.int64, pin_memory=True + ).to(pruned_states.device, non_blocking=True) return ( pruned_states, @@ -590,19 +590,24 @@ def _get_hidden_states_to_store( def _expand_metadata_for_logprobs( self, logits_metadata: LogitsMetadata, device: torch.device ): + # Avoid implicit device sync inside repeat_interleave by providing output_size, + # which we can compute from CPU metadata. + total_pruned_len = sum(logits_metadata.extend_logprob_pruned_lens_cpu) pruned_lens = torch.tensor( logits_metadata.extend_logprob_pruned_lens_cpu, - device=device, - ) + pin_memory=True, + ).to(device, non_blocking=True) if logits_metadata.temp_scaled_logprobs: logits_metadata.temperature = torch.repeat_interleave( logits_metadata.temperature.view(-1), pruned_lens, + output_size=total_pruned_len, ).view(-1, 1) if logits_metadata.top_p_normalized_logprobs: logits_metadata.top_p = torch.repeat_interleave( logits_metadata.top_p, pruned_lens, + output_size=total_pruned_len, ) def process_input_logprobs(self, input_logits, logits_metadata: LogitsMetadata): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c0799579802e..ca419aa2d6c6 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1253,6 +1253,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): global_num_tokens: Optional[List[int]] = None global_num_tokens_for_logprob: Optional[List[int]] = None is_extend_in_batch: bool = False + all_extend_in_batch: bool = False can_run_dp_cuda_graph: bool = False tbo_split_seq_index: Optional[int] = None global_forward_mode: Optional[ForwardMode] = None @@ -2012,22 +2013,34 @@ def prepare_for_decode(self): self.seq_lens_sum += bs if get_global_server_args().enable_mamba_extra_buffer(): - self.mamba_track_indices = torch.tensor( - [ - req.mamba_ping_pong_track_buffer[req.mamba_next_track_idx] - for req in self.reqs - ], - dtype=torch.int64, - device=self.device, - ) + # Build indices fully on GPU without scalar extraction. + # Each slice is shape [1]; cat -> [bs]. + if len(self.reqs) == 0: + self.mamba_track_indices = torch.empty( + (0,), dtype=torch.int64, device=self.device + ) + else: + self.mamba_track_indices = torch.cat( + [ + ( + req.mamba_ping_pong_track_buffer[1:] + if req.mamba_next_track_idx == 1 + else req.mamba_ping_pong_track_buffer[:1] + ) + for req in self.reqs + ], + dim=0, + ).to(torch.int64) + + # Keep mask construction in the pinned-tensor form. self.mamba_track_mask = torch.tensor( [ sl % get_global_server_args().mamba_track_interval == 0 for sl in self.seq_lens_cpu ], dtype=torch.bool, - device=self.device, - ) + pin_memory=True, + ).to(device=self.device, non_blocking=True) def maybe_wait_verify_done(self): if self.is_spec_v2: @@ -2197,6 +2210,7 @@ def get_model_worker_batch( global_num_tokens=self.global_num_tokens, global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, is_extend_in_batch=self.is_extend_in_batch, + all_extend_in_batch=self.all_extend_in_batch, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, tbo_split_seq_index=self.tbo_split_seq_index, global_forward_mode=self.global_forward_mode, @@ -2254,6 +2268,7 @@ def copy(self): global_num_tokens=self.global_num_tokens, global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, + all_extend_in_batch=self.all_extend_in_batch, is_extend_in_batch=self.is_extend_in_batch, is_prefill_only=self.is_prefill_only, seq_lens_cpu=self.seq_lens_cpu, @@ -2358,6 +2373,7 @@ class ModelWorkerBatch: global_num_tokens: Optional[List[int]] global_num_tokens_for_logprob: Optional[List[int]] is_extend_in_batch: bool + all_extend_in_batch: bool can_run_dp_cuda_graph: bool tbo_split_seq_index: Optional[int] global_forward_mode: Optional[ForwardMode] diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 1d917137c68d..8e6d37a4231d 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -337,10 +337,18 @@ def alloc(self, need_size: int) -> Optional[torch.Tensor]: select_index = self.free_slots[:need_size] self.free_slots = self.free_slots[need_size:] - # clear at alloc time, fill allocated slots with zeros + # clear at alloc time — expand a scalar GPU zero to the right shape, no CPU-GPU sync for i in range(len(self.mamba_cache.conv)): - self.mamba_cache.conv[i][:, select_index] = 0 - self.mamba_cache.temporal[:, select_index] = 0 + t = self.mamba_cache.conv[i] + z = torch.zeros(1, dtype=t.dtype, device=t.device).expand( + t.shape[0], need_size, *t.shape[2:] + ) + t[:, select_index] = z + t = self.mamba_cache.temporal + z = torch.zeros(1, dtype=t.dtype, device=t.device).expand( + t.shape[0], need_size, *t.shape[2:] + ) + t[:, select_index] = z return select_index @@ -503,8 +511,8 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]: if select_index is None: return None - mamba_index = [] - mamba_ping_pong_track_buffer_list = [] + mamba_indices: list[torch.Tensor] = [] + mamba_ping_pong_track_buffers: list[torch.Tensor] = [] for req in reqs: mid = None if req.mamba_pool_idx is not None: # for radix cache @@ -516,7 +524,7 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]: ), f"Not enough space for mamba cache, try to increase --mamba-full-memory-ratio or --max-mamba-cache-size. {mid=}, {self.mamba_pool.size=}, {self.mamba_pool.available_size()=}, {len(reqs)=}" mid = mid[0] req.mamba_pool_idx = mid - mamba_index.append(mid) + mamba_indices.append(mid) if self.enable_mamba_extra_buffer: if req.mamba_ping_pong_track_buffer is None: req.mamba_ping_pong_track_buffer = self.mamba_pool.alloc( @@ -526,26 +534,22 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]: req.mamba_ping_pong_track_buffer is not None ), "Not enough space for mamba ping pong idx, try to increase --mamba-full-memory-ratio." req.mamba_next_track_idx = 0 - mamba_ping_pong_track_buffer_list.append( - req.mamba_ping_pong_track_buffer.tolist() - ) + mamba_ping_pong_track_buffers.append(req.mamba_ping_pong_track_buffer) assert len(select_index) == len( - mamba_index + mamba_indices ), f"Not enough space for mamba cache, try to increase --mamba-full-memory-ratio or --max-mamba-cache-size." if self.enable_mamba_extra_buffer: assert len(select_index) == len( - mamba_ping_pong_track_buffer_list + mamba_ping_pong_track_buffers ), f"Not enough space for mamba ping pong idx, try to increase --mamba-full-memory-ratio." - self.req_index_to_mamba_index_mapping[select_index] = torch.tensor( - mamba_index, dtype=torch.int32, device=self.device - ) + mamba_index_tensor = torch.stack(mamba_indices).to(dtype=torch.int32) + self.req_index_to_mamba_index_mapping[select_index] = mamba_index_tensor if self.enable_mamba_extra_buffer: + ping_pong_tensor = torch.stack(mamba_ping_pong_track_buffers).to( + dtype=torch.int32 + ) self.req_index_to_mamba_ping_pong_track_buffer_mapping[select_index] = ( - torch.tensor( - mamba_ping_pong_track_buffer_list, - dtype=torch.int32, - device=self.device, - ) + ping_pong_tensor ) return select_index @@ -582,11 +586,28 @@ def free_mamba_cache( 0, 1, ], f"mamba_ping_pong_track_buffer_to_keep must be 0 or 1, {mamba_ping_pong_track_buffer_to_keep=}" - idx_to_free = list(range(self.mamba_ping_pong_track_buffer_size)) - idx_to_free.remove(mamba_ping_pong_track_buffer_to_keep) - mamba_ping_pong_track_buffer_to_free = ( - mamba_ping_pong_track_buffer_to_free[idx_to_free] - ) + # Avoid Python-list advanced indexing on a device tensor. + # The ping-pong buffer size is either 2 (normal) or 1 (spec decode). + if self.mamba_ping_pong_track_buffer_size == 2: + idx_to_free = 1 - mamba_ping_pong_track_buffer_to_keep + mamba_ping_pong_track_buffer_to_free = ( + mamba_ping_pong_track_buffer_to_free[ + idx_to_free : idx_to_free + 1 + ] + ) + else: + assert self.mamba_ping_pong_track_buffer_size == 1, ( + f"Unexpected mamba_ping_pong_track_buffer_size=" + f"{self.mamba_ping_pong_track_buffer_size}" + ) + assert mamba_ping_pong_track_buffer_to_keep == 0, ( + "mamba_ping_pong_track_buffer_to_keep must be 0 when " + "mamba_ping_pong_track_buffer_size is 1" + ) + # Keep the only slot, so free nothing. + mamba_ping_pong_track_buffer_to_free = ( + mamba_ping_pong_track_buffer_to_free[0:0] + ) self.mamba_pool.free(mamba_ping_pong_track_buffer_to_free) def clear(self): diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 2345235324e4..75cae580c68a 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -338,6 +338,7 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime global_dp_buffer_len: Optional[int] = None is_extend_in_batch: bool = False + all_extend_in_batch: bool = False can_run_dp_cuda_graph: bool = False global_forward_mode: Optional[ForwardMode] = None @@ -401,6 +402,7 @@ def init_new( top_logprobs_nums=batch.top_logprobs_nums, token_ids_logprobs=batch.token_ids_logprobs, is_extend_in_batch=batch.is_extend_in_batch, + all_extend_in_batch=batch.all_extend_in_batch, can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, global_forward_mode=batch.global_forward_mode, is_prefill_only=batch.is_prefill_only, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 16a28f099c74..3f09c61a59cb 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1122,7 +1122,7 @@ def update_weights_from_disk( """Update engine weights in-place from the disk.""" logger.info( f"Update engine weights online from disk begin. " - f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + f"avail mem={get_available_gpu_memory(self.device, self.gpu_id, empty_cache=False):.2f} GB" ) target_device = torch.device(self.device)