Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def __init__(
self.is_hybrid_swa = isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
self.is_hybrid_gdn_cache = isinstance(self.tree_cache, MambaRadixCache)
self.is_ssm_radix_cache = isinstance(self.tree_cache, MambaRadixCache)

self.priority_scheduling_preemption_threshold = (
priority_scheduling_preemption_threshold
Expand All @@ -387,7 +387,7 @@ def rem_total_tokens(self):
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
elif self.is_hybrid_gdn_cache:
elif self.is_ssm_radix_cache:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.full_evictable_size()
Expand All @@ -409,7 +409,7 @@ def cur_rem_tokens(self):
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
elif self.is_hybrid_gdn_cache:
elif self.is_ssm_radix_cache:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.full_evictable_size()
Expand Down
7 changes: 5 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,10 @@ def __init__(

# Hybrid memory pool
self.is_hybrid_swa = self.tp_worker.is_hybrid_swa
self.is_hybrid_gdn = self.tp_worker.model_runner.hybrid_gdn_config is not None
self.is_ssm_model = (
self.tp_worker.model_runner.hybrid_gdn_config is not None
or self.tp_worker.model_runner.mamba2_config is not None
)

if self.is_hybrid_swa:
self.sliding_window_size = self.tp_worker.sliding_window_size
Expand Down Expand Up @@ -762,7 +765,7 @@ def init_cache_with_memory_pool(self):
self.tree_cache = SWARadixCache(
params=params, sliding_window_size=self.sliding_window_size
)
elif self.is_hybrid_gdn:
elif self.is_ssm_model:
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache

self.tree_cache = MambaRadixCache(params)
Expand Down
10 changes: 5 additions & 5 deletions python/sglang/srt/managers/scheduler_metrics_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def log_prefill_stats(
f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
elif self.is_hybrid_gdn:
elif self.is_ssm_model:
(
full_num_used,
_,
Expand Down Expand Up @@ -166,7 +166,7 @@ def log_prefill_stats(
self.stats.token_usage = token_usage
if self.is_hybrid_swa:
self.stats.swa_token_usage = swa_token_usage
if self.is_hybrid_gdn:
if self.is_ssm_model:
self.stats.mamba_usage = mamba_usage
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
Expand Down Expand Up @@ -238,7 +238,7 @@ def log_decode_stats(
f"#swa token: {swa_num_used}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
elif self.is_hybrid_gdn:
elif self.is_ssm_model:
(
full_num_used,
mamba_used,
Expand Down Expand Up @@ -315,7 +315,7 @@ def log_decode_stats(
self.stats.token_usage = token_usage
if self.is_hybrid_swa:
self.stats.swa_token_usage = swa_token_usage
if self.is_hybrid_gdn:
if self.is_ssm_model:
self.stats.mamba_usage = mamba_usage
self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
Expand Down Expand Up @@ -401,7 +401,7 @@ def get_load(self: Scheduler, _: GetLoadReqInput = None) -> GetLoadReqOutput:
if self.is_hybrid_swa:
full_num_used, swa_num_used, *_ = self._get_swa_token_info()
num_tokens = max(full_num_used, swa_num_used)
elif self.is_hybrid_gdn:
elif self.is_ssm_model:
num_tokens = self._get_mamba_token_info()[0]
else:
num_tokens = self._get_token_info()[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _check_req_pool(self: Scheduler):
def check_memory(self: Scheduler):
if self.is_hybrid_swa:
memory_leak, token_msg = self._check_hybrid_memory()
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
elif self.is_ssm_model and isinstance(self.tree_cache, MambaRadixCache):
memory_leak, token_msg = self._check_mamba_memory()
else:
memory_leak, token_msg = self._check_radix_cache_memory()
Expand Down Expand Up @@ -239,7 +239,7 @@ def check_memory(self: Scheduler):
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
elif self.is_hybrid_gdn:
elif self.is_ssm_model:
(
num_used,
_,
Expand Down Expand Up @@ -278,7 +278,7 @@ def check_memory(self: Scheduler):

def check_tree_cache(self: Scheduler):
if (self.is_hybrid_swa and isinstance(self.tree_cache, SWARadixCache)) or (
self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
self.is_ssm_model and isinstance(self.tree_cache, MambaRadixCache)
):
self.tree_cache.sanity_check()

Expand Down Expand Up @@ -322,7 +322,7 @@ def watchdog_thread(self: Scheduler):
# Print batch size and memory pool info to check whether there are de-sync issues.
if self.is_hybrid_swa:
_, info_msg = self._check_hybrid_memory()
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
elif self.is_ssm_model and isinstance(self.tree_cache, MambaRadixCache):
_, info_msg = self._check_mamba_memory()
else:
_, info_msg = self._check_radix_cache_memory()
Expand Down
5 changes: 0 additions & 5 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,11 +441,6 @@ def initialize(self, min_per_gpu_memory: float):
if architectures and not any("Llama4" in arch for arch in architectures):
self.is_hybrid_swa = self.model_config.is_hybrid_swa = True

if config := self.mamba2_config:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can not enable overlap schedule since it is not fully compatible with MambaRadixCache currently. Fork mamba cache in MambaRadixCache has data race with forward in next batch.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, and I see that the check in the server args initialization only handles Qwen3-Next models. This means that it won't be disabled for JetNemotron models, which also uses GDN attention. I'll update the check there to check for all currently supported architectures, but a better approach is needed there, I think

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I think I forgot to disable for JetNemotron models. It also needs to be disabled.

class_name = config.__class__.__name__
logger.warning(f"{class_name} model detected, disable radix cache")
self.server_args.disable_radix_cache = True

# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
# determine the number of layers.
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,18 @@ def _handle_model_specific_adjustments(self):
logger.info(
"Use flashinfer_trtllm as MoE runner backend on sm100 for Qwen3NextForCausalLM"
)
elif model_arch in [
"NemotronHForCausalLM",
"FalconH1ForCausalLM",
"JetNemotronForCausalLM",
"JetVLMForConditionalGeneration",
]:
if not self.disable_radix_cache:
logger.warning(
"Disabling overlap schedule since MambaRadixCache is not compatible with "
"overlap schedule currently, try to use --disable-radix-cache if overlap schedule is necessary"
)
self.disable_overlap_schedule = True

def _handle_sampling_backend(self):
if self.sampling_backend is None:
Expand Down
Loading