Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ class Envs:
SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE = EnvInt(4096)
SGLANG_TRITON_DECODE_SPLIT_TILE_SIZE = EnvInt(256)

# RoPE cache configuration
SGLANG_SPEC_EXPANSION_SAFETY_FACTOR = EnvInt(2)
SGLANG_ROPE_CACHE_SAFETY_MARGIN = EnvInt(256)
SGLANG_ROPE_CACHE_ALIGN = EnvInt(128)

# Overlap Spec V2
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)

Expand Down
30 changes: 30 additions & 0 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,36 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
cache = torch.cat((cos, sin), dim=-1)
return cache

def _ensure_cos_sin_cache_length(self, needed_max_pos: int):
"""Ensure cos_sin_cache length > needed_max_pos."""
cur_len = int(self.cos_sin_cache.shape[0])
if needed_max_pos < cur_len:
return

# Align to 128 to reduce realloc frequency
new_len = ((needed_max_pos + 128) // 128) * 128
device = self.cos_sin_cache.device
dtype = self.cos_sin_cache.dtype

# Compute inv_freq on same device
inv_freq = self._compute_inv_freq(self.base).to(device=device)

# Incremental computation for new positions only
start = cur_len
t_new = torch.arange(start, new_len, dtype=inv_freq.dtype, device=device)
if t_new.numel() == 0:
return

freqs_new = torch.einsum("i,j->ij", t_new, inv_freq)
cos_new = freqs_new.cos()
sin_new = freqs_new.sin()
new_rows = torch.cat((cos_new, sin_new), dim=-1).to(dtype=dtype)

# Update cache with new rows
self.cos_sin_cache = torch.cat((self.cos_sin_cache, new_rows), dim=0).to(
device=device, dtype=dtype
)

def forward_native(
self,
positions: torch.Tensor,
Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
log_info_on_rank0,
monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config,
reserve_rope_cache_for_long_sequences,
set_cuda_arch,
slow_rank_detector,
)
Expand Down Expand Up @@ -898,6 +899,15 @@ def load_model(self):
f"mem usage={self.weight_load_mem_usage:.2f} GB."
)

# Pre-expand RoPE cache before CUDA Graph capture
reserve_rope_cache_for_long_sequences(
self.model,
self.server_args,
self.model_config,
self.req_to_token_pool,
logger,
)

if self.server_args.elastic_ep_backend == "mooncake":
# Mooncake does not support `monitored_barrier`
dist.barrier(group=get_tp_group().cpu_group)
Expand Down
58 changes: 58 additions & 0 deletions python/sglang/srt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3460,3 +3460,61 @@ def decorator(fn):
return CachedKernel(fn, key_fn)

return decorator


def reserve_rope_cache_for_long_sequences(
model, server_args, model_config, req_to_token_pool=None, logger=None
):
"""Pre-expand RoPE cache for long sequences and speculative decoding."""
from sglang.srt.environ import envs

if logger is None:
import logging

logger = logging.getLogger(__name__)

SAFETY_FACTOR = envs.SGLANG_SPEC_EXPANSION_SAFETY_FACTOR.value
MARGIN = envs.SGLANG_ROPE_CACHE_SAFETY_MARGIN.value
ALIGN = envs.SGLANG_ROPE_CACHE_ALIGN.value

# 1) Estimate base context upper bound
base_ctx = (
getattr(server_args, "context_length", None)
or getattr(model_config, "context_len", None)
or getattr(model_config, "max_model_len", None)
or getattr(model_config.hf_text_config, "max_position_embeddings", None)
or 2048
)

# 2) Runtime input capacity (including extra_len from req_to_token_pool)
inferred_cap = getattr(req_to_token_pool, "max_context_len", None) or base_ctx

# 3) Speculative decoding expansion
steps = int(getattr(server_args, "speculative_num_steps", 0) or 0)
draft = int(getattr(server_args, "speculative_num_draft_tokens", 0) or 0)
reserve = inferred_cap + steps * draft * SAFETY_FACTOR + MARGIN

# 4) Align to reduce reallocation frequency
reserve = (reserve + ALIGN - 1) // ALIGN * ALIGN

logger.info(
f"RoPE cache reserve={reserve} (base={base_ctx}, cap={inferred_cap}, steps={steps}, draft={draft}, k={SAFETY_FACTOR}, margin={MARGIN})"
)

# Recursively expand all RoPE layers
def reserve_rope_cache_recursive(module):
for child in module.children():
if hasattr(child, "_ensure_cos_sin_cache_length") and hasattr(
child, "cos_sin_cache"
):
old_len = child.cos_sin_cache.shape[0]
child._ensure_cos_sin_cache_length(reserve - 1)
new_len = child.cos_sin_cache.shape[0]
if new_len > old_len:
logger.info(
f"Expanded RoPE cache from {old_len} to {new_len} positions"
)
else:
reserve_rope_cache_recursive(child)

reserve_rope_cache_recursive(model)
Loading