Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Mixture-of-Experts (MoE) functionality by adding support for MXInt4 block-scale MoE with pre-computed routing. This change is crucial for optimizing performance in advanced use cases, such as CUDA Graph capture and distributed MoE systems, where routing decisions can be prepared beforehand. By allowing external routing inputs, the system gains flexibility and efficiency, avoiding redundant computations and enabling more complex inference pipelines. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds pre-computed (packed top-k) routing support to fused MoE: C++ launchers and Python ops now accept expert indices and expert weights, include validation and wiring for dual routing modes (routing_logits or pre-computed), and a new MXInt4 routed test was added. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant MoE_API
participant MoE_Kernel
participant Expert
rect rgba(100,150,200,0.5)
Note over Client,MoE_Kernel: Routing via logits
Client->>MoE_API: trtllm_mxint4_block_scale_moe(routing_logits,...)
MoE_API->>MoE_Kernel: call with routing_logits
MoE_Kernel->>MoE_Kernel: compute top-k & expert weights
MoE_Kernel->>Expert: dispatch per-expert GEMM
Expert-->>MoE_Kernel: aggregated outputs
MoE_Kernel-->>MoE_API: final output
MoE_API-->>Client: result
end
rect rgba(150,200,100,0.5)
Note over Client,MoE_Kernel: Pre-computed routing (packed)
Client->>MoE_API: trtllm_mxint4_block_scale_routed_moe(topk_ids,expert_weights,...)
MoE_API->>MoE_Kernel: call with expert_indices & expert_weights
MoE_Kernel->>MoE_Kernel: validate/wire precomputed routing
MoE_Kernel->>Expert: dispatch per-expert GEMM using precomputed indices/weights
Expert-->>MoE_Kernel: aggregated outputs
MoE_Kernel-->>MoE_API: final output
MoE_API-->>Client: result
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for pre-computed routing in the mxint4 MoE kernel, which is a valuable optimization for scenarios like CUDA graph capture. The changes are well-structured across the C++ kernel, Python bindings, and tests. My review focuses on improving code quality by addressing some minor code duplication and redundancies. Overall, this is a solid contribution.
| if (expert_indices.ndim() == 2 && expert_indices.size(0) > 0) { | ||
| // Pre-computed routing: expert_indices is a packed tensor | ||
| // Format: (expert_id << 16) | (weight_bf16.view(int16)) | ||
| TVM_FFI_ICHECK_EQ(expert_indices.ndim(), 2) << "expert_indices must be 2D."; |
| expert_weights = | ||
| alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); | ||
| // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr | ||
| bool has_precomputed_indices = expert_indices.ndim() == 2 && expert_indices.size(0) > 0; |
There was a problem hiding this comment.
The condition expert_indices.ndim() == 2 && expert_indices.size(0) > 0 is repeated in check_routing, prepare_routing, and run. To improve maintainability and reduce code duplication, consider encapsulating this logic in a private helper method, for example:
private:
bool has_precomputed_indices() const {
// Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr
return expert_indices.ndim() == 2 && expert_indices.size(0) > 0;
}You could then call has_precomputed_indices() where this check is needed. A similar helper could be created for checking pre-computed weights.
flashinfer/fused_moe/core.py
Outdated
| # When routing_logits is None, we either have topk_ids/expert_weights, | ||
| # packed into a single tensor as topk_ids | ||
| # or have them individually as topk_ids and expert_weights respectively | ||
| topk_ids = topk_ids |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/fused_moe/core.py (2)
2125-2131:⚠️ Potential issue | 🔴 CriticalRouted mode can break autotuning by passing
Noneasrouting_logitsinput.When
routing_logits is None, thisinputslist still includesNone; later tuner runner paths dereferencerouting_logits.shape[0].🔧 Suggested fix
- inputs = [ - output, - routing_logits, - topk_ids, - expert_weights, - hidden_states, - ] + routing_logits_for_tuning = ( + torch.empty( + num_tokens, num_experts, dtype=routing_dtype, device="meta" + ) + if routing_logits is None + else routing_logits + ) + inputs = [ + output, + routing_logits_for_tuning, + topk_ids, + expert_weights, + hidden_states, + ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 2125 - 2131, The inputs list construction currently includes routing_logits even when it's None, which later causes dereferences like routing_logits.shape[0]; update the code that builds the inputs list (the variable named inputs in fused_moe/core.py) to omit routing_logits when it is None (or replace it with a safe sentinel tensor) so downstream tuner/runner paths never receive a None; specifically, gate the inclusion of routing_logits in the inputs list (or ensure routing_logits is a valid tensor before building inputs) so references to routing_logits.shape[...] are safe.
2055-2090:⚠️ Potential issue | 🟠 MajorHonor the caller-provided
outputtensor instead of always reallocating.Line 2088 always creates a new tensor, so an explicit
outputpassed by callers is silently ignored.🔧 Suggested fix
- output: torch.Tensor, + output: Optional[torch.Tensor], @@ - # Create workspace buffers - output = torch.empty( - num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device - ) + # Create workspace buffers + if output is None: + output = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device + ) + else: + check_shape_dtype_device( + output, + (num_tokens, hidden_size), + torch.bfloat16, + hidden_states.device, + "output", + )
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1191-1200: The code currently forwards
expert_weights_in.data_ptr() to workspace.expert_weights based only on ndim and
non-empty size; instead, validate the full shape, dtype and device before
binding raw pointers: ensure expert_weights_in.ndim() == 2,
expert_weights_in.size(0) == args->num_tokens, expert_weights_in.size(1) ==
args->top_k, expert_weights_in.dtype() == dl_bfloat16 and
expert_weights_in.device() == hidden_states.device(); only if all checks pass
set workspace.expert_weights = const_cast<void*>(expert_weights_in.data_ptr()),
otherwise allocate expert_weights via alloc_tensor({args->num_tokens,
args->top_k}, dl_bfloat16, hidden_states.device()) and set
workspace.expert_weights to that buffer; also ensure the tensor is contiguous
(or document required layout) before passing its data_ptr().
---
Outside diff comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 2125-2131: The inputs list construction currently includes
routing_logits even when it's None, which later causes dereferences like
routing_logits.shape[0]; update the code that builds the inputs list (the
variable named inputs in fused_moe/core.py) to omit routing_logits when it is
None (or replace it with a safe sentinel tensor) so downstream tuner/runner
paths never receive a None; specifically, gate the inclusion of routing_logits
in the inputs list (or ensure routing_logits is a valid tensor before building
inputs) so references to routing_logits.shape[...] are safe.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
csrc/trtllm_fused_moe_kernel_launcher.cuflashinfer/fused_moe/__init__.pyflashinfer/fused_moe/core.pytests/moe/test_trtllm_gen_routed_fused_moe.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 2087-2105: The function currently overwrites the provided output
buffer by unconditionally allocating a new tensor named output; change the
allocation logic in fused_moe.core (the function around where output is created)
to respect an incoming output parameter by only allocating torch.empty(...) when
output is None or not provided, and otherwise validate that the provided output
has the expected shape/dtype/device before using it; mirror the conditional
pattern used in trtllm_fp4_block_scale_moe_op for in-place support, and update
the function signature/annotation for output to Optional[torch.Tensor] to
reflect that None is acceptable.
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
1165-1173:⚠️ Potential issue | 🟠 MajorPrecomputed
expert_weights_instill needs device validation before raw-pointer binding.
expert_weights_in.data_ptr()is bound directly without checking it is on the same device ashidden_states. Cross-device binding here can cause invalid access during routing.🔧 Proposed fix
if (has_precomputed_weights()) { // Pre-computed expert weights: validate shape and dtype TVM_FFI_ICHECK_EQ(expert_weights_in.size(0), hidden_states.size(0)) << "expert_weights_in and hidden_states must have same number of tokens."; TVM_FFI_ICHECK_EQ(expert_weights_in.size(1), args->top_k) << "expert_weights_in dim1 must match top_k."; TVM_FFI_ICHECK_EQ(expert_weights_in.dtype(), dl_bfloat16) << "expert_weights_in must be bfloat16."; + TVM_FFI_ICHECK_EQ(expert_weights_in.device().device_type, hidden_states.device().device_type) + << "expert_weights_in must be on the same device type as hidden_states."; + TVM_FFI_ICHECK_EQ(expert_weights_in.device().device_id, hidden_states.device().device_id) + << "expert_weights_in must be on the same device id as hidden_states."; }Also applies to: 1197-1199
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1165 - 1173, The precomputed expert_weights_in is validated for shape and dtype but not for device, so before calling expert_weights_in.data_ptr() (in the has_precomputed_weights() branch and the other binding site referencing expert_weights_in.data_ptr()) add a device check to ensure expert_weights_in.device() (or .is_cuda and .device.index) matches hidden_states.device() (or the CUDA device used by hidden_states) and error out (or copy/move the tensor) with a clear message if they differ; update both the check block around has_precomputed_weights() and the other raw-pointer binding locations that use expert_weights_in.data_ptr() accordingly.flashinfer/fused_moe/core.py (1)
2087-2098:⚠️ Potential issue | 🔴 CriticalProvided
outputis not shape-validated before kernel use.Only dtype/device are validated. A mismatched shape can lead to invalid writes when the kernel assumes
[num_tokens, hidden_size].🛡️ Proposed fix
else: check_shape_dtype_device( - output, None, torch.bfloat16, hidden_states.device, "output" + output, + (num_tokens, hidden_size), + torch.bfloat16, + hidden_states.device, + "output", )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 2087 - 2098, The provided output buffer isn't shape-validated before kernel use, risking invalid writes; update the validation call in the block that handles a provided output so check_shape_dtype_device verifies shape (num_tokens, hidden_size) as well as dtype and device (i.e., replace the None shape argument with the tuple (num_tokens, hidden_size) when calling check_shape_dtype_device for "output"), ensuring the output tensor matches hidden_states' expected [num_tokens, hidden_size] layout before the kernel runs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 2069-2077: The code passes routing_logits (which can be None in
routed mode) into the AutoTuner/runner, and the runner later dereferences
routing_logits.shape[0], causing a crash; update the
trtllm_mxint4_block_scale_moe_op call site to create a safe substitute when
routing_logits is None (e.g., compute routing_dtype as you already do, then set
routing_logits_for_tuner = routing_logits if routing_logits is not None else
torch.empty((topk_ids.shape[0], 0), dtype=routing_dtype, device=topk_ids.device)
or alternatively pass topk_ids-derived shape info instead), and pass
routing_logits_for_tuner into AutoTuner/runner; repeat the same defensive
substitution where the other similar call is made (the block around the other
occurrence) so the runner never dereferences None.
---
Duplicate comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1165-1173: The precomputed expert_weights_in is validated for
shape and dtype but not for device, so before calling
expert_weights_in.data_ptr() (in the has_precomputed_weights() branch and the
other binding site referencing expert_weights_in.data_ptr()) add a device check
to ensure expert_weights_in.device() (or .is_cuda and .device.index) matches
hidden_states.device() (or the CUDA device used by hidden_states) and error out
(or copy/move the tensor) with a clear message if they differ; update both the
check block around has_precomputed_weights() and the other raw-pointer binding
locations that use expert_weights_in.data_ptr() accordingly.
In `@flashinfer/fused_moe/core.py`:
- Around line 2087-2098: The provided output buffer isn't shape-validated before
kernel use, risking invalid writes; update the validation call in the block that
handles a provided output so check_shape_dtype_device verifies shape
(num_tokens, hidden_size) as well as dtype and device (i.e., replace the None
shape argument with the tuple (num_tokens, hidden_size) when calling
check_shape_dtype_device for "output"), ensuring the output tensor matches
hidden_states' expected [num_tokens, hidden_size] layout before the kernel runs.
There was a problem hiding this comment.
♻️ Duplicate comments (1)
flashinfer/fused_moe/core.py (1)
2095-2098:⚠️ Potential issue | 🟡 MinorMissing shape validation for user-provided output tensor.
When the user provides an
outputtensor, only dtype and device are validated, but shape is not checked. Thetrtllm_fp4_block_scale_moe_op(lines 1883-1888) includes additional shape assertions that are missing here. This inconsistency could allow incorrectly sized output tensors to pass validation silently.🛡️ Proposed fix to add shape validation
else: check_shape_dtype_device( output, None, torch.bfloat16, hidden_states.device, "output" ) + assert output.shape[0] == num_tokens, ( + f"output.shape[0]={output.shape[0]} must be equal to {num_tokens}" + ) + assert output.shape[1] == hidden_size, ( + f"output.shape[1]={output.shape[1]} must be equal to {hidden_size}" + ) if routing_logits is not None:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 2095 - 2098, The user-provided output tensor currently only has dtype/device checked via check_shape_dtype_device(output, None, torch.bfloat16, hidden_states.device, "output"); update this to validate shape as well (match hidden_states' expected shape) — either call check_shape_dtype_device with the expected shape (e.g., hidden_states.shape or the computed output shape used elsewhere) instead of None, or add an explicit assertion comparing output.shape to the expected shape; mirror the shape assertions used in trtllm_fp4_block_scale_moe_op to ensure consistency.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/fused_moe/core.py`:
- Around line 2095-2098: The user-provided output tensor currently only has
dtype/device checked via check_shape_dtype_device(output, None, torch.bfloat16,
hidden_states.device, "output"); update this to validate shape as well (match
hidden_states' expected shape) — either call check_shape_dtype_device with the
expected shape (e.g., hidden_states.shape or the computed output shape used
elsewhere) instead of None, or add an explicit assertion comparing output.shape
to the expected shape; mirror the shape assertions used in
trtllm_fp4_block_scale_moe_op to ensure consistency.
|
[FAILED] Pipeline #45182758: 8/20 passed |
|
Hi @zyongye can you resolve the merge conflict? |
Head branch was pushed to by a user without write access
6131bd3 to
9f3a99d
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (2)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
1193-1201:⚠️ Potential issue | 🟠 MajorAdd device checks before binding precomputed routing buffers to raw pointers.
expert_indices/expert_weights_inare shape/dtype-checked, but not device-checked. Binding cross-device pointers into CUDA kernels is unsafe.🔧 Proposed fix
if (has_precomputed_indices()) { // Pre-computed routing: expert_indices is a packed tensor // Format: (expert_id << 16) | (weight_bf16.view(int16)) TVM_FFI_ICHECK_EQ(expert_indices.size(0), hidden_states.size(0)) << "expert_indices and hidden_states must have same number of tokens."; TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k) << "expert_indices dim1 must match top_k."; TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) << "expert_indices must be int32."; + TVM_FFI_ICHECK_EQ(expert_indices.device(), hidden_states.device()) + << "expert_indices must be on the same device as hidden_states."; } if (has_precomputed_weights()) { // Pre-computed expert weights: validate shape and dtype TVM_FFI_ICHECK_EQ(expert_weights_in.size(0), hidden_states.size(0)) << "expert_weights_in and hidden_states must have same number of tokens."; TVM_FFI_ICHECK_EQ(expert_weights_in.size(1), args->top_k) << "expert_weights_in dim1 must match top_k."; TVM_FFI_ICHECK_EQ(expert_weights_in.dtype(), dl_bfloat16) << "expert_weights_in must be bfloat16."; + TVM_FFI_ICHECK_EQ(expert_weights_in.device(), hidden_states.device()) + << "expert_weights_in must be on the same device as hidden_states."; }Also applies to: 1216-1227
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1193 - 1201, The precomputed routing arrays expert_indices and expert_weights_in are validated for shape/dtype but not checked for device, which can lead to unsafe cross-device raw pointer bindings; update the checks (around has_precomputed_weights(), and the similar block at 1216-1227) to verify each array's device type and device id match hidden_states (e.g., ensure expert_indices.device().device_type == hidden_states.device().device_type and expert_indices.device().device_id == hidden_states.device().device_id, same for expert_weights_in) and raise an error if they differ so raw pointer binding into CUDA kernels only occurs when all arrays are on the same GPU.flashinfer/fused_moe/core.py (1)
2120-2124:⚠️ Potential issue | 🔴 CriticalValidate provided
outputshape before passing it to the kernel.Current checks only enforce dtype/device. If
outputis undersized, the kernel can write out of bounds because it assumes[num_tokens, hidden_size].🐛 Proposed fix
else: check_shape_dtype_device( output, None, torch.bfloat16, hidden_states.device, "output" ) + assert output.dim() == 2, f"output must be 2D, got {output.dim()}D" + assert output.shape[0] == num_tokens, ( + f"output.shape[0]={output.shape[0]} must equal num_tokens={num_tokens}" + ) + assert output.shape[1] >= hidden_size, ( + f"output.shape[1]={output.shape[1]} must be >= hidden_size={hidden_size}" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 2120 - 2124, The code currently only checks dtype/device for `output` via `check_shape_dtype_device` but must also validate its shape to prevent out-of-bounds writes by the kernel; add a shape validation that `output.shape == (num_tokens, hidden_size)` (or raise a clear error) immediately after the dtype/device check (before the `if routing_logits is not None` branch) using the same symbols (`output`, `num_tokens`, `hidden_size`, `hidden_states`) so the kernel assumptions are enforced; keep the check adjacent to the existing `check_shape_dtype_device` call in `fused_moe/core.py`.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1193-1201: The precomputed routing arrays expert_indices and
expert_weights_in are validated for shape/dtype but not checked for device,
which can lead to unsafe cross-device raw pointer bindings; update the checks
(around has_precomputed_weights(), and the similar block at 1216-1227) to verify
each array's device type and device id match hidden_states (e.g., ensure
expert_indices.device().device_type == hidden_states.device().device_type and
expert_indices.device().device_id == hidden_states.device().device_id, same for
expert_weights_in) and raise an error if they differ so raw pointer binding into
CUDA kernels only occurs when all arrays are on the same GPU.
In `@flashinfer/fused_moe/core.py`:
- Around line 2120-2124: The code currently only checks dtype/device for
`output` via `check_shape_dtype_device` but must also validate its shape to
prevent out-of-bounds writes by the kernel; add a shape validation that
`output.shape == (num_tokens, hidden_size)` (or raise a clear error) immediately
after the dtype/device check (before the `if routing_logits is not None` branch)
using the same symbols (`output`, `num_tokens`, `hidden_size`, `hidden_states`)
so the kernel assumptions are enforced; keep the check adjacent to the existing
`check_shape_dtype_device` call in `fused_moe/core.py`.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: cc1b6b64-b38b-46c4-981a-eeaab559e41a
📒 Files selected for processing (4)
csrc/trtllm_fused_moe_kernel_launcher.cuflashinfer/fused_moe/__init__.pyflashinfer/fused_moe/core.pytests/moe/test_trtllm_gen_routed_fused_moe.py
|
/bot run |
|
[SUCCESS] Pipeline #45441610: 10/20 passed |
|
@flashinfer-bot run |
📌 Description
Add mxint4 routed moe version
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests