Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ class Envs:
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)

# RoPE cache configuration
SGLANG_SPEC_EXPANSION_SAFETY_FACTOR = EnvInt(2)
SGLANG_ROPE_CACHE_SAFETY_MARGIN = EnvInt(256)
SGLANG_ROPE_CACHE_ALIGN = EnvInt(128)

# Overlap Spec V2
SGLANG_ENABLE_SPEC_V2 = EnvBool(False)
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)
Expand Down
33 changes: 33 additions & 0 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,39 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
cache = torch.cat((cos, sin), dim=-1)
return cache

def _ensure_cos_sin_cache_length(self, needed_max_pos: int):
"""Ensure cos_sin_cache length > needed_max_pos."""
from sglang.srt.environ import envs

cur_len = int(self.cos_sin_cache.shape[0])
if needed_max_pos < cur_len:
return

# Align to reduce realloc frequency
align = envs.SGLANG_ROPE_CACHE_ALIGN.value
new_len = ((needed_max_pos + align) // align) * align
device = self.cos_sin_cache.device
dtype = self.cos_sin_cache.dtype

# Compute inv_freq on same device
inv_freq = self._compute_inv_freq(self.base).to(device=device)

# Incremental computation for new positions only
start = cur_len
t_new = torch.arange(start, new_len, dtype=inv_freq.dtype, device=device)
if t_new.numel() == 0:
return

freqs_new = torch.einsum("i,j->ij", t_new, inv_freq)
cos_new = freqs_new.cos()
sin_new = freqs_new.sin()
new_rows = torch.cat((cos_new, sin_new), dim=-1).to(dtype=dtype)

# Update cache with new rows
self.cos_sin_cache = torch.cat((self.cos_sin_cache, new_rows), dim=0).to(
device=device, dtype=dtype
)

def forward_native(
self,
positions: torch.Tensor,
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
is_npu,
log_info_on_rank0,
monkey_patch_p2p_access_check,
reserve_rope_cache_for_long_sequences,
set_cuda_arch,
slow_rank_detector,
xpu_has_xmx_support,
Expand Down Expand Up @@ -862,6 +863,14 @@ def load_model(self):
self.pp_rank,
)

# Pre-expand RoPE cache before CUDA Graph capture
reserve_rope_cache_for_long_sequences(
self.model,
self.server_args,
self.model_config,
logger,
)

if self.server_args.elastic_ep_backend == "mooncake":
# Mooncake does not support `monitored_barrier`
dist.barrier(group=get_tp_group().cpu_group)
Expand Down
55 changes: 55 additions & 0 deletions python/sglang/srt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3633,6 +3633,61 @@ def decorator(fn):
return decorator


def reserve_rope_cache_for_long_sequences(
model, server_args, model_config, logger=None
):
"""Pre-expand RoPE cache for long sequences and speculative decoding."""
from sglang.srt.environ import envs

if logger is None:
import logging

logger = logging.getLogger(__name__)

SAFETY_FACTOR = envs.SGLANG_SPEC_EXPANSION_SAFETY_FACTOR.value
MARGIN = envs.SGLANG_ROPE_CACHE_SAFETY_MARGIN.value
ALIGN = envs.SGLANG_ROPE_CACHE_ALIGN.value

# 1) Estimate base context upper bound
base_ctx = (
getattr(server_args, "context_length", None)
or getattr(model_config, "context_len", None)
or getattr(model_config, "max_model_len", None)
or getattr(model_config.hf_text_config, "max_position_embeddings", None)
or 2048
)

# 2) Speculative decoding expansion
steps = int(getattr(server_args, "speculative_num_steps", 0) or 0)
draft = int(getattr(server_args, "speculative_num_draft_tokens", 0) or 0)
reserve = base_ctx + steps * draft * SAFETY_FACTOR + MARGIN

# 3) Align to reduce reallocation frequency
reserve = (reserve + ALIGN - 1) // ALIGN * ALIGN

logger.info(
f"RoPE cache reserve={reserve} (cap={base_ctx}, steps={steps}, draft={draft}, k={SAFETY_FACTOR}, margin={MARGIN})"
)

# Recursively expand all RoPE layers
def reserve_rope_cache_recursive(module):
for child in module.children():
if hasattr(child, "_ensure_cos_sin_cache_length") and hasattr(
child, "cos_sin_cache"
):
old_len = child.cos_sin_cache.shape[0]
child._ensure_cos_sin_cache_length(reserve - 1)
new_len = child.cos_sin_cache.shape[0]
if new_len > old_len:
logger.info(
f"Expanded RoPE cache from {old_len} to {new_len} positions"
)
else:
reserve_rope_cache_recursive(child)

reserve_rope_cache_recursive(model)


# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
def calc_diff(x, y):
x, y = x.double(), y.double()
Expand Down
14 changes: 7 additions & 7 deletions test/srt/test_priority_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,35 +291,35 @@ def test_priority_scheduling_with_multiple_running_requests_preemption(self):
def test_priority_scheduling_preemption_token_offset_calculation(self):
"""
Verify correct token offset calculation during preemption.

This test specifically targets the bug where rem_total_token_offset was incorrectly
calculated using the incoming request's tokens instead of the preempted request's tokens
(related to issue #13111 and PR #13201).

THE BUG:
In schedule_policy.py line 700, the code was using:
self.rem_total_token_offset -= self._get_running_request_total_token_offset(req)
Instead of:
self.rem_total_token_offset -= self._get_running_request_total_token_offset(running_req)

WHY THIS TEST CATCHES THE BUG:
- Request 1 (preempted): 8000 tokens - This is what SHOULD be freed
- Request 3 (incoming): 1000 tokens - This is what WAS freed (bug)
- Token difference: 8000 - 1000 = 7000 tokens incorrectly accounted

With the bug, the system thinks it only freed 1000 tokens instead of 8000 tokens.
This causes incorrect memory accounting and can lead to:
1. Scheduler believes less memory is available than actually is
2. Subsequent requests (like Request 4) may fail to schedule or cause issues
3. Memory calculations become increasingly inaccurate with each preemption

The test creates a scenario where:
1. A low-priority request with many tokens (8000) starts running
2. A high-priority request with few tokens (1000) arrives and triggers preemption
3. The system must correctly free 8000 tokens from the preempted request
4. Additional requests can be scheduled only if tokens were correctly freed
5. Execution order validates priority-based scheduling works correctly

The large token difference (8x) makes the bug's impact obvious and testable.
"""
responses = asyncio.run(
Expand Down Expand Up @@ -360,7 +360,7 @@ def test_priority_scheduling_preemption_token_offset_calculation(self):
_verify_genereate_responses(
responses, expected_status_and_error_messages, e2e_latencies
)

# Verify execution order: high priority requests finish before low priority ones
# Request 3 (priority 100) should finish first
# Request 4 (priority 50) should finish second
Expand Down
Loading