Conversation
) Implement `mm_mxfp8` / `bmm_mxfp8` support for SM12x GPUs (RTX PRO 6000, RTX 5090, RTX 5080) via CUTLASS 4.x blockscaled GEMM. Key changes: - New CUTLASS kernel template: `mxfp8_gemm_cutlass_template_sm120.h` using `Sm1xxBlkScaledConfig` with 3 CTA tile configs (128×128×128, 256×128×128, 128×256×128) - New C++ launcher: `csrc/mxfp8_gemm_cutlass_sm120.cu` with full input validation (K%32==0, N%32==0, M unconstrained, swizzled-only scale format) - SM120 kernel hardcodes the hardware-native swizzled scale layout (Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA/SFB); linear (2D) scale is explicitly rejected with a clear error message - JIT module generator `gen_gemm_sm120_module_cutlass_mxfp8()` in `flashinfer/jit/gemm/core.py` - AOT registration under `has_sm120 or has_sm121` in `flashinfer/aot.py` - Updated `mm_mxfp8` docstring with SM12x swizzled-only note - Added `mm_mxfp8` to `docs/api/gemm.rst` - Test suite: `tests/gemm/test_mm_mxfp8_sm120.py` covering arbitrary M (including non-multiples of 128), N/K in {128,256,512,1024}, both bf16/fp16 output, cos_sim > 0.99 accuracy threshold AI-assisted Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
📝 WalkthroughWalkthroughThis PR adds SM120 MXFP8 GEMM support by implementing SM120-specific kernel templates, dispatchers, JIT generators, and Python integration. It includes new CUDA kernels, template instantiations, runtime dispatchers with configuration selection, documentation, and comprehensive tests for the SM120 architecture. Changes
Sequence Diagram(s)sequenceDiagram
participant PyAPI as Python API<br/>(mm_mxfp8)
participant PyGemm as gemm_base.py<br/>(module routing)
participant Dispatch as SM120 Dispatcher<br/>(mxfp8_gemm_cutlass_sm120.cu)
participant Runner as CutlassMxfp8GemmRunnerSm120<br/>(template_sm120.h)
participant Kernel as CUTLASS Kernel<br/>(SM120 arch)
participant CUDA as CUDA Device
PyAPI->>PyGemm: get_cutlass_mxfp8_gemm_module(sm_major=12)
PyGemm->>Dispatch: Load SM120 MXFP8 module
PyAPI->>Dispatch: mxfp8_gemm(mat1, mat2, scales, workspace, tactic)
Dispatch->>Dispatch: Validate inputs (dtypes, ranks, scale layout)
Dispatch->>Dispatch: Select CutlassGemmConfig from tactic
Dispatch->>Runner: gemm(A, B, scales, workspace)
Runner->>Runner: dispatchToArch (select CTA shape)
Runner->>Kernel: Launch kernel with CTA dimensions
Kernel->>CUDA: Execute MXFP8 block-scaled GEMM
CUDA-->>Kernel: Return results
Kernel-->>Runner: Workspace size / Status
Runner-->>Dispatch: Completion
Dispatch-->>PyAPI: Output tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request implements MXFP8 GEMM support for SM120 (Blackwell) GPUs using CUTLASS. The changes include the core CUDA implementation, JIT and AOT build system integration, documentation updates, and new test cases. Feedback identifies an improvement opportunity to refactor duplicated logic for retrieving GEMM configurations into a shared helper function to enhance maintainability.
| int64_t mxfp8_gemm_tactic_num_sm120() { | ||
| auto getCutlassConfigs = []() { | ||
| CutlassMxfp8GemmRunnerSm120<__nv_bfloat16> gemmRunner; | ||
| return gemmRunner.getConfigs(); | ||
| }; | ||
| static int64_t totalTactics = getCutlassConfigs().size(); | ||
| return totalTactics; | ||
| } |
There was a problem hiding this comment.
The logic for retrieving CUTLASS GEMM configurations is duplicated from getMxfp8GemmConfigSm120. To improve maintainability and avoid code duplication, you could introduce a shared helper function.
For example, you could create a helper function within a detail namespace inside torch_ext:
namespace torch_ext {
namespace detail {
inline const std::vector<CutlassGemmConfig>& GetMxfp8GemmConfigsSm120() {
static const std::vector<CutlassGemmConfig> kGlobalConfigs = []() {
CutlassMxfp8GemmRunnerSm120<__nv_bfloat16> gemmRunner;
return gemmRunner.getConfigs();
}();
return kGlobalConfigs;
}
} // namespace detailThen, getMxfp8GemmConfigSm120 and mxfp8_gemm_tactic_num_sm120 can both use this helper:
// In getMxfp8GemmConfigSm120
const auto& globalConfigs = detail::GetMxfp8GemmConfigsSm120();
// In mxfp8_gemm_tactic_num_sm120
int64_t mxfp8_gemm_tactic_num_sm120() {
return detail::GetMxfp8GemmConfigsSm120().size();
}There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)
3409-3420:⚠️ Potential issue | 🟠 MajorReject SM12x-incompatible CUTLASS inputs during backend selection.
The docstring now says SM12x CUTLASS only supports 1D
layout_128x4scales, but_cutlass_gemm_mxfp8_requirement()still returnsTrueunconditionally. That meansbackend="auto"can still pick CUTLASS for 2D scales,use_8x4_sf_layout=True, orN/Kvalues that the SM120 launcher rejects, and the failure happens late instead of at API validation time.♻️ Suggested guard
`@supported_compute_capability`([100, 103, 110, 120, 121]) def _cutlass_gemm_mxfp8_requirement( a: torch.Tensor, b: torch.Tensor, @@ use_8x4_sf_layout: bool = True, backend: Literal["cutlass", "cute-dsl", "trtllm", "auto"] = "auto", ): + if is_sm12x_supported(a.device): + if a.shape[1] % 32 != 0 or b.shape[1] % 32 != 0: + raise ValueError( + "SM120/SM121 CUTLASS MXFP8 requires K and N to be multiples of 32." + ) + if use_8x4_sf_layout: + raise ValueError( + "SM120/SM121 CUTLASS MXFP8 only supports SfLayout.layout_128x4." + ) + if a_descale.ndim != 1 or b_descale.ndim != 1: + raise ValueError( + "SM120/SM121 CUTLASS MXFP8 only supports 1D swizzled scale tensors." + ) return TrueAlso applies to: 3933-3935
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 3409 - 3420, Update _cutlass_gemm_mxfp8_requirement to proactively reject inputs that SM12x CUTLASS cannot support: detect when backend selection might pick CUTLASS (backend in {"auto","cutlass"}) and return False for SM12x GPUs if any of the SM12x-only constraints are violated — e.g., a_descale or b_descale are not 1D layout_128x4-style scales, use_8x4_sf_layout is True when SM12x doesn't support that layout, or N/K tensor dimensions/sizes fall into ranges the SM120 launcher rejects; apply the same validation logic to the analogous check around lines 3933-3935 so backend="auto" won't choose CUTLASS for incompatible 2D scales, layouts, or N/K values. Ensure you reference and use the function name _cutlass_gemm_mxfp8_requirement (and the similar guard at 3933-3935) to locate and implement these condition checks and return False when constraints are not met.
🧹 Nitpick comments (3)
include/flashinfer/gemm/mxfp8_gemm_cutlass_template_sm120.h (1)
107-113: Document why the workspace probe swallowsstd::runtime_error.This catch is part of config probing, but without a short rationale it reads like accidental error suppression.
Based on learnings, swallowed `std::runtime_error` probes in `getWorkspaceSizeImpl` are acceptable here, but they should carry a brief rationale comment.Suggested change
- } catch (std::runtime_error&) { + } catch (std::runtime_error&) { + // Swallow errors when a candidate config is ineligible or exceeds SMEM. continue; }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/gemm/mxfp8_gemm_cutlass_template_sm120.h` around lines 107 - 113, The workspace-size probe in getWorkspaceSizeImpl currently catches and swallows std::runtime_error thrown by dispatchToArch (used to probe candidate gemmConfig) without explanation; add a brief comment above the try/catch explaining that probing may throw for unsupported/config-invalid configs and that these exceptions are intentionally ignored to allow trying other configs (i.e., this is a non-fatal probe, not a logic bug), referencing dispatchToArch and workspace_size so readers know why the catch exists and that it only applies to probing, not to actual execution.tests/gemm/test_mm_mxfp8_sm120.py (2)
28-45: The SM120-specific paths still need direct coverage.This file only exercises 2D swizzled cases with
N/Kvalues that are multiples of 128. Please add at least one 3D batched case, oneN/Kcase that's only a multiple of 32, and onepytest.raisescase for linear scales so the newbmm_mxfp8, alignment, and rejection paths are actually tested.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_mm_mxfp8_sm120.py` around lines 28 - 45, Extend the test matrix in tests/gemm/test_mm_mxfp8_sm120.py to cover SM120-specific code paths by (1) adding a 3D batched input case that calls _prepare_mxfp8 and exercises bmm_mxfp8 (e.g., shapes with batch>1) so the batched branch is executed; (2) adding at least one case where N and K are multiples of 32 but not 128 (e.g., 96 or 160) to hit the alignment path that handles 32-granularity; and (3) adding a pytest.raises test that calls mxfp8_quantize/_prepare_mxfp8 (using SfLayout.layout_linear) with an input shape that should trigger the linear-scale rejection so the code path that raises for invalid linear scales is covered; ensure the new tests reference _prepare_mxfp8, mxfp8_quantize, bmm_mxfp8 and use SfLayout.layout_linear / SfLayout.layout_128x4 to select swizzled vs linear flows.
81-109:all_tacticsonly exercises the default path once.This proves the module reports a tactic count, but it never dispatches tactics 1 and 2. A broken secondary CTA instantiation would still pass here, so either iterate each tactic explicitly or rename the test to match its real coverage.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_mm_mxfp8_sm120.py` around lines 81 - 109, The test test_mm_mxfp8_sm120_all_tactics only exercises the default tactic; query module.mxfp8_gemm_tactic_num() and loop i from 0 to num_tactics-1, invoking the same execution path for each tactic (use the module or mm_mxfp8 API that accepts a tactic/index parameter or call the module’s tactic-dispatch method) and validate shape, finiteness and cosine similarity per-tactic; alternatively if you intend to only check tactic count, rename the test to reflect that (e.g., test_mm_mxfp8_sm120_reports_tactics) instead of claiming it exercises all tactics.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/api/gemm.rst`:
- Around line 30-33: The autosummary list for GEMM currently only documents
mm_mxfp8, omitting the batched API bmm_mxfp8; update the autosummary block to
include bmm_mxfp8 alongside mm_mxfp8 so the batched MXFP8 function is included
in the generated reference docs (i.e., add the symbol name bmm_mxfp8 into the
same autosummary list that contains mm_mxfp8).
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3271-3286: The batched MXFP8 path isn't using the SM120/121
CUTLASS module: update the bmm path to support SM120/121 by wiring the CUTLASS
runner there and relaxing the cuDNN requirement checks. Specifically, modify
bmm_mxfp8 to prefer the CUTLASS module for sm_major 12 (use
get_cutlass_mxfp8_gemm_module or call the equivalent cutlass runner instead of
hard-coding mxfp8_gemm_sm100(..., ["cudnn"])), and update
_cudnn_bmm_mxfp8_requirement to allow SM120/121 (or to defer to the CUTLASS
option) so batched MXFP8 can run on SM120/121 with the CUTLASS implementation.
In `@tests/gemm/test_mm_mxfp8_sm120.py`:
- Around line 15-20: The helper _is_sm120_available should avoid catching all
exceptions; update it to first check torch.cuda.is_available() and then call
get_compute_capability(torch.device("cuda")), or alternatively catch only
RuntimeError from get_compute_capability; specifically modify the
_is_sm120_available function to return False immediately if
torch.cuda.is_available() is False, and when calling get_compute_capability (the
symbol in question) catch RuntimeError only and use that to return False instead
of a broad except Exception.
---
Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3409-3420: Update _cutlass_gemm_mxfp8_requirement to proactively
reject inputs that SM12x CUTLASS cannot support: detect when backend selection
might pick CUTLASS (backend in {"auto","cutlass"}) and return False for SM12x
GPUs if any of the SM12x-only constraints are violated — e.g., a_descale or
b_descale are not 1D layout_128x4-style scales, use_8x4_sf_layout is True when
SM12x doesn't support that layout, or N/K tensor dimensions/sizes fall into
ranges the SM120 launcher rejects; apply the same validation logic to the
analogous check around lines 3933-3935 so backend="auto" won't choose CUTLASS
for incompatible 2D scales, layouts, or N/K values. Ensure you reference and use
the function name _cutlass_gemm_mxfp8_requirement (and the similar guard at
3933-3935) to locate and implement these condition checks and return False when
constraints are not met.
---
Nitpick comments:
In `@include/flashinfer/gemm/mxfp8_gemm_cutlass_template_sm120.h`:
- Around line 107-113: The workspace-size probe in getWorkspaceSizeImpl
currently catches and swallows std::runtime_error thrown by dispatchToArch (used
to probe candidate gemmConfig) without explanation; add a brief comment above
the try/catch explaining that probing may throw for unsupported/config-invalid
configs and that these exceptions are intentionally ignored to allow trying
other configs (i.e., this is a non-fatal probe, not a logic bug), referencing
dispatchToArch and workspace_size so readers know why the catch exists and that
it only applies to probing, not to actual execution.
In `@tests/gemm/test_mm_mxfp8_sm120.py`:
- Around line 28-45: Extend the test matrix in tests/gemm/test_mm_mxfp8_sm120.py
to cover SM120-specific code paths by (1) adding a 3D batched input case that
calls _prepare_mxfp8 and exercises bmm_mxfp8 (e.g., shapes with batch>1) so the
batched branch is executed; (2) adding at least one case where N and K are
multiples of 32 but not 128 (e.g., 96 or 160) to hit the alignment path that
handles 32-granularity; and (3) adding a pytest.raises test that calls
mxfp8_quantize/_prepare_mxfp8 (using SfLayout.layout_linear) with an input shape
that should trigger the linear-scale rejection so the code path that raises for
invalid linear scales is covered; ensure the new tests reference _prepare_mxfp8,
mxfp8_quantize, bmm_mxfp8 and use SfLayout.layout_linear / SfLayout.layout_128x4
to select swizzled vs linear flows.
- Around line 81-109: The test test_mm_mxfp8_sm120_all_tactics only exercises
the default tactic; query module.mxfp8_gemm_tactic_num() and loop i from 0 to
num_tactics-1, invoking the same execution path for each tactic (use the module
or mm_mxfp8 API that accepts a tactic/index parameter or call the module’s
tactic-dispatch method) and validate shape, finiteness and cosine similarity
per-tactic; alternatively if you intend to only check tactic count, rename the
test to reflect that (e.g., test_mm_mxfp8_sm120_reports_tactics) instead of
claiming it exercises all tactics.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d7b77a63-ee11-4a91-8929-f087c895b0d9
📒 Files selected for processing (11)
csrc/mxfp8_gemm_cutlass_sm120.cucsrc/mxfp8_gemm_cutlass_sm120.jinjadocs/api/gemm.rstflashinfer/aot.pyflashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/__init__.pyflashinfer/jit/gemm/core.pyinclude/flashinfer/gemm/cutlass_gemm_configs.hinclude/flashinfer/gemm/mxfp8_gemm_cutlass_template_sm120.hinclude/flashinfer/gemm/mxfp8_gemm_template_sm120.htests/gemm/test_mm_mxfp8_sm120.py
| .. autosummary:: | ||
| :toctree: ../generated | ||
|
|
||
| mm_mxfp8 |
There was a problem hiding this comment.
Add bmm_mxfp8 to the generated GEMM docs.
The new section only exposes mm_mxfp8, so the batched MXFP8 API stays undiscoverable in the generated reference docs.
📝 Suggested update
.. autosummary::
:toctree: ../generated
mm_mxfp8
+ bmm_mxfp8📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| .. autosummary:: | |
| :toctree: ../generated | |
| mm_mxfp8 | |
| .. autosummary:: | |
| :toctree: ../generated | |
| mm_mxfp8 | |
| bmm_mxfp8 |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/api/gemm.rst` around lines 30 - 33, The autosummary list for GEMM
currently only documents mm_mxfp8, omitting the batched API bmm_mxfp8; update
the autosummary block to include bmm_mxfp8 alongside mm_mxfp8 so the batched
MXFP8 function is included in the generated reference docs (i.e., add the symbol
name bmm_mxfp8 into the same autosummary list that contains mm_mxfp8).
| @functools.cache | ||
| def get_gemm_sm120_module_cutlass_mxfp8(): | ||
| """Get the SM120/121 MXFP8 GEMM module.""" | ||
| module = gen_gemm_sm120_module_cutlass_mxfp8().build_and_load() | ||
| return _create_cutlass_mxfp8_gemm_module( | ||
| module, "flashinfer::cutlass_mxfp8_gemm", "cutlass_mxfp8_gemm" | ||
| ) | ||
|
|
||
|
|
||
| def get_cutlass_mxfp8_gemm_module( | ||
| sm_major: int, | ||
| ): | ||
| if sm_major in [10, 11]: | ||
| return get_gemm_sm100_module_cutlass_mxfp8() | ||
| elif sm_major in [12]: | ||
| return get_gemm_sm120_module_cutlass_mxfp8() |
There was a problem hiding this comment.
SM120 CUTLASS is wired only into mm_mxfp8().
get_gemm_sm120_module_cutlass_mxfp8() is only consumed by the 2D path. bmm_mxfp8() below still hard-codes the cuDNN runner via mxfp8_gemm_sm100(..., ["cudnn"]), and _cudnn_bmm_mxfp8_requirement still only allows SM100/103, so batched MXFP8 remains unavailable on SM120/121.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gemm/gemm_base.py` around lines 3271 - 3286, The batched MXFP8
path isn't using the SM120/121 CUTLASS module: update the bmm path to support
SM120/121 by wiring the CUTLASS runner there and relaxing the cuDNN requirement
checks. Specifically, modify bmm_mxfp8 to prefer the CUTLASS module for sm_major
12 (use get_cutlass_mxfp8_gemm_module or call the equivalent cutlass runner
instead of hard-coding mxfp8_gemm_sm100(..., ["cudnn"])), and update
_cudnn_bmm_mxfp8_requirement to allow SM120/121 (or to defer to the CUTLASS
option) so batched MXFP8 can run on SM120/121 with the CUTLASS implementation.
| def _is_sm120_available() -> bool: | ||
| try: | ||
| cc = get_compute_capability(torch.device("cuda")) | ||
| return cc[0] == 12 | ||
| except Exception: | ||
| return False |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Find the implementation of get_compute_capability
fd -t f "*.py" | xargs grep -l "def get_compute_capability" | head -5Repository: flashinfer-ai/flashinfer
Length of output: 351
🏁 Script executed:
# Look at the test file being reviewed
cat -n tests/gemm/test_mm_mxfp8_sm120.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1105
🏁 Script executed:
# Search for other uses of get_compute_capability in tests to see patterns
rg "get_compute_capability" tests/ -A 3 -B 1 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 2999
🏁 Script executed:
# Find where get_compute_capability is defined
rg "def get_compute_capability" -A 10Repository: flashinfer-ai/flashinfer
Length of output: 1406
🏁 Script executed:
# Check how it's used in similar helper functions across the codebase
rg "_is_sm|_require_sm|skip.*sm" tests/ --type py -A 2 -B 1 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 4915
🏁 Script executed:
# Look at any docstring or error handling in flashinfer.utils
fd flashinfer utils.py -type f | head -5Repository: flashinfer-ai/flashinfer
Length of output: 238
🏁 Script executed:
# Check if CUDA availability is checked explicitly anywhere
rg "torch.cuda.is_available" tests/ --type py -B 2 | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1785
🏁 Script executed:
# Look at the actual implementation details of get_device_capability
python3 << 'EOF'
import torch
try:
# Check what exceptions can be raised
help(torch.cuda.get_device_capability)
except Exception as e:
print(f"Error: {e}")
EOFRepository: flashinfer-ai/flashinfer
Length of output: 185
🏁 Script executed:
# Find similar patterns where exception handling is used for device checks
rg "except.*RuntimeError|except.*ValueError" tests/ --type py -B 3 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1732
🏁 Script executed:
# Check if there's a pattern of checking torch.cuda.is_available before device checks
rg "torch.cuda.is_available" tests/gemm/ -A 2 -B 1Repository: flashinfer-ai/flashinfer
Length of output: 382
🏁 Script executed:
# Look at the specific exceptions that torch.cuda.get_device_capability can raise
rg "get_device_capability" --type py -B 2 -A 5 | head -60Repository: flashinfer-ai/flashinfer
Length of output: 3667
🏁 Script executed:
# Check what ValueError message would be if device type is wrong
cat -n flashinfer/utils.py | grep -A 5 "get_compute_capability"Repository: flashinfer-ai/flashinfer
Length of output: 3205
🏁 Script executed:
# See if the test file initializes CUDA before calling _is_sm120_available
cat -n tests/gemm/test_mm_mxfp8_sm120.py | grep -A 15 "_is_sm120_available\|_skip_if_not_sm120\|def test_"Repository: flashinfer-ai/flashinfer
Length of output: 3107
Replace except Exception with specific exception handling or add availability check.
The current code catches all exceptions, which can silently mask unexpected failures. Either:
- Catch only
RuntimeError(whichtorch.cuda.get_device_capability()raises when CUDA is unavailable) - Or add
torch.cuda.is_available()check first, matching the pattern used elsewhere in the codebase (e.g.,tests/gdn/test_prefill_delta_rule.py,tests/model_optimizations/test_tinygemm2.py)
🧰 Tools
🪛 Ruff (0.15.7)
[warning] 19-19: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gemm/test_mm_mxfp8_sm120.py` around lines 15 - 20, The helper
_is_sm120_available should avoid catching all exceptions; update it to first
check torch.cuda.is_available() and then call
get_compute_capability(torch.device("cuda")), or alternatively catch only
RuntimeError from get_compute_capability; specifically modify the
_is_sm120_available function to return False immediately if
torch.cuda.is_available() is False, and when calling get_compute_capability (the
symbol in question) catch RuntimeError only and use that to return False instead
of a broad except Exception.
|
is this compatible with DGX Spark? @samuellees |
Resolves #2728
Key changes:
mxfp8_gemm_cutlass_template_sm120.husingSm1xxBlkScaledConfigwith 3 CTA tile configs (128×128×128, 256×128×128, 128×256×128)csrc/mxfp8_gemm_cutlass_sm120.cuwith full input validation (K%32==0, N%32==0, M unconstrained, swizzled-only scale format)gen_gemm_sm120_module_cutlass_mxfp8()inflashinfer/jit/gemm/core.pyhas_sm120 or has_sm121inflashinfer/aot.pymm_mxfp8docstring with SM12x swizzled-only notemm_mxfp8todocs/api/gemm.rsttests/gemm/test_mm_mxfp8_sm120.pycovering arbitrary M (including non-multiples of 128), N/K in {128,256,512,1024}, both bf16/fp16 output, cos_sim > 0.99 accuracy thresholdAI-assisted
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Documentation
Tests