diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index cd62f3f76d52..3cd9cd661092 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -313,8 +313,7 @@ def _derive_hybrid_model(self): self.swa_attention_layer_ids, self.full_attention_layer_ids = ( get_hybrid_layer_ids( self.hf_config.architectures, - self.hf_text_config.num_hidden_layers, - getattr(self.hf_text_config, "hybrid_layer_pattern", None), + self.hf_text_config, ) ) @@ -1197,6 +1196,7 @@ def is_hybrid_swa_model(model_architectures: List[str]): hybrid_swa_archs = { "Llama4ForConditionalGeneration", + "GptOssForCausalLM", "MiMoV2FlashForCausalLM", "MiMoV2MTP", } @@ -1205,9 +1205,9 @@ def is_hybrid_swa_model(model_architectures: List[str]): def get_hybrid_layer_ids( model_architectures: List[str], - num_hidden_layers: int, - hybrid_layer_pattern: Optional[List[int]] = None, + hf_text_config: PretrainedConfig, ): + num_hidden_layers = hf_text_config.num_hidden_layers if "Llama4ForConditionalGeneration" in model_architectures: swa_attention_layer_ids = [ i for i in range(num_hidden_layers) if (i + 1) % 4 != 0 @@ -1215,7 +1215,16 @@ def get_hybrid_layer_ids( full_attention_layer_ids = [ i for i in range(num_hidden_layers) if (i + 1) % 4 == 0 ] + elif "GptOssForCausalLM" in model_architectures: + layer_types = getattr(hf_text_config, "layer_types", None) + swa_attention_layer_ids = [ + i for i, x in enumerate(layer_types) if x == "sliding_attention" + ] + full_attention_layer_ids = [ + i for i, x in enumerate(layer_types) if x == "full_attention" + ] elif "MiMoV2FlashForCausalLM" in model_architectures: + hybrid_layer_pattern = getattr(hf_text_config, "hybrid_layer_pattern", None) swa_attention_layer_ids = [ i for i in range(num_hidden_layers) if hybrid_layer_pattern[i] == 1 ] diff --git a/python/sglang/srt/models/utils.py b/python/sglang/srt/models/utils.py index ea2983edc820..a742184620d5 100644 --- a/python/sglang/srt/models/utils.py +++ b/python/sglang/srt/models/utils.py @@ -24,6 +24,7 @@ from sglang.jit_kernel.norm import can_use_fused_inplace_qknorm, fused_inplace_qknorm from sglang.srt.environ import envs from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import get_current_device_stream_fast, is_cuda @@ -109,6 +110,7 @@ def enable_fused_set_kv_buffer(forward_batch: ForwardBatch): _is_cuda and hasattr(forward_batch.token_to_kv_pool, "dtype") and forward_batch.token_to_kv_pool.dtype == torch.bfloat16 + and not isinstance(forward_batch.token_to_kv_pool, SWAKVPool) ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 75b79a15b6dd..6fd8036540f4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1261,6 +1261,16 @@ def _handle_model_specific_adjustments(self): f"- Decode: {decode_attn_backend}\n" ) + if ( + prefill_attn_backend == "trtllm_mha" + or decode_attn_backend == "trtllm_mha" + ): + # TODO: support swa kv indices translation for trtllm_mha attention backend + self.swa_full_tokens_ratio = 1.0 + logger.warning( + "Set swa_full_tokens_ratio to 1.0 for GPT-OSS model with trtllm_mha attention backend." + ) + quant_method = get_quantization_config(hf_config) is_mxfp4_quant_format = quant_method == "mxfp4" if is_mxfp4_quant_format: @@ -1288,7 +1298,6 @@ def _handle_model_specific_adjustments(self): assert ( self.ep_size == 1 ), "Triton kernel MoE is only supported when ep_size == 1" - self.disable_hybrid_swa_memory = True elif "MiMoV2FlashForCausalLM" in model_arch: if self.speculative_algorithm == "EAGLE":