Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,15 @@ def fused_experts_none_to_flashinfer_trtllm_fp8(
# Move kernel call outside context manager to avoid graph breaks
# during torch.compile for piecewise cuda graph.
# Use custom op wrapper for torch.compile compatibility.

# The DeepSeekV3 routing method requires float32 router logits.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leejnau @trevor-m is this true? If so, why didn't we run into issues before?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe will be fixed by flashinfer-ai/flashinfer#2993 ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The path for block scale had this fix already, I think we didn't use per tensor scaling before?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. we have never run DSV3/R1 with per-tensor FP8 before.

if routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)
else:
router_logits = router_logits.to(torch.bfloat16)

output = trtllm_fp8_per_tensor_scale_moe_wrapper(
routing_logits=router_logits.to(torch.bfloat16),
routing_logits=router_logits,
routing_bias=routing_bias_cast,
hidden_states=a_q,
gemm1_weights=quant_info.w13_weight,
Expand Down
Loading