Conversation
📝 WalkthroughWalkthroughRefactors cuDNN GEMM override-shape APIs (renaming FP4/MXFP8 exports), adds an optional Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the cuDNN GEMM implementation to provide consistent dynamic-shape (override-shape) support across FP4, MXFP8, and BF16 data types. Key changes include renaming functions for better naming consistency, adding a policy parameter to graph builders, and integrating override-shape logic into the TunableRunner classes for BF16 and FP4. The review feedback focuses on improving the efficiency of workspace buffer handling by suggesting that the code raise a ValueError for undersized buffers instead of performing local re-allocations, and recommends refactoring duplicated logic within the runner classes to improve maintainability.
| if workspace_buffer.numel() < graph.get_workspace_size(): | ||
| workspace_buffer = torch.empty( | ||
| graph.get_workspace_size(), device=a.device, dtype=torch.uint8 | ||
| ) |
There was a problem hiding this comment.
Re-assigning workspace_buffer here only changes the local variable. The caller's buffer remains unchanged and potentially undersized, leading to re-allocation on every call if the initial buffer is insufficient. This is inefficient.
A better approach would be to raise a ValueError if the buffer is too small, forcing the caller to provide a sufficiently sized buffer.
| if workspace_buffer.numel() < graph.get_workspace_size(): | |
| workspace_buffer = torch.empty( | |
| graph.get_workspace_size(), device=a.device, dtype=torch.uint8 | |
| ) | |
| if workspace_buffer.numel() < graph.get_workspace_size(): | |
| raise ValueError( | |
| f"workspace_buffer is too small. Need at least {graph.get_workspace_size()} elements, but got {workspace_buffer.numel()}." | |
| ) |
| if workspace_buffer.numel() < graph.get_workspace_size(): | ||
| workspace_buffer = torch.empty( | ||
| graph.get_workspace_size(), device=a.device, dtype=torch.uint8 | ||
| ) |
There was a problem hiding this comment.
Similar to other execute_* functions in this file, re-assigning workspace_buffer here is inefficient as the change is local. If the caller passes an undersized buffer, it will be re-allocated on every call. Consider raising a ValueError instead to enforce that the caller provides a buffer of adequate size.
| if workspace_buffer.numel() < graph.get_workspace_size(): | |
| workspace_buffer = torch.empty( | |
| graph.get_workspace_size(), device=a.device, dtype=torch.uint8 | |
| ) | |
| if workspace_buffer.numel() < graph.get_workspace_size(): | |
| raise ValueError( | |
| f"workspace_buffer is too small. Need at least {graph.get_workspace_size()} elements, but got {workspace_buffer.numel()}." | |
| ) |
| if is_cudnn_override_shape_available(): | ||
| graph = self._get_override_graph(a, b, out) |
There was a problem hiding this comment.
The condition is_cudnn_override_shape_available() and the call to self._get_override_graph(a, b, out) are duplicated in get_valid_tactics and forward. This could be refactored to improve maintainability and avoid redundant graph lookups/builds, even with caching. Consider creating a helper method that retrieves the correct graph based on availability, which can be called by both get_valid_tactics and forward.
flashinfer/gemm/gemm_base.py
Outdated
| if is_cudnn_override_shape_available() and alpha is None: | ||
| graph = self._get_override_graph( | ||
| a, b, alpha, out_dtype, block_size, use_nvfp4 | ||
| ) |
There was a problem hiding this comment.
The condition is_cudnn_override_shape_available() and alpha is None is duplicated in get_valid_tactics and forward. This could lead to maintenance issues if the condition changes. Consider refactoring this logic into a helper method to determine which execution path to take. This would also avoid calling self._get_override_graph twice (once in get_valid_tactics and once in forward), improving efficiency even with caching.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tests/gemm/test_cudnn_override_shape.py (1)
17-27: Import the public helpers throughflashinfer.gemmin this test.Right now this bypasses
flashinfer.gemm, so the test still passes even if the package re-export layer regresses. Since this PR changes that surface, the test should exercise it.🧪 Minimal import split
-from flashinfer.gemm.gemm_base import ( - CUDNN_AVAILABLE, - build_cudnn_gemm_bf16_graph_override_shape, - execute_cudnn_gemm_bf16_graph_override_shape, - build_cudnn_gemm_fp4_graph_override_shape, - execute_cudnn_gemm_fp4_graph_override_shape, - build_cudnn_gemm_mxfp8_graph_override_shape, - execute_cudnn_gemm_mxfp8_graph_override_shape, - is_cudnn_override_shape_available, - _calculate_block_scale_dims, -) +from flashinfer.gemm import ( + build_cudnn_gemm_bf16_graph_override_shape, + execute_cudnn_gemm_bf16_graph_override_shape, + build_cudnn_gemm_fp4_graph_override_shape, + execute_cudnn_gemm_fp4_graph_override_shape, + build_cudnn_gemm_mxfp8_graph_override_shape, + execute_cudnn_gemm_mxfp8_graph_override_shape, + is_cudnn_override_shape_available, +) +from flashinfer.gemm.gemm_base import CUDNN_AVAILABLE, _calculate_block_scale_dims🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_cudnn_override_shape.py` around lines 17 - 27, The test directly imports helpers from flashinfer.gemm.gemm_base instead of exercising the package re-export layer; update the import to import the public helpers from flashinfer.gemm (e.g. import CUDNN_AVAILABLE, build_cudnn_gemm_bf16_graph_override_shape, execute_cudnn_gemm_bf16_graph_override_shape, build_cudnn_gemm_fp4_graph_override_shape, execute_cudnn_gemm_fp4_graph_override_shape, build_cudnn_gemm_mxfp8_graph_override_shape, execute_cudnn_gemm_mxfp8_graph_override_shape, is_cudnn_override_shape_available, _calculate_block_scale_dims) via from flashinfer.gemm import <symbols> so the test fails if the package re-export surface regresses.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/__init__.py`:
- Around line 25-28: Restore backward-compatible aliases for the renamed
override-shape exports by reintroducing the old names as simple assignments to
the new symbols: for example, set the previous FP4/MXFP8 export names equal to
build_cudnn_gemm_fp4_graph_override_shape,
execute_cudnn_gemm_fp4_graph_override_shape,
build_cudnn_gemm_mxfp8_graph_override_shape, and
execute_cudnn_gemm_mxfp8_graph_override_shape in flashinfer.gemm.__init__.py so
old imports continue to work; also add the same alias assignments in the
flashinfer.gemm.gemm_base module if that path is a supported public import so
both import surfaces mirror each other.
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2121-2124: Replace the runtime assertion with explicit input
validation that raises a ValueError: instead of using "assert real_a_stride[2]
== 1 and real_b_stride[1] == 1", check those conditions and raise ValueError
with a clear message (e.g., "a and b must be k-major") so invalid
caller-provided strides cannot slip through when Python assertions are disabled;
update the code around _get_bf16_3d_shape_stride and the
real_a_stride/real_b_stride checks accordingly.
---
Nitpick comments:
In `@tests/gemm/test_cudnn_override_shape.py`:
- Around line 17-27: The test directly imports helpers from
flashinfer.gemm.gemm_base instead of exercising the package re-export layer;
update the import to import the public helpers from flashinfer.gemm (e.g. import
CUDNN_AVAILABLE, build_cudnn_gemm_bf16_graph_override_shape,
execute_cudnn_gemm_bf16_graph_override_shape,
build_cudnn_gemm_fp4_graph_override_shape,
execute_cudnn_gemm_fp4_graph_override_shape,
build_cudnn_gemm_mxfp8_graph_override_shape,
execute_cudnn_gemm_mxfp8_graph_override_shape,
is_cudnn_override_shape_available, _calculate_block_scale_dims) via from
flashinfer.gemm import <symbols> so the test fails if the package re-export
surface regresses.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 463f5a46-a140-49a0-8802-735d5fba51af
📒 Files selected for processing (3)
flashinfer/gemm/__init__.pyflashinfer/gemm/gemm_base.pytests/gemm/test_cudnn_override_shape.py
| build_cudnn_gemm_fp4_graph_override_shape as build_cudnn_gemm_fp4_graph_override_shape, | ||
| execute_cudnn_gemm_fp4_graph_override_shape as execute_cudnn_gemm_fp4_graph_override_shape, | ||
| build_cudnn_gemm_mxfp8_graph_override_shape as build_cudnn_gemm_mxfp8_graph_override_shape, | ||
| execute_cudnn_gemm_mxfp8_graph_override_shape as execute_cudnn_gemm_mxfp8_graph_override_shape, |
There was a problem hiding this comment.
Preserve the old override-shape export names as aliases.
These symbols were already part of the flashinfer.gemm surface. Replacing them outright means downstream imports fail immediately on upgrade. Please keep the old FP4/MXFP8 names as shims for at least one release, or call out the breaking change explicitly. If flashinfer.gemm.gemm_base is also a supported import path, mirror the aliases there too.
↩️ Compatibility shim example
from .gemm_base import (
is_cudnn_override_shape_available as is_cudnn_override_shape_available,
build_cudnn_gemm_bf16_graph_override_shape as build_cudnn_gemm_bf16_graph_override_shape,
execute_cudnn_gemm_bf16_graph_override_shape as execute_cudnn_gemm_bf16_graph_override_shape,
build_cudnn_gemm_fp4_graph_override_shape as build_cudnn_gemm_fp4_graph_override_shape,
execute_cudnn_gemm_fp4_graph_override_shape as execute_cudnn_gemm_fp4_graph_override_shape,
build_cudnn_gemm_mxfp8_graph_override_shape as build_cudnn_gemm_mxfp8_graph_override_shape,
execute_cudnn_gemm_mxfp8_graph_override_shape as execute_cudnn_gemm_mxfp8_graph_override_shape,
build_cudnn_gemm_with_per_tensor_q_graph_override_shape as build_cudnn_gemm_with_per_tensor_q_graph_override_shape,
execute_cudnn_gemm_with_per_tensor_q_graph_override_shape as execute_cudnn_gemm_with_per_tensor_q_graph_override_shape,
)
+
+# Backward-compat aliases
+build_cudnn_fp4_gemm_graph_override_shape = build_cudnn_gemm_fp4_graph_override_shape
+execute_cudnn_fp4_gemm_graph_override_shape = execute_cudnn_gemm_fp4_graph_override_shape
+build_cudnn_mxfp8_gemm_graph_override_shape = build_cudnn_gemm_mxfp8_graph_override_shape
+execute_cudnn_mxfp8_gemm_graph_override_shape = execute_cudnn_gemm_mxfp8_graph_override_shape
@@
"build_cudnn_gemm_fp4_graph_override_shape",
"execute_cudnn_gemm_fp4_graph_override_shape",
"build_cudnn_gemm_mxfp8_graph_override_shape",
"execute_cudnn_gemm_mxfp8_graph_override_shape",
+ "build_cudnn_fp4_gemm_graph_override_shape",
+ "execute_cudnn_fp4_gemm_graph_override_shape",
+ "build_cudnn_mxfp8_gemm_graph_override_shape",
+ "execute_cudnn_mxfp8_gemm_graph_override_shape",
"build_cudnn_gemm_with_per_tensor_q_graph_override_shape",
"execute_cudnn_gemm_with_per_tensor_q_graph_override_shape",
] + _cute_dsl_kernelsAlso applies to: 82-85
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gemm/__init__.py` around lines 25 - 28, Restore
backward-compatible aliases for the renamed override-shape exports by
reintroducing the old names as simple assignments to the new symbols: for
example, set the previous FP4/MXFP8 export names equal to
build_cudnn_gemm_fp4_graph_override_shape,
execute_cudnn_gemm_fp4_graph_override_shape,
build_cudnn_gemm_mxfp8_graph_override_shape, and
execute_cudnn_gemm_mxfp8_graph_override_shape in flashinfer.gemm.__init__.py so
old imports continue to work; also add the same alias assignments in the
flashinfer.gemm.gemm_base module if that path is a supported public import so
both import surfaces mirror each other.
| c_shape, c_stride = _get_bf16_3d_shape_stride(c_final) | ||
|
|
||
| assert real_a_stride[2] == 1 and real_b_stride[1] == 1, "a and b must be k-major" | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, locate and examine the specific lines in question
wc -l flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 101
🏁 Script executed:
# Read the section around lines 2121-2124 to see the assertion
sed -n '2110,2135p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 963
🏁 Script executed:
# Look at the function containing this assertion
sed -n '2080,2130p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1583
🏁 Script executed:
# Search for other assert statements related to layout in this file
rg 'assert.*stride' flashinfer/gemm/gemm_base.py -A 1 -B 1Repository: flashinfer-ai/flashinfer
Length of output: 155
🏁 Script executed:
# Check if there are other ValueError patterns for layout validation
rg 'raise ValueError.*stride|raise ValueError.*layout|raise ValueError.*major' flashinfer/gemm/gemm_base.py -A 1 -B 1Repository: flashinfer-ai/flashinfer
Length of output: 831
🏁 Script executed:
# Search for other assertions in executor functions, particularly override-shape ones
rg 'def execute.*override' flashinfer/gemm/gemm_base.py -A 50 | rg 'assert'Repository: flashinfer-ai/flashinfer
Length of output: 153
🏁 Script executed:
# Get more context on the function signature and documentation
sed -n '2090,2105p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 330
🏁 Script executed:
# Check for other assertions in executor functions that validate caller input
rg 'def execute_cudnn.*\(' flashinfer/gemm/gemm_base.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 484
🏁 Script executed:
# Search within other executor functions for assertions
rg 'def execute_' flashinfer/gemm/gemm_base.py -A 30 | grep -E '(def execute_|assert )'Repository: flashinfer-ai/flashinfer
Length of output: 570
🏁 Script executed:
# Check the non-override FP4 executor for comparison
rg 'def execute_cudnn_gemm_fp4_graph\(' flashinfer/gemm/gemm_base.py -A 40 | head -60Repository: flashinfer-ai/flashinfer
Length of output: 1173
🏁 Script executed:
# Also check MXFP8 override executor for validation patterns
rg 'def execute_cudnn_gemm_mxfp8_graph_override_shape' flashinfer/gemm/gemm_base.py -A 50 | grep -E '(assert|raise ValueError|stride)'Repository: flashinfer-ai/flashinfer
Length of output: 244
🏁 Script executed:
# Check if the non-override FP4 function validates strides differently
rg 'def execute_cudnn_gemm_fp4_graph\(' flashinfer/gemm/gemm_base.py -A 60 | grep -E '(stride|assert|ValueError)'Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at the full MXFP8 override executor to see if it has similar validation
sed -n '2200,2350p' flashinfer/gemm/gemm_base.py | grep -E '(def execute|assert|stride|raise ValueError)'Repository: flashinfer-ai/flashinfer
Length of output: 847
🏁 Script executed:
# Look at the full MXFP8 override executor function to see complete validation pattern
rg 'def execute_cudnn_gemm_mxfp8_graph_override_shape' flashinfer/gemm/gemm_base.py -A 80 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 2240
Use ValueError instead of assert for FP4 layout validation.
The assertion at line 2123 validates caller-provided tensor strides. Under python -O, it will be stripped, allowing invalid layouts to pass to cuDNN. Input validation must use exceptions, not assertions. This is consistent with the codebase's pattern of using ValueError for constraint validation.
🛠️ Proposed fix
- assert real_a_stride[2] == 1 and real_b_stride[1] == 1, "a and b must be k-major"
+ if real_a_stride[2] != 1 or real_b_stride[1] != 1:
+ raise ValueError("a and b must be k-major")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gemm/gemm_base.py` around lines 2121 - 2124, Replace the runtime
assertion with explicit input validation that raises a ValueError: instead of
using "assert real_a_stride[2] == 1 and real_b_stride[1] == 1", check those
conditions and raise ValueError with a clear message (e.g., "a and b must be
k-major") so invalid caller-provided strides cannot slip through when Python
assertions are disabled; update the code around _get_bf16_3d_shape_stride and
the real_a_stride/real_b_stride checks accordingly.
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/gemm/gemm_base.py (2)
1955-1966:⚠️ Potential issue | 🟠 MajorKeep deprecated aliases for the renamed override-shape helpers.
These functions are part of the exported
flashinfer.gemmsurface, so renaming them without keeping the previous symbols is a hard import break for downstream users. Please keep the old FP4/MXFP8 names as aliases for at least one release cycle, or explicitly treat this PR as a breaking API change.Also applies to: 2090-2100, 2209-2219, 2321-2330
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 1955 - 1966, The renamed override-shape helper functions must keep their old FP4/MXFP8 symbols as deprecated aliases: add thin alias definitions that point the old names to the new implementations (e.g., define the previous FP4/MXFP8 function names to return/call build_cudnn_gemm_fp4_graph_override_shape and the other two renamed helpers referenced at the other ranges) and emit a warnings.warn(..., DeprecationWarning) when those aliases are called so downstream code keeps working for one release cycle while signaling the rename; ensure the alias names exactly match the previous exported symbols and reference the new functions (use the new function names from this file such as build_cudnn_gemm_fp4_graph_override_shape) and include a short deprecation message indicating the new name.
2321-2365:⚠️ Potential issue | 🟠 MajorApply shape normalization to MXFP8 override-shape executor for consistency.
execute_cudnn_gemm_mxfp8_graph_override_shape()forwards raw tensor shapes/strides, whereasexecute_cudnn_gemm_bf16_graph_override_shape()uses_get_bf16_3d_shape_stride()to normalize 2D inputs to 3D. Even though MXFP8 currently enforces 3D inputs, the inconsistency should be resolved by using the same normalization pattern:Suggested normalization
def execute_cudnn_gemm_mxfp8_graph_override_shape( graph, a, b, @@ ): """Execute MXFP8 GEMM cuDNN graph with dynamic-shape overrides.""" + a_shape, a_stride = _get_bf16_3d_shape_stride(a) + b_shape, b_stride = _get_bf16_3d_shape_stride(b) + batch = a_shape[0] + a_descale_shape, a_descale_stride = _expand_block_scale_tensor_shape( + a_descale, batch + ) + b_descale_shape, b_descale_stride = _expand_block_scale_tensor_shape( + b_descale, batch + ) + c_shape, c_stride = _get_bf16_3d_shape_stride(c_final) + variant_pack = { UIDs.A_UID.value: a, UIDs.B_UID.value: b, UIDs.BLOCK_DESCALE_A_UID.value: a_descale, @@ override_shapes = [ - list(a.shape), - list(b.shape), - list(a_descale.shape), - list(b_descale.shape), - list(c_final.shape), + list(a_shape), + list(b_shape), + list(a_descale_shape), + list(b_descale_shape), + list(c_shape), ] override_strides = [ - list(a.stride()), - list(b.stride()), - list(a_descale.stride()), - list(b_descale.stride()), - list(c_final.stride()), + list(a_stride), + list(b_stride), + list(a_descale_stride), + list(b_descale_stride), + list(c_stride), ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 2321 - 2365, The override-shape executor execute_cudnn_gemm_mxfp8_graph_override_shape currently forwards raw tensor shapes/strides; update it to use the same normalization used by the BF16 path by calling _get_bf16_3d_shape_stride() for each input (a, b, a_descale, b_descale, c_final) and use the returned normalized shape and stride values when building override_shapes and override_strides; ensure you replace the direct list(tensor.shape)/list(tensor.stride()) calls with the normalized shape/stride results so the MXFP8 override-shape logic is consistent with execute_cudnn_gemm_bf16_graph_override_shape.
♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)
2113-2115:⚠️ Potential issue | 🟡 MinorUse
ValueErrorinstead ofassertfor the k-major guard.Line 2115 is validating caller input. Under
python -O, thatassertdisappears and invalid layouts can slip through to cuDNN.🛠️ Safer runtime validation
- assert real_a_stride[2] == 1 and real_b_stride[1] == 1, "a and b must be k-major" + if real_a_stride[2] != 1 or real_b_stride[1] != 1: + raise ValueError("a and b must be k-major")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 2113 - 2115, Replace the runtime assertion that validates k-major layout with an explicit exception: instead of using assert real_a_stride[2] == 1 and real_b_stride[1] == 1, raise a ValueError with the same descriptive message so the check remains active under python -O; locate the check around the call to _get_bf16_3d_shape_stride and update the validation (referring to variables real_a_stride and real_b_stride) to raise ValueError("a and b must be k-major") when the condition fails.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 1955-1966: The renamed override-shape helper functions must keep
their old FP4/MXFP8 symbols as deprecated aliases: add thin alias definitions
that point the old names to the new implementations (e.g., define the previous
FP4/MXFP8 function names to return/call
build_cudnn_gemm_fp4_graph_override_shape and the other two renamed helpers
referenced at the other ranges) and emit a warnings.warn(...,
DeprecationWarning) when those aliases are called so downstream code keeps
working for one release cycle while signaling the rename; ensure the alias names
exactly match the previous exported symbols and reference the new functions (use
the new function names from this file such as
build_cudnn_gemm_fp4_graph_override_shape) and include a short deprecation
message indicating the new name.
- Around line 2321-2365: The override-shape executor
execute_cudnn_gemm_mxfp8_graph_override_shape currently forwards raw tensor
shapes/strides; update it to use the same normalization used by the BF16 path by
calling _get_bf16_3d_shape_stride() for each input (a, b, a_descale, b_descale,
c_final) and use the returned normalized shape and stride values when building
override_shapes and override_strides; ensure you replace the direct
list(tensor.shape)/list(tensor.stride()) calls with the normalized shape/stride
results so the MXFP8 override-shape logic is consistent with
execute_cudnn_gemm_bf16_graph_override_shape.
---
Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 2113-2115: Replace the runtime assertion that validates k-major
layout with an explicit exception: instead of using assert real_a_stride[2] == 1
and real_b_stride[1] == 1, raise a ValueError with the same descriptive message
so the check remains active under python -O; locate the check around the call to
_get_bf16_3d_shape_stride and update the validation (referring to variables
real_a_stride and real_b_stride) to raise ValueError("a and b must be k-major")
when the condition fails.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6aab590b-9860-40d0-825b-b2ff37180d78
📒 Files selected for processing (1)
flashinfer/gemm/gemm_base.py
📌 Description
Add cudnn dynamic shape support for bf16 and fp4 gemm
🔍 Related Issues
https://nvbugspro.nvidia.com/bug/5539146
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Refactor
New Features
Bug Fixes