fix: Handle zeros in Mistral Large 3 MoE inference#2238
fix: Handle zeros in Mistral Large 3 MoE inference#2238yzh119 merged 9 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughAdds uniform tokens-per-batch support, GELU activation, pre-activation scaling propagation, TMA padding control, mmaK-aware SMEM calculations, inline BatchedGemmInterface::run with module caching and kernel launch logic, a clamped per-CTA scale in CUDA kernel, artifact path/checksum updates, and test hooks for zeroed hidden states. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant BatchedGemm as BatchedGemmInterface::run
participant Options as BatchedGemmOptions
participant KernelParams as KernelParams::setKernelParams
participant Traits as KernelTraits
participant TMA as TmaDescriptor
participant CUDA as CUDA Runtime / Module Cache
Client->>BatchedGemm: call run(config, data)
BatchedGemm->>Options: derive options (batching, activations, fp8)
BatchedGemm->>KernelParams: setKernelParams(..., ptrScaleAct, ...)
KernelParams->>Traits: init traits (mmaK, numEpilogueWarps)
Traits->>Traits: compute SMEM/TMEM using mmaK-aware functions
KernelParams->>TMA: build A/B/C descriptors (doPad/doSwizzle)
BatchedGemm->>CUDA: load/lookup module in context cache
BatchedGemm->>CUDA: launch kernel with grid/cluster params
CUDA->>Client: kernel completes / results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
📜 Recent review detailsConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @dbari, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical issue preventing the successful inference of Mistral Large 3 models using Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses a division-by-zero issue in Mistral Large 3 MoE inference when an expert's weights are all zero. The fix in activationDeepSeekKernel prevents this by adding a small epsilon to the scaling factor. The changes are accompanied by a new test case that uses zero-valued hidden states to validate the fix, and a minor include fix for GCC 11 compatibility. The overall approach is sound. I have one suggestion to make the fix in the CUDA kernel more robust by using fmaxf to avoid potential precision issues with very small, non-zero values.
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
f269974 to
70aecb7
Compare
|
The PR is ready to be reviewed, the functionality is implemented and all tests pass locally. Since the goal is to use this in vLLM and SGLang to allow Flashinfer MoE FP8 for Mistral Large 3, which release could this be integrated in? Thanks in advance for reviews and comments. |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
csrc/trtllm_fused_moe_dev_kernel.cu (1)
949-961: Apply the same division-by-zero protection here.This kernel has the same pattern as the fixed
activationDeepSeekKernel: it computes a scale by dividing byE4m3MaxVal(Line 951, 953) and later divides by that scale (Line 961). IfaMaxis zero (when all accumulated expert outputs are zero), this will cause division by zero.For consistency and robustness, apply the same fix as in
activationDeepSeekKernel:🔧 Proposed fix
if (threadIdx.x == 0) { if (params.outDqSfsPtr) { - s_scaleOut = aMax / E4m3MaxVal; + // Make sure the scale is strictly positive to avoid division by zero in case the + // maximum is zero. + s_scaleOut = fmaxf(aMax / E4m3MaxVal, std::numeric_limits<float>::min()); int const scaleOut_idx = tokenIdx + hiddenIdx / 128 * params.numTokens; - params.outDqSfsPtr[scaleOut_idx] = aMax / E4m3MaxVal; + params.outDqSfsPtr[scaleOut_idx] = s_scaleOut; } else { s_scaleOut = 1.0f; } }include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)
141-150: Avoid unused-parameter warnings ingetNumSmemBitsPerElt(mmaK).
mmaKis currently unused; if you build with-Wextra/-Werror, this can fail compilation.Proposed fix
inline int getNumSmemBitsPerElt(tg::Dtype dtype, tg::MmaKind mmaKind, int mmaK) { + (void) mmaK; if (mmaKind == tg::MmaKind::Auto) { throw std::runtime_error("mmaKind != tg::MmaKind::Auto"); } if (mmaKind == tg::MmaKind::MxFp8Fp6Fp4) { return 8; } else { return tg::dtypeGetNumBits(dtype); } }include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (1)
330-381: Input buffer size mismatch in uniform-tokens-per-batch broadcast case.When
params.batchStrideInCtas == 0, the TMA descriptor semantics (per KernelParamsDecl.h) indicate the input is broadcast across batches: logical shape is[M, K]or[N, K](single batch). However,totalNumPaddedTokensis correctly sized to one batch whilectaOffsetis unconditionally set tomaxNumCtas, and subsequent TMA descriptor shapes are computed usingctaOffset * mTile[M|N]. This sizes the descriptors formaxNumCtas * tiletokens instead of the single batch size, causing a mismatch between the declared buffer size (totalNumPaddedTokens) and the descriptor's logical access range.
🤖 Fix all issues with AI agents
In @tests/moe/utils.py:
- Around line 55-57: Update the misleading comment above the skip logic to
reflect what the code actually does: it skips zero-input tests when
zero_hidden_states is true and the implementation is NOT FP8 Block Scale MoE.
Edit the comment that currently reads "Skip checking zero input for FP8 Block
Scale MoE" to something like "Skip zero-input tests for non-FP8 Block Scale MoE
implementations" near the check using zero_hidden_states and
is_fp8_block_scale_moe so the comment matches the pytest.skip behavior.
🧹 Nitpick comments (6)
csrc/trtllm_fused_moe_dev_kernel.cu (1)
17-30: Consider adding explicit include forstd::numeric_limits.Line 305 uses
std::numeric_limits<float>::min(), but there's no explicit include for<limits>. While it may be available transitively through other headers, adding an explicit include improves code clarity and portability:#include <algorithm> #include <cub/cub.cuh> #include <cuda/functional> #include <cuda/std/functional> #include <cuda/std/type_traits> +#include <limits>Alternatively, you could use
FLT_MINfrom<cfloat>which is more commonly used in CUDA code.tests/moe/test_trtllm_gen_fused_moe.py (1)
2477-2632: Consider extending zero hidden states coverage to other routing tests.The
zero_hidden_statesparametrization is currently only added totest_renormalize_routing. Based on the PR objectives (fixing division-by-zero in quantization for Mistral Large 3), the kernel-level fix should protect against zero inputs regardless of the routing method.Should
test_deepseekv3_routing,test_topk_routing, andtest_llama4_routingalso include thezero_hidden_statesparametrization for FP8BlockScaleMoe to ensure comprehensive coverage?If this is intentional scope limitation for this PR, the current implementation is acceptable. Otherwise, consider adding the parametrization to other routing tests as well.
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (3)
169-173: Clarify invariants formNumEpilogueWarpsusage (multiple-of-4, non-negative).
You scaleextraGmemCMultiplierbynumEpilogueWarps / 4(integer division). IfnumEpilogueWarpsis not a multiple of 4, this silently rounds down. Consider asserting the expected granularity (or documenting it) at the boundary where options are validated.Also applies to: 279-281
205-245: SMEM sizing updates look consistent, but consider guarding for negative/overflow sizing.
The new sizing expressions now depend ongetNumSmemBitsPerElt(..., mmaK); that’s fine, but this path is a common source of accidental negative/overflow when tiles/stages are large. If there’s a central “options validation” phase, it’d be good to ensure these products fit inint32_tbefore storing them into the allocator vectors.
472-532: SfA/SfB TMEM sizing: make constantsconstand consider naming consistency.
kGroupSizeis constant but non-const, and the per-stage variables are duplicated for A/B. Minor readability win: mark constantsconstand consider a tiny helper to computenumColsPerStageto avoid drift between A/B.include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (1)
502-625:run()module-cache path needs error handling + thread-safety clarification.
A few concrete risks in the new inlinerun():
- You ignore return codes from
cuCtxGetCurrent,cuCtxGetId,cuModuleGetFunction,cuModuleUnload(failures can turn into confusing launch errors).ModuleCacheis a plainunordered_map; if callers share it across threads, inserts/reads race.- Keying uses
cuCtxGetId+ function name; please confirmcuCtxGetIdis available for your supported CUDA versions, or fall back to keying by theCUcontextpointer value.Possible tightening (sketch)
- cuCtxGetCurrent(&ctx); - cuCtxGetId(ctx, &ctxId); + CUresult st = cuCtxGetCurrent(&ctx); + if (st != CUDA_SUCCESS) return st; + st = cuCtxGetId(ctx, &ctxId); + if (st != CUDA_SUCCESS) return st; ... - cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName); + st = cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName); + if (st != CUDA_SUCCESS) return st; ... - cuModuleUnload(cuModule); + st = cuModuleUnload(cuModule); + if (st != CUDA_SUCCESS) return st;
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.his excluded by!**/gen/**
📒 Files selected for processing (13)
csrc/trtllm_fused_moe_dev_kernel.cuflashinfer/artifacts.pyinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.htests/moe/test_trtllm_gen_fused_moe.pytests/moe/utils.py
🧰 Additional context used
📓 Path-based instructions (3)
csrc/**/*.cu
📄 CodeRabbit inference engine (CLAUDE.md)
Framework bindings and PyTorch tensor handling should be implemented in
csrc/via TVM-FFI, not ininclude/headers
Files:
csrc/trtllm_fused_moe_dev_kernel.cu
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/moe/test_trtllm_gen_fused_moe.pytests/moe/utils.py
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/artifacts.py
🧠 Learnings (2)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures
Applied to files:
tests/moe/utils.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - `tests/conftest.py` provides auto-skipping for OOM tests as a safety net but should not be relied upon
Applied to files:
tests/moe/utils.py
🧬 Code graph analysis (5)
tests/moe/utils.py (1)
flashinfer/utils.py (1)
get_compute_capability(258-261)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
Dtype(43-274)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (3)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h (1)
RouteImpl(28-57)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (1)
gemm(148-153)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (1)
divUpMul(576-578)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (4)
std(224-625)std(234-239)std(274-295)std(285-291)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h (3)
trtllm(28-90)gen(29-89)launchKernel(34-84)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)
gemm(30-419)include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (1)
gemm(30-294)
🔇 Additional comments (21)
flashinfer/artifacts.py (2)
92-92: LGTM: Artifact path updated for fixed quantization.The TRTLLM_GEN_BMM artifact path has been updated to reference the new cubin build that includes the quantization fix for Mistral Large 3. The format is consistent with other artifact paths.
113-113: LGTM: Checksum updated consistently with artifact path.The checksum has been correctly updated to match the new TRTLLM_GEN_BMM artifact. The SHA256 format is valid (64 hex characters), and this value will be verified during cubin download by the
verify_cubin()function.csrc/trtllm_fused_moe_dev_kernel.cu (1)
302-309: LGTM! Division-by-zero protection correctly implemented.The fix properly guards against division by zero when quantizing zero activations by clamping the scale to
std::numeric_limits<float>::min(). This ensures Line 326's division byscaleOutnever encounters a zero denominator. The comment clearly documents the intent.tests/moe/test_trtllm_gen_fused_moe.py (2)
2290-2302: LGTM!The addition of the
zero_hidden_statesparameter torun_moe_testand its propagation toskip_checksis implemented correctly. The default value ofFalsemaintains backward compatibility.
2345-2348: LGTM!The conditional initialization of hidden states using either
torch.zerosortorch.randnis implemented cleanly and correctly handles both zero and random input scenarios.include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (1)
188-189: LGTM! Formatting refinement.The dumpOptions formatting has been streamlined to emit the mActType field inline rather than across multiple stream insertions, improving consistency with similar changes in GemmOptions.h.
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h (1)
93-98: LGTM! GELU activation correctly added.The GELU (Gaussian Error Linear Unit) activation is properly defined with accurate mathematical documentation. The phi-function approximation using tanh matches the standard fast GELU implementation.
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h (2)
210-215: LGTM! Pre-activation scaling support added.The new
ptrScaleActfield is well-documented and correctly scoped for non-linear activations (GELU, Relu2) when input scaling is needed. The shape specification[B]is clear.
466-487: LGTM! Uniform batching parameters added.The new fields
totalNumOutputPaddedTokens,ctasInTokenDimPerBatch, andbatchStrideInCtasare clearly documented and appropriately scoped for the uniform tokens-per-batch feature.include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (2)
19-19: LGTM! Required include added.The
<algorithm>header is correctly added to supportstd::all_ofusage in the uniform batching validation logic (lines 385-390).
378-429: Comprehensive validation for uniform tokens-per-batch feature.The validation logic is thorough and correctly enforces constraints:
- Verifies uniformity using
std::all_of- Validates batch stride alignment (either 0 or
divUpMul(firstValue, tileTokensDim))- Checks incompatibilities with DeepSeek FP8, per-token SF, bias, routing, fused activation, and block formats
- Provides clear error messages
The batch stride validation at lines 402-407 ensures proper tile-aligned padding, which is necessary for efficient batched operations.
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (2)
754-768: LGTM! MmaK default handling refined.The Fp8Fp6Fp4 MmaKind now has explicit default mmaK=32 handling in its own conditional block, improving code clarity by separating it from the MxFp4NvFp4 path.
1481-1486: LGTM! Activation scaling logic correctly implemented.The
getDoesScaleActfunction correctly determines when separate pre-activation scaling is needed:
- Returns
trueonly for non-linear activations (Gelu, Relu2)- When input scaling is required (via
getDoesScaleAb)- Returns
falsefor linear activations (None) since scaling can be fused with output scalingThis aligns with the new
ptrScaleActparameter added in KernelParamsDecl.h.include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (3)
19-23: Good:<functional>include matchesstd::reference_wrapperusage.
This should fix the GCC 11 compilation issue described.
96-124: Doc update: batchStrideInTokens==0 semantics are clearer now.
No issues—this helps disambiguate the “broadcast vs. per-batch packed” cases for A/B.Also applies to: 167-196
276-282:mPtrScaleActaddition is fine, but please verify downstream invariants withscaleC.
The comment says: whenmPtrScaleActis used,scaleCshould be “quantScaleC only”. It’s worth ensuring validation rejects inconsistent combinations (e.g., activation enabled butmPtrScaleAct==nullptrwhen required, orscaleCalready pre-multiplied).include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (2)
40-44:doPad-driven padding is a cleaner API thanmmaKind-driven branching.
This makes call sites explicit and avoids “hidden coupling” to MMA kind.Also applies to: 56-64
71-91: Double-check the intended interaction betweenpadMultiplierand swizzle/box sizing for MxInt4.
ThepadMultiplierfeedsfastestDimTileSizeBytesandnumEltsPerUInt32, which is consistent if padding is strictly a SMEM layout concern for 4-bit packing. Just verify that kernels consuming this descriptor expect these exact semantics for the padded 16B path.Also applies to: 118-124
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (3)
299-323:ptrScaleActplumbing is straightforward; keep validation in sync.
Assignment toparams.ptrScaleActis correct; just ensure any option-validation logic enforces the expected “scaleC vs scaleAct” contract when non-linear activation is enabled.
401-408:doPadA/doPadB+ updatedbuildNdTmaDescriptorcall sites look consistent.
Nice to see the padding decision localized and explicitly passed into the descriptor builder.Also applies to: 421-441, 537-554
487-491: Good: non-padded descriptors for SF/C are explicit (doPad=false).
This reduces ambiguity compared to the prior mmaKind-based selection.Also applies to: 517-520, 600-603
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
📌 Description
This PR provides fixes that are needed to infer Mistral Large 3 using
flashinfer. In the current checkpoint, one expert is zeroed out, causing a division by zero in the dynamic quantization that takes place in MoE when this particular expert is selected.Changes in this PR:
artifacts.pyactivationDeepSeekKernelMinor unrelated change:
functionalinBatchedGemmInterface.hto solve a compilation problem with GCC 11🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.