diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index a99b6b16cc4d..d6d1172e2b38 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -362,7 +362,7 @@ def __init__( self.is_hybrid_swa = isinstance( self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator ) - self.is_hybrid_gdn_cache = isinstance(self.tree_cache, MambaRadixCache) + self.is_ssm_radix_cache = isinstance(self.tree_cache, MambaRadixCache) self.priority_scheduling_preemption_threshold = ( priority_scheduling_preemption_threshold @@ -387,7 +387,7 @@ def rem_total_tokens(self): self.token_to_kv_pool_allocator.swa_available_size() + self.tree_cache.swa_evictable_size(), ) - elif self.is_hybrid_gdn_cache: + elif self.is_ssm_radix_cache: available_and_evictable = ( self.token_to_kv_pool_allocator.available_size() + self.tree_cache.full_evictable_size() @@ -409,7 +409,7 @@ def cur_rem_tokens(self): self.token_to_kv_pool_allocator.swa_available_size() + self.tree_cache.swa_evictable_size(), ) - elif self.is_hybrid_gdn_cache: + elif self.is_ssm_radix_cache: available_and_evictable = ( self.token_to_kv_pool_allocator.available_size() + self.tree_cache.full_evictable_size() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 59ac30e5b0fc..b801fd8f8e63 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -398,7 +398,10 @@ def __init__( # Hybrid memory pool self.is_hybrid_swa = self.tp_worker.is_hybrid_swa - self.is_hybrid_gdn = self.tp_worker.model_runner.hybrid_gdn_config is not None + self.is_ssm_model = ( + self.tp_worker.model_runner.hybrid_gdn_config is not None + or self.tp_worker.model_runner.mamba2_config is not None + ) if self.is_hybrid_swa: self.sliding_window_size = self.tp_worker.sliding_window_size @@ -762,7 +765,7 @@ def init_cache_with_memory_pool(self): self.tree_cache = SWARadixCache( params=params, sliding_window_size=self.sliding_window_size ) - elif self.is_hybrid_gdn: + elif self.is_ssm_model: from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache self.tree_cache = MambaRadixCache(params) diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index 853341d6e22c..f121ed3d4067 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -112,7 +112,7 @@ def log_prefill_stats( f"full token usage: {full_token_usage:.2f}, " f"swa token usage: {swa_token_usage:.2f}, " ) - elif self.is_hybrid_gdn: + elif self.is_ssm_model: ( full_num_used, _, @@ -166,7 +166,7 @@ def log_prefill_stats( self.stats.token_usage = token_usage if self.is_hybrid_swa: self.stats.swa_token_usage = swa_token_usage - if self.is_hybrid_gdn: + if self.is_ssm_model: self.stats.mamba_usage = mamba_usage self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_grammar_queue_reqs = len(self.grammar_queue) @@ -238,7 +238,7 @@ def log_decode_stats( f"#swa token: {swa_num_used}, " f"swa token usage: {swa_token_usage:.2f}, " ) - elif self.is_hybrid_gdn: + elif self.is_ssm_model: ( full_num_used, mamba_used, @@ -315,7 +315,7 @@ def log_decode_stats( self.stats.token_usage = token_usage if self.is_hybrid_swa: self.stats.swa_token_usage = swa_token_usage - if self.is_hybrid_gdn: + if self.is_ssm_model: self.stats.mamba_usage = mamba_usage self.stats.gen_throughput = self.last_gen_throughput self.stats.num_queue_reqs = len(self.waiting_queue) @@ -401,7 +401,7 @@ def get_load(self: Scheduler, _: GetLoadReqInput = None) -> GetLoadReqOutput: if self.is_hybrid_swa: full_num_used, swa_num_used, *_ = self._get_swa_token_info() num_tokens = max(full_num_used, swa_num_used) - elif self.is_hybrid_gdn: + elif self.is_ssm_model: num_tokens = self._get_mamba_token_info()[0] else: num_tokens = self._get_token_info()[0] diff --git a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py index 70ffc90bf8fb..09f1ba548daa 100644 --- a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py +++ b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py @@ -204,7 +204,7 @@ def _check_req_pool(self: Scheduler): def check_memory(self: Scheduler): if self.is_hybrid_swa: memory_leak, token_msg = self._check_hybrid_memory() - elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache): + elif self.is_ssm_model and isinstance(self.tree_cache, MambaRadixCache): memory_leak, token_msg = self._check_mamba_memory() else: memory_leak, token_msg = self._check_radix_cache_memory() @@ -239,7 +239,7 @@ def check_memory(self: Scheduler): ) = self._get_swa_token_info() num_used = max(full_num_used, swa_num_used) token_usage = max(full_token_usage, swa_token_usage) - elif self.is_hybrid_gdn: + elif self.is_ssm_model: ( num_used, _, @@ -278,7 +278,7 @@ def check_memory(self: Scheduler): def check_tree_cache(self: Scheduler): if (self.is_hybrid_swa and isinstance(self.tree_cache, SWARadixCache)) or ( - self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache) + self.is_ssm_model and isinstance(self.tree_cache, MambaRadixCache) ): self.tree_cache.sanity_check() @@ -322,7 +322,7 @@ def watchdog_thread(self: Scheduler): # Print batch size and memory pool info to check whether there are de-sync issues. if self.is_hybrid_swa: _, info_msg = self._check_hybrid_memory() - elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache): + elif self.is_ssm_model and isinstance(self.tree_cache, MambaRadixCache): _, info_msg = self._check_mamba_memory() else: _, info_msg = self._check_radix_cache_memory() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 06f384e0ed65..1f628e0f3d53 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -441,11 +441,6 @@ def initialize(self, min_per_gpu_memory: float): if architectures and not any("Llama4" in arch for arch in architectures): self.is_hybrid_swa = self.model_config.is_hybrid_swa = True - if config := self.mamba2_config: - class_name = config.__class__.__name__ - logger.warning(f"{class_name} model detected, disable radix cache") - self.server_args.disable_radix_cache = True - # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to # determine the number of layers. diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a07b71c5a02e..33fda34750c9 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1279,6 +1279,18 @@ def _handle_model_specific_adjustments(self): logger.info( "Use flashinfer_trtllm as MoE runner backend on sm100 for Qwen3NextForCausalLM" ) + elif model_arch in [ + "NemotronHForCausalLM", + "FalconH1ForCausalLM", + "JetNemotronForCausalLM", + "JetVLMForConditionalGeneration", + ]: + if not self.disable_radix_cache: + logger.warning( + "Disabling overlap schedule since MambaRadixCache is not compatible with " + "overlap schedule currently, try to use --disable-radix-cache if overlap schedule is necessary" + ) + self.disable_overlap_schedule = True def _handle_sampling_backend(self): if self.sampling_backend is None: