Skip to content

fix: Handle zeros in Mistral Large 3 MoE inference#2238

Merged
yzh119 merged 9 commits intoflashinfer-ai:mainfrom
dbari:dbariamis/fix-mistral-large-3
Jan 14, 2026
Merged

fix: Handle zeros in Mistral Large 3 MoE inference#2238
yzh119 merged 9 commits intoflashinfer-ai:mainfrom
dbari:dbariamis/fix-mistral-large-3

Conversation

@dbari
Copy link
Copy Markdown
Contributor

@dbari dbari commented Dec 18, 2025

📌 Description

This PR provides fixes that are needed to infer Mistral Large 3 using flashinfer. In the current checkpoint, one expert is zeroed out, causing a division by zero in the dynamic quantization that takes place in MoE when this particular expert is selected.

Changes in this PR:

  • Use new cubins with fixed quantization in artifacts.py
  • Handle zeros in quantization in activationDeepSeekKernel
  • Add test for zero input to verify fix

Minor unrelated change:

  • Add include for functional in BatchedGemmInterface.h to solve a compilation problem with GCC 11

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • GELU activation added for batched operations.
    • Uniform tokens-per-batch batching mode and related batching metadata.
    • Per-token/pre-activation scaling support for output activations.
  • Bug Fixes

    • Clamped scale values to ensure positive, stable scaling and avoid division-by-zero in MOE paths.
  • Tests

    • MOE tests extended to cover both random and zero-filled hidden-state initialization.
  • Chores

    • Updated internal runtime artifact references and checksums.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Dec 18, 2025

📝 Walkthrough

Walkthrough

Adds uniform tokens-per-batch support, GELU activation, pre-activation scaling propagation, TMA padding control, mmaK-aware SMEM calculations, inline BatchedGemmInterface::run with module caching and kernel launch logic, a clamped per-CTA scale in CUDA kernel, artifact path/checksum updates, and test hooks for zeroed hidden states.

Changes

Cohort / File(s) Summary
Kernel behavior & safety
csrc/trtllm_fused_moe_dev_kernel.cu
Clamp per-CTA scale to a positive minimum before storing to s_scaleOutArr and params.outDqSfsPtr to avoid zero/negative divides.
Artifact manifest
flashinfer/artifacts.py
Updated TRTLLM_GEN_BMM artifact path and checksum constants.
Interface & runtime flow
BatchedGemmInterface
include/.../BatchedGemmInterface.h
Inlined BatchedGemmInterface::run, added workspace/barrier allocation, kernel module caching/loading, kernel launch and optional PDL checks; added mPtrScaleAct to BatchedGemmData.
Options & activation
include/.../BatchedGemmOptions.h, include/.../Enums.h, include/.../GemmGatedActOptions.h, include/.../GemmOptions.h
Added mBatchStrideInTokens, mIsUniformNumTokensPerBatch with validation and uniform-batch checks; added Gelu enum; formatting and dump changes; added getDoesScaleAct helper and refined MmaK handling logic.
Kernel params & descriptors
include/.../KernelParams.h, include/.../KernelParamsDecl.h, include/.../TmaDescriptor.h
Added ptrScaleAct, totalNumOutputPaddedTokens, ctasInTokenDimPerBatch, batchStrideInCtas; extended setKernelParams signature to accept ptrScaleAct; TMA descriptor API changed to accept doPad flag (removed mmaKind-based padding).
Kernel traits & memory calc
include/.../KernelTraits.h
getNumSmemBitsPerElt now takes mmaK; added mNumEpilogueWarps; refactored SMEM/TMEM and per-stage SF calculations to use mmaK-aware paths.
Tests & utilities
tests/moe/test_trtllm_gen_fused_moe.py, tests/moe/utils.py
Introduced zero_hidden_states test param; tests can use zeroed hidden states and skip when incompatible with MoE impl.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant BatchedGemm as BatchedGemmInterface::run
    participant Options as BatchedGemmOptions
    participant KernelParams as KernelParams::setKernelParams
    participant Traits as KernelTraits
    participant TMA as TmaDescriptor
    participant CUDA as CUDA Runtime / Module Cache

    Client->>BatchedGemm: call run(config, data)
    BatchedGemm->>Options: derive options (batching, activations, fp8)
    BatchedGemm->>KernelParams: setKernelParams(..., ptrScaleAct, ...)
    KernelParams->>Traits: init traits (mmaK, numEpilogueWarps)
    Traits->>Traits: compute SMEM/TMEM using mmaK-aware functions
    KernelParams->>TMA: build A/B/C descriptors (doPad/doSwizzle)
    BatchedGemm->>CUDA: load/lookup module in context cache
    BatchedGemm->>CUDA: launch kernel with grid/cluster params
    CUDA->>Client: kernel completes / results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • joker-eph
  • cyx-6
  • wenscarl
  • yzh119
  • jiahanc
  • aleozlx
  • nvmbreughe
  • bkryu

Poem

🐇 I scaled each CTA with careful care,
Gelu tucked neatly in activation's lair,
Tokens march uniform, paddings align,
Modules cached, kernels launched on time—
A hopping rabbit cheers the code's fine flair!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 53.85% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main fix: handling zeros in Mistral Large 3 MoE inference, which matches the primary changes throughout the PR.
Description check ✅ Passed The description is comprehensive and well-structured, covering all major changes, including quantization fixes, test additions, and compilation fixes with clear explanations of why each change was needed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

📜 Recent review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 70aecb7 and ed24b65.

📒 Files selected for processing (1)
  • tests/moe/utils.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/moe/utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @dbari, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical issue preventing the successful inference of Mistral Large 3 models using flashinfer, where a zeroed-out expert could lead to a division-by-zero error during dynamic quantization in the MoE layer. The core change involves adding a small epsilon to the scaling factor to ensure numerical stability. To validate this fix, a new test scenario has been introduced that specifically checks behavior with zero inputs. Additionally, a minor, unrelated change was made to include a necessary header for GCC 11 compatibility.

Highlights

  • Division by Zero Fix: Implemented a fix in the activationDeepSeekKernel to prevent division by zero during dynamic quantization in Mixture-of-Experts (MoE) inference, specifically addressing an issue encountered with Mistral Large 3 where an expert's output could be zero.
  • New Test Case for Zero Input: Added a new test case to verify the robustness of the MoE quantization process when presented with zero input hidden states, ensuring the division-by-zero fix is effective.
  • GCC 11 Compilation Fix: Included <functional> header in BatchedGemmInterface.h to resolve a compilation error that occurred with GCC 11.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a division-by-zero issue in Mistral Large 3 MoE inference when an expert's weights are all zero. The fix in activationDeepSeekKernel prevents this by adding a small epsilon to the scaling factor. The changes are accompanied by a new test case that uses zero-valued hidden states to validate the fix, and a minor include fix for GCC 11 compatibility. The overall approach is sound. I have one suggestion to make the fix in the CUDA kernel more robust by using fmaxf to avoid potential precision issues with very small, non-zero values.

@dbari dbari force-pushed the dbariamis/fix-mistral-large-3 branch from f269974 to 70aecb7 Compare January 9, 2026 10:18
@dbari dbari marked this pull request as ready for review January 9, 2026 10:19
@dbari
Copy link
Copy Markdown
Contributor Author

dbari commented Jan 9, 2026

The PR is ready to be reviewed, the functionality is implemented and all tests pass locally.

Since the goal is to use this in vLLM and SGLang to allow Flashinfer MoE FP8 for Mistral Large 3, which release could this be integrated in?

Thanks in advance for reviews and comments.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
csrc/trtllm_fused_moe_dev_kernel.cu (1)

949-961: Apply the same division-by-zero protection here.

This kernel has the same pattern as the fixed activationDeepSeekKernel: it computes a scale by dividing by E4m3MaxVal (Line 951, 953) and later divides by that scale (Line 961). If aMax is zero (when all accumulated expert outputs are zero), this will cause division by zero.

For consistency and robustness, apply the same fix as in activationDeepSeekKernel:

🔧 Proposed fix
       if (threadIdx.x == 0) {
         if (params.outDqSfsPtr) {
-          s_scaleOut = aMax / E4m3MaxVal;
+          // Make sure the scale is strictly positive to avoid division by zero in case the
+          // maximum is zero.
+          s_scaleOut = fmaxf(aMax / E4m3MaxVal, std::numeric_limits<float>::min());
           int const scaleOut_idx = tokenIdx + hiddenIdx / 128 * params.numTokens;
-          params.outDqSfsPtr[scaleOut_idx] = aMax / E4m3MaxVal;
+          params.outDqSfsPtr[scaleOut_idx] = s_scaleOut;
         } else {
           s_scaleOut = 1.0f;
         }
       }
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)

141-150: Avoid unused-parameter warnings in getNumSmemBitsPerElt (mmaK).
mmaK is currently unused; if you build with -Wextra/-Werror, this can fail compilation.

Proposed fix
 inline int getNumSmemBitsPerElt(tg::Dtype dtype, tg::MmaKind mmaKind, int mmaK) {
+  (void) mmaK;
   if (mmaKind == tg::MmaKind::Auto) {
     throw std::runtime_error("mmaKind != tg::MmaKind::Auto");
   }
   if (mmaKind == tg::MmaKind::MxFp8Fp6Fp4) {
     return 8;
   } else {
     return tg::dtypeGetNumBits(dtype);
   }
 }
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (1)

330-381: Input buffer size mismatch in uniform-tokens-per-batch broadcast case.

When params.batchStrideInCtas == 0, the TMA descriptor semantics (per KernelParamsDecl.h) indicate the input is broadcast across batches: logical shape is [M, K] or [N, K] (single batch). However, totalNumPaddedTokens is correctly sized to one batch while ctaOffset is unconditionally set to maxNumCtas, and subsequent TMA descriptor shapes are computed using ctaOffset * mTile[M|N]. This sizes the descriptors for maxNumCtas * tile tokens instead of the single batch size, causing a mismatch between the declared buffer size (totalNumPaddedTokens) and the descriptor's logical access range.

🤖 Fix all issues with AI agents
In @tests/moe/utils.py:
- Around line 55-57: Update the misleading comment above the skip logic to
reflect what the code actually does: it skips zero-input tests when
zero_hidden_states is true and the implementation is NOT FP8 Block Scale MoE.
Edit the comment that currently reads "Skip checking zero input for FP8 Block
Scale MoE" to something like "Skip zero-input tests for non-FP8 Block Scale MoE
implementations" near the check using zero_hidden_states and
is_fp8_block_scale_moe so the comment matches the pytest.skip behavior.
🧹 Nitpick comments (6)
csrc/trtllm_fused_moe_dev_kernel.cu (1)

17-30: Consider adding explicit include for std::numeric_limits.

Line 305 uses std::numeric_limits<float>::min(), but there's no explicit include for <limits>. While it may be available transitively through other headers, adding an explicit include improves code clarity and portability:

 #include <algorithm>
 #include <cub/cub.cuh>
 #include <cuda/functional>
 #include <cuda/std/functional>
 #include <cuda/std/type_traits>
+#include <limits>

Alternatively, you could use FLT_MIN from <cfloat> which is more commonly used in CUDA code.

tests/moe/test_trtllm_gen_fused_moe.py (1)

2477-2632: Consider extending zero hidden states coverage to other routing tests.

The zero_hidden_states parametrization is currently only added to test_renormalize_routing. Based on the PR objectives (fixing division-by-zero in quantization for Mistral Large 3), the kernel-level fix should protect against zero inputs regardless of the routing method.

Should test_deepseekv3_routing, test_topk_routing, and test_llama4_routing also include the zero_hidden_states parametrization for FP8BlockScaleMoe to ensure comprehensive coverage?

If this is intentional scope limitation for this PR, the current implementation is acceptable. Otherwise, consider adding the parametrization to other routing tests as well.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (3)

169-173: Clarify invariants for mNumEpilogueWarps usage (multiple-of-4, non-negative).
You scale extraGmemCMultiplier by numEpilogueWarps / 4 (integer division). If numEpilogueWarps is not a multiple of 4, this silently rounds down. Consider asserting the expected granularity (or documenting it) at the boundary where options are validated.

Also applies to: 279-281


205-245: SMEM sizing updates look consistent, but consider guarding for negative/overflow sizing.
The new sizing expressions now depend on getNumSmemBitsPerElt(..., mmaK); that’s fine, but this path is a common source of accidental negative/overflow when tiles/stages are large. If there’s a central “options validation” phase, it’d be good to ensure these products fit in int32_t before storing them into the allocator vectors.


472-532: SfA/SfB TMEM sizing: make constants const and consider naming consistency.
kGroupSize is constant but non-const, and the per-stage variables are duplicated for A/B. Minor readability win: mark constants const and consider a tiny helper to compute numColsPerStage to avoid drift between A/B.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (1)

502-625: run() module-cache path needs error handling + thread-safety clarification.
A few concrete risks in the new inline run():

  • You ignore return codes from cuCtxGetCurrent, cuCtxGetId, cuModuleGetFunction, cuModuleUnload (failures can turn into confusing launch errors).
  • ModuleCache is a plain unordered_map; if callers share it across threads, inserts/reads race.
  • Keying uses cuCtxGetId + function name; please confirm cuCtxGetId is available for your supported CUDA versions, or fall back to keying by the CUcontext pointer value.
Possible tightening (sketch)
-      cuCtxGetCurrent(&ctx);
-      cuCtxGetId(ctx, &ctxId);
+      CUresult st = cuCtxGetCurrent(&ctx);
+      if (st != CUDA_SUCCESS) return st;
+      st = cuCtxGetId(ctx, &ctxId);
+      if (st != CUDA_SUCCESS) return st;
...
-        cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName);
+        st = cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName);
+        if (st != CUDA_SUCCESS) return st;
...
-      cuModuleUnload(cuModule);
+      st = cuModuleUnload(cuModule);
+      if (st != CUDA_SUCCESS) return st;
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bd2b033 and 70aecb7.

⛔ Files ignored due to path filters (1)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h is excluded by !**/gen/**
📒 Files selected for processing (13)
  • csrc/trtllm_fused_moe_dev_kernel.cu
  • flashinfer/artifacts.py
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h
  • tests/moe/test_trtllm_gen_fused_moe.py
  • tests/moe/utils.py
🧰 Additional context used
📓 Path-based instructions (3)
csrc/**/*.cu

📄 CodeRabbit inference engine (CLAUDE.md)

Framework bindings and PyTorch tensor handling should be implemented in csrc/ via TVM-FFI, not in include/ headers

Files:

  • csrc/trtllm_fused_moe_dev_kernel.cu
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/moe/test_trtllm_gen_fused_moe.py
  • tests/moe/utils.py
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/artifacts.py
🧠 Learnings (2)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures

Applied to files:

  • tests/moe/utils.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - `tests/conftest.py` provides auto-skipping for OOM tests as a safety net but should not be relied upon

Applied to files:

  • tests/moe/utils.py
🧬 Code graph analysis (5)
tests/moe/utils.py (1)
flashinfer/utils.py (1)
  • get_compute_capability (258-261)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h (1)
  • Dtype (43-274)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (3)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h (1)
  • RouteImpl (28-57)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (1)
  • gemm (148-153)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (1)
  • divUpMul (576-578)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (4)
  • std (224-625)
  • std (234-239)
  • std (274-295)
  • std (285-291)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h (3)
  • trtllm (28-90)
  • gen (29-89)
  • launchKernel (34-84)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (2)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h (1)
  • gemm (30-419)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (1)
  • gemm (30-294)
🔇 Additional comments (21)
flashinfer/artifacts.py (2)

92-92: LGTM: Artifact path updated for fixed quantization.

The TRTLLM_GEN_BMM artifact path has been updated to reference the new cubin build that includes the quantization fix for Mistral Large 3. The format is consistent with other artifact paths.


113-113: LGTM: Checksum updated consistently with artifact path.

The checksum has been correctly updated to match the new TRTLLM_GEN_BMM artifact. The SHA256 format is valid (64 hex characters), and this value will be verified during cubin download by the verify_cubin() function.

csrc/trtllm_fused_moe_dev_kernel.cu (1)

302-309: LGTM! Division-by-zero protection correctly implemented.

The fix properly guards against division by zero when quantizing zero activations by clamping the scale to std::numeric_limits<float>::min(). This ensures Line 326's division by scaleOut never encounters a zero denominator. The comment clearly documents the intent.

tests/moe/test_trtllm_gen_fused_moe.py (2)

2290-2302: LGTM!

The addition of the zero_hidden_states parameter to run_moe_test and its propagation to skip_checks is implemented correctly. The default value of False maintains backward compatibility.


2345-2348: LGTM!

The conditional initialization of hidden states using either torch.zeros or torch.randn is implemented cleanly and correctly handles both zero and random input scenarios.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h (1)

188-189: LGTM! Formatting refinement.

The dumpOptions formatting has been streamlined to emit the mActType field inline rather than across multiple stream insertions, improving consistency with similar changes in GemmOptions.h.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h (1)

93-98: LGTM! GELU activation correctly added.

The GELU (Gaussian Error Linear Unit) activation is properly defined with accurate mathematical documentation. The phi-function approximation using tanh matches the standard fast GELU implementation.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h (2)

210-215: LGTM! Pre-activation scaling support added.

The new ptrScaleAct field is well-documented and correctly scoped for non-linear activations (GELU, Relu2) when input scaling is needed. The shape specification [B] is clear.


466-487: LGTM! Uniform batching parameters added.

The new fields totalNumOutputPaddedTokens, ctasInTokenDimPerBatch, and batchStrideInCtas are clearly documented and appropriately scoped for the uniform tokens-per-batch feature.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h (2)

19-19: LGTM! Required include added.

The <algorithm> header is correctly added to support std::all_of usage in the uniform batching validation logic (lines 385-390).


378-429: Comprehensive validation for uniform tokens-per-batch feature.

The validation logic is thorough and correctly enforces constraints:

  • Verifies uniformity using std::all_of
  • Validates batch stride alignment (either 0 or divUpMul(firstValue, tileTokensDim))
  • Checks incompatibilities with DeepSeek FP8, per-token SF, bias, routing, fused activation, and block formats
  • Provides clear error messages

The batch stride validation at lines 402-407 ensures proper tile-aligned padding, which is necessary for efficient batched operations.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h (2)

754-768: LGTM! MmaK default handling refined.

The Fp8Fp6Fp4 MmaKind now has explicit default mmaK=32 handling in its own conditional block, improving code clarity by separating it from the MxFp4NvFp4 path.


1481-1486: LGTM! Activation scaling logic correctly implemented.

The getDoesScaleAct function correctly determines when separate pre-activation scaling is needed:

  • Returns true only for non-linear activations (Gelu, Relu2)
  • When input scaling is required (via getDoesScaleAb)
  • Returns false for linear activations (None) since scaling can be fused with output scaling

This aligns with the new ptrScaleAct parameter added in KernelParamsDecl.h.

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h (3)

19-23: Good: <functional> include matches std::reference_wrapper usage.
This should fix the GCC 11 compilation issue described.


96-124: Doc update: batchStrideInTokens==0 semantics are clearer now.
No issues—this helps disambiguate the “broadcast vs. per-batch packed” cases for A/B.

Also applies to: 167-196


276-282: mPtrScaleAct addition is fine, but please verify downstream invariants with scaleC.
The comment says: when mPtrScaleAct is used, scaleC should be “quantScaleC only”. It’s worth ensuring validation rejects inconsistent combinations (e.g., activation enabled but mPtrScaleAct==nullptr when required, or scaleC already pre-multiplied).

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h (2)

40-44: doPad-driven padding is a cleaner API than mmaKind-driven branching.
This makes call sites explicit and avoids “hidden coupling” to MMA kind.

Also applies to: 56-64


71-91: Double-check the intended interaction between padMultiplier and swizzle/box sizing for MxInt4.
The padMultiplier feeds fastestDimTileSizeBytes and numEltsPerUInt32, which is consistent if padding is strictly a SMEM layout concern for 4-bit packing. Just verify that kernels consuming this descriptor expect these exact semantics for the padded 16B path.

Also applies to: 118-124

include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (3)

299-323: ptrScaleAct plumbing is straightforward; keep validation in sync.
Assignment to params.ptrScaleAct is correct; just ensure any option-validation logic enforces the expected “scaleC vs scaleAct” contract when non-linear activation is enabled.


401-408: doPadA/doPadB + updated buildNdTmaDescriptor call sites look consistent.
Nice to see the padding decision localized and explicitly passed into the descriptor builder.

Also applies to: 421-441, 537-554


487-491: Good: non-padded descriptors for SF/C are explicit (doPad=false).
This reduces ambiguity compared to the prior mmaKind-based selection.

Also applies to: 517-520, 600-603

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
@yzh119 yzh119 requested a review from nekorobov January 9, 2026 18:34
@yzh119 yzh119 merged commit 36380e2 into flashinfer-ai:main Jan 14, 2026
6 checks passed
@dbari dbari deleted the dbariamis/fix-mistral-large-3 branch January 15, 2026 08:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants