diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 4045f43ab3f..09caf8ed417 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -199,9 +199,7 @@ def forward_impl( topk_group = None routed_scaling_factor = None - # Don't support post-quant allgather for fp8 block scale and has_w4a16_mxfp4 for now. - is_post_quant_allgather_supported = self.has_nvfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8 - run_post_quant_allgather = self.use_dp and self.parallel_size > 1 and is_post_quant_allgather_supported + run_post_quant_allgather = self.use_dp and self.parallel_size > 1 x_sf = None token_selected_experts = None @@ -239,6 +237,11 @@ def forward_impl( x, False, alignment=self.quant_method.weight_alignment) # Update x_row and x_col to the padded shape x_row, x_col = x.shape[0], x.shape[1] + elif self.has_deepseek_fp8_block_scales: + pass + elif self.has_w4a16_mxfp4: + pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] + x = torch.nn.functional.pad(x, (0, pad_size)) else: raise ValueError( f"unsupported quantization mode with run_post_quant_allgather: {self.quant_config.quant_mode}" @@ -266,8 +269,8 @@ def forward_impl( x_val, x_scale = torch.ops.trtllm.fp8_quantize_1x128(x) final_hidden_states = torch.ops.trtllm.fp8_block_scale_moe_runner( - router_logits, - routing_bias, + router_logits if not run_post_quant_allgather else None, + routing_bias if not run_post_quant_allgather else None, x_val, x_scale, self.w3_w1_weight, @@ -284,6 +287,8 @@ def forward_impl( self.expert_size_per_partition, # local_expert_size routed_scaling_factor, self.routing_method.routing_method_type, + topk_weights=token_final_scales, + topk_ids=token_selected_experts, ) elif self.has_nvfp4: scale_factor_use_ue8m0 = False @@ -324,8 +329,8 @@ def forward_impl( routed_scaling_factor, self.routing_method.routing_method_type, do_finalize=do_finalize, - topk_ids=token_selected_experts, topk_weights=token_final_scales, + topk_ids=token_selected_experts, ) if not do_finalize: @@ -335,14 +340,17 @@ def forward_impl( final_hidden_states = outputs[0] elif self.has_w4a16_mxfp4: assert x.dtype == torch.bfloat16 + if not run_post_quant_allgather: + pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] + x = torch.nn.functional.pad(x, (0, pad_size)) + else: + x = x - pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1] - x = torch.nn.functional.pad(x, (0, pad_size)) intermediate_size_per_partition_padded = self.w3_w1_weight.shape[ -2] // 2 final_hidden_states = torch.ops.trtllm.bf16_mxe2m1_block_scale_moe_runner( - router_logits, - routing_bias, + router_logits if not run_post_quant_allgather else None, + routing_bias if not run_post_quant_allgather else None, x, self.w3_w1_weight, self.w3_w1_weight_scale, diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index ea39655efc2..cde54a61762 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -341,7 +341,6 @@ accuracy/test_cli_flow.py::TestLlama3_1_8B::test_tp4[disable_gemm_allreduce_plug accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp8_tp2pp2_manage_weights SKIP (https://nvbugs/5532023) accuracy/test_cli_flow.py::TestLlama3_1_8B::test_tp4[enable_gemm_allreduce_plugin] SKIP (https://nvbugs/5532023) accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_tp2cp2 SKIP (https://nvbugs/5532023) -accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp_trtllm] SKIP (https://nvbugs/5537738) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5503479) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5541494) unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout] SKIP (https://nvbugs/5541545)