diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 9c57b428f44f..dd18da48a43b 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -277,6 +277,11 @@ class Envs: SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096) SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256) + # RoPE cache configuration + SGLANG_SPEC_EXPANSION_SAFETY_FACTOR = EnvInt(2) + SGLANG_ROPE_CACHE_SAFETY_MARGIN = EnvInt(256) + SGLANG_ROPE_CACHE_ALIGN = EnvInt(128) + # Overlap Spec V2 SGLANG_ENABLE_SPEC_V2 = EnvBool(False) SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 8678aaef364c..b14ceaed17f8 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -169,6 +169,39 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: cache = torch.cat((cos, sin), dim=-1) return cache + def _ensure_cos_sin_cache_length(self, needed_max_pos: int): + """Ensure cos_sin_cache length > needed_max_pos.""" + from sglang.srt.environ import envs + + cur_len = int(self.cos_sin_cache.shape[0]) + if needed_max_pos < cur_len: + return + + # Align to reduce realloc frequency + align = envs.SGLANG_ROPE_CACHE_ALIGN.value + new_len = ((needed_max_pos + align) // align) * align + device = self.cos_sin_cache.device + dtype = self.cos_sin_cache.dtype + + # Compute inv_freq on same device + inv_freq = self._compute_inv_freq(self.base).to(device=device) + + # Incremental computation for new positions only + start = cur_len + t_new = torch.arange(start, new_len, dtype=inv_freq.dtype, device=device) + if t_new.numel() == 0: + return + + freqs_new = torch.einsum("i,j->ij", t_new, inv_freq) + cos_new = freqs_new.cos() + sin_new = freqs_new.sin() + new_rows = torch.cat((cos_new, sin_new), dim=-1).to(dtype=dtype) + + # Update cache with new rows + self.cos_sin_cache = torch.cat((self.cos_sin_cache, new_rows), dim=0).to( + device=device, dtype=dtype + ) + def forward_native( self, positions: torch.Tensor, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b95759af489d..47fde96ecd2e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -155,6 +155,7 @@ is_npu, log_info_on_rank0, monkey_patch_p2p_access_check, + reserve_rope_cache_for_long_sequences, set_cuda_arch, slow_rank_detector, xpu_has_xmx_support, @@ -862,6 +863,14 @@ def load_model(self): self.pp_rank, ) + # Pre-expand RoPE cache before CUDA Graph capture + reserve_rope_cache_for_long_sequences( + self.model, + self.server_args, + self.model_config, + logger, + ) + if self.server_args.elastic_ep_backend == "mooncake": # Mooncake does not support `monitored_barrier` dist.barrier(group=get_tp_group().cpu_group) diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 0ac752b93fb9..08560eae131a 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -3633,6 +3633,61 @@ def decorator(fn): return decorator +def reserve_rope_cache_for_long_sequences( + model, server_args, model_config, logger=None +): + """Pre-expand RoPE cache for long sequences and speculative decoding.""" + from sglang.srt.environ import envs + + if logger is None: + import logging + + logger = logging.getLogger(__name__) + + SAFETY_FACTOR = envs.SGLANG_SPEC_EXPANSION_SAFETY_FACTOR.value + MARGIN = envs.SGLANG_ROPE_CACHE_SAFETY_MARGIN.value + ALIGN = envs.SGLANG_ROPE_CACHE_ALIGN.value + + # 1) Estimate base context upper bound + base_ctx = ( + getattr(server_args, "context_length", None) + or getattr(model_config, "context_len", None) + or getattr(model_config, "max_model_len", None) + or getattr(model_config.hf_text_config, "max_position_embeddings", None) + or 2048 + ) + + # 2) Speculative decoding expansion + steps = int(getattr(server_args, "speculative_num_steps", 0) or 0) + draft = int(getattr(server_args, "speculative_num_draft_tokens", 0) or 0) + reserve = base_ctx + steps * draft * SAFETY_FACTOR + MARGIN + + # 3) Align to reduce reallocation frequency + reserve = (reserve + ALIGN - 1) // ALIGN * ALIGN + + logger.info( + f"RoPE cache reserve={reserve} (cap={base_ctx}, steps={steps}, draft={draft}, k={SAFETY_FACTOR}, margin={MARGIN})" + ) + + # Recursively expand all RoPE layers + def reserve_rope_cache_recursive(module): + for child in module.children(): + if hasattr(child, "_ensure_cos_sin_cache_length") and hasattr( + child, "cos_sin_cache" + ): + old_len = child.cos_sin_cache.shape[0] + child._ensure_cos_sin_cache_length(reserve - 1) + new_len = child.cos_sin_cache.shape[0] + if new_len > old_len: + logger.info( + f"Expanded RoPE cache from {old_len} to {new_len} positions" + ) + else: + reserve_rope_cache_recursive(child) + + reserve_rope_cache_recursive(model) + + # Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py def calc_diff(x, y): x, y = x.double(), y.double() diff --git a/test/srt/test_priority_scheduling.py b/test/srt/test_priority_scheduling.py index 923bcb2d32eb..88d955fa0571 100644 --- a/test/srt/test_priority_scheduling.py +++ b/test/srt/test_priority_scheduling.py @@ -291,35 +291,35 @@ def test_priority_scheduling_with_multiple_running_requests_preemption(self): def test_priority_scheduling_preemption_token_offset_calculation(self): """ Verify correct token offset calculation during preemption. - + This test specifically targets the bug where rem_total_token_offset was incorrectly calculated using the incoming request's tokens instead of the preempted request's tokens (related to issue #13111 and PR #13201). - + THE BUG: In schedule_policy.py line 700, the code was using: self.rem_total_token_offset -= self._get_running_request_total_token_offset(req) Instead of: self.rem_total_token_offset -= self._get_running_request_total_token_offset(running_req) - + WHY THIS TEST CATCHES THE BUG: - Request 1 (preempted): 8000 tokens - This is what SHOULD be freed - Request 3 (incoming): 1000 tokens - This is what WAS freed (bug) - Token difference: 8000 - 1000 = 7000 tokens incorrectly accounted - + With the bug, the system thinks it only freed 1000 tokens instead of 8000 tokens. This causes incorrect memory accounting and can lead to: 1. Scheduler believes less memory is available than actually is 2. Subsequent requests (like Request 4) may fail to schedule or cause issues 3. Memory calculations become increasingly inaccurate with each preemption - + The test creates a scenario where: 1. A low-priority request with many tokens (8000) starts running 2. A high-priority request with few tokens (1000) arrives and triggers preemption 3. The system must correctly free 8000 tokens from the preempted request 4. Additional requests can be scheduled only if tokens were correctly freed 5. Execution order validates priority-based scheduling works correctly - + The large token difference (8x) makes the bug's impact obvious and testable. """ responses = asyncio.run( @@ -360,7 +360,7 @@ def test_priority_scheduling_preemption_token_offset_calculation(self): _verify_genereate_responses( responses, expected_status_and_error_messages, e2e_latencies ) - + # Verify execution order: high priority requests finish before low priority ones # Request 3 (priority 100) should finish first # Request 4 (priority 50) should finish second