-
Notifications
You must be signed in to change notification settings - Fork 330
[Language] support T.gemm_sp_v2 on sm80 and sm89
#1056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughAdds a new gemm_sp_v2 TileLang API and TL operator (GemmSPPy) with FFI wiring and GemmSPWarpPolicy, a SparseTensorCore intrinsics emitter plus many SP layout helpers, refactors metadata layout to make_cutlass_metadata_layout, updates examples/benchmarks/tests, and exposes related utilities and tests. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant TileLang as gemm_sp_v2
participant TL_Op as GemmSPPy
participant TileOp as GemmSPMMA
participant Emitter as SparseTensorCoreIntrinEmitter
User->>TileLang: call gemm_sp_v2(A_sparse, E, B, C, ...)
TileLang->>TL_Op: construct GemmSPPy node (buffers, args)
TL_Op->>TileOp: infer_layout(target, thread_nums)
TileOp->>Emitter: build emitter for pattern (ss/sr/rs/rr)
Emitter->>Emitter: ldmatrix / make_mma_load_layout / mma_sp / stmatrix
Emitter-->>TileOp: fragment/layout map
TileOp-->>TL_Op: layout_map
TL_Op->>TileOp: lower(target, thread_nums, thread_var)
TileOp->>Emitter: emit lowering -> prim_func
TileOp-->>TL_Op: lowered kernel
TL_Op-->>User: compiled kernel handle
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes
Possibly related issues
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
016dd1c to
122abb5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
docs/deeplearning_operators/matmul_sparse.md (1)
39-39: Improve link text for accessibility.The duplicate
[here]links violate MD059 and make the doc harder to navigate with assistive tech. Please replace them with descriptive titles (e.g.,[PyTorch sparse kernel],[vLLM sparse kernel]).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
docs/deeplearning_operators/matmul_sparse.md(1 hunks)testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py(1 hunks)tilelang/intrinsics/mma_sp_macro_generator.py(1 hunks)tilelang/language/experimental/gemm_sp.py(2 hunks)tilelang/tileop/gemm_sp/gemm_sp_mma.py(1 hunks)
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
docs/deeplearning_operators/matmul_sparse.mdtilelang/intrinsics/mma_sp_macro_generator.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
tilelang/intrinsics/mma_sp_macro_generator.py
🧬 Code graph analysis (4)
tilelang/tileop/gemm_sp/gemm_sp_mma.py (6)
tilelang/tileop/gemm_sp/gemm_sp_base.py (20)
GemmSPBase(11-127)infer_layout(14-15)policy(126-127)M(33-34)N(37-38)in_dtype(57-59)e_dtype(53-54)accum_dtype(62-63)trans_A(45-46)trans_B(49-50)K(41-42)is_gemm_ss(20-21)A(66-67)B(74-75)C(78-79)is_gemm_sr(23-24)is_gemm_rs(26-27)is_gemm_rr(29-30)lower(17-18)E(70-71)tilelang/layout/swizzle.py (1)
make_swizzled_layout(10-18)tilelang/intrinsics/mma_sp_macro_generator.py (6)
make_mma_store_layout(789-858)make_mma_load_layout(646-787)ldmatrix_a(290-354)ldmatrix_e(356-418)ldmatrix_b(420-516)mma_sp(518-589)tilelang/utils/language.py (2)
is_shared(25-39)is_fragment(81-91)tilelang/transform/simplify.py (1)
_Simplify(31-49)tilelang/tileop/gemm_sp/__init__.py (2)
infer_layout(56-61)lower(63-69)
tilelang/language/experimental/gemm_sp.py (2)
tilelang/utils/language.py (1)
get_buffer_region_from_load(137-159)tilelang/language/gemm.py (10)
legalize_arguments(48-59)legalize_arguments(251-262)retrieve_shape(66-83)retrieve_shape(269-286)retrieve_stride(85-111)retrieve_stride(288-314)retrieve_ptr(140-175)retrieve_ptr(343-378)retrieve_offset(177-195)retrieve_offset(380-398)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (6)
tilelang/utils/sparse.py (3)
compress(77-106)randn_semi_sparse(109-128)randint_semi_sparse(131-158)tilelang/utils/tensor.py (2)
torch_assert_close(237-329)map_torch_type(37-54)tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(136-150)tilelang/intrinsics/mma_sp_macro_generator.py (1)
SparseTensorCoreIntrinEmitter(40-858)tilelang/language/experimental/gemm_sp.py (1)
gemm_sp_v2(91-307)tilelang/layout/swizzle.py (1)
make_swizzled_layout(10-18)
tilelang/intrinsics/mma_sp_macro_generator.py (5)
tilelang/intrinsics/utils.py (2)
mma_store_index_map(81-82)get_ldmatrix_offset(21-63)tilelang/utils/language.py (1)
is_fragment(81-91)tilelang/intrinsics/mma_sp_layout.py (20)
shared_16x16_to_mma_sp_layout_sr_a(14-15)shared_16x16_to_mma_sp_layout_sr_b(18-20)shared_16x32_to_mma_sp_layout_sr_a(23-24)shared_16x32_to_mma_sp_layout_sr_b(27-29)shared_16x64_to_mma_sp_layout_sr_a(32-33)shared_16x64_to_mma_sp_layout_sr_b(36-38)mma_sp_load_a_32x4_to_shared_16x16_layout(41-42)mma_sp_load_a_32x8_to_shared_16x32_layout(45-46)mma_sp_load_a_32x16_to_shared_16x64_layout(49-50)mma_sp_load_b_32x8_to_shared_16x16_layout(53-56)mma_sp_load_b_32x16_to_shared_16x32_layout(59-62)mma_sp_load_b_32x32_to_shared_16x64_layout(65-68)metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(75-80)metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(83-88)metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(91-94)metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(97-100)metadata_8bit_load_32x4_to_shared_16x4_layout_8bit(107-112)metadata_16bit_load_32x2_to_shared_16x4_layout_8bit(115-120)metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(123-129)get_ldmatrix_offset_b(156-190)tilelang/language/tir/op.py (2)
ptx_ldmatrix(1313-1349)ptx_mma_sp(964-1062)tilelang/layout/fragment.py (2)
replicate(147-161)repeat(124-145)
🪛 markdownlint-cli2 (0.18.1)
docs/deeplearning_operators/matmul_sparse.md
39-39: Link text should be descriptive
(MD059, descriptive-link-text)
39-39: Link text should be descriptive
(MD059, descriptive-link-text)
176-176: Link text should be descriptive
(MD059, descriptive-link-text)
🪛 Ruff (0.14.3)
tilelang/tileop/gemm_sp/gemm_sp_mma.py
57-58: Avoid specifying long messages outside the exception class
(TRY003)
232-233: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/experimental/gemm_sp.py
161-162: Prefer TypeError exception for invalid type
(TRY004)
161-162: Avoid specifying long messages outside the exception class
(TRY003)
189-190: Prefer TypeError exception for invalid type
(TRY004)
189-190: Avoid specifying long messages outside the exception class
(TRY003)
253-254: Prefer TypeError exception for invalid type
(TRY004)
253-254: Avoid specifying long messages outside the exception class
(TRY003)
273-274: Prefer TypeError exception for invalid type
(TRY004)
273-274: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/mma_sp_macro_generator.py
52-60: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
62-119: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
121-129: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
178-180: Avoid specifying long messages outside the exception class
(TRY003)
212-212: Avoid specifying long messages outside the exception class
(TRY003)
313-313: Avoid specifying long messages outside the exception class
(TRY003)
381-381: Avoid specifying long messages outside the exception class
(TRY003)
390-390: Avoid specifying long messages outside the exception class
(TRY003)
395-395: Avoid specifying long messages outside the exception class
(TRY003)
397-397: Avoid specifying long messages outside the exception class
(TRY003)
445-445: Avoid specifying long messages outside the exception class
(TRY003)
696-696: Avoid specifying long messages outside the exception class
(TRY003)
713-713: Avoid specifying long messages outside the exception class
(TRY003)
771-771: Avoid specifying long messages outside the exception class
(TRY003)
785-785: Avoid specifying long messages outside the exception class
(TRY003)
| dtype_abbrv = { | ||
| "float16": "fp16", | ||
| "bfloat16": "bf16", | ||
| "float32": "fp32", | ||
| "int8": "int8", | ||
| "int32": "int32", | ||
| "float8_e4m3": "e4m3", | ||
| "float8_e5m2": "e5m2", | ||
| } | ||
|
|
||
| E_FACTOR_MAP = { # e_kdim = mma_kdim // e_factor | ||
| "float": { | ||
| "int16": 8, | ||
| "uint16": 8, | ||
| }, | ||
| "float32": { | ||
| "int16": 8, | ||
| "uint16": 8, | ||
| }, | ||
| "float16": { | ||
| "int8": 8, | ||
| "uint8": 8, | ||
| "int16": 16, | ||
| "uint16": 16, | ||
| "int32": 32, | ||
| "uint32": 32, | ||
| }, | ||
| "bfloat16": { | ||
| "int8": 8, | ||
| "uint8": 8, | ||
| "int16": 16, | ||
| "uint16": 16, | ||
| "int32": 32, | ||
| "uint32": 32, | ||
| }, | ||
| "int8": { | ||
| "int8": 8, | ||
| "uint8": 8, | ||
| "int16": 16, | ||
| "uint16": 16, | ||
| "int32": 32, | ||
| "uint32": 32, | ||
| }, | ||
| "uint8": { | ||
| "int8": 8, | ||
| "uint8": 8, | ||
| "int16": 16, | ||
| "uint16": 16, | ||
| "int32": 32, | ||
| "uint32": 32, | ||
| }, | ||
| "float8_e4m3": { | ||
| "int8": 8, | ||
| "uint8": 8, | ||
| "int16": 16, | ||
| "uint16": 16, | ||
| "int32": 32, | ||
| "uint32": 32, | ||
| }, | ||
| "float8_e5m2": { | ||
| "int8": 8, | ||
| "uint8": 8, | ||
| "int16": 16, | ||
| "uint16": 16, | ||
| "int32": 32, | ||
| "uint32": 32, | ||
| }, | ||
| } | ||
|
|
||
| E_REPLICATE_FACTOR = { # metadata replicate every 4 consecutive threads | ||
| "float32": 2, | ||
| "float16": 2, # 2 of 4 consecutive threads provides | ||
| "bfloat16": 2, | ||
| "int8": 1, # 4 of 4 consecutive threads provides | ||
| "uint8": 1, | ||
| "float8_e4m3": 1, | ||
| "float8_e5m2": 1, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add tf32 dtype metadata to emitter tables.
E_FACTOR_MAP already exposes "float" entries, but dtype_abbrv and E_REPLICATE_FACTOR don’t, so creating the emitter with in_dtype="float" (tf32) raises a KeyError. That breaks the advertised tf32 path. Please add the missing mappings so we can actually emit tf32 kernels.
Apply this diff:
dtype_abbrv = {
+ "float": "tf32",
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
@@
E_REPLICATE_FACTOR = { # metadata replicate every 4 consecutive threads
+ "float": 2,
"float32": 2,
"float16": 2, # 2 of 4 consecutive threads provides
"bfloat16": 2,🧰 Tools
🪛 Ruff (0.14.3)
52-60: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
62-119: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
121-129: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
🤖 Prompt for AI Agents
In tilelang/intrinsics/mma_sp_macro_generator.py around lines 52 to 129, the
emitter tables are missing entries for the tf32 dtype key "float", causing a
KeyError when in_dtype="float"; add "float": "tf32" to the dtype_abbrv mapping
and add "float": 2 to E_REPLICATE_FACTOR (matching float32 behavior) so the
existing E_FACTOR_MAP "float" entries can be used without error.
|
we're good to go if we can resolve the conflict and I think then we can let this pr in. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (2)
5-11: Scope TF32 configuration to avoid global side effectsSetting
torch.backends.cuda.matmul.allow_tf32 = Falseat module import time affects all CUDA matmul calls in the test process, not just this module. Consider moving this into the specific tests (or a context manager) so other tests that rely on TF32 behavior are not implicitly changed.
81-103: Metadata layout wiring: sm80 is nicely aligned to emitter; please double‑check sm90 dtype choiceFor sm80, deriving
metadata_dtype(int32for 8‑bit, otherwiseint16) andE_factorfromSparseTensorCoreIntrinEmitter.E_FACTOR_MAP, then passingmma_dtype=in_dtypeintomake_cutlass_metadata_layout, is a good match to the intrinsics’ expectations and the CUTLASS-style layout pattern. This aligns with the established “layout depends on MMA element type” approach. Based on learnings.For sm90,
E/E_sharedare hard‑coded asuint8andmake_cutlass_metadata_layout(..., mma_dtype="float16", arch="9.0", block_k=block_K)is used even whenin_dtypeisint8orfloat8_e4m3. If the Hopper metadata layout truly varies with the MMA operand type (as it does on sm80), you may want to mirror the sm80 pattern and drivemma_dtype(and possibly metadata dtype/E_factor) fromin_dtypeas well.Please double‑check that:
- The
E_factorand metadata dtype used here for each supportedin_dtypematch whatSparseTensorCoreIntrinEmitterexpects on sm90, andmma_dtype="float16"is correct for non‑fp16 inputs, or adjust it to bein_dtypewhere appropriate.Also applies to: 124-148
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py(7 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
🧬 Code graph analysis (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (4)
tilelang/utils/sparse.py (2)
randn_semi_sparse(109-128)randint_semi_sparse(131-158)tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(136-150)tilelang/utils/tensor.py (2)
torch_assert_close(237-329)map_torch_type(37-54)tilelang/intrinsics/mma_sp_macro_generator.py (1)
SparseTensorCoreIntrinEmitter(40-858)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (2)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (2)
14-44: Dense input generator looks consistent with kernel/transposition contractsThe shape/dtype logic in
generate_dense_inputmatches the later_matmulconventions: after applyingtrans_A/trans_B, you always get(M, K) @ (K, N), andmap_torch_type(in_dtype)keeps A/B consistent across integer and float/float8 cases. This should keep the tests well-aligned with the kernels.
195-202: Centralized input generation and comparison path looks correct and robustReusing
generate_dense_inputinrun_gemm_spand comparingC_spvs denseCvia_matmulplustorch_assert_closeonfloat32(with explicit upcast forfloat8/int8) is a solid improvement over ad‑hoc checks. Shapes line up with the kernels, and the tolerances (rtol=atol=1e-3) are reasonable for these dtypes.Also applies to: 215-226
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (1)
166-178: Remove unused helper functionsnormalizeandcalc_diff.Search results confirm these functions are not called anywhere in the codebase. They appear to be leftover from prior validation logic now handled by
torch_assert_close. Remove them to reduce unnecessary code.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py(8 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
🧬 Code graph analysis (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (4)
tilelang/utils/sparse.py (3)
compress(77-106)randn_semi_sparse(109-128)randint_semi_sparse(131-158)tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(136-150)tilelang/utils/tensor.py (2)
torch_assert_close(237-329)map_torch_type(37-54)tilelang/intrinsics/mma_sp_macro_generator.py (1)
SparseTensorCoreIntrinEmitter(40-858)
🔇 Additional comments (5)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (5)
5-12: LGTM!The new imports align with the gemm_sp_v2 updates, and disabling TF32 ensures consistent numerical behavior in tests.
14-44: LGTM!The input generation logic correctly handles different data types, using appropriate ranges for integer types and generating float types via float32 intermediates to support float8 variants.
124-147: LGTM!The dynamic determination of
metadata_dtypeandE_factorfromSparseTensorCoreIntrinEmitter.E_FACTOR_MAP, along with passingmma_dtype=in_dtypeto the layout helper, provides appropriate flexibility for various input data types on SM80.
195-226: LGTM!The refactored validation logic using
torch_assert_closeis cleaner and more maintainable than the previous comparison approach, with appropriate tolerances for sparse matrix multiplication tests.
316-368: LGTM!The test suite provides comprehensive coverage across architectures (SM80/SM90), data types (float16, float8_e4m3, int8), and configurations (various block sizes, transpositions, and pipeline stages).
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (1)
38-43: Consider passing target dtype directly to avoid redundant conversion.The
randn_semi_sparsefunction already handles dtype conversion internally (creates in float32, applies sparsity mask, converts to target dtype). Passingdtype=torch.float32then calling.to(map_torch_type(in_dtype))adds an extra no-op conversion step when the input is already float32.Apply this diff to streamline the conversion:
- A = randn_semi_sparse( - M, K, dtype=torch.float32, device='cuda', - transposed=trans_A).to(map_torch_type(in_dtype)) - B = torch.randn( - (N, K) if trans_B else (K, N), device='cuda', - dtype=torch.float32).to(map_torch_type(in_dtype)) + A = randn_semi_sparse( + M, K, dtype=map_torch_type(in_dtype), device='cuda', + transposed=trans_A) + B = torch.randn( + (N, K) if trans_B else (K, N), device='cuda', + dtype=map_torch_type(in_dtype))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py(7 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
🧬 Code graph analysis (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (6)
tilelang/utils/sparse.py (3)
compress(77-106)randn_semi_sparse(109-128)randint_semi_sparse(131-158)tilelang/layout/gemm_sp.py (1)
make_cutlass_metadata_layout(136-150)tilelang/utils/tensor.py (2)
torch_assert_close(237-329)map_torch_type(37-54)tilelang/intrinsics/mma_sp_macro_generator.py (1)
SparseTensorCoreIntrinEmitter(40-858)testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (1)
generate_dense_input(136-165)tilelang/language/experimental/gemm_sp.py (1)
gemm_sp(10-87)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (7)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (7)
10-10: Good practice: TF32 disabled for test precision.Disabling TF32 ensures numerical accuracy when comparing sparse GEMM results against dense reference implementations.
84-88: Past issue resolved:mma_dtypenow correctly usesin_dtype.The previous review flagged hardcoded
mma_dtype="float16"which would produce incorrect metadata layouts for int8/float8 dtypes. This has been fixed—mma_dtype=in_dtypeensures the layout atoms and BlockK are computed correctly for all supported dtypes.
124-125: Good: Dynamic metadata dtype selection for SM80.The metadata_dtype is now correctly selected based on input bit width (int32 for 8-bit, int16 for 16-bit), with E_factor computed from the intrinsic emitter's lookup table. This enables proper support for int8 and other 8-bit dtypes.
146-147: Past issue resolved: SM80 metadata layout now usesin_dtype.Similar to SM90, the hardcoded
mma_dtype="float16"has been replaced withmma_dtype=in_dtype, ensuring correct metadata layout generation for all supported dtypes including int8 (tested at lines 366-372).
223-230: Good: Unified validation usingtorch_assert_close.Replacing the ad-hoc diff calculation with
torch_assert_closeprovides consistent, configurable comparison across dtypes. The float32 normalization ensures compatibility when comparing results from different accumulation types.
342-344: Good test coverage: Verifies int8/float8 support on SM90.These tests validate that the fixed
mma_dtype=in_dtypecorrectly handles 8-bit dtypes with appropriate metadata layouts, addressing the past review concerns.
366-372: Good test coverage: Verifies int8 support on SM80.Comprehensive int8 testing across multiple configurations validates the dynamic
metadata_dtypeselection and fixedmma_dtypeparameter, ensuring correct operation for 8-bit dtypes on SM80 architecture.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tilelang/layout/gemm_sp.py (1)
115-120: Consider extracting exception messages into constants.Static analysis suggests avoiding long exception messages directly in raise statements for better maintainability.
Example refactor:
+METADATA_BIT_MISMATCH_MSG_16 = "metadata should be 16 bit, got {}" +METADATA_BIT_MISMATCH_MSG_32 = "metadata should be 32 bit, got {}" + if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]: - raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}") + raise ValueError(METADATA_BIT_MISMATCH_MSG_16.format(buffer.dtype)) if mma_dtype in ["float8_e4m3", "float8_e5m2", "int8", "uint8" ] and buffer.dtype not in ["uint32", "int32"]: - raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}") + raise ValueError(METADATA_BIT_MISMATCH_MSG_32.format(buffer.dtype))Based on learnings
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/layout/gemm_sp.py(4 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
tilelang/layout/gemm_sp.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
tilelang/layout/gemm_sp.py
🧬 Code graph analysis (1)
tilelang/layout/gemm_sp.py (2)
tilelang/layout/layout.py (1)
Layout(13-148)tilelang/contrib/nvcc.py (2)
get_target_compute_version(409-450)parse_compute_version(453-475)
🪛 Ruff (0.14.4)
tilelang/layout/gemm_sp.py
116-116: Avoid specifying long messages outside the exception class
(TRY003)
120-120: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (5)
tilelang/layout/gemm_sp.py (5)
34-34: LGTM! FP8 dtype support correctly added.The extension to support "float8_e4m3" and "float8_e5m2" aligns with the PR objectives and is correctly implemented with 8-bit mappings.
Also applies to: 45-46
138-151: Clarify the handling of extra_args for SM8x architecture.The dispatcher passes
**extra_argsto the SM90 implementation (line 148) but not to the SM8x implementation (line 150). While the SM8x implementation doesn't currently acceptblock_k, this could cause silent failures if callers provide extra arguments expecting them to be used.Consider one of the following:
- If SM8x truly doesn't need extra_args: Add explicit validation to warn users if extra_args are provided for SM8x:
if compute_version >= (9, 0): return make_cutlass_metadata_layout_sm90(buffer=buffer, mma_dtype=mma_dtype, **extra_args) elif compute_version >= (8, 0): if extra_args: warnings.warn(f"extra_args {extra_args} ignored for SM8x architecture", stacklevel=2) return make_cutlass_metadata_layout_sm8x(buffer=buffer, mma_dtype=mma_dtype)
- If SM8x should accept and ignore extra_args for API consistency: Update the SM8x signature:
def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str, **extra_args): # Ignore extra_args for SM8x ...Which approach aligns better with the intended API design?
21-21: LGTM! Public API exposure provides useful flexibility.Making
make_cutlass_metadata_layout_sm90andmake_cutlass_metadata_layout_sm8xpublic allows advanced users to directly call architecture-specific implementations when needed, while the dispatcher function provides a convenient high-level API.Also applies to: 108-108
118-120: LGTM! FP8/int8 validation correctly enforces 32-bit buffer requirement.The validation ensures that fp8 and int8 sparse metadata uses 32-bit buffers on SM8x, which is the correct requirement for these data types.
122-133: Implementation verified against PyTorch reference—no issues found.The new
ColumnMajorInterleavedfunction correctly implements the CUTLASS metadata layout algorithm. Comparison with the PyTorch reference confirms:
- The group/interweave parameters match (
bits == 16→ itemsize == 2)- The row reordering formula is identical
- The topright/bottomleft swizzle logic is equivalent (using bitwise arithmetic rather than boolean comparison)
- The final offset computation matches the reference; returning
(offset // k, offset % k)appropriately converts the flat offset to matrix indices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (6)
src/op/gemm_sp_py.cc (3)
68-68: UseGemmSPWarpPolicyinstead ofGemmWarpPolicy.The policy should be constructed as
GemmSPWarpPolicyto use sparse-specific warp partition logic, as noted in previous reviews. The header (src/op/gemm_sp_py.h:36) declarespolicyasGemmWarpPolicy, but the sparse GEMM path requiresGemmSPWarpPolicyfor correct atom-size adjustments.Apply this diff:
- node->policy = GemmWarpPolicy(args[10].as<IntImm>().value()->value); + node->policy = GemmSPWarpPolicy(args[10].as<IntImm>().value()->value);
213-224: Add error handling forstd::stoi.If the architecture string contains non-numeric characters after "sm_" (e.g., "sm_abc"),
std::stoiwill throwstd::invalid_argument, potentially causing crashes. This was previously flagged but remains unaddressed.Apply this diff:
static int GetArchInt(Target target) { int arch_int = 0; auto s = target->GetAttr<String>("arch"); ICHECK(s.has_value()); std::string arch = s.value(); if (arch.rfind("sm_", 0) == 0) { - arch_int = std::stoi(arch.substr(3)); + try { + arch_int = std::stoi(arch.substr(3)); + } catch (const std::exception& e) { + LOG(WARNING) << "Failed to parse architecture from '" << arch << "': " << e.what(); + arch_int = 0; + } } else { arch_int = 0; } return arch_int; }
226-231: FixcomputeWarpPartitioncall signature.The call to
policy->computeWarpPartitionpasses aGemmInstenum where abool use_wgmmais expected, and omits the requiredbitsparameter. This was previously flagged but remains unaddressed.Apply this diff:
auto block_size = *as_const_int(T.thread_bounds->extent); GemmInst gemm_inst = GetGemmInst(block_size, T.target); + bool use_wgmma = (gemm_inst == GemmInst::kWGMMA); auto [warp_m, warp_n] = - policy->computeWarpPartition(M, N, block_size, T.target, gemm_inst); + policy->computeWarpPartition(M, N, block_size, T.target, use_wgmma, A->dtype.bits());src/op/gemm_sp_py.h (1)
36-36: DeclarepolicyasGemmSPWarpPolicyinstead ofGemmWarpPolicy.The policy should be declared as
mutable GemmSPWarpPolicyto use sparse-specific warp partition logic with atom-size adjustments. This was previously flagged but remains unaddressed.Apply this diff:
- mutable GemmWarpPolicy policy; + mutable GemmSPWarpPolicy policy;tilelang/intrinsics/mma_sp_macro_generator.py (2)
52-129: Add tf32 dtype metadata to emitter tables.
E_FACTOR_MAPsupports"float"entries (lines 63-66), butdtype_abbrv(lines 52-60) andE_REPLICATE_FACTOR(lines 121-129) don't, causing aKeyErrorwhen creating an emitter within_dtype="float"(tf32). This was previously flagged but remains unaddressed.Apply this diff:
dtype_abbrv = { + "float": "tf32", "float16": "fp16", "bfloat16": "bf16", "float32": "fp32",E_REPLICATE_FACTOR = { # metadata replicate every 4 consecutive threads + "float": 2, "float32": 2, "float16": 2, # 2 of 4 consecutive threads provides "bfloat16": 2,
423-435: Guard ldmatrix for int8 + transpose correctly.The comment at line 434 states ldmatrix is unavailable for int8 with transposed layout, but the condition at line 435 disables ldmatrix for the opposite case (int8 + not transposed). This was previously flagged but remains unaddressed.
Apply this diff:
- ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed) + ldmatrix_available = not (DataType(b_dtype).bits != 16 and b_transposed)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
src/op/gemm_sp_py.cc(1 hunks)src/op/gemm_sp_py.h(1 hunks)tilelang/intrinsics/mma_sp_macro_generator.py(1 hunks)tilelang/language/experimental/gemm_sp.py(3 hunks)tilelang/tileop/gemm_sp/gemm_sp_base.py(1 hunks)tilelang/tileop/gemm_sp/gemm_sp_mma.py(1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
src/op/gemm_sp_py.cctilelang/intrinsics/mma_sp_macro_generator.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
src/op/gemm_sp_py.cctilelang/intrinsics/mma_sp_macro_generator.py
🧬 Code graph analysis (4)
src/op/gemm_sp_py.h (7)
src/op/gemm_sp.h (2)
tl(15-111)RegisterReflection(27-33)src/op/operator.h (2)
TileOperatorNode(55-65)TileOperator(67-71)src/op/gemm_sp_py.cc (11)
CheckWGMMA(145-197)CheckWGMMA(145-145)Lower(226-260)Lower(226-226)InferLayout(262-277)InferLayout(262-263)Clone(94-97)Clone(94-94)GetGemmInst(99-113)GetGemmInst(99-99)GemmSPPy(51-84)tilelang/tileop/gemm_sp/gemm_sp_base.py (17)
A(70-71)E(74-75)B(78-79)C(82-83)trans_A(45-46)trans_B(49-50)trans_E(53-54)M(33-34)N(37-38)K(41-42)stride_A(102-103)stride_B(106-107)offset_A(110-111)offset_B(114-115)clear_accum(118-119)wg_wait(126-127)policy(130-131)src/op/gemm.h (6)
GemmWarpPolicy(59-83)GemmWarpPolicy(64-68)GemmWarpPolicy(70-74)GemmWarpPolicy(76-82)RegisterReflection(35-41)RegisterReflection(106-128)tilelang/ir.py (1)
GemmWarpPolicy(30-39)tilelang/tileop/gemm_sp/__init__.py (1)
GemmSPPy(29-69)
tilelang/language/experimental/gemm_sp.py (2)
tilelang/utils/language.py (1)
get_buffer_region_from_load(162-184)tilelang/language/gemm.py (1)
legalize_arguments(36-47)
src/op/gemm_sp_py.cc (4)
src/op/gemm_sp_py.h (2)
GemmSPPy(81-87)RegisterReflection(41-66)tilelang/tileop/gemm_sp/__init__.py (1)
GemmSPPy(29-69)tilelang/tileop/gemm_sp/gemm_sp_base.py (3)
policy(130-131)M(33-34)N(37-38)src/op/gemm_sp.h (1)
RegisterReflection(27-33)
tilelang/intrinsics/mma_sp_macro_generator.py (6)
tilelang/intrinsics/utils.py (2)
mma_store_index_map(82-83)get_ldmatrix_offset(22-64)tilelang/utils/language.py (1)
is_fragment(105-116)tilelang/intrinsics/mma_sp_layout.py (18)
shared_16x16_to_mma_sp_layout_sr_a(14-15)shared_16x16_to_mma_sp_layout_sr_b(18-20)shared_16x32_to_mma_sp_layout_sr_a(23-24)shared_16x32_to_mma_sp_layout_sr_b(27-29)shared_16x64_to_mma_sp_layout_sr_a(32-33)shared_16x64_to_mma_sp_layout_sr_b(36-38)mma_sp_load_a_32x4_to_shared_16x16_layout(41-42)mma_sp_load_a_32x8_to_shared_16x32_layout(45-46)mma_sp_load_a_32x16_to_shared_16x64_layout(49-50)mma_sp_load_b_32x8_to_shared_16x16_layout(53-56)mma_sp_load_b_32x16_to_shared_16x32_layout(59-62)mma_sp_load_b_32x32_to_shared_16x64_layout(65-68)metadata_8bit_load_32x4_to_shared_16x4_layout_32bit(75-80)metadata_16bit_load_32x2_to_shared_16x2_layout_32bit(83-88)metadata_8bit_load_32x4_to_shared_16x4_layout_16bit(91-94)metadata_16bit_load_32x2_to_shared_16x2_layout_16bit(97-100)metadata_32bit_load_32x1_to_shared_16x2_layout_8bit(123-129)get_ldmatrix_offset_b(156-190)tilelang/intrinsics/mma_macro_generator.py (4)
_initialize_k_dim(106-109)_initialize_k_dim(849-850)_initialize_abbrev(116-119)_initialize_abbrev(857-860)tilelang/language/tir/op.py (2)
ptx_ldmatrix(1313-1349)ptx_mma_sp(964-1062)tilelang/layout/fragment.py (2)
replicate(147-161)repeat(124-145)
🪛 Ruff (0.14.5)
tilelang/language/experimental/gemm_sp.py
163-164: Prefer TypeError exception for invalid type
(TRY004)
163-164: Avoid specifying long messages outside the exception class
(TRY003)
191-192: Prefer TypeError exception for invalid type
(TRY004)
191-192: Avoid specifying long messages outside the exception class
(TRY003)
255-256: Prefer TypeError exception for invalid type
(TRY004)
255-256: Avoid specifying long messages outside the exception class
(TRY003)
275-276: Prefer TypeError exception for invalid type
(TRY004)
275-276: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/intrinsics/mma_sp_macro_generator.py
52-60: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
62-119: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
121-129: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
180-182: Avoid specifying long messages outside the exception class
(TRY003)
214-214: Avoid specifying long messages outside the exception class
(TRY003)
315-315: Avoid specifying long messages outside the exception class
(TRY003)
384-384: Avoid specifying long messages outside the exception class
(TRY003)
393-393: Avoid specifying long messages outside the exception class
(TRY003)
398-398: Avoid specifying long messages outside the exception class
(TRY003)
400-400: Avoid specifying long messages outside the exception class
(TRY003)
448-448: Avoid specifying long messages outside the exception class
(TRY003)
699-699: Avoid specifying long messages outside the exception class
(TRY003)
716-716: Avoid specifying long messages outside the exception class
(TRY003)
774-774: Avoid specifying long messages outside the exception class
(TRY003)
788-788: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tileop/gemm_sp/gemm_sp_mma.py
58-59: Avoid specifying long messages outside the exception class
(TRY003)
234-235: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (3)
tilelang/language/experimental/gemm_sp.py (1)
92-310: LGTM!The
gemm_sp_v2implementation correctly handles input normalization, shape/stride/offset retrieval, and validation. The helper functions appropriately supportBuffer,BufferRegion, andBufferLoadtypes, and the extensive error messages aid debugging.tilelang/tileop/gemm_sp/gemm_sp_mma.py (1)
1-236: LGTM!The
GemmSPMMAclass correctly implements sparse tensor-core MMA tiling with proper layout inference and lowering for all four dataflow patterns (ss, sr, rs, rr). The per-variant kernel generation properly allocates local buffers and orchestrates load/MMA operations.tilelang/tileop/gemm_sp/gemm_sp_base.py (1)
1-131: LGTM!The
GemmSPBaseclass provides a clean abstraction with proper delegation to the underlying node. The property accessors correctly expose all necessary fields, and the dtype consistency check ensures A and B have matching types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (3)
tilelang/layout/gemm_sp.py (1)
115-120: Optional: Consider custom exception class for metadata validation errors.Static analysis suggests avoiding long messages in exception constructors. However, for this use case, the descriptive error messages are more valuable than strict adherence to TRY003. This is purely a style consideration and can be safely deferred.
tilelang/language/experimental/gemm_sp.py (2)
128-139: Extract duplicatedlegalize_argumentshelper to avoid code duplication.The
legalize_argumentshelper is identical to the one defined in thegemm_spfunction (lines 46-57). Consider extracting it to module level to follow the DRY principle.Apply this diff to extract the helper:
+def _legalize_arguments(arg: tir.Buffer | tir.Var): + """Convert let-bound variables to their corresponding buffers. + + Args: + arg (Union[tir.Buffer, tir.Var]): Input argument to legalize + + Returns: + Union[tir.Buffer, tir.Var]: The legalized argument + """ + if isinstance(arg, tir.Var) and T.has_let_value(arg): + return T.get_let_value(arg).buffer + return arg + + def gemm_sp( A_sparse: tir.Buffer | tir.Var,Then replace the nested helper definitions with calls to
_legalize_argumentsin both functions.
234-234: Minor typo in comment.Apply this diff:
- # not offset the last two dimension + # not offset the last two dimensions
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tilelang/language/experimental/gemm_sp.py(2 hunks)tilelang/layout/gemm_sp.py(5 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
tilelang/layout/gemm_sp.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
tilelang/layout/gemm_sp.py
🧬 Code graph analysis (1)
tilelang/layout/gemm_sp.py (1)
tilelang/contrib/nvcc.py (2)
get_target_compute_version(409-450)parse_compute_version(453-475)
🪛 Ruff (0.14.5)
tilelang/language/experimental/gemm_sp.py
162-163: Prefer TypeError exception for invalid type
(TRY004)
162-163: Avoid specifying long messages outside the exception class
(TRY003)
190-191: Prefer TypeError exception for invalid type
(TRY004)
190-191: Avoid specifying long messages outside the exception class
(TRY003)
254-255: Prefer TypeError exception for invalid type
(TRY004)
254-255: Avoid specifying long messages outside the exception class
(TRY003)
274-275: Prefer TypeError exception for invalid type
(TRY004)
274-275: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/layout/gemm_sp.py
116-116: Avoid specifying long messages outside the exception class
(TRY003)
120-120: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (9)
tilelang/layout/gemm_sp.py (5)
108-124: Approve SM8x metadata layout with dtype-specific bit widths.The function correctly handles different metadata bit widths: 16-bit for float16/bfloat16 and 32-bit for float8/int8 types. The
groupandinterweaveparameters adjust accordingly (32/4 for 16-bit, 16/2 for 32-bit). The PyTorch reference is helpful for verification.
138-152: LGTM: Public dispatcher with architecture-aware routing.The dispatcher correctly routes to SM90 or SM8x implementations based on compute version. The selective passing of
**extra_argsto SM90 only is appropriate since SM8x doesn't require theblock_kparameter. Architecture auto-detection provides good ergonomics.
79-83: I'll help you verify the 8-bit layout calculations in the review comment. Let me start by examining the code context and searching for related test cases and CUTLASS references.
<function_calls>
<invoke_name>shell
<invoke_name>#!/bin/bashFirst, let's examine the file structure and find the relevant code
fd -t f "gemm_sp.py" | head -5
</invoke_name>
</function_calls>
<invoke_name>shell
<invoke_name>#!/bin/bashRead the gemm_sp.py file to understand the context
wc -l tilelang/layout/gemm_sp.py
</invoke_name>
</function_calls>
<function_calls>
<invoke_name>shell
<invoke_name>#!/bin/bashRead lines around 79-83 to see the full context
sed -n '60,100p' tilelang/layout/gemm_sp.py
</invoke_name>
</function_calls>
<function_calls>
<invoke_name>shell
<invoke_name>#!/bin/bashSearch for test files related to gemm_sp or sparse layout
fd -t f "test.*gemm_sp|gemm_sp.*test"
</invoke_name>
</function_calls>
<function_calls>
<invoke_name>shell
<invoke_name>#!/bin/bashSearch for files that test or use make_cutlass_metadata_layout with 8-bit types
rg -nP --type=py "make_cutlass_metadata_layout.*float8|make_cutlass_metadata_layout.*int8|float8_e4m3fn|int8" --max-count=20 | head -30
</invoke_name>
</function_calls>
<function_calls>
<invoke_name>shell
<invoke_name>#!/bin/bashSearch for any references to CUTLASS sparse metadata or SM90
rg -nP "SM90|sparse.*metadata|cutlass.*sparse" --type=py | head -20
</invoke_name>
</function_calls>
<function_calls>
<invoke_name>shell
<invoke_name>#!/bin/bashLook for documentation or comments about the 8-bit layout calculation
rg -nP -B5 -A5 "block_k.*8|8.*bit.*layout" tilelang/layout/gemm_sp.py
</invoke_name>
</function_calls>Let me get those results and continue the investigation.
21-47: Float8 support for CUTLASS SM90 is confirmed and correctly implemented.The web search confirms that CUTLASS SM90 fully supports float8_e4m3 and float8_e5m2 for sparse GEMM operations. cuSPARSELt includes Hopper (SM 9.0) sparse GEMM kernels for E4M3 and E5M2 inputs/outputs, and CUTLASS provides TensorOp and SpTensorOp kernels targeting SM90. The code changes adding these data types to the supported list are correct and compatible with the underlying CUTLASS kernels.
126-135: Based on my web search and analysis, I can now provide my findings:Mathematical Verification:
The offset calculation in the code snippet follows the CUTLASS ColumnMajorInterleaved layout pattern. The formula(j // 2) * m * 2 + i * 2 + (j % 2)is consistent with CUTLASS's column-major interleaved mapping for k=2 (stride-2 interleaving), where:
(j // 2) * m * 2= column index × padded row extenti * 2 + (j % 2)= per-element offset componentsThe
topright/bottomleftbit manipulation logic is a standard pattern used in sparse tensor libraries (PyTorch, CUTLASS, cuSPARSELt) to encode position mappings within 4-element windows for 2:4 semi-structured sparsity.Verification Status:
- ✓ Offset calculation structure matches known CUTLASS/PyTorch patterns
- ✓ Bit manipulation approach is standard in sparse tensor encodings
- ⚠ Full correctness verification requires test cases (unable to access codebase directly)
The algorithm is mathematically sound and follows industry-standard patterns for semi-structured sparse metadata layout. However, manual testing or integration tests would be needed to confirm the implementation works correctly with the specific parameters (
group,interweave,m,k) and sparsity patterns used in TileLang.
Verify the rewritten ColumnMajorInterleaved logic against PyTorch and CUTLASS sparse metadata patterns. The offset calculation follows the standard column-major interleaved layout (matching CUTLASS documentation), and the topright/bottomleft bit manipulation is a known pattern in sparse tensor libraries for encoding position indices in 2:4 semi-structured sparsity. The algorithm is mathematically sound based on industry references. Recommend validating with existing test cases to confirm parameter values (
group,interweave,m,k) and sparse pattern compatibility.tilelang/language/experimental/gemm_sp.py (4)
200-215: LGTM: Shape validation logic is sound.The shape validation correctly handles 2D matrices and higher-order tensors with appropriate checks. The K dimension calculation accounts for the 2:4 sparsity pattern (factor of 2 on line 213) and validates consistency between matrices A and B.
277-282: LGTM: Offset validation ensures proper matrix alignment.The assertions correctly enforce that only the last dimension can be offset, which aligns with GEMM semantics where the M dimension of A and K dimension of B must be fully utilized.
284-309: LGTM: Intrinsic call properly assembles all parameters.The intrinsic call correctly passes all computed values including pointers, dimensions, strides, and offsets. The expanded parameter set compared to
gemm_spprovides greater flexibility for handling various buffer layouts and access patterns.
146-146: Python version compatibility forlist[int]annotations requires project-specific verification.The review comment's concern is technically sound: the lowercase
list[int]syntax requires Python 3.9+ (PEP 585). However, I was unable to access the repository to verify:
- The project's declared Python version requirement (
python_requiresin setup.py/pyproject.toml)- Current imports in
tilelang/language/experimental/gemm_sp.py- Whether
Listfromtypingis already imported- Usage consistency across the codebase
To complete this verification, check:
- Project's minimum Python version in
pyproject.tomlorsetup.py- If Python < 3.9, apply the suggested changes (import
Listfromtypingand useList[int])- If Python ≥ 3.9, the current
list[int]syntax is correct and preferred
| """Perform a General Matrix Multiplication (GEMM) operation. | ||
| This function computes C = A @ B where A and B can optionally be transposed. | ||
| The operation supports various warp policies and accumulation modes. | ||
| Args: | ||
| A_sparse (Union[tir.Buffer, tir.Var]): First input matrix, contains only non-zero elements | ||
| E (Union[tir.Buffer, tir.Var]): The metadata of A_sparse, noted as E | ||
| B (Union[tir.Buffer, tir.Var]): Second input matrix | ||
| C (Union[tir.Buffer, tir.Var]): Output matrix for results | ||
| transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False. | ||
| transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False. | ||
| policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square. | ||
| clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False. | ||
| k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1. | ||
| wg_wait (int, optional): Warp group wait count. Defaults to 0. | ||
| Returns: | ||
| tir.Call: A handle to the GEMM operation | ||
| Raises: | ||
| AssertionError: If the K dimensions of matrices A and B don't match | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix docstring inaccuracies.
The docstring has two issues:
- Line 104 says "General Matrix Multiplication" but this is a sparse GEMM operation—it should say "Sparse General Matrix Multiplication (GEMM-sp)".
- The
transpose_Eparameter (line 98) is missing from the docstring's Args section.
Apply this diff to fix the docstring:
- """Perform a General Matrix Multiplication (GEMM) operation.
+ """Perform a Sparse General Matrix Multiplication (GEMM-sp) operation.
This function computes C = A @ B where A and B can optionally be transposed.
The operation supports various warp policies and accumulation modes.
Args:
A_sparse (Union[tir.Buffer, tir.Var]): First input matrix, contains only non-zero elements
E (Union[tir.Buffer, tir.Var]): The metadata of A_sparse, noted as E
B (Union[tir.Buffer, tir.Var]): Second input matrix
C (Union[tir.Buffer, tir.Var]): Output matrix for results
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
+ transpose_E (bool, optional): Whether to transpose metadata matrix E. Defaults to False.
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.🤖 Prompt for AI Agents
In tilelang/language/experimental/gemm_sp.py around lines 104 to 126, the
docstring title and Args are inaccurate: change the short description from
"General Matrix Multiplication" to "Sparse General Matrix Multiplication
(GEMM-sp)" and add the missing transpose_E parameter to the Args list (e.g.,
transpose_E (bool, optional): Whether to transpose metadata E. Defaults to
False.), keeping style and default value consistent with transpose_A/transpose_B
entries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (5)
tilelang/language/experimental/gemm_sp.py (3)
105-105: Update docstring title to reflect sparse GEMM operation.The docstring says "General Matrix Multiplication" but this function performs a sparse GEMM operation.
Apply this diff:
- """Perform a General Matrix Multiplication (GEMM) operation. + """Perform a Sparse General Matrix Multiplication (GEMM-sp) operation.Based on past review comments, this issue was previously flagged but the fix may have been incomplete.
115-120: Document thetranspose_Eparameter in the docstring.The function signature includes
transpose_E(line 99), but it's missing from the Args section of the docstring.Apply this diff:
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False. transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False. + transpose_E (bool, optional): Whether to transpose metadata matrix E. Defaults to False. policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
258-258: Fix incorrect return type annotation.The return type is annotated as
tir.PrimExpr, but the function returnslist[int]at lines 261, 267, and 273.Apply this diff:
- def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr: + def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> list[int]: """Retrieve the offset of the buffer or buffer region."""Based on past review comments, this issue was previously identified but not yet addressed.
src/op/gemm_sp.cc (1)
321-322: Return the computed warp partition from the FFI binding.The lambda calls
policy->computeWarpPartition(...)which returnsstd::pair<int, int>(see lines 22-23 and line 63), but the result is discarded and the lambda returns void. Python callers expect to receive the computed(m_warp, n_warp)values.Apply this diff:
- policy->computeWarpPartition(M, N, block_size, target, use_wgmma, bits); - return; + return policy->computeWarpPartition(M, N, block_size, target, use_wgmma, bits);Note: A past review comment indicated this was addressed in commits 354e9af to b2871dd, but the current code still shows the issue. Please verify if the fix was applied to a different branch or if this needs to be corrected.
tilelang/intrinsics/mma_macro_generator.py (1)
424-425: Verify reachability: 16-bit dtype in non-ldmatrix path appears unreachable.Similar to matrix A, the condition at line 416 ensures
ldmatrix_available = Truefor all 16-bit dtypes (the condition simplifies tobits == 16 or b_transposed). This makes the check at line 424 unreachable.This issue mirrors the one flagged for matrix A at lines 296-297. If the earlier verification confirms this is intentional for future sparse support, this code should also be documented accordingly.
🧹 Nitpick comments (3)
src/tl_templates/cuda/debug.h (1)
111-119: Specialization is fine; consider trait-based reuse and RTC include checkThe uint16_t specialization looks correct and the
%u+uint32_tcast is safe. Two small follow‑ups:
- For consistency with the rest of this file, you might instead specialize
PrintTraits<uint16_t>(or alias it toPrintTraits<unsigned short>) and let the genericdebug_print_buffer_valuecall through the trait, sodebug_print_var<uint16_t>also prints a meaningful value instead of going through the generic “dtype=unknown” path.- This specialization introduces an unconditional use of
uint16_t/uint32_t; please double‑check that these typedefs are available in all compile modes (especially under__CUDACC_RTC__, where<cstdint>is currently skipped) so NVRTC builds don’t fail.tilelang/language/experimental/gemm_sp.py (1)
129-140: Consider extracting the duplicatedlegalize_argumentshelper.The
legalize_argumentsfunction (lines 129-140) is identical to the one ingemm_sp(lines 46-57 in tilelang/language/gemm.py per relevant snippets). Extracting this to a shared utility would reduce duplication.You could move
legalize_argumentsto a shared location (e.g.,tilelang.utils.language) and import it in both functions, similar to howget_buffer_region_from_loadis imported.tilelang/layout/gemm_sp.py (1)
137-151: Consider validating required parameters for SM90.The dispatcher passes
**extra_argstomake_cutlass_metadata_layout_sm90, which requires ablock_kparameter. Ifextra_argsdoesn't containblock_k, this will raise aTypeErrorat runtime.Consider adding parameter validation or making the requirement explicit:
def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer, mma_dtype: str = "float16", arch: str | None = None, - **extra_args): + block_k: int | None = None, + **extra_args): if arch is None: arch = nvcc.get_target_compute_version() compute_version = nvcc.parse_compute_version(arch) if compute_version >= (9, 0): + if block_k is None: + raise ValueError("block_k is required for SM90 and above") - return make_cutlass_metadata_layout_sm90(buffer=buffer, mma_dtype=mma_dtype, **extra_args) + return make_cutlass_metadata_layout_sm90(buffer=buffer, mma_dtype=mma_dtype, block_k=block_k, **extra_args) elif compute_version >= (8, 0): return make_cutlass_metadata_layout_sm8x(buffer=buffer, mma_dtype=mma_dtype) else: raise NotImplementedError(f"Unsupported architecture: {arch}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
src/op/gemm_sp.cc(1 hunks)src/op/gemm_sp.h(1 hunks)src/tl_templates/cuda/debug.h(1 hunks)tilelang/intrinsics/mma_macro_generator.py(3 hunks)tilelang/language/__init__.py(1 hunks)tilelang/language/experimental/gemm_sp.py(2 hunks)tilelang/layout/gemm_sp.py(5 hunks)tilelang/profiler/__init__.py(1 hunks)tilelang/utils/tensor.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- tilelang/utils/tensor.py
- src/op/gemm_sp.h
- tilelang/profiler/init.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
tilelang/layout/gemm_sp.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
tilelang/layout/gemm_sp.py
🧬 Code graph analysis (4)
src/op/gemm_sp.cc (3)
src/op/gemm_sp.h (1)
RegisterReflection(27-33)src/op/gemm.h (2)
RegisterReflection(35-41)RegisterReflection(106-128)tilelang/tileop/gemm_sp/gemm_sp_base.py (3)
policy(130-131)M(33-34)N(37-38)
tilelang/language/__init__.py (1)
tilelang/language/experimental/gemm_sp.py (2)
gemm_sp(10-88)gemm_sp_v2(92-310)
tilelang/layout/gemm_sp.py (1)
tilelang/contrib/nvcc.py (2)
get_target_compute_version(401-442)parse_compute_version(445-467)
tilelang/language/experimental/gemm_sp.py (4)
tilelang/utils/language.py (1)
get_buffer_region_from_load(162-193)tilelang/tileop/gemm_sp/gemm_sp_base.py (6)
E(74-75)B(78-79)C(82-83)policy(130-131)M(33-34)N(37-38)src/op/gemm.h (4)
GemmWarpPolicy(59-83)GemmWarpPolicy(64-68)GemmWarpPolicy(70-74)GemmWarpPolicy(76-82)tilelang/language/gemm.py (1)
legalize_arguments(36-47)
🪛 Ruff (0.14.5)
tilelang/language/__init__.py
54-54: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/layout/gemm_sp.py
115-115: Avoid specifying long messages outside the exception class
(TRY003)
119-119: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/experimental/gemm_sp.py
163-164: Prefer TypeError exception for invalid type
(TRY004)
163-164: Avoid specifying long messages outside the exception class
(TRY003)
191-192: Prefer TypeError exception for invalid type
(TRY004)
191-192: Avoid specifying long messages outside the exception class
(TRY003)
255-256: Prefer TypeError exception for invalid type
(TRY004)
255-256: Avoid specifying long messages outside the exception class
(TRY003)
275-276: Prefer TypeError exception for invalid type
(TRY004)
275-276: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (8)
tilelang/language/__init__.py (1)
54-54: LGTM! Public API export is correct.The addition of
gemm_sp_v2to the public API export is appropriate and consistent with the pattern used for other experimental features.The static analysis hint about the unused
noqadirective is a false positive—the directive is standard practice for public API modules where imports are re-exported rather than used locally.src/op/gemm_sp.cc (1)
313-324: LGTM! FFI registration structure is well-organized.The FFI registration block properly registers both reflection entries and the global function definition, following the established pattern from
GemmWarpPolicyNode(see src/op/gemm.h:34-40). The registration structure enables proper exposure of theGemmSPWarpPolicytype and its methods to the Python FFI layer.tilelang/intrinsics/mma_macro_generator.py (2)
25-25: LGTM! New layout imports follow existing conventions.The newly imported layout functions follow the established naming pattern and are used in the ldmatrix fallback paths below.
Also applies to: 28-28
296-297: I'm unable to access the repository due to persistent clone failures. Without being able to examine the actual code at lines 280-310 intilelang/intrinsics/mma_macro_generator.py, I cannot definitively verify whether the 16-bit dtype path is unreachable or confirm the logic of theldmatrix_availablecondition at line 288.To complete this verification, I would need to:
- Examine the complete context around lines 280-310
- Verify the actual
ldmatrix_availablelogic- Search for sparse tensor constraints or code paths that might affect 16-bit handling
The review comment requires manual verification of the reachability claim.
tilelang/layout/gemm_sp.py (4)
33-46: LGTM! fp8 dtype support correctly added.The addition of
float8_e4m3andfloat8_e5m2to both the validation check and the bits_map is correct and consistent with the 8-bit handling logic.
78-80: LGTM! 8-bit dtype handling correctly parameterized.The parameterization of shape and stride calculations using
block_kfor 8-bit types is consistent with how other bit-widths handle block_k (through NumK).
117-119: Verify dtype support consistency across SM8x and SM90—uint8 may be missing from SM90.Line 117 includes
"uint8"in the SM8x validation check. According to CUTLASS architecture documentation, both SM8x (Ampere) and SM90 (Hopper) have hardware support for INT8/UINT8 tensor-core operations. If the SM90 layout function (line 33) omits"uint8"from its dtype list, this represents an inconsistency that should be addressed to ensure feature parity across architectures, unless there is an explicit architectural limitation or design decision not shown in the provided context.
107-134: Based on my verification efforts, I found important information but cannot fully validate the implementation. Here's what I discovered:PyTorch Reference Implementation (from
_calculate_meta_reordering_scatter_offsets):
- Uses
group_x = 64(fixed) andgroup_y = 32or16(based on dtype.itemsize)- Has a specific row reordering formula with multiple steps
- Uses
interleave = 2(fixed)- Final offset:
(cols_maj * m * interleave + dst_rows * interleave + cols_min)TileLang Implementation (from the review):
- Uses
group = 32or16(not 64)- Uses
interweave = 4or2(variable)- Row calculation:
i = i // group * group + (i % 8) * interweave + (i % group) // 8- Final offset:
(j // 2) * m * 2 + i * 2 + (j % 2)The algorithms differ substantially from the PyTorch reference. However, I could not find:
- Public documentation of TileLang's specific sparse metadata layout algorithm
- Tests or specifications validating the implementation
- Explanation of why the approach differs from PyTorch's CUTLASS reference
The algorithm used in
ColumnMajorInterleaveddiffers significantly from the PyTorch/CUTLASS reference implementation. Verify this is intentional or a bug by:
- Checking the original implementation before this change for comparison
- Running existing sparse GEMM tests to confirm correctness
- Documenting the reasoning if this is an intentional algorithmic divergence from CUTLASS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
src/op/gemm_sp_py.cc (1)
216-227: Guardstd::stoiinGetArchIntto avoid exceptions on malformedarch
GetArchIntstill callsstd::stoi(arch.substr(3))directly. If thearchstring is"sm_abc"or otherwise non‑numeric after"sm_", this will throwstd::invalid_argumentorstd::out_of_rangeand crash instead of falling back to0.You can make this more robust by catching conversion errors and logging a warning:
static int GetArchInt(Target target) { int arch_int = 0; auto s = target->GetAttr<String>("arch"); ICHECK(s.has_value()); std::string arch = s.value(); if (arch.rfind("sm_", 0) == 0) { - arch_int = std::stoi(arch.substr(3)); + try { + arch_int = std::stoi(arch.substr(3)); + } catch (const std::exception &e) { + LOG(WARNING) << "Failed to parse architecture number from '" << arch + << "': " << e.what(); + arch_int = 0; + } } else { arch_int = 0; } return arch_int; }This keeps the documented behavior (“return 0 if it doesn’t match
sm_<num>”) even when the attribute content is unexpected.
🧹 Nitpick comments (2)
tilelang/language/experimental/gemm_sp.py (1)
112-134: Aligngemm_sp_v2docstring with sparse semantics and argumentsThe docstring still describes a generic GEMM and doesn’t mention
transpose_E, even though this is a sparse GEMM with explicit metadata and atranspose_Eparameter.Consider updating the text to make this clearer, e.g.:
- """Perform a General Matrix Multiplication (GEMM) operation. + """Perform a Sparse General Matrix Multiplication (GEMM-sp) operation. @@ - transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False. + transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False. + transpose_E (bool, optional): Whether to transpose metadata matrix E. Defaults to False.This keeps the Python API docs in sync with the underlying operator semantics.
src/op/gemm_sp_py.cc (1)
24-45: Update ctor argument layout comment to includetrans_EThe implementation reads
trans_Efromargs[6], but the “expected layout” comment only liststrans_Aandtrans_B. To avoid confusion when maintaining the Python side, consider updating the comment block to reflect the extra flag, e.g.:- * [Aptr, Eptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool), - * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), + * [Aptr, Eptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool), + * trans_E (Bool), + * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), * (optional) kPack (Int), (optional) wg_wait (Int)]The ctor logic itself looks consistent with the Python call site.
Also applies to: 65-85
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/op/gemm_sp_py.cc(1 hunks)src/op/gemm_sp_py.h(1 hunks)tilelang/language/experimental/gemm_sp.py(2 hunks)tilelang/tileop/gemm_sp/gemm_sp_base.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tilelang/tileop/gemm_sp/gemm_sp_base.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
src/op/gemm_sp_py.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
src/op/gemm_sp_py.cc
🧬 Code graph analysis (3)
src/op/gemm_sp_py.h (1)
src/op/gemm_sp.h (2)
tvm(13-117)RegisterReflection(27-33)
tilelang/language/experimental/gemm_sp.py (2)
tilelang/utils/language.py (5)
to_buffer_region(196-238)retrieve_shape(241-258)retrieve_stride(261-279)retrieve_offset(359-376)prim_expr_equal(390-405)tilelang/language/utils.py (1)
buffer_region_to_tile_region(23-35)
src/op/gemm_sp_py.cc (3)
src/op/gemm_sp_py.h (1)
GemmSPPy(83-89)tilelang/tileop/gemm_sp/__init__.py (1)
GemmSPPy(29-69)tilelang/tileop/gemm_sp/gemm_sp_base.py (3)
policy(130-131)M(33-34)N(37-38)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (2)
tilelang/language/experimental/gemm_sp.py (1)
6-14: Imports and existinggemm_spwiring look consistentUsing the shared helpers (
to_buffer_region,retrieve_shape/stride/offset,prim_expr_equal,buffer_region_to_tile_region) in this module keeps the sparse GEMM path aligned with the common utilities; the existinggemm_spimplementation remains consistent and doesn’t introduce new issues in this diff.Also applies to: 53-95
src/op/gemm_sp_py.h (1)
21-88: GemmSPPy node/operator definition and reflection look consistentThe layout of
GemmSPPyNode(buffers, regions, transposes, sizes, strides, offsets,clear_accum,kPack/wg_wait, andpolicy) matches how the ctor ingemm_sp_py.ccinitializes these fields, andRegisterReflectionexposes the expected attributes for Python/FFI use. TheGemmSPPywrapper andTVM_FFImacros also follow the existing TileOperator patterns.No issues from the header side in this diff.
| A_offset = retrieve_offset(A_sparse) | ||
| B_offset = retrieve_offset(B) | ||
| assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0" | ||
| assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0" | ||
| offset_a = A_offset[-1] | ||
| offset_b = B_offset[-1] | ||
|
|
||
| A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape]) | ||
| E_arg = buffer_region_to_tile_region(E_region, "r", [r for r in E_shape]) | ||
| B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape]) | ||
| C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use BufferRegion when computing offsets, not the raw Buffer
retrieve_offset(A_sparse) / retrieve_offset(B) operate on the underlying tir.Buffer, which always returns zeros for all dimensions. This means:
- The assertions on
A_offset[-2]/B_offset[-2]will pass even if the actual region has a non‑zero minimum on that dimension. offset_a/offset_bpassed intotl.gemm_sp_pywill ignore any per‑region minima encoded inA_region/B_region.
Since you already have A_region / B_region as tir.BufferRegions, it’s more accurate to derive offsets from those so sub‑regions are handled correctly:
- A_offset = retrieve_offset(A_sparse)
- B_offset = retrieve_offset(B)
+ A_offset = retrieve_offset(A_region)
+ B_offset = retrieve_offset(B_region)This keeps the offset computation consistent with how retrieve_offset is defined in tilelang.utils.language and ensures non‑zero minima on the last dimension are reflected in offset_a / offset_b.
🤖 Prompt for AI Agents
In tilelang/language/experimental/gemm_sp.py around lines 188-198, the code
currently calls retrieve_offset on the raw Buffers (A_sparse and B) which always
yields zeros; replace those calls to use the corresponding BufferRegion objects
(A_region and B_region) so offsets reflect per-region minima; update the
assertions to check the second-to-last element of the offsets returned from the
regions and set offset_a/offset_b from the last element of those region-derived
offsets before passing them into tl.gemm_sp_py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tilelang/utils/sparse.py (2)
109-128: Update docstring to reflect conditional sparsity pattern.The docstring states "2:4 sparsity" but the implementation now uses 1:2 sparsity for
float32. Consider updating the documentation to clarify this behavior.def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transposed: bool = False): """ - Generate a random semi-sparse tensor. The generated tensor will have 2:4 sparsity along the K dimension. + Generate a random semi-sparse tensor. The generated tensor will have 2:4 sparsity along the K dimension + (or 1:2 sparsity for float32). Args: M (int): Number of rows K (int): Number of columns
161-184: Update docstring to reflect conditional sparsity pattern.Same as
randn_semi_sparse, the docstring states "2:4 sparsity" but uses 1:2 sparsity forfloat32.def arange_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transposed: bool = False): """ - Generate a semi-sparse tensor with values from 0 to M*K-1. The generated tensor will have 2:4 sparsity along the K dimension. + Generate a semi-sparse tensor with values from 0 to M*K-1. The generated tensor will have 2:4 sparsity + along the K dimension (or 1:2 sparsity for float32). Args: M (int): Number of rows K (int): Number of columns
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
tilelang/profiler/__init__.py(2 hunks)tilelang/utils/sparse.py(5 hunks)tilelang/utils/tensor.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/utils/sparse.py (1)
tilelang/utils/tensor.py (2)
is_float8_dtype(8-14)fp8_remove_negative_zeros_(17-21)
tilelang/profiler/__init__.py (1)
tilelang/utils/tensor.py (1)
is_float8_dtype(8-14)
tilelang/utils/tensor.py (3)
tilelang/language/v2/dtypes.py (5)
float8_e5m2(266-266)float8_e5m2fnuz(273-273)float8_e4m3fn(252-252)float8_e4m3fnuz(259-259)uint8(182-182)tilelang/language/proxy.py (1)
Tensor(252-253)tilelang/language/customize.py (1)
view(55-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (8)
tilelang/utils/tensor.py (3)
8-14: LGTM!Clean implementation of float8 dtype detection using set membership. This centralizes the float8 check for reuse across the codebase.
17-21: LGTM!Correctly uses PyTorch's in-place operation naming convention (trailing underscore). The implementation properly handles negative zeros in float8 by viewing the underlying bytes as uint8 and zeroing out positions where the original value was zero.
24-31: LGTM!The enum extension provides new tensor supply strategies. Note that
Normal(line 27) andRandn(line 28) represent different distributions:Normalusesnormal_(-1.0, 1.0)with custom mean/std, whileRandnuses the standard normal distribution viatorch.randn.tilelang/profiler/__init__.py (2)
9-14: LGTM!Good refactoring to import the centralized
is_float8_dtypehelper instead of maintaining a local implementation. This improves code reusability and consistency.
129-137: LGTM!Correct handling of float8 dtypes by converting to float32 before comparison, since float8 types have limited operation support in PyTorch.
tilelang/utils/sparse.py (3)
6-6: LGTM!Appropriate import of float8 utilities to support the compression path enhancements.
131-158: LGTM!The new
randint_semi_sparsefunction correctly follows the same pattern asrandn_semi_sparse. Note that the docstring should also mention the 1:2 sparsity forfloat32(thoughfloat32is unusual for integer tensors).
92-103: I'll help you verify the metadata layout consistency concern with transposed inputs. Let me search the codebase for the relevant function definitions and usages.
<function_calls>
#!/bin/bashFirst, find the compress_sm80 function definition to understand what E represents
rg -n "def compress_sm80" --type=py -A 15
</function_calls>
#!/bin/bashSearch for the compress function itself to see the full context
rg -n "def compress(" --type=py -A 30 | head -100
</function_calls>
#!/bin/bashFind all usages of compress() to see how the returned E is used
rg -n "compress\s*(" --type=py -B 2 -A 3 | head -150
</function_calls>Let me look at the specific file mentioned in the review:
<function_calls>
#!/bin/bashRead the sparse.py file to understand the full context
wc -l tilelang/utils/sparse.py
</function_calls><function_calls>
#!/bin/bashRead the file to see the compress function and related functions
cat -n tilelang/utils/sparse.py | sed -n '70,130p'
</function_calls><function_calls>
#!/bin/bashLet's also check for compress_sm80 and other compress functions in the file
cat -n tilelang/utils/sparse.py | grep -E "(def compress|compress_sm80)" -A 3 -B 1
</function_calls><function_calls>
#!/bin/bashFind all calls to compress() and compress_sm80() across the codebase to understand usage patterns
rg -n "compress(" --type=py -B 1 -A 2 | grep -E "(compress|sparse|decompress)" | head -80
</function_calls><function_calls>
#!/bin/bashCheck what compress_sm80 is - might be a C++ extension
rg -n "compress_sm80" --type=py -B 2 -A 5
</function_calls><function_calls>
#!/bin/bashLook for decompress to understand the symmetry
rg -n "def decompress" --type=py -A 20
</function_calls><function_calls>
#!/bin/bashSearch for any test files that might demonstrate usage of compress with transposed flag
find . -name "test.py" -type f | xargs rg -l "compress" 2>/dev/null | head -5
</function_calls><function_calls>
#!/bin/bashLet's search for test files and see how compress is tested
fd -e "test*.py" | xargs rg -n "compress(" -B 3 -A 5 2>/dev/null | head -100
</function_calls><function_calls>
#!/bin/bashLet me view the entire sparse.py file to get full context
cat tilelang/utils/sparse.py
</function_calls>
Checklist
4090 mini benchmark
(2 experiments)
(Torch, CUTlASS backend)
(Torch, CUSPARSELT backend)
(TileLang Sparse, fp32 accum)
Summary by CodeRabbit
New Features
Documentation
Updates
Tests
✏️ Tip: You can customize this high-level summary in your review settings.