diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index 713917f2168..2b988000374 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -419,11 +419,12 @@ def __init__( overridden_tp_size=1 if self.enable_attention_dp else None, layer_idx=layer_idx, ) - + # TODO(TRTLLM-7809): Fix fusion with PP>1 self.fusion_config.PRE_MLP_FUSION = model_config.mapping.has_tp( - ) and not self.enable_attention_dp and self.enable_fusion - self.fusion_config.POST_MLP_FUSION = model_config.mapping.has_tp( - ) and not self.enable_attention_dp and self.enable_fusion + ) and not self.enable_attention_dp and self.enable_fusion and not model_config.mapping.has_pp( + ) + self.fusion_config.POST_MLP_FUSION = self.fusion_config.PRE_MLP_FUSION + else: self.feed_forward = Llama4MoE( num_experts=config.num_local_experts, @@ -437,9 +438,9 @@ def __init__( layer_idx=layer_idx) self.fusion_config.PRE_MOE_FUSION = model_config.mapping.has_tp( - ) and not self.enable_attention_dp and self.enable_fusion - self.fusion_config.POST_MOE_FUSION = model_config.mapping.has_tp( - ) and not self.enable_attention_dp and self.enable_fusion + ) and not self.enable_attention_dp and self.enable_fusion and not model_config.mapping.has_pp( + ) + self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 481b63de168..588724b3c93 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -696,8 +696,8 @@ def test_chunked_prefill(self, attn_backend): @parametrize_with_ids("cuda_graph", [False, True]) @pytest.mark.parametrize( "tp_size,pp_size,ep_size", [(8, 1, 1), (8, 1, 4), (8, 1, 8), (4, 1, 1), - (4, 1, 2), (4, 1, 4)], - ids=["tp8", "tp8ep4", "tp8ep8", "tp4", "tp4ep2", "tp4ep4"]) + (4, 1, 2), (4, 1, 4), (4, 2, 1)], + ids=["tp8", "tp8ep4", "tp8ep8", "tp4", "tp4ep2", "tp4ep4", "tp4pp2"]) def test_fp8(self, cuda_graph, tp_size, pp_size, ep_size): if get_device_memory() < 140000 and get_device_count() < 8: pytest.skip("Not enough memory for this test")