Add support for Relu2 in BF16 fused MoE#2864
Add support for Relu2 in BF16 fused MoE#2864amitz-nv wants to merge 3 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughActivation type is now a runtime-configurable parameter throughout the BF16 MoE stack: Python public APIs accept an Changes
Sequence Diagram(s)sequenceDiagram
participant PythonClient as Python API
participant Core as flashinfer.fused_moe.core
participant CppBinding as csrc trtllm entrypoint
participant Launcher as Bf16MoeLauncher
participant Kernel as TRT-LLM kernel
PythonClient->>Core: call trtllm_bf16_moe(..., activation_type=int)
Core->>Core: validateAndCastActivationType(int) -> ActivationType
Core->>CppBinding: trtllm_bf16_moe(..., activation_type)
CppBinding->>Launcher: init(..., activation_type)
Launcher->>Kernel: init_common(..., activation_type / isGatedActivation)
Kernel-->>Launcher: kernel configured
Launcher-->>CppBinding: ready
CppBinding-->>Core: results returned
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the BF16 Fused Mixture-of-Experts (MoE) functionality by integrating support for the Relu2 activation function. The changes involve extending the core C++ kernel and its Python bindings to allow specifying the activation type, moving beyond a fixed activation. This provides greater flexibility for model architectures utilizing BF16 MoE and is accompanied by updated test cases to confirm the new activation's behavior. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces dynamic activation function selection for BF16 Mixture-of-Experts (MoE) operations. Previously, the activation type was hardcoded to Swiglu. The changes involve modifying C++ kernel launcher signatures and implementations to accept an ActivationType parameter, propagating this parameter through the Python frontend functions, and updating test cases to reflect and validate this new configurability. Test configurations for specific models and intermediate sizes were also adjusted, and BF16 was added to the list of supported quantization modes in test utilities. I have no feedback to provide as there were no review comments.
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
…it run --all-files' Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
1681-1697:⚠️ Potential issue | 🟠 MajorValidate
activation_typebefore the BF16 cast.Line 1697 bypasses the new
validateAndCastActivationType()helper and feeds unchecked values intoisGatedActivation()/Runner. For a publicint64_tFFI parameter, bad inputs should fail here with a deterministicICHECK, not later inside runner setup.Suggested fix
- auto const activation = static_cast<ActivationType>(activation_type); + auto const activation = validateAndCastActivationType(activation_type);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1681 - 1697, The function currently casts the public int64_t activation_type directly via static_cast<ActivationType> and proceeds, which can allow invalid values into isGatedActivation() and Runner; replace that cast with a call to validateAndCastActivationType(activation_type) before any use so the value is deterministically checked (ICHECK) and converted; update all subsequent references that use activation (and any branching like isGatedActivation(activation)) to use the validated result; ensure validateAndCastActivationType is called in this function before any Runner construction or gated-activation checks.tests/moe/test_trtllm_gen_fused_moe.py (1)
1439-1443:⚠️ Potential issue | 🟠 MajorThe new gated/non-gated flag is still aliased by the permute-index cache.
Line 1443 passes
is_gated_act_gemm, but_maybe_get_cached_w3_w1_permute_indices()still memoizes only on("w3_w1", dst_w3_w1_weight.shape)inflashinfer/fused_moe/core.py. Sincecache_permute_indicesis module-scoped, a gated BF16 case can poison a later Relu2 case with the same viewed shape, making this coverage order-dependent and permuting FC1 rows incorrectly.Possible fix in
flashinfer/fused_moe/core.py- cache_key = ("w3_w1", dst_w3_w1_weight.shape) + cache_key = ( + "w3_w1", + dst_w3_w1_weight.shape, + epilogue_tile_m, + num_elts_per_sf, + is_gated_act_gemm, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 1439 - 1443, The permute-index cache (_maybe_get_cached_w3_w1_permute_indices) is currently keyed only by ("w3_w1", dst_w3_w1_weight.shape) so a cached entry from a gated BF16 case can be reused for a non-gated case; update the cache key in flashinfer/fused_moe/core.py to include the gated flag (is_gated_act_gemm) or the activation type so the memoization distinguishes gated vs non-gated variants (e.g., include is_gated_act_gemm in the tuple key when reading/writing cache_permute_indices) to prevent cross-contamination.flashinfer/fused_moe/core.py (1)
1323-1350:⚠️ Potential issue | 🟡 MinorPre-existing signature mismatch in fake op.
The
activation_typeaddition (line 1345) is correct. However, the fake op signature is missingrouted_scaling_factor: Optional[float]betweenlocal_num_expertsandrouting_method_typecompared to the real op at lines 1190-1213.This pre-existing mismatch should be addressed to ensure the fake op mirrors the real op exactly.
🔧 Proposed fix to add missing parameter
local_expert_offset: int, local_num_experts: int, + routed_scaling_factor: Optional[float], routing_method_type: int, use_shuffled_weight: bool,Based on learnings: "When reviewing files that define fake ops decorated with register_fake_op, ensure the function signatures exactly mirror the real op they stand in for."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/core.py` around lines 1323 - 1350, The fake op _fake_trtllm_bf16_moe has a signature mismatch: add the missing parameter routed_scaling_factor: Optional[float] (default None) between local_num_experts and routing_method_type so the fake op exactly mirrors the real op signature; include the parameter in the function signature (but it can remain unused) and keep the activation_type and other params unchanged to ensure parity with the real operator.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1681-1697: The function currently casts the public int64_t
activation_type directly via static_cast<ActivationType> and proceeds, which can
allow invalid values into isGatedActivation() and Runner; replace that cast with
a call to validateAndCastActivationType(activation_type) before any use so the
value is deterministically checked (ICHECK) and converted; update all subsequent
references that use activation (and any branching like
isGatedActivation(activation)) to use the validated result; ensure
validateAndCastActivationType is called in this function before any Runner
construction or gated-activation checks.
In `@flashinfer/fused_moe/core.py`:
- Around line 1323-1350: The fake op _fake_trtllm_bf16_moe has a signature
mismatch: add the missing parameter routed_scaling_factor: Optional[float]
(default None) between local_num_experts and routing_method_type so the fake op
exactly mirrors the real op signature; include the parameter in the function
signature (but it can remain unused) and keep the activation_type and other
params unchanged to ensure parity with the real operator.
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 1439-1443: The permute-index cache
(_maybe_get_cached_w3_w1_permute_indices) is currently keyed only by ("w3_w1",
dst_w3_w1_weight.shape) so a cached entry from a gated BF16 case can be reused
for a non-gated case; update the cache key in flashinfer/fused_moe/core.py to
include the gated flag (is_gated_act_gemm) or the activation type so the
memoization distinguishes gated vs non-gated variants (e.g., include
is_gated_act_gemm in the tuple key when reading/writing cache_permute_indices)
to prevent cross-contamination.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5a59e148-d201-4efe-bf55-ea14b1ac3535
📒 Files selected for processing (5)
csrc/trtllm_fused_moe_kernel_launcher.cuflashinfer/fused_moe/__init__.pyflashinfer/fused_moe/core.pytests/moe/test_trtllm_gen_fused_moe.pytests/moe/utils.py
Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
f3dae20 to
62e38fd
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/moe/test_trtllm_gen_fused_moe.py (1)
1439-1443:⚠️ Potential issue | 🟠 MajorInclude
is_gated_act_gemmin the permute-cache key.Passing the flag here still reuses whatever
_maybe_get_cached_w3_w1_permute_indices()cached first, because the helper currently keys only on("w3_w1", shape). With the module-scopedcache_permute_indicesfixture, gated and non-gated cases that collapse to the sameview(torch.uint8)shape can therefore reuse the wrong row order, so the BF16 shuffle becomes test-order dependent.Please fix this in
flashinfer/fused_moe/core.pyby keying the cache on the activation mode as well, instead of only passing the flag at the call site.Suggested helper-side fix
- cache_key = ("w3_w1", dst_w3_w1_weight.shape) + cache_key = ( + "w3_w1", + dst_w3_w1_weight.shape, + epilogue_tile_m, + num_elts_per_sf, + is_gated_act_gemm, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 1439 - 1443, The cached permute indices helper _maybe_get_cached_w3_w1_permute_indices currently keys only on ("w3_w1", shape) which allows gated and non-gated tensors with identical uint8 views to collide; change the helper to include the is_gated_act_gemm boolean in the cache key (e.g., ("w3_w1", shape, is_gated_act_gemm)) and update any cache lookups/insertions that use cache_permute_indices so gated and non-gated cases store and retrieve distinct entries while leaving the call sites (which already pass is_gated_act_gemm) unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1758-1761: trtllm_fp8_per_tensor_scale_moe currently accepts any
activation enum while trtllm_get_valid_moe_configs and
Fp8PerTensorLauncher::prepare_moe assume the gated FC1/gate-scale layout; make
them consistent by enforcing the gated-only contract at the entry point: after
calling validateAndCastActivationType(activation_type) in
trtllm_fp8_per_tensor_scale_moe, check that the returned activation is one of
the gated activation variants used by Fp8PerTensorLauncher::prepare_moe (reject
non-gated enums) and return an error (or throw) if not; alternatively, if you
prefer to permit non-gated activations, update trtllm_get_valid_moe_configs and
Fp8PerTensorLauncher::prepare_moe to accept the non-gated layout—but pick one
approach and apply it consistently across trtllm_fp8_per_tensor_scale_moe,
trtllm_get_valid_moe_configs, and Fp8PerTensorLauncher::prepare_moe so both
autotune and direct execution advertise the same activation contract.
---
Outside diff comments:
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 1439-1443: The cached permute indices helper
_maybe_get_cached_w3_w1_permute_indices currently keys only on ("w3_w1", shape)
which allows gated and non-gated tensors with identical uint8 views to collide;
change the helper to include the is_gated_act_gemm boolean in the cache key
(e.g., ("w3_w1", shape, is_gated_act_gemm)) and update any cache
lookups/insertions that use cache_permute_indices so gated and non-gated cases
store and retrieve distinct entries while leaving the call sites (which already
pass is_gated_act_gemm) unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d7102610-1265-42b5-add1-237337333f3c
📒 Files selected for processing (5)
csrc/trtllm_fused_moe_kernel_launcher.cuflashinfer/fused_moe/__init__.pyflashinfer/fused_moe/core.pytests/moe/test_trtllm_gen_fused_moe.pytests/moe/utils.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/moe/utils.py
- flashinfer/fused_moe/core.py
| bool enable_pdl, Array<int64_t> config_index, int64_t activation_type) { | ||
| // Basic type validation | ||
| auto dtype = hidden_states.dtype(); | ||
| auto activation = static_cast<ActivationType>(activation_type); | ||
| auto activation = validateAndCastActivationType(activation_type); |
There was a problem hiding this comment.
Unify the FP8 per-tensor activation contract.
trtllm_fp8_per_tensor_scale_moe() now accepts any valid activation enum, but trtllm_get_valid_moe_configs() later in this file still rejects non-gated per-tensor activations. Since Fp8PerTensorLauncher::prepare_moe() also keeps the gated FC1/gate-scale layout, direct execution and autotune currently advertise different rules for the same call. Please make these entry points agree—either reject non-gated activations here too, or lift the old restriction there.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1758 - 1761,
trtllm_fp8_per_tensor_scale_moe currently accepts any activation enum while
trtllm_get_valid_moe_configs and Fp8PerTensorLauncher::prepare_moe assume the
gated FC1/gate-scale layout; make them consistent by enforcing the gated-only
contract at the entry point: after calling
validateAndCastActivationType(activation_type) in
trtllm_fp8_per_tensor_scale_moe, check that the returned activation is one of
the gated activation variants used by Fp8PerTensorLauncher::prepare_moe (reject
non-gated enums) and return an error (or throw) if not; alternatively, if you
prefer to permit non-gated activations, update trtllm_get_valid_moe_configs and
Fp8PerTensorLauncher::prepare_moe to accept the non-gated layout—but pick one
approach and apply it consistently across trtllm_fp8_per_tensor_scale_moe,
trtllm_get_valid_moe_configs, and Fp8PerTensorLauncher::prepare_moe so both
autotune and direct execution advertise the same activation contract.
📌 Description
PR is pending uploading of Relu2 BF16 FC1 batched GEMM kernels to artifactory
activation_typeto external API:trtllm_bf16_moetrtllm_bf16_routed_moeBf16MoeLauncher::inittests/moe/test_trtllm_gen_fused_moe.pyto test that with deepseek routing, fix Nemotron 3intermediate_sizetest param to match Nemotron 3 Super.pre-commit run --all-files🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests