diff --git a/python/sglang/srt/layers/gemma4_fused_ops.py b/python/sglang/srt/layers/gemma4_fused_ops.py index e30027776bb3..127c42b29573 100644 --- a/python/sglang/srt/layers/gemma4_fused_ops.py +++ b/python/sglang/srt/layers/gemma4_fused_ops.py @@ -93,6 +93,76 @@ def gemma_rmsnorm_residual_scalar( return out +def gemma4_arf_rmsnorm_residual_scalar( + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor, + scalar: torch.Tensor, + eps: float = 1e-6, + use_attn_tp_group: bool = True, +) -> torch.Tensor: + """Fused TP all-reduce + (rmsnorm(x) + residual) * scalar for Gemma-4 + dense post-FF combine. + + Numerically equivalent to:: + + x_reduced = tensor_model_parallel_all_reduce(x) + return gemma_rmsnorm_residual_scalar(x_reduced, weight, residual, scalar, eps) + + but, when FlashInfer's fused AllReduce+RMSNorm pattern is applicable on + this step (Hopper/Blackwell, ``--enable-flashinfer-allreduce-fusion``, + batch <= ``FUSE_ALLREDUCE_MAX_BATCH_SIZE``, workspace healthy, etc.), + collapses the TP all-reduce and the residual-add+RMSNorm into a single + TRT-LLM communication kernel that overlaps the collective with the norm + math. The final ``* scalar`` tail runs as a one-launch broadcast mul + (cheap; vectorized point-wise op). + + Caller contract: + * The caller is responsible for passing ``skip_all_reduce=True`` to + the upstream ``RowParallelLinear`` whose output is ``x`` so the + all-reduce is not double-counted. + * ``x`` must be the still-TP-sharded output of that ``down_proj`` + (i.e. the value RowParallelLinear would have all-reduced). + * ``residual`` is the full pre-FF hidden state (already replicated). + * ``scalar`` is the Gemma-4 ``layer_scalar`` persistent buffer + (shape ``[1]``). + * ``use_attn_tp_group=True`` selects the attention-TP group's + FlashInfer workspace; for Gemma-4 (no DP-attn, no MoE-TP split) + this is the full TP group. + + When the fused path is not applicable, falls back to the explicit + ``tensor_model_parallel_all_reduce`` + ``gemma_rmsnorm_residual_scalar`` + sequence with bit-identical semantics to the pre-fusion code path. + """ + # Lazy imports to avoid pulling in distributed/communicator at module + # load time (matches the convention used by other call sites of + # ``flashinfer_allreduce_residual_rmsnorm`` in SGLang). + from sglang.srt.distributed import tensor_model_parallel_all_reduce + from sglang.srt.layers.communicator import apply_flashinfer_allreduce_fusion + from sglang.srt.layers.flashinfer_comm_fusion import ( + flashinfer_allreduce_residual_rmsnorm, + ) + + if x.is_cuda and x.dim() == 2 and apply_flashinfer_allreduce_fusion(x.shape[0]): + norm_out, _residual_out = flashinfer_allreduce_residual_rmsnorm( + input_tensor=x, + residual=residual, + weight=weight, + eps=eps, + use_attn_tp_group=use_attn_tp_group, + ) + if norm_out is not None: + # FlashInfer succeeded; apply the Gemma-4 layer_scalar tail. + # The mul is fused by the eager bf16 elementwise path; one + # extra launch on top of the fused AR+RMSNorm. ``scalar`` is + # shape ``[1]`` so broadcasting is free. + return norm_out * scalar + + # Fallback: identical to the pre-fusion code path. + x_reduced = tensor_model_parallel_all_reduce(x) + return gemma_rmsnorm_residual_scalar(x_reduced, weight, residual, scalar, eps) + + @triton.jit def _gemma_dual_rmsnorm_residual_kernel( X1_ptr, diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index 2af6216daf47..f5dbe2557650 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -107,10 +107,18 @@ def __init__( self.act_fn = GeluAndMul() self.prefix = prefix - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, skip_all_reduce: bool = False) -> torch.Tensor: + """Forward pass. + + When ``skip_all_reduce=True``, the ``RowParallelLinear.down_proj`` + omits its TP all-reduce so the caller can fuse it into a downstream + operation (see ``gemma4_arf_rmsnorm_residual_scalar`` for the + Gemma-4 post-FF combine fusion). The default is to all-reduce + in-line for back-compat with every other caller. + """ gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) - x, _ = self.down_proj(x) + x, _ = self.down_proj(x, skip_all_reduce=skip_all_reduce) return x diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d4192f947744..174bc8e67f70 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1487,9 +1487,9 @@ def _handle_gpu_memory_settings(self, gpu_mem): ) self.cuda_graph_bs = self._generate_cpu_graph_batch_sizes() - assert ( - self.torch_compile_max_bs > 0 - ), "cuda_graph_bs should contain positive batch sizes" + assert self.torch_compile_max_bs > 0, ( + "cuda_graph_bs should contain positive batch sizes" + ) self.cuda_graph_max_bs = self.torch_compile_max_bs if self.piecewise_cuda_graph_max_tokens is None: @@ -1815,12 +1815,12 @@ def _handle_model_specific_adjustments(self): else: self.enable_dp_attention = True self.moe_dense_tp_size = 1 - assert ( - self.dp_size == 1 - ), "For round-robin split mode, dp attention is not supported." - assert ( - self.tp_size <= 8 - ), "Context parallel only supports single machine (tp_size <= 8). Cross-machine CP has precision issues." + assert self.dp_size == 1, ( + "For round-robin split mode, dp attention is not supported." + ) + assert self.tp_size <= 8, ( + "Context parallel only supports single machine (tp_size <= 8). Cross-machine CP has precision issues." + ) self.attn_cp_size = self.tp_size // self.dp_size logger.warning( @@ -1861,9 +1861,9 @@ def _handle_model_specific_adjustments(self): self._set_default_dsa_backends(self.kv_cache_dtype, major) if self.enable_dsa_prefill_context_parallel: - assert ( - self.disaggregation_mode != "decode" - ), "CP is only supported for prefill when PD disaggregation, please remove --enable-dsa-prefill-context-parallel." + assert self.disaggregation_mode != "decode", ( + "CP is only supported for prefill when PD disaggregation, please remove --enable-dsa-prefill-context-parallel." + ) else: # DeepSeek V3/R1/V3.1 @@ -2098,9 +2098,9 @@ def _handle_model_specific_adjustments(self): ) if self.moe_runner_backend == "triton_kernel": - assert ( - self.ep_size == 1 - ), "Triton kernel MoE is only supported when ep_size == 1" + assert self.ep_size == 1, ( + "Triton kernel MoE is only supported when ep_size == 1" + ) elif model_arch in MIMO_V2_MODEL_ARCHS: if model_arch == "MiMoV2ForCausalLM" and not self.encoder_only: @@ -2182,7 +2182,9 @@ def _handle_model_specific_adjustments(self): "ascend", "trtllm_mha", "intel_xpu", - }, f"fa3, aiter, triton, ascend, trtllm_mha or intel_xpu is required for Llama4 model but got {self.attention_backend}" + }, ( + f"fa3, aiter, triton, ascend, trtllm_mha or intel_xpu is required for Llama4 model but got {self.attention_backend}" + ) if is_sm100_supported() and self.moe_runner_backend == "auto": if self.quantization in {"fp8", "modelopt_fp8"}: self.moe_runner_backend = "flashinfer_trtllm" @@ -2316,9 +2318,9 @@ def _handle_model_specific_adjustments(self): self.disable_hybrid_swa_memory = True # https://docs.sglang.ai/advanced_features/attention_backend.html accepted_backends = ["fa3", "triton", "trtllm_mha"] - assert ( - self.attention_backend in accepted_backends - ), f"One of the attention backends in {accepted_backends} is required for {model_arch}, but got {self.attention_backend}" + assert self.attention_backend in accepted_backends, ( + f"One of the attention backends in {accepted_backends} is required for {model_arch}, but got {self.attention_backend}" + ) elif model_arch in ["Olmo2ForCausalLM"]: # FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with Olmo3 model. logger.warning( @@ -2337,9 +2339,9 @@ def _handle_model_specific_adjustments(self): # Flashinfer appears to degrade performance when sliding window attention # is used for the Olmo2 architecture. Olmo2 does not use sliding window attention # but Olmo3 does. - assert ( - self.attention_backend != "flashinfer" - ), "FlashInfer backend can significantly degrade the performance of Olmo3 models." + assert self.attention_backend != "flashinfer", ( + "FlashInfer backend can significantly degrade the performance of Olmo3 models." + ) logger.info( f"Using {self.attention_backend} as attention backend for {model_arch}." @@ -2530,6 +2532,13 @@ def _handle_model_specific_adjustments(self): "Qwen3_5MoeForConditionalGeneration", "InternS2PreviewForConditionalGeneration", "Qwen3_5ForConditionalGeneration", + # Gemma-4 (dense + ConditionalGeneration): the post-FF + # combine path opts into ``gemma4_arf_rmsnorm_residual_scalar`` + # which calls FlashInfer's kARResidualRMSNorm pattern and + # then applies the Gemma layer_scalar tail. See + # ``Gemma4DecoderLayer.forward`` in models/gemma4_causal.py. + "Gemma4ForCausalLM", + "Gemma4ForConditionalGeneration", ] and (is_sm90_supported() or is_sm100_supported()) and self.tp_size > 1 @@ -2576,9 +2585,9 @@ def _handle_mamba_radix_cache( return if not support_mamba_cache_extra_buffer: - assert ( - not self.enable_mamba_extra_buffer() - ), f"mamba extra_buffer is not supported for {model_arch} model" + assert not self.enable_mamba_extra_buffer(), ( + f"mamba extra_buffer is not supported for {model_arch} model" + ) if self.enable_mamba_extra_buffer(): # extra_buffer if self.disable_radix_cache: @@ -2588,23 +2597,25 @@ def _handle_mamba_radix_cache( "Please use --mamba-scheduler-strategy no_buffer instead." ) - assert ( - is_cuda() or is_musa() or is_npu() - ), "Mamba extra_buffer is only supported on CUDA and MUSA and NPU devices with FLA backend" + assert is_cuda() or is_musa() or is_npu(), ( + "Mamba extra_buffer is only supported on CUDA and MUSA and NPU devices with FLA backend" + ) if self.speculative_num_draft_tokens is not None: - assert ( - self.mamba_track_interval >= self.speculative_num_draft_tokens - ), f"mamba_track_interval {self.mamba_track_interval} must be greater than or equal to speculative_num_draft_tokens {self.speculative_num_draft_tokens}" + assert self.mamba_track_interval >= self.speculative_num_draft_tokens, ( + f"mamba_track_interval {self.mamba_track_interval} must be greater than or equal to speculative_num_draft_tokens {self.speculative_num_draft_tokens}" + ) if self.page_size is not None: - assert ( - self.mamba_track_interval % self.page_size == 0 - ), f"mamba_track_interval {self.mamba_track_interval} must be divisible by page_size {self.page_size}" + assert self.mamba_track_interval % self.page_size == 0, ( + f"mamba_track_interval {self.mamba_track_interval} must be divisible by page_size {self.page_size}" + ) assert ( max(FLA_CHUNK_SIZE, self.page_size) % min(FLA_CHUNK_SIZE, self.page_size) == 0 - ), f"For SSM models with extra buffer, either FLA_CHUNK_SIZE or page_size must be divisible by the other, got {FLA_CHUNK_SIZE=}, {self.page_size=}" + ), ( + f"For SSM models with extra buffer, either FLA_CHUNK_SIZE or page_size must be divisible by the other, got {FLA_CHUNK_SIZE=}, {self.page_size=}" + ) elif not self.disable_radix_cache: # no_buffer if self.page_size is not None and self.page_size != 1: logger.warning( @@ -2749,9 +2760,9 @@ def _handle_attention_backend_compatibility(self): "Cuda graph is disabled because of using torch Flex Attention backend" ) self.disable_cuda_graph = True - assert ( - self.speculative_algorithm is None - ), "Speculative decoding is currently not supported with Flex Attention backend" + assert self.speculative_algorithm is None, ( + "Speculative decoding is currently not supported with Flex Attention backend" + ) # Whisper's encoder token padding conflicts with prefix caching. # Only disable for Whisper; other encoder-decoder models (e.g., mllama) use radix cache. @@ -3097,40 +3108,40 @@ def _handle_linear_attn_backend(self): def _handle_context_parallelism(self): if self.attn_cp_size > 1: # The tp_size is the world size, not the real tensor parallel size - assert ( - self.tp_size % self.attn_cp_size == 0 - ), "tp_size must be divisible by attn_cp_size" - assert ( - self.tp_size % (self.dp_size * self.attn_cp_size) == 0 - ), "tp_size must be divisible by dp_size * attn_cp_size" + assert self.tp_size % self.attn_cp_size == 0, ( + "tp_size must be divisible by attn_cp_size" + ) + assert self.tp_size % (self.dp_size * self.attn_cp_size) == 0, ( + "tp_size must be divisible by dp_size * attn_cp_size" + ) - assert ( - not self.enable_aiter_allreduce_fusion - ), "Aiter allreduce fusion is not supported with context parallelism" + assert not self.enable_aiter_allreduce_fusion, ( + "Aiter allreduce fusion is not supported with context parallelism" + ) if self.moe_dp_size > 1: # The tp_size is the world size, not the real tensor parallel size - assert ( - self.tp_size % self.moe_dp_size == 0 - ), "tp_size must be divisible by moe_dp_size" - assert ( - self.ep_size * self.moe_dp_size <= self.tp_size - ), "ep_size * moe_dp_size must be less than or equal to tp_size" + assert self.tp_size % self.moe_dp_size == 0, ( + "tp_size must be divisible by moe_dp_size" + ) + assert self.ep_size * self.moe_dp_size <= self.tp_size, ( + "ep_size * moe_dp_size must be less than or equal to tp_size" + ) assert self.pp_size == 1, "PP is not supported with context parallelism" if self.ep_size > 1: - assert ( - self.ep_size * self.moe_dp_size == self.tp_size - ), "ep_size * moe_dp_size must be equal to tp_size" + assert self.ep_size * self.moe_dp_size == self.tp_size, ( + "ep_size * moe_dp_size must be equal to tp_size" + ) - assert ( - not self.enable_aiter_allreduce_fusion - ), "Aiter allreduce fusion is not supported with context parallelism" + assert not self.enable_aiter_allreduce_fusion, ( + "Aiter allreduce fusion is not supported with context parallelism" + ) if self.attn_cp_size != self.moe_dp_size: - assert ( - self.moe_dp_size == 1 - ), "attn_cp_size != moe_dp_size is only supported when moe_dp_size == 1" + assert self.moe_dp_size == 1, ( + "attn_cp_size != moe_dp_size is only supported when moe_dp_size == 1" + ) def _handle_data_parallelism(self): if self.dp_size == 1: @@ -3146,9 +3157,9 @@ def _handle_data_parallelism(self): ) if self.enable_dp_lm_head: - assert ( - self.enable_dp_attention - ), "Please enable dp attention when setting enable_dp_lm_head. " + assert self.enable_dp_attention, ( + "Please enable dp attention when setting enable_dp_lm_head. " + ) def _handle_moe_kernel_config(self): if self.quantization == "mxfp8": @@ -3172,20 +3183,26 @@ def _handle_moe_kernel_config(self): "modelopt_fp8", "modelopt_mixed", None, - ], f"Invalid quantization '{self.quantization}'. \nFlashInfer Cutlass MOE supports only: 'modelopt_fp4', 'modelopt_fp8', 'modelopt_mixed', or bfloat16 (None)." + ], ( + f"Invalid quantization '{self.quantization}'. \nFlashInfer Cutlass MOE supports only: 'modelopt_fp4', 'modelopt_fp8', 'modelopt_mixed', or bfloat16 (None)." + ) assert self.ep_size in [ 1, self.tp_size, - ], "The expert parallel size must be 1 or the same as the tensor parallel size" + ], ( + "The expert parallel size must be 1 or the same as the tensor parallel size" + ) if self.moe_runner_backend == "flashinfer_cutedsl": - assert self.quantization in [ - "modelopt_fp4" - ], f"Invalid quantization '{self.quantization}'. \nFlashInfer CuteDSL MOE currently supports only: 'modelopt_fp4'." + assert self.quantization in ["modelopt_fp4"], ( + f"Invalid quantization '{self.quantization}'. \nFlashInfer CuteDSL MOE currently supports only: 'modelopt_fp4'." + ) assert self.ep_size in [ 1, self.tp_size, - ], "The expert parallel size must be 1 or the same as the tensor parallel size" + ], ( + "The expert parallel size must be 1 or the same as the tensor parallel size" + ) assert self.moe_a2a_backend in [ "none", "deepep", @@ -3208,7 +3225,9 @@ def _handle_moe_kernel_config(self): "modelopt_mixed", "compressed-tensors", None, - ], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM MOE supports only: 'modelopt_fp4', 'fp8', 'modelopt_fp8', 'modelopt_mixed', 'compressed-tensors', or bfloat16 (None)." + ], ( + f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM MOE supports only: 'modelopt_fp4', 'fp8', 'modelopt_fp8', 'modelopt_mixed', 'compressed-tensors', or bfloat16 (None)." + ) self.disable_shared_experts_fusion = True logger.warning( "FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set." @@ -3220,7 +3239,9 @@ def _handle_moe_kernel_config(self): "mxfp8", "modelopt_fp4", None, - ], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM routed MOE supports only: 'fp8', 'mxfp8', 'modelopt_fp4', or bfloat16 (None)." + ], ( + f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM routed MOE supports only: 'fp8', 'mxfp8', 'modelopt_fp4', or bfloat16 (None)." + ) self.disable_shared_experts_fusion = True logger.warning( "FlashInfer TRTLLM routed MoE is enabled. --disable-shared-experts-fusion is automatically set." @@ -3239,9 +3260,9 @@ def _handle_moe_kernel_config(self): "fp8", "mxfp8", ]: - assert ( - self.ep_size == 1 - ), "FP8/MXFP8 Cutlass MoE is only supported with ep_size == 1" + assert self.ep_size == 1, ( + "FP8/MXFP8 Cutlass MoE is only supported with ep_size == 1" + ) # TODO(yuwei): Fix piecewise cuda graph support for bypassed topk MoE backends. # Exception: GptOssForCausalLM wraps the entire MoE block in its own @@ -3328,13 +3349,13 @@ def _handle_a2a_moe(self): f"Wrong value of {fuse_mode=}, the NPU only support 1 or 2." ) elif fuse_mode == 2: - assert ( - self.quantization == "modelslim" - ), "When fuse_mode is set to 2, the NPU supports only ModelSlim quantization." + assert self.quantization == "modelslim", ( + "When fuse_mode is set to 2, the NPU supports only ModelSlim quantization." + ) if self.moe_a2a_backend == "flashinfer": - assert ( - self.enable_dp_attention and self.dp_size == self.tp_size - ), "Flashinfer MoE A2A is only supported with dp_size == tp_size and --enable-dp-attention" + assert self.enable_dp_attention and self.dp_size == self.tp_size, ( + "Flashinfer MoE A2A is only supported with dp_size == tp_size and --enable-dp-attention" + ) self.ep_size = self.tp_size logger.warning( f"Flashinfer MoE A2A is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." @@ -3353,7 +3374,9 @@ def _handle_a2a_moe(self): assert self.moe_runner_backend in [ "flashinfer_cutlass", "flashinfer_cutedsl", - ], "Flashinfer MoE A2A is only supported with flashinfer_cutlass or flashinfer_cutedsl moe runner backend" + ], ( + "Flashinfer MoE A2A is only supported with flashinfer_cutlass or flashinfer_cutedsl moe runner backend" + ) if ( self.moe_runner_backend == "flashinfer_cutedsl" and self.max_prefill_tokens is not None @@ -3401,7 +3424,9 @@ def _handle_a2a_moe(self): if self.chunked_prefill_size > 0 and self.disaggregation_mode != "decode": assert ( self.chunked_prefill_size - ) <= envs.SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK.get(), "SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK (default 4096) must be larger or equal to chunked_prefill_size" + ) <= envs.SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK.get(), ( + "SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK (default 4096) must be larger or equal to chunked_prefill_size" + ) def _handle_eplb_and_dispatch(self): if self.enable_eplb and (self.expert_distribution_recorder_mode is None): @@ -3426,7 +3451,9 @@ def _handle_elastic_ep(self): assert self.eplb_algorithm in [ "elasticity_aware", "elasticity_aware_hierarchical", - ], "Elastic EP requires eplb_algorithm to be set to 'auto' or 'elasticity_aware(_hierarchical)'." + ], ( + "Elastic EP requires eplb_algorithm to be set to 'auto' or 'elasticity_aware(_hierarchical)'." + ) assert self.pp_size == 1, "PP size should be set to 1 under elastic EP" @@ -3435,9 +3462,9 @@ def _handle_elastic_ep(self): self.mooncake_ib_device ) if self.elastic_ep_rejoin: - assert ( - self.elastic_ep_backend is not None - ), "Elastic EP rejoin requires elastic_ep_backend to be set." + assert self.elastic_ep_backend is not None, ( + "Elastic EP rejoin requires elastic_ep_backend to be set." + ) def _handle_expert_distribution_metrics(self): if self.enable_expert_distribution_metrics and ( @@ -3806,9 +3833,9 @@ def _handle_pd_disaggregation(self): self.mamba_scheduler_strategy = "no_buffer" elif self.disaggregation_mode == "prefill": - assert ( - self.disaggregation_transfer_backend != "fake" - ), "Prefill server does not support 'fake' as the transfer backend" + assert self.disaggregation_transfer_backend != "fake", ( + "Prefill server does not support 'fake' as the transfer backend" + ) self.disable_cuda_graph = True @@ -7031,9 +7058,9 @@ def mamba_cache_chunk_size(self) -> int: def check_server_args(self): # Check parallel size constraints - assert ( - self.tp_size * self.pp_size - ) % self.nnodes == 0, "tp_size must be divisible by number of nodes" + assert (self.tp_size * self.pp_size) % self.nnodes == 0, ( + "tp_size must be divisible by number of nodes" + ) assert ( self.pp_max_micro_batch_size is None or self.pp_max_micro_batch_size >= 1 @@ -7053,7 +7080,9 @@ def check_server_args(self): if self.pp_size > 1: assert ( self.disable_overlap_schedule and self.speculative_algorithm is None - ), "Pipeline parallelism is not compatible with overlap schedule, speculative decoding" + ), ( + "Pipeline parallelism is not compatible with overlap schedule, speculative decoding" + ) assert not ( self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention @@ -7080,32 +7109,32 @@ def check_server_args(self): # Check speculative decoding if self.speculative_algorithm is not None: - assert ( - not self.enable_mixed_chunk - ), "enable_mixed_chunk is required for speculative decoding" + assert not self.enable_mixed_chunk, ( + "enable_mixed_chunk is required for speculative decoding" + ) # Check chunked prefill # Skip validation if chunked prefill is disabled (i.e., size <= 0). # Skip validation if disaggregation mode is decode. if self.chunked_prefill_size > 0 and self.disaggregation_mode != "decode": - assert ( - self.chunked_prefill_size % self.page_size == 0 - ), "chunked_prefill_size must be divisible by page_size" + assert self.chunked_prefill_size % self.page_size == 0, ( + "chunked_prefill_size must be divisible by page_size" + ) # Check pdmux if self.enable_pdmux: - assert ( - self.pp_size == 1 - ), "PD-Multiplexing is only supported with pipeline parallelism disabled (pp_size=1)." - assert ( - self.chunked_prefill_size == -1 - ), "PD-Multiplexing is not compatible with chunked prefill." - assert ( - self.disaggregation_mode == "null" - ), "PD-Multiplexing is not compatible with disaggregation mode." - assert ( - self.disable_overlap_schedule - ), "PD-Multiplexing is not compatible with overlap schedule." + assert self.pp_size == 1, ( + "PD-Multiplexing is only supported with pipeline parallelism disabled (pp_size=1)." + ) + assert self.chunked_prefill_size == -1, ( + "PD-Multiplexing is not compatible with chunked prefill." + ) + assert self.disaggregation_mode == "null", ( + "PD-Multiplexing is not compatible with disaggregation mode." + ) + assert self.disable_overlap_schedule, ( + "PD-Multiplexing is not compatible with overlap schedule." + ) # NOTE: CUDA Green Context may encounter potential issues with CudaGraph on torch 2.7.x – 2.8.x, leading to performance degradation. import torch @@ -7131,7 +7160,9 @@ def check_server_args(self): assert self.schedule_policy in [ "fcfs", "lof", - ], f"To use priority scheduling, schedule_policy must be 'fcfs' or 'lof'. '{self.schedule_policy}' is not supported." + ], ( + f"To use priority scheduling, schedule_policy must be 'fcfs' or 'lof'. '{self.schedule_policy}' is not supported." + ) if self.default_priority_value is None: logger.warning( "--default-priority-value is not set while --enable-priority-scheduling is enabled. " @@ -7153,9 +7184,9 @@ def check_server_args(self): validate_hisparse(self) - assert ( - self.schedule_conservativeness >= 0 - ), "schedule_conservativeness must be non-negative" + assert self.schedule_conservativeness >= 0, ( + "schedule_conservativeness must be non-negative" + ) if self.model_impl == "mindspore": assert is_npu(), "MindSpore model impl is only supported on Ascend npu." @@ -7274,9 +7305,9 @@ def check_lora_server_args(self): pinned=False, ) elif isinstance(lora_path, dict): - assert ( - "lora_name" in lora_path and "lora_path" in lora_path - ), f"When providing LoRA paths as a list of dict, each dict should contain 'lora_name' and 'lora_path' keys. Got: {lora_path}" + assert "lora_name" in lora_path and "lora_path" in lora_path, ( + f"When providing LoRA paths as a list of dict, each dict should contain 'lora_name' and 'lora_path' keys. Got: {lora_path}" + ) lora_ref = LoRARef( lora_id=LoRARef.deterministic_id( lora_path["lora_name"], lora_path["lora_path"] @@ -7314,14 +7345,16 @@ def check_lora_server_args(self): if self.lora_target_modules: self.lora_target_modules = set(self.lora_target_modules) if "all" in self.lora_target_modules: - assert ( - len(self.lora_target_modules) == 1 - ), "If 'all' is specified in --lora-target-modules, it should be the only module specified." + assert len(self.lora_target_modules) == 1, ( + "If 'all' is specified in --lora-target-modules, it should be the only module specified." + ) # Ensure sufficient information is provided for LoRA initialization. assert self.lora_paths or ( self.max_lora_rank and self.lora_target_modules - ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization." + ), ( + "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization." + ) # Validate max_loaded_loras if self.max_loaded_loras is not None: @@ -7343,9 +7376,9 @@ def check_lora_server_args(self): if self.lora_use_virtual_experts: logger.info("Virtual expert computation enabled.") - assert ( - self.lora_drain_wait_threshold >= 0.0 - ), "--lora-drain-wait-threshold must be non-negative." + assert self.lora_drain_wait_threshold >= 0.0, ( + "--lora-drain-wait-threshold must be non-negative." + ) def validate_buckets_rule(self, arg_name: str, buckets_rule: List[str]): if not buckets_rule: @@ -7357,43 +7390,45 @@ def validate_buckets_rule(self, arg_name: str, buckets_rule: List[str]): "tse", "default", "custom", - ], f"Unsupported {arg_name} rule type: '{rule}'. Must be one of: 'tse', 'default', 'custom'" + ], ( + f"Unsupported {arg_name} rule type: '{rule}'. Must be one of: 'tse', 'default', 'custom'" + ) if rule == "tse": - assert ( - len(buckets_rule) == 4 - ), f"{arg_name} TSE rule requires exactly 4 parameters: ['tse', middle, base, count], got {len(buckets_rule)}" + assert len(buckets_rule) == 4, ( + f"{arg_name} TSE rule requires exactly 4 parameters: ['tse', middle, base, count], got {len(buckets_rule)}" + ) try: middle = float(buckets_rule[1]) base = float(buckets_rule[2]) count = int(buckets_rule[3]) except (ValueError, IndexError): - assert ( - False - ), f"{arg_name} TSE rule parameters must be: ['tse', , , ]" + assert False, ( + f"{arg_name} TSE rule parameters must be: ['tse', , , ]" + ) assert base > 1, f"{arg_name} TSE base must be larger than 1, got: {base}" assert count > 0, f"{arg_name} TSE count must be positive, got: {count}" assert middle > 0, f"{arg_name} TSE middle must be positive, got: {middle}" elif rule == "default": - assert ( - len(buckets_rule) == 1 - ), f"{arg_name} default rule should only have one parameter: ['default'], got {len(buckets_rule)}" + assert len(buckets_rule) == 1, ( + f"{arg_name} default rule should only have one parameter: ['default'], got {len(buckets_rule)}" + ) elif rule == "custom": - assert ( - len(buckets_rule) >= 2 - ), f"{arg_name} custom rule requires at least one bucket value: ['custom', value1, ...]" + assert len(buckets_rule) >= 2, ( + f"{arg_name} custom rule requires at least one bucket value: ['custom', value1, ...]" + ) try: bucket_values = [float(x) for x in buckets_rule[1:]] except ValueError: assert False, f"{arg_name} custom rule bucket values must be numeric" - assert len(set(bucket_values)) == len( - bucket_values - ), f"{arg_name} custom rule bucket values should not contain duplicates" - assert all( - val >= 0 for val in bucket_values - ), f"{arg_name} custom rule bucket values should be non-negative" + assert len(set(bucket_values)) == len(bucket_values), ( + f"{arg_name} custom rule bucket values should not contain duplicates" + ) + assert all(val >= 0 for val in bucket_values), ( + f"{arg_name} custom rule bucket values should be non-negative" + ) def adjust_mem_fraction_for_vlm(self, model_config): vision_config = getattr(model_config.hf_config, "vision_config", None) diff --git a/test/registered/unit/layers/test_gemma4_arf_ops.py b/test/registered/unit/layers/test_gemma4_arf_ops.py new file mode 100644 index 000000000000..916ef7d6a5ff --- /dev/null +++ b/test/registered/unit/layers/test_gemma4_arf_ops.py @@ -0,0 +1,208 @@ +"""Unit tests for ``gemma4_arf_rmsnorm_residual_scalar``. + +Three coverage points: +1. Success path: when the FlashInfer fused kernel returns a non-None + ``norm_out``, the wrapper returns ``norm_out * scalar`` and does NOT + call ``tensor_model_parallel_all_reduce``. +2. Fallback path: when the FlashInfer fused kernel returns ``(None, None)`` + (e.g. flashinfer unavailable, workspace not ready, non-contig input, + batch too large), the wrapper falls back to + ``tensor_model_parallel_all_reduce`` + ``gemma_rmsnorm_residual_scalar`` + with bit-identical semantics. +3. Predicate-off path: when ``apply_flashinfer_allreduce_fusion`` returns + False (e.g. flag disabled), the wrapper takes the fallback path + directly without even calling FlashInfer. +""" + +import unittest +from unittest.mock import patch + +import torch + +from sglang.srt.layers.gemma4_fused_ops import gemma4_arf_rmsnorm_residual_scalar + + +class TestGemma4ArfRmsnormResidualScalar(unittest.TestCase): + """All three branches of the new wrapper, with FlashInfer + all-reduce + fully mocked so the test runs on CPU.""" + + def setUp(self): + # Use a tiny CUDA tensor to satisfy the ``x.is_cuda`` gate in the + # wrapper without requiring real CUDA: the wrapper's branch + # condition reads attributes; we provide a fake CUDA tensor via + # mocking ``is_cuda`` to True on a CPU tensor. + self.T = 4 # tokens + self.H = 16 # hidden + self.scalar_val = 2.5 + # All test tensors live on CPU; we patch ``is_cuda`` per-test. + self.x = torch.randn(self.T, self.H) + self.weight = torch.randn(self.H) + self.residual = torch.randn(self.T, self.H) + self.scalar = torch.tensor([self.scalar_val]) + + def _force_cuda(self, *tensors): + for t in tensors: + # SimpleNamespace would break PyTorch ops; we instead patch the + # is_cuda *property* on the tensor's class via a context. + pass + + def test_success_path_uses_flashinfer_and_applies_scalar(self): + # Sentinel "fused" tensor returned by flashinfer. In real life it + # would be (norm(allreduce(x) + residual) * weight). Here it's + # just a known tensor so the test can assert ``out == sentinel * + # scalar`` without re-implementing the kernel. + sentinel_norm = torch.full_like(self.x, fill_value=1.5) + sentinel_residual = torch.full_like(self.residual, fill_value=0.5) + + with ( + patch( + "sglang.srt.layers.gemma4_fused_ops.gemma_rmsnorm_residual_scalar" + ) as mock_kernel, + patch( + "sglang.srt.layers.communicator.apply_flashinfer_allreduce_fusion", + return_value=True, + ), + patch( + "sglang.srt.layers.flashinfer_comm_fusion.flashinfer_allreduce_residual_rmsnorm", + return_value=(sentinel_norm, sentinel_residual), + ), + patch("sglang.srt.distributed.tensor_model_parallel_all_reduce") as mock_ar, + patch.object(torch.Tensor, "is_cuda", property(lambda self: True)), + ): + out = gemma4_arf_rmsnorm_residual_scalar( + self.x, + self.weight, + self.residual, + self.scalar, + eps=1e-6, + ) + + # The wrapper should return sentinel_norm * scalar. + expected = sentinel_norm * self.scalar + torch.testing.assert_close(out, expected) + # And critically: neither the AR helper nor the fallback kernel + # was invoked, because the fused path succeeded. + mock_ar.assert_not_called() + mock_kernel.assert_not_called() + + def test_fallback_when_flashinfer_returns_none(self): + # FlashInfer's wrapper returns (None, None) when its preconditions + # aren't met at runtime (e.g. workspace init failed, + # non-contiguous tensors). Wrapper should fall back to the + # ``tensor_model_parallel_all_reduce + gemma_rmsnorm_residual_scalar`` + # pair with bit-identical semantics. + reduced_sentinel = torch.full_like(self.x, fill_value=7.7) + kernel_sentinel = torch.full_like(self.x, fill_value=3.3) + + with ( + patch( + "sglang.srt.layers.gemma4_fused_ops.gemma_rmsnorm_residual_scalar", + return_value=kernel_sentinel, + ) as mock_kernel, + patch( + "sglang.srt.layers.communicator.apply_flashinfer_allreduce_fusion", + return_value=True, + ), + patch( + "sglang.srt.layers.flashinfer_comm_fusion.flashinfer_allreduce_residual_rmsnorm", + return_value=(None, None), + ), + patch( + "sglang.srt.distributed.tensor_model_parallel_all_reduce", + return_value=reduced_sentinel, + ) as mock_ar, + patch.object(torch.Tensor, "is_cuda", property(lambda self: True)), + ): + out = gemma4_arf_rmsnorm_residual_scalar( + self.x, + self.weight, + self.residual, + self.scalar, + eps=1e-6, + ) + + # Wrapper must have called the AR + fallback kernel. + mock_ar.assert_called_once_with(self.x) + mock_kernel.assert_called_once() + # And returned the kernel's output verbatim (no extra scalar mul + # because the fallback kernel already applies the scalar). + self.assertIs(out, kernel_sentinel) + + def test_predicate_off_uses_fallback_directly(self): + # When apply_flashinfer_allreduce_fusion(...) is False (e.g. flag + # disabled), the wrapper must take the fallback path without even + # invoking flashinfer_allreduce_residual_rmsnorm. + reduced_sentinel = torch.full_like(self.x, fill_value=7.7) + kernel_sentinel = torch.full_like(self.x, fill_value=3.3) + + with ( + patch( + "sglang.srt.layers.gemma4_fused_ops.gemma_rmsnorm_residual_scalar", + return_value=kernel_sentinel, + ) as mock_kernel, + patch( + "sglang.srt.layers.communicator.apply_flashinfer_allreduce_fusion", + return_value=False, + ), + patch( + "sglang.srt.layers.flashinfer_comm_fusion.flashinfer_allreduce_residual_rmsnorm", + ) as mock_fi, + patch( + "sglang.srt.distributed.tensor_model_parallel_all_reduce", + return_value=reduced_sentinel, + ) as mock_ar, + patch.object(torch.Tensor, "is_cuda", property(lambda self: True)), + ): + out = gemma4_arf_rmsnorm_residual_scalar( + self.x, + self.weight, + self.residual, + self.scalar, + eps=1e-6, + ) + + mock_fi.assert_not_called() + mock_ar.assert_called_once_with(self.x) + mock_kernel.assert_called_once() + self.assertIs(out, kernel_sentinel) + + def test_non_cuda_input_takes_fallback(self): + # CPU tensors short-circuit through the fallback (the ``is_cuda`` + # gate prevents flashinfer from ever being called). + reduced_sentinel = torch.full_like(self.x, fill_value=7.7) + kernel_sentinel = torch.full_like(self.x, fill_value=3.3) + + with ( + patch( + "sglang.srt.layers.gemma4_fused_ops.gemma_rmsnorm_residual_scalar", + return_value=kernel_sentinel, + ) as mock_kernel, + patch( + "sglang.srt.layers.communicator.apply_flashinfer_allreduce_fusion", + ) as mock_pred, + patch( + "sglang.srt.layers.flashinfer_comm_fusion.flashinfer_allreduce_residual_rmsnorm", + ) as mock_fi, + patch( + "sglang.srt.distributed.tensor_model_parallel_all_reduce", + return_value=reduced_sentinel, + ) as mock_ar, + ): + # Note: NOT patching is_cuda — leave it False on the CPU tensor. + out = gemma4_arf_rmsnorm_residual_scalar( + self.x, + self.weight, + self.residual, + self.scalar, + eps=1e-6, + ) + + mock_pred.assert_not_called() # short-circuited before predicate + mock_fi.assert_not_called() + mock_ar.assert_called_once_with(self.x) + mock_kernel.assert_called_once() + self.assertIs(out, kernel_sentinel) + + +if __name__ == "__main__": + unittest.main()