-
Notifications
You must be signed in to change notification settings - Fork 327
[Language] Add Correctness and performance check scripts for V2 #1174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Caution Review failedThe pull request is closed. WalkthroughAdds comprehensive GEMM testing and profiling modules, introduces gemm_v1 alongside gemm_v2 (with alias gemm→gemm_v1), updates CUDA codegen type handling and MMA dispatch templates, extends MMA TF32 specializations, and removes a debug log and one outdated test. Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test Runner
participant API as tilelang.language
participant GEMMv1 as gemm_v1
participant GEMMv2 as gemm_v2
participant Codegen as CUDA Codegen
participant Kernel as CUDA Kernel
participant Ref as PyTorch Ref
Test->>API: call run_gemm / run_gemm_* (params)
API->>GEMMv2: build program (or GEMMv1 if used)
GEMMv2->>Codegen: emit CUDA (resolve AType/BType, scale_out)
Codegen-->>GEMMv2: CUDA source / binary
GEMMv2->>Kernel: launch kernel
Kernel-->>Test: outputs
Test->>Ref: compute reference (torch)
Test->>Test: compare outputs vs Ref (assert close)
Note over Codegen: MMA dispatch may choose TF32 paths\n(templates use derived AType/BType)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (4)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tilelang/language/__init__.py (1)
54-54: Drop the redundant noqa suppressor
gemm,gemm_v1, andgemm_v2are all re-exported, so Ruff now flags# noqa: F401as unused. Please remove the directive to keep lint clean.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
3rdparty/tvm(1 hunks)maint/gemm_v2/correctness_evaluation.py(1 hunks)maint/gemm_v2/latency.py(1 hunks)src/op/gemm.cc(0 hunks)src/target/codegen_cuda.cc(4 hunks)tilelang/language/__init__.py(1 hunks)tilelang/language/gemm.py(2 hunks)
💤 Files with no reviewable changes (1)
- src/op/gemm.cc
🧰 Additional context used
🧬 Code graph analysis (4)
tilelang/language/gemm.py (2)
tilelang/primitives/gemm/__init__.py (1)
gemm(10-46)examples/gemm/example_gemm.py (1)
gemm(9-25)
maint/gemm_v2/latency.py (9)
tilelang/jit/__init__.py (1)
jit(233-306)tilelang/language/allocate.py (2)
alloc_shared(27-42)alloc_fragment(59-70)tilelang/language/fill.py (1)
clear(24-48)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/copy.py (1)
copy(11-87)tilelang/language/gemm.py (2)
gemm_v2(215-434)gemm_v1(10-211)tilelang/language/parallel.py (1)
Parallel(9-29)tilelang/jit/kernel.py (1)
get_profiler(367-383)tilelang/utils/tensor.py (1)
TensorSupplyType(11-18)
maint/gemm_v2/correctness_evaluation.py (11)
tilelang/language/kernel.py (2)
threads(215-219)num_threads(222-226)tilelang/language/allocate.py (2)
alloc_shared(27-42)alloc_fragment(59-70)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/copy.py (1)
copy(11-87)tilelang/language/gemm.py (1)
gemm_v2(215-434)tilelang/jit/__init__.py (1)
compile(30-79)tilelang/jit/kernel.py (2)
out_idx(453-454)get_profiler(367-383)tilelang/transform/pass_config.py (1)
PassConfigKey(6-144)tilelang/profiler/__init__.py (1)
assert_allclose(77-146)tilelang/env.py (1)
disable_cache(271-272)tilelang/testing/__init__.py (1)
set_random_seed(30-35)
tilelang/language/__init__.py (2)
tilelang/primitives/gemm/__init__.py (1)
gemm(10-46)tilelang/language/gemm.py (2)
gemm_v1(10-211)gemm_v2(215-434)
🪛 Ruff (0.14.2)
tilelang/language/__init__.py
54-54: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (3)
src/target/codegen_cuda.cc (2)
1841-1841: LGTM: scale_out now handled as runtime expressionThe change from boolean extraction to expression string is correct. This allows
scale_outto be a runtime-evaluated expression rather than a compile-time constant, which provides more flexibility.Note: Unlike
scale_in_aandscale_in_b(which are still booleans converted to "1"/"-1" strings at lines 1866-1867),scale_outis now directly substituted as an arbitrary expression into the template. Ensure the underlying WGMMA instruction template can handle expression strings for the scale_out parameter.Also applies to: 1873-1873
1903-1903: LGTM: Consistent scale_out handling across WGMMA variantsThe changes mirror the pattern in
ptx_wgmma_ss(lines 1841, 1873), ensuring consistent handling ofscale_outas a runtime expression across bothptx_wgmma_ssandptx_wgmma_rshandlers.Also applies to: 1942-1942
3rdparty/tvm (1)
1-1: Manually verify the TVM submodule commit references.The sandbox environment cannot access the upstream TVM repository to verify these commit hashes. Confirm that both commit hashes exist in the official TVM repository and review the changes between them to ensure they align with the GEMM v2 implementation goals and introduce no unexpected breaking changes.
| import torch | ||
|
|
||
| if trans_A: | ||
| A = A.T | ||
| if trans_B: | ||
| B = B.T | ||
| if in_dtype == "float32": | ||
| A = (A.view(torch.int32) - 0x1000).view(torch.float32) | ||
| B = (B.view(torch.int32) - 0x1000).view(torch.float32) | ||
| C = torch.matmul(A.to(torch.float), B.to(torch.float)) | ||
| C = C.to(torch.__getattribute__(out_dtype)) | ||
| return C |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the dtype reinterpretation code path
This block tries to reinterpret tensor bits with a NumPy-style view(dtype), but PyTorch’s Tensor.view only accepts shape arguments. As soon as we exercise the float32 path, this raises TypeError: view(): argument 'size' must be tuple of ints, but found type torch.dtype. Please switch to an explicit dtype conversion before the bias adjustment:
- if in_dtype == "float32":
- A = (A.view(torch.int32) - 0x1000).view(torch.float32)
- B = (B.view(torch.int32) - 0x1000).view(torch.float32)
+ if in_dtype == "float32":
+ A = (A.to(torch.int32) - 0x1000).to(torch.float32)
+ B = (B.to(torch.int32) - 0x1000).to(torch.float32)That keeps the intended shift while avoiding the runtime failure.
🤖 Prompt for AI Agents
In maint/gemm_v2/correctness_evaluation.py around lines 77 to 88, the code
incorrectly uses Tensor.view(torch.int32) to reinterpret dtypes (PyTorch view
expects shape tuples), causing a TypeError; change the float32 reinterpretation
path to explicitly cast the tensor to an integer dtype (e.g., A =
A.to(torch.int32)), perform the subtraction (A = A - 0x1000), then cast back to
float (A = A.to(torch.float32)); do the same for B, then proceed with matmul and
conversion to out_dtype as before.
…-ai#1174) * fix * lint fix * fix * lint fix * fix * upd
* [Test] Add cp async to avoid register spill * [BugFix] GQA fwd and bwd - Fix the undefined behavior of -inf in acc_s - Fix the causal loop range in varlen scenario * [TMA] Move on to TMA and locate the register spill issue * [Debug] Not the reason of zero-assignment. Probably the combination of Parallel op & conditional qkT * [Debug] The SIMT copy in producer occupies too many registers * [BugFix] Use 3D lse and delta to avoid illegal instruction * [Perf] Relaxed order for dQ and SIMT store for dKdV * [Feat] For atomic add version * [Lint] * [Bugfix] Enable code lowering with producer‑copy‑only program (#1168) * bugfix * lint fix * Enhance warp group register allocation to handle missing consumer bodies gracefully. Updated logic to annotate producer side when consumer is absent, ensuring robustness in degenerate warp-specialized patterns. * Refactor VisitExpr_ method in inject_tma_barrier.cc for improved readability. Adjusted formatting and spacing for clarity in barrier handling logic. * Update barrier handling in inject_tma_barrier.cc to accommodate newly appended entries. Adjusted the size of the replace vector to ensure it covers the full needed length, and modified the logic for appending barriers based on the updated replace conditions. * [Bugfix] Support 16bits shfl_sync (#1169) * Add type-safe warp shuffle helpers for 16-bit float types in common.h - Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`. - Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations. - Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability. * lint fix * [Testing] Move TMA 1D and test for its functionality (#1167) * [Testing] Move TMA 1D and test for its functionality * [Lint] * [Refactor]: Change the params in pytest to avoid oom error during ci (#1170) * [Refactor]: Change the params in pytest to avoid oom error during ci * format * fix * Update test_example_cast.py * Update parameters in test_example_cast * Update test_example_flash_attention.py * update * format * fix * fix * format * [Bugfix] Fix tvm import path for editable build (#1172) * [Language] Expose `T.warpgroup_fence_operand` for nvcc code motion (#986) * remove debug print * pipeline fix * use the correct buffer access scope * rs support * warp warpgroup_fence_operand * fix * fp8 dtype ptx enhance * mma fix * TCGEN05 Interface * tcgen05 support * rebase * update * Enhance TCGEN05 support by adding new intrinsic operations and descriptors. Introduced `ptx_tcgen05_mma_ts` for tensor-memory to shared-memory instructions and `tcgen05_mma_arrive` for signaling barrier completion. Updated existing descriptors and code generation logic to accommodate these changes, ensuring compatibility with new instruction sets. Refactored related allocation functions and improved handling of shared memory descriptors. * lint fix * Refactor buffer reference handling in CUDA code generation and update test execution in tilelang. Ensure default annotations for unrolling are set correctly in TIR IR module. * wgmma fix --------- Co-authored-by: Zhiwen Mo <[email protected]> * [Language] Add Correctness and performance check scripts for V2 (#1174) * fix * lint fix * fix * lint fix * fix * upd * [Bugfix] Legalize Datatype for mma intrinisc codegen (#1179) * fix * lint fix * Enhance CUDA code generation by updating register type handling for float data types. Introduced a workaround for TF32 type compatibility and improved the registration of MMA register types for A and B operands. * [Perf] Add layout and use_tma to boost performance * [Lint] * [Note] --------- Co-authored-by: Lei Wang <[email protected]> Co-authored-by: Yuqi Dong <[email protected]> Co-authored-by: Zhiwen Mo <[email protected]>
…-ai#1174) * fix * lint fix * fix * lint fix * fix * upd
* [Test] Add cp async to avoid register spill * [BugFix] GQA fwd and bwd - Fix the undefined behavior of -inf in acc_s - Fix the causal loop range in varlen scenario * [TMA] Move on to TMA and locate the register spill issue * [Debug] Not the reason of zero-assignment. Probably the combination of Parallel op & conditional qkT * [Debug] The SIMT copy in producer occupies too many registers * [BugFix] Use 3D lse and delta to avoid illegal instruction * [Perf] Relaxed order for dQ and SIMT store for dKdV * [Feat] For atomic add version * [Lint] * [Bugfix] Enable code lowering with producer‑copy‑only program (tile-ai#1168) * bugfix * lint fix * Enhance warp group register allocation to handle missing consumer bodies gracefully. Updated logic to annotate producer side when consumer is absent, ensuring robustness in degenerate warp-specialized patterns. * Refactor VisitExpr_ method in inject_tma_barrier.cc for improved readability. Adjusted formatting and spacing for clarity in barrier handling logic. * Update barrier handling in inject_tma_barrier.cc to accommodate newly appended entries. Adjusted the size of the replace vector to ensure it covers the full needed length, and modified the logic for appending barriers based on the updated replace conditions. * [Bugfix] Support 16bits shfl_sync (tile-ai#1169) * Add type-safe warp shuffle helpers for 16-bit float types in common.h - Introduced generic passthrough functions for warp shuffle operations: `shfl_xor_sync`, `shfl_down_sync`, `shfl_up_sync`, and `shfl_sync`. - Added specializations for `cutlass::half_t` and `cutlass::bfloat16_t` to ensure type safety during shuffle operations. - Updated `reduce.h` to utilize the new shuffle functions, enhancing code clarity and maintainability. * lint fix * [Testing] Move TMA 1D and test for its functionality (tile-ai#1167) * [Testing] Move TMA 1D and test for its functionality * [Lint] * [Refactor]: Change the params in pytest to avoid oom error during ci (tile-ai#1170) * [Refactor]: Change the params in pytest to avoid oom error during ci * format * fix * Update test_example_cast.py * Update parameters in test_example_cast * Update test_example_flash_attention.py * update * format * fix * fix * format * [Bugfix] Fix tvm import path for editable build (tile-ai#1172) * [Language] Expose `T.warpgroup_fence_operand` for nvcc code motion (tile-ai#986) * remove debug print * pipeline fix * use the correct buffer access scope * rs support * warp warpgroup_fence_operand * fix * fp8 dtype ptx enhance * mma fix * TCGEN05 Interface * tcgen05 support * rebase * update * Enhance TCGEN05 support by adding new intrinsic operations and descriptors. Introduced `ptx_tcgen05_mma_ts` for tensor-memory to shared-memory instructions and `tcgen05_mma_arrive` for signaling barrier completion. Updated existing descriptors and code generation logic to accommodate these changes, ensuring compatibility with new instruction sets. Refactored related allocation functions and improved handling of shared memory descriptors. * lint fix * Refactor buffer reference handling in CUDA code generation and update test execution in tilelang. Ensure default annotations for unrolling are set correctly in TIR IR module. * wgmma fix --------- Co-authored-by: Zhiwen Mo <[email protected]> * [Language] Add Correctness and performance check scripts for V2 (tile-ai#1174) * fix * lint fix * fix * lint fix * fix * upd * [Bugfix] Legalize Datatype for mma intrinisc codegen (tile-ai#1179) * fix * lint fix * Enhance CUDA code generation by updating register type handling for float data types. Introduced a workaround for TF32 type compatibility and improved the registration of MMA register types for A and B operands. * [Perf] Add layout and use_tma to boost performance * [Lint] * [Note] --------- Co-authored-by: Lei Wang <[email protected]> Co-authored-by: Yuqi Dong <[email protected]> Co-authored-by: Zhiwen Mo <[email protected]>
This pull request introduces a new version of the GEMM (matrix multiplication) kernel, adds a usage example and benchmarking script, and refactors how GEMM functions are exposed in the
tilelangpackage. It also updates the CUDA code generation logic to treat thescale_outargument as a runtime value rather than a compile-time boolean, improving flexibility. The most important changes are grouped below.GEMM Kernel Improvements and Refactoring
gemm_v2and refactored the original kernel togemm_v1. The defaultgemmnow points togemm_v1, allowing users to select between versions. (tilelang/language/gemm.py,tilelang/language/__init__.py) [1] [2] [3]3rdparty/tvm)Usage Example and Benchmarking
latency.pythat demonstrates how to use the new kernel, validates correctness against PyTorch, and profiles latency. It supports toggling betweengemm_v1andgemm_v2via a command-line flag. (maint/gemm_v2/latency.py)Code Generation and Runtime Flexibility
scale_outargument in CUDA code generation to treat it as a runtime value rather than a compile-time boolean, both in assembly generation and rule replacement. This enables more flexible kernel invocation. (src/target/codegen_cuda.cc) [1] [2] [3] [4]Logging and Debugging
src/op/gemm.cc)Summary by CodeRabbit
New Features
Tests
Bug Fixes