Conversation
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
📝 WalkthroughWalkthroughThis PR extends the MoE kernel infrastructure to support precomputed routing data (expert indices and weights) alongside traditional routing logits, enabling alternative routing input paths. The changes propagate expert_indices and expert_weights through launcher constructors, add optional routing parameters to Python APIs, introduce a new trtllm_bf16_routed_moe function, and expand test coverage for routed MoE inference. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @IwakuraRein, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the Mixture of Experts (MoE) functionality by introducing a new API for bfloat16 operations with pre-computed routing. It also refactors existing MoE implementations across different data types to provide greater control over output finalization, allowing for the retrieval of intermediate results. These changes aim to improve flexibility and integration within MoE pipelines. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new API, trtllm_bf16_routed_moe, and enhances existing MoE functions to support pre-computed routing and the option to return intermediate results. The changes are consistently applied across the C++ kernel launcher and Python bindings, improving flexibility and control over the MoE computation flow. New tests have been added to validate the functionality of the BF16 routed MoE. The do_finalize parameter is a valuable addition, allowing users to either get the final output or inspect intermediate tensors for further processing.
| if (has_precomputed_indices) { | ||
| // Use expert_indices directly | ||
| workspace.routing_expert_indexes = | ||
| static_cast<int*>(const_cast<void*>(expert_indices.data_ptr())); |
There was a problem hiding this comment.
The use of const_cast<void*>(expert_indices.data_ptr()) here bypasses const-correctness. While it might be necessary due to the underlying API expecting a non-const pointer, it's important to ensure that the expert_indices data is indeed treated as read-only within the kernel to prevent unintended modifications. If the kernel truly modifies this data, it should be explicitly copied to a mutable buffer.
4a1c205 to
83a634d
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/fused_moe/core.py (1)
1317-1371:⚠️ Potential issue | 🟠 MajorValidate
topk_ids/expert_weightswhenrouting_logitsis absent.
Without dtype/shape checks, empty or malformed packed routing data can reach the kernel and produce undefined routing. Add explicit validation fortopk_ids(int32,[num_tokens, top_k]) andexpert_weights(if provided) before proceeding.🔧 Suggested guardrails
assert routing_logits is not None or topk_ids is not None, ( "either routing_logits or topk_ids must be provided" ) @@ num_tokens = hidden_states.shape[0] hidden_size = hidden_states.shape[-1] + if routing_logits is None: + if topk_ids is None: + raise ValueError("topk_ids must be provided when routing_logits is None.") + if topk_ids.dtype != torch.int32: + raise TypeError("topk_ids must be int32 when routing_logits is None.") + if topk_ids.ndim != 2 or topk_ids.shape[0] != num_tokens or topk_ids.shape[1] != top_k: + raise ValueError("topk_ids must have shape [num_tokens, top_k].") + if expert_weights is not None and expert_weights.numel() > 0: + if expert_weights.dtype != torch.bfloat16: + raise TypeError("expert_weights must be bfloat16.") + if expert_weights.shape != topk_ids.shape: + raise ValueError("expert_weights must match topk_ids shape.")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 1317 - 1371, When routing_logits is None, validate that topk_ids and expert_weights are well-formed before use: in the function handling routing (the block that currently assigns topk_ids/topk_ids = topk_ids and expert_weights = ...), assert topk_ids is a torch.Tensor of dtype torch.int32 and shape [num_tokens, top_k], and if expert_weights is provided assert it is a torch.Tensor with matching first two dims [num_tokens, top_k] and a sensible dtype (e.g., routing_logits.dtype if available or torch.bfloat16), otherwise raise a clear ValueError; use the local symbols topk_ids, expert_weights, num_tokens, and top_k to implement these checks so malformed/empty packed routing data cannot reach the kernel.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 463-474: In check_routing(), enforce that exactly one routing
source is provided: if routing_logits is null/absent then expert_indices must be
non-empty, and if expert_indices is provided/non-empty then routing_logits must
be absent; error out when both are present or both absent. Modify
FusedMoeLauncher::check_routing_common() caller logic in check_routing() to add
TVM_FFI_ICHECK-style guards referencing expert_indices, routing_logits, and
hidden_states (and args->top_k where relevant) so you reject the case of empty
expert_indices with no routing_logits and the case where both routing_logits and
a non-empty expert_indices are supplied.
In `@flashinfer/fused_moe/core.py`:
- Around line 1448-1452: The function _fake_trtllm_bf16_moe has parameters
(routing_logits, routing_bias, expert_indices, expert_weights, hidden_states)
flagged as unused by Ruff; to silence the lint while preserving the signature,
rename the unused parameters by prefixing them with an underscore (e.g.,
routing_logits -> _routing_logits, routing_bias -> _routing_bias, expert_indices
-> _expert_indices, expert_weights -> _expert_weights, and if hidden_states is
unused rename to _hidden_states) in the _fake_trtllm_bf16_moe definition and any
internal references so the signature stays compatible but Ruff no longer reports
them as unused.
In `@tests/moe/test_trtllm_gen_routed_fused_moe.py`:
- Around line 403-530: The test's call to shuffle_matrix_a(..., 64) produces a
shuffle_block_size of 16 that doesn't match
BF16Moe.prepare_static_weights_for_kernel which uses epilogue_tile_m=128
(shuffle_block_size=32); update the test to use the same tile param (use 128
instead of 64) or derive the tile size from
BF16Moe.prepare_static_weights_for_kernel so that shuffle_matrix_a and the
production preprocessing use the same epilogue_tile_m/shuffle_block_size,
ensuring gemm1_weights/gemm2_weights are shuffled into the identical layout the
kernel expects.
---
Outside diff comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 1317-1371: When routing_logits is None, validate that topk_ids and
expert_weights are well-formed before use: in the function handling routing (the
block that currently assigns topk_ids/topk_ids = topk_ids and expert_weights =
...), assert topk_ids is a torch.Tensor of dtype torch.int32 and shape
[num_tokens, top_k], and if expert_weights is provided assert it is a
torch.Tensor with matching first two dims [num_tokens, top_k] and a sensible
dtype (e.g., routing_logits.dtype if available or torch.bfloat16), otherwise
raise a clear ValueError; use the local symbols topk_ids, expert_weights,
num_tokens, and top_k to implement these checks so malformed/empty packed
routing data cannot reach the kernel.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/trtllm_fused_moe_kernel_launcher.cuflashinfer/__init__.pyflashinfer/fused_moe/__init__.pyflashinfer/fused_moe/core.pytests/moe/test_trtllm_gen_fused_moe.pytests/moe/test_trtllm_gen_routed_fused_moe.py
| void check_routing() const override { | ||
| FusedMoeLauncher::check_routing_common(); | ||
| if (expert_indices.ndim() == 2 && expert_indices.size(0) > 0) { | ||
| // Pre-computed routing: expert_indices is a packed tensor | ||
| // Format: (expert_id << 16) | (weight_bf16.view(int16)) | ||
| TVM_FFI_ICHECK_EQ(expert_indices.ndim(), 2) << "expert_indices must be 2D."; | ||
| TVM_FFI_ICHECK_EQ(expert_indices.size(0), hidden_states.size(0)) | ||
| << "expert_indices and hidden_states must have same number of tokens."; | ||
| TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k) | ||
| << "expert_indices dim1 must match top_k."; | ||
| TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) << "expert_indices must be int32."; | ||
| } |
There was a problem hiding this comment.
Enforce exactly one routing source to avoid undefined routing data.
If routing_logits is absent and expert_indices is empty, routing runs with uninitialized indices; if both are present, routing can overwrite user-provided indices. Add an explicit exclusivity/required check.
🔧 Suggested validation
void check_routing() const override {
FusedMoeLauncher::check_routing_common();
- if (expert_indices.ndim() == 2 && expert_indices.size(0) > 0) {
+ bool has_precomputed = expert_indices.ndim() == 2 && expert_indices.size(0) > 0;
+ bool has_logits = routing_logits.has_value();
+ TVM_FFI_ICHECK(has_precomputed || has_logits)
+ << "either routing_logits or expert_indices must be provided.";
+ TVM_FFI_ICHECK(!(has_precomputed && has_logits))
+ << "provide either routing_logits or expert_indices, not both.";
+ if (has_precomputed) {
// Pre-computed routing: expert_indices is a packed tensor
// Format: (expert_id << 16) | (weight_bf16.view(int16))
TVM_FFI_ICHECK_EQ(expert_indices.ndim(), 2) << "expert_indices must be 2D.";
TVM_FFI_ICHECK_EQ(expert_indices.size(0), hidden_states.size(0))
<< "expert_indices and hidden_states must have same number of tokens.";
TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k)
<< "expert_indices dim1 must match top_k.";
TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) << "expert_indices must be int32.";
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| void check_routing() const override { | |
| FusedMoeLauncher::check_routing_common(); | |
| if (expert_indices.ndim() == 2 && expert_indices.size(0) > 0) { | |
| // Pre-computed routing: expert_indices is a packed tensor | |
| // Format: (expert_id << 16) | (weight_bf16.view(int16)) | |
| TVM_FFI_ICHECK_EQ(expert_indices.ndim(), 2) << "expert_indices must be 2D."; | |
| TVM_FFI_ICHECK_EQ(expert_indices.size(0), hidden_states.size(0)) | |
| << "expert_indices and hidden_states must have same number of tokens."; | |
| TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k) | |
| << "expert_indices dim1 must match top_k."; | |
| TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) << "expert_indices must be int32."; | |
| } | |
| void check_routing() const override { | |
| FusedMoeLauncher::check_routing_common(); | |
| bool has_precomputed = expert_indices.ndim() == 2 && expert_indices.size(0) > 0; | |
| bool has_logits = routing_logits.has_value(); | |
| TVM_FFI_ICHECK(has_precomputed || has_logits) | |
| << "either routing_logits or expert_indices must be provided."; | |
| TVM_FFI_ICHECK(!(has_precomputed && has_logits)) | |
| << "provide either routing_logits or expert_indices, not both."; | |
| if (has_precomputed) { | |
| // Pre-computed routing: expert_indices is a packed tensor | |
| // Format: (expert_id << 16) | (weight_bf16.view(int16)) | |
| TVM_FFI_ICHECK_EQ(expert_indices.ndim(), 2) << "expert_indices must be 2D."; | |
| TVM_FFI_ICHECK_EQ(expert_indices.size(0), hidden_states.size(0)) | |
| << "expert_indices and hidden_states must have same number of tokens."; | |
| TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k) | |
| << "expert_indices dim1 must match top_k."; | |
| TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) << "expert_indices must be int32."; | |
| } |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 463 - 474, In
check_routing(), enforce that exactly one routing source is provided: if
routing_logits is null/absent then expert_indices must be non-empty, and if
expert_indices is provided/non-empty then routing_logits must be absent; error
out when both are present or both absent. Modify
FusedMoeLauncher::check_routing_common() caller logic in check_routing() to add
TVM_FFI_ICHECK-style guards referencing expert_indices, routing_logits, and
hidden_states (and args->top_k where relevant) so you reject the case of empty
expert_indices with no routing_logits and the case where both routing_logits and
a non-empty expert_indices are supplied.
| routing_logits: Optional[torch.Tensor], | ||
| routing_bias: Optional[torch.Tensor], | ||
| expert_indices: Optional[torch.Tensor], | ||
| expert_weights: Optional[torch.Tensor], | ||
| hidden_states: torch.Tensor, |
There was a problem hiding this comment.
Silence unused-argument lint in _fake_trtllm_bf16_moe.
Ruff flags these parameters as unused; prefixing with _ keeps lint clean while preserving signature parity.
🔧 Minimal lint-safe rename
def _fake_trtllm_bf16_moe(
- routing_logits: Optional[torch.Tensor],
- routing_bias: Optional[torch.Tensor],
- expert_indices: Optional[torch.Tensor],
- expert_weights: Optional[torch.Tensor],
+ _routing_logits: Optional[torch.Tensor],
+ _routing_bias: Optional[torch.Tensor],
+ _expert_indices: Optional[torch.Tensor],
+ _expert_weights: Optional[torch.Tensor],
hidden_states: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm2_weights: torch.Tensor,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| routing_logits: Optional[torch.Tensor], | |
| routing_bias: Optional[torch.Tensor], | |
| expert_indices: Optional[torch.Tensor], | |
| expert_weights: Optional[torch.Tensor], | |
| hidden_states: torch.Tensor, | |
| def _fake_trtllm_bf16_moe( | |
| _routing_logits: Optional[torch.Tensor], | |
| _routing_bias: Optional[torch.Tensor], | |
| _expert_indices: Optional[torch.Tensor], | |
| _expert_weights: Optional[torch.Tensor], | |
| hidden_states: torch.Tensor, |
🧰 Tools
🪛 Ruff (0.15.2)
[warning] 1448-1448: Unused function argument: routing_logits
(ARG001)
[warning] 1449-1449: Unused function argument: routing_bias
(ARG001)
[warning] 1450-1450: Unused function argument: expert_indices
(ARG001)
[warning] 1451-1451: Unused function argument: expert_weights
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fused_moe/core.py` around lines 1448 - 1452, The function
_fake_trtllm_bf16_moe has parameters (routing_logits, routing_bias,
expert_indices, expert_weights, hidden_states) flagged as unused by Ruff; to
silence the lint while preserving the signature, rename the unused parameters by
prefixing them with an underscore (e.g., routing_logits -> _routing_logits,
routing_bias -> _routing_bias, expert_indices -> _expert_indices, expert_weights
-> _expert_weights, and if hidden_states is unused rename to _hidden_states) in
the _fake_trtllm_bf16_moe definition and any internal references so the
signature stays compatible but Ruff no longer reports them as unused.
| @pytest.mark.parametrize("num_tokens", [8, 64]) | ||
| @pytest.mark.parametrize("hidden_size", [1024, 2048]) | ||
| @pytest.mark.parametrize("intermediate_size", [1024, 2048]) | ||
| @pytest.mark.parametrize("num_experts", [8, 16]) | ||
| @pytest.mark.parametrize("top_k", [2, 4]) | ||
| @pytest.mark.parametrize( | ||
| "routing_method_type", | ||
| [ | ||
| RoutingMethodType.Renormalize, | ||
| ], | ||
| ) | ||
| def test_trtllm_gen_bf16_routed_fused_moe( | ||
| num_tokens: int, | ||
| hidden_size: int, | ||
| intermediate_size: int, | ||
| top_k: int, | ||
| num_experts: int, | ||
| routing_method_type: RoutingMethodType, | ||
| ): | ||
| """Test Bf16 scale routed MoE matches standard routing.""" | ||
| compute_capability = get_compute_capability(torch.device(device="cuda")) | ||
| if compute_capability[0] not in [10]: | ||
| pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") | ||
| torch.manual_seed(42) | ||
| device = torch.device("cuda:0") | ||
| enable_pdl = device_support_pdl(device) | ||
|
|
||
| # Generate random routing logits for reference | ||
| routing_logits = torch.rand(num_tokens, num_experts, device=device).to( | ||
| torch.bfloat16 | ||
| ) | ||
|
|
||
| # Generate random hidden states in FP8 | ||
| hidden_states = ( | ||
| torch.randn(num_tokens, hidden_size, device=device).to(torch.bfloat16) * 0.1 | ||
| ) | ||
|
|
||
| # Generate weights | ||
| gemm1_weights = torch.randn( | ||
| num_experts, 2 * intermediate_size, hidden_size, device=device | ||
| ).to(torch.bfloat16) | ||
| gemm2_weights = torch.randn( | ||
| num_experts, hidden_size, intermediate_size, device=device | ||
| ).to(torch.bfloat16) | ||
|
|
||
| gemm1_weights_shuffled = [] | ||
| gemm2_weights_shuffled = [] | ||
| for i in range(num_experts): | ||
| tmp_weights1 = shuffle_matrix_a(gemm1_weights[i].view(torch.uint8), 64) | ||
| tmp_weights2 = shuffle_matrix_a(gemm2_weights[i].view(torch.uint8), 64) | ||
| block_k = 128 | ||
| gemm1_weights_shuffled.append(convert_to_block_layout(tmp_weights1, block_k)) | ||
| gemm2_weights_shuffled.append(convert_to_block_layout(tmp_weights2, block_k)) | ||
| gemm1_weights = torch.stack(gemm1_weights_shuffled).view(torch.bfloat16) | ||
| gemm2_weights = torch.stack(gemm2_weights_shuffled).view(torch.bfloat16) | ||
|
|
||
| # Run reference with routing_logits | ||
| reference_output = trtllm_bf16_moe( | ||
| routing_logits=routing_logits, | ||
| routing_bias=None, | ||
| hidden_states=hidden_states, | ||
| gemm1_weights=gemm1_weights, | ||
| gemm2_weights=gemm2_weights, | ||
| num_experts=num_experts, | ||
| top_k=top_k, | ||
| n_group=None, | ||
| topk_group=None, | ||
| intermediate_size=intermediate_size, | ||
| local_expert_offset=0, | ||
| local_num_experts=num_experts, | ||
| routed_scaling_factor=None, | ||
| routing_method_type=routing_method_type.value, | ||
| use_shuffled_weight=True, | ||
| weight_layout=WeightLayout.BlockMajorK, | ||
| do_finalize=True, | ||
| enable_pdl=enable_pdl, | ||
| ).to(torch.float) | ||
|
|
||
| # Compute routing using reference implementation | ||
| if routing_method_type == RoutingMethodType.Renormalize: | ||
| permute_info, expert_weights_ref = routing_reference_renormalize( | ||
| routing_logits, top_k, num_experts, 8 | ||
| ) | ||
| elif routing_method_type == RoutingMethodType.RenormalizeNaive: | ||
| permute_info, expert_weights_ref = routing_reference_renormalize_naive( | ||
| routing_logits, top_k, num_experts, 8 | ||
| ) | ||
| elif routing_method_type == RoutingMethodType.TopK: | ||
| permute_info, expert_weights_ref = routing_reference_topk( | ||
| routing_logits, top_k, num_experts, 8 | ||
| ) | ||
| topk_ids = permute_info["topKIndices"].to(torch.int32) | ||
| expert_weights = expert_weights_ref.view(num_tokens, num_experts)[ | ||
| torch.arange(num_tokens, device=device).unsqueeze(1), topk_ids | ||
| ].to(torch.bfloat16) | ||
|
|
||
| # Pack topk_ids and expert_weights into single tensor | ||
| # Format: (expert_id << 16) | (weight_bf16.view(int16)) | ||
| packed_topk_ids = (topk_ids << 16) | expert_weights.view(torch.int16).to( | ||
| torch.int32 | ||
| ) | ||
|
|
||
| # Run with pre-computed routing (packed format) | ||
| output = trtllm_bf16_routed_moe( | ||
| topk_ids=packed_topk_ids, | ||
| hidden_states=hidden_states, | ||
| gemm1_weights=gemm1_weights, | ||
| gemm2_weights=gemm2_weights, | ||
| num_experts=num_experts, | ||
| top_k=top_k, | ||
| n_group=None, | ||
| topk_group=None, | ||
| intermediate_size=intermediate_size, | ||
| local_expert_offset=0, | ||
| local_num_experts=num_experts, | ||
| routed_scaling_factor=None, | ||
| routing_method_type=routing_method_type.value, | ||
| use_shuffled_weight=True, | ||
| weight_layout=WeightLayout.BlockMajorK, | ||
| do_finalize=True, | ||
| enable_pdl=enable_pdl, | ||
| ).to(torch.float) | ||
|
|
||
| mask = torch.isclose(output, reference_output, rtol=1e-2, atol=1e-2) | ||
|
|
||
| # mismatch percentage | ||
| mismatch_pct = (~mask).float().mean().item() * 100 | ||
| assert mismatch_pct < 10, f"Mismatch percentage is {mismatch_pct:.2f}%" |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for BF16 weight preprocessing patterns and tile sizes
rg -n "class BF16Moe|prepare_static_weights_for_kernel|epilogue_tile_m|shuffle_matrix_a" \
tests/moe/ flashinfer/fused_moe/Repository: flashinfer-ai/flashinfer
Length of output: 4614
🏁 Script executed:
#!/bin/bash
# Look at BF16Moe class and its prepare_static_weights_for_kernel implementation
sed -n '1375,1450p' tests/moe/test_trtllm_gen_fused_moe.pyRepository: flashinfer-ai/flashinfer
Length of output: 2885
🏁 Script executed:
#!/bin/bash
# Get the full BF16Moe.prepare_static_weights_for_kernel implementation
sed -n '1401,1460p' tests/moe/test_trtllm_gen_fused_moe.pyRepository: flashinfer-ai/flashinfer
Length of output: 2501
🏁 Script executed:
#!/bin/bash
# Look at the get_shuffle_matrix_a_row_indices and shuffle_matrix_a implementations
grep -A 20 "def get_shuffle_matrix_a_row_indices" flashinfer/fused_moe/core.py
echo "---"
grep -B 5 -A 10 "def shuffle_matrix_a" flashinfer/fused_moe/core.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 71
🏁 Script executed:
#!/bin/bash
# Check imports in the test file
head -50 tests/moe/test_trtllm_gen_routed_fused_moe.py | grep -E "import|from"Repository: flashinfer-ai/flashinfer
Length of output: 322
🏁 Script executed:
#!/bin/bash
# Get the full imports from both flashinfer and flashinfer.fused_moe
head -100 tests/moe/test_trtllm_gen_routed_fused_moe.pyRepository: flashinfer-ai/flashinfer
Length of output: 3404
🏁 Script executed:
#!/bin/bash
# Find where shuffle_matrix_a is defined
rg -n "def shuffle_matrix_a|shuffle_matrix_a =" flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 190
🏁 Script executed:
#!/bin/bash
# Look at shuffle_matrix_a definition and its implementation
sed -n '771,850p' flashinfer/fp4_quantization.pyRepository: flashinfer-ai/flashinfer
Length of output: 2855
🏁 Script executed:
#!/bin/bash
# Look at get_shuffle_matrix_a_row_indices definition
rg -n "def get_shuffle_matrix_a_row_indices" flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 129
🏁 Script executed:
#!/bin/bash
# Look at get_shuffle_matrix_a_row_indices implementation
sed -n '757,820p' flashinfer/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 2231
🏁 Script executed:
#!/bin/bash
# Look at get_shuffle_block_size to understand how epilogue_tile_m affects shuffling
rg -n "def get_shuffle_block_size" flashinfer/Repository: flashinfer-ai/flashinfer
Length of output: 148
🏁 Script executed:
#!/bin/bash
# Look at get_shuffle_block_size implementation
sed -n '750,756p' flashinfer/utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 251
🏁 Script executed:
#!/bin/bash
# Check if there are other BF16 tests in the routed test file and what tile sizes they use
grep -n "def test_trtllm_gen.*bf16\|shuffle_matrix_a" tests/moe/test_trtllm_gen_routed_fused_moe.pyRepository: flashinfer-ai/flashinfer
Length of output: 307
🏁 Script executed:
#!/bin/bash
# Let's verify the exact context around how the routed test uses these functions
sed -n '445,460p' tests/moe/test_trtllm_gen_routed_fused_moe.pyRepository: flashinfer-ai/flashinfer
Length of output: 846
🏁 Script executed:
#!/bin/bash
# Let's double-check the relationship between epilogue_tile_m and block size
cat > /tmp/verify_tile_size.py << 'EOF'
# Test epilogue_tile_m values and their corresponding shuffle_block_size
values = [64, 128]
for epilogue_tile_m in values:
shuffle_block_size = 16
if epilogue_tile_m % 128 == 0:
shuffle_block_size = 32
print(f"epilogue_tile_m={epilogue_tile_m} => shuffle_block_size={shuffle_block_size}")
EOF
python /tmp/verify_tile_size.pyRepository: flashinfer-ai/flashinfer
Length of output: 156
🏁 Script executed:
#!/bin/bash
# Check if there are comments explaining the 64 tile size choice in the routed test
sed -n '403,460p' tests/moe/test_trtllm_gen_routed_fused_moe.py | head -60Repository: flashinfer-ai/flashinfer
Length of output: 2281
Fix hardcoded tile size to match BF16Moe preprocessing path.
The test uses shuffle_matrix_a(..., 64) which produces a shuffle_block_size=16 row permutation, but BF16Moe.prepare_static_weights_for_kernel uses epilogue_tile_m=128, producing shuffle_block_size=32. This mismatch means the test shuffles weights with a different layout than the production code path, potentially allowing the test to pass while feeding incorrect weights to the kernel. Use epilogue_tile_m=128 to match the BF16Moe preprocessing logic, or document why a different tile size is required.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/moe/test_trtllm_gen_routed_fused_moe.py` around lines 403 - 530, The
test's call to shuffle_matrix_a(..., 64) produces a shuffle_block_size of 16
that doesn't match BF16Moe.prepare_static_weights_for_kernel which uses
epilogue_tile_m=128 (shuffle_block_size=32); update the test to use the same
tile param (use 128 instead of 64) or derive the tile size from
BF16Moe.prepare_static_weights_for_kernel so that shuffle_matrix_a and the
production preprocessing use the same epilogue_tile_m/shuffle_block_size,
ensuring gemm1_weights/gemm2_weights are shuffled into the identical layout the
kernel expects.
|
/bot run |
|
[FAILED] Pipeline #44689639: 10/20 passed |
|
|
|
this is ready for merging |
jimmyzho
left a comment
There was a problem hiding this comment.
lgtm based on other's approvals to help unblock
<!-- .github/pull_request_template.md --> ## 📌 Description Add `trtllm_bf16_routed_moe` api ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests `pytest tests/moe/test_trtllm_gen_routed_fused_moe.py::test_trtllm_gen_bf16_routed_fused_moe` - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for pre-computed routing in MoE operations, enabling flexible routing input strategies. * New routed MoE APIs now available: BF16 and FP8 variants support pre-packed top-k routing information. * Introduced dual-path mechanism allowing MoE operations to accept either routing logits or pre-computed routing data. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Siyuan Fu <siyuanf@nvidia.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📌 Description
Add
trtllm_bf16_routed_moeapi🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
pytest tests/moe/test_trtllm_gen_routed_fused_moe.py::test_trtllm_gen_bf16_routed_fused_moeunittest, etc.).Reviewer Notes
Summary by CodeRabbit