-
Notifications
You must be signed in to change notification settings - Fork 331
[Bugfix][Enhancement] Fix a bug in previous commit and enhance cuda backend #887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
* update sm100 related utcmma, tmem, ld/st256 in src * update sm100 related utcmma, tmem, ld/st256 in tilelang * Remove deprecated GEMM examples and related README documentation for SM100 architecture support * Update GEMM implementation to replace UTCMMA with TCGEN5MMA across relevant files * Remove gemm_umma.py example and update README to reflect TCGEN5MMA terminology changes * Update README.md for gemm_sm100 example by removing outdated API sections and streamlining documentation * Update README and source files to reflect TCGEN5.MMA terminology changes * Refactor CUDA GEMM header for improved readability
WalkthroughAdds SM100/TCGEN5MMA and TMEM support across layouts, CUDA templates, codegen, and transforms. Introduces new builtins, target utilities, vectorization config, and lowering passes (LowerSharedTmem, pipeline planning updates). Extends GEMM with TCGEN5MMA path and WarpPolicy API. Adds examples, tests, and Python API entries and gating for Hopper/SM100 behaviors. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant TileLang as TileLang Python
participant Lowering as Lowering Passes
participant Codegen as CUDA Codegen
participant Device as GPU (SM100)
User->>TileLang: tl.gemm(..., mbar=opt, alloc_tmem(...))
TileLang->>Lowering: OptimizeForTarget(target)
Note over Lowering: If Hopper: RewriterWgmmaSync<br/>Always: LowerSharedBarrier -> LowerSharedTmem
Lowering->>Lowering: LowerSharedTmem (init/dealloc TMEM, remap)
Lowering->>Lowering: Pipeline planning (async dep chain)
Lowering->>Lowering: GEMM InferLayout (TCGEN5MMA/WGMMA/MMA)
Lowering->>Codegen: Emit CUDA (vec ld/st up to 256-bit)
Codegen-->>Device: Kernel launch + TMEM ops + barriers
Device-->>User: Results
sequenceDiagram
autonumber
participant GEMM as GemmNode
participant Policy as WarpPolicy
participant Target as Target Utils
participant Layout as Layout Inference
GEMM->>Target: Get target features (Sm100/Hopper, HasTmem)
GEMM->>GEMM: GetGemmInst (kTCGEN5MMA/kWGMMA/kMMA)
GEMM->>Policy: ComputeWarpPartition(M,N,block_size,target,gemm_inst)
GEMM->>Layout: Infer A/B/C layouts (incl. SM100 Swizzles)
GEMM-->>GEMM: Lower specialized path (TCGEN5MMA uses mbar/tmem)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~150 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 25
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
src/tl_templates/cuda/gemm.h (1)
3-17: Revert to__CUDA_ARCH__-based guards to avoid pulling the wrong header.
__CUDA_ARCH_LIST__expands to a comma-separated list of all targets on the command line (or host-side preprocessing), not the architecture currently being compiled. Comparing it to a number collapses to the last element via the comma operator, i.e., the highest arch in the list, so a multi-arch build (sm75 + sm100) will includegemm_sm100.heven when compiling the sm75 pass. That will pull in instructions unsupported on sm75 and can break codegen. Please keep the per-phase checks on__CUDA_ARCH__(and, if you need a host-side path, add a separate helper that parses__CUDA_ARCH_LIST__instead of using it directly in these comparisons). (docs.nvidia.com)Apply this diff to restore the original guard:
-#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1200)) +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1200)) #include "gemm_sm120.h" -#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1000)) +#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) #include "gemm_sm100.h" -#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) +#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) #include "gemm_sm90.h" -#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890)) +#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)) #include "gemm_sm89.h" -#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750)) +#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) #include "gemm_sm80.h" -#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 700)) +#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) #include "gemm_sm70.h" #else // No matching architecture found #endifsrc/op/gemm_py.cc (1)
95-109: TCGEN5 path never selected in Python GEMM lowering
GemmPyNode::GetGemmInststill only returns{kWGMMA, kMFMA, kMMA}. Insrc/op/gemm.ccthe native lowering path was updated to returnGemmInst::kTCGEN5MMAwhenAllowTCGEN5MMA(target)succeeds, but the Python binding never produces that enum. As a result, on SM100 targets we always fall back tokMMA, so the new TCGEN5 warp policy and lowering you just added remain unreachable whenever kernels are created throughtilelang.op.gemm. Please align this helper with the C++ implementation so it can emitkTCGEN5MMA(e.g., reuse the sharedAllowTCGEN5MMAlogic) and unlock the intended SM100 flow.Apply this adjustment:
GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const { int warp_size = TargetGetWarpSize(target); int num_warps = block_size / warp_size; - bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) && - (num_warps % 4 == 0) && CheckWGMMA(); - if (allow_wgmma) { + bool allow_tcgen5mma = AllowTCGEN5MMA(target); + bool allow_wgmma = AllowWGMMA(block_size, target); + if (allow_tcgen5mma) { + return GemmInst::kTCGEN5MMA; + } else if (allow_wgmma) { return GemmInst::kWGMMA;src/op/gemm.h (1)
115-182: Account for the new GEMM metadata in structural equality/hashing.We now carry
mbarptr,mbar, andC_coordsonGemmNode, butSEqualReduce/SHashReducestill ignore them. That means two nodes that differ only in their barrier pointer or TCGEN5 MMA coordinates collapse to the same key, letting memoized rewrites and caches return the wrong instance. Please include these fields in both equality and hashing.Apply this diff:
@@ - equal(offset_A, other->offset_A) && - equal(offset_B, other->offset_B) && - equal(clear_accum, other->clear_accum) && - equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) && - equal(policy, other->policy); + equal(offset_A, other->offset_A) && + equal(offset_B, other->offset_B) && + equal(clear_accum, other->clear_accum) && + equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) && + equal(mbarptr, other->mbarptr) && equal(mbar, other->mbar) && + equal(C_coords, other->C_coords) && equal(policy, other->policy); @@ - hash_reduce(clear_accum); - hash_reduce(kPack); - hash_reduce(wg_wait); - hash_reduce(policy); + hash_reduce(clear_accum); + hash_reduce(kPack); + hash_reduce(wg_wait); + hash_reduce(mbarptr); + hash_reduce(mbar); + hash_reduce(C_coords); + hash_reduce(policy);
🧹 Nitpick comments (11)
src/layout/tcgen05_layout.h (1)
7-31: Add the standard library includes this header relies on.This header uses
std::stringandstd::tuple, but it doesn't include<string>or<tuple>. Please include them here so the header stays self-contained instead of depending on transitive includes from other headers.testing/python/kernel/test_tilelang_kernel_gemm.py (1)
85-85: Remove unconditional kernel-source print from the test.Dumping every generated kernel into STDOUT makes the test logs huge (each invocation prints hundreds of lines, and this test runs a dozen variants), which slows CI and buries real failures. Please drop the print or guard it behind an explicit debug flag/environment check before merging.
src/op/gemm.cc (1)
21-86: Consider adding template specializations for common atom configurations.The
GetTCGEN5MMAMetafunction has repetitive patterns for atom_m (128, 64, 32) that could benefit from template specialization or a more data-driven approach to reduce code duplication. Additionally, the function returns a pair where the boolean and struct are always consistent - consider returningstd::optional<TCGEN5MMAMeta>instead.Apply this refactor to simplify the meta computation:
-// Return {is_success, meta} -static inline std::pair<bool, TCGEN5MMAMeta> -GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { -// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA. -#define FAIL \ - return { \ - false, TCGEN5MMAMeta { 0, 0, 0 } \ - } -#define SUCCESS(atom_m, atom_n, atom_k) \ - return { \ - true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \ - } +static inline std::optional<TCGEN5MMAMeta> +GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { + // TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA. std::vector<int> ws_valid_atom_ns = {256, 128, 64}; + + int required_k_alignment = 0; if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) && (c_dtype.is_float() && c_dtype.bits() == 32)) { - if (K % 16 != 0) - FAIL; - if (M % 128 == 0) { - for (int atom_n = 256; atom_n >= 16; atom_n -= 16) - if (N % atom_n == 0) - SUCCESS(128, atom_n, 16); - FAIL; - } else if (M % 64 == 0) { - for (int atom_n : ws_valid_atom_ns) - if (N % atom_n == 0) - SUCCESS(64, atom_n, 16); - FAIL; - } else if (M % 32 == 0) { - for (int atom_n : ws_valid_atom_ns) - if (N % atom_n == 0) - SUCCESS(32, atom_n, 16); - FAIL; - } else { - FAIL; - } + required_k_alignment = 16; } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) && (c_dtype.is_float() && c_dtype.bits() == 32)) { - if (K % 32 != 0) - FAIL; - if (M % 128 == 0) { - for (int atom_n = 256; atom_n >= 16; atom_n -= 16) - if (N % atom_n == 0) - SUCCESS(128, atom_n, 32); - FAIL; - } else if (M % 64 == 0) { - for (int atom_n : ws_valid_atom_ns) - if (N % atom_n == 0) - SUCCESS(64, atom_n, 32); - FAIL; - } else if (M % 32 == 0) { - for (int atom_n : ws_valid_atom_ns) - if (N % atom_n == 0) - SUCCESS(32, atom_n, 32); - FAIL; - } else { - FAIL; - } + required_k_alignment = 32; + } else { + return std::nullopt; } - FAIL; -#undef FAIL -#undef SUCCESS + + if (K % required_k_alignment != 0) + return std::nullopt; + + // Try atom_m values in descending order + for (int atom_m : {128, 64, 32}) { + if (M % atom_m != 0) continue; + + if (atom_m == 128) { + // For atom_m=128, try all atom_n from 256 down to 16 + for (int atom_n = 256; atom_n >= 16; atom_n -= 16) { + if (N % atom_n == 0) + return TCGEN5MMAMeta{atom_m, atom_n, required_k_alignment}; + } + } else { + // For atom_m=64 and 32, use predefined valid atom_n values + for (int atom_n : ws_valid_atom_ns) { + if (N % atom_n == 0) + return TCGEN5MMAMeta{atom_m, atom_n, required_k_alignment}; + } + } + } + + return std::nullopt; }src/tl_templates/cuda/tcgen_05_ld.h (3)
17-19: Power-of-2 validation can be simplified.The static assertion
(N & (N - 1)) == 0is a good check for power of 2, but consider using a more readable helper or standard library function if available.Consider extracting the power-of-2 check into a constexpr helper for better readability:
+ template<int N> + static constexpr bool is_power_of_two() { + return N > 0 && (N & (N - 1)) == 0; + } + template <int N> static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { - static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, + static_assert(is_power_of_two<N>() && N <= 128, "N must be a power of 2 and lies between 1 ~ 128");
179-180: Trap on invalid template parameters lacks diagnostic information.Using
asm volatile("trap")for the else case provides no diagnostic information. While this should never be reached due to static_assert, having a more informative error would help debugging.Consider adding a compile-time error message:
} else { - asm volatile("trap"); + static_assert(N <= 128 && (N & (N - 1)) == 0, "Invalid N value for tmem_ld_32dp32bNx::copy"); + __builtin_unreachable(); }
683-711: 32dp wrapper classes use magic number for address offset.The 32-datapath wrapper classes use
(16 << 16)as an address offset, which appears to be encoding datapath lane information in the upper bits. This magic number should be documented or defined as a named constant.Define the magic number as a named constant with documentation:
+// TMEM address encoding: bits [31:16] represent the datapath lane offset +// For 32dp operations, we need to access both 16-lane groups +constexpr uint32_t TMEM_LANE_OFFSET = 16 << 16; + // 32 data path lanes, 64-bit pattern, repeated N times // (conducted with 2x16dp64bNx) class tmem_ld_32dp64bNx { public: template <int N> static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) { tmem_ld_16dp64bNx::copy<N>(src_addr, dst_ptr); - tmem_ld_16dp64bNx::copy<N>(src_addr + (16 << 16), dst_ptr + N); + tmem_ld_16dp64bNx::copy<N>(src_addr + TMEM_LANE_OFFSET, dst_ptr + N); } };src/tl_templates/cuda/copy_sm100.h (1)
76-83: Recursive template instantiation depth could be problematic.The
get_floor_log2template uses unbounded recursion which could hit template instantiation depth limits for large N values. While the current use cases seem bounded, consider adding a depth limit check.Add a recursion depth limit:
template <int N, int K = 0> __device__ __forceinline__ constexpr int get_floor_log2() { static_assert(N > 0); + static_assert(K < 32, "Recursion depth exceeded in get_floor_log2"); if constexpr ((1 << (K + 1)) > N) return K; else return get_floor_log2<N, K + 1>(); }src/tl_templates/cuda/gemm_sm100.h (4)
160-163: Remove redundant static assertionLines 162-163 duplicate the validation already performed at lines 149-151. The first assertion is more informative as it names the specific type.
Apply this diff to remove the redundant assertion:
using FrgTypeC = UMMA::tmem_frg_ws_1sm<c_type>; - static_assert(sizeof_bits_v<ValTypeA> <= sizeof_bits_v<uint8_t> && - sizeof_bits_v<ValTypeB> <= sizeof_bits_v<uint8_t>); - // Logical shape-K is always 256bits, transform to units of elements constexpr static int K = 32;
325-330: Address the TODO commentThe TODO comment indicates that the implementation is using the
.wsvariant as a workaround. This should be tracked for future optimization.The TODO comment suggests there might be a performance impact. Would you like me to create an issue to track implementing proper support for the non-.ws variant when M == 64, or would you prefer to document the performance implications?
365-370: Reminder: Implement gemm_ts functionThe TODO comment indicates a missing
gemm_tsimplementation. Please ensure this is tracked for completion.Do you want me to generate the
gemm_tsimplementation or open a new issue to track this task?
248-253: Inconsistent specialization patternFor fp8 types with M==128, the code uses
MMA_Traits<SM100_MMA_F8F6F4_SS, ...>wrapper, while for bf16/half types with M==128, it directly uses the MMA type (e.g.,SM100_MMA_F16BF16_SS). Consider using a consistent pattern.Either wrap all M==128 cases in
MMA_Traitsor use the direct type for all. The current mixed approach may confuse maintainers.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (49)
.clang-tidy(1 hunks)examples/gemm_sm100/README.md(1 hunks)examples/gemm_sm100/gemm_mma.py(1 hunks)examples/gemm_sm100/gemm_tcgen5mma.py(1 hunks)src/layout/gemm_layouts.cc(2 hunks)src/layout/layout.h(2 hunks)src/layout/tcgen05_layout.cc(1 hunks)src/layout/tcgen05_layout.h(1 hunks)src/op/builtin.cc(3 hunks)src/op/builtin.h(2 hunks)src/op/fill.cc(1 hunks)src/op/finalize_reducer.cc(1 hunks)src/op/gemm.cc(11 hunks)src/op/gemm.h(5 hunks)src/op/gemm_py.cc(2 hunks)src/op/gemm_py.h(0 hunks)src/op/gemm_sp.cc(1 hunks)src/op/reduce.cc(1 hunks)src/runtime/runtime.cc(3 hunks)src/target/codegen_cpp.cc(0 hunks)src/target/codegen_cuda.cc(13 hunks)src/target/codegen_cuda.h(1 hunks)src/target/utils.cc(2 hunks)src/target/utils.h(1 hunks)src/tl_templates/cuda/copy.h(1 hunks)src/tl_templates/cuda/copy_sm100.h(1 hunks)src/tl_templates/cuda/cuda_fp8.h(3 hunks)src/tl_templates/cuda/debug.h(2 hunks)src/tl_templates/cuda/gemm.h(2 hunks)src/tl_templates/cuda/gemm_sm100.h(1 hunks)src/tl_templates/cuda/tcgen_05.h(1 hunks)src/tl_templates/cuda/tcgen_05_ld.h(1 hunks)src/transform/loop_vectorize.cc(4 hunks)src/transform/lower_shared_tmem.cc(1 hunks)src/transform/lower_tile_op.cc(6 hunks)src/transform/pipeline_planning.cc(7 hunks)testing/python/cpu/test_tilelang_cpu_gemm.py(2 hunks)testing/python/kernel/test_tilelang_kernel_gemm.py(1 hunks)testing/python/transform/test_tilelang_transform_layout_inference.py(1 hunks)testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py(1 hunks)testing/python/webgpu/test_webgpu_codegen.py(1 hunks)tilelang/contrib/nvcc.py(1 hunks)tilelang/engine/phase.py(3 hunks)tilelang/language/__init__.py(1 hunks)tilelang/language/allocate.py(1 hunks)tilelang/language/gemm.py(5 hunks)tilelang/transform/__init__.py(2 hunks)tilelang/transform/pass_config.py(1 hunks)tilelang/utils/target.py(1 hunks)
💤 Files with no reviewable changes (2)
- src/target/codegen_cpp.cc
- src/op/gemm_py.h
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
PR: tile-ai/tilelang#794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.
Applied to files:
src/transform/pipeline_planning.cc
🧬 Code graph analysis (32)
tilelang/language/allocate.py (1)
tilelang/language/ast/ir.py (1)
alloc_buffer(441-508)
src/target/utils.h (1)
src/target/utils.cc (4)
TargetIsSm100(56-61)TargetIsSm100(56-56)TargetHasTmem(114-118)TargetHasTmem(114-114)
tilelang/language/__init__.py (1)
tilelang/language/allocate.py (1)
alloc_tmem(92-118)
tilelang/transform/__init__.py (1)
src/transform/lower_shared_tmem.cc (4)
LowerSharedTmem(285-291)LowerSharedTmem(285-285)LowerSharedTmem(296-301)LowerSharedTmem(296-296)
testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py (1)
tilelang/transform/__init__.py (1)
LegalizeVectorizedLoop(241-249)
testing/python/cpu/test_tilelang_cpu_gemm.py (2)
tilelang/engine/lower.py (1)
lower(190-242)tilelang/jit/__init__.py (1)
compile(33-86)
testing/python/webgpu/test_webgpu_codegen.py (1)
tilelang/language/ast/ir.py (1)
target(1682-1713)
src/transform/lower_shared_tmem.cc (5)
src/target/codegen_cuda.cc (26)
op(217-232)op(217-217)op(1611-1613)op(1611-1611)op(1614-1616)op(1614-1614)VisitStmt_(296-315)VisitStmt_(296-296)VisitStmt_(1934-1973)VisitStmt_(1934-1934)VisitStmt_(1975-2035)VisitStmt_(1975-1975)VisitStmt_(2037-2053)VisitStmt_(2037-2037)VisitExpr_(886-971)VisitExpr_(886-886)VisitExpr_(1176-1932)VisitExpr_(1176-1176)VisitExpr_(2055-2069)VisitExpr_(2055-2055)VisitExpr_(2071-2139)VisitExpr_(2071-2072)VisitExpr_(2141-2289)VisitExpr_(2141-2142)VisitExpr_(2345-2348)VisitExpr_(2345-2346)src/layout/layout.h (1)
Array(34-66)src/target/utils.cc (2)
TargetGetWarpSize(127-132)TargetGetWarpSize(127-127)tilelang/language/tir/op.py (1)
tvm_access_ptr(650-675)tilelang/transform/__init__.py (1)
LowerSharedTmem(443-446)
tilelang/engine/phase.py (4)
tilelang/contrib/nvcc.py (2)
have_tma(434-449)is_hopper(452-457)src/transform/lower_shared_tmem.cc (4)
LowerSharedTmem(285-291)LowerSharedTmem(285-285)LowerSharedTmem(296-301)LowerSharedTmem(296-296)tilelang/transform/__init__.py (2)
LowerSharedTmem(443-446)RewriteWgmmaSync(117-125)src/transform/wgmma_sync_rewriter.cc (2)
RewriteWgmmaSync(262-267)RewriteWgmmaSync(262-262)
src/op/reduce.cc (1)
src/target/utils.cc (4)
TargetIsHopper(49-54)TargetIsHopper(49-49)TargetIsSm100(56-61)TargetIsSm100(56-56)
src/layout/tcgen05_layout.h (1)
src/layout/tcgen05_layout.cc (10)
getTcgen05Meta_32dp32b(22-30)getTcgen05Meta_32dp32b(22-22)getTcgen05Meta_32dp64b(32-44)getTcgen05Meta_32dp64b(32-32)getTcgen05Meta_32dp128b(46-57)getTcgen05Meta_32dp128b(46-46)getTcgen05Meta_32dp256b(59-72)getTcgen05Meta_32dp256b(59-59)expandTcgen05Layout(74-108)expandTcgen05Layout(75-76)
src/target/codegen_cuda.cc (2)
src/target/codegen_cpp.cc (2)
PrintType(95-164)PrintType(95-95)src/target/codegen_hip.cc (2)
PrintType(178-421)PrintType(178-178)
testing/python/kernel/test_tilelang_kernel_gemm.py (5)
examples/gemv/example_gemv.py (1)
kernel(236-285)tilelang/jit/adapter/ctypes/adapter.py (1)
get_kernel_source(290-296)tilelang/jit/adapter/cython/adapter.py (1)
get_kernel_source(516-522)tilelang/jit/kernel.py (1)
get_kernel_source(378-389)tilelang/jit/adapter/base.py (1)
get_kernel_source(51-52)
src/target/codegen_cuda.h (1)
src/target/codegen_cuda.cc (12)
GetVecLoad(1097-1118)GetVecLoad(1097-1099)t(25-63)t(25-25)t(67-74)t(67-67)t(78-94)t(78-78)t(98-106)t(98-99)PrintVecStore(1120-1142)PrintVecStore(1120-1122)
examples/gemm_sm100/gemm_mma.py (11)
examples/gemm_sm100/gemm_tcgen5mma.py (2)
matmul(8-62)main(29-60)testing/python/kernel/test_tilelang_kernel_gemm.py (4)
matmul(5-50)main(28-48)main(323-346)main(443-466)tilelang/language/allocate.py (2)
alloc_shared(21-36)alloc_fragment(53-64)tilelang/language/fill.py (1)
clear(24-48)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/copy.py (1)
copy(84-152)tilelang/language/gemm.py (1)
gemm(10-212)tilelang/jit/__init__.py (1)
compile(33-86)tilelang/jit/kernel.py (2)
out_idx(446-447)get_profiler(360-376)tilelang/transform/pass_config.py (1)
PassConfigKey(6-101)tilelang/utils/tensor.py (1)
TensorSupplyType(11-18)
src/layout/tcgen05_layout.cc (1)
src/layout/gemm_layouts.cc (2)
make_itervar(16-19)make_itervar(16-16)
tilelang/language/gemm.py (2)
tilelang/tileop/gemm/gemm_base.py (8)
C(74-75)M(33-34)N(37-38)K(41-42)policy(118-119)clear_accum(106-107)k_pack(110-111)wg_wait(114-115)tilelang/language/tir/op.py (1)
call_intrin(119-144)
src/op/gemm.h (2)
tilelang/tileop/gemm/gemm_base.py (1)
offset_A(98-99)src/op/gemm.cc (4)
AllowTCGEN5MMA(169-176)AllowTCGEN5MMA(169-169)AllowWGMMA(178-186)AllowWGMMA(178-178)
src/tl_templates/cuda/copy_sm100.h (2)
src/tl_templates/cuda/tcgen_05.h (3)
tl(10-60)__device__(28-30)__device__(32-34)src/tl_templates/cuda/tcgen_05_ld.h (5)
tl(10-713)tmem_ld_32dp32bNx(13-182)tmem_ld_32dp64bNx(684-691)tmem_ld_32dp128bNx(694-701)tmem_ld_32dp256bNx(704-711)
src/target/utils.cc (1)
src/op/gemm.cc (2)
GetArchInt(476-487)GetArchInt(476-476)
src/op/gemm.cc (5)
src/op/operator.cc (2)
GetVarFromAccessPtr(75-82)GetVarFromAccessPtr(75-75)src/target/utils.cc (8)
TargetIsSm100(56-61)TargetIsSm100(56-56)TargetGetWarpSize(127-132)TargetGetWarpSize(127-127)TargetIsHopper(49-54)TargetIsHopper(49-49)TargetIsCDNA(70-79)TargetIsCDNA(70-70)src/op/gemm_py.cc (4)
CheckWGMMA(141-191)CheckWGMMA(141-141)GetGemmInst(95-109)GetGemmInst(95-95)src/layout/gemm_layouts.cc (4)
makeGemmABLayoutSm100(767-787)makeGemmABLayoutSm100(767-768)make_itervar(16-19)make_itervar(16-16)src/layout/tcgen05_layout.cc (2)
make_itervar(17-20)make_itervar(17-17)
src/layout/layout.h (2)
src/layout/gemm_layouts.cc (4)
make_itervar(16-19)make_itervar(16-16)makeGemmABLayoutSm100(767-787)makeGemmABLayoutSm100(767-768)src/layout/tcgen05_layout.cc (2)
make_itervar(17-20)make_itervar(17-17)
src/transform/loop_vectorize.cc (2)
src/transform/loop_vectorize_dynamic.cc (18)
node(80-85)node(80-80)node(92-96)node(92-92)node(98-112)node(98-98)node(114-120)node(114-114)node(122-125)node(122-122)node(127-135)node(127-127)node(263-266)node(263-263)node(280-283)node(280-280)indices(141-194)indices(141-141)src/target/utils.cc (2)
TargetIsSm100(56-61)TargetIsSm100(56-56)
testing/python/transform/test_tilelang_transform_layout_inference.py (1)
tilelang/transform/__init__.py (1)
LayoutInference(39-47)
src/tl_templates/cuda/tcgen_05_ld.h (2)
src/tl_templates/cuda/copy_sm100.h (2)
tl(6-134)int(77-83)src/tl_templates/cuda/tcgen_05.h (1)
tl(10-60)
examples/gemm_sm100/gemm_tcgen5mma.py (6)
tilelang/env.py (1)
disable_cache(232-233)tilelang/language/allocate.py (3)
alloc_tmem(92-118)alloc_barrier(80-89)alloc_fragment(53-64)tilelang/language/copy.py (1)
copy(84-152)tilelang/language/gemm.py (1)
gemm(10-212)tilelang/language/builtin.py (1)
mbarrier_wait_parity(172-219)tilelang/jit/__init__.py (1)
compile(33-86)
src/op/finalize_reducer.cc (1)
src/target/utils.cc (4)
TargetIsHopper(49-54)TargetIsHopper(49-49)TargetIsSm100(56-61)TargetIsSm100(56-56)
src/transform/pipeline_planning.cc (3)
src/transform/layout_inference.cc (4)
op(42-48)op(42-42)expr(379-391)expr(379-379)src/transform/inject_pipeline.cc (18)
op(165-181)op(165-165)op(183-196)op(183-183)op(198-211)op(198-198)op(213-219)op(213-213)op(839-945)op(839-839)op(947-964)op(947-947)op(966-983)op(966-966)buf(485-487)buf(485-485)call(133-163)call(133-134)tilelang/language/builtin.py (1)
mbarrier_wait_parity(172-219)
src/tl_templates/cuda/gemm_sm100.h (4)
src/tl_templates/cuda/tcgen_05.h (4)
void(12-18)void(20-26)void(62-68)tl(10-60)src/tl_templates/cuda/common.h (1)
uint32_t(115-117)tilelang/language/gemm.py (1)
gemm(10-212)src/tl_templates/cuda/tcgen_05_ld.h (1)
tl(10-713)
src/transform/lower_tile_op.cc (2)
src/transform/loop_partition.cc (6)
op(128-154)op(128-128)op(157-162)op(157-157)op(164-169)op(164-164)src/target/codegen_cuda.cc (8)
VisitStmt_(296-315)VisitStmt_(296-296)VisitStmt_(1934-1973)VisitStmt_(1934-1934)VisitStmt_(1975-2035)VisitStmt_(1975-1975)VisitStmt_(2037-2053)VisitStmt_(2037-2037)
src/op/gemm_py.cc (1)
src/op/gemm.cc (2)
GetGemmInst(188-204)GetGemmInst(188-188)
src/tl_templates/cuda/tcgen_05.h (2)
src/tl_templates/cuda/copy.h (2)
tl(14-80)void(16-18)src/tl_templates/cuda/tcgen_05_ld.h (1)
tl(10-713)
🪛 Ruff (0.13.1)
tilelang/language/__init__.py
45-45: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
examples/gemm_sm100/gemm_tcgen5mma.py
20-20: Unused function argument: num_stages
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: format-check
🔇 Additional comments (16)
src/runtime/runtime.cc (2)
75-86: LGTM on newline swapReplacing
std::endlwith'\n'keeps the formatting while avoiding the extra flush cost, which is consistent with our logging cleanup.
102-104: LGTM on fatal logRetains the newline while eliminating the flush, so the fatal path still prints the descriptor dump without incurring the extra flush penalty.
src/tl_templates/cuda/tcgen_05.h (1)
36-60: Good defensive programming with static assertionsThe static assertions properly validate the supported M/N dimensions for the TCGEN5MMA operations. The error messages are clear and informative.
src/transform/loop_vectorize.cc (3)
48-71: Well-structured visitor pattern implementationThe
VectorizeFindGlobalAccessclass properly encapsulates global access detection with clear visitor methods and a simple public API.
151-155: Proper handling of Cast node vectorizationThe addition of
VisitExpr_forCastNodecorrectly adjusts vector size based on the target data type width using ZeroAwareGCD.
191-191: State management change for runtime configurationChanging
vector_load_bits_max_from const to non-const allows runtime configuration based on target. This is a necessary change for the SM100 256-bit vectorization feature.src/target/codegen_cuda.cc (6)
360-363: Good extension for wider FP16 vectorsThe extension properly handles FP16 vectors with lanes up to 16 by using
ulonglongtypes. The error messages clearly indicate the lane requirements.
508-511: Support for int8x32 vectorsSuccessfully adds support for 32-lane int8 vectors using
longlong4type, which aligns with the 256-bit wide path requirements.
1097-1142: Well-implemented wide vector load/store pathsThe
GetVecLoadandPrintVecStoremethods properly check for 256-bit vector operations on global memory and use the appropriate TL helpers (tl::ld_global_256/tl::st_global_256). The assertion ensures only 256-bit operations use this path.
1289-1290: Good addition of fence barrier init supportAdding support for
tl::ptx_fence_barrier_initproperly extends the barrier initialization capabilities.
2157-2171: Proper handling of int8x32 broadcastThe implementation correctly handles broadcasting for both 4-lane and 32-lane int8 vectors with appropriate bit manipulation and type construction.
577-585: No int64 vectors beyond 4 lanes in codebase—no breakage
Search forlonglong5-9andlonglong10-16across all.cc,.cpp, and.hfiles returned no hits, indicating no prior support for int64 vectors with lanes > 4. The new error path aligns with existing capabilities.src/tl_templates/cuda/copy_sm100.h (1)
98-132: LGTM! Well-structured template wrappers for TMEM loading.The tcgen05_ld wrapper functions provide a clean interface for different data path widths, properly calling the fence after loading. The MAX_LOGN parameters (7, 7, 6, 5) correctly match the maximum supported N values for each variant.
src/op/gemm.cc (3)
765-823: Add tests to validate TCGEN5MMA D/E/G layout computation
No existing coverage for Layout D/E/G (src/op/gemm.cc:789–822). Cross-check the FloorDiv/FloorMod formulas against NVIDIA’s PTX “tcgen05 data-path” docs and the reference in src/layout/tcgen05_layout.cc, and add unit tests for meta.atom_m = 128, 64, 32 cases to ensure correctness.
511-576: Clarify multi-warp TCGEN5MMA execution
The lowering in src/op/gemm.cc lines 571–574 only emits the TCGEN5MMA call when
FloorDiv(thread_var, 32) == FloorDiv(thread_bounds.min, 32)(i.e. warp 0), skipping all other warps. We found no comments, tests, or external docs covering multi-warp invocation (the sm100 note at lines 609–612 only addresses single-thread issuance within a warp). Confirm whether your schedule intends only the first warp to issue the MMA instruction or adjust the logic to invoke it on each warp as needed.
169-176: AllowTCGEN5MMA correctly gates SM100 (sm_100–sm_103)
TargetIsSm100(target) returns true for arch 100–103, and no per‐variant TCGEN5MMA restrictions were found.
| -cppcoreguidelines-pro-bounds-pointer-arithmetic, | ||
| -cppcoreguidelines-pro-bounds-array-to-pointer-decay, | ||
| -clang-analyzer-deadcode.DeadStores, | ||
| -clang-analyzer-optin.cplusplus.VirtualCall, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Undo the blanket disable of the bounds and virtual-call checks
These three checks were part of the “Retained categories” for a reason—they regularly flag genuine memory-safety and lifecycle bugs across our CUDA backends. Disabling them globally trades concrete signal for convenience and will let regressions slip through (e.g., silent out-of-bounds pointer arithmetic or virtual dispatch in constructors/destructors). If there are nuisance warnings in new SM100 codepaths, please suppress them locally with targeted NOLINT or refactor the offending spots instead of removing the protections repo-wide.
- -cppcoreguidelines-pro-bounds-pointer-arithmetic,
- -cppcoreguidelines-pro-bounds-array-to-pointer-decay,
- -clang-analyzer-optin.cplusplus.VirtualCall,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| -cppcoreguidelines-pro-bounds-pointer-arithmetic, | |
| -cppcoreguidelines-pro-bounds-array-to-pointer-decay, | |
| -clang-analyzer-deadcode.DeadStores, | |
| -clang-analyzer-optin.cplusplus.VirtualCall, | |
| -clang-analyzer-deadcode.DeadStores, |
🤖 Prompt for AI Agents
In .clang-tidy around lines 45 to 48, the rules
cppcoreguidelines-pro-bounds-pointer-arithmetic,
cppcoreguidelines-pro-bounds-array-to-pointer-decay and
clang-analyzer-optin.cplusplus.VirtualCall were removed; restore these checks by
removing them from the negative list so they run globally, and instead address
any false positives in SM100 or other new code by adding targeted NOLINT
comments at the exact offending lines or refactoring the code to eliminate the
warning (e.g., avoid pointer-arithmetic that can overflow, use safe indexing,
and avoid virtual calls in constructors/destructors).
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): | ||
| T.copy(A[by * block_M, k * block_K], A_shared) | ||
| T.copy(B[bx * block_N, k * block_K], B_shared) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Respect the num_stages argument
matmul advertises a tunable num_stages, but the pipelined loop hardcodes num_stages=1, so callers (including the example, which passes 0) silently get a different staging depth than requested. That breaks tuning knobs and can desynchronize the mbarrier usage. Please plumb the parameter through.
- for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
+ for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1): | |
| T.copy(A[by * block_M, k * block_K], A_shared) | |
| T.copy(B[bx * block_N, k * block_K], B_shared) | |
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): | |
| T.copy(A[by * block_M, k * block_K], A_shared) | |
| T.copy(B[bx * block_N, k * block_K], B_shared) |
🤖 Prompt for AI Agents
In examples/gemm_sm100/gemm_tcgen5mma.py around lines 42 to 44, the pipelined
loop currently hardcodes num_stages=1 which ignores the matmul’s tunable
num_stages (callers may pass 0 or other values) and can desynchronize mbarrier
usage; replace the hardcoded 1 with the function/local parameter that holds the
requested staging depth (e.g., num_stages) so the call reads
T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages), ensuring the variable
is in scope and propagated from the matmul signature.
| T.copy(A[by * block_M, k * block_K], A_shared) | ||
| T.copy(B[bx * block_N, k * block_K], B_shared) | ||
| T.gemm( | ||
| A_shared, | ||
| B_shared, | ||
| C_tmem, | ||
| trans_A, | ||
| trans_B, | ||
| mbar=mbar, | ||
| wg_wait=-1, | ||
| clear_accum=k == 0) | ||
| T.mbarrier_wait_parity(mbar, k % 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix TMEM tile loads for transpose flags
These two T.copy statements ignore trans_A/trans_B and always fetch tiles as if A were row-major and B were transposed. For trans_A=True, A_shape is (K, M) yet we still step the first dimension with by * block_M, which walks past the K extent; for trans_B=False, B_shape is (K, N) but we index its first dimension with bx * block_N. Anything outside the single (False, True) combination used in the demo will read the wrong region or fall off the tensor. The unit-test kernel in testing/python/kernel/test_tilelang_kernel_gemm.py fixes this by branching on the transpose flags—please mirror that logic here.
- T.copy(A[by * block_M, k * block_K], A_shared)
- T.copy(B[bx * block_N, k * block_K], B_shared)
+ if trans_A:
+ T.copy(A[k * block_K, by * block_M], A_shared)
+ else:
+ T.copy(A[by * block_M, k * block_K], A_shared)
+ if trans_B:
+ T.copy(B[bx * block_N, k * block_K], B_shared)
+ else:
+ T.copy(B[k * block_K, bx * block_N], B_shared)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| T.copy(A[by * block_M, k * block_K], A_shared) | |
| T.copy(B[bx * block_N, k * block_K], B_shared) | |
| T.gemm( | |
| A_shared, | |
| B_shared, | |
| C_tmem, | |
| trans_A, | |
| trans_B, | |
| mbar=mbar, | |
| wg_wait=-1, | |
| clear_accum=k == 0) | |
| T.mbarrier_wait_parity(mbar, k % 2) | |
| if trans_A: | |
| T.copy(A[k * block_K, by * block_M], A_shared) | |
| else: | |
| T.copy(A[by * block_M, k * block_K], A_shared) | |
| if trans_B: | |
| T.copy(B[bx * block_N, k * block_K], B_shared) | |
| else: | |
| T.copy(B[k * block_K, bx * block_N], B_shared) | |
| T.gemm( | |
| A_shared, | |
| B_shared, | |
| C_tmem, | |
| trans_A, | |
| trans_B, | |
| mbar=mbar, | |
| wg_wait=-1, | |
| clear_accum=k == 0) | |
| T.mbarrier_wait_parity(mbar, k % 2) |
🤖 Prompt for AI Agents
In examples/gemm_sm100/gemm_tcgen5mma.py around lines 43 to 54, the two T.copy
calls always index A and B as if A is (M,K) and B is (N,K) and ignore
trans_A/trans_B; this causes out-of-bounds or wrong tiles when transposed. Fix
by branching on trans_A and trans_B like the unit test: when trans_A is False
copy A[by * block_M, k * block_K] (tile dims M x K), but when trans_A is True
copy A[k * block_K, by * block_M] (tile dims K x M); similarly, when trans_B is
True copy B[bx * block_N, k * block_K] (tile dims N x K), else when trans_B is
False copy B[k * block_K, bx * block_N] (tile dims K x N). Mirror the
index-swapping and tile extents from
testing/python/kernel/test_tilelang_kernel_gemm.py so TMEM loads match the
declared shapes.
| # Compile kernel | ||
| jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={ | ||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required | ||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required | ||
| }) | ||
|
|
||
| # Run test | ||
| a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) | ||
| b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) | ||
| c = jit_kernel(a, b) | ||
|
|
||
| # Verify correctness | ||
| ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16) | ||
| torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) | ||
|
|
||
| # Performance benchmark | ||
| profiler = jit_kernel.get_profiler() | ||
| latency = profiler.do_bench() | ||
| print(f"Latency: {latency} ms") | ||
| print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the example to compile main rather than an undefined symbol.
Lines 86-93 invoke tilelang.compile(func, ...), but the snippet only defines main. Copying the example as written will raise a NameError. Update the call (and any downstream references) to use the defined prim func so users can execute the example successfully.
-jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={
+jit_kernel = tilelang.compile(main, out_idx=[2], target="cuda", pass_configs={📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Compile kernel | |
| jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={ | |
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required | |
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required | |
| }) | |
| # Run test | |
| a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) | |
| b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) | |
| c = jit_kernel(a, b) | |
| # Verify correctness | |
| ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16) | |
| torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) | |
| # Performance benchmark | |
| profiler = jit_kernel.get_profiler() | |
| latency = profiler.do_bench() | |
| print(f"Latency: {latency} ms") | |
| print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS") | |
| # Compile kernel | |
| jit_kernel = tilelang.compile(main, out_idx=[2], target="cuda", pass_configs={ | |
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required | |
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required | |
| }) | |
| # Run test | |
| a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) | |
| b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) | |
| c = jit_kernel(a, b) | |
| # Verify correctness | |
| ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16) | |
| torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) | |
| # Performance benchmark | |
| profiler = jit_kernel.get_profiler() | |
| latency = profiler.do_bench() | |
| print(f"Latency: {latency} ms") | |
| print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS") |
🤖 Prompt for AI Agents
In examples/gemm_sm100/README.md around lines 85 to 104, the code calls
tilelang.compile(func, ...) but only defines a prim func named main; change the
compile invocation to tilelang.compile(main, ...) (and any downstream references
expecting that compiled object remain the same) so the example compiles the
defined symbol instead of the undefined name func; keep the existing
pass_configs, inputs, profiling and verification lines unchanged.
| if (mat_continuous % (vector_size * 8) == 0) | ||
| return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size); | ||
| else if (mat_continuous % (vector_size * 4) == 0) | ||
| return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size); | ||
| else if (mat_continuous % (vector_size * 2) == 0) | ||
| return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, | ||
| element_size); | ||
| else if (mat_continuous % (vector_size * 8) == 0) | ||
| else if (mat_continuous % vector_size == 0) | ||
| return makeGemmLayoutLinear(mat_stride, mat_continuous); | ||
| else | ||
| ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride | ||
| << ", continuous=" << mat_continuous | ||
| << ", element_size=" << element_size << ", kfactor=" << kfactor; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regression: Hopper layout now aborts for valid widths
For non power-of-two tiles (e.g., element_size == 8, mat_continuous == 48) we used to fall back to makeGemmABLayoutPadded, so Hopper kernels kept working. The new branch ends with ICHECK(0), which now fatals for those same shapes. That’s a correctness regression that will abort user programs. Please restore a padded/linear fallback instead of hard failing.
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
- else if (mat_continuous % vector_size == 0)
- return makeGemmLayoutLinear(mat_stride, mat_continuous);
- else
- ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride
- << ", continuous=" << mat_continuous
- << ", element_size=" << element_size << ", kfactor=" << kfactor;
+ else if (mat_continuous % vector_size == 0)
+ return makeGemmLayoutLinear(mat_stride, mat_continuous);
+ return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if (mat_continuous % (vector_size * 8) == 0) | |
| return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size); | |
| else if (mat_continuous % (vector_size * 4) == 0) | |
| return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size); | |
| else if (mat_continuous % (vector_size * 2) == 0) | |
| return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, | |
| element_size); | |
| else if (mat_continuous % (vector_size * 8) == 0) | |
| else if (mat_continuous % vector_size == 0) | |
| return makeGemmLayoutLinear(mat_stride, mat_continuous); | |
| else | |
| ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride | |
| << ", continuous=" << mat_continuous | |
| << ", element_size=" << element_size << ", kfactor=" << kfactor; | |
| } | |
| if (mat_continuous % (vector_size * 8) == 0) | |
| return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size); | |
| else if (mat_continuous % (vector_size * 4) == 0) | |
| return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size); | |
| else if (mat_continuous % (vector_size * 2) == 0) | |
| return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous, | |
| element_size); | |
| else if (mat_continuous % vector_size == 0) | |
| return makeGemmLayoutLinear(mat_stride, mat_continuous); | |
| return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size); | |
| } |
🤖 Prompt for AI Agents
In src/layout/gemm_layouts.cc around lines 752-765, the final ICHECK(0) causes
Hopper to abort for valid non-power-of-two widths (e.g., element_size==8,
mat_continuous==48); restore the previous padded/linear fallback instead of
fatally asserting by returning the padded GEMM A/B layout (same call used
before, e.g. makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size,
kfactor) or, if that API isn’t available here, fall back to
makeGemmLayoutLinear/makeGemmABLayoutPadded as appropriate) so the code returns
a valid layout for those shapes rather than calling ICHECK(0).
| Stmt VisitStmt_(const BlockNode *op) final { | ||
| auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); | ||
| if (op->annotations.count(attr::kLayoutMap)) { | ||
| block.CopyOnWrite()->annotations.Set(attr::kLayoutMap, layout_remap_); | ||
| } | ||
| return block; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not attach the global layout_remap_ to every block
layout_remap_ is populated while we traverse the entire PrimFunc, so by the time we call LayoutRemapRewriter::Substitute it contains the union of all remapped buffers. Overwriting each block’s attr::kLayoutMap with that global map means every annotated block now advertises layouts for buffers it does not allocate or match. That breaks the per-block contract that other passes (including this one, in VisitStmt_(const BlockNode*)) rely on, and it quickly leads to incorrect lowering or hard asserts once a downstream visitor iterates the annotation and cannot find the corresponding buffers in the block. Please filter the annotation so that we only write back the layouts for buffers owned by the current block.
Apply this diff to keep the annotation block-local:
auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
if (op->annotations.count(attr::kLayoutMap)) {
- block.CopyOnWrite()->annotations.Set(attr::kLayoutMap, layout_remap_);
+ Map<Buffer, Layout> filtered;
+ for (const Buffer &buffer : block->alloc_buffers) {
+ if (layout_remap_.count(buffer)) {
+ filtered.Set(buffer, layout_remap_[buffer]);
+ }
+ }
+ for (const MatchBufferRegion &match_buffer : block->match_buffers) {
+ const Buffer &buffer = match_buffer->buffer;
+ if (layout_remap_.count(buffer)) {
+ filtered.Set(buffer, layout_remap_[buffer]);
+ }
+ }
+ ICHECK(!filtered.empty())
+ << "Expected layout remap info for block-local buffers";
+ block.CopyOnWrite()->annotations.Set(attr::kLayoutMap, filtered);
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Stmt VisitStmt_(const BlockNode *op) final { | |
| auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); | |
| if (op->annotations.count(attr::kLayoutMap)) { | |
| block.CopyOnWrite()->annotations.Set(attr::kLayoutMap, layout_remap_); | |
| } | |
| return block; | |
| } | |
| Stmt VisitStmt_(const BlockNode *op) final { | |
| auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); | |
| if (op->annotations.count(attr::kLayoutMap)) { | |
| // Only propagate remap entries for buffers that this block actually allocates or matches | |
| Map<Buffer, Layout> filtered; | |
| for (const Buffer &buffer : block->alloc_buffers) { | |
| if (layout_remap_.count(buffer)) { | |
| filtered.Set(buffer, layout_remap_[buffer]); | |
| } | |
| } | |
| for (const MatchBufferRegion &match_buffer : block->match_buffers) { | |
| const Buffer &buffer = match_buffer->buffer; | |
| if (layout_remap_.count(buffer)) { | |
| filtered.Set(buffer, layout_remap_[buffer]); | |
| } | |
| } | |
| ICHECK(!filtered.empty()) | |
| << "Expected layout remap info for block-local buffers"; | |
| block.CopyOnWrite()->annotations.Set(attr::kLayoutMap, filtered); | |
| } | |
| return block; | |
| } |
🤖 Prompt for AI Agents
In src/transform/lower_tile_op.cc around lines 94-100, the code attaches the
global layout_remap_ to every block which leaks layouts for buffers the block
doesn't own; instead build a block-local map by iterating the block's owned
buffers (alloc_buffers and match_buffers / any BlockNode fields that list
buffers the block allocates or matches), copy entries from layout_remap_ only
for those buffer vars, and write that filtered map into the block's
attr::kLayoutMap so the annotation only contains layouts for buffers owned by
the current block.
| Buffer buf = op->buffer; | ||
| if (buf->data->type_annotation.as<PointerTypeNode>()->storage_scope == | ||
| "shared") { | ||
| // We only care about shared.tmem buffers | ||
| ICHECK(!result.defined()) | ||
| << "TmemLoadCollector: More than one shared buffer visited"; | ||
| result = buf; | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| /*! | ||
| * \brief Build the dependency chain between async operations and their | ||
| * corresponding buffers & synchronizations. | ||
| * | ||
| * Example: | ||
| * If we encounter the following pattern: | ||
| * | ||
| * tcgen5mma_gemm_ts(..., mbar, ...) | ||
| * mbarrier_wait_parity(mbar) | ||
| * | ||
| * The builder will link the mbarrier to the buffers used in the | ||
| * TCGEN5MMA | ||
| */ | ||
| class AsyncDependencyChainBuilder : public StmtExprVisitor { | ||
| public: | ||
| AsyncDependencyChainBuilder(Map<Var, Buffer> buffer_data_to_buffer) | ||
| : buffer_data_to_buffer_(buffer_data_to_buffer) {} | ||
|
|
||
| std::unordered_map<const BufferNode *, Array<BufferRegion>> | ||
| mbar_to_buffer_reads_; | ||
|
|
||
| std::unordered_map<const BufferNode *, Array<BufferRegion>> | ||
| mbar_to_buffer_writes_; | ||
|
|
||
| private: | ||
| Map<Var, Buffer> buffer_data_to_buffer_; | ||
|
|
||
| void VisitExpr_(const CallNode *op) final { | ||
| auto args = op->args; | ||
| if (op->op.same_as(builtin::call_extern())) { | ||
| std::string func_name_with_template = args[0].as<StringImmNode>()->value; | ||
| std::size_t le_pos = func_name_with_template.find_first_of('<'); | ||
| std::string func_name = le_pos == std::string::npos | ||
| ? func_name_with_template | ||
| : func_name_with_template.substr(0, le_pos); | ||
| if (func_name == "tl::utcmma_gemm_ts" || | ||
| func_name == "tl::utcmma_gemm_ss") { | ||
| // TCGEN5MMA | ||
| auto get_buf_from_access_ptr_call = | ||
| [&](const PrimExpr &expr) -> Buffer { | ||
| auto call = expr.as<CallNode>(); | ||
| ICHECK(call); | ||
| ICHECK(call->op.same_as(builtin::tvm_access_ptr())); | ||
| auto var = call->args[1].as<VarNode>(); | ||
| ICHECK(var); | ||
| auto it = buffer_data_to_buffer_.find(GetRef<Var>(var)); | ||
| ICHECK(it != buffer_data_to_buffer_.end()); | ||
| return (*it).second; | ||
| }; | ||
| Buffer a_buf = get_buf_from_access_ptr_call(args[1]); | ||
| Buffer b_buf = get_buf_from_access_ptr_call(args[2]); | ||
| Buffer mbar_buf = get_buf_from_access_ptr_call(args[4]); | ||
|
|
||
| TmemLoadCollector tmem_collector; | ||
| tmem_collector(args[3]); | ||
| ICHECK(tmem_collector.result.defined()) | ||
| << "TmemLoadCollector: No tmem buffer load found in the TCGEN5MMA " | ||
| "call"; | ||
| Buffer c_buf = tmem_collector.result; | ||
|
|
||
| PrimExpr clear_accum = args[5]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix TMEM detection to avoid crashes and false negatives.
ptr_type is dereferenced before we confirm it exists, so a missing type annotation segfaults before the ICHECK fires. We also look for storage scope "shared", but TCGEN5 writes land in "shared.tmem", which leaves result undefined and trips the later ICHECK. Please guard the pointer first and match the correct TMEM scope.
Apply this diff:
- const auto *ptr_type =
- buf->data->type_annotation.as<PointerTypeNode>();
- auto storage_scope = ptr_type->storage_scope;
- ICHECK(ptr_type) << "Buffer Var's type annotation must be of PointerType";
- if (storage_scope == "shared") {
+ const auto *ptr_type =
+ buf->data->type_annotation.as<PointerTypeNode>();
+ ICHECK(ptr_type) << "Buffer Var's type annotation must be of PointerType";
+ if (ptr_type->storage_scope == "shared.tmem") {
// We only care about shared.tmem buffers
ICHECK(!result.defined())
- << "TmemLoadCollector: More than one shared buffer visited";
+ << "TmemLoadCollector: More than one shared.tmem buffer visited";
result = buf;
}Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/transform/pipeline_planning.cc around lines 47 to 118, the
TmemLoadCollector dereferences the pointer-type annotation without null-checking
and compares storage_scope to "shared" (causing false negatives for
"shared.tmem" and crashes if type annotation is missing). Change the logic to
first verify the type annotation and that as<PointerTypeNode>() returns non-null
before accessing storage_scope, then match the TMEM scope against "shared.tmem"
(or use a starts_with/substring check for scopes that begin with "shared.tmem")
and only then set result; keep the existing ICHECK to ensure only one
shared.tmem buffer is collected.
| def alloc_tmem(shape, dtype): | ||
| """ | ||
| Allocate a Tensor Memory (TMEM) buffer for use with 5th generation Tensor Core operations (e.g., TCGEN5.MMA). | ||
| TMEM is a dedicated on-chip memory introduced in Hopper GPUs, designed to reduce register pressure and enable asynchronous, single-threaded MMA operations. It is organized as a 2D array of 512 columns by 128 rows (lanes), with each cell being 32 bits. Allocation is performed in units of columns, and every lane of a column is allocated together. | ||
| Key properties and requirements: | ||
| - The number of columns allocated must be a power of 2 and at least 32. | ||
| - TMEM allocations are dynamic and must be explicitly deallocated. | ||
| - Both allocation and deallocation must be performed by the same warp. | ||
| - The base address of the TMEM allocation is stored in shared memory and used as the offset for TCGEN5.MMA accumulator tensors. | ||
| - Only TCGEN5.MMA and specific TMEM load/store instructions can access TMEM; all pre-processing must occur before data is loaded into TMEM, and all post-processing after data is retrieved. | ||
| - The number of columns allocated should not increase between any two allocations in the execution order within the CTA. | ||
| Args: | ||
| num_cols (int): Number of columns to allocate in TMEM. Must be a power of 2 and >= 32 but less than or equal to 512. | ||
| Returns: | ||
| T.Buffer: A TVM buffer object allocated in TMEM scope, suitable for use as an accumulator or operand in TCGEN5.MMA operations. | ||
| Note: | ||
| - TMEM is only available on supported architectures (e.g., Hopper and later). | ||
| - The buffer returned should be used according to TMEM access restrictions and deallocated appropriately. | ||
| """ | ||
|
|
||
| assert len(shape) == 2, "shape must be a 2D tensor for TMEM allocation" | ||
| return T.alloc_buffer(shape, dtype, scope="shared.tmem") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the docstring arg description.
The docstring still references a num_cols parameter, but the function actually accepts (shape, dtype). This mismatch is misleading for users reading the API docs—please update the text (or the signature) so they line up.
🤖 Prompt for AI Agents
In tilelang/language/allocate.py around lines 92 to 118, the docstring
incorrectly documents a non-existent num_cols parameter while the function
signature is alloc_tmem(shape, dtype); update the docstring to describe the
actual parameters (shape: 2-tuple for rows/columns and dtype: element type),
remove or replace references to num_cols, and adjust the description of TMEM
allocation behavior to explain that shape specifies the 2D tensor dimensions
(including any constraints on columns expressed in terms of shape[1] if needed);
keep the usage notes about TMEM scope and deallocation but ensure all parameter
names and examples match the function signature.
| C_coords = [r.min for r in C.region] if isinstance(C, tir.BufferRegion) else [0, 0] | ||
| return tir.call_intrin("handle", tir.op.Op.get("tl.gemm"), Aptr, Bptr, Cptr, transpose_A, | ||
| transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a, | ||
| offset_b, k_pack, wg_wait, mbarptr, C_coords[0], C_coords[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle BufferLoad when computing C tile coordinates
gemm already accepts tir.BufferLoad for C (see the existing retrieve_ptr / retrieve_shape helpers), but the new C_coords branch only extracts offsets for tir.BufferRegion. For BufferLoad call sites—which we use when feeding tile slices of C—this now hardcodes [0, 0], so the backend sees every tile at the origin and the TCGEN5MMA barrier wiring becomes wrong. Please plumb the real minima for BufferLoad cases as well.
- C_coords = [r.min for r in C.region] if isinstance(C, tir.BufferRegion) else [0, 0]
+ if isinstance(C, tir.BufferRegion):
+ C_coords = [r.min for r in C.region]
+ elif isinstance(C, tir.BufferLoad):
+ region = get_buffer_region_from_load(C).region
+ C_coords = [r.min for r in region]
+ else:
+ C_coords = [0, 0]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| C_coords = [r.min for r in C.region] if isinstance(C, tir.BufferRegion) else [0, 0] | |
| return tir.call_intrin("handle", tir.op.Op.get("tl.gemm"), Aptr, Bptr, Cptr, transpose_A, | |
| transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a, | |
| offset_b, k_pack, wg_wait, mbarptr, C_coords[0], C_coords[1]) | |
| if isinstance(C, tir.BufferRegion): | |
| C_coords = [r.min for r in C.region] | |
| elif isinstance(C, tir.BufferLoad): | |
| region = get_buffer_region_from_load(C).region | |
| C_coords = [r.min for r in region] | |
| else: | |
| C_coords = [0, 0] | |
| return tir.call_intrin("handle", tir.op.Op.get("tl.gemm"), Aptr, Bptr, Cptr, transpose_A, | |
| transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a, | |
| offset_b, k_pack, wg_wait, mbarptr, C_coords[0], C_coords[1]) |
| target = tvm.target.Target.current(allow_none=True) | ||
| if target is not None: | ||
| return target |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preserve return_object contract when reusing the current TVM target
determine_target() historically returned a string backend key whenever return_object is left as the default False. With this early return we now hand back a tvm.target.Target instance whenever a target context is active, bypassing the return_object handling below. Any caller that compares the result to literals like "cuda"/"hip" (there are several today) will start failing as soon as a target context is installed. Please keep the old contract: only surface the Target instance when the caller explicitly requests it, and otherwise downgrade to the target kind string.
- if target == "auto":
- target = tvm.target.Target.current(allow_none=True)
- if target is not None:
- return target
+ if target == "auto":
+ current_target = tvm.target.Target.current(allow_none=True)
+ if current_target is not None:
+ if return_object:
+ return current_target
+ return current_target.kind.name📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| target = tvm.target.Target.current(allow_none=True) | |
| if target is not None: | |
| return target | |
| if target == "auto": | |
| - target = tvm.target.Target.current(allow_none=True) | |
| - if target is not None: | |
| current_target = tvm.target.Target.current(allow_none=True) | |
| if current_target is not None: | |
| if return_object: | |
| return current_target | |
| return current_target.kind.name |
🤖 Prompt for AI Agents
In tilelang/utils/target.py around lines 65-67, the early return hands back a
tvm.target.Target instance unconditionally when a target context is active,
breaking callers that expect a backend key string when return_object is False;
change the logic to respect the return_object parameter: if target is not None
and return_object is True return the Target instance, otherwise when
return_object is False return a string backend key (use target.kind if
available, falling back to str(target)), and preserve existing behavior with
allow_none=True (i.e., return None when no current target).
* update sm100 related utcmma, tmem, ld/st256 in src * update sm100 related utcmma, tmem, ld/st256 in tilelang * Remove deprecated GEMM examples and related README documentation for SM100 architecture support * Update GEMM implementation to replace UTCMMA with TCGEN5MMA across relevant files * Remove gemm_umma.py example and update README to reflect TCGEN5MMA terminology changes * Update README.md for gemm_sm100 example by removing outdated API sections and streamlining documentation * Update README and source files to reflect TCGEN5.MMA terminology changes * Refactor CUDA GEMM header for improved readability
Thank you for this great project. Our team (from IC) has extended the SM90 implementation to include SM100-related features, such as support for tcgen05 related components and ldg256. And also fix a bug for previous commits.
Summary by CodeRabbit
New Features
Documentation
Tests