Skip to content

[feat] trtllm-gen mxfp8 gemm#2653

Merged
aleozlx merged 38 commits intoflashinfer-ai:mainfrom
IwakuraRein:mxfp8-gemm
Mar 21, 2026
Merged

[feat] trtllm-gen mxfp8 gemm#2653
aleozlx merged 38 commits intoflashinfer-ai:mainfrom
IwakuraRein:mxfp8-gemm

Conversation

@IwakuraRein
Copy link
Collaborator

@IwakuraRein IwakuraRein commented Feb 28, 2026

📌 Description

  • Create flashinfer/tllm_enums.py for storing Trtllm-gen related enums.
  • Add trtllm backend to mm_mxfp8
    • Api change: add use_8x4_sf_layout as the last argument
  • Refactor get_trtllm_gemm_module()
  • Refactor fp8Quantize.cpp. It supports either 128x4 or 8x4 swizzle layout. This is needed since the first matrix of trtllm-gen mxfp8 GEMM can be 8x4 swizzle layout.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

Unit tests: pytest tests/gemm/test_mm_mxfp8.py -k trtllm

Benchmark: python benchmarks/flashinfer_benchmark.py --routine mm_mxfp8 --backend trtllm [--use_128x4_sf_layout]

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • TRTLLM backend integrated into GEMM/FP4/MXFP8 paths and tests; new TRTLLM runners exposed.
  • Improvements

    • Centralized enums for layouts/dtypes; explicit SF-layout options (128x4, 8x4, linear).
    • Per-backend MXFP8 input/layout handling, swizzling, quantization, sparsity, and per-token pre-activation scaling enhanced.
    • More detailed tactic selection and runtime reporting.
  • Breaking Changes

    • Public APIs/signatures updated to accept explicit layout, dtype, sparsity, and new SF-layout parameters—update callers.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 28, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds TRTLLM backend and runners for quantized GEMM (MXFP8/FP4/FP8), centralizes TritLLM enums into flashinfer.tllm_enums, replaces boolean SF-layout flags with numeric/enum sf_layout across C++ and Python APIs, and extends GEMM headers/interfaces with sparsity, scheduler and TMA padding support.

Changes

Cohort / File(s) Summary
Enums & central exports
flashinfer/tllm_enums.py, flashinfer/__init__.py, flashinfer/fused_moe/core.py, flashinfer/fp4_quantization.py
New tllm_enums module introduced and used across modules; in-file enum definitions removed and public exports consolidated to tllm_enums.
TRTLLM GEMM runner & Python glue
csrc/trtllm_gemm_runner.cu, flashinfer/gemm/gemm_base.py, flashinfer/artifacts.py
Runner APIs changed to accept explicit input/output dtypes and layoutA; gemm_base exposes TRTLLM runners/dispatch; artifact path for TRTLLM GEMM updated.
FP8 quantize C++ API signatures
csrc/.../thop/fp8Quantize.cpp, csrc/.../thop/fp8Quantize.h
Replaced boolean is_sf_swizzled_layout with numeric sfSwizzleLayout (int64) in public/native signatures and internal layout selection.
Python FP8 API & wiring
flashinfer/fp8_quantization.py, flashinfer/fp4_quantization.py
Python APIs now accept/propagate numeric sf_layout / optional sf_swizzle_layout (SfLayout values); validation and kernel-call sites updated; local SfLayout removed in favor of tllm_enums.
Public GEMM headers & kernel interface
include/.../GemmInterface.h, GemmOptions.h, KernelParams.h, KernelParamsDecl.h, KernelTraits.h, TmaDescriptor.h, Enums.h
Large public header changes: new enums (EltwiseActType), scheduler values, sparsity support, new ptrScaleAct/tmaSparsityInfoA, updated GemmInterface/GemmOptions/Kernels signatures, doPad in TMA descriptor, memory/grid/workspace helpers.
Kernel params, sparsity & TMA changes
include/.../KernelParams.h, KernelParamsDecl.h, KernelTraits.h, TmaDescriptor.h
Sparsity-aware shape/stride generation, added sparsity-info descriptors and offsets, updated descriptor construction (doPad) and SF element sizing.
Benchmarks & tests
benchmarks/routines/gemm.py, benchmarks/routines/flashinfer_benchmark_utils.py, tests/gemm/test_mm_mxfp8.py
Added trtllm backend in benchmarks/tests; per-backend input preparation (reshape/shuffle/quantize) and per-backend input propagation; tests updated for SfLayout variants and cosine-sim thresholds.
Small packaging/artifact changes
flashinfer/__init__.py, flashinfer/artifacts.py
Re-export enums from tllm_enums; artifact constant for TRTLLM GEMM updated (hash/path).

Sequence Diagram

sequenceDiagram
    participant Client as Python Client
    participant API as flashinfer API
    participant Dispatcher as Backend Dispatcher
    participant TrtRunner as TRTLLM Runner
    participant KernelMod as TRTLLM Kernel Module
    participant GPU as CUDA Device

    Client->>API: call mm_mxfp8(..., backend="trtllm", sf_layout=...)
    API->>Dispatcher: select backend and prepare per-backend inputs
    Dispatcher->>TrtRunner: pass inputs, sf_layout, dtypes, options
    TrtRunner->>KernelMod: request tactics/configs (metadata)
    KernelMod->>TrtRunner: return tactics/configs
    loop try tactics
      TrtRunner->>GPU: load/instantiate kernel with chosen tactic
    end
    TrtRunner->>GPU: execute trtllm_gemm with inputs, scales, ptrScaleAct, sparsity info
    GPU->>Client: return output
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

op: gemm, run-ci

Suggested reviewers

  • aleozlx
  • Anerudhan
  • cyx-6
  • djmmoss
  • yongwww
  • bkryu
  • jimmyzho
  • jiahanc
  • nv-yunzheq
  • yzh119

Poem

🐰 I hopped through enums, tidy and bright,

swapped booleans for numbers, set layouts right.
TRTLLM shuffled scales and learned to hide gaps,
sparsity and tactics drew new little maps.
GEMM kernels hum — carrots for the night!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.78% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title '[feat] trtllm-gen mxfp8 gemm' directly describes the main feature addition of TRTLLM backend support for MXFP8 GEMM operations, which aligns with the primary change across the codebase.
Description check ✅ Passed The pull request description covers key changes (new tllm_enums.py, trtllm backend addition to mm_mxfp8 with use_8x4_sf_layout parameter, refactoring of get_trtllm_gemm_module and fp8Quantize.cpp), provides test commands, and confirms completion of pre-commit and testing checklists, though the related issues section is left empty.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can use Trivy to scan for security misconfigurations and secrets in Infrastructure as Code files.

Add a .trivyignore file to your project to customize which findings Trivy reports.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new feature that enables MxFp8 GEMM operations, significantly expanding the mixed-precision capabilities of the system. The changes involve deep modifications across the C++ and Python components, focusing on improving the flexibility and robustness of GEMM configurations. Key updates include explicit handling of input/output data types, support for valid problem dimensions, and the introduction of sparsity and advanced CUDA architecture-specific optimizations. The refactoring also centralizes enum definitions for better code maintainability.

Highlights

  • MxFp8 GEMM Support: Introduced comprehensive support for MxFp8 General Matrix Multiplication (GEMM) within the TensorRT-LLM generation framework, enhancing mixed-precision computation capabilities.
  • Refactored GEMM Runner Logic: The core trtllm_gemm_runner.cu file has been significantly refactored to explicitly handle various input and output data types, and to incorporate mValidM/N/K dimensions for more precise problem definition.
  • Centralized Enum Definitions: All TensorRT-LLM related IntEnum definitions and helper functions have been moved to a new flashinfer/tllm_enums.py file, improving code organization and reusability across the Python codebase.
  • Enhanced GEMM Options and Interface: The GemmOptions.h and GemmInterface.h headers received numerous updates, including support for EltwiseActType, Sparsity, flexible cluster dimensions, and refined scaling factor block size handling, enabling more advanced GEMM configurations.
  • Sparsity and CUDA Architecture Declarations: New header files CudaArchDecl.h and SparsityDecl.h were added to define CUDA architecture versions and structured sparsity modes, respectively, providing clearer abstractions for these concepts.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_gemm_runner.cu
    • Improved error messages for tactic selection failures, providing more detailed context.
    • Added mValidM, mValidN, and mValidK fields to gemmData.mProblemDimensions for accurate problem dimension tracking.
    • Modified selectHeuristic to generalize tactic selection based on eltType instead of specific Dtype enums.
    • Updated trtllm_gemm and trtllm_gemm_tactics function signatures to accept explicit input_dtype and output_dtype parameters.
    • Removed the local Dtype enum definition, now using the centralized gemm::trtllm::gen::Dtype.
  • flashinfer/fused_moe/core.py
    • Removed local IntEnum definitions (RoutingMethodType, ActivationType, DtypeTrtllmGen, WeightLayout, GatedActType, Fp8QuantizationType) and their associated helper functions.
    • Imported all relevant enum definitions from the new flashinfer.tllm_enums module.
  • flashinfer/gemm/gemm_base.py
    • Imported DtypeTrtllmGen from flashinfer.tllm_enums.
    • Removed the standalone get_trtllm_gemm_module function, integrating its functionality into a new TrtllmGemmRunner class.
    • Added use_8x4_sf_layout parameter to _check_mm_mxfp8_problem_size and _cutlass_gemm_mxfp8_requirement.
    • Introduced _trtllm_gemm_mxfp8_requirement to define requirements for the TRT-LLM backend for MxFp8 GEMM.
    • Updated _heuristic_func_mm_mxfp8 to consider the trtllm backend.
    • Modified mm_mxfp8 function signature to include use_8x4_sf_layout and support trtllm as a backend option.
    • Updated the backend_to_runner_factory dictionary to include the trtllm backend for MxFp8 GEMM.
    • Modified mm_fp4 to utilize the refactored get_trtllm_gemm_module structure.
    • Updated gemm_fp8_nt_groupwise to pass input_dtype and output_dtype to the underlying trtllm_gemm call.
    • Refactored get_trtllm_gemm_module to return a SimpleNamespace containing factories for generic, FP4, and MxFp8 TRT-LLM GEMM runners.
  • flashinfer/tllm_enums.py
    • Added a new Python file to centralize IntEnum definitions for RoutingMethodType, ActivationType, DtypeTrtllmGen, WeightLayout, GatedActType, and Fp8QuantizationType.
    • Included helper functions trtllm_gen_dtype_has_scale and deduce_trtllm_gen_tensor_dtype in this new module.
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h
    • Updated the copyright year to 2026.
    • Added EltwiseActType enum to specify element-wise activation functions.
    • Expanded TileScheduler enum with StaticPersistent and PersistentSm90 options.
    • Introduced isPersistentScheduler and supportsCleanEarlyExit helper functions.
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h
    • Updated the copyright year to 2026.
    • Removed static assertions and namespace related to TLLM_GEN_GEMM_CUBIN_PATH.
    • Added mValidM, mValidN, and mValidK to GemmData::mProblemDimensions to specify valid data ranges.
    • Updated documentation for GemmData::InputBuffers::mPtrA and mPtrSfA to clarify sparsity and scaling block size implications.
    • Added mPtrSparsityInfoA to GemmData::InputBuffers for structured sparsity metadata.
    • Updated documentation for GemmData::InputBuffers::mPtrSfB regarding scaling block sizes.
    • Added mPtrScaleAct to GemmData::InputBuffers for pre-activation scaling factors.
    • Updated documentation for GemmData::OutputBuffers::mPtrSfC regarding scaling block sizes.
    • Refactored GemmInterface constructor to accept rankId, exportsCubin, and numRotations.
    • Inlined implementations for getWorkspaceSizeInBytes, getGemmConfigs, getNumGemmConfigs, getOptionsFromConfigAndData, and isValidConfig.
    • Modified run method to support TLLM_GEN_EXPORT_INTERFACE and TLLM_GEN_EXPORT_FLASHINFER preprocessor directives, and flexible grid sizing.
    • Added getFixedGridSize helper function for fixed grid dimension calculations.
    • Removed alignPtr from the GemmInterface class.
    • Updated runInitBeforeWorldSync to correctly use getGridSize.
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h
    • Updated the copyright year to 2026.
    • Included CudaArchDecl.h and SparsityDecl.h for new enum definitions.
    • Defined GEMM_UPDATE_OR_ERROR macro for conditional option updates or error returns.
    • Updated GemmOptions constructor and members to include mClcFastDrain, mEltwiseActType, mFallbackClusterDimX/Y/Z, mFuseUtccpWithUtcmma, mNumEpilogueWarps, mNumRegsCopySparsityInfo, mSfBlockSizeB/C, mSparsityA, mUseFlexibleClusterDims, mUseMaxTmemOverlap, and mValidM/N/K.
    • Changed mSfBlockSizeA from an std::optional<int32_t> to a direct int.
    • Renamed mUseShuffledMatrixA to mUseShuffledMatrix for broader applicability.
    • Updated SmVersion to use tg::CudaArch for consistency.
    • Modified dumpOptions to reflect new members and conditionally dump runtime parameters.
    • Added srcToDstBlk16RowMap, srcToDstBlk32RowMap vectors, and getShuffleIndices function for matrix shuffling.
    • Extensively updated checkAndUpdateGemmOptions to use tg::CudaArch, handle mValidM/N/K initialization and validation, add checks for BlockMajorK layout, support MxInt4 in A casts, include sparsity checks, refine mmaK logic for sparse matrices, and update scaling factor block size logic.
    • Added TLLM_PUBLIC_RELEASE specific checks for E2m1 to E4m3 cast and mPatchF2fp.
    • Updated mSfLayoutB supported layouts and mDtypeC block format checks.
    • Refined mUseShuffledMatrix checks and added new validations for mClusterDimX, mClusterDimY, mUseFlexibleClusterDims, mUseMaxTmemOverlap, and mNumEpilogueWarps.
    • Updated KernelTraits initialization with the new set of parameters.
    • Introduced helper functions getDoesScaleC, getDoesScaleAb, getDoesScaleAct, getKernelDoesScaleC for scaling factor logic.
    • Added loadCubinData template function to manage cubin loading based on TLLM_GEN_EXPORT_FLASHINFER and TLLM_GEN_GEMM_CUBIN_PATH.
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParams.h
    • Updated the copyright year to 2026.
    • Included <tuple> and SparsityDecl.h.
    • Modified makeTmaShapeStrideAb to account for isSparseA, validM/N/K dimensions, and doPad for TMA descriptor building.
    • Added makeTmaShapeStrideSparsityInfoA to create TMA shape/stride for sparsity information.
    • Updated makeTmaShapeStrideSfAb to accept numEltsPerSf as a parameter.
    • Extended setKernelParams function signature to include ptrSparsityInfoA and ptrScaleAct.
    • Updated setKernelParams logic to handle isSparseA, doPadA/B, dTypeSfA/B for MxInt4, and to build tmaSparsityInfoA.
    • Modified buildNdTmaDescriptor calls for tmaA, tmaB, and tmaC to include the doPad parameter.
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParamsDecl.h
    • Updated the copyright year to 2026.
    • Included <cuda.h>.
    • Updated comment for tmaSfB to include MxInt4 format.
    • Added tmaSparsityInfoA member for structured sparsity information.
    • Added ptrScaleAct member for pre-activation scaling factors.
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelTraits.h
    • Updated the copyright year to 2026.
    • Included <cstdio> and SparsityDecl.h.
    • Modified getNumSmemBitsPerElt to accept mmaK and isSparseA parameters.
    • Updated KernelTraits constructor signature to include sparsityA, numEltsPerSfA/B, fuseUtccpWithUtcmma, useMaxTmemOverlap, and numEpilogueWarps.
    • Adjusted KernelTraits constructor logic to handle isSparseA for shared memory allocation of LoadA/B.
    • Added numEpilogueWarps multiplier for extraGmemCMultiplier calculation.
    • Added a new shared memory chunk for smemSparsityInfoA.
    • Split smemPerTokenSf into separate smemPerTokenSfA and smemPerTokenSfB chunks.
    • Removed a TODO comment related to smemBlockAmax.
    • Updated the print output for shared memory chunks to reflect smemPerTokenSfA/B.
    • Modified numTmemColsD calculation to incorporate mUseMaxTmemOverlap.
    • Updated numTmemColsSfA/B calculations to use kGroupSize, numEltsPerSfA/B, and mFuseUtccpWithUtcmma.
    • Added a new TMEM chunk for tmemSparsityInfoA.
    • Added mFuseUtccpWithUtcmma, mUseMaxTmemOverlap, and mNumEpilogueWarps as members.
    • Introduced getSmemOffsetPerTokenSfA/B and getSmemOffsetSparsityInfoA functions.
    • Added getTmemOffsetSparsityInfoA function.
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/TmaDescriptor.h
    • Updated the copyright year to 2026.
    • Included <sstream> for string stream operations.
    • Removed cutlass/cutlass.h and cutlass/half.h includes.
    • Modified buildNdTmaDescriptor signature to remove mmaKind and add a doPad boolean parameter.
    • Updated buildNdTmaDescriptor logic for tmaDataFormat to support UInt8 and MxInt4.
    • Adjusted buildNdTmaDescriptor logic for swizzleType to include E2m1 and UInt8.
    • Updated buildSfTmaDescriptor logic for tmaDataFormat to support Bfloat16.
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CommonUtils.h
    • Updated the copyright year to 2026.
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CudaArchDecl.h
    • Added a new header file defining the CudaArch enum (Sm90a, Sm100a, Sm100f, Sm103a).
    • Included helper functions isArchHopper, isArchBlackwell, cudaArchToString, and stringToCudaArch.
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CudaKernelLauncher.h
    • Updated the copyright year to 2026.
    • Introduced launchKernelFlexibleCgaSizes to support kernel launches with flexible cluster dimensions.
    • Updated launchKernel to explicitly use CU_LAUNCH_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES and CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION attributes.
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/DtypeDecl.h
    • Updated the copyright year to 2026.
    • Added MxInt4 to the Dtype enum definition.
    • Updated dtypeToString to include the string representation for MxInt4.
    • Modified dtypeNumEltsPerSf to accept a useSparsity parameter and include MxInt4.
    • Updated dtypeGetBlockSfType to correctly return Bfloat16 for MxInt4.
    • Added dtypeNeedsPadding function to determine if a dtype requires padding based on MMA kind and sparsity.
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/MmaDecl.h
    • Updated the copyright year to 2026.
    • Modified getTmemColStridePerGroup to accept kGroupSize and use mmaMn for more accurate stride calculation.
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/SfLayoutDecl.h
    • Updated the copyright year to 2026.
    • Removed a specific NVIDIA bug reference from the R128c4 layout description.
    • Added sfLayoutCanUseUtccp function to check if a scaling factor layout can use UTCCP.
  • include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/SparsityDecl.h
    • Added a new header file defining the Sparsity enum (Dense, Any_1_2, Any_2_4, Pairwise_4_8) for structured sparsity modes.
    • Included helper functions isSparse, sparsityToString, getSparsityChunkSize, and getNumBytesSparsityInfo.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for mxfp8 GEMM using trtllm-gen kernels, which involves a substantial refactoring of both Python and C++ code to create a more generalized GEMM infrastructure. The Python code has been improved by centralizing enums and creating a generic GEMM runner factory. The C++ side sees extensive updates to support new hardware features like sparsity and flexible cluster dimensions. The changes are well-structured and enhance the project's capabilities. I have one suggestion to improve code clarity and remove a redundancy in the newly added enum definitions.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
…antize

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein IwakuraRein changed the title [feat[ trtllm-gen mxfp8 gemm [feat] trtllm-gen mxfp8 gemm Mar 9, 2026
@IwakuraRein IwakuraRein marked this pull request as ready for review March 9, 2026 21:19
@IwakuraRein
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !394 has been updated with latest changes, and the CI pipeline #46296482 is currently running. I'll report back once the pipeline job completes.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !394 has been updated with latest changes, and the CI pipeline #46298268 is currently running. I'll report back once the pipeline job completes.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46298268: 10/20 passed

@IwakuraRein
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !394 has been updated with latest changes, and the CI pipeline #46494877 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46494877: 6/20 passed

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !394 has been updated with latest changes, and the CI pipeline #46541481 is currently running. I'll report back once the pipeline job completes.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #46541481: canceled

@IwakuraRein
Copy link
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !394 has been updated with latest changes, and the CI pipeline #46551653 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx enabled auto-merge (squash) March 20, 2026 00:32
@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #46551653: 14/20 passed

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@aleozlx aleozlx merged commit c938604 into flashinfer-ai:main Mar 21, 2026
29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants