-
Notifications
You must be signed in to change notification settings - Fork 331
[feat] support gemm_sp for ampere arch #691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @botbw, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces support for sparse General Matrix Multiply (GEMM) operations on NVIDIA Ampere architecture GPUs. It integrates new CUTLASS-based kernels and updates the system's layout inference and metadata generation to handle Ampere-specific requirements.
Highlights
- Ampere Architecture Support: Added support for sparse General Matrix Multiply (GEMM) operations on NVIDIA Ampere (SM80/SM89) GPUs.
- CUTLASS Integration for Sparse GEMM: Implemented the core sparse GEMM kernel for Ampere using CUTLASS, including instruction shapes and shared memory layouts for various data types.
- Metadata Layout Generation: Extended the layout generation logic to create metadata layouts compatible with CUTLASS SM8x sparse kernels, handling 16-bit metadata types and column-major interleaved layouts.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for sparse GEMM on NVIDIA's Ampere architecture (sm80). The changes include new C++ CUDA kernel implementations using CUTLASS, updates to the operator logic in C++ to handle Ampere targets, and new Python layout functions for metadata. My review focuses on improving code maintainability by reducing duplication, enhancing debuggability with better error messages, and adhering to language-specific best practices and style guides (C++ casts, Python PEP 8 naming).
|
Caution Review failedThe pull request is closed. Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds SM80/SM90 structured-sparse GEMM support and arch-aware dispatch across Python and C++: new CUDA templates, layouts, GemmSP warp-policy, compress/annotate_layout dispatch, tests split per-arch, and benchmark/example updates (B layout, E tiling, accum dtype, CLI flags, optional torch-sparse). Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Bench as benchmark_matmul_sp.py
participant NVCC as nvcc
participant TL as tilelang.layout & utils
participant Kernel as CUDA kernels (SM80/SM90)
User->>Bench: run (M,N,K, --accum_dtype, --bench_torch_sparse)
Bench->>NVCC: get_target_compute_version()
NVCC-->>Bench: arch_version
Bench->>TL: make_metadata_layout(..., arch=None, backend="cutlass")
TL->>NVCC: (if arch None) get_target_compute_version()
NVCC-->>TL: compute_version
TL-->>Bench: metadata layout (SM80 or SM90)
Bench->>Kernel: launch gemm_sp (B:(K,N), E tiled per-arch, accum dtype)
Kernel-->>Bench: C result + timing
alt bench_torch_sparse requested
Bench->>Kernel: run torch-sparse path
Kernel-->>Bench: torch sparse TFlops
end
sequenceDiagram
autonumber
actor PyTest
participant Test as test_tilelibrary_gemm_sp.py
participant Utils as tilelang.utils.sparse
participant NVCC as nvcc
participant Kernel as SM80/SM90 kernels
PyTest->>Test: run_gemm_sp_sm80/sm90(...)
Test->>Utils: compress(A, transposed, arch=None)
Utils->>NVCC: get_target_compute_version()
NVCC-->>Utils: compute_version
alt compute_version >= 9.0
Utils-->>Test: compress_sm90(A,...)
else compute_version >= 8.0
Utils-->>Test: compress_sm80(A,...)
end
Test->>Kernel: launch arch-specific kernel
Kernel-->>Test: output & validate
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
This could be a missing check related to cutlass 2.x implementation. Using the cutlass example can't trigger this cause it's bypassed by padding the input. For now I think it might be safe to increase the atom size and inve- sgate in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
benchmark/matmul/benchmark_matmul_sp.py (1)
165-173: Update the e_factor/e_dtype selection to use the new arch config function.Replace the dictionary lookup with the proper compute version parsing to prevent runtime failures.
Apply this diff:
- e_factor, e_dtype = ARCH_INFO[arch] + e_factor, e_dtype = get_arch_config(arch)
🧹 Nitpick comments (3)
benchmark/matmul/benchmark_matmul_sp.py (3)
89-89: Add accum_dtype parameter documentation.The function signature was updated to include
accum_dtypebut the docstring doesn't document this new parameter.Add the missing parameter documentation:
Parameters ---------- M : int The dimension M of the matrix multiplication. N : int The dimension N of the matrix multiplication. K : int The dimension K of the matrix multiplication. + accum_dtype : str + The accumulation data type ("float" or "float16").
278-284: Handle potential import failures for torch sparse functionality.The torch sparse import and functionality may not be available in all environments or PyTorch versions.
Add error handling for the torch sparse functionality:
if args.bench_torch_sparse is not None: - from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor - if args.bench_torch_sparse == 'cutlass': - SparseSemiStructuredTensor._FORCE_CUTLASS = True - A_sp = to_sparse_semi_structured(A, transposed=False) - torch_sparse_latency = do_bench(lambda: A_sp @ B) + try: + from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor + if args.bench_torch_sparse == 'cutlass': + SparseSemiStructuredTensor._FORCE_CUTLASS = True + A_sp = to_sparse_semi_structured(A, transposed=False) + torch_sparse_latency = do_bench(lambda: A_sp @ B) + except ImportError as e: + print(f"Warning: Torch sparse functionality not available: {e}") + args.bench_torch_sparse = None
290-293: Handle torch_sparse_latency reference when benchmarking fails.If torch sparse benchmarking fails or is disabled,
torch_sparse_latencywill be undefined when referenced.Fix the undefined variable reference:
if args.bench_torch_sparse is not None: + try: from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor if args.bench_torch_sparse == 'cutlass': SparseSemiStructuredTensor._FORCE_CUTLASS = True A_sp = to_sparse_semi_structured(A, transposed=False) torch_sparse_latency = do_bench(lambda: A_sp @ B) + print(f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}") + except Exception as e: + print(f"Warning: Torch sparse benchmarking failed: {e}") - if args.bench_torch_sparse is not None: - print(f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
benchmark/matmul/benchmark_matmul_sp.py(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmark/matmul/benchmark_matmul_sp.py (5)
tilelang/layout/gemm_sp.py (1)
make_metadata_layout(136-158)tilelang/contrib/nvcc.py (1)
get_target_compute_version(257-298)examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (3)
matmul_sp(9-61)main(30-59)main(126-127)tilelang/language/experimental/gemm_sp.py (1)
gemm_sp(9-86)tilelang/env.py (1)
disable_cache(232-233)
🪛 GitHub Actions: CI Test on AMD
benchmark/matmul/benchmark_matmul_sp.py
[error] 1-1: Clang-format reformatted this file. Changes not staged for commit.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: format-check
- GitHub Check: bot-task
🔇 Additional comments (4)
benchmark/matmul/benchmark_matmul_sp.py (4)
93-94: Update matrix shape documentation to reflect layout changes.The documentation incorrectly states B has shape (K, N), but based on the code changes, B now uses transpose_B=False with shape (K, N).
The documentation and implementation are now consistent with the new B matrix layout.
246-264: LGTM! Good CLI argument additions.The new command-line arguments enhance the benchmark's flexibility:
--disable_cacheallows bypassing TileLang caching for development--accum_dtypeenables testing different accumulation precisions--bench_torch_sparseprovides comparative benchmarking against PyTorch sparse implementations
228-228: Verify transpose_B=False is correct for the new layout.The change from default (likely transpose_B=True) to
transpose_B=Falsealigns with the B matrix layout change from (N, K) to (K, N).This change correctly reflects the new B matrix layout where B is now shaped as (K, N) instead of (N, K).
208-212: Resolved — make_metadata_layout auto-detects arch and only forwards block_k for SM90.make_metadata_layout sets arch via nvcc.get_target_compute_version() when arch is None, parses the compute version, and forwards **extra_args (including block_k) only in the compute_version >= (9, 0) branch; the SM8x branch does not forward extra_args. Current calls that omit arch but pass block_k are safe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
tilelang/layout/gemm_sp.py (1)
125-133: PEP 8 naming for locals.Use CONSTANT_CASE for constants and snake_case for callables.
Apply:
- kInterleaved = 2 - stride = buffer.shape[0] * kInterleaved + K_INTERLEAVED = 2 + stride = buffer.shape[0] * K_INTERLEAVED @@ - def ColumnMajorInterleaved(i: int, j: int) -> int: - column_major = j // kInterleaved - column_minor = j % kInterleaved - return column_major * stride + i * kInterleaved + column_minor + def column_major_interleaved(i: int, j: int) -> int: + column_major = j // K_INTERLEAVED + column_minor = j % K_INTERLEAVED + return column_major * stride + i * K_INTERLEAVED + column_minor @@ - return T.Layout(buffer.shape, ColumnMajorInterleaved) + return T.Layout(buffer.shape, column_major_interleaved)benchmark/matmul/benchmark_matmul_sp.py (1)
18-21: Remove fragile ARCH_INFO mapping; derive by compute version.Hardcoding will KeyError for 8.6/8.7/8.9/9.0a and future arches.
Apply:
-arch = nvcc.get_target_compute_version() - -ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")} +arch = nvcc.get_target_compute_version()And compute e_factor/e_dtype at use site (see next comment).
🧹 Nitpick comments (12)
tilelang/layout/gemm_sp.py (4)
31-35: Fix clamped-warning to show the original value.You mutate block_k before warning, so the message is misleading.
Apply:
- if block_k > 128: - block_k = 128 - warnings.warn(f"block_k {block_k} is too large, set to 128 for {mma_dtype}.", stacklevel=2) + if block_k > 128: + orig_block_k = block_k + block_k = 128 + warnings.warn( + f"block_k {orig_block_k} is too large; clamped to 128 for {mma_dtype}.", + stacklevel=2, + )
93-96: Typo in variable name (rep_k_stirde) — rename for clarity.Pure readability; keeps future edits sane.
Apply:
- rep_k_stirde = prod(shape_i + shape_k) + rep_k_stride = prod(shape_i + shape_k) shape_k.append(rep_k) - stride_k.append(rep_k_stirde) + stride_k.append(rep_k_stride)
97-103: Avoid shadowing parameter k inside comprehensions.Shadowing isn’t a bug here but hurts readability.
Apply:
- i_offset = sum(i_decomposed[k] * stride_i[k] for k in range(len(i_decomposed))) - k_offset = sum(k_decomposed[k] * stride_k[k] for k in range(len(k_decomposed))) + i_offset = sum(i_decomposed[idx] * stride_i[idx] for idx in range(len(i_decomposed))) + k_offset = sum(k_decomposed[idx] * stride_k[idx] for idx in range(len(k_decomposed)))
119-124: Guard unsupported mma_dtype early.If mma_dtype isn’t one of the known sets, both checks are skipped with no error. Fail fast.
Apply:
+ supported = {"float16", "bfloat16", "float8", "int8", "uint8"} + if mma_dtype not in supported: + raise NotImplementedError(f"Unsupported mma_dtype for sm8x: {mma_dtype}")benchmark/matmul/benchmark_matmul_sp.py (4)
165-173: Select e_factor/e_dtype programmatically (8.x vs 9.x).Prevents runtime KeyError and future-proofs.
Apply:
- e_factor, e_dtype = ARCH_INFO[arch] + major_minor = nvcc.parse_compute_version(arch) + if major_minor >= (9, 0): + e_factor, e_dtype = 8, "uint8" + elif major_minor >= (8, 0): + e_factor, e_dtype = 16, "int16" + else: + raise NotImplementedError(f"Unsupported compute capability: {arch}")
23-39: Docstring drift: B shape and result expression.Code uses B with shape (K, N) and returns A @ B (not A @ B.T). Update text.
Apply:
- B : numpy.ndarray - The matrix with shape (N, K). + B : numpy.ndarray + The matrix with shape (K, N). @@ - The result of A @ B.T, shape (M, N). + The result of A @ B, shape (M, N).
140-161: Inline kernel docstring: fix “A @ B^T”.The kernel computes C = A @ B. Adjust wording.
Apply:
- The actual kernel to compute C = A @ B^T. + The actual kernel to compute C = A @ B.
278-285: Gate torch sparse benchmark to SM80+.PyTorch 2:4 path is SM80-specific; fail fast on older CC.
Apply:
- if args.bench_torch_sparse is not None: + if args.bench_torch_sparse is not None: + if nvcc.parse_compute_version(arch) < (8, 0): + raise NotImplementedError("--bench_torch_sparse requires SM80+") from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensorsrc/tl_templates/cuda/gemm_sp_sm80.h (2)
220-228: Replace C-style pointer casts with reinterpret_cast.Improves clarity and type-safety.
Apply:
- const TensorRefA ref_A( - (ElementA *)pA, + const TensorRefA ref_A( + reinterpret_cast<ElementA *>(pA), MmaWarp::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn})); - const TensorRefE ref_E( - (ElementE *)pE, + const TensorRefE ref_E( + reinterpret_cast<ElementE *>(pE), MmaWarp::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn})); - const TensorRefB ref_B( - (ElementB *)pB, + const TensorRefB ref_B( + reinterpret_cast<ElementB *>(pB), MmaWarp::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}));
260-264: Avoid C-style cast on accum.Use reinterpret_cast.
Apply:
- MMA::body(pA, pE, pB, *(FragmentC *)(accum), warp_id % num_warp_m, + MMA::body(pA, pE, pB, *reinterpret_cast<FragmentC *>(accum), warp_id % num_warp_m, warp_id / num_warp_m, lane_id);tilelang/utils/sparse.py (2)
40-41: Don’t hardcode -arch=sm_90.Rely on TORCH_CUDA_ARCH_LIST set by env._initialize_torch_cuda_arch_flags() to support 9.0/9.0a and future CCs.
Apply:
- f'-I{env.CUTLASS_INCLUDE_DIR}/../tools/util/include', - '-arch=sm_90', + f'-I{env.CUTLASS_INCLUDE_DIR}/../tools/util/include',
48-53: Fix clamped-warning to show the original block_k.Same pattern as SM90 layout builder.
Apply:
- if block_k > 128: - block_k = 128 - warnings.warn( - f"block_k {block_k} is too large, set to 128 for sm90 compression.", stacklevel=2) + if block_k > 128: + orig_block_k = block_k + block_k = 128 + warnings.warn( + f"block_k {orig_block_k} is too large; clamped to 128 for sm90 compression.", + stacklevel=2)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
benchmark/matmul/benchmark_matmul_sp.py(6 hunks)src/tl_templates/cpp/half.hpp(4 hunks)src/tl_templates/cuda/common.h(1 hunks)src/tl_templates/cuda/gemm_sp_sm80.h(1 hunks)testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py(6 hunks)tilelang/language/builtin.py(1 hunks)tilelang/layout/gemm_sp.py(3 hunks)tilelang/utils/sparse.py(2 hunks)
✅ Files skipped from review due to trivial changes (3)
- tilelang/language/builtin.py
- src/tl_templates/cpp/half.hpp
- src/tl_templates/cuda/common.h
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.975Z
Learnt from: botbw
PR: tile-ai/tilelang#691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.975Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
src/tl_templates/cuda/gemm_sp_sm80.htilelang/layout/gemm_sp.py
📚 Learning: 2025-09-15T10:51:06.975Z
Learnt from: botbw
PR: tile-ai/tilelang#691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.975Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
src/tl_templates/cuda/gemm_sp_sm80.htilelang/layout/gemm_sp.py
🧬 Code graph analysis (5)
src/tl_templates/cuda/gemm_sp_sm80.h (2)
src/tl_templates/cuda/common.h (3)
tl(299-335)int(92-95)int(130-137)src/tl_templates/cuda/gemm_sp_sm90.h (1)
gemm_sp_ss(224-231)
testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (3)
tilelang/utils/sparse.py (1)
compress(76-94)tilelang/layout/gemm_sp.py (1)
make_metadata_layout(136-158)tilelang/env.py (1)
disable_cache(232-233)
benchmark/matmul/benchmark_matmul_sp.py (4)
tilelang/layout/gemm_sp.py (1)
make_metadata_layout(136-158)tilelang/contrib/nvcc.py (1)
get_target_compute_version(257-298)examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py (3)
matmul_sp(9-61)main(30-59)main(126-127)tilelang/env.py (1)
disable_cache(232-233)
tilelang/layout/gemm_sp.py (1)
tilelang/contrib/nvcc.py (1)
get_target_compute_version(257-298)
tilelang/utils/sparse.py (2)
tilelang/contrib/nvcc.py (2)
get_target_compute_version(257-298)parse_compute_version(301-323)src/tl_templates/cuda/compress_sm90.cu (2)
compress_sm90(156-159)compress_sm90(156-156)
🪛 GitHub Actions: CI Test on AMD
src/tl_templates/cuda/gemm_sp_sm80.h
[error] 1-1: Tile-lang formatter reformatted this file. Changes not staged for commit.
🪛 GitHub Actions: CI
src/tl_templates/cuda/gemm_sp_sm80.h
[error] 1-1: clang-format reformatted this file. Changes not staged for commit. Please review and stage the changes.
🪛 Ruff (0.12.2)
tilelang/layout/gemm_sp.py
120-120: Avoid specifying long messages outside the exception class
(TRY003)
123-123: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/utils/sparse.py
63-64: Avoid specifying long messages outside the exception class
(TRY003)
93-94: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: bot-task
🔇 Additional comments (5)
src/tl_templates/cuda/gemm_sp_sm80.h (3)
83-87: Crosswise uses sizeof(T): correct per CUTLASS.Matches CUTLASS’s SM80 sparse patterns (Crosswise in bytes; layout template in bits). Nothing to change.
97-101: Crosswise uses sizeof(T): correct per CUTLASS.Same note for B-side.
1-1: Run formatters and commit the reformat (clang-format & tile-lang).
The verification run matched no files because fd treated 'src/tl_templates/cuda' as a pattern; run eithermake fmtorpython -m tilelang.format, or run clang-format manually, e.g.:
fd . src/tl_templates/cuda -t f -e h -e cu | xargs -r -I{} clang-format -i {}
Then git add/commit. Affected: src/tl_templates/cuda/gemm_sp_sm80.h.testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py (2)
367-369: Good: re-enabled SM80 gating.Prevents running Ampere kernels on unsupported GPUs.
119-139: SM80 E dtypes and factors look correct.int8/float8 -> 32 (int32 metadata), fp16/bf16 -> 16 (int16 metadata). Matches layout checks.
If you want auto-detect instead of hard-coding arch="8.0" in make_metadata_layout, you can drop arch=… and rely on autodetect, keeping tests arch-agnostic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
src/target/codegen_webgpu.cc (2)
727-731: Keep this copy-elision; also make output order deterministicThe
const auto&avoids copies—good. Consider emittingsmap_in a stable order to ensure reproducible artifacts across runs.Apply this diff:
- for (const auto &kv : smap_) { - os << kv.second; - } + // Emit in a deterministic order to ensure reproducible binaries. + std::vector<std::string> keys; + keys.reserve(smap_.size()); + for (const auto &kv : smap_) keys.push_back(kv.first); + std::sort(keys.begin(), keys.end()); + for (const auto &k : keys) { + os << smap_.at(k); + }
754-771: Use const-reference here as well to avoid copiesMirror the pattern used above to prevent copying function map entries.
Apply this diff:
- for (auto kv : mod->functions) { + for (const auto &kv : mod->functions) {
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/target/codegen_webgpu.cc(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: format-check
- GitHub Check: bot-task
🔇 Additional comments (1)
src/target/codegen_webgpu.cc (1)
80-98: Idiomatic empty() check — good changeSwitching to
empty()is clearer and avoids redundant size computation. No behavior change.
…ommon.h, and gemm_sp_sm80.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (3)
src/tl_templates/cuda/gemm_sp_sm80.h (3)
83-85: Layout computation correctly uses sizeof(T) for crosswise calculation.The use of
sizeof(T)(bytes) for computing crosswise constants aligns with the official CUTLASS implementation pattern, as documented in the retrieved learnings. This is correct and should not be changed tocutlass::sizeof_bits<T>::value.
224-232: Replace C-style casts with C++ reinterpret_cast for type safety.Using C-style casts for pointer conversions reduces type safety. Modern C++ prefers explicit
reinterpret_castto make the type-punning operation clear and intentional.const TensorRefA ref_A( - (ElementA *)pA, + reinterpret_cast<ElementA *>(pA), MmaWarp::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn})); const TensorRefE ref_E( - (ElementE *)pE, + reinterpret_cast<ElementE *>(pE), MmaWarp::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn})); const TensorRefB ref_B( - (ElementB *)pB, + reinterpret_cast<ElementB *>(pB), MmaWarp::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}));
266-267: Replace C-style cast with reinterpret_cast for consistency.For consistency with modern C++ practices and the previous suggestions, replace the C-style cast with
reinterpret_cast.- MMA::body(pA, pE, pB, *(FragmentC *)(accum), warp_id % num_warp_m, + MMA::body(pA, pE, pB, *reinterpret_cast<FragmentC *>(accum), warp_id % num_warp_m, warp_id / num_warp_m, lane_id);
🧹 Nitpick comments (2)
src/tl_templates/cuda/gemm_sp_sm80.h (2)
22-24: Consider documenting the rationale for int8_t shape constraints.The int8_t and uint8_t types require K to be divisible by 64 (vs 32 for fp16), likely due to tensor core instruction requirements for integer types on SM80.
Add a brief comment explaining the K dimension requirement:
template <typename Shape> struct ShapeCheck<int8_t, Shape> { + // Int8 tensor cores on SM80 require K dimension to be multiple of 64 static constexpr bool value = (Shape::kM % 16 == 0) && (Shape::kN % 16 == 0) && (Shape::kK % 64 == 0); };
255-258: Consider adding documentation for the gemm_sp_ss interface.The template parameters and their constraints would benefit from documentation to help users understand the requirements and expected behavior.
Add a documentation comment before the function:
+/// Sparse GEMM operation for SM80 architecture with 2:4 structured sparsity +/// @tparam M, N, K - Matrix dimensions +/// @tparam num_warp_m, num_warp_n - Warp tiling configuration +/// @tparam trans_A, trans_B - Whether to transpose A and B matrices +/// @tparam clear_accum - Whether to clear accumulator before operation +/// @param pA - Pointer to sparse matrix A (2:4 structured) +/// @param pB - Pointer to dense matrix B +/// @param accum - Accumulator for result C +/// @param pE - Pointer to sparsity metadata for matrix A template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, bool trans_B, bool clear_accum = false, typename A_type, typename B_type, typename C_type, typename E_type> TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) {
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/tl_templates/cuda/gemm_sp_sm80.h(1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-09-15T10:51:06.975Z
Learnt from: botbw
PR: tile-ai/tilelang#691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.975Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). However, the layout template parameter should use sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h where Crosswise uses sizeof(ElementA) but the layout template uses sizeof_bits<ElementA>::value.
Applied to files:
src/tl_templates/cuda/gemm_sp_sm80.h
📚 Learning: 2025-09-15T10:51:06.975Z
Learnt from: botbw
PR: tile-ai/tilelang#691
File: src/tl_templates/cuda/gemm_sp_sm80.h:81-85
Timestamp: 2025-09-15T10:51:06.975Z
Learning: In CUTLASS tensor operation layouts, crosswise constants should be computed using sizeof(T) (bytes), not cutlass::sizeof_bits<T>::value (bits). This is the established pattern in the official CUTLASS codebase, as seen in default_mma_core_sparse_sm80.h.
Applied to files:
src/tl_templates/cuda/gemm_sp_sm80.h
🧬 Code graph analysis (1)
src/tl_templates/cuda/gemm_sp_sm80.h (2)
src/tl_templates/cuda/common.h (3)
tl(301-337)int(92-95)int(130-137)src/tl_templates/cuda/gemm_sp_sm90.h (1)
gemm_sp_ss(224-231)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-amd
🔇 Additional comments (2)
src/tl_templates/cuda/gemm_sp_sm80.h (2)
143-146: Comprehensive input validation with clear error messages.The static assertions correctly validate type compatibility and shape requirements. This ensures compile-time safety for the sparse GEMM operations.
94-96: Verify int8/uint8/float8 row-major support for matrix BStatic assert at src/tl_templates/cuda/gemm_sp_sm80.h:94–96 tests cutlass::sizeof_bits::value != 8 while this specialization is for transpose = false (row-major). Repository search produced no output; cannot confirm whether row-major B is actually unsupported for 8‑bit element types. Confirm intended behavior and take one of the actions below:
- If row-major B must be disallowed for 8‑bit types: remove this specialization or make the assert explicitly target concrete types (int8_t/uint8_t/float8_t) and update the message.
- If row-major B is allowed: relax/remove the assert or adjust it to check bytes vs bits (sizeof(T) vs cutlass::sizeof_bits::value) and clarify the message.
| @@ -0,0 +1,270 @@ | |||
| #include <cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h> | |||
| #include <stdio.h> | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Remove unused <stdio.h> header.
The <stdio.h> header is included but never used in this file.
#include <cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h>
-#include <stdio.h>
namespace tl {📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #include <stdio.h> | |
| #include <cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h> | |
| namespace tl { |
🤖 Prompt for AI Agents
In src/tl_templates/cuda/gemm_sp_sm80.h around lines 2 to 2, remove the unused
#include <stdio.h> directive because the header is not referenced anywhere in
the file; simply delete that include line to avoid unnecessary dependency and
potential warnings.
| cutlass::gemm::warp::SparseMmaTensorOp<WarpShape, ElementA, SmemLayoutA, | ||
| ElementB, SmemLayoutB, ElementC, | ||
| LayoutC, Policy>; | ||
| static_assert(kSparse == MmaWarp::kSparse, "not 2:4 structured sparse"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Remove duplicate static assertion.
The same assertion kSparse == MmaWarp::kSparse appears on both line 184 and line 206. The second occurrence is redundant.
static int const kElementsPerElementE = MmaWarp::kElementsPerElementE;
- static_assert(kSparse == MmaWarp::kSparse, "not 2:4 structured sparse");
using ShapeA = cutlass::MatrixShape<Shape::kM, Shape::kK / kSparse>;Also applies to: 206-206
🤖 Prompt for AI Agents
In src/tl_templates/cuda/gemm_sp_sm80.h around lines 184 and 206, there are
duplicate static_asserts checking kSparse == MmaWarp::kSparse; keep the first
occurrence (line 184) and remove the redundant second occurrence at line 206 to
avoid duplication; ensure no other dependent code expects the second assert and
run a quick build to confirm no compilation regressions.
| int warp_id = threadIdx.x / 32; | ||
| int lane_id = threadIdx.x % 32; | ||
| MMA::body(pA, pE, pB, *(FragmentC *)(accum), warp_id % num_warp_m, | ||
| warp_id / num_warp_m, lane_id); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect warp index mapping for MMA operation.
The warp indices passed to MMA::body are swapped. The current code passes warp_id % num_warp_m as warp_idx_m and warp_id / num_warp_m as warp_idx_n, but this is incorrect for the standard warp tiling layout.
In CUTLASS and standard GEMM tiling, warps are typically arranged in row-major order within a threadblock, so:
warp_idx_m = warp_id / num_warp_n(row index)warp_idx_n = warp_id % num_warp_n(column index)
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
- MMA::body(pA, pE, pB, *reinterpret_cast<FragmentC *>(accum), warp_id % num_warp_m,
- warp_id / num_warp_m, lane_id);
+ MMA::body(pA, pE, pB, *reinterpret_cast<FragmentC *>(accum), warp_id / num_warp_n,
+ warp_id % num_warp_n, lane_id);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| int warp_id = threadIdx.x / 32; | |
| int lane_id = threadIdx.x % 32; | |
| MMA::body(pA, pE, pB, *(FragmentC *)(accum), warp_id % num_warp_m, | |
| warp_id / num_warp_m, lane_id); | |
| int warp_id = threadIdx.x / 32; | |
| int lane_id = threadIdx.x % 32; | |
| MMA::body(pA, pE, pB, *reinterpret_cast<FragmentC *>(accum), warp_id / num_warp_n, | |
| warp_id % num_warp_n, lane_id); |
🤖 Prompt for AI Agents
In src/tl_templates/cuda/gemm_sp_sm80.h around lines 264 to 267, the warp index
arguments to MMA::body are swapped; change the mapping so warp_idx_m = warp_id /
num_warp_n and warp_idx_n = warp_id % num_warp_n (i.e., pass warp_id /
num_warp_n as the first warp index and warp_id % num_warp_n as the second) to
match the standard row-major warp tiling used by CUTLASS/GEMM.
…rk cutlass subproject as dirty
* [feat] add an example mma atom * [fix] fix typo naming * [feat] add a template to enable compilation * [feat] add print util * [WIP] pass on single block tile * [feat] add sm80 metadata layout * [chore] clean codebase * [CI] format.sh * [feat] add sm80 compress utils * [bugfix] fix C fragment layout * [refactor] use nvcc version instead of str * [test] add test cases * [chore] add a param check * [chore] format a bit * [chore] rename func to satisfy PEP 8 and appease gemini * [chore] add check * [feat] support sm75 layout && add assertion && chore * [bug] fix illegal memory access when using two warps over N=32 This could be a missing check related to cutlass 2.x implementation. Using the cutlass example can't trigger this cause it's bypassed by padding the input. For now I think it might be safe to increase the atom size and inve- sgate in the future. * [chore] add example * [chore] format * [example] update benchmark * [bugfix] fix namespace and format * [bugfix] fix incorrect param passing * [refactor] update variable declaration for clarity in gemm_layouts and gemm_sp * [Cleanup] Remove unnecessary blank lines in metadata layout functions in gemm_sp.py * [CI] fix arch * [example] add torch sparse benchmark * [misc] polish && add reference && apply review suggestionsi && format * [CI] format with clang-tidy * [Cleanup] Format and align template struct definitions in half.hpp, common.h, and gemm_sp_sm80.h * [Update] Modify CUDA version requirements in test_gemm_sp_sm80 and mark cutlass subproject as dirty --------- Co-authored-by: LeiWang1999 <[email protected]>
As title
Benchmark file:
benchmark/matmul/benchmark_matmul_sp.pyBenchmark result on 4090:
(2 experiments)
(Torch, CUTlASS backend)
(Torch, CUSPARSELT backend)
(TileLang Sparse, fp32 accum)
(TileLang Sparse, fp16 accum)
Thoretical 4090 FP16 Sparse Tensor Core TFLOPS:
Summary by CodeRabbit
New Features
Documentation
Tests
Chores