Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,12 @@ def quantize_input(self, x, post_quant_comm: bool = True):
x, False, alignment=self.quant_method.input_hidden_alignment)
x_row, x_col = x.shape[0], x.shape[1]
elif self.has_deepseek_fp8_block_scales:
x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x)
x_row = x.shape[0]
# For SM100+, fp8_quantize_1x128 returns x_sf with shape (blocked_n, num_tokens),
# but moe_a2a_dispatch requires all payloads to have first dim = num_tokens.
# Transpose x_sf before dispatch and transpose back after receive, but this may
# introduce perf regression. So we don't supports post_quant_comm for fp8_block_scales.
# TODO: Consider remove the constraint of the OneSided AlltoAll
pass
elif self.has_w4a16_mxfp4:
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
x = torch.nn.functional.pad(x, (0, pad_size))
Expand Down Expand Up @@ -405,6 +409,9 @@ def run_moe(

if self.has_deepseek_fp8_block_scales:
assert do_finalize, "fp8_block_scale_moe_runner does not support do_finalize=False"
# fp8_block_scale_moe_runner needs 2D shape for x_sf and only support SM100+
if x_sf is None:
x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x)

final_hidden_states = torch.ops.trtllm.fp8_block_scale_moe_runner(
router_logits,
Expand Down