diff --git a/python/sglang/srt/batch_overlap/two_batch_overlap.py b/python/sglang/srt/batch_overlap/two_batch_overlap.py index d167e65c7f36..c05e10b5b387 100644 --- a/python/sglang/srt/batch_overlap/two_batch_overlap.py +++ b/python/sglang/srt/batch_overlap/two_batch_overlap.py @@ -685,6 +685,7 @@ def filter_batch( for key in [ "forward_mode", "is_extend_in_batch", + "all_extend_in_batch", "return_logprob", "req_to_token_pool", "token_to_kv_pool", diff --git a/python/sglang/srt/connector/__init__.py b/python/sglang/srt/connector/__init__.py index c9663a836d14..053ab7c49eca 100644 --- a/python/sglang/srt/connector/__init__.py +++ b/python/sglang/srt/connector/__init__.py @@ -22,7 +22,7 @@ class ConnectorType(str, enum.Enum): INSTANCE = "instance" -def create_remote_connector(url, device, **kwargs) -> BaseConnector: +def create_remote_connector(url, device=None, **kwargs) -> BaseConnector: connector_type = parse_connector_type(url) if connector_type == "redis": return RedisConnector(url) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 03f86c6152d0..e6921f282a02 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1226,6 +1226,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): global_num_tokens: Optional[List[int]] = None global_num_tokens_for_logprob: Optional[List[int]] = None is_extend_in_batch: bool = False + all_extend_in_batch: bool = False can_run_dp_cuda_graph: bool = False tbo_split_seq_index: Optional[int] = None global_forward_mode: Optional[ForwardMode] = None @@ -1985,21 +1986,33 @@ def prepare_for_decode(self): self.seq_lens_sum += bs if get_global_server_args().enable_mamba_extra_buffer(): - self.mamba_track_indices = torch.tensor( - [ - req.mamba_ping_pong_track_buffer[req.mamba_next_track_idx] - for req in self.reqs - ], - dtype=torch.int64, - device=self.device, - ) - self.mamba_track_mask = torch.tensor( - [ - sl % get_global_server_args().mamba_track_interval == 0 - for sl in self.seq_lens_cpu - ], - dtype=torch.bool, - device=self.device, + if len(self.reqs) == 0: + self.mamba_track_indices = torch.empty( + (0,), dtype=torch.int64, device=self.device + ) + else: + # already on device + all_buffers = torch.stack( + [req.mamba_ping_pong_track_buffer for req in self.reqs] + ) + idx = ( + torch.tensor( + [req.mamba_next_track_idx for req in self.reqs], + dtype=torch.int64, + pin_memory=True, + ) + .unsqueeze(1) + .to(device=all_buffers.device, non_blocking=True) + ) + self.mamba_track_indices = ( + torch.gather(all_buffers, 1, idx).squeeze(1).to(torch.int64) + ) + + # async H2D + self.mamba_track_mask = ( + (self.seq_lens_cpu % get_global_server_args().mamba_track_interval == 0) + .pin_memory() + .to(device=self.device, non_blocking=True) ) def maybe_wait_verify_done(self): @@ -2170,6 +2183,7 @@ def get_model_worker_batch( global_num_tokens=self.global_num_tokens, global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, is_extend_in_batch=self.is_extend_in_batch, + all_extend_in_batch=self.all_extend_in_batch, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, tbo_split_seq_index=self.tbo_split_seq_index, global_forward_mode=self.global_forward_mode, @@ -2227,6 +2241,7 @@ def copy(self): global_num_tokens=self.global_num_tokens, global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, + all_extend_in_batch=self.all_extend_in_batch, is_extend_in_batch=self.is_extend_in_batch, is_prefill_only=self.is_prefill_only, seq_lens_cpu=self.seq_lens_cpu, @@ -2331,6 +2346,7 @@ class ModelWorkerBatch: global_num_tokens: Optional[List[int]] global_num_tokens_for_logprob: Optional[List[int]] is_extend_in_batch: bool + all_extend_in_batch: bool can_run_dp_cuda_graph: bool tbo_split_seq_index: Optional[int] global_forward_mode: Optional[ForwardMode] diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index d654b7f31944..11c45ff1cc8f 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -343,10 +343,18 @@ def alloc(self, need_size: int) -> Optional[torch.Tensor]: select_index = self.free_slots[:need_size] self.free_slots = self.free_slots[need_size:] - # clear at alloc time, fill allocated slots with zeros + # clear at alloc time — expand a scalar GPU zero to the right shape, no CPU-GPU sync for i in range(len(self.mamba_cache.conv)): - self.mamba_cache.conv[i][:, select_index] = 0 - self.mamba_cache.temporal[:, select_index] = 0 + t = self.mamba_cache.conv[i] + z = torch.zeros(1, dtype=t.dtype, device=t.device).expand( + t.shape[0], need_size, *t.shape[2:] + ) + t[:, select_index] = z + t = self.mamba_cache.temporal + z = torch.zeros(1, dtype=t.dtype, device=t.device).expand( + t.shape[0], need_size, *t.shape[2:] + ) + t[:, select_index] = z return select_index @@ -514,8 +522,8 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]: if select_index is None: return None - mamba_index = [] - mamba_ping_pong_track_buffer_list = [] + mamba_indices: list[torch.Tensor] = [] + mamba_ping_pong_track_buffers: list[torch.Tensor] = [] for req in reqs: mid = None if req.mamba_pool_idx is not None: # for radix cache @@ -527,7 +535,7 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]: ), f"Not enough space for mamba cache, try to increase --mamba-full-memory-ratio or --max-mamba-cache-size. {mid=}, {self.mamba_pool.size=}, {self.mamba_pool.available_size()=}, {len(reqs)=}" mid = mid[0] req.mamba_pool_idx = mid - mamba_index.append(mid) + mamba_indices.append(mid) if self.enable_mamba_extra_buffer: if req.mamba_ping_pong_track_buffer is None: req.mamba_ping_pong_track_buffer = self.mamba_pool.alloc( @@ -537,26 +545,22 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]: req.mamba_ping_pong_track_buffer is not None ), "Not enough space for mamba ping pong idx, try to increase --mamba-full-memory-ratio." req.mamba_next_track_idx = 0 - mamba_ping_pong_track_buffer_list.append( - req.mamba_ping_pong_track_buffer.tolist() - ) + mamba_ping_pong_track_buffers.append(req.mamba_ping_pong_track_buffer) assert len(select_index) == len( - mamba_index + mamba_indices ), f"Not enough space for mamba cache, try to increase --mamba-full-memory-ratio or --max-mamba-cache-size." if self.enable_mamba_extra_buffer: assert len(select_index) == len( - mamba_ping_pong_track_buffer_list + mamba_ping_pong_track_buffers ), f"Not enough space for mamba ping pong idx, try to increase --mamba-full-memory-ratio." - self.req_index_to_mamba_index_mapping[select_index] = torch.tensor( - mamba_index, dtype=torch.int32, device=self.device - ) + mamba_index_tensor = torch.stack(mamba_indices).to(dtype=torch.int32) + self.req_index_to_mamba_index_mapping[select_index] = mamba_index_tensor if self.enable_mamba_extra_buffer: + ping_pong_tensor = torch.stack(mamba_ping_pong_track_buffers).to( + dtype=torch.int32 + ) self.req_index_to_mamba_ping_pong_track_buffer_mapping[select_index] = ( - torch.tensor( - mamba_ping_pong_track_buffer_list, - dtype=torch.int32, - device=self.device, - ) + ping_pong_tensor ) return select_index @@ -593,11 +597,28 @@ def free_mamba_cache( 0, 1, ], f"mamba_ping_pong_track_buffer_to_keep must be 0 or 1, {mamba_ping_pong_track_buffer_to_keep=}" - idx_to_free = list(range(self.mamba_ping_pong_track_buffer_size)) - idx_to_free.remove(mamba_ping_pong_track_buffer_to_keep) - mamba_ping_pong_track_buffer_to_free = ( - mamba_ping_pong_track_buffer_to_free[idx_to_free] - ) + # Avoid Python-list advanced indexing on a device tensor. + # The ping-pong buffer size is either 2 (normal) or 1 (spec decode). + if self.mamba_ping_pong_track_buffer_size == 2: + idx_to_free = 1 - mamba_ping_pong_track_buffer_to_keep + mamba_ping_pong_track_buffer_to_free = ( + mamba_ping_pong_track_buffer_to_free[ + idx_to_free : idx_to_free + 1 + ] + ) + else: + assert self.mamba_ping_pong_track_buffer_size == 1, ( + f"Unexpected mamba_ping_pong_track_buffer_size=" + f"{self.mamba_ping_pong_track_buffer_size}" + ) + assert mamba_ping_pong_track_buffer_to_keep == 0, ( + "mamba_ping_pong_track_buffer_to_keep must be 0 when " + "mamba_ping_pong_track_buffer_size is 1" + ) + # Keep the only slot, so free nothing. + mamba_ping_pong_track_buffer_to_free = ( + mamba_ping_pong_track_buffer_to_free[0:0] + ) self.mamba_pool.free(mamba_ping_pong_track_buffer_to_free) def clear(self): diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 0e84ec8aab27..84c64b08d5dd 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -338,6 +338,7 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime global_dp_buffer_len: Optional[int] = None is_extend_in_batch: bool = False + all_extend_in_batch: bool = False can_run_dp_cuda_graph: bool = False global_forward_mode: Optional[ForwardMode] = None @@ -404,6 +405,7 @@ def init_new( top_logprobs_nums=batch.top_logprobs_nums, token_ids_logprobs=batch.token_ids_logprobs, is_extend_in_batch=batch.is_extend_in_batch, + all_extend_in_batch=batch.all_extend_in_batch, can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, global_forward_mode=batch.global_forward_mode, is_prefill_only=batch.is_prefill_only, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 219473245f96..cee3dd2ea325 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1134,7 +1134,7 @@ def update_weights_from_disk( """Update engine weights in-place from the disk.""" logger.info( f"Update engine weights online from disk begin. " - f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + f"avail mem={get_available_gpu_memory(self.device, self.gpu_id, empty_cache=False):.2f} GB" ) target_device = torch.device(self.device)