Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/batch_overlap/two_batch_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ def filter_batch(
for key in [
"forward_mode",
"is_extend_in_batch",
"all_extend_in_batch",
"return_logprob",
"req_to_token_pool",
"token_to_kv_pool",
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/connector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ConnectorType(str, enum.Enum):
INSTANCE = "instance"


def create_remote_connector(url, device, **kwargs) -> BaseConnector:
def create_remote_connector(url, device=None, **kwargs) -> BaseConnector:
connector_type = parse_connector_type(url)
if connector_type == "redis":
return RedisConnector(url)
Expand Down
46 changes: 31 additions & 15 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None
is_extend_in_batch: bool = False
all_extend_in_batch: bool = False
can_run_dp_cuda_graph: bool = False
tbo_split_seq_index: Optional[int] = None
global_forward_mode: Optional[ForwardMode] = None
Expand Down Expand Up @@ -1985,21 +1986,33 @@ def prepare_for_decode(self):
self.seq_lens_sum += bs

if get_global_server_args().enable_mamba_extra_buffer():
self.mamba_track_indices = torch.tensor(
[
req.mamba_ping_pong_track_buffer[req.mamba_next_track_idx]
for req in self.reqs
],
dtype=torch.int64,
device=self.device,
)
self.mamba_track_mask = torch.tensor(
[
sl % get_global_server_args().mamba_track_interval == 0
for sl in self.seq_lens_cpu
],
dtype=torch.bool,
device=self.device,
if len(self.reqs) == 0:
self.mamba_track_indices = torch.empty(
(0,), dtype=torch.int64, device=self.device
)
else:
# already on device
all_buffers = torch.stack(
[req.mamba_ping_pong_track_buffer for req in self.reqs]
)
idx = (
torch.tensor(
[req.mamba_next_track_idx for req in self.reqs],
dtype=torch.int64,
pin_memory=True,
)
.unsqueeze(1)
.to(device=all_buffers.device, non_blocking=True)
)
self.mamba_track_indices = (
torch.gather(all_buffers, 1, idx).squeeze(1).to(torch.int64)
)

# async H2D
self.mamba_track_mask = (
(self.seq_lens_cpu % get_global_server_args().mamba_track_interval == 0)
.pin_memory()
.to(device=self.device, non_blocking=True)
)

def maybe_wait_verify_done(self):
Expand Down Expand Up @@ -2170,6 +2183,7 @@ def get_model_worker_batch(
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
is_extend_in_batch=self.is_extend_in_batch,
all_extend_in_batch=self.all_extend_in_batch,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
tbo_split_seq_index=self.tbo_split_seq_index,
global_forward_mode=self.global_forward_mode,
Expand Down Expand Up @@ -2227,6 +2241,7 @@ def copy(self):
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
all_extend_in_batch=self.all_extend_in_batch,
is_extend_in_batch=self.is_extend_in_batch,
is_prefill_only=self.is_prefill_only,
seq_lens_cpu=self.seq_lens_cpu,
Expand Down Expand Up @@ -2331,6 +2346,7 @@ class ModelWorkerBatch:
global_num_tokens: Optional[List[int]]
global_num_tokens_for_logprob: Optional[List[int]]
is_extend_in_batch: bool
all_extend_in_batch: bool
can_run_dp_cuda_graph: bool
tbo_split_seq_index: Optional[int]
global_forward_mode: Optional[ForwardMode]
Expand Down
69 changes: 45 additions & 24 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,18 @@ def alloc(self, need_size: int) -> Optional[torch.Tensor]:

select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
# clear at alloc time, fill allocated slots with zeros
# clear at alloc time — expand a scalar GPU zero to the right shape, no CPU-GPU sync
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, select_index] = 0
self.mamba_cache.temporal[:, select_index] = 0
t = self.mamba_cache.conv[i]
z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
t.shape[0], need_size, *t.shape[2:]
)
t[:, select_index] = z
t = self.mamba_cache.temporal
z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
t.shape[0], need_size, *t.shape[2:]
)
t[:, select_index] = z
Comment on lines 347 to +357
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for creating a zero tensor and assigning it to a slice of the cache is duplicated for self.mamba_cache.conv and self.mamba_cache.temporal. You can refactor this into a single loop to improve code clarity and reduce duplication.

Suggested change
for i in range(len(self.mamba_cache.conv)):
self.mamba_cache.conv[i][:, select_index] = 0
self.mamba_cache.temporal[:, select_index] = 0
t = self.mamba_cache.conv[i]
z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
t.shape[0], need_size, *t.shape[2:]
)
t[:, select_index] = z
t = self.mamba_cache.temporal
z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
t.shape[0], need_size, *t.shape[2:]
)
t[:, select_index] = z
for t in self.mamba_cache.conv + [self.mamba_cache.temporal]:
z = torch.zeros(1, dtype=t.dtype, device=t.device).expand(
t.shape[0], need_size, *t.shape[2:]
)
t[:, select_index] = z


return select_index

Expand Down Expand Up @@ -514,8 +522,8 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]:
if select_index is None:
return None

mamba_index = []
mamba_ping_pong_track_buffer_list = []
mamba_indices: list[torch.Tensor] = []
mamba_ping_pong_track_buffers: list[torch.Tensor] = []
for req in reqs:
mid = None
if req.mamba_pool_idx is not None: # for radix cache
Expand All @@ -527,7 +535,7 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]:
), f"Not enough space for mamba cache, try to increase --mamba-full-memory-ratio or --max-mamba-cache-size. {mid=}, {self.mamba_pool.size=}, {self.mamba_pool.available_size()=}, {len(reqs)=}"
mid = mid[0]
req.mamba_pool_idx = mid
mamba_index.append(mid)
mamba_indices.append(mid)
if self.enable_mamba_extra_buffer:
if req.mamba_ping_pong_track_buffer is None:
req.mamba_ping_pong_track_buffer = self.mamba_pool.alloc(
Expand All @@ -537,26 +545,22 @@ def alloc(self, reqs: List["Req"]) -> Optional[List[int]]:
req.mamba_ping_pong_track_buffer is not None
), "Not enough space for mamba ping pong idx, try to increase --mamba-full-memory-ratio."
req.mamba_next_track_idx = 0
mamba_ping_pong_track_buffer_list.append(
req.mamba_ping_pong_track_buffer.tolist()
)
mamba_ping_pong_track_buffers.append(req.mamba_ping_pong_track_buffer)
assert len(select_index) == len(
mamba_index
mamba_indices
), f"Not enough space for mamba cache, try to increase --mamba-full-memory-ratio or --max-mamba-cache-size."
if self.enable_mamba_extra_buffer:
assert len(select_index) == len(
mamba_ping_pong_track_buffer_list
mamba_ping_pong_track_buffers
), f"Not enough space for mamba ping pong idx, try to increase --mamba-full-memory-ratio."
self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
mamba_index, dtype=torch.int32, device=self.device
)
mamba_index_tensor = torch.stack(mamba_indices).to(dtype=torch.int32)
self.req_index_to_mamba_index_mapping[select_index] = mamba_index_tensor
if self.enable_mamba_extra_buffer:
ping_pong_tensor = torch.stack(mamba_ping_pong_track_buffers).to(
dtype=torch.int32
)
self.req_index_to_mamba_ping_pong_track_buffer_mapping[select_index] = (
torch.tensor(
mamba_ping_pong_track_buffer_list,
dtype=torch.int32,
device=self.device,
)
ping_pong_tensor
)
return select_index

Expand Down Expand Up @@ -593,11 +597,28 @@ def free_mamba_cache(
0,
1,
], f"mamba_ping_pong_track_buffer_to_keep must be 0 or 1, {mamba_ping_pong_track_buffer_to_keep=}"
idx_to_free = list(range(self.mamba_ping_pong_track_buffer_size))
idx_to_free.remove(mamba_ping_pong_track_buffer_to_keep)
mamba_ping_pong_track_buffer_to_free = (
mamba_ping_pong_track_buffer_to_free[idx_to_free]
)
# Avoid Python-list advanced indexing on a device tensor.
# The ping-pong buffer size is either 2 (normal) or 1 (spec decode).
if self.mamba_ping_pong_track_buffer_size == 2:
idx_to_free = 1 - mamba_ping_pong_track_buffer_to_keep
mamba_ping_pong_track_buffer_to_free = (
mamba_ping_pong_track_buffer_to_free[
idx_to_free : idx_to_free + 1
]
)
else:
assert self.mamba_ping_pong_track_buffer_size == 1, (
f"Unexpected mamba_ping_pong_track_buffer_size="
f"{self.mamba_ping_pong_track_buffer_size}"
)
assert mamba_ping_pong_track_buffer_to_keep == 0, (
"mamba_ping_pong_track_buffer_to_keep must be 0 when "
"mamba_ping_pong_track_buffer_size is 1"
)
# Keep the only slot, so free nothing.
mamba_ping_pong_track_buffer_to_free = (
mamba_ping_pong_track_buffer_to_free[0:0]
)
self.mamba_pool.free(mamba_ping_pong_track_buffer_to_free)

def clear(self):
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin):
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
global_dp_buffer_len: Optional[int] = None
is_extend_in_batch: bool = False
all_extend_in_batch: bool = False
can_run_dp_cuda_graph: bool = False
global_forward_mode: Optional[ForwardMode] = None

Expand Down Expand Up @@ -404,6 +405,7 @@ def init_new(
top_logprobs_nums=batch.top_logprobs_nums,
token_ids_logprobs=batch.token_ids_logprobs,
is_extend_in_batch=batch.is_extend_in_batch,
all_extend_in_batch=batch.all_extend_in_batch,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
global_forward_mode=batch.global_forward_mode,
is_prefill_only=batch.is_prefill_only,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,7 @@ def update_weights_from_disk(
"""Update engine weights in-place from the disk."""
logger.info(
f"Update engine weights online from disk begin. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id, empty_cache=False):.2f} GB"
)

target_device = torch.device(self.device)
Expand Down
Loading