Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,7 @@ def forward_impl(
topk_group = None
routed_scaling_factor = None

# Don't support post-quant allgather for fp8 block scale and has_w4a16_mxfp4 for now.
is_post_quant_allgather_supported = self.has_nvfp4 or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8
run_post_quant_allgather = self.use_dp and self.parallel_size > 1 and is_post_quant_allgather_supported
run_post_quant_allgather = self.use_dp and self.parallel_size > 1

x_sf = None
token_selected_experts = None
Expand Down Expand Up @@ -239,6 +237,11 @@ def forward_impl(
x, False, alignment=self.quant_method.weight_alignment)
# Update x_row and x_col to the padded shape
x_row, x_col = x.shape[0], x.shape[1]
elif self.has_deepseek_fp8_block_scales:
pass
elif self.has_w4a16_mxfp4:
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
x = torch.nn.functional.pad(x, (0, pad_size))
else:
raise ValueError(
f"unsupported quantization mode with run_post_quant_allgather: {self.quant_config.quant_mode}"
Expand Down Expand Up @@ -266,8 +269,8 @@ def forward_impl(
x_val, x_scale = torch.ops.trtllm.fp8_quantize_1x128(x)

final_hidden_states = torch.ops.trtllm.fp8_block_scale_moe_runner(
router_logits,
routing_bias,
router_logits if not run_post_quant_allgather else None,
routing_bias if not run_post_quant_allgather else None,
x_val,
x_scale,
self.w3_w1_weight,
Expand All @@ -284,6 +287,8 @@ def forward_impl(
self.expert_size_per_partition, # local_expert_size
routed_scaling_factor,
self.routing_method.routing_method_type,
topk_weights=token_final_scales,
topk_ids=token_selected_experts,
)
elif self.has_nvfp4:
scale_factor_use_ue8m0 = False
Expand Down Expand Up @@ -324,8 +329,8 @@ def forward_impl(
routed_scaling_factor,
self.routing_method.routing_method_type,
do_finalize=do_finalize,
topk_ids=token_selected_experts,
topk_weights=token_final_scales,
topk_ids=token_selected_experts,
)

if not do_finalize:
Expand All @@ -335,14 +340,17 @@ def forward_impl(
final_hidden_states = outputs[0]
elif self.has_w4a16_mxfp4:
assert x.dtype == torch.bfloat16
if not run_post_quant_allgather:
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
x = torch.nn.functional.pad(x, (0, pad_size))
else:
x = x

pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
x = torch.nn.functional.pad(x, (0, pad_size))
intermediate_size_per_partition_padded = self.w3_w1_weight.shape[
-2] // 2
final_hidden_states = torch.ops.trtllm.bf16_mxe2m1_block_scale_moe_runner(
router_logits,
routing_bias,
router_logits if not run_post_quant_allgather else None,
routing_bias if not run_post_quant_allgather else None,
x,
self.w3_w1_weight,
self.w3_w1_weight_scale,
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,6 @@ accuracy/test_cli_flow.py::TestLlama3_1_8B::test_tp4[disable_gemm_allreduce_plug
accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp8_tp2pp2_manage_weights SKIP (https://nvbugs/5532023)
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_tp4[enable_gemm_allreduce_plugin] SKIP (https://nvbugs/5532023)
accuracy/test_llm_api.py::TestLlama3_1_8BInstruct::test_tp2cp2 SKIP (https://nvbugs/5532023)
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput_mtp_trtllm] SKIP (https://nvbugs/5537738)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5503479)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5541494)
unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout] SKIP (https://nvbugs/5541545)
Expand Down