diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3b4659631fc8..b990c80d8d59 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -518,6 +518,7 @@ class ServerArgs: moe_runner_backend: str = "auto" flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default" enable_flashinfer_allreduce_fusion: bool = False + enforce_disable_flashinfer_allreduce_fusion: bool = False enable_aiter_allreduce_fusion: bool = False deepep_mode: Literal["auto", "normal", "low_latency"] = "auto" ep_num_redundant_experts: int = 0 @@ -2080,6 +2081,14 @@ def _handle_model_specific_adjustments(self): ): self.enable_flashinfer_allreduce_fusion = True + # Apply enforce_disable_flashinfer_allreduce_fusion after all model-specific adjustments + if self.enforce_disable_flashinfer_allreduce_fusion: + self.enable_flashinfer_allreduce_fusion = False + logger.info( + "FlashInfer allreduce fusion is forcibly disabled " + "via --enforce-disable-flashinfer-allreduce-fusion." + ) + def _handle_mamba_radix_cache( self, model_arch: str, @@ -4841,6 +4850,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable FlashInfer allreduce fusion with Residual RMSNorm.", ) + parser.add_argument( + "--enforce-disable-flashinfer-allreduce-fusion", + action="store_true", + help="Enforce disable FlashInfer allreduce fusion.", + ) parser.add_argument( "--enable-aiter-allreduce-fusion", action="store_true",