[DSV4] Add BF16 and MXFP8 A2A support for flashinfer a2a one sided#40960
Conversation
|
Documentation preview: https://vllm--40960.org.readthedocs.build/en/40960/ |
There was a problem hiding this comment.
Code Review
This pull request introduces support for bf16 dispatch and deferred input quantization within the flashinfer_nvlink_one_sided MoE kernel, specifically to accommodate expert types like trtllm_mxfp4. The implementation updates the MoeAlltoAll workspace to support dynamic payload sizes and modifies the preparation logic to handle cases where quantization is performed post-dispatch. Feedback was provided regarding a potential AttributeError when quant_config is None and the necessity of updating existing validation checks to ensure the new logic is reachable.
| if defer_input_quant or quant_config.quant_dtype is None: | ||
| # Experts (e.g. trtllm_mxfp4 with mxfp8 activations) quantize | ||
| # post-dispatch; ship bf16 tokens with no per-token scale payload. | ||
| dispatch_dtype_bytes_per_elem, dispatch_has_fp8_scale = 2, False | ||
| elif quant_config.quant_dtype == "nvfp4": | ||
| dispatch_dtype_bytes_per_elem, dispatch_has_fp8_scale = 0, True | ||
| else: | ||
| raise NotImplementedError( | ||
| "flashinfer_nvlink_one_sided dispatch only supports nvfp4, " | ||
| "bf16, and defer_input_quant paths today; got " | ||
| f"quant_dtype={quant_config.quant_dtype!r}" | ||
| ) |
There was a problem hiding this comment.
This logic is currently unreachable for non-nvfp4 models (such as mxfp4) because of the ValueError check at lines 232-239 (outside this diff hunk). To support BF16 and deferred quantization for other formats, that validation block should be removed or updated. Additionally, the code should safely handle cases where quant_config is None (e.g., for unquantized models) to avoid an AttributeError when accessing quant_config.quant_dtype.
quant_dtype = quant_config.quant_dtype if quant_config is not None else None
if defer_input_quant or quant_dtype is None:
# Experts (e.g. trtllm_mxfp4 with mxfp8 activations) quantize
# post-dispatch; ship bf16 tokens with no per-token scale payload.
dispatch_dtype_bytes_per_elem, dispatch_has_fp8_scale = 2, False
elif quant_dtype == "nvfp4":
dispatch_dtype_bytes_per_elem, dispatch_has_fp8_scale = 0, True
else:
raise NotImplementedError(
"flashinfer_nvlink_one_sided dispatch only supports nvfp4, "
"bf16, and defer_input_quant paths today; got "
f"quant_dtype={quant_dtype!r}"
)There was a problem hiding this comment.
I think we need to change this condition as well otherwise, the branch would early exit at here for bf16.
e86f2a3 to
b07b85a
Compare
c6761fd to
fcaaf5e
Compare
fcaaf5e to
743a407
Compare
|
LGTM overall. I have confirmed that the gen step time is ~40ms for AG/RS and ~35ms for NVLinkOneSided A2A under the following config:
|
|
A small missing piece is |
Thanks. We do have pre-allocated workspace buffer but now we are just manually copy it. I will take a look at trtllm code! |
Yes, the point is that a portion of workspace is viewed as a Tensor, and passed into the MoE OP as output Tensor, so that the MoE OP directly outputs onto the workspace. |
2d72303 to
32df1aa
Compare
…ll2all
The one-sided MoeAlltoAll dispatch workspace was hardcoded for nvfp4
hidden states + fp8 scales, so any other activation dtype overran the
buffer. Parameterize the workspace sizing by bytes-per-elem and whether
an fp8 scale payload is present, then route non-nvfp4 quant configs to
a bf16 dispatch (2 B/elem, no scale) via a new defer_input_quant hint.
trtllm_mxfp4 experts already advertise expects_unquantized_inputs=True
(they call mxfp8_quantize internally). Wire make_mxfp4_moe_kernel to
pass that signal into maybe_make_prepare_finalize, and have the one-
sided prepare() honor the per-call defer_input_quant flag by shipping
a1 as bf16 with no scale payload. Two-sided already handled this.
NOTE: the flashinfer moe_a2a_dispatch C++ kernel only templates top_k
in {1, 2, 4, 8}; models with other top_k (e.g. DeepSeek-V4 top_k=6)
must use flashinfer_nvlink_two_sided instead.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
) Padded rows in the [ep_size, max_num_tokens, ...] workspace retain stale topk_ids from prior dispatch calls (the workspace is zeroed only once at init). Those stale ids cause the downstream trtllm_fp4 grouped GEMM to do phantom work for random local experts every layer, which (a) inflates expert GEMM time and (b) creates the cross-rank skew that the combine kernel then has to wait on. Setting `invalid_token_expert_id` to `num_experts` (one past the valid expert range) makes the flashinfer worker overwrite all top_k topk_ids slots of padded rows with that sentinel (moeA2ASanitizeExpertIdsKernel in moeAlltoAllKernels.cu); the trtllm grouped GEMM then sees those rows as routed to no local expert (out of [local_expert_offset, local_expert_offset + local_num_experts)) and skips them. Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
When prepare-side mxfp8_quantize pads K to mx_alignment (e.g. gpt-oss hidden=2880 -> 3072 with align=256), pre-PR's torch.empty_like(hidden_states) naturally produced an unpadded output because hidden_states was the original bf16 input. With prepare-side quantize, hidden_states entering apply() is the padded fp8 tensor, so allocating output by self.hidden_dim (which is the post-roundup padded value from maybe_roundup_sizes) propagates padding into lm_head. Use moe_config.hidden_dim_unpadded so trtllm internally truncates back to the original hidden, matching pre-PR behavior. Apply the same fix to the modular workspace_shapes for non-aligned hiddens with EP. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
32df1aa to
96d415a
Compare
…40960) Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Signed-off-by: Zijing Liu <liuzijing2014@gmail.com> Co-authored-by: Zijing Liu <liuzijing2014@users.noreply.github.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> (cherry picked from commit b4806c8)
Purpose
Originally Flashinfer one sided a2a only supports nvfp4 dispatch. Add BF16 and MXFP8 dispatch.
Test Plan
gsm8k on V4-Flash
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.