Skip to content

Yanqinz/dynamic shape unified api#2910

Open
yanqinz2 wants to merge 4 commits intomainfrom
yanqinz/dynamic-shape-unified-api
Open

Yanqinz/dynamic shape unified api#2910
yanqinz2 wants to merge 4 commits intomainfrom
yanqinz/dynamic-shape-unified-api

Conversation

@yanqinz2
Copy link
Copy Markdown
Collaborator

@yanqinz2 yanqinz2 commented Mar 29, 2026

📌 Description

Add cudnn dynamic shape support for bf16 and fp4 gemm

🔍 Related Issues

https://nvbugspro.nvidia.com/bug/5539146

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Refactor

    • Standardized cuDNN GEMM public API names for consistency.
    • Removed redundant cached graph wrappers for a cleaner API surface.
  • New Features

    • Added an optional policy parameter to cuDNN graph builders (defaults to heuristic choice).
  • Bug Fixes

    • Improved workspace buffer allocation for dynamic-shape operations.
    • Corrected shape/stride handling for override-shape execution and added safer fallback behavior when override support is unavailable.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 29, 2026

📝 Walkthrough

Walkthrough

Refactors cuDNN GEMM override-shape APIs (renaming FP4/MXFP8 exports), adds an optional policy parameter to graph builders, updates execution paths to compute override shapes/strides from expanded tensors, manages workspace reallocation, and introduces BF16 override-shape caching and conditional execution when available.

Changes

Cohort / File(s) Summary
Public API Exports
flashinfer/gemm/__init__.py
Replaced exported function names for FP4 and MXFP8 override-shape APIs: removed old names and added build_cudnn_gemm_fp4_graph_override_shape, execute_cudnn_gemm_fp4_graph_override_shape, build_cudnn_gemm_mxfp8_graph_override_shape, execute_cudnn_gemm_mxfp8_graph_override_shape.
Core Implementation
flashinfer/gemm/gemm_base.py
Renamed FP4/MXFP8 override-shape builders/runners; added optional policy parameter (defaults to heuristics) and call to graph.build_plans(policy); changed override-shape execution to derive shapes/strides from expanded packed tensors and block-scale tensors; reallocate workspace buffer when too small; added BF16 override-shape caching using power-of-two m and conditional fallback to non-override path.
Tests
tests/gemm/test_cudnn_override_shape.py
Updated imports and call sites to use renamed FP4 and MXFP8 build/execute function names.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • bkryu
  • aleozlx
  • dhiraj113

Poem

🐰 I hopped through code with nimble feet,
Renamed the graphs so names now meet,
Policies chosen, plans built anew,
Workspaces grow when bytes are due,
BF16 caches hum — a rabbit's small feat.

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 35.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'Yanqinz/dynamic shape unified api' is vague and generic, using non-descriptive formatting (branch name) that doesn't clearly convey the main change. Use a more descriptive title that highlights the primary change, such as 'Rename cuDNN GEMM override-shape APIs for bf16 and fp4' or 'Add dynamic-shape support for bf16 and fp4 GEMM'.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed The description includes the required sections (Description, Related Issues, Checklist) with key information provided, though the Tests checklist items are unchecked despite significant code changes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch yanqinz/dynamic-shape-unified-api

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the cuDNN GEMM implementation to provide consistent dynamic-shape (override-shape) support across FP4, MXFP8, and BF16 data types. Key changes include renaming functions for better naming consistency, adding a policy parameter to graph builders, and integrating override-shape logic into the TunableRunner classes for BF16 and FP4. The review feedback focuses on improving the efficiency of workspace buffer handling by suggesting that the code raise a ValueError for undersized buffers instead of performing local re-allocations, and recommends refactoring duplicated logic within the runner classes to improve maintainability.

Comment on lines +2158 to +2161
if workspace_buffer.numel() < graph.get_workspace_size():
workspace_buffer = torch.empty(
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Re-assigning workspace_buffer here only changes the local variable. The caller's buffer remains unchanged and potentially undersized, leading to re-allocation on every call if the initial buffer is insufficient. This is inefficient.

A better approach would be to raise a ValueError if the buffer is too small, forcing the caller to provide a sufficiently sized buffer.

Suggested change
if workspace_buffer.numel() < graph.get_workspace_size():
workspace_buffer = torch.empty(
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
)
if workspace_buffer.numel() < graph.get_workspace_size():
raise ValueError(
f"workspace_buffer is too small. Need at least {graph.get_workspace_size()} elements, but got {workspace_buffer.numel()}."
)

Comment on lines +2370 to +2373
if workspace_buffer.numel() < graph.get_workspace_size():
workspace_buffer = torch.empty(
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to other execute_* functions in this file, re-assigning workspace_buffer here is inefficient as the change is local. If the caller passes an undersized buffer, it will be re-allocated on every call. Consider raising a ValueError instead to enforce that the caller provides a buffer of adequate size.

Suggested change
if workspace_buffer.numel() < graph.get_workspace_size():
workspace_buffer = torch.empty(
graph.get_workspace_size(), device=a.device, dtype=torch.uint8
)
if workspace_buffer.numel() < graph.get_workspace_size():
raise ValueError(
f"workspace_buffer is too small. Need at least {graph.get_workspace_size()} elements, but got {workspace_buffer.numel()}."
)

Comment on lines +2961 to +2962
if is_cudnn_override_shape_available():
graph = self._get_override_graph(a, b, out)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition is_cudnn_override_shape_available() and the call to self._get_override_graph(a, b, out) are duplicated in get_valid_tactics and forward. This could be refactored to improve maintainability and avoid redundant graph lookups/builds, even with caching. Consider creating a helper method that retrieves the correct graph based on availability, which can be called by both get_valid_tactics and forward.

Comment on lines +4125 to +4128
if is_cudnn_override_shape_available() and alpha is None:
graph = self._get_override_graph(
a, b, alpha, out_dtype, block_size, use_nvfp4
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition is_cudnn_override_shape_available() and alpha is None is duplicated in get_valid_tactics and forward. This could lead to maintenance issues if the condition changes. Consider refactoring this logic into a helper method to determine which execution path to take. This would also avoid calling self._get_override_graph twice (once in get_valid_tactics and once in forward), improving efficiency even with caching.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tests/gemm/test_cudnn_override_shape.py (1)

17-27: Import the public helpers through flashinfer.gemm in this test.

Right now this bypasses flashinfer.gemm, so the test still passes even if the package re-export layer regresses. Since this PR changes that surface, the test should exercise it.

🧪 Minimal import split
-from flashinfer.gemm.gemm_base import (
-    CUDNN_AVAILABLE,
-    build_cudnn_gemm_bf16_graph_override_shape,
-    execute_cudnn_gemm_bf16_graph_override_shape,
-    build_cudnn_gemm_fp4_graph_override_shape,
-    execute_cudnn_gemm_fp4_graph_override_shape,
-    build_cudnn_gemm_mxfp8_graph_override_shape,
-    execute_cudnn_gemm_mxfp8_graph_override_shape,
-    is_cudnn_override_shape_available,
-    _calculate_block_scale_dims,
-)
+from flashinfer.gemm import (
+    build_cudnn_gemm_bf16_graph_override_shape,
+    execute_cudnn_gemm_bf16_graph_override_shape,
+    build_cudnn_gemm_fp4_graph_override_shape,
+    execute_cudnn_gemm_fp4_graph_override_shape,
+    build_cudnn_gemm_mxfp8_graph_override_shape,
+    execute_cudnn_gemm_mxfp8_graph_override_shape,
+    is_cudnn_override_shape_available,
+)
+from flashinfer.gemm.gemm_base import CUDNN_AVAILABLE, _calculate_block_scale_dims
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_cudnn_override_shape.py` around lines 17 - 27, The test
directly imports helpers from flashinfer.gemm.gemm_base instead of exercising
the package re-export layer; update the import to import the public helpers from
flashinfer.gemm (e.g. import CUDNN_AVAILABLE,
build_cudnn_gemm_bf16_graph_override_shape,
execute_cudnn_gemm_bf16_graph_override_shape,
build_cudnn_gemm_fp4_graph_override_shape,
execute_cudnn_gemm_fp4_graph_override_shape,
build_cudnn_gemm_mxfp8_graph_override_shape,
execute_cudnn_gemm_mxfp8_graph_override_shape,
is_cudnn_override_shape_available, _calculate_block_scale_dims) via from
flashinfer.gemm import <symbols> so the test fails if the package re-export
surface regresses.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/__init__.py`:
- Around line 25-28: Restore backward-compatible aliases for the renamed
override-shape exports by reintroducing the old names as simple assignments to
the new symbols: for example, set the previous FP4/MXFP8 export names equal to
build_cudnn_gemm_fp4_graph_override_shape,
execute_cudnn_gemm_fp4_graph_override_shape,
build_cudnn_gemm_mxfp8_graph_override_shape, and
execute_cudnn_gemm_mxfp8_graph_override_shape in flashinfer.gemm.__init__.py so
old imports continue to work; also add the same alias assignments in the
flashinfer.gemm.gemm_base module if that path is a supported public import so
both import surfaces mirror each other.

In `@flashinfer/gemm/gemm_base.py`:
- Around line 2121-2124: Replace the runtime assertion with explicit input
validation that raises a ValueError: instead of using "assert real_a_stride[2]
== 1 and real_b_stride[1] == 1", check those conditions and raise ValueError
with a clear message (e.g., "a and b must be k-major") so invalid
caller-provided strides cannot slip through when Python assertions are disabled;
update the code around _get_bf16_3d_shape_stride and the
real_a_stride/real_b_stride checks accordingly.

---

Nitpick comments:
In `@tests/gemm/test_cudnn_override_shape.py`:
- Around line 17-27: The test directly imports helpers from
flashinfer.gemm.gemm_base instead of exercising the package re-export layer;
update the import to import the public helpers from flashinfer.gemm (e.g. import
CUDNN_AVAILABLE, build_cudnn_gemm_bf16_graph_override_shape,
execute_cudnn_gemm_bf16_graph_override_shape,
build_cudnn_gemm_fp4_graph_override_shape,
execute_cudnn_gemm_fp4_graph_override_shape,
build_cudnn_gemm_mxfp8_graph_override_shape,
execute_cudnn_gemm_mxfp8_graph_override_shape,
is_cudnn_override_shape_available, _calculate_block_scale_dims) via from
flashinfer.gemm import <symbols> so the test fails if the package re-export
surface regresses.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 463f5a46-a140-49a0-8802-735d5fba51af

📥 Commits

Reviewing files that changed from the base of the PR and between 904fa8c and 368656f.

📒 Files selected for processing (3)
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_cudnn_override_shape.py

Comment on lines +25 to +28
build_cudnn_gemm_fp4_graph_override_shape as build_cudnn_gemm_fp4_graph_override_shape,
execute_cudnn_gemm_fp4_graph_override_shape as execute_cudnn_gemm_fp4_graph_override_shape,
build_cudnn_gemm_mxfp8_graph_override_shape as build_cudnn_gemm_mxfp8_graph_override_shape,
execute_cudnn_gemm_mxfp8_graph_override_shape as execute_cudnn_gemm_mxfp8_graph_override_shape,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Preserve the old override-shape export names as aliases.

These symbols were already part of the flashinfer.gemm surface. Replacing them outright means downstream imports fail immediately on upgrade. Please keep the old FP4/MXFP8 names as shims for at least one release, or call out the breaking change explicitly. If flashinfer.gemm.gemm_base is also a supported import path, mirror the aliases there too.

↩️ Compatibility shim example
 from .gemm_base import (
     is_cudnn_override_shape_available as is_cudnn_override_shape_available,
     build_cudnn_gemm_bf16_graph_override_shape as build_cudnn_gemm_bf16_graph_override_shape,
     execute_cudnn_gemm_bf16_graph_override_shape as execute_cudnn_gemm_bf16_graph_override_shape,
     build_cudnn_gemm_fp4_graph_override_shape as build_cudnn_gemm_fp4_graph_override_shape,
     execute_cudnn_gemm_fp4_graph_override_shape as execute_cudnn_gemm_fp4_graph_override_shape,
     build_cudnn_gemm_mxfp8_graph_override_shape as build_cudnn_gemm_mxfp8_graph_override_shape,
     execute_cudnn_gemm_mxfp8_graph_override_shape as execute_cudnn_gemm_mxfp8_graph_override_shape,
     build_cudnn_gemm_with_per_tensor_q_graph_override_shape as build_cudnn_gemm_with_per_tensor_q_graph_override_shape,
     execute_cudnn_gemm_with_per_tensor_q_graph_override_shape as execute_cudnn_gemm_with_per_tensor_q_graph_override_shape,
 )
+
+# Backward-compat aliases
+build_cudnn_fp4_gemm_graph_override_shape = build_cudnn_gemm_fp4_graph_override_shape
+execute_cudnn_fp4_gemm_graph_override_shape = execute_cudnn_gemm_fp4_graph_override_shape
+build_cudnn_mxfp8_gemm_graph_override_shape = build_cudnn_gemm_mxfp8_graph_override_shape
+execute_cudnn_mxfp8_gemm_graph_override_shape = execute_cudnn_gemm_mxfp8_graph_override_shape
@@
     "build_cudnn_gemm_fp4_graph_override_shape",
     "execute_cudnn_gemm_fp4_graph_override_shape",
     "build_cudnn_gemm_mxfp8_graph_override_shape",
     "execute_cudnn_gemm_mxfp8_graph_override_shape",
+    "build_cudnn_fp4_gemm_graph_override_shape",
+    "execute_cudnn_fp4_gemm_graph_override_shape",
+    "build_cudnn_mxfp8_gemm_graph_override_shape",
+    "execute_cudnn_mxfp8_gemm_graph_override_shape",
     "build_cudnn_gemm_with_per_tensor_q_graph_override_shape",
     "execute_cudnn_gemm_with_per_tensor_q_graph_override_shape",
 ] + _cute_dsl_kernels

Also applies to: 82-85

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/__init__.py` around lines 25 - 28, Restore
backward-compatible aliases for the renamed override-shape exports by
reintroducing the old names as simple assignments to the new symbols: for
example, set the previous FP4/MXFP8 export names equal to
build_cudnn_gemm_fp4_graph_override_shape,
execute_cudnn_gemm_fp4_graph_override_shape,
build_cudnn_gemm_mxfp8_graph_override_shape, and
execute_cudnn_gemm_mxfp8_graph_override_shape in flashinfer.gemm.__init__.py so
old imports continue to work; also add the same alias assignments in the
flashinfer.gemm.gemm_base module if that path is a supported public import so
both import surfaces mirror each other.

Comment on lines +2121 to 2124
c_shape, c_stride = _get_bf16_3d_shape_stride(c_final)

assert real_a_stride[2] == 1 and real_b_stride[1] == 1, "a and b must be k-major"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, locate and examine the specific lines in question
wc -l flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 101


🏁 Script executed:

# Read the section around lines 2121-2124 to see the assertion
sed -n '2110,2135p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 963


🏁 Script executed:

# Look at the function containing this assertion
sed -n '2080,2130p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1583


🏁 Script executed:

# Search for other assert statements related to layout in this file
rg 'assert.*stride' flashinfer/gemm/gemm_base.py -A 1 -B 1

Repository: flashinfer-ai/flashinfer

Length of output: 155


🏁 Script executed:

# Check if there are other ValueError patterns for layout validation
rg 'raise ValueError.*stride|raise ValueError.*layout|raise ValueError.*major' flashinfer/gemm/gemm_base.py -A 1 -B 1

Repository: flashinfer-ai/flashinfer

Length of output: 831


🏁 Script executed:

# Search for other assertions in executor functions, particularly override-shape ones
rg 'def execute.*override' flashinfer/gemm/gemm_base.py -A 50 | rg 'assert'

Repository: flashinfer-ai/flashinfer

Length of output: 153


🏁 Script executed:

# Get more context on the function signature and documentation
sed -n '2090,2105p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 330


🏁 Script executed:

# Check for other assertions in executor functions that validate caller input
rg 'def execute_cudnn.*\(' flashinfer/gemm/gemm_base.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 484


🏁 Script executed:

# Search within other executor functions for assertions
rg 'def execute_' flashinfer/gemm/gemm_base.py -A 30 | grep -E '(def execute_|assert )'

Repository: flashinfer-ai/flashinfer

Length of output: 570


🏁 Script executed:

# Check the non-override FP4 executor for comparison
rg 'def execute_cudnn_gemm_fp4_graph\(' flashinfer/gemm/gemm_base.py -A 40 | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 1173


🏁 Script executed:

# Also check MXFP8 override executor for validation patterns
rg 'def execute_cudnn_gemm_mxfp8_graph_override_shape' flashinfer/gemm/gemm_base.py -A 50 | grep -E '(assert|raise ValueError|stride)'

Repository: flashinfer-ai/flashinfer

Length of output: 244


🏁 Script executed:

# Check if the non-override FP4 function validates strides differently
rg 'def execute_cudnn_gemm_fp4_graph\(' flashinfer/gemm/gemm_base.py -A 60 | grep -E '(stride|assert|ValueError)'

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at the full MXFP8 override executor to see if it has similar validation
sed -n '2200,2350p' flashinfer/gemm/gemm_base.py | grep -E '(def execute|assert|stride|raise ValueError)'

Repository: flashinfer-ai/flashinfer

Length of output: 847


🏁 Script executed:

# Look at the full MXFP8 override executor function to see complete validation pattern
rg 'def execute_cudnn_gemm_mxfp8_graph_override_shape' flashinfer/gemm/gemm_base.py -A 80 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 2240


Use ValueError instead of assert for FP4 layout validation.

The assertion at line 2123 validates caller-provided tensor strides. Under python -O, it will be stripped, allowing invalid layouts to pass to cuDNN. Input validation must use exceptions, not assertions. This is consistent with the codebase's pattern of using ValueError for constraint validation.

🛠️ Proposed fix
-    assert real_a_stride[2] == 1 and real_b_stride[1] == 1, "a and b must be k-major"
+    if real_a_stride[2] != 1 or real_b_stride[1] != 1:
+        raise ValueError("a and b must be k-major")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2121 - 2124, Replace the runtime
assertion with explicit input validation that raises a ValueError: instead of
using "assert real_a_stride[2] == 1 and real_b_stride[1] == 1", check those
conditions and raise ValueError with a clear message (e.g., "a and b must be
k-major") so invalid caller-provided strides cannot slip through when Python
assertions are disabled; update the code around _get_bf16_3d_shape_stride and
the real_a_stride/real_b_stride checks accordingly.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/gemm/gemm_base.py (2)

1955-1966: ⚠️ Potential issue | 🟠 Major

Keep deprecated aliases for the renamed override-shape helpers.

These functions are part of the exported flashinfer.gemm surface, so renaming them without keeping the previous symbols is a hard import break for downstream users. Please keep the old FP4/MXFP8 names as aliases for at least one release cycle, or explicitly treat this PR as a breaking API change.

Also applies to: 2090-2100, 2209-2219, 2321-2330

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 1955 - 1966, The renamed
override-shape helper functions must keep their old FP4/MXFP8 symbols as
deprecated aliases: add thin alias definitions that point the old names to the
new implementations (e.g., define the previous FP4/MXFP8 function names to
return/call build_cudnn_gemm_fp4_graph_override_shape and the other two renamed
helpers referenced at the other ranges) and emit a warnings.warn(...,
DeprecationWarning) when those aliases are called so downstream code keeps
working for one release cycle while signaling the rename; ensure the alias names
exactly match the previous exported symbols and reference the new functions (use
the new function names from this file such as
build_cudnn_gemm_fp4_graph_override_shape) and include a short deprecation
message indicating the new name.

2321-2365: ⚠️ Potential issue | 🟠 Major

Apply shape normalization to MXFP8 override-shape executor for consistency.

execute_cudnn_gemm_mxfp8_graph_override_shape() forwards raw tensor shapes/strides, whereas execute_cudnn_gemm_bf16_graph_override_shape() uses _get_bf16_3d_shape_stride() to normalize 2D inputs to 3D. Even though MXFP8 currently enforces 3D inputs, the inconsistency should be resolved by using the same normalization pattern:

Suggested normalization
 def execute_cudnn_gemm_mxfp8_graph_override_shape(
     graph,
     a,
     b,
@@
 ):
     """Execute MXFP8 GEMM cuDNN graph with dynamic-shape overrides."""
+    a_shape, a_stride = _get_bf16_3d_shape_stride(a)
+    b_shape, b_stride = _get_bf16_3d_shape_stride(b)
+    batch = a_shape[0]
+    a_descale_shape, a_descale_stride = _expand_block_scale_tensor_shape(
+        a_descale, batch
+    )
+    b_descale_shape, b_descale_stride = _expand_block_scale_tensor_shape(
+        b_descale, batch
+    )
+    c_shape, c_stride = _get_bf16_3d_shape_stride(c_final)
+
     variant_pack = {
         UIDs.A_UID.value: a,
         UIDs.B_UID.value: b,
         UIDs.BLOCK_DESCALE_A_UID.value: a_descale,
@@
     override_shapes = [
-        list(a.shape),
-        list(b.shape),
-        list(a_descale.shape),
-        list(b_descale.shape),
-        list(c_final.shape),
+        list(a_shape),
+        list(b_shape),
+        list(a_descale_shape),
+        list(b_descale_shape),
+        list(c_shape),
     ]
     override_strides = [
-        list(a.stride()),
-        list(b.stride()),
-        list(a_descale.stride()),
-        list(b_descale.stride()),
-        list(c_final.stride()),
+        list(a_stride),
+        list(b_stride),
+        list(a_descale_stride),
+        list(b_descale_stride),
+        list(c_stride),
     ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2321 - 2365, The override-shape
executor execute_cudnn_gemm_mxfp8_graph_override_shape currently forwards raw
tensor shapes/strides; update it to use the same normalization used by the BF16
path by calling _get_bf16_3d_shape_stride() for each input (a, b, a_descale,
b_descale, c_final) and use the returned normalized shape and stride values when
building override_shapes and override_strides; ensure you replace the direct
list(tensor.shape)/list(tensor.stride()) calls with the normalized shape/stride
results so the MXFP8 override-shape logic is consistent with
execute_cudnn_gemm_bf16_graph_override_shape.
♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)

2113-2115: ⚠️ Potential issue | 🟡 Minor

Use ValueError instead of assert for the k-major guard.

Line 2115 is validating caller input. Under python -O, that assert disappears and invalid layouts can slip through to cuDNN.

🛠️ Safer runtime validation
-    assert real_a_stride[2] == 1 and real_b_stride[1] == 1, "a and b must be k-major"
+    if real_a_stride[2] != 1 or real_b_stride[1] != 1:
+        raise ValueError("a and b must be k-major")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 2113 - 2115, Replace the runtime
assertion that validates k-major layout with an explicit exception: instead of
using assert real_a_stride[2] == 1 and real_b_stride[1] == 1, raise a ValueError
with the same descriptive message so the check remains active under python -O;
locate the check around the call to _get_bf16_3d_shape_stride and update the
validation (referring to variables real_a_stride and real_b_stride) to raise
ValueError("a and b must be k-major") when the condition fails.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 1955-1966: The renamed override-shape helper functions must keep
their old FP4/MXFP8 symbols as deprecated aliases: add thin alias definitions
that point the old names to the new implementations (e.g., define the previous
FP4/MXFP8 function names to return/call
build_cudnn_gemm_fp4_graph_override_shape and the other two renamed helpers
referenced at the other ranges) and emit a warnings.warn(...,
DeprecationWarning) when those aliases are called so downstream code keeps
working for one release cycle while signaling the rename; ensure the alias names
exactly match the previous exported symbols and reference the new functions (use
the new function names from this file such as
build_cudnn_gemm_fp4_graph_override_shape) and include a short deprecation
message indicating the new name.
- Around line 2321-2365: The override-shape executor
execute_cudnn_gemm_mxfp8_graph_override_shape currently forwards raw tensor
shapes/strides; update it to use the same normalization used by the BF16 path by
calling _get_bf16_3d_shape_stride() for each input (a, b, a_descale, b_descale,
c_final) and use the returned normalized shape and stride values when building
override_shapes and override_strides; ensure you replace the direct
list(tensor.shape)/list(tensor.stride()) calls with the normalized shape/stride
results so the MXFP8 override-shape logic is consistent with
execute_cudnn_gemm_bf16_graph_override_shape.

---

Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2113-2115: Replace the runtime assertion that validates k-major
layout with an explicit exception: instead of using assert real_a_stride[2] == 1
and real_b_stride[1] == 1, raise a ValueError with the same descriptive message
so the check remains active under python -O; locate the check around the call to
_get_bf16_3d_shape_stride and update the validation (referring to variables
real_a_stride and real_b_stride) to raise ValueError("a and b must be k-major")
when the condition fails.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6aab590b-9860-40d0-825b-b2ff37180d78

📥 Commits

Reviewing files that changed from the base of the PR and between 368656f and da99fe3.

📒 Files selected for processing (1)
  • flashinfer/gemm/gemm_base.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant