Implement override shape support for cuDNN GEMM operations#2790
Implement override shape support for cuDNN GEMM operations#2790
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the flexibility of GEMM operations by integrating cuDNN's override shape feature. This allows the system to execute pre-compiled GEMM graphs with dynamically changing M dimensions, which is crucial for performance in scenarios with variable input sizes, such as large language model inference. The changes include new functions for building and executing these dynamic-shape-enabled graphs for various data types, along with robust version checks and comprehensive tests. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds cuDNN "override-shape" GEMM support: runtime availability checks, cached graph builders/executors (BF16, FP4, MXFP8, per-tensor FP8) allowing a single compiled graph to execute with varying M via override_shapes/override_strides. Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test Suite
participant Lib as flashinfer.gemm (API)
participant Cache as Graph Cache
participant cuDNN as cuDNN Runtime
Test->>Lib: call build_*_override_shape(cache_m, ...)
Lib->>cuDNN: compile graph (reserve/cache with cache_m)
Lib->>Cache: store compiled graph keyed by params+cache_m
Test->>Lib: call execute_*_override_shape(graph, a, b, override_shapes, override_strides)
Lib->>Cache: fetch compiled graph
Lib->>cuDNN: execute graph with override_shapes/override_strides (dynamic M)
cuDNN-->>Lib: return result tensor
Lib-->>Test: return output (no rebuild)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can enable review details to help with troubleshooting, context usage and more.Enable the |
There was a problem hiding this comment.
Code Review
This pull request introduces dynamic shape support for the cuDNN backend, which is a significant feature for improving performance by avoiding graph recompilations. The implementation adds several new functions for building and executing cuDNN graphs with override shapes for BF16, FP4, and FP8 GEMM operations. The changes are well-structured, but I've identified a few critical issues in the FP4 implementation and the new tests.
Specifically, there appears to be a bug in the shape calculation for FP4 graphs and an issue with workspace buffer handling that could lead to performance degradation. Additionally, the new tests for FP4 and MXFP8 dynamic shapes are missing correctness assertions, which is a crucial gap in validation. My review provides detailed feedback and suggestions to address these points.
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/__init__.py`:
- Around line 21-32: The new override-shape symbols imported from gemm_base
(is_cudnn_override_shape_available, CUDNN_MIN_VERSION_OVERRIDE_SHAPE,
build_cudnn_gemm_bf16_graph_override_shape,
execute_cudnn_gemm_bf16_graph_override_shape,
build_cudnn_fp4_gemm_graph_override_shape,
execute_cudnn_fp4_gemm_graph_override_shape,
build_cudnn_mxfp8_gemm_graph_override_shape,
execute_cudnn_mxfp8_gemm_graph_override_shape,
build_cudnn_gemm_with_per_tensor_q_graph_override_shape,
execute_cudnn_gemm_with_per_tensor_q_graph_override_shape) are not listed in the
module's __all__ export list; update the __all__ variable in
flashinfer.gemm.__init__ to include these exact symbol names so they are
exported on from flashinfer.gemm import * and by re-exports.
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2007-2021: The override path recomputes FP4 packed shapes/strides
(e.g., a_shape, a_stride, b_shape, b_stride, a_descale_shape, a_descale_stride,
b_descale_shape, b_descale_stride) instead of reusing the canonical helper,
causing batched layouts to diverge; replace the manual math with a call to the
existing helper _get_real_fp4_shape_from_packed_uint8 (and use its returned
shape/stride tuples) and keep _calculate_block_scale_dims only for block scale
dims, ensuring both this block (around a_shape/a_stride) and the similar block
at the other location (around lines mentioned) use the same helper-derived FP4
metadata so packed-FP4 layout matches the non-override path.
- Around line 1977-1993: The cached builder
build_cudnn_fp4_gemm_graph_override_shape currently includes a_descale_n_dim in
its parameter list, which varies with M and causes unnecessary cache churn;
remove a_descale_n_dim from the cached function signature (and from the
analogous cached helper at the other location) so the cache key no longer
depends on M, update callers (e.g., where
_get_cudnn_fp4_gemm_graph_override_shape passes expanded_a_descale_shape[1]) to
stop passing that component, and if the builder needs a Descales shape value
compute or derive it inside the function from the stable values you already pass
(or ignore it if unused) to preserve a single reusable plan.
- Around line 1727-1746: The current checks only inspect cudnn.backend_version()
but must also gate on the cuDNN-frontend Python package (frontend) supporting
override-shape; update both _check_cudnn_override_shape_availability and
is_cudnn_override_shape_available to also verify the frontend version or feature
presence: e.g., check a frontend version string (cudnn.__version__ or
cudnn.frontend.__version__) >= the minimum frontend release that added
override-shape, and/or perform feature detection with hasattr / inspection
(confirm presence of the called parameter or the newer execute(...) API with
override_uids/override_shapes/override_strides using inspect.signature or a
small try/except TypeError probe). If the frontend is too old or the feature is
missing, return False in is_cudnn_override_shape_available and raise a clear
RuntimeError in _check_cudnn_override_shape_availability; reference the
functions _check_cudnn_override_shape_availability and
is_cudnn_override_shape_available when making the changes.
In `@tests/gemm/test_cudnn_override_shape.py`:
- Around line 194-221: The test currently only runs
execute_cudnn_fp4_gemm_graph_override_shape in the dynamic_ms loop and
synchronizes, so failures like zeros/NaNs or wrong numerics won't be caught;
update the loop that calls execute_cudnn_fp4_gemm_graph_override_shape (and the
similar loop at 292-316) to validate outputs: compute a reference output (e.g.,
by calling the static-shape path such as execute_cudnn_fp4_gemm_graph or another
known-good implementation with the same inputs a_packed, b_packed, a_descale,
b_descale, workspace, tactic) and then assert torch.isfinite(out).all() and that
out is close to the reference using torch.allclose or
torch.testing.assert_allclose with a small atol/rtol appropriate for FP4/FP8; if
you prefer a cheaper check add at minimum finiteness plus a basic relative-error
threshold, and use the exact symbols
execute_cudnn_fp4_gemm_graph_override_shape, execute_cudnn_fp4_gemm_graph,
a_packed, b_packed, a_descale, b_descale, out, workspace, tactic so the
assertions are added next to the existing invocation.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2d681695-bc90-4fa7-9aa3-d593f21b2991
📒 Files selected for processing (3)
flashinfer/gemm/__init__.pyflashinfer/gemm/gemm_base.pytests/gemm/test_cudnn_override_shape.py
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
flashinfer/gemm/gemm_base.py (3)
2146-2158:⚠️ Potential issue | 🟠 MajorFP4 override shape/stride should reuse the canonical helper to avoid layout drift.
Manual
* 2shape/stride math here can diverge from_get_real_fp4_shape_from_packed_uint8()behavior (especially across layout/batch cases).Suggested fix
- override_shapes = [ - [a.shape[0], a.shape[1], a.shape[2] * 2], - [b.shape[0], b.shape[1] * 2, b.shape[2]], + real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) + real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) + override_shapes = [ + list(real_a_shape), + list(real_b_shape), a_descale.shape, b_descale.shape, c_final.shape, ] override_strides = [ - [a.stride()[0], a.stride()[1] * 2, a.stride()[2]], - [b.stride()[0], b.stride()[1], b.stride()[2] * 2], + list(real_a_stride), + list(real_b_stride), a_descale.stride(), b_descale.stride(), c_final.stride(), ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 2146 - 2158, The override_shapes and override_strides calculations currently hard-code "* 2" adjustments which can drift from the canonical logic; replace the manual math with calls to the canonical helper _get_real_fp4_shape_from_packed_uint8() (and its stride-equivalent or by deriving strides from that helper) for tensors a and b so the computed shapes/strides for a, b (and keep a_descale.shape, b_descale.shape, c_final.shape and their strides) match the canonical FP4-unpacking behavior and avoid layout/batch mismatches.
1977-1993:⚠️ Potential issue | 🟠 Major
a_descale_n_dimstill leaks M into the FP4 override graph cache key.
build_cudnn_fp4_gemm_graph_override_shape()is cached, anda_descale_n_dimvaries with M (Line 2203-2204). That defeats single-plan reuse across dynamic M.Suggested fix
`@functools.cache` def build_cudnn_fp4_gemm_graph_override_shape( batch, n, k, - a_descale_n_dim, a_descale_k_dim, b_descale_k_dim, b_descale_n_dim, @@ - # a_descale N-dimension (dim[1]) depends on M, so we pass it separately - a_descale_n_dim = expanded_a_descale_shape[1] - return build_cudnn_fp4_gemm_graph_override_shape( batch=batch, n=n, k=k, - a_descale_n_dim=a_descale_n_dim, a_descale_k_dim=a_descale_k_dim, b_descale_k_dim=b_descale_k_dim, b_descale_n_dim=b_descale_n_dim,Also applies to: 2203-2211
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 1977 - 1993, The cached function build_cudnn_fp4_gemm_graph_override_shape is leaking M into its cache key via the a_descale_n_dim parameter; remove M-varying data from the cached signature by eliminating a_descale_n_dim (or replacing it with a normalized/boolean indicator) and instead compute any M-dependent descaling inside the function body or upstream per-call (so the cache key remains stable). Update callers that pass a_descale_n_dim to instead pass the fixed/normalized indicator (or stop passing it) and adjust logic in build_cudnn_fp4_gemm_graph_override_shape (and the similar code around the 2203-2211 region) to derive the actual per-M descale values at runtime rather than as part of the cached parameters.
1727-1747:⚠️ Potential issue | 🟠 MajorOverride-shape availability check is still backend-only and can misreport support.
_check_cudnn_override_shape_availability()/is_cudnn_override_shape_available()only gate oncudnn.backend_version(). That can returnTrueeven when the installed Python frontend lacksis_override_shape_enabled/override execute kwargs, causing runtime failures later. Also, Line 1746 catches a blindException, masking real causes.For NVIDIA cuDNN Python frontend, which version first supports `pygraph(..., is_override_shape_enabled=True)` and `execute_plan_at_index(..., override_uids, override_shapes, override_strides)`? Is backend_version() alone sufficient to guarantee those APIs exist?Suggested fix
+import inspect + +def _has_cudnn_override_shape_frontend() -> bool: + try: + pygraph_sig = inspect.signature(cudnn.pygraph) + if "is_override_shape_enabled" not in pygraph_sig.parameters: + return False + except (AttributeError, TypeError, ValueError): + return False + return True + def _check_cudnn_override_shape_availability(): _check_cudnn_availability() + if not _has_cudnn_override_shape_frontend(): + raise RuntimeError("cuDNN frontend override-shape API is unavailable.") backend_version = cudnn.backend_version() if backend_version < CUDNN_MIN_VERSION_OVERRIDE_SHAPE: raise RuntimeError(...) def is_cudnn_override_shape_available() -> bool: if not CUDNN_AVAILABLE: return False try: - return cudnn.backend_version() >= CUDNN_MIN_VERSION_OVERRIDE_SHAPE - except Exception: + return _has_cudnn_override_shape_frontend() and ( + cudnn.backend_version() >= CUDNN_MIN_VERSION_OVERRIDE_SHAPE + ) + except (AttributeError, TypeError, ValueError): return False🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 1727 - 1747, The current checks in _check_cudnn_override_shape_availability and is_cudnn_override_shape_available rely only on cudnn.backend_version() which can misreport support if the Python frontend lacks the new APIs; update both functions to also verify the frontend provides the required symbols (e.g., hasattr(cudnn, "is_override_shape_enabled") and hasattr(cudnn, "execute_plan_at_index")) and, for execute_plan_at_index, optionally inspect its signature to ensure it accepts override_uids/override_shapes/override_strides; if the frontend checks fail, raise a clear RuntimeError in _check_cudnn_override_shape_availability and return False in is_cudnn_override_shape_available. Replace the broad except Exception in is_cudnn_override_shape_available with targeted exception handling (AttributeError or TypeError) so real errors are not masked, and still respect CUDNN_AVAILABLE and the CUDNN_MIN_VERSION_OVERRIDE_SHAPE version check.
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)
2008-2008: Use_-prefixed unpacking for intentionally unused values.A few unpacked vars are unused (
block_scale_dim_k,real_a_stride,real_b_stride,expanded_a_descale_stride,expanded_b_descale_stride). Prefixing with_keeps intent clear and quiets lint noise.Also applies to: 2249-2255
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` at line 2008, The unpacking in gemm_base.py assigns several variables that are intentionally unused (e.g., block_scale_dim_k, real_a_stride, real_b_stride, expanded_a_descale_stride, expanded_b_descale_stride); update those unpack targets to use a leading underscore (for example _block_scale_dim_k, _real_a_stride, etc.) so the intent is clear and linters stop flagging them. Locate the unpack expressions such as the call to _calculate_block_scale_dims (where block_scale_dim_m, _, block_scale_dim_k = ...) and the unpackings around lines 2249-2255, and rename each intentionally unused variable to an _-prefixed name while leaving used locals unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2113-2171: The override-shape executors
execute_cudnn_fp4_gemm_graph_override_shape and
execute_cudnn_mxfp8_gemm_graph_override_shape call graph.execute_plan_at_index
without verifying the provided workspace_buffer is large enough; add a guard
that fetches required_size = graph.get_workspace_size(plan_index or default) and
compares required_size to workspace_buffer.numel(), and if workspace_buffer is
too small either raise a clear error (including required_size and provided size)
or reallocate/resize the workspace_buffer before calling
graph.execute_plan_at_index; locate the check around the call site to
graph.execute_plan_at_index in both functions and perform this size validation
using the same handle/stream logic already present.
---
Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2146-2158: The override_shapes and override_strides calculations
currently hard-code "* 2" adjustments which can drift from the canonical logic;
replace the manual math with calls to the canonical helper
_get_real_fp4_shape_from_packed_uint8() (and its stride-equivalent or by
deriving strides from that helper) for tensors a and b so the computed
shapes/strides for a, b (and keep a_descale.shape, b_descale.shape,
c_final.shape and their strides) match the canonical FP4-unpacking behavior and
avoid layout/batch mismatches.
- Around line 1977-1993: The cached function
build_cudnn_fp4_gemm_graph_override_shape is leaking M into its cache key via
the a_descale_n_dim parameter; remove M-varying data from the cached signature
by eliminating a_descale_n_dim (or replacing it with a normalized/boolean
indicator) and instead compute any M-dependent descaling inside the function
body or upstream per-call (so the cache key remains stable). Update callers that
pass a_descale_n_dim to instead pass the fixed/normalized indicator (or stop
passing it) and adjust logic in build_cudnn_fp4_gemm_graph_override_shape (and
the similar code around the 2203-2211 region) to derive the actual per-M descale
values at runtime rather than as part of the cached parameters.
- Around line 1727-1747: The current checks in
_check_cudnn_override_shape_availability and is_cudnn_override_shape_available
rely only on cudnn.backend_version() which can misreport support if the Python
frontend lacks the new APIs; update both functions to also verify the frontend
provides the required symbols (e.g., hasattr(cudnn, "is_override_shape_enabled")
and hasattr(cudnn, "execute_plan_at_index")) and, for execute_plan_at_index,
optionally inspect its signature to ensure it accepts
override_uids/override_shapes/override_strides; if the frontend checks fail,
raise a clear RuntimeError in _check_cudnn_override_shape_availability and
return False in is_cudnn_override_shape_available. Replace the broad except
Exception in is_cudnn_override_shape_available with targeted exception handling
(AttributeError or TypeError) so real errors are not masked, and still respect
CUDNN_AVAILABLE and the CUDNN_MIN_VERSION_OVERRIDE_SHAPE version check.
---
Nitpick comments:
In `@flashinfer/gemm/gemm_base.py`:
- Line 2008: The unpacking in gemm_base.py assigns several variables that are
intentionally unused (e.g., block_scale_dim_k, real_a_stride, real_b_stride,
expanded_a_descale_stride, expanded_b_descale_stride); update those unpack
targets to use a leading underscore (for example _block_scale_dim_k,
_real_a_stride, etc.) so the intent is clear and linters stop flagging them.
Locate the unpack expressions such as the call to _calculate_block_scale_dims
(where block_scale_dim_m, _, block_scale_dim_k = ...) and the unpackings around
lines 2249-2255, and rename each intentionally unused variable to an _-prefixed
name while leaving used locals unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 51adb811-10c1-4625-904c-063b0b235aa4
📒 Files selected for processing (2)
flashinfer/gemm/__init__.pyflashinfer/gemm/gemm_base.py
There was a problem hiding this comment.
♻️ Duplicate comments (3)
flashinfer/gemm/gemm_base.py (2)
2204-2228:⚠️ Potential issue | 🔴 CriticalRemove the stale descale-dimension kwargs from the cached FP4 builder call.
build_cudnn_fp4_gemm_graph_override_shape()no longer acceptsa_descale_n_dim,a_descale_k_dim,b_descale_k_dim, orb_descale_n_dim. This call now raisesTypeErrorbefore the override graph is built.Suggested diff
- expanded_a_descale_shape, _ = _expand_block_scale_tensor_shape(a_descale, batch) - expanded_b_descale_shape, _ = _expand_block_scale_tensor_shape(b_descale, batch) - - # Scale dimension sizes that are independent of M - a_descale_k_dim = expanded_a_descale_shape[2] - b_descale_k_dim = expanded_b_descale_shape[1] - b_descale_n_dim = expanded_b_descale_shape[2] - # a_descale N-dimension (dim[1]) depends on M, so we pass it separately - a_descale_n_dim = expanded_a_descale_shape[1] - return build_cudnn_fp4_gemm_graph_override_shape( batch=batch, n=n, k=k, - a_descale_n_dim=a_descale_n_dim, - a_descale_k_dim=a_descale_k_dim, - b_descale_k_dim=b_descale_k_dim, - b_descale_n_dim=b_descale_n_dim, ab_type=cudnn.data_type.FP4_E2M1, o_type=_torch_data_type_to_cudnn_data_type(out_dtype), block_size=block_size, device=a.device, alpha_is_not_none=alpha is not None, use_nvfp4=use_nvfp4, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 2204 - 2228, The call to build_cudnn_fp4_gemm_graph_override_shape in the gemm builder is passing stale kwargs a_descale_n_dim, a_descale_k_dim, b_descale_k_dim, and b_descale_n_dim which the function no longer accepts; remove those four keyword arguments from the return call and only pass the remaining valid parameters (batch, n, k, ab_type, o_type, block_size, device, alpha_is_not_none, use_nvfp4) so the cached FP4 builder call no longer raises a TypeError in build_cudnn_fp4_gemm_graph_override_shape.
2154-2166:⚠️ Potential issue | 🔴 CriticalReuse the canonical FP4 shape/stride helper here.
The manual override metadata only doubles the inner stride. For batched packed FP4 tensors, the batch stride also needs to be doubled, otherwise
batch > 1executions read the wrong logical layout. Reusing_get_real_fp4_shape_from_packed_uint8()keeps the override path aligned with the non-override path.Suggested diff
- override_shapes = [ - [a.shape[0], a.shape[1], a.shape[2] * 2], - [b.shape[0], b.shape[1] * 2, b.shape[2]], + real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) + real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) + override_shapes = [ + list(real_a_shape), + list(real_b_shape), a_descale.shape, b_descale.shape, c_final.shape, ] override_strides = [ - [a.stride()[0], a.stride()[1] * 2, a.stride()[2]], - [b.stride()[0], b.stride()[1], b.stride()[2] * 2], + list(real_a_stride), + list(real_b_stride), a_descale.stride(), b_descale.stride(), c_final.stride(), ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 2154 - 2166, The override path currently builds override_shapes/override_strides manually and only doubles the inner stride, which breaks batched packed FP4 tensors; replace the manual constructions for a and b with calls to the canonical helper _get_real_fp4_shape_from_packed_uint8(a) and _get_real_fp4_shape_from_packed_uint8(b) (or its stride-aware variant) to produce the correct shape and stride tuples, then set override_shapes to [real_shape_a, real_shape_b, a_descale.shape, b_descale.shape, c_final.shape] and override_strides to [real_stride_a, real_stride_b, a_descale.stride(), b_descale.stride(), c_final.stride()] so batch and inner strides are adjusted consistently with the non-override path.tests/gemm/test_cudnn_override_shape.py (1)
186-213:⚠️ Potential issue | 🟠 MajorAdd output checks to the FP4/MXFP8 override-shape loops.
These cases only launch the kernels and synchronize, so numerically wrong outputs still pass. Please compare
outagainst the corresponding static-shape path, or at minimum assert something meaningful about the result inside each loop.Also applies to: 284-308
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_cudnn_override_shape.py` around lines 186 - 213, The FP4/MXFP8 override-shape loops currently only launch kernels and synchronize without validating results; update the loop that iterates dynamic_ms (and the similar loop later) to compute a reference result using the static-shape path (or a known-correct function) and compare it to out after execute_cudnn_fp4_gemm_graph_override_shape, e.g. call the same inputs through the static graph/function that produces the expected tensor and assert torch.allclose(out, expected, atol=..., rtol=...) or at minimum assert non-NaN/non-zero statistics; reference symbols to update: dynamic_ms loop, execute_cudnn_fp4_gemm_graph_override_shape, out, a_packed, b_packed, a_descale, b_descale, workspace, and the static-shape executor used elsewhere in the test so the override-shape branch actually verifies numerical correctness.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2204-2228: The call to build_cudnn_fp4_gemm_graph_override_shape
in the gemm builder is passing stale kwargs a_descale_n_dim, a_descale_k_dim,
b_descale_k_dim, and b_descale_n_dim which the function no longer accepts;
remove those four keyword arguments from the return call and only pass the
remaining valid parameters (batch, n, k, ab_type, o_type, block_size, device,
alpha_is_not_none, use_nvfp4) so the cached FP4 builder call no longer raises a
TypeError in build_cudnn_fp4_gemm_graph_override_shape.
- Around line 2154-2166: The override path currently builds
override_shapes/override_strides manually and only doubles the inner stride,
which breaks batched packed FP4 tensors; replace the manual constructions for a
and b with calls to the canonical helper
_get_real_fp4_shape_from_packed_uint8(a) and
_get_real_fp4_shape_from_packed_uint8(b) (or its stride-aware variant) to
produce the correct shape and stride tuples, then set override_shapes to
[real_shape_a, real_shape_b, a_descale.shape, b_descale.shape, c_final.shape]
and override_strides to [real_stride_a, real_stride_b, a_descale.stride(),
b_descale.stride(), c_final.stride()] so batch and inner strides are adjusted
consistently with the non-override path.
In `@tests/gemm/test_cudnn_override_shape.py`:
- Around line 186-213: The FP4/MXFP8 override-shape loops currently only launch
kernels and synchronize without validating results; update the loop that
iterates dynamic_ms (and the similar loop later) to compute a reference result
using the static-shape path (or a known-correct function) and compare it to out
after execute_cudnn_fp4_gemm_graph_override_shape, e.g. call the same inputs
through the static graph/function that produces the expected tensor and assert
torch.allclose(out, expected, atol=..., rtol=...) or at minimum assert
non-NaN/non-zero statistics; reference symbols to update: dynamic_ms loop,
execute_cudnn_fp4_gemm_graph_override_shape, out, a_packed, b_packed, a_descale,
b_descale, workspace, and the static-shape executor used elsewhere in the test
so the override-shape branch actually verifies numerical correctness.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ecf9f8ce-2c28-4a2d-bd66-f66115001391
📒 Files selected for processing (3)
flashinfer/gemm/__init__.pyflashinfer/gemm/gemm_base.pytests/gemm/test_cudnn_override_shape.py
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/gemm/test_cudnn_override_shape.py`:
- Around line 294-296: The tuple unpacking binds an unused variable
block_scale_dim_m_cache from the _calculate_block_scale_dims(...) call; to fix,
drop or rename that binding to a throwaway name (e.g., replace
"block_scale_dim_m_cache, block_scale_dim_n, block_scale_dim_k =
(_calculate_block_scale_dims(cache_m, n, k, block_size))" with "_,
block_scale_dim_n, block_scale_dim_k = _calculate_block_scale_dims(cache_m, n,
k, block_size)") so Ruff RUF059 is resolved while preserving the call and the
other two bindings.
- Around line 317-320: Test sampling only used non-negative MXFP8 patterns;
change the uint8 sampling so it covers the full 0..255 space (including values
with the sign bit set) while still avoiding the NaN bit-patterns 0x7F and 0xFF:
generate b with torch.randint(0, 256, (1, n, k), dtype=torch.uint8,
device=device).transpose(1,2) and then replace any occurrences of 0x7F or 0xFF
with other finite encodings (e.g., random picks from 0..254 excluding 0x7F), so
the tensor b exercises negative E4M3 encodings; apply the same change to the
other occurrence at lines 332-333.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 23175cf6-6c1b-4dce-beab-fdd6a55349b2
📒 Files selected for processing (1)
tests/gemm/test_cudnn_override_shape.py
|
/bot run |
|
[SUCCESS] Pipeline #46370787: 14/20 passed |
There was a problem hiding this comment.
♻️ Duplicate comments (2)
flashinfer/gemm/gemm_base.py (2)
2205-2228:⚠️ Potential issue | 🔴 CriticalDrop the removed FP4 cache-key kwargs from this helper.
build_cudnn_fp4_gemm_graph_override_shape()no longer accepts the descale-dimension kwargs, but this helper still passes them. The first call into_get_cudnn_fp4_gemm_graph_override_shape()will raiseTypeErrorinstead of returning a graph.Suggested fix
- expanded_a_descale_shape, _ = _expand_block_scale_tensor_shape(a_descale, batch) - expanded_b_descale_shape, _ = _expand_block_scale_tensor_shape(b_descale, batch) - - # Scale dimension sizes that are independent of M - a_descale_k_dim = expanded_a_descale_shape[2] - b_descale_k_dim = expanded_b_descale_shape[1] - b_descale_n_dim = expanded_b_descale_shape[2] - # a_descale N-dimension (dim[1]) depends on M, so we pass it separately - a_descale_n_dim = expanded_a_descale_shape[1] - return build_cudnn_fp4_gemm_graph_override_shape( batch=batch, n=n, k=k, - a_descale_n_dim=a_descale_n_dim, - a_descale_k_dim=a_descale_k_dim, - b_descale_k_dim=b_descale_k_dim, - b_descale_n_dim=b_descale_n_dim, ab_type=cudnn.data_type.FP4_E2M1, o_type=_torch_data_type_to_cudnn_data_type(out_dtype), block_size=block_size,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 2205 - 2228, The helper currently computes expanded descale dims via _expand_block_scale_tensor_shape and then passes a_descale_n_dim, a_descale_k_dim, b_descale_k_dim, b_descale_n_dim into build_cudnn_fp4_gemm_graph_override_shape, but that function no longer accepts those kwargs; remove those four descale-dimension keyword arguments from the call (keep all other args like batch, n, k, ab_type, o_type, block_size, device, alpha_is_not_none, use_nvfp4, etc.) so the call matches the new build_cudnn_fp4_gemm_graph_override_shape/_get_cudnn_fp4_gemm_graph_override_shape signature and no TypeError is raised. Ensure any computed descale dim variables (a_descale_n_dim, a_descale_k_dim, b_descale_k_dim, b_descale_n_dim) are not passed and can be removed if unused.
2155-2167:⚠️ Potential issue | 🟠 MajorReuse the canonical FP4 shape/stride helper here.
This override metadata is still hand-derived, and it only scales the packed dimension stride. For
batch > 1, the logical batch stride also doubles for packed FP4, so cuDNN will read the wrong batch slice. The current tests don't catch it because they only run withbatch == 1.Suggested fix
+ real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) + real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) + override_shapes = [ - [a.shape[0], a.shape[1], a.shape[2] * 2], - [b.shape[0], b.shape[1] * 2, b.shape[2]], - a_descale.shape, - b_descale.shape, - c_final.shape, + list(real_a_shape), + list(real_b_shape), + list(a_descale.shape), + list(b_descale.shape), + list(c_final.shape), ] override_strides = [ - [a.stride()[0], a.stride()[1] * 2, a.stride()[2]], - [b.stride()[0], b.stride()[1], b.stride()[2] * 2], - a_descale.stride(), - b_descale.stride(), - c_final.stride(), + list(real_a_stride), + list(real_b_stride), + list(a_descale.stride()), + list(b_descale.stride()), + list(c_final.stride()), ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 2155 - 2167, The override_shapes/override_strides block is hand-derived and only scales the packed dimension stride, which breaks when batch > 1; replace the manual scaling for the FP4-packed inputs with the canonical FP4 shape/stride helper (the module's FP4 helper used elsewhere) rather than hand-manipulating a.shape/a.stride and b.shape/b.stride. Call the helper for a and b to produce their overridden shapes and strides (e.g., use get_canonical_fp4_shape_stride(a.shape, a.stride) and same for b) and leave a_descale.shape/stride, b_descale.shape/stride and c_final.shape/stride as-is; update override_shapes to use the helper-returned shapes and override_strides to use the helper-returned strides so the batch stride is correctly doubled for packed FP4.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2205-2228: The helper currently computes expanded descale dims via
_expand_block_scale_tensor_shape and then passes a_descale_n_dim,
a_descale_k_dim, b_descale_k_dim, b_descale_n_dim into
build_cudnn_fp4_gemm_graph_override_shape, but that function no longer accepts
those kwargs; remove those four descale-dimension keyword arguments from the
call (keep all other args like batch, n, k, ab_type, o_type, block_size, device,
alpha_is_not_none, use_nvfp4, etc.) so the call matches the new
build_cudnn_fp4_gemm_graph_override_shape/_get_cudnn_fp4_gemm_graph_override_shape
signature and no TypeError is raised. Ensure any computed descale dim variables
(a_descale_n_dim, a_descale_k_dim, b_descale_k_dim, b_descale_n_dim) are not
passed and can be removed if unused.
- Around line 2155-2167: The override_shapes/override_strides block is
hand-derived and only scales the packed dimension stride, which breaks when
batch > 1; replace the manual scaling for the FP4-packed inputs with the
canonical FP4 shape/stride helper (the module's FP4 helper used elsewhere)
rather than hand-manipulating a.shape/a.stride and b.shape/b.stride. Call the
helper for a and b to produce their overridden shapes and strides (e.g., use
get_canonical_fp4_shape_stride(a.shape, a.stride) and same for b) and leave
a_descale.shape/stride, b_descale.shape/stride and c_final.shape/stride as-is;
update override_shapes to use the helper-returned shapes and override_strides to
use the helper-returned strides so the batch stride is correctly doubled for
packed FP4.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1eade161-0775-4982-aaad-5442e4ad5f0d
📒 Files selected for processing (2)
flashinfer/gemm/gemm_base.pytests/gemm/test_cudnn_override_shape.py
|
/bot run |
|
[SUCCESS] Pipeline #46472647: 14/20 passed |
|
@yanqinz2 It seems the PR introduced a bunch of new API for cudnn dynamic shape support. However, we already have function like mm_fp4 for fp4 gemm with cudnn as an backend option. Is it possible to fit the dynamic shape support into those functions instead of creating new API? |
📌 Description
Add override shape support for cudnn backend with test examples.
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Summary by CodeRabbit
New Features
Tests