Skip to content

Implement override shape support for cuDNN GEMM operations#2790

Merged
yanqinz2 merged 12 commits intomainfrom
yanqinz/dynamic-shape
Mar 19, 2026
Merged

Implement override shape support for cuDNN GEMM operations#2790
yanqinz2 merged 12 commits intomainfrom
yanqinz/dynamic-shape

Conversation

@yanqinz2
Copy link
Copy Markdown
Collaborator

@yanqinz2 yanqinz2 commented Mar 15, 2026

📌 Description

Add override shape support for cudnn backend with test examples.

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Summary by CodeRabbit

  • New Features

    • cuDNN override-shape support for GEMM so a single cached graph can handle multiple M dimensions at runtime without rebuilds.
    • Extended support to BF16, FP4 (NVFP4), MXFP8 and per-tensor FP8 data types.
    • Added a runtime availability check to safely gate override-shape usage on unsupported backends.
  • Tests

    • Added end-to-end tests validating override-shape GEMM across supported data types and dynamic M scenarios.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the flexibility of GEMM operations by integrating cuDNN's override shape feature. This allows the system to execute pre-compiled GEMM graphs with dynamically changing M dimensions, which is crucial for performance in scenarios with variable input sizes, such as large language model inference. The changes include new functions for building and executing these dynamic-shape-enabled graphs for various data types, along with robust version checks and comprehensive tests.

Highlights

  • Dynamic Shape Support for cuDNN GEMM: Introduced the capability to use cuDNN's override shape feature for various GEMM operations (BF16, FP4, MXFP8, and FP8 per-tensor quantized), allowing for dynamic M dimensions without graph recompilation.
  • cuDNN Version Compatibility Check: Added a mechanism to check for the minimum required cuDNN backend version (9.21.0) to ensure the override shape functionality is supported.
  • New Test Suite for Dynamic Shapes: Included a dedicated test file to validate the correctness and efficiency of the cuDNN override shape GEMM implementations across different dynamic M dimensions for BF16, NVFP4, and MXFP8 data types.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/gemm/init.py
    • Imported new functions related to cuDNN override shape support for BF16, FP4, MXFP8, and FP8 per-tensor GEMM operations.
  • flashinfer/gemm/gemm_base.py
    • Added CUDNN_MIN_VERSION_OVERRIDE_SHAPE constant and utility functions _check_cudnn_override_shape_availability and is_cudnn_override_shape_available to manage cuDNN version requirements.
    • Implemented build_cudnn_fp4_gemm_graph_override_shape and execute_cudnn_fp4_gemm_graph_override_shape for FP4 GEMM with dynamic M dimension support.
    • Implemented build_cudnn_mxfp8_gemm_graph_override_shape and execute_cudnn_mxfp8_gemm_graph_override_shape for MXFP8 GEMM with dynamic M dimension support.
    • Implemented build_cudnn_gemm_with_per_tensor_q_graph_override_shape and execute_cudnn_gemm_with_per_tensor_q_graph_override_shape for FP8 per-tensor quantized GEMM with dynamic M dimension support.
    • Implemented build_cudnn_gemm_bf16_graph_override_shape and execute_cudnn_gemm_bf16_graph_override_shape for BF16 GEMM with dynamic M dimension support.
    • Added corresponding internal helper functions like _cudnn_gemm_fp4_override_shape, _cudnn_gemm_mxfp8_override_shape, and _cudnn_gemm_fp8_override_shape.
  • tests/gemm/test_cudnn_override_shape.py
    • Added a new test file to verify the functionality of cuDNN override shape for BF16, NVFP4, and MXFP8 GEMM operations.
    • Included tests that parametrize cache_m and dynamic_ms to ensure a single compiled graph can handle various M dimensions.
    • Added skip conditions for environments without cuDNN, insufficient cuDNN backend version, or non-SM100+ compute capability.
Activity
  • The author, yanqinz2, created this pull request with a clear description of adding override shape support for the cuDNN backend with test examples.
  • Pre-commit checks were marked as completed by the author.
  • Tests were indicated as needing to be added or updated, and all tests passing was marked as pending.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 15, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds cuDNN "override-shape" GEMM support: runtime availability checks, cached graph builders/executors (BF16, FP4, MXFP8, per-tensor FP8) allowing a single compiled graph to execute with varying M via override_shapes/override_strides.

Changes

Cohort / File(s) Summary
Module Exports
flashinfer/gemm/__init__.py
Re-exports new override-shape symbols from flashinfer.gemm_base and extends __all__ with availability check plus paired build_*/execute_*_override_shape APIs.
Override-Shape GEMM Implementation
flashinfer/gemm/gemm_base.py
Adds _check_cudnn_override_shape_availability, is_cudnn_override_shape_available, _OVERRIDE_SHAPE_CACHE_M, multiple build_*_override_shape / execute_*_override_shape / _get_*_override_shape implementations for BF16, FP4, MXFP8, per-tensor FP8, data-type mapping updates, 3D/packed-shape helpers, and wrapper routing to enable dynamic-M execution using override shapes/strides.
Tests
tests/gemm/test_cudnn_override_shape.py
New test module that compiles graphs with a cached M and exercises dynamic M values via override_shapes/override_strides for BF16, NVFP4, and MXFP8, with environment and backend-version guards.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test Suite
    participant Lib as flashinfer.gemm (API)
    participant Cache as Graph Cache
    participant cuDNN as cuDNN Runtime

    Test->>Lib: call build_*_override_shape(cache_m, ...)
    Lib->>cuDNN: compile graph (reserve/cache with cache_m)
    Lib->>Cache: store compiled graph keyed by params+cache_m
    Test->>Lib: call execute_*_override_shape(graph, a, b, override_shapes, override_strides)
    Lib->>Cache: fetch compiled graph
    Lib->>cuDNN: execute graph with override_shapes/override_strides (dynamic M)
    cuDNN-->>Lib: return result tensor
    Lib-->>Test: return output (no rebuild)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

op: gemm

Suggested reviewers

  • aleozlx
  • yongwww
  • bkryu
  • jimmyzho
  • nv-yunzheq

Poem

🐰 I built one graph and left it wide,
Many M’s to hop through, no need to hide.
BF16, FP4, MXFP8 in tow,
Override the shape — watch performance grow. ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 53.85% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically summarizes the main change: implementing override shape support for cuDNN GEMM operations, which aligns with the substantial additions of override-shape functions and public API exports.
Description check ✅ Passed The description covers the main objective (override shape support for cuDNN) and confirms tests were added. Pre-commit checks are marked complete. However, it lacks detail on the 'Related Issues' section and does not fully confirm all tests are passing.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch yanqinz/dynamic-shape
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

You can enable review details to help with troubleshooting, context usage and more.

Enable the reviews.review_details setting to include review details such as the model used, the time taken for each step and more in the review comments.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces dynamic shape support for the cuDNN backend, which is a significant feature for improving performance by avoiding graph recompilations. The implementation adds several new functions for building and executing cuDNN graphs with override shapes for BF16, FP4, and FP8 GEMM operations. The changes are well-structured, but I've identified a few critical issues in the FP4 implementation and the new tests.

Specifically, there appears to be a bug in the shape calculation for FP4 graphs and an issue with workspace buffer handling that could lead to performance degradation. Additionally, the new tests for FP4 and MXFP8 dynamic shapes are missing correctness assertions, which is a crucial gap in validation. My review provides detailed feedback and suggestions to address these points.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/__init__.py`:
- Around line 21-32: The new override-shape symbols imported from gemm_base
(is_cudnn_override_shape_available, CUDNN_MIN_VERSION_OVERRIDE_SHAPE,
build_cudnn_gemm_bf16_graph_override_shape,
execute_cudnn_gemm_bf16_graph_override_shape,
build_cudnn_fp4_gemm_graph_override_shape,
execute_cudnn_fp4_gemm_graph_override_shape,
build_cudnn_mxfp8_gemm_graph_override_shape,
execute_cudnn_mxfp8_gemm_graph_override_shape,
build_cudnn_gemm_with_per_tensor_q_graph_override_shape,
execute_cudnn_gemm_with_per_tensor_q_graph_override_shape) are not listed in the
module's __all__ export list; update the __all__ variable in
flashinfer.gemm.__init__ to include these exact symbol names so they are
exported on from flashinfer.gemm import * and by re-exports.

In `@flashinfer/gemm/gemm_base.py`:
- Around line 2007-2021: The override path recomputes FP4 packed shapes/strides
(e.g., a_shape, a_stride, b_shape, b_stride, a_descale_shape, a_descale_stride,
b_descale_shape, b_descale_stride) instead of reusing the canonical helper,
causing batched layouts to diverge; replace the manual math with a call to the
existing helper _get_real_fp4_shape_from_packed_uint8 (and use its returned
shape/stride tuples) and keep _calculate_block_scale_dims only for block scale
dims, ensuring both this block (around a_shape/a_stride) and the similar block
at the other location (around lines mentioned) use the same helper-derived FP4
metadata so packed-FP4 layout matches the non-override path.
- Around line 1977-1993: The cached builder
build_cudnn_fp4_gemm_graph_override_shape currently includes a_descale_n_dim in
its parameter list, which varies with M and causes unnecessary cache churn;
remove a_descale_n_dim from the cached function signature (and from the
analogous cached helper at the other location) so the cache key no longer
depends on M, update callers (e.g., where
_get_cudnn_fp4_gemm_graph_override_shape passes expanded_a_descale_shape[1]) to
stop passing that component, and if the builder needs a Descales shape value
compute or derive it inside the function from the stable values you already pass
(or ignore it if unused) to preserve a single reusable plan.
- Around line 1727-1746: The current checks only inspect cudnn.backend_version()
but must also gate on the cuDNN-frontend Python package (frontend) supporting
override-shape; update both _check_cudnn_override_shape_availability and
is_cudnn_override_shape_available to also verify the frontend version or feature
presence: e.g., check a frontend version string (cudnn.__version__ or
cudnn.frontend.__version__) >= the minimum frontend release that added
override-shape, and/or perform feature detection with hasattr / inspection
(confirm presence of the called parameter or the newer execute(...) API with
override_uids/override_shapes/override_strides using inspect.signature or a
small try/except TypeError probe). If the frontend is too old or the feature is
missing, return False in is_cudnn_override_shape_available and raise a clear
RuntimeError in _check_cudnn_override_shape_availability; reference the
functions _check_cudnn_override_shape_availability and
is_cudnn_override_shape_available when making the changes.

In `@tests/gemm/test_cudnn_override_shape.py`:
- Around line 194-221: The test currently only runs
execute_cudnn_fp4_gemm_graph_override_shape in the dynamic_ms loop and
synchronizes, so failures like zeros/NaNs or wrong numerics won't be caught;
update the loop that calls execute_cudnn_fp4_gemm_graph_override_shape (and the
similar loop at 292-316) to validate outputs: compute a reference output (e.g.,
by calling the static-shape path such as execute_cudnn_fp4_gemm_graph or another
known-good implementation with the same inputs a_packed, b_packed, a_descale,
b_descale, workspace, tactic) and then assert torch.isfinite(out).all() and that
out is close to the reference using torch.allclose or
torch.testing.assert_allclose with a small atol/rtol appropriate for FP4/FP8; if
you prefer a cheaper check add at minimum finiteness plus a basic relative-error
threshold, and use the exact symbols
execute_cudnn_fp4_gemm_graph_override_shape, execute_cudnn_fp4_gemm_graph,
a_packed, b_packed, a_descale, b_descale, out, workspace, tactic so the
assertions are added next to the existing invocation.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2d681695-bc90-4fa7-9aa3-d593f21b2991

📥 Commits

Reviewing files that changed from the base of the PR and between 4781b42 and 1803c4e.

📒 Files selected for processing (3)
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_cudnn_override_shape.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
flashinfer/gemm/gemm_base.py (3)

2146-2158: ⚠️ Potential issue | 🟠 Major

FP4 override shape/stride should reuse the canonical helper to avoid layout drift.

Manual * 2 shape/stride math here can diverge from _get_real_fp4_shape_from_packed_uint8() behavior (especially across layout/batch cases).

Suggested fix
-    override_shapes = [
-        [a.shape[0], a.shape[1], a.shape[2] * 2],
-        [b.shape[0], b.shape[1] * 2, b.shape[2]],
+    real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a)
+    real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b)
+    override_shapes = [
+        list(real_a_shape),
+        list(real_b_shape),
         a_descale.shape,
         b_descale.shape,
         c_final.shape,
     ]
     override_strides = [
-        [a.stride()[0], a.stride()[1] * 2, a.stride()[2]],
-        [b.stride()[0], b.stride()[1], b.stride()[2] * 2],
+        list(real_a_stride),
+        list(real_b_stride),
         a_descale.stride(),
         b_descale.stride(),
         c_final.stride(),
     ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2146 - 2158, The override_shapes
and override_strides calculations currently hard-code "* 2" adjustments which
can drift from the canonical logic; replace the manual math with calls to the
canonical helper _get_real_fp4_shape_from_packed_uint8() (and its
stride-equivalent or by deriving strides from that helper) for tensors a and b
so the computed shapes/strides for a, b (and keep a_descale.shape,
b_descale.shape, c_final.shape and their strides) match the canonical
FP4-unpacking behavior and avoid layout/batch mismatches.

1977-1993: ⚠️ Potential issue | 🟠 Major

a_descale_n_dim still leaks M into the FP4 override graph cache key.

build_cudnn_fp4_gemm_graph_override_shape() is cached, and a_descale_n_dim varies with M (Line 2203-2204). That defeats single-plan reuse across dynamic M.

Suggested fix
 `@functools.cache`
 def build_cudnn_fp4_gemm_graph_override_shape(
     batch,
     n,
     k,
-    a_descale_n_dim,
     a_descale_k_dim,
     b_descale_k_dim,
     b_descale_n_dim,
@@
-    # a_descale N-dimension (dim[1]) depends on M, so we pass it separately
-    a_descale_n_dim = expanded_a_descale_shape[1]
-
     return build_cudnn_fp4_gemm_graph_override_shape(
         batch=batch,
         n=n,
         k=k,
-        a_descale_n_dim=a_descale_n_dim,
         a_descale_k_dim=a_descale_k_dim,
         b_descale_k_dim=b_descale_k_dim,
         b_descale_n_dim=b_descale_n_dim,

Also applies to: 2203-2211

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 1977 - 1993, The cached function
build_cudnn_fp4_gemm_graph_override_shape is leaking M into its cache key via
the a_descale_n_dim parameter; remove M-varying data from the cached signature
by eliminating a_descale_n_dim (or replacing it with a normalized/boolean
indicator) and instead compute any M-dependent descaling inside the function
body or upstream per-call (so the cache key remains stable). Update callers that
pass a_descale_n_dim to instead pass the fixed/normalized indicator (or stop
passing it) and adjust logic in build_cudnn_fp4_gemm_graph_override_shape (and
the similar code around the 2203-2211 region) to derive the actual per-M descale
values at runtime rather than as part of the cached parameters.

1727-1747: ⚠️ Potential issue | 🟠 Major

Override-shape availability check is still backend-only and can misreport support.

_check_cudnn_override_shape_availability() / is_cudnn_override_shape_available() only gate on cudnn.backend_version(). That can return True even when the installed Python frontend lacks is_override_shape_enabled/override execute kwargs, causing runtime failures later. Also, Line 1746 catches a blind Exception, masking real causes.

For NVIDIA cuDNN Python frontend, which version first supports `pygraph(..., is_override_shape_enabled=True)` and `execute_plan_at_index(..., override_uids, override_shapes, override_strides)`? Is backend_version() alone sufficient to guarantee those APIs exist?
Suggested fix
+import inspect
+
+def _has_cudnn_override_shape_frontend() -> bool:
+    try:
+        pygraph_sig = inspect.signature(cudnn.pygraph)
+        if "is_override_shape_enabled" not in pygraph_sig.parameters:
+            return False
+    except (AttributeError, TypeError, ValueError):
+        return False
+    return True
+
 def _check_cudnn_override_shape_availability():
     _check_cudnn_availability()
+    if not _has_cudnn_override_shape_frontend():
+        raise RuntimeError("cuDNN frontend override-shape API is unavailable.")
     backend_version = cudnn.backend_version()
     if backend_version < CUDNN_MIN_VERSION_OVERRIDE_SHAPE:
         raise RuntimeError(...)
 
 def is_cudnn_override_shape_available() -> bool:
     if not CUDNN_AVAILABLE:
         return False
     try:
-        return cudnn.backend_version() >= CUDNN_MIN_VERSION_OVERRIDE_SHAPE
-    except Exception:
+        return _has_cudnn_override_shape_frontend() and (
+            cudnn.backend_version() >= CUDNN_MIN_VERSION_OVERRIDE_SHAPE
+        )
+    except (AttributeError, TypeError, ValueError):
         return False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 1727 - 1747, The current checks in
_check_cudnn_override_shape_availability and is_cudnn_override_shape_available
rely only on cudnn.backend_version() which can misreport support if the Python
frontend lacks the new APIs; update both functions to also verify the frontend
provides the required symbols (e.g., hasattr(cudnn, "is_override_shape_enabled")
and hasattr(cudnn, "execute_plan_at_index")) and, for execute_plan_at_index,
optionally inspect its signature to ensure it accepts
override_uids/override_shapes/override_strides; if the frontend checks fail,
raise a clear RuntimeError in _check_cudnn_override_shape_availability and
return False in is_cudnn_override_shape_available. Replace the broad except
Exception in is_cudnn_override_shape_available with targeted exception handling
(AttributeError or TypeError) so real errors are not masked, and still respect
CUDNN_AVAILABLE and the CUDNN_MIN_VERSION_OVERRIDE_SHAPE version check.
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)

2008-2008: Use _-prefixed unpacking for intentionally unused values.

A few unpacked vars are unused (block_scale_dim_k, real_a_stride, real_b_stride, expanded_a_descale_stride, expanded_b_descale_stride). Prefixing with _ keeps intent clear and quiets lint noise.

Also applies to: 2249-2255

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` at line 2008, The unpacking in gemm_base.py
assigns several variables that are intentionally unused (e.g.,
block_scale_dim_k, real_a_stride, real_b_stride, expanded_a_descale_stride,
expanded_b_descale_stride); update those unpack targets to use a leading
underscore (for example _block_scale_dim_k, _real_a_stride, etc.) so the intent
is clear and linters stop flagging them. Locate the unpack expressions such as
the call to _calculate_block_scale_dims (where block_scale_dim_m, _,
block_scale_dim_k = ...) and the unpackings around lines 2249-2255, and rename
each intentionally unused variable to an _-prefixed name while leaving used
locals unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2113-2171: The override-shape executors
execute_cudnn_fp4_gemm_graph_override_shape and
execute_cudnn_mxfp8_gemm_graph_override_shape call graph.execute_plan_at_index
without verifying the provided workspace_buffer is large enough; add a guard
that fetches required_size = graph.get_workspace_size(plan_index or default) and
compares required_size to workspace_buffer.numel(), and if workspace_buffer is
too small either raise a clear error (including required_size and provided size)
or reallocate/resize the workspace_buffer before calling
graph.execute_plan_at_index; locate the check around the call site to
graph.execute_plan_at_index in both functions and perform this size validation
using the same handle/stream logic already present.

---

Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2146-2158: The override_shapes and override_strides calculations
currently hard-code "* 2" adjustments which can drift from the canonical logic;
replace the manual math with calls to the canonical helper
_get_real_fp4_shape_from_packed_uint8() (and its stride-equivalent or by
deriving strides from that helper) for tensors a and b so the computed
shapes/strides for a, b (and keep a_descale.shape, b_descale.shape,
c_final.shape and their strides) match the canonical FP4-unpacking behavior and
avoid layout/batch mismatches.
- Around line 1977-1993: The cached function
build_cudnn_fp4_gemm_graph_override_shape is leaking M into its cache key via
the a_descale_n_dim parameter; remove M-varying data from the cached signature
by eliminating a_descale_n_dim (or replacing it with a normalized/boolean
indicator) and instead compute any M-dependent descaling inside the function
body or upstream per-call (so the cache key remains stable). Update callers that
pass a_descale_n_dim to instead pass the fixed/normalized indicator (or stop
passing it) and adjust logic in build_cudnn_fp4_gemm_graph_override_shape (and
the similar code around the 2203-2211 region) to derive the actual per-M descale
values at runtime rather than as part of the cached parameters.
- Around line 1727-1747: The current checks in
_check_cudnn_override_shape_availability and is_cudnn_override_shape_available
rely only on cudnn.backend_version() which can misreport support if the Python
frontend lacks the new APIs; update both functions to also verify the frontend
provides the required symbols (e.g., hasattr(cudnn, "is_override_shape_enabled")
and hasattr(cudnn, "execute_plan_at_index")) and, for execute_plan_at_index,
optionally inspect its signature to ensure it accepts
override_uids/override_shapes/override_strides; if the frontend checks fail,
raise a clear RuntimeError in _check_cudnn_override_shape_availability and
return False in is_cudnn_override_shape_available. Replace the broad except
Exception in is_cudnn_override_shape_available with targeted exception handling
(AttributeError or TypeError) so real errors are not masked, and still respect
CUDNN_AVAILABLE and the CUDNN_MIN_VERSION_OVERRIDE_SHAPE version check.

---

Nitpick comments:
In `@flashinfer/gemm/gemm_base.py`:
- Line 2008: The unpacking in gemm_base.py assigns several variables that are
intentionally unused (e.g., block_scale_dim_k, real_a_stride, real_b_stride,
expanded_a_descale_stride, expanded_b_descale_stride); update those unpack
targets to use a leading underscore (for example _block_scale_dim_k,
_real_a_stride, etc.) so the intent is clear and linters stop flagging them.
Locate the unpack expressions such as the call to _calculate_block_scale_dims
(where block_scale_dim_m, _, block_scale_dim_k = ...) and the unpackings around
lines 2249-2255, and rename each intentionally unused variable to an _-prefixed
name while leaving used locals unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 51adb811-10c1-4625-904c-063b0b235aa4

📥 Commits

Reviewing files that changed from the base of the PR and between 1803c4e and 9c21ddf.

📒 Files selected for processing (2)
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/gemm_base.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (3)
flashinfer/gemm/gemm_base.py (2)

2204-2228: ⚠️ Potential issue | 🔴 Critical

Remove the stale descale-dimension kwargs from the cached FP4 builder call.

build_cudnn_fp4_gemm_graph_override_shape() no longer accepts a_descale_n_dim, a_descale_k_dim, b_descale_k_dim, or b_descale_n_dim. This call now raises TypeError before the override graph is built.

Suggested diff
-    expanded_a_descale_shape, _ = _expand_block_scale_tensor_shape(a_descale, batch)
-    expanded_b_descale_shape, _ = _expand_block_scale_tensor_shape(b_descale, batch)
-
-    # Scale dimension sizes that are independent of M
-    a_descale_k_dim = expanded_a_descale_shape[2]
-    b_descale_k_dim = expanded_b_descale_shape[1]
-    b_descale_n_dim = expanded_b_descale_shape[2]
-    # a_descale N-dimension (dim[1]) depends on M, so we pass it separately
-    a_descale_n_dim = expanded_a_descale_shape[1]
-
     return build_cudnn_fp4_gemm_graph_override_shape(
         batch=batch,
         n=n,
         k=k,
-        a_descale_n_dim=a_descale_n_dim,
-        a_descale_k_dim=a_descale_k_dim,
-        b_descale_k_dim=b_descale_k_dim,
-        b_descale_n_dim=b_descale_n_dim,
         ab_type=cudnn.data_type.FP4_E2M1,
         o_type=_torch_data_type_to_cudnn_data_type(out_dtype),
         block_size=block_size,
         device=a.device,
         alpha_is_not_none=alpha is not None,
         use_nvfp4=use_nvfp4,
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2204 - 2228, The call to
build_cudnn_fp4_gemm_graph_override_shape in the gemm builder is passing stale
kwargs a_descale_n_dim, a_descale_k_dim, b_descale_k_dim, and b_descale_n_dim
which the function no longer accepts; remove those four keyword arguments from
the return call and only pass the remaining valid parameters (batch, n, k,
ab_type, o_type, block_size, device, alpha_is_not_none, use_nvfp4) so the cached
FP4 builder call no longer raises a TypeError in
build_cudnn_fp4_gemm_graph_override_shape.

2154-2166: ⚠️ Potential issue | 🔴 Critical

Reuse the canonical FP4 shape/stride helper here.

The manual override metadata only doubles the inner stride. For batched packed FP4 tensors, the batch stride also needs to be doubled, otherwise batch > 1 executions read the wrong logical layout. Reusing _get_real_fp4_shape_from_packed_uint8() keeps the override path aligned with the non-override path.

Suggested diff
-    override_shapes = [
-        [a.shape[0], a.shape[1], a.shape[2] * 2],
-        [b.shape[0], b.shape[1] * 2, b.shape[2]],
+    real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a)
+    real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b)
+    override_shapes = [
+        list(real_a_shape),
+        list(real_b_shape),
         a_descale.shape,
         b_descale.shape,
         c_final.shape,
     ]
     override_strides = [
-        [a.stride()[0], a.stride()[1] * 2, a.stride()[2]],
-        [b.stride()[0], b.stride()[1], b.stride()[2] * 2],
+        list(real_a_stride),
+        list(real_b_stride),
         a_descale.stride(),
         b_descale.stride(),
         c_final.stride(),
     ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2154 - 2166, The override path
currently builds override_shapes/override_strides manually and only doubles the
inner stride, which breaks batched packed FP4 tensors; replace the manual
constructions for a and b with calls to the canonical helper
_get_real_fp4_shape_from_packed_uint8(a) and
_get_real_fp4_shape_from_packed_uint8(b) (or its stride-aware variant) to
produce the correct shape and stride tuples, then set override_shapes to
[real_shape_a, real_shape_b, a_descale.shape, b_descale.shape, c_final.shape]
and override_strides to [real_stride_a, real_stride_b, a_descale.stride(),
b_descale.stride(), c_final.stride()] so batch and inner strides are adjusted
consistently with the non-override path.
tests/gemm/test_cudnn_override_shape.py (1)

186-213: ⚠️ Potential issue | 🟠 Major

Add output checks to the FP4/MXFP8 override-shape loops.

These cases only launch the kernels and synchronize, so numerically wrong outputs still pass. Please compare out against the corresponding static-shape path, or at minimum assert something meaningful about the result inside each loop.

Also applies to: 284-308

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_cudnn_override_shape.py` around lines 186 - 213, The
FP4/MXFP8 override-shape loops currently only launch kernels and synchronize
without validating results; update the loop that iterates dynamic_ms (and the
similar loop later) to compute a reference result using the static-shape path
(or a known-correct function) and compare it to out after
execute_cudnn_fp4_gemm_graph_override_shape, e.g. call the same inputs through
the static graph/function that produces the expected tensor and assert
torch.allclose(out, expected, atol=..., rtol=...) or at minimum assert
non-NaN/non-zero statistics; reference symbols to update: dynamic_ms loop,
execute_cudnn_fp4_gemm_graph_override_shape, out, a_packed, b_packed, a_descale,
b_descale, workspace, and the static-shape executor used elsewhere in the test
so the override-shape branch actually verifies numerical correctness.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2204-2228: The call to build_cudnn_fp4_gemm_graph_override_shape
in the gemm builder is passing stale kwargs a_descale_n_dim, a_descale_k_dim,
b_descale_k_dim, and b_descale_n_dim which the function no longer accepts;
remove those four keyword arguments from the return call and only pass the
remaining valid parameters (batch, n, k, ab_type, o_type, block_size, device,
alpha_is_not_none, use_nvfp4) so the cached FP4 builder call no longer raises a
TypeError in build_cudnn_fp4_gemm_graph_override_shape.
- Around line 2154-2166: The override path currently builds
override_shapes/override_strides manually and only doubles the inner stride,
which breaks batched packed FP4 tensors; replace the manual constructions for a
and b with calls to the canonical helper
_get_real_fp4_shape_from_packed_uint8(a) and
_get_real_fp4_shape_from_packed_uint8(b) (or its stride-aware variant) to
produce the correct shape and stride tuples, then set override_shapes to
[real_shape_a, real_shape_b, a_descale.shape, b_descale.shape, c_final.shape]
and override_strides to [real_stride_a, real_stride_b, a_descale.stride(),
b_descale.stride(), c_final.stride()] so batch and inner strides are adjusted
consistently with the non-override path.

In `@tests/gemm/test_cudnn_override_shape.py`:
- Around line 186-213: The FP4/MXFP8 override-shape loops currently only launch
kernels and synchronize without validating results; update the loop that
iterates dynamic_ms (and the similar loop later) to compute a reference result
using the static-shape path (or a known-correct function) and compare it to out
after execute_cudnn_fp4_gemm_graph_override_shape, e.g. call the same inputs
through the static graph/function that produces the expected tensor and assert
torch.allclose(out, expected, atol=..., rtol=...) or at minimum assert
non-NaN/non-zero statistics; reference symbols to update: dynamic_ms loop,
execute_cudnn_fp4_gemm_graph_override_shape, out, a_packed, b_packed, a_descale,
b_descale, workspace, and the static-shape executor used elsewhere in the test
so the override-shape branch actually verifies numerical correctness.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ecf9f8ce-2c28-4a2d-bd66-f66115001391

📥 Commits

Reviewing files that changed from the base of the PR and between 9c21ddf and 5ff1beb.

📒 Files selected for processing (3)
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_cudnn_override_shape.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/gemm/test_cudnn_override_shape.py`:
- Around line 294-296: The tuple unpacking binds an unused variable
block_scale_dim_m_cache from the _calculate_block_scale_dims(...) call; to fix,
drop or rename that binding to a throwaway name (e.g., replace
"block_scale_dim_m_cache, block_scale_dim_n, block_scale_dim_k =
(_calculate_block_scale_dims(cache_m, n, k, block_size))" with "_,
block_scale_dim_n, block_scale_dim_k = _calculate_block_scale_dims(cache_m, n,
k, block_size)") so Ruff RUF059 is resolved while preserving the call and the
other two bindings.
- Around line 317-320: Test sampling only used non-negative MXFP8 patterns;
change the uint8 sampling so it covers the full 0..255 space (including values
with the sign bit set) while still avoiding the NaN bit-patterns 0x7F and 0xFF:
generate b with torch.randint(0, 256, (1, n, k), dtype=torch.uint8,
device=device).transpose(1,2) and then replace any occurrences of 0x7F or 0xFF
with other finite encodings (e.g., random picks from 0..254 excluding 0x7F), so
the tensor b exercises negative E4M3 encodings; apply the same change to the
other occurrence at lines 332-333.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 23175cf6-6c1b-4dce-beab-fdd6a55349b2

📥 Commits

Reviewing files that changed from the base of the PR and between 5ff1beb and b4c92c6.

📒 Files selected for processing (1)
  • tests/gemm/test_cudnn_override_shape.py

@yzh119 yzh119 changed the title Yanqinz/dynamic shape Implement override shape support for cuDNN GEMM operations Mar 17, 2026
@yzh119 yzh119 added the run-ci label Mar 17, 2026
@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Mar 17, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !421 has been created, and the CI pipeline #46370787 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #46370787: 14/20 passed

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
flashinfer/gemm/gemm_base.py (2)

2205-2228: ⚠️ Potential issue | 🔴 Critical

Drop the removed FP4 cache-key kwargs from this helper.

build_cudnn_fp4_gemm_graph_override_shape() no longer accepts the descale-dimension kwargs, but this helper still passes them. The first call into _get_cudnn_fp4_gemm_graph_override_shape() will raise TypeError instead of returning a graph.

Suggested fix
-    expanded_a_descale_shape, _ = _expand_block_scale_tensor_shape(a_descale, batch)
-    expanded_b_descale_shape, _ = _expand_block_scale_tensor_shape(b_descale, batch)
-
-    # Scale dimension sizes that are independent of M
-    a_descale_k_dim = expanded_a_descale_shape[2]
-    b_descale_k_dim = expanded_b_descale_shape[1]
-    b_descale_n_dim = expanded_b_descale_shape[2]
-    # a_descale N-dimension (dim[1]) depends on M, so we pass it separately
-    a_descale_n_dim = expanded_a_descale_shape[1]
-
     return build_cudnn_fp4_gemm_graph_override_shape(
         batch=batch,
         n=n,
         k=k,
-        a_descale_n_dim=a_descale_n_dim,
-        a_descale_k_dim=a_descale_k_dim,
-        b_descale_k_dim=b_descale_k_dim,
-        b_descale_n_dim=b_descale_n_dim,
         ab_type=cudnn.data_type.FP4_E2M1,
         o_type=_torch_data_type_to_cudnn_data_type(out_dtype),
         block_size=block_size,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2205 - 2228, The helper currently
computes expanded descale dims via _expand_block_scale_tensor_shape and then
passes a_descale_n_dim, a_descale_k_dim, b_descale_k_dim, b_descale_n_dim into
build_cudnn_fp4_gemm_graph_override_shape, but that function no longer accepts
those kwargs; remove those four descale-dimension keyword arguments from the
call (keep all other args like batch, n, k, ab_type, o_type, block_size, device,
alpha_is_not_none, use_nvfp4, etc.) so the call matches the new
build_cudnn_fp4_gemm_graph_override_shape/_get_cudnn_fp4_gemm_graph_override_shape
signature and no TypeError is raised. Ensure any computed descale dim variables
(a_descale_n_dim, a_descale_k_dim, b_descale_k_dim, b_descale_n_dim) are not
passed and can be removed if unused.

2155-2167: ⚠️ Potential issue | 🟠 Major

Reuse the canonical FP4 shape/stride helper here.

This override metadata is still hand-derived, and it only scales the packed dimension stride. For batch > 1, the logical batch stride also doubles for packed FP4, so cuDNN will read the wrong batch slice. The current tests don't catch it because they only run with batch == 1.

Suggested fix
+    real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a)
+    real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b)
+
     override_shapes = [
-        [a.shape[0], a.shape[1], a.shape[2] * 2],
-        [b.shape[0], b.shape[1] * 2, b.shape[2]],
-        a_descale.shape,
-        b_descale.shape,
-        c_final.shape,
+        list(real_a_shape),
+        list(real_b_shape),
+        list(a_descale.shape),
+        list(b_descale.shape),
+        list(c_final.shape),
     ]
     override_strides = [
-        [a.stride()[0], a.stride()[1] * 2, a.stride()[2]],
-        [b.stride()[0], b.stride()[1], b.stride()[2] * 2],
-        a_descale.stride(),
-        b_descale.stride(),
-        c_final.stride(),
+        list(real_a_stride),
+        list(real_b_stride),
+        list(a_descale.stride()),
+        list(b_descale.stride()),
+        list(c_final.stride()),
     ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2155 - 2167, The
override_shapes/override_strides block is hand-derived and only scales the
packed dimension stride, which breaks when batch > 1; replace the manual scaling
for the FP4-packed inputs with the canonical FP4 shape/stride helper (the
module's FP4 helper used elsewhere) rather than hand-manipulating
a.shape/a.stride and b.shape/b.stride. Call the helper for a and b to produce
their overridden shapes and strides (e.g., use
get_canonical_fp4_shape_stride(a.shape, a.stride) and same for b) and leave
a_descale.shape/stride, b_descale.shape/stride and c_final.shape/stride as-is;
update override_shapes to use the helper-returned shapes and override_strides to
use the helper-returned strides so the batch stride is correctly doubled for
packed FP4.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2205-2228: The helper currently computes expanded descale dims via
_expand_block_scale_tensor_shape and then passes a_descale_n_dim,
a_descale_k_dim, b_descale_k_dim, b_descale_n_dim into
build_cudnn_fp4_gemm_graph_override_shape, but that function no longer accepts
those kwargs; remove those four descale-dimension keyword arguments from the
call (keep all other args like batch, n, k, ab_type, o_type, block_size, device,
alpha_is_not_none, use_nvfp4, etc.) so the call matches the new
build_cudnn_fp4_gemm_graph_override_shape/_get_cudnn_fp4_gemm_graph_override_shape
signature and no TypeError is raised. Ensure any computed descale dim variables
(a_descale_n_dim, a_descale_k_dim, b_descale_k_dim, b_descale_n_dim) are not
passed and can be removed if unused.
- Around line 2155-2167: The override_shapes/override_strides block is
hand-derived and only scales the packed dimension stride, which breaks when
batch > 1; replace the manual scaling for the FP4-packed inputs with the
canonical FP4 shape/stride helper (the module's FP4 helper used elsewhere)
rather than hand-manipulating a.shape/a.stride and b.shape/b.stride. Call the
helper for a and b to produce their overridden shapes and strides (e.g., use
get_canonical_fp4_shape_stride(a.shape, a.stride) and same for b) and leave
a_descale.shape/stride, b_descale.shape/stride and c_final.shape/stride as-is;
update override_shapes to use the helper-returned shapes and override_strides to
use the helper-returned strides so the batch stride is correctly doubled for
packed FP4.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1eade161-0775-4982-aaad-5442e4ad5f0d

📥 Commits

Reviewing files that changed from the base of the PR and between b4c92c6 and ffb9f95.

📒 Files selected for processing (2)
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_cudnn_override_shape.py

@dhiraj113
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !421 has been updated with latest changes, and the CI pipeline #46472647 is currently running. I'll report back once the pipeline job completes.

@yanqinz2 yanqinz2 enabled auto-merge (squash) March 19, 2026 02:56
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #46472647: 14/20 passed

@nv-yunzheq
Copy link
Copy Markdown
Collaborator

@yanqinz2 It seems the PR introduced a bunch of new API for cudnn dynamic shape support. However, we already have function like mm_fp4 for fp4 gemm with cudnn as an backend option. Is it possible to fit the dynamic shape support into those functions instead of creating new API?

@yanqinz2 yanqinz2 disabled auto-merge March 19, 2026 15:56
@yanqinz2 yanqinz2 merged commit 623db38 into main Mar 19, 2026
28 of 33 checks passed
@yanqinz2 yanqinz2 deleted the yanqinz/dynamic-shape branch March 19, 2026 15:56
@coderabbitai coderabbitai bot mentioned this pull request Mar 29, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants