diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c16a0797fac4..aea366d55f7d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1706,13 +1706,6 @@ def _handle_model_specific_adjustments(self): quant_method = get_quantization_config(hf_config) is_mxfp4_quant_format = quant_method == "mxfp4" - if is_blackwell_supported(): - # workaround for https://github.com/flashinfer-ai/flashinfer/issues/2006 - if not self.enable_dp_attention and self.nnodes == 1: - self.enable_flashinfer_allreduce_fusion = True - logger.info( - "Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM" - ) if not self.enable_dp_attention and self.nnodes == 1 and is_hip(): # TODO (Hubert): Put this back later # self.enable_aiter_allreduce_fusion = True @@ -2076,6 +2069,7 @@ def _handle_model_specific_adjustments(self): "Qwen3_5ForConditionalGeneration", ] and (is_sm90_supported() or is_sm100_supported()) + and self.tp_size > 1 and not self.enable_dp_attention and self.attn_cp_size <= 1 and self.nnodes == 1 @@ -2083,6 +2077,9 @@ def _handle_model_specific_adjustments(self): and self.moe_a2a_backend == "none" ): self.enable_flashinfer_allreduce_fusion = True + logger.info( + f"Auto-enabling FlashInfer AllReduce Fusion on SM90/SM10X for {model_arch}" + ) def _handle_mamba_radix_cache( self,