Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/sglang/srt/batch_overlap/two_batch_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,6 @@ def filter_batch(
for key in [
"forward_mode",
"is_extend_in_batch",
"all_extend_in_batch",
"return_logprob",
"req_to_token_pool",
"token_to_kv_pool",
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/connector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ConnectorType(str, enum.Enum):
INSTANCE = "instance"


def create_remote_connector(url, device=None, **kwargs) -> BaseConnector:
def create_remote_connector(url, device, **kwargs) -> BaseConnector:
connector_type = parse_connector_type(url)
if connector_type == "redis":
return RedisConnector(url)
Expand Down
17 changes: 6 additions & 11 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,11 +519,11 @@ def _get_pruned_states(
if hidden_states_before_norm is not None:
pruned_states_before_norm = torch.cat(pruned_states_before_norm_list)
sample_indices = torch.tensor(
sample_indices, dtype=torch.int64, pin_memory=True
).to(pruned_states.device, non_blocking=True)
sample_indices, device=pruned_states.device, dtype=torch.int64
)
input_logprob_indices = torch.tensor(
input_logprob_indices, dtype=torch.int64, pin_memory=True
).to(pruned_states.device, non_blocking=True)
input_logprob_indices, device=pruned_states.device, dtype=torch.int64
)

return (
pruned_states,
Expand Down Expand Up @@ -590,24 +590,19 @@ def _get_hidden_states_to_store(
def _expand_metadata_for_logprobs(
self, logits_metadata: LogitsMetadata, device: torch.device
):
# Avoid implicit device sync inside repeat_interleave by providing output_size,
# which we can compute from CPU metadata.
total_pruned_len = sum(logits_metadata.extend_logprob_pruned_lens_cpu)
pruned_lens = torch.tensor(
logits_metadata.extend_logprob_pruned_lens_cpu,
pin_memory=True,
).to(device, non_blocking=True)
device=device,
)
if logits_metadata.temp_scaled_logprobs:
logits_metadata.temperature = torch.repeat_interleave(
logits_metadata.temperature.view(-1),
pruned_lens,
output_size=total_pruned_len,
).view(-1, 1)
if logits_metadata.top_p_normalized_logprobs:
logits_metadata.top_p = torch.repeat_interleave(
logits_metadata.top_p,
pruned_lens,
output_size=total_pruned_len,
)

def process_input_logprobs(self, input_logits, logits_metadata: LogitsMetadata):
Expand Down
36 changes: 10 additions & 26 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None
is_extend_in_batch: bool = False
all_extend_in_batch: bool = False
can_run_dp_cuda_graph: bool = False
tbo_split_seq_index: Optional[int] = None
global_forward_mode: Optional[ForwardMode] = None
Expand Down Expand Up @@ -1986,34 +1985,22 @@ def prepare_for_decode(self):
self.seq_lens_sum += bs

if get_global_server_args().enable_mamba_extra_buffer():
# Build indices fully on GPU without scalar extraction.
# Each slice is shape [1]; cat -> [bs].
if len(self.reqs) == 0:
self.mamba_track_indices = torch.empty(
(0,), dtype=torch.int64, device=self.device
)
else:
self.mamba_track_indices = torch.cat(
[
(
req.mamba_ping_pong_track_buffer[1:]
if req.mamba_next_track_idx == 1
else req.mamba_ping_pong_track_buffer[:1]
)
for req in self.reqs
],
dim=0,
).to(torch.int64)

# Keep mask construction in the pinned-tensor form.
self.mamba_track_indices = torch.tensor(
[
req.mamba_ping_pong_track_buffer[req.mamba_next_track_idx]
for req in self.reqs
],
dtype=torch.int64,
device=self.device,
)
self.mamba_track_mask = torch.tensor(
[
sl % get_global_server_args().mamba_track_interval == 0
for sl in self.seq_lens_cpu
],
dtype=torch.bool,
pin_memory=True,
).to(device=self.device, non_blocking=True)
device=self.device,
)

def maybe_wait_verify_done(self):
if self.is_spec_v2:
Expand Down Expand Up @@ -2183,7 +2170,6 @@ def get_model_worker_batch(
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
is_extend_in_batch=self.is_extend_in_batch,
all_extend_in_batch=self.all_extend_in_batch,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
tbo_split_seq_index=self.tbo_split_seq_index,
global_forward_mode=self.global_forward_mode,
Expand Down Expand Up @@ -2241,7 +2227,6 @@ def copy(self):
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
all_extend_in_batch=self.all_extend_in_batch,
is_extend_in_batch=self.is_extend_in_batch,
is_prefill_only=self.is_prefill_only,
seq_lens_cpu=self.seq_lens_cpu,
Expand Down Expand Up @@ -2346,7 +2331,6 @@ class ModelWorkerBatch:
global_num_tokens: Optional[List[int]]
global_num_tokens_for_logprob: Optional[List[int]]
is_extend_in_batch: bool
all_extend_in_batch: bool
can_run_dp_cuda_graph: bool
tbo_split_seq_index: Optional[int]
global_forward_mode: Optional[ForwardMode]
Expand Down
69 changes: 24 additions & 45 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,18 +343,10 @@ def alloc(self, need_size: int) -> Optional[torch.Tensor]:

select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
# clear at alloc time — expand a scalar GPU zero to the right shape, no CPU-GPU sync
# clear at alloc time, fill allocated slots with zeros
for i in range(len(self.mamba_cache.conv)):
t = self.mamba_cache.conv[i]
z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
t.shape[0], need_size, *t.shape[2:]
)
t[:, select_index] = z
t = self.mamba_cache.temporal
z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
t.shape[0], need_size, *t.shape[2:]
)
t[:, select_index] = z
self.mamba_cache.conv[i][:, select_index] = 0
self.mamba_cache.temporal[:, select_index] = 0

return select_index

Expand Down Expand Up @@ -522,8 +514,8 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]:
if select_index is None:
return None

mamba_indices: list[torch.Tensor] = []
mamba_ping_pong_track_buffers: list[torch.Tensor] = []
mamba_index = []
mamba_ping_pong_track_buffer_list = []
for req in reqs:
mid = None
if req.mamba_pool_idx is not None: # for radix cache
Expand All @@ -535,7 +527,7 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]:
), f"Not enough space for mamba cache, try to increase --mamba-full-memory-ratio or --max-mamba-cache-size. {mid=}, {self.mamba_pool.size=}, {self.mamba_pool.available_size()=}, {len(reqs)=}"
mid = mid[0]
req.mamba_pool_idx = mid
mamba_indices.append(mid)
mamba_index.append(mid)
if self.enable_mamba_extra_buffer:
if req.mamba_ping_pong_track_buffer is None:
req.mamba_ping_pong_track_buffer = self.mamba_pool.alloc(
Expand All @@ -545,22 +537,26 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]:
req.mamba_ping_pong_track_buffer is not None
), "Not enough space for mamba ping pong idx, try to increase --mamba-full-memory-ratio."
req.mamba_next_track_idx = 0
mamba_ping_pong_track_buffers.append(req.mamba_ping_pong_track_buffer)
mamba_ping_pong_track_buffer_list.append(
req.mamba_ping_pong_track_buffer.tolist()
)
assert len(select_index) == len(
mamba_indices
mamba_index
), f"Not enough space for mamba cache, try to increase --mamba-full-memory-ratio or --max-mamba-cache-size."
if self.enable_mamba_extra_buffer:
assert len(select_index) == len(
mamba_ping_pong_track_buffers
mamba_ping_pong_track_buffer_list
), f"Not enough space for mamba ping pong idx, try to increase --mamba-full-memory-ratio."
mamba_index_tensor = torch.stack(mamba_indices).to(dtype=torch.int32)
self.req_index_to_mamba_index_mapping[select_index] = mamba_index_tensor
self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
mamba_index, dtype=torch.int32, device=self.device
)
if self.enable_mamba_extra_buffer:
ping_pong_tensor = torch.stack(mamba_ping_pong_track_buffers).to(
dtype=torch.int32
)
self.req_index_to_mamba_ping_pong_track_buffer_mapping[select_index] = (
ping_pong_tensor
torch.tensor(
mamba_ping_pong_track_buffer_list,
dtype=torch.int32,
device=self.device,
)
)
return select_index

Expand Down Expand Up @@ -597,28 +593,11 @@ def free_mamba_cache(
0,
1,
], f"mamba_ping_pong_track_buffer_to_keep must be 0 or 1, {mamba_ping_pong_track_buffer_to_keep=}"
# Avoid Python-list advanced indexing on a device tensor.
# The ping-pong buffer size is either 2 (normal) or 1 (spec decode).
if self.mamba_ping_pong_track_buffer_size == 2:
idx_to_free = 1 - mamba_ping_pong_track_buffer_to_keep
mamba_ping_pong_track_buffer_to_free = (
mamba_ping_pong_track_buffer_to_free[
idx_to_free : idx_to_free + 1
]
)
else:
assert self.mamba_ping_pong_track_buffer_size == 1, (
f"Unexpected mamba_ping_pong_track_buffer_size="
f"{self.mamba_ping_pong_track_buffer_size}"
)
assert mamba_ping_pong_track_buffer_to_keep == 0, (
"mamba_ping_pong_track_buffer_to_keep must be 0 when "
"mamba_ping_pong_track_buffer_size is 1"
)
# Keep the only slot, so free nothing.
mamba_ping_pong_track_buffer_to_free = (
mamba_ping_pong_track_buffer_to_free[0:0]
)
idx_to_free = list(range(self.mamba_ping_pong_track_buffer_size))
idx_to_free.remove(mamba_ping_pong_track_buffer_to_keep)
mamba_ping_pong_track_buffer_to_free = (
mamba_ping_pong_track_buffer_to_free[idx_to_free]
)
self.mamba_pool.free(mamba_ping_pong_track_buffer_to_free)

def clear(self):
Expand Down
2 changes: 0 additions & 2 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin):
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
global_dp_buffer_len: Optional[int] = None
is_extend_in_batch: bool = False
all_extend_in_batch: bool = False
can_run_dp_cuda_graph: bool = False
global_forward_mode: Optional[ForwardMode] = None

Expand Down Expand Up @@ -405,7 +404,6 @@ def init_new(
top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
is_extend_in_batch=batch.is_extend_in_batch,
all_extend_in_batch=batch.all_extend_in_batch,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
global_forward_mode=batch.global_forward_mode,
is_prefill_only=batch.is_prefill_only,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,7 @@ def update_weights_from_disk(
"""Update engine weights in-place from the disk."""
logger.info(
f"Update engine weights online from disk begin. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id, empty_cache=False):.2f} GB"
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)

target_device = torch.device(self.device)
Expand Down
Loading