diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 44dc3d67bb98..cdf03ea1ad27 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -711,6 +711,22 @@ def __init__(self, config: VllmConfig) -> None: if self.tp_size <= 1: logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.") return + if config.parallel_config.pipeline_parallel_size > 1: + # flashinfer has issues with multi-group TP configurations. + # See: https://github.com/flashinfer-ai/flashinfer/issues/2647 + logger.warning_once( + "AllReduce fusion pass is disabled for PP+TP due to " + "a flashinfer device assignment limitation." + ) + return + if config.parallel_config.data_parallel_size > 1: + # flashinfer has issues with multi-group TP configurations. + # See: https://github.com/flashinfer-ai/flashinfer/issues/2647 + logger.warning_once( + "AllReduce fusion pass is disabled for DP+TP due to " + "a flashinfer device assignment limitation." + ) + return self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="all_reduce_fusion_pass" )