Skip to content

Conversation

@botbw
Copy link
Contributor

@botbw botbw commented Oct 17, 2025

Checklist

  • bf16/fp16
  • customized metadata layout
  • tf32 (with precision issue(
  • int8
  • fp8
  • different scopes
    • sss
    • srs
    • rss
    • rrs
    • metadata in register (?)
  • custom compression utils example
  • Doc

4090 mini benchmark

mnk Ref TFLOPS (CuBLAS Dense)
(2 experiments)
Ref TFLOPS
(Torch, CUTlASS backend)
Ref TFLOPS
(Torch, CUSPARSELT backend)
Best TFLOPS
(TileLang Sparse, fp32 accum)
16384 154.373 / 156.286 194.450 183.236 278.731

Summary by CodeRabbit

  • New Features

    • Public GEMM SP v2 and GemmSPPy operator; configurable sparse Tensor‑Core emitter with richer sparse MMA and transpose/metadata options.
  • Documentation

    • Added sparse matmul guide with examples, migration notes, and index entry.
  • Updates

    • Adopted Cutlass-style metadata layouts, added float8 support and fp8 utilities, improved load/layout helpers and debug/packing utilities.
  • Tests

    • New examples and end‑to‑end tests covering sparse GEMM variants, dtypes, and custom compression workflows.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 17, 2025

Walkthrough

Adds a new gemm_sp_v2 TileLang API and TL operator (GemmSPPy) with FFI wiring and GemmSPWarpPolicy, a SparseTensorCore intrinsics emitter plus many SP layout helpers, refactors metadata layout to make_cutlass_metadata_layout, updates examples/benchmarks/tests, and exposes related utilities and tests.

Changes

Cohort / File(s) Summary
FFI / C++ op & reflection
src/op/gemm_sp.cc, src/op/gemm_sp.h
Register tl.GemmSPWarpPolicy, add reflection registration and FFI entry GemmSPWarpPolicyComputeWarpPartition.
C++ TileOp wiring
src/op/gemm_sp_py.cc, src/op/gemm_sp_py.h, src/op/gemm_sp_py.cc
New GemmSPPy TileOperator: arg deserialization, Clone, GetGemmInst/CheckWGMMA, Lower/InferLayout, TL op registration and static reflection init.
TileLang public API
tilelang/language/experimental/gemm_sp.py, tilelang/language/__init__.py
Add and export gemm_sp_v2 that legalizes arguments, computes shapes/strides/offsets, checks K consistency, and dispatches to tl.gemm_sp_py.
TileOp layer (Python)
tilelang/tileop/gemm_sp/*, tilelang/tileop/__init__.py, tilelang/tileop/gemm_sp/gemm_sp_base.py, tilelang/tileop/gemm_sp/gemm_sp_mma.py
Add GemmSPBase accessors, GemmSPPy export, and GemmSPMMA implementing infer_layout and lower for ss/sr/rs/rr patterns.
Sparse emitter & layouts
tilelang/intrinsics/mma_sp_macro_generator.py, tilelang/intrinsics/mma_sp_layout.py, tilelang/intrinsics/mma_layout.py, tilelang/intrinsics/mma_sp_layout.py, tilelang/intrinsics/mma_macro_generator.py
Add SparseTensorCoreIntrinEmitter (ldmatrix/loads/mma_sp/stmatrix), many SP layout helpers, and 32x8→16x16 load-layout helpers for non-ldmatrix fallbacks.
Layout API refactor
tilelang/layout/gemm_sp.py, tilelang/layout/__init__.py
Replace make_metadata_layout with make_cutlass_metadata_layout, remove backend arg, add SM90/SM8x creators and arch dispatch; update callers.
TileLang IR / policy
tilelang/ir.py
Add GemmSPWarpPolicy class and compute_warp_partition(..., bits) delegating to FFI.
TileOp typing
tilelang/tileop/gemm/__init__.py
Add type hints to gemm_py_infer_layout and gemm_py_lower signatures.
Utils: tensor & sparse
tilelang/utils/tensor.py, tilelang/utils/sparse.py
Add is_float8_dtype, fp8_remove_negative_zeros_, extend TensorSupplyType; add randint_semi_sparse and dtype-aware compress/randn behavior.
Templates / Debug / Common
src/tl_templates/cuda/debug.h, src/tl_templates/cuda/common.h
Add debug_print_buffer_value<uint16_t> specialization and new make_int4(short...) overload.
Benchmarks / Examples
benchmark/matmul/benchmark_matmul_sp.py, examples/gemm_sp/*, examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py, examples/gemm_sp/example_custom_compress.py
Update matmul_sp signature to accept in_dtype and call T.gemm_sp_v2; switch to make_cutlass_metadata_layout; add example custom compressor and config constants; adjust imports and CLI defaults.
Tests
examples/gemm_sp/test_example_gemm_sp.py, testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py, testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py
Add example tests; add/refactor comprehensive gemm_sp and gemm_sp_v2 tests with dtype-aware input generators and dense-reference comparisons.
Docs / Index
docs/deeplearning_operators/matmul_sparse.md, docs/index.md
Add new matmul_sparse documentation and register it in docs index.
Profiler exports
tilelang/profiler/__init__.py
Import and re-export is_float8_dtype from tilelang.utils.tensor (replace local implementation).

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant TileLang as gemm_sp_v2
    participant TL_Op as GemmSPPy
    participant TileOp as GemmSPMMA
    participant Emitter as SparseTensorCoreIntrinEmitter

    User->>TileLang: call gemm_sp_v2(A_sparse, E, B, C, ...)
    TileLang->>TL_Op: construct GemmSPPy node (buffers, args)
    TL_Op->>TileOp: infer_layout(target, thread_nums)
    TileOp->>Emitter: build emitter for pattern (ss/sr/rs/rr)
    Emitter->>Emitter: ldmatrix / make_mma_load_layout / mma_sp / stmatrix
    Emitter-->>TileOp: fragment/layout map
    TileOp-->>TL_Op: layout_map
    TL_Op->>TileOp: lower(target, thread_nums, thread_var)
    TileOp->>Emitter: emit lowering -> prim_func
    TileOp-->>TL_Op: lowered kernel
    TL_Op-->>User: compiled kernel handle
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

  • Attention points:
    • tilelang/intrinsics/mma_sp_macro_generator.py — large emitter with dense index/thread math and dtype-conditional branches.
    • tilelang/tileop/gemm_sp/gemm_sp_mma.py — multiple kernel variants, warp partitioning, lowering complexity.
    • src/op/gemm_sp_py.cc / src/op/gemm_sp_py.h — FFI deserialization, pointer/stride handling and lowering hooks.
    • tilelang/layout/gemm_sp.py — metadata layout math, arch dispatch, FP8 handling.
    • benchmark/examples/tests — confirm API updates (matmul_sp signature, in_dtype propagation, make_cutlass_metadata_layout, T.gemm_sp_v2).

Possibly related issues

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • chengyupku

Poem

🐇 I hopped through fragments, layouts bright,
I packed int4s and tuned the thread's flight.
gemm_sp_v2 hums, metadata in sight,
Kernels compiled, sparse math feels light.
A rabbit cheers — kernels take flight!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.88% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly identifies the main change: adding support for T.gemm_sp_v2 on SM80 and SM89 architectures, which directly corresponds to the primary objective of this PR.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@botbw botbw force-pushed the gemm_sp_v2 branch 3 times, most recently from 016dd1c to 122abb5 Compare October 20, 2025 07:07
@botbw botbw marked this pull request as ready for review November 5, 2025 15:43
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
docs/deeplearning_operators/matmul_sparse.md (1)

39-39: Improve link text for accessibility.

The duplicate [here] links violate MD059 and make the doc harder to navigate with assistive tech. Please replace them with descriptive titles (e.g., [PyTorch sparse kernel], [vLLM sparse kernel]).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 82261e7 and d7ca20e.

📒 Files selected for processing (5)
  • docs/deeplearning_operators/matmul_sparse.md (1 hunks)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (1 hunks)
  • tilelang/intrinsics/mma_sp_macro_generator.py (1 hunks)
  • tilelang/language/experimental/gemm_sp.py (2 hunks)
  • tilelang/tileop/gemm_sp/gemm_sp_mma.py (1 hunks)
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • docs/deeplearning_operators/matmul_sparse.md
  • tilelang/intrinsics/mma_sp_macro_generator.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • tilelang/intrinsics/mma_sp_macro_generator.py
🧬 Code graph analysis (4)
tilelang/tileop/gemm_sp/gemm_sp_mma.py (6)
tilelang/tileop/gemm_sp/gemm_sp_base.py (20)
  • GemmSPBase (11-127)
  • infer_layout (14-15)
  • policy (126-127)
  • M (33-34)
  • N (37-38)
  • in_dtype (57-59)
  • e_dtype (53-54)
  • accum_dtype (62-63)
  • trans_A (45-46)
  • trans_B (49-50)
  • K (41-42)
  • is_gemm_ss (20-21)
  • A (66-67)
  • B (74-75)
  • C (78-79)
  • is_gemm_sr (23-24)
  • is_gemm_rs (26-27)
  • is_gemm_rr (29-30)
  • lower (17-18)
  • E (70-71)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (10-18)
tilelang/intrinsics/mma_sp_macro_generator.py (6)
  • make_mma_store_layout (789-858)
  • make_mma_load_layout (646-787)
  • ldmatrix_a (290-354)
  • ldmatrix_e (356-418)
  • ldmatrix_b (420-516)
  • mma_sp (518-589)
tilelang/utils/language.py (2)
  • is_shared (25-39)
  • is_fragment (81-91)
tilelang/transform/simplify.py (1)
  • _Simplify (31-49)
tilelang/tileop/gemm_sp/__init__.py (2)
  • infer_layout (56-61)
  • lower (63-69)
tilelang/language/experimental/gemm_sp.py (2)
tilelang/utils/language.py (1)
  • get_buffer_region_from_load (137-159)
tilelang/language/gemm.py (10)
  • legalize_arguments (48-59)
  • legalize_arguments (251-262)
  • retrieve_shape (66-83)
  • retrieve_shape (269-286)
  • retrieve_stride (85-111)
  • retrieve_stride (288-314)
  • retrieve_ptr (140-175)
  • retrieve_ptr (343-378)
  • retrieve_offset (177-195)
  • retrieve_offset (380-398)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (6)
tilelang/utils/sparse.py (3)
  • compress (77-106)
  • randn_semi_sparse (109-128)
  • randint_semi_sparse (131-158)
tilelang/utils/tensor.py (2)
  • torch_assert_close (237-329)
  • map_torch_type (37-54)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (136-150)
tilelang/intrinsics/mma_sp_macro_generator.py (1)
  • SparseTensorCoreIntrinEmitter (40-858)
tilelang/language/experimental/gemm_sp.py (1)
  • gemm_sp_v2 (91-307)
tilelang/layout/swizzle.py (1)
  • make_swizzled_layout (10-18)
tilelang/intrinsics/mma_sp_macro_generator.py (5)
tilelang/intrinsics/utils.py (2)
  • mma_store_index_map (81-82)
  • get_ldmatrix_offset (21-63)
tilelang/utils/language.py (1)
  • is_fragment (81-91)
tilelang/intrinsics/mma_sp_layout.py (20)
  • shared_16x16_to_mma_sp_layout_sr_a (14-15)
  • shared_16x16_to_mma_sp_layout_sr_b (18-20)
  • shared_16x32_to_mma_sp_layout_sr_a (23-24)
  • shared_16x32_to_mma_sp_layout_sr_b (27-29)
  • shared_16x64_to_mma_sp_layout_sr_a (32-33)
  • shared_16x64_to_mma_sp_layout_sr_b (36-38)
  • mma_sp_load_a_32x4_to_shared_16x16_layout (41-42)
  • mma_sp_load_a_32x8_to_shared_16x32_layout (45-46)
  • mma_sp_load_a_32x16_to_shared_16x64_layout (49-50)
  • mma_sp_load_b_32x8_to_shared_16x16_layout (53-56)
  • mma_sp_load_b_32x16_to_shared_16x32_layout (59-62)
  • mma_sp_load_b_32x32_to_shared_16x64_layout (65-68)
  • metadata_8bit_load_32x4_to_shared_16x4_layout_32bit (75-80)
  • metadata_16bit_load_32x2_to_shared_16x2_layout_32bit (83-88)
  • metadata_8bit_load_32x4_to_shared_16x4_layout_16bit (91-94)
  • metadata_16bit_load_32x2_to_shared_16x2_layout_16bit (97-100)
  • metadata_8bit_load_32x4_to_shared_16x4_layout_8bit (107-112)
  • metadata_16bit_load_32x2_to_shared_16x4_layout_8bit (115-120)
  • metadata_32bit_load_32x1_to_shared_16x2_layout_8bit (123-129)
  • get_ldmatrix_offset_b (156-190)
tilelang/language/tir/op.py (2)
  • ptx_ldmatrix (1313-1349)
  • ptx_mma_sp (964-1062)
tilelang/layout/fragment.py (2)
  • replicate (147-161)
  • repeat (124-145)
🪛 markdownlint-cli2 (0.18.1)
docs/deeplearning_operators/matmul_sparse.md

39-39: Link text should be descriptive

(MD059, descriptive-link-text)


39-39: Link text should be descriptive

(MD059, descriptive-link-text)


176-176: Link text should be descriptive

(MD059, descriptive-link-text)

🪛 Ruff (0.14.3)
tilelang/tileop/gemm_sp/gemm_sp_mma.py

57-58: Avoid specifying long messages outside the exception class

(TRY003)


232-233: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/experimental/gemm_sp.py

161-162: Prefer TypeError exception for invalid type

(TRY004)


161-162: Avoid specifying long messages outside the exception class

(TRY003)


189-190: Prefer TypeError exception for invalid type

(TRY004)


189-190: Avoid specifying long messages outside the exception class

(TRY003)


253-254: Prefer TypeError exception for invalid type

(TRY004)


253-254: Avoid specifying long messages outside the exception class

(TRY003)


273-274: Prefer TypeError exception for invalid type

(TRY004)


273-274: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/mma_sp_macro_generator.py

52-60: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


62-119: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


121-129: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


178-180: Avoid specifying long messages outside the exception class

(TRY003)


212-212: Avoid specifying long messages outside the exception class

(TRY003)


313-313: Avoid specifying long messages outside the exception class

(TRY003)


381-381: Avoid specifying long messages outside the exception class

(TRY003)


390-390: Avoid specifying long messages outside the exception class

(TRY003)


395-395: Avoid specifying long messages outside the exception class

(TRY003)


397-397: Avoid specifying long messages outside the exception class

(TRY003)


445-445: Avoid specifying long messages outside the exception class

(TRY003)


696-696: Avoid specifying long messages outside the exception class

(TRY003)


713-713: Avoid specifying long messages outside the exception class

(TRY003)


771-771: Avoid specifying long messages outside the exception class

(TRY003)


785-785: Avoid specifying long messages outside the exception class

(TRY003)

Comment on lines +52 to +129
dtype_abbrv = {
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"int8": "int8",
"int32": "int32",
"float8_e4m3": "e4m3",
"float8_e5m2": "e5m2",
}

E_FACTOR_MAP = { # e_kdim = mma_kdim // e_factor
"float": {
"int16": 8,
"uint16": 8,
},
"float32": {
"int16": 8,
"uint16": 8,
},
"float16": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
"bfloat16": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
"int8": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
"uint8": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
"float8_e4m3": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
"float8_e5m2": {
"int8": 8,
"uint8": 8,
"int16": 16,
"uint16": 16,
"int32": 32,
"uint32": 32,
},
}

E_REPLICATE_FACTOR = { # metadata replicate every 4 consecutive threads
"float32": 2,
"float16": 2, # 2 of 4 consecutive threads provides
"bfloat16": 2,
"int8": 1, # 4 of 4 consecutive threads provides
"uint8": 1,
"float8_e4m3": 1,
"float8_e5m2": 1,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add tf32 dtype metadata to emitter tables.

E_FACTOR_MAP already exposes "float" entries, but dtype_abbrv and E_REPLICATE_FACTOR don’t, so creating the emitter with in_dtype="float" (tf32) raises a KeyError. That breaks the advertised tf32 path. Please add the missing mappings so we can actually emit tf32 kernels.

Apply this diff:

     dtype_abbrv = {
+        "float": "tf32",
         "float16": "fp16",
         "bfloat16": "bf16",
         "float32": "fp32",
@@
     E_REPLICATE_FACTOR = {  # metadata replicate every 4 consecutive threads
+        "float": 2,
         "float32": 2,
         "float16": 2,  # 2 of 4 consecutive threads provides
         "bfloat16": 2,
🧰 Tools
🪛 Ruff (0.14.3)

52-60: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


62-119: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


121-129: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)

🤖 Prompt for AI Agents
In tilelang/intrinsics/mma_sp_macro_generator.py around lines 52 to 129, the
emitter tables are missing entries for the tf32 dtype key "float", causing a
KeyError when in_dtype="float"; add "float": "tf32" to the dtype_abbrv mapping
and add "float": 2 to E_REPLICATE_FACTOR (matching float32 behavior) so the
existing E_FACTOR_MAP "float" entries can be used without error.

@botbw botbw marked this pull request as draft November 10, 2025 04:10
@LeiWang1999 LeiWang1999 marked this pull request as ready for review November 11, 2025 07:38
@LeiWang1999
Copy link
Member

we're good to go if we can resolve the conflict and I think then we can let this pr in.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (2)

5-11: Scope TF32 configuration to avoid global side effects

Setting torch.backends.cuda.matmul.allow_tf32 = False at module import time affects all CUDA matmul calls in the test process, not just this module. Consider moving this into the specific tests (or a context manager) so other tests that rely on TF32 behavior are not implicitly changed.


81-103: Metadata layout wiring: sm80 is nicely aligned to emitter; please double‑check sm90 dtype choice

For sm80, deriving metadata_dtype (int32 for 8‑bit, otherwise int16) and E_factor from SparseTensorCoreIntrinEmitter.E_FACTOR_MAP, then passing mma_dtype=in_dtype into make_cutlass_metadata_layout, is a good match to the intrinsics’ expectations and the CUTLASS-style layout pattern. This aligns with the established “layout depends on MMA element type” approach. Based on learnings.

For sm90, E/E_shared are hard‑coded as uint8 and make_cutlass_metadata_layout(..., mma_dtype="float16", arch="9.0", block_k=block_K) is used even when in_dtype is int8 or float8_e4m3. If the Hopper metadata layout truly varies with the MMA operand type (as it does on sm80), you may want to mirror the sm80 pattern and drive mma_dtype (and possibly metadata dtype/E_factor) from in_dtype as well.

Please double‑check that:

  • The E_factor and metadata dtype used here for each supported in_dtype match what SparseTensorCoreIntrinEmitter expects on sm90, and
  • mma_dtype="float16" is correct for non‑fp16 inputs, or adjust it to be in_dtype where appropriate.

Also applies to: 124-148

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ada178c and f6dc8c0.

📒 Files selected for processing (1)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (7 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
🧬 Code graph analysis (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (4)
tilelang/utils/sparse.py (2)
  • randn_semi_sparse (109-128)
  • randint_semi_sparse (131-158)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (136-150)
tilelang/utils/tensor.py (2)
  • torch_assert_close (237-329)
  • map_torch_type (37-54)
tilelang/intrinsics/mma_sp_macro_generator.py (1)
  • SparseTensorCoreIntrinEmitter (40-858)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (2)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (2)

14-44: Dense input generator looks consistent with kernel/transposition contracts

The shape/dtype logic in generate_dense_input matches the later _matmul conventions: after applying trans_A/trans_B, you always get (M, K) @ (K, N), and map_torch_type(in_dtype) keeps A/B consistent across integer and float/float8 cases. This should keep the tests well-aligned with the kernels.


195-202: Centralized input generation and comparison path looks correct and robust

Reusing generate_dense_input in run_gemm_sp and comparing C_sp vs dense C via _matmul plus torch_assert_close on float32 (with explicit upcast for float8/int8) is a solid improvement over ad‑hoc checks. Shapes line up with the kernels, and the tolerances (rtol=atol=1e-3) are reasonable for these dtypes.

Also applies to: 215-226

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (1)

166-178: Remove unused helper functions normalize and calc_diff.

Search results confirm these functions are not called anywhere in the codebase. They appear to be leftover from prior validation logic now handled by torch_assert_close. Remove them to reduce unnecessary code.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f6dc8c0 and 1022c9c.

📒 Files selected for processing (1)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (8 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
🧬 Code graph analysis (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (4)
tilelang/utils/sparse.py (3)
  • compress (77-106)
  • randn_semi_sparse (109-128)
  • randint_semi_sparse (131-158)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (136-150)
tilelang/utils/tensor.py (2)
  • torch_assert_close (237-329)
  • map_torch_type (37-54)
tilelang/intrinsics/mma_sp_macro_generator.py (1)
  • SparseTensorCoreIntrinEmitter (40-858)
🔇 Additional comments (5)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (5)

5-12: LGTM!

The new imports align with the gemm_sp_v2 updates, and disabling TF32 ensures consistent numerical behavior in tests.


14-44: LGTM!

The input generation logic correctly handles different data types, using appropriate ranges for integer types and generating float types via float32 intermediates to support float8 variants.


124-147: LGTM!

The dynamic determination of metadata_dtype and E_factor from SparseTensorCoreIntrinEmitter.E_FACTOR_MAP, along with passing mma_dtype=in_dtype to the layout helper, provides appropriate flexibility for various input data types on SM80.


195-226: LGTM!

The refactored validation logic using torch_assert_close is cleaner and more maintainable than the previous comparison approach, with appropriate tolerances for sparse matrix multiplication tests.


316-368: LGTM!

The test suite provides comprehensive coverage across architectures (SM80/SM90), data types (float16, float8_e4m3, int8), and configurations (various block sizes, transpositions, and pipeline stages).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (1)

38-43: Consider passing target dtype directly to avoid redundant conversion.

The randn_semi_sparse function already handles dtype conversion internally (creates in float32, applies sparsity mask, converts to target dtype). Passing dtype=torch.float32 then calling .to(map_torch_type(in_dtype)) adds an extra no-op conversion step when the input is already float32.

Apply this diff to streamline the conversion:

-    A = randn_semi_sparse(
-        M, K, dtype=torch.float32, device='cuda',
-        transposed=trans_A).to(map_torch_type(in_dtype))
-    B = torch.randn(
-        (N, K) if trans_B else (K, N), device='cuda',
-        dtype=torch.float32).to(map_torch_type(in_dtype))
+    A = randn_semi_sparse(
+        M, K, dtype=map_torch_type(in_dtype), device='cuda',
+        transposed=trans_A)
+    B = torch.randn(
+        (N, K) if trans_B else (K, N), device='cuda',
+        dtype=map_torch_type(in_dtype))
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1022c9c and c28a687.

📒 Files selected for processing (1)
  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (7 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py
🧬 Code graph analysis (1)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (6)
tilelang/utils/sparse.py (3)
  • compress (77-106)
  • randn_semi_sparse (109-128)
  • randint_semi_sparse (131-158)
tilelang/layout/gemm_sp.py (1)
  • make_cutlass_metadata_layout (136-150)
tilelang/utils/tensor.py (2)
  • torch_assert_close (237-329)
  • map_torch_type (37-54)
tilelang/intrinsics/mma_sp_macro_generator.py (1)
  • SparseTensorCoreIntrinEmitter (40-858)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp_v2.py (1)
  • generate_dense_input (136-165)
tilelang/language/experimental/gemm_sp.py (1)
  • gemm_sp (10-87)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (7)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (7)

10-10: Good practice: TF32 disabled for test precision.

Disabling TF32 ensures numerical accuracy when comparing sparse GEMM results against dense reference implementations.


84-88: Past issue resolved: mma_dtype now correctly uses in_dtype.

The previous review flagged hardcoded mma_dtype="float16" which would produce incorrect metadata layouts for int8/float8 dtypes. This has been fixed—mma_dtype=in_dtype ensures the layout atoms and BlockK are computed correctly for all supported dtypes.


124-125: Good: Dynamic metadata dtype selection for SM80.

The metadata_dtype is now correctly selected based on input bit width (int32 for 8-bit, int16 for 16-bit), with E_factor computed from the intrinsic emitter's lookup table. This enables proper support for int8 and other 8-bit dtypes.


146-147: Past issue resolved: SM80 metadata layout now uses in_dtype.

Similar to SM90, the hardcoded mma_dtype="float16" has been replaced with mma_dtype=in_dtype, ensuring correct metadata layout generation for all supported dtypes including int8 (tested at lines 366-372).


223-230: Good: Unified validation using torch_assert_close.

Replacing the ad-hoc diff calculation with torch_assert_close provides consistent, configurable comparison across dtypes. The float32 normalization ensures compatibility when comparing results from different accumulation types.


342-344: Good test coverage: Verifies int8/float8 support on SM90.

These tests validate that the fixed mma_dtype=in_dtype correctly handles 8-bit dtypes with appropriate metadata layouts, addressing the past review concerns.


366-372: Good test coverage: Verifies int8 support on SM80.

Comprehensive int8 testing across multiple configurations validates the dynamic metadata_dtype selection and fixed mma_dtype parameter, ensuring correct operation for 8-bit dtypes on SM80 architecture.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tilelang/layout/gemm_sp.py (1)

115-120: Consider extracting exception messages into constants.

Static analysis suggests avoiding long exception messages directly in raise statements for better maintainability.

Example refactor:

+METADATA_BIT_MISMATCH_MSG_16 = "metadata should be 16 bit, got {}"
+METADATA_BIT_MISMATCH_MSG_32 = "metadata should be 32 bit, got {}"
+
 if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]:
-    raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}")
+    raise ValueError(METADATA_BIT_MISMATCH_MSG_16.format(buffer.dtype))

 if mma_dtype in ["float8_e4m3", "float8_e5m2", "int8", "uint8"
                 ] and buffer.dtype not in ["uint32", "int32"]:
-    raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}")
+    raise ValueError(METADATA_BIT_MISMATCH_MSG_32.format(buffer.dtype))

Based on learnings

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c28a687 and a1e734d.

📒 Files selected for processing (1)
  • tilelang/layout/gemm_sp.py (4 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • tilelang/layout/gemm_sp.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • tilelang/layout/gemm_sp.py
🧬 Code graph analysis (1)
tilelang/layout/gemm_sp.py (2)
tilelang/layout/layout.py (1)
  • Layout (13-148)
tilelang/contrib/nvcc.py (2)
  • get_target_compute_version (409-450)
  • parse_compute_version (453-475)
🪛 Ruff (0.14.4)
tilelang/layout/gemm_sp.py

116-116: Avoid specifying long messages outside the exception class

(TRY003)


120-120: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (5)
tilelang/layout/gemm_sp.py (5)

34-34: LGTM! FP8 dtype support correctly added.

The extension to support "float8_e4m3" and "float8_e5m2" aligns with the PR objectives and is correctly implemented with 8-bit mappings.

Also applies to: 45-46


138-151: Clarify the handling of extra_args for SM8x architecture.

The dispatcher passes **extra_args to the SM90 implementation (line 148) but not to the SM8x implementation (line 150). While the SM8x implementation doesn't currently accept block_k, this could cause silent failures if callers provide extra arguments expecting them to be used.

Consider one of the following:

  1. If SM8x truly doesn't need extra_args: Add explicit validation to warn users if extra_args are provided for SM8x:
if compute_version >= (9, 0):
    return make_cutlass_metadata_layout_sm90(buffer=buffer, mma_dtype=mma_dtype, **extra_args)
elif compute_version >= (8, 0):
    if extra_args:
        warnings.warn(f"extra_args {extra_args} ignored for SM8x architecture", stacklevel=2)
    return make_cutlass_metadata_layout_sm8x(buffer=buffer, mma_dtype=mma_dtype)
  1. If SM8x should accept and ignore extra_args for API consistency: Update the SM8x signature:
def make_cutlass_metadata_layout_sm8x(buffer: tvm.tir.Buffer, mma_dtype: str, **extra_args):
    # Ignore extra_args for SM8x
    ...

Which approach aligns better with the intended API design?


21-21: LGTM! Public API exposure provides useful flexibility.

Making make_cutlass_metadata_layout_sm90 and make_cutlass_metadata_layout_sm8x public allows advanced users to directly call architecture-specific implementations when needed, while the dispatcher function provides a convenient high-level API.

Also applies to: 108-108


118-120: LGTM! FP8/int8 validation correctly enforces 32-bit buffer requirement.

The validation ensures that fp8 and int8 sparse metadata uses 32-bit buffers on SM8x, which is the correct requirement for these data types.


122-133: Implementation verified against PyTorch reference—no issues found.

The new ColumnMajorInterleaved function correctly implements the CUTLASS metadata layout algorithm. Comparison with the PyTorch reference confirms:

  • The group/interweave parameters match (bits == 16 → itemsize == 2)
  • The row reordering formula is identical
  • The topright/bottomleft swizzle logic is equivalent (using bitwise arithmetic rather than boolean comparison)
  • The final offset computation matches the reference; returning (offset // k, offset % k) appropriately converts the flat offset to matrix indices

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (6)
src/op/gemm_sp_py.cc (3)

68-68: Use GemmSPWarpPolicy instead of GemmWarpPolicy.

The policy should be constructed as GemmSPWarpPolicy to use sparse-specific warp partition logic, as noted in previous reviews. The header (src/op/gemm_sp_py.h:36) declares policy as GemmWarpPolicy, but the sparse GEMM path requires GemmSPWarpPolicy for correct atom-size adjustments.

Apply this diff:

-  node->policy = GemmWarpPolicy(args[10].as<IntImm>().value()->value);
+  node->policy = GemmSPWarpPolicy(args[10].as<IntImm>().value()->value);

213-224: Add error handling for std::stoi.

If the architecture string contains non-numeric characters after "sm_" (e.g., "sm_abc"), std::stoi will throw std::invalid_argument, potentially causing crashes. This was previously flagged but remains unaddressed.

Apply this diff:

 static int GetArchInt(Target target) {
   int arch_int = 0;
   auto s = target->GetAttr<String>("arch");
   ICHECK(s.has_value());
   std::string arch = s.value();
   if (arch.rfind("sm_", 0) == 0) {
-    arch_int = std::stoi(arch.substr(3));
+    try {
+      arch_int = std::stoi(arch.substr(3));
+    } catch (const std::exception& e) {
+      LOG(WARNING) << "Failed to parse architecture from '" << arch << "': " << e.what();
+      arch_int = 0;
+    }
   } else {
     arch_int = 0;
   }
   return arch_int;
 }

226-231: Fix computeWarpPartition call signature.

The call to policy->computeWarpPartition passes a GemmInst enum where a bool use_wgmma is expected, and omits the required bits parameter. This was previously flagged but remains unaddressed.

Apply this diff:

   auto block_size = *as_const_int(T.thread_bounds->extent);
   GemmInst gemm_inst = GetGemmInst(block_size, T.target);
 
+  bool use_wgmma = (gemm_inst == GemmInst::kWGMMA);
   auto [warp_m, warp_n] =
-      policy->computeWarpPartition(M, N, block_size, T.target, gemm_inst);
+      policy->computeWarpPartition(M, N, block_size, T.target, use_wgmma, A->dtype.bits());
src/op/gemm_sp_py.h (1)

36-36: Declare policy as GemmSPWarpPolicy instead of GemmWarpPolicy.

The policy should be declared as mutable GemmSPWarpPolicy to use sparse-specific warp partition logic with atom-size adjustments. This was previously flagged but remains unaddressed.

Apply this diff:

-  mutable GemmWarpPolicy policy;
+  mutable GemmSPWarpPolicy policy;
tilelang/intrinsics/mma_sp_macro_generator.py (2)

52-129: Add tf32 dtype metadata to emitter tables.

E_FACTOR_MAP supports "float" entries (lines 63-66), but dtype_abbrv (lines 52-60) and E_REPLICATE_FACTOR (lines 121-129) don't, causing a KeyError when creating an emitter with in_dtype="float" (tf32). This was previously flagged but remains unaddressed.

Apply this diff:

     dtype_abbrv = {
+        "float": "tf32",
         "float16": "fp16",
         "bfloat16": "bf16",
         "float32": "fp32",
     E_REPLICATE_FACTOR = {  # metadata replicate every 4 consecutive threads
+        "float": 2,
         "float32": 2,
         "float16": 2,  # 2 of 4 consecutive threads provides
         "bfloat16": 2,

423-435: Guard ldmatrix for int8 + transpose correctly.

The comment at line 434 states ldmatrix is unavailable for int8 with transposed layout, but the condition at line 435 disables ldmatrix for the opposite case (int8 + not transposed). This was previously flagged but remains unaddressed.

Apply this diff:

-        ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed)
+        ldmatrix_available = not (DataType(b_dtype).bits != 16 and b_transposed)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5d5cf85 and befe30e.

📒 Files selected for processing (6)
  • src/op/gemm_sp_py.cc (1 hunks)
  • src/op/gemm_sp_py.h (1 hunks)
  • tilelang/intrinsics/mma_sp_macro_generator.py (1 hunks)
  • tilelang/language/experimental/gemm_sp.py (3 hunks)
  • tilelang/tileop/gemm_sp/gemm_sp_base.py (1 hunks)
  • tilelang/tileop/gemm_sp/gemm_sp_mma.py (1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/op/gemm_sp_py.cc
  • tilelang/intrinsics/mma_sp_macro_generator.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • src/op/gemm_sp_py.cc
  • tilelang/intrinsics/mma_sp_macro_generator.py
🧬 Code graph analysis (4)
src/op/gemm_sp_py.h (7)
src/op/gemm_sp.h (2)
  • tl (15-111)
  • RegisterReflection (27-33)
src/op/operator.h (2)
  • TileOperatorNode (55-65)
  • TileOperator (67-71)
src/op/gemm_sp_py.cc (11)
  • CheckWGMMA (145-197)
  • CheckWGMMA (145-145)
  • Lower (226-260)
  • Lower (226-226)
  • InferLayout (262-277)
  • InferLayout (262-263)
  • Clone (94-97)
  • Clone (94-94)
  • GetGemmInst (99-113)
  • GetGemmInst (99-99)
  • GemmSPPy (51-84)
tilelang/tileop/gemm_sp/gemm_sp_base.py (17)
  • A (70-71)
  • E (74-75)
  • B (78-79)
  • C (82-83)
  • trans_A (45-46)
  • trans_B (49-50)
  • trans_E (53-54)
  • M (33-34)
  • N (37-38)
  • K (41-42)
  • stride_A (102-103)
  • stride_B (106-107)
  • offset_A (110-111)
  • offset_B (114-115)
  • clear_accum (118-119)
  • wg_wait (126-127)
  • policy (130-131)
src/op/gemm.h (6)
  • GemmWarpPolicy (59-83)
  • GemmWarpPolicy (64-68)
  • GemmWarpPolicy (70-74)
  • GemmWarpPolicy (76-82)
  • RegisterReflection (35-41)
  • RegisterReflection (106-128)
tilelang/ir.py (1)
  • GemmWarpPolicy (30-39)
tilelang/tileop/gemm_sp/__init__.py (1)
  • GemmSPPy (29-69)
tilelang/language/experimental/gemm_sp.py (2)
tilelang/utils/language.py (1)
  • get_buffer_region_from_load (162-184)
tilelang/language/gemm.py (1)
  • legalize_arguments (36-47)
src/op/gemm_sp_py.cc (4)
src/op/gemm_sp_py.h (2)
  • GemmSPPy (81-87)
  • RegisterReflection (41-66)
tilelang/tileop/gemm_sp/__init__.py (1)
  • GemmSPPy (29-69)
tilelang/tileop/gemm_sp/gemm_sp_base.py (3)
  • policy (130-131)
  • M (33-34)
  • N (37-38)
src/op/gemm_sp.h (1)
  • RegisterReflection (27-33)
tilelang/intrinsics/mma_sp_macro_generator.py (6)
tilelang/intrinsics/utils.py (2)
  • mma_store_index_map (82-83)
  • get_ldmatrix_offset (22-64)
tilelang/utils/language.py (1)
  • is_fragment (105-116)
tilelang/intrinsics/mma_sp_layout.py (18)
  • shared_16x16_to_mma_sp_layout_sr_a (14-15)
  • shared_16x16_to_mma_sp_layout_sr_b (18-20)
  • shared_16x32_to_mma_sp_layout_sr_a (23-24)
  • shared_16x32_to_mma_sp_layout_sr_b (27-29)
  • shared_16x64_to_mma_sp_layout_sr_a (32-33)
  • shared_16x64_to_mma_sp_layout_sr_b (36-38)
  • mma_sp_load_a_32x4_to_shared_16x16_layout (41-42)
  • mma_sp_load_a_32x8_to_shared_16x32_layout (45-46)
  • mma_sp_load_a_32x16_to_shared_16x64_layout (49-50)
  • mma_sp_load_b_32x8_to_shared_16x16_layout (53-56)
  • mma_sp_load_b_32x16_to_shared_16x32_layout (59-62)
  • mma_sp_load_b_32x32_to_shared_16x64_layout (65-68)
  • metadata_8bit_load_32x4_to_shared_16x4_layout_32bit (75-80)
  • metadata_16bit_load_32x2_to_shared_16x2_layout_32bit (83-88)
  • metadata_8bit_load_32x4_to_shared_16x4_layout_16bit (91-94)
  • metadata_16bit_load_32x2_to_shared_16x2_layout_16bit (97-100)
  • metadata_32bit_load_32x1_to_shared_16x2_layout_8bit (123-129)
  • get_ldmatrix_offset_b (156-190)
tilelang/intrinsics/mma_macro_generator.py (4)
  • _initialize_k_dim (106-109)
  • _initialize_k_dim (849-850)
  • _initialize_abbrev (116-119)
  • _initialize_abbrev (857-860)
tilelang/language/tir/op.py (2)
  • ptx_ldmatrix (1313-1349)
  • ptx_mma_sp (964-1062)
tilelang/layout/fragment.py (2)
  • replicate (147-161)
  • repeat (124-145)
🪛 Ruff (0.14.5)
tilelang/language/experimental/gemm_sp.py

163-164: Prefer TypeError exception for invalid type

(TRY004)


163-164: Avoid specifying long messages outside the exception class

(TRY003)


191-192: Prefer TypeError exception for invalid type

(TRY004)


191-192: Avoid specifying long messages outside the exception class

(TRY003)


255-256: Prefer TypeError exception for invalid type

(TRY004)


255-256: Avoid specifying long messages outside the exception class

(TRY003)


275-276: Prefer TypeError exception for invalid type

(TRY004)


275-276: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/intrinsics/mma_sp_macro_generator.py

52-60: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


62-119: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


121-129: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


180-182: Avoid specifying long messages outside the exception class

(TRY003)


214-214: Avoid specifying long messages outside the exception class

(TRY003)


315-315: Avoid specifying long messages outside the exception class

(TRY003)


384-384: Avoid specifying long messages outside the exception class

(TRY003)


393-393: Avoid specifying long messages outside the exception class

(TRY003)


398-398: Avoid specifying long messages outside the exception class

(TRY003)


400-400: Avoid specifying long messages outside the exception class

(TRY003)


448-448: Avoid specifying long messages outside the exception class

(TRY003)


699-699: Avoid specifying long messages outside the exception class

(TRY003)


716-716: Avoid specifying long messages outside the exception class

(TRY003)


774-774: Avoid specifying long messages outside the exception class

(TRY003)


788-788: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/tileop/gemm_sp/gemm_sp_mma.py

58-59: Avoid specifying long messages outside the exception class

(TRY003)


234-235: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (3)
tilelang/language/experimental/gemm_sp.py (1)

92-310: LGTM!

The gemm_sp_v2 implementation correctly handles input normalization, shape/stride/offset retrieval, and validation. The helper functions appropriately support Buffer, BufferRegion, and BufferLoad types, and the extensive error messages aid debugging.

tilelang/tileop/gemm_sp/gemm_sp_mma.py (1)

1-236: LGTM!

The GemmSPMMA class correctly implements sparse tensor-core MMA tiling with proper layout inference and lowering for all four dataflow patterns (ss, sr, rs, rr). The per-variant kernel generation properly allocates local buffers and orchestrates load/MMA operations.

tilelang/tileop/gemm_sp/gemm_sp_base.py (1)

1-131: LGTM!

The GemmSPBase class provides a clean abstraction with proper delegation to the underlying node. The property accessors correctly expose all necessary fields, and the dtype consistency check ensures A and B have matching types.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (3)
tilelang/layout/gemm_sp.py (1)

115-120: Optional: Consider custom exception class for metadata validation errors.

Static analysis suggests avoiding long messages in exception constructors. However, for this use case, the descriptive error messages are more valuable than strict adherence to TRY003. This is purely a style consideration and can be safely deferred.

tilelang/language/experimental/gemm_sp.py (2)

128-139: Extract duplicated legalize_arguments helper to avoid code duplication.

The legalize_arguments helper is identical to the one defined in the gemm_sp function (lines 46-57). Consider extracting it to module level to follow the DRY principle.

Apply this diff to extract the helper:

+def _legalize_arguments(arg: tir.Buffer | tir.Var):
+    """Convert let-bound variables to their corresponding buffers.
+
+    Args:
+        arg (Union[tir.Buffer, tir.Var]): Input argument to legalize
+
+    Returns:
+        Union[tir.Buffer, tir.Var]: The legalized argument
+    """
+    if isinstance(arg, tir.Var) and T.has_let_value(arg):
+        return T.get_let_value(arg).buffer
+    return arg
+
+
 def gemm_sp(
     A_sparse: tir.Buffer | tir.Var,

Then replace the nested helper definitions with calls to _legalize_arguments in both functions.


234-234: Minor typo in comment.

Apply this diff:

-            # not offset the last two dimension
+            # not offset the last two dimensions
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between befe30e and 34d1af6.

📒 Files selected for processing (2)
  • tilelang/language/experimental/gemm_sp.py (2 hunks)
  • tilelang/layout/gemm_sp.py (5 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • tilelang/layout/gemm_sp.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • tilelang/layout/gemm_sp.py
🧬 Code graph analysis (1)
tilelang/layout/gemm_sp.py (1)
tilelang/contrib/nvcc.py (2)
  • get_target_compute_version (409-450)
  • parse_compute_version (453-475)
🪛 Ruff (0.14.5)
tilelang/language/experimental/gemm_sp.py

162-163: Prefer TypeError exception for invalid type

(TRY004)


162-163: Avoid specifying long messages outside the exception class

(TRY003)


190-191: Prefer TypeError exception for invalid type

(TRY004)


190-191: Avoid specifying long messages outside the exception class

(TRY003)


254-255: Prefer TypeError exception for invalid type

(TRY004)


254-255: Avoid specifying long messages outside the exception class

(TRY003)


274-275: Prefer TypeError exception for invalid type

(TRY004)


274-275: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/layout/gemm_sp.py

116-116: Avoid specifying long messages outside the exception class

(TRY003)


120-120: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (9)
tilelang/layout/gemm_sp.py (5)

108-124: Approve SM8x metadata layout with dtype-specific bit widths.

The function correctly handles different metadata bit widths: 16-bit for float16/bfloat16 and 32-bit for float8/int8 types. The group and interweave parameters adjust accordingly (32/4 for 16-bit, 16/2 for 32-bit). The PyTorch reference is helpful for verification.


138-152: LGTM: Public dispatcher with architecture-aware routing.

The dispatcher correctly routes to SM90 or SM8x implementations based on compute version. The selective passing of **extra_args to SM90 only is appropriate since SM8x doesn't require the block_k parameter. Architecture auto-detection provides good ergonomics.


79-83: I'll help you verify the 8-bit layout calculations in the review comment. Let me start by examining the code context and searching for related test cases and CUTLASS references.
<function_calls>
<invoke_name>shell
<invoke_name>#!/bin/bash

First, let's examine the file structure and find the relevant code

fd -t f "gemm_sp.py" | head -5
</invoke_name>
</function_calls>
<invoke_name>shell
<invoke_name>#!/bin/bash

Read the gemm_sp.py file to understand the context

wc -l tilelang/layout/gemm_sp.py
</invoke_name>
</function_calls>
<function_calls>
<invoke_name>shell
<invoke_name>#!/bin/bash

Read lines around 79-83 to see the full context

sed -n '60,100p' tilelang/layout/gemm_sp.py
</invoke_name>
</function_calls>
<function_calls>
<invoke_name>shell
<invoke_name>#!/bin/bash

Search for test files related to gemm_sp or sparse layout

fd -t f "test.*gemm_sp|gemm_sp.*test"
</invoke_name>
</function_calls>
<function_calls>
<invoke_name>shell
<invoke_name>#!/bin/bash

Search for files that test or use make_cutlass_metadata_layout with 8-bit types

rg -nP --type=py "make_cutlass_metadata_layout.*float8|make_cutlass_metadata_layout.*int8|float8_e4m3fn|int8" --max-count=20 | head -30
</invoke_name>
</function_calls>
<function_calls>
<invoke_name>shell
<invoke_name>#!/bin/bash

Search for any references to CUTLASS sparse metadata or SM90

rg -nP "SM90|sparse.*metadata|cutlass.*sparse" --type=py | head -20
</invoke_name>
</function_calls>
<function_calls>
<invoke_name>shell
<invoke_name>#!/bin/bash

Look for documentation or comments about the 8-bit layout calculation

rg -nP -B5 -A5 "block_k.*8|8.*bit.*layout" tilelang/layout/gemm_sp.py
</invoke_name>
</function_calls>

Let me get those results and continue the investigation.


21-47: Float8 support for CUTLASS SM90 is confirmed and correctly implemented.

The web search confirms that CUTLASS SM90 fully supports float8_e4m3 and float8_e5m2 for sparse GEMM operations. cuSPARSELt includes Hopper (SM 9.0) sparse GEMM kernels for E4M3 and E5M2 inputs/outputs, and CUTLASS provides TensorOp and SpTensorOp kernels targeting SM90. The code changes adding these data types to the supported list are correct and compatible with the underlying CUTLASS kernels.


126-135: Based on my web search and analysis, I can now provide my findings:

Mathematical Verification:
The offset calculation in the code snippet follows the CUTLASS ColumnMajorInterleaved layout pattern. The formula (j // 2) * m * 2 + i * 2 + (j % 2) is consistent with CUTLASS's column-major interleaved mapping for k=2 (stride-2 interleaving), where:

  • (j // 2) * m * 2 = column index × padded row extent
  • i * 2 + (j % 2) = per-element offset components

The topright/bottomleft bit manipulation logic is a standard pattern used in sparse tensor libraries (PyTorch, CUTLASS, cuSPARSELt) to encode position mappings within 4-element windows for 2:4 semi-structured sparsity.

Verification Status:

  • ✓ Offset calculation structure matches known CUTLASS/PyTorch patterns
  • ✓ Bit manipulation approach is standard in sparse tensor encodings
  • ⚠ Full correctness verification requires test cases (unable to access codebase directly)

The algorithm is mathematically sound and follows industry-standard patterns for semi-structured sparse metadata layout. However, manual testing or integration tests would be needed to confirm the implementation works correctly with the specific parameters (group, interweave, m, k) and sparsity patterns used in TileLang.


Verify the rewritten ColumnMajorInterleaved logic against PyTorch and CUTLASS sparse metadata patterns. The offset calculation follows the standard column-major interleaved layout (matching CUTLASS documentation), and the topright/bottomleft bit manipulation is a known pattern in sparse tensor libraries for encoding position indices in 2:4 semi-structured sparsity. The algorithm is mathematically sound based on industry references. Recommend validating with existing test cases to confirm parameter values (group, interweave, m, k) and sparse pattern compatibility.

tilelang/language/experimental/gemm_sp.py (4)

200-215: LGTM: Shape validation logic is sound.

The shape validation correctly handles 2D matrices and higher-order tensors with appropriate checks. The K dimension calculation accounts for the 2:4 sparsity pattern (factor of 2 on line 213) and validates consistency between matrices A and B.


277-282: LGTM: Offset validation ensures proper matrix alignment.

The assertions correctly enforce that only the last dimension can be offset, which aligns with GEMM semantics where the M dimension of A and K dimension of B must be fully utilized.


284-309: LGTM: Intrinsic call properly assembles all parameters.

The intrinsic call correctly passes all computed values including pointers, dimensions, strides, and offsets. The expanded parameter set compared to gemm_sp provides greater flexibility for handling various buffer layouts and access patterns.


146-146: Python version compatibility for list[int] annotations requires project-specific verification.

The review comment's concern is technically sound: the lowercase list[int] syntax requires Python 3.9+ (PEP 585). However, I was unable to access the repository to verify:

  • The project's declared Python version requirement (python_requires in setup.py/pyproject.toml)
  • Current imports in tilelang/language/experimental/gemm_sp.py
  • Whether List from typing is already imported
  • Usage consistency across the codebase

To complete this verification, check:

  1. Project's minimum Python version in pyproject.toml or setup.py
  2. If Python < 3.9, apply the suggested changes (import List from typing and use List[int])
  3. If Python ≥ 3.9, the current list[int] syntax is correct and preferred

Comment on lines +104 to +126
"""Perform a General Matrix Multiplication (GEMM) operation.
This function computes C = A @ B where A and B can optionally be transposed.
The operation supports various warp policies and accumulation modes.
Args:
A_sparse (Union[tir.Buffer, tir.Var]): First input matrix, contains only non-zero elements
E (Union[tir.Buffer, tir.Var]): The metadata of A_sparse, noted as E
B (Union[tir.Buffer, tir.Var]): Second input matrix
C (Union[tir.Buffer, tir.Var]): Output matrix for results
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional): Warp group wait count. Defaults to 0.
Returns:
tir.Call: A handle to the GEMM operation
Raises:
AssertionError: If the K dimensions of matrices A and B don't match
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix docstring inaccuracies.

The docstring has two issues:

  1. Line 104 says "General Matrix Multiplication" but this is a sparse GEMM operation—it should say "Sparse General Matrix Multiplication (GEMM-sp)".
  2. The transpose_E parameter (line 98) is missing from the docstring's Args section.

Apply this diff to fix the docstring:

-    """Perform a General Matrix Multiplication (GEMM) operation.
+    """Perform a Sparse General Matrix Multiplication (GEMM-sp) operation.
 
     This function computes C = A @ B where A and B can optionally be transposed.
     The operation supports various warp policies and accumulation modes.
 
     Args:
         A_sparse (Union[tir.Buffer, tir.Var]): First input matrix, contains only non-zero elements
         E (Union[tir.Buffer, tir.Var]): The metadata of A_sparse, noted as E
         B (Union[tir.Buffer, tir.Var]): Second input matrix
         C (Union[tir.Buffer, tir.Var]): Output matrix for results
         transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
         transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
+        transpose_E (bool, optional): Whether to transpose metadata matrix E. Defaults to False.
         policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.
🤖 Prompt for AI Agents
In tilelang/language/experimental/gemm_sp.py around lines 104 to 126, the
docstring title and Args are inaccurate: change the short description from
"General Matrix Multiplication" to "Sparse General Matrix Multiplication
(GEMM-sp)" and add the missing transpose_E parameter to the Args list (e.g.,
transpose_E (bool, optional): Whether to transpose metadata E. Defaults to
False.), keeping style and default value consistent with transpose_A/transpose_B
entries.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (5)
tilelang/language/experimental/gemm_sp.py (3)

105-105: Update docstring title to reflect sparse GEMM operation.

The docstring says "General Matrix Multiplication" but this function performs a sparse GEMM operation.

Apply this diff:

-    """Perform a General Matrix Multiplication (GEMM) operation.
+    """Perform a Sparse General Matrix Multiplication (GEMM-sp) operation.

Based on past review comments, this issue was previously flagged but the fix may have been incomplete.


115-120: Document the transpose_E parameter in the docstring.

The function signature includes transpose_E (line 99), but it's missing from the Args section of the docstring.

Apply this diff:

         transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
         transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
+        transpose_E (bool, optional): Whether to transpose metadata matrix E. Defaults to False.
         policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square.

258-258: Fix incorrect return type annotation.

The return type is annotated as tir.PrimExpr, but the function returns list[int] at lines 261, 267, and 273.

Apply this diff:

-    def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr:
+    def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> list[int]:
         """Retrieve the offset of the buffer or buffer region."""

Based on past review comments, this issue was previously identified but not yet addressed.

src/op/gemm_sp.cc (1)

321-322: Return the computed warp partition from the FFI binding.

The lambda calls policy->computeWarpPartition(...) which returns std::pair<int, int> (see lines 22-23 and line 63), but the result is discarded and the lambda returns void. Python callers expect to receive the computed (m_warp, n_warp) values.

Apply this diff:

-        policy->computeWarpPartition(M, N, block_size, target, use_wgmma, bits);
-        return;
+        return policy->computeWarpPartition(M, N, block_size, target, use_wgmma, bits);

Note: A past review comment indicated this was addressed in commits 354e9af to b2871dd, but the current code still shows the issue. Please verify if the fix was applied to a different branch or if this needs to be corrected.

tilelang/intrinsics/mma_macro_generator.py (1)

424-425: Verify reachability: 16-bit dtype in non-ldmatrix path appears unreachable.

Similar to matrix A, the condition at line 416 ensures ldmatrix_available = True for all 16-bit dtypes (the condition simplifies to bits == 16 or b_transposed). This makes the check at line 424 unreachable.

This issue mirrors the one flagged for matrix A at lines 296-297. If the earlier verification confirms this is intentional for future sparse support, this code should also be documented accordingly.

🧹 Nitpick comments (3)
src/tl_templates/cuda/debug.h (1)

111-119: Specialization is fine; consider trait-based reuse and RTC include check

The uint16_t specialization looks correct and the %u + uint32_t cast is safe. Two small follow‑ups:

  • For consistency with the rest of this file, you might instead specialize PrintTraits<uint16_t> (or alias it to PrintTraits<unsigned short>) and let the generic debug_print_buffer_value call through the trait, so debug_print_var<uint16_t> also prints a meaningful value instead of going through the generic “dtype=unknown” path.
  • This specialization introduces an unconditional use of uint16_t/uint32_t; please double‑check that these typedefs are available in all compile modes (especially under __CUDACC_RTC__, where <cstdint> is currently skipped) so NVRTC builds don’t fail.
tilelang/language/experimental/gemm_sp.py (1)

129-140: Consider extracting the duplicated legalize_arguments helper.

The legalize_arguments function (lines 129-140) is identical to the one in gemm_sp (lines 46-57 in tilelang/language/gemm.py per relevant snippets). Extracting this to a shared utility would reduce duplication.

You could move legalize_arguments to a shared location (e.g., tilelang.utils.language) and import it in both functions, similar to how get_buffer_region_from_load is imported.

tilelang/layout/gemm_sp.py (1)

137-151: Consider validating required parameters for SM90.

The dispatcher passes **extra_args to make_cutlass_metadata_layout_sm90, which requires a block_k parameter. If extra_args doesn't contain block_k, this will raise a TypeError at runtime.

Consider adding parameter validation or making the requirement explicit:

 def make_cutlass_metadata_layout(buffer: tvm.tir.Buffer,
                                  mma_dtype: str = "float16",
                                  arch: str | None = None,
-                                 **extra_args):
+                                 block_k: int | None = None,
+                                 **extra_args):
     if arch is None:
         arch = nvcc.get_target_compute_version()
 
     compute_version = nvcc.parse_compute_version(arch)
 
     if compute_version >= (9, 0):
+        if block_k is None:
+            raise ValueError("block_k is required for SM90 and above")
-        return make_cutlass_metadata_layout_sm90(buffer=buffer, mma_dtype=mma_dtype, **extra_args)
+        return make_cutlass_metadata_layout_sm90(buffer=buffer, mma_dtype=mma_dtype, block_k=block_k, **extra_args)
     elif compute_version >= (8, 0):
         return make_cutlass_metadata_layout_sm8x(buffer=buffer, mma_dtype=mma_dtype)
     else:
         raise NotImplementedError(f"Unsupported architecture: {arch}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 34d1af6 and ff2cf71.

📒 Files selected for processing (9)
  • src/op/gemm_sp.cc (1 hunks)
  • src/op/gemm_sp.h (1 hunks)
  • src/tl_templates/cuda/debug.h (1 hunks)
  • tilelang/intrinsics/mma_macro_generator.py (3 hunks)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/language/experimental/gemm_sp.py (2 hunks)
  • tilelang/layout/gemm_sp.py (5 hunks)
  • tilelang/profiler/__init__.py (1 hunks)
  • tilelang/utils/tensor.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • tilelang/utils/tensor.py
  • src/op/gemm_sp.h
  • tilelang/profiler/init.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • tilelang/layout/gemm_sp.py
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • tilelang/layout/gemm_sp.py
🧬 Code graph analysis (4)
src/op/gemm_sp.cc (3)
src/op/gemm_sp.h (1)
  • RegisterReflection (27-33)
src/op/gemm.h (2)
  • RegisterReflection (35-41)
  • RegisterReflection (106-128)
tilelang/tileop/gemm_sp/gemm_sp_base.py (3)
  • policy (130-131)
  • M (33-34)
  • N (37-38)
tilelang/language/__init__.py (1)
tilelang/language/experimental/gemm_sp.py (2)
  • gemm_sp (10-88)
  • gemm_sp_v2 (92-310)
tilelang/layout/gemm_sp.py (1)
tilelang/contrib/nvcc.py (2)
  • get_target_compute_version (401-442)
  • parse_compute_version (445-467)
tilelang/language/experimental/gemm_sp.py (4)
tilelang/utils/language.py (1)
  • get_buffer_region_from_load (162-193)
tilelang/tileop/gemm_sp/gemm_sp_base.py (6)
  • E (74-75)
  • B (78-79)
  • C (82-83)
  • policy (130-131)
  • M (33-34)
  • N (37-38)
src/op/gemm.h (4)
  • GemmWarpPolicy (59-83)
  • GemmWarpPolicy (64-68)
  • GemmWarpPolicy (70-74)
  • GemmWarpPolicy (76-82)
tilelang/language/gemm.py (1)
  • legalize_arguments (36-47)
🪛 Ruff (0.14.5)
tilelang/language/__init__.py

54-54: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/layout/gemm_sp.py

115-115: Avoid specifying long messages outside the exception class

(TRY003)


119-119: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/experimental/gemm_sp.py

163-164: Prefer TypeError exception for invalid type

(TRY004)


163-164: Avoid specifying long messages outside the exception class

(TRY003)


191-192: Prefer TypeError exception for invalid type

(TRY004)


191-192: Avoid specifying long messages outside the exception class

(TRY003)


255-256: Prefer TypeError exception for invalid type

(TRY004)


255-256: Avoid specifying long messages outside the exception class

(TRY003)


275-276: Prefer TypeError exception for invalid type

(TRY004)


275-276: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (8)
tilelang/language/__init__.py (1)

54-54: LGTM! Public API export is correct.

The addition of gemm_sp_v2 to the public API export is appropriate and consistent with the pattern used for other experimental features.

The static analysis hint about the unused noqa directive is a false positive—the directive is standard practice for public API modules where imports are re-exported rather than used locally.

src/op/gemm_sp.cc (1)

313-324: LGTM! FFI registration structure is well-organized.

The FFI registration block properly registers both reflection entries and the global function definition, following the established pattern from GemmWarpPolicyNode (see src/op/gemm.h:34-40). The registration structure enables proper exposure of the GemmSPWarpPolicy type and its methods to the Python FFI layer.

tilelang/intrinsics/mma_macro_generator.py (2)

25-25: LGTM! New layout imports follow existing conventions.

The newly imported layout functions follow the established naming pattern and are used in the ldmatrix fallback paths below.

Also applies to: 28-28


296-297: I'm unable to access the repository due to persistent clone failures. Without being able to examine the actual code at lines 280-310 in tilelang/intrinsics/mma_macro_generator.py, I cannot definitively verify whether the 16-bit dtype path is unreachable or confirm the logic of the ldmatrix_available condition at line 288.

To complete this verification, I would need to:

  1. Examine the complete context around lines 280-310
  2. Verify the actual ldmatrix_available logic
  3. Search for sparse tensor constraints or code paths that might affect 16-bit handling

The review comment requires manual verification of the reachability claim.

tilelang/layout/gemm_sp.py (4)

33-46: LGTM! fp8 dtype support correctly added.

The addition of float8_e4m3 and float8_e5m2 to both the validation check and the bits_map is correct and consistent with the 8-bit handling logic.


78-80: LGTM! 8-bit dtype handling correctly parameterized.

The parameterization of shape and stride calculations using block_k for 8-bit types is consistent with how other bit-widths handle block_k (through NumK).


117-119: Verify dtype support consistency across SM8x and SM90—uint8 may be missing from SM90.

Line 117 includes "uint8" in the SM8x validation check. According to CUTLASS architecture documentation, both SM8x (Ampere) and SM90 (Hopper) have hardware support for INT8/UINT8 tensor-core operations. If the SM90 layout function (line 33) omits "uint8" from its dtype list, this represents an inconsistency that should be addressed to ensure feature parity across architectures, unless there is an explicit architectural limitation or design decision not shown in the provided context.


107-134: Based on my verification efforts, I found important information but cannot fully validate the implementation. Here's what I discovered:

PyTorch Reference Implementation (from _calculate_meta_reordering_scatter_offsets):

  • Uses group_x = 64 (fixed) and group_y = 32 or 16 (based on dtype.itemsize)
  • Has a specific row reordering formula with multiple steps
  • Uses interleave = 2 (fixed)
  • Final offset: (cols_maj * m * interleave + dst_rows * interleave + cols_min)

TileLang Implementation (from the review):

  • Uses group = 32 or 16 (not 64)
  • Uses interweave = 4 or 2 (variable)
  • Row calculation: i = i // group * group + (i % 8) * interweave + (i % group) // 8
  • Final offset: (j // 2) * m * 2 + i * 2 + (j % 2)

The algorithms differ substantially from the PyTorch reference. However, I could not find:

  1. Public documentation of TileLang's specific sparse metadata layout algorithm
  2. Tests or specifications validating the implementation
  3. Explanation of why the approach differs from PyTorch's CUTLASS reference

The algorithm used in ColumnMajorInterleaved differs significantly from the PyTorch/CUTLASS reference implementation. Verify this is intentional or a bug by:

  • Checking the original implementation before this change for comparison
  • Running existing sparse GEMM tests to confirm correctness
  • Documenting the reasoning if this is an intentional algorithmic divergence from CUTLASS

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
src/op/gemm_sp_py.cc (1)

216-227: Guard std::stoi in GetArchInt to avoid exceptions on malformed arch

GetArchInt still calls std::stoi(arch.substr(3)) directly. If the arch string is "sm_abc" or otherwise non‑numeric after "sm_", this will throw std::invalid_argument or std::out_of_range and crash instead of falling back to 0.

You can make this more robust by catching conversion errors and logging a warning:

 static int GetArchInt(Target target) {
   int arch_int = 0;
   auto s = target->GetAttr<String>("arch");
   ICHECK(s.has_value());
   std::string arch = s.value();
   if (arch.rfind("sm_", 0) == 0) {
-    arch_int = std::stoi(arch.substr(3));
+    try {
+      arch_int = std::stoi(arch.substr(3));
+    } catch (const std::exception &e) {
+      LOG(WARNING) << "Failed to parse architecture number from '" << arch
+                   << "': " << e.what();
+      arch_int = 0;
+    }
   } else {
     arch_int = 0;
   }
   return arch_int;
 }

This keeps the documented behavior (“return 0 if it doesn’t match sm_<num>”) even when the attribute content is unexpected.

🧹 Nitpick comments (2)
tilelang/language/experimental/gemm_sp.py (1)

112-134: Align gemm_sp_v2 docstring with sparse semantics and arguments

The docstring still describes a generic GEMM and doesn’t mention transpose_E, even though this is a sparse GEMM with explicit metadata and a transpose_E parameter.

Consider updating the text to make this clearer, e.g.:

-    """Perform a General Matrix Multiplication (GEMM) operation.
+    """Perform a Sparse General Matrix Multiplication (GEMM-sp) operation.
@@
-        transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
+        transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False.
+        transpose_E (bool, optional): Whether to transpose metadata matrix E. Defaults to False.

This keeps the Python API docs in sync with the underlying operator semantics.

src/op/gemm_sp_py.cc (1)

24-45: Update ctor argument layout comment to include trans_E

The implementation reads trans_E from args[6], but the “expected layout” comment only lists trans_A and trans_B. To avoid confusion when maintaining the Python side, consider updating the comment block to reflect the extra flag, e.g.:

- *     [Aptr, Eptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
- *      M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
+ *     [Aptr, Eptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool),
+ *      trans_E (Bool),
+ *      M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool),
 *      stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int),
 *      (optional) kPack (Int), (optional) wg_wait (Int)]

The ctor logic itself looks consistent with the Python call site.

Also applies to: 65-85

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 47fc794 and ff22934.

📒 Files selected for processing (4)
  • src/op/gemm_sp_py.cc (1 hunks)
  • src/op/gemm_sp_py.h (1 hunks)
  • tilelang/language/experimental/gemm_sp.py (2 hunks)
  • tilelang/tileop/gemm_sp/gemm_sp_base.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tilelang/tileop/gemm_sp/gemm_sp_base.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.

Applied to files:

  • src/op/gemm_sp_py.cc
📚 Learning: 2025-09-15T10:51:06.985Z
Learnt from: botbw
Repo: tile-ai/tilelang PR: 691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.985Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.

Applied to files:

  • src/op/gemm_sp_py.cc
🧬 Code graph analysis (3)
src/op/gemm_sp_py.h (1)
src/op/gemm_sp.h (2)
  • tvm (13-117)
  • RegisterReflection (27-33)
tilelang/language/experimental/gemm_sp.py (2)
tilelang/utils/language.py (5)
  • to_buffer_region (196-238)
  • retrieve_shape (241-258)
  • retrieve_stride (261-279)
  • retrieve_offset (359-376)
  • prim_expr_equal (390-405)
tilelang/language/utils.py (1)
  • buffer_region_to_tile_region (23-35)
src/op/gemm_sp_py.cc (3)
src/op/gemm_sp_py.h (1)
  • GemmSPPy (83-89)
tilelang/tileop/gemm_sp/__init__.py (1)
  • GemmSPPy (29-69)
tilelang/tileop/gemm_sp/gemm_sp_base.py (3)
  • policy (130-131)
  • M (33-34)
  • N (37-38)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (2)
tilelang/language/experimental/gemm_sp.py (1)

6-14: Imports and existing gemm_sp wiring look consistent

Using the shared helpers (to_buffer_region, retrieve_shape/stride/offset, prim_expr_equal, buffer_region_to_tile_region) in this module keeps the sparse GEMM path aligned with the common utilities; the existing gemm_sp implementation remains consistent and doesn’t introduce new issues in this diff.

Also applies to: 53-95

src/op/gemm_sp_py.h (1)

21-88: GemmSPPy node/operator definition and reflection look consistent

The layout of GemmSPPyNode (buffers, regions, transposes, sizes, strides, offsets, clear_accum, kPack/wg_wait, and policy) matches how the ctor in gemm_sp_py.cc initializes these fields, and RegisterReflection exposes the expected attributes for Python/FFI use. The GemmSPPy wrapper and TVM_FFI macros also follow the existing TileOperator patterns.

No issues from the header side in this diff.

Comment on lines +188 to +198
A_offset = retrieve_offset(A_sparse)
B_offset = retrieve_offset(B)
assert A_offset[-2] == 0, "The offset of the first dimension of A must be 0"
assert B_offset[-2] == 0, "The offset of the first dimension of B must be 0"
offset_a = A_offset[-1]
offset_b = B_offset[-1]

A_arg = buffer_region_to_tile_region(A_region, "r", [r for r in A_shape])
E_arg = buffer_region_to_tile_region(E_region, "r", [r for r in E_shape])
B_arg = buffer_region_to_tile_region(B_region, "r", [r for r in B_shape])
C_arg = buffer_region_to_tile_region(C_region, "rw", [r for r in C_shape])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use BufferRegion when computing offsets, not the raw Buffer

retrieve_offset(A_sparse) / retrieve_offset(B) operate on the underlying tir.Buffer, which always returns zeros for all dimensions. This means:

  • The assertions on A_offset[-2] / B_offset[-2] will pass even if the actual region has a non‑zero minimum on that dimension.
  • offset_a / offset_b passed into tl.gemm_sp_py will ignore any per‑region minima encoded in A_region / B_region.

Since you already have A_region / B_region as tir.BufferRegions, it’s more accurate to derive offsets from those so sub‑regions are handled correctly:

-    A_offset = retrieve_offset(A_sparse)
-    B_offset = retrieve_offset(B)
+    A_offset = retrieve_offset(A_region)
+    B_offset = retrieve_offset(B_region)

This keeps the offset computation consistent with how retrieve_offset is defined in tilelang.utils.language and ensures non‑zero minima on the last dimension are reflected in offset_a / offset_b.

🤖 Prompt for AI Agents
In tilelang/language/experimental/gemm_sp.py around lines 188-198, the code
currently calls retrieve_offset on the raw Buffers (A_sparse and B) which always
yields zeros; replace those calls to use the corresponding BufferRegion objects
(A_region and B_region) so offsets reflect per-region minima; update the
assertions to check the second-to-last element of the offsets returned from the
regions and set offset_a/offset_b from the last element of those region-derived
offsets before passing them into tl.gemm_sp_py.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tilelang/utils/sparse.py (2)

109-128: Update docstring to reflect conditional sparsity pattern.

The docstring states "2:4 sparsity" but the implementation now uses 1:2 sparsity for float32. Consider updating the documentation to clarify this behavior.

 def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transposed: bool = False):
     """
-    Generate a random semi-sparse tensor. The generated tensor will have 2:4 sparsity along the K dimension.
+    Generate a random semi-sparse tensor. The generated tensor will have 2:4 sparsity along the K dimension
+    (or 1:2 sparsity for float32).
     Args:
         M (int): Number of rows
         K (int): Number of columns

161-184: Update docstring to reflect conditional sparsity pattern.

Same as randn_semi_sparse, the docstring states "2:4 sparsity" but uses 1:2 sparsity for float32.

 def arange_semi_sparse(M: int,
                        K: int,
                        dtype=torch.float16,
                        device='cuda',
                        transposed: bool = False):
     """
-    Generate a semi-sparse tensor with values from 0 to M*K-1. The generated tensor will have 2:4 sparsity along the K dimension.
+    Generate a semi-sparse tensor with values from 0 to M*K-1. The generated tensor will have 2:4 sparsity
+    along the K dimension (or 1:2 sparsity for float32).
     Args:
         M (int): Number of rows
         K (int): Number of columns
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ff22934 and 8ea0353.

📒 Files selected for processing (3)
  • tilelang/profiler/__init__.py (2 hunks)
  • tilelang/utils/sparse.py (5 hunks)
  • tilelang/utils/tensor.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/utils/sparse.py (1)
tilelang/utils/tensor.py (2)
  • is_float8_dtype (8-14)
  • fp8_remove_negative_zeros_ (17-21)
tilelang/profiler/__init__.py (1)
tilelang/utils/tensor.py (1)
  • is_float8_dtype (8-14)
tilelang/utils/tensor.py (3)
tilelang/language/v2/dtypes.py (5)
  • float8_e5m2 (266-266)
  • float8_e5m2fnuz (273-273)
  • float8_e4m3fn (252-252)
  • float8_e4m3fnuz (259-259)
  • uint8 (182-182)
tilelang/language/proxy.py (1)
  • Tensor (252-253)
tilelang/language/customize.py (1)
  • view (55-66)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
🔇 Additional comments (8)
tilelang/utils/tensor.py (3)

8-14: LGTM!

Clean implementation of float8 dtype detection using set membership. This centralizes the float8 check for reuse across the codebase.


17-21: LGTM!

Correctly uses PyTorch's in-place operation naming convention (trailing underscore). The implementation properly handles negative zeros in float8 by viewing the underlying bytes as uint8 and zeroing out positions where the original value was zero.


24-31: LGTM!

The enum extension provides new tensor supply strategies. Note that Normal (line 27) and Randn (line 28) represent different distributions: Normal uses normal_(-1.0, 1.0) with custom mean/std, while Randn uses the standard normal distribution via torch.randn.

tilelang/profiler/__init__.py (2)

9-14: LGTM!

Good refactoring to import the centralized is_float8_dtype helper instead of maintaining a local implementation. This improves code reusability and consistency.


129-137: LGTM!

Correct handling of float8 dtypes by converting to float32 before comparison, since float8 types have limited operation support in PyTorch.

tilelang/utils/sparse.py (3)

6-6: LGTM!

Appropriate import of float8 utilities to support the compression path enhancements.


131-158: LGTM!

The new randint_semi_sparse function correctly follows the same pattern as randn_semi_sparse. Note that the docstring should also mention the 1:2 sparsity for float32 (though float32 is unusual for integer tensors).


92-103: I'll help you verify the metadata layout consistency concern with transposed inputs. Let me search the codebase for the relevant function definitions and usages.
<function_calls>


#!/bin/bash

First, find the compress_sm80 function definition to understand what E represents

rg -n "def compress_sm80" --type=py -A 15


</function_calls>


#!/bin/bash

Search for the compress function itself to see the full context

rg -n "def compress(" --type=py -A 30 | head -100


</function_calls>


#!/bin/bash

Find all usages of compress() to see how the returned E is used

rg -n "compress\s*(" --type=py -B 2 -A 3 | head -150


</function_calls>

Let me look at the specific file mentioned in the review:
<function_calls>


#!/bin/bash

Read the sparse.py file to understand the full context

wc -l tilelang/utils/sparse.py


</function_calls>

<function_calls>


#!/bin/bash

Read the file to see the compress function and related functions

cat -n tilelang/utils/sparse.py | sed -n '70,130p'


</function_calls>

<function_calls>


#!/bin/bash

Let's also check for compress_sm80 and other compress functions in the file

cat -n tilelang/utils/sparse.py | grep -E "(def compress|compress_sm80)" -A 3 -B 1


</function_calls>

<function_calls>


#!/bin/bash

Find all calls to compress() and compress_sm80() across the codebase to understand usage patterns

rg -n "compress(" --type=py -B 1 -A 2 | grep -E "(compress|sparse|decompress)" | head -80


</function_calls>

<function_calls>


#!/bin/bash

Check what compress_sm80 is - might be a C++ extension

rg -n "compress_sm80" --type=py -B 2 -A 5


</function_calls>

<function_calls>


#!/bin/bash

Look for decompress to understand the symmetry

rg -n "def decompress" --type=py -A 20


</function_calls>

<function_calls>


#!/bin/bash

Search for any test files that might demonstrate usage of compress with transposed flag

find . -name "test.py" -type f | xargs rg -l "compress" 2>/dev/null | head -5


</function_calls>

<function_calls>


#!/bin/bash

Let's search for test files and see how compress is tested

fd -e "test*.py" | xargs rg -n "compress(" -B 3 -A 5 2>/dev/null | head -100


</function_calls>

<function_calls>


#!/bin/bash

Let me view the entire sparse.py file to get full context

cat tilelang/utils/sparse.py


</function_calls>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants