Skip to content

Conversation

@Hamerlate
Copy link
Contributor

@Hamerlate Hamerlate commented Sep 28, 2025

Thank you for this great project. Our team (from IC) has extended the SM90 implementation to include SM100-related features, such as support for tcgen05 related components and ldg256. And also fix a bug for previous commits.

Summary by CodeRabbit

  • New Features

    • Preview support for NVIDIA SM100, including TCGEN5MMA, TMEM workflows, and 256-bit global vector loads/stores.
    • Expanded FP8/BF16 vectorization and CUDA codegen paths.
    • New API: allocate TMEM buffers; GEMM supports an optional synchronization barrier.
    • New transform pass to lower shared TMEM; pass config to disable 256-bit vectorization.
    • New GEMM examples and profiling utilities.
  • Documentation

    • Added README covering SM100 usage, setup, and limitations.
  • Tests

    • Updated tests to use target contexts (CUDA/WebGPU/CPU); added kernel source printing for debugging.

* update sm100 related utcmma, tmem, ld/st256 in src
* update sm100 related utcmma, tmem, ld/st256 in tilelang
* Remove deprecated GEMM examples and related README documentation for SM100 architecture support
* Update GEMM implementation to replace UTCMMA with TCGEN5MMA across relevant files
* Remove gemm_umma.py example and update README to reflect TCGEN5MMA terminology changes
* Update README.md for gemm_sm100 example by removing outdated API sections and streamlining documentation
* Update README and source files to reflect TCGEN5.MMA terminology changes
* Refactor CUDA GEMM header for improved readability
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 28, 2025

Walkthrough

Adds SM100/TCGEN5MMA and TMEM support across layouts, CUDA templates, codegen, and transforms. Introduces new builtins, target utilities, vectorization config, and lowering passes (LowerSharedTmem, pipeline planning updates). Extends GEMM with TCGEN5MMA path and WarpPolicy API. Adds examples, tests, and Python API entries and gating for Hopper/SM100 behaviors.

Changes

Cohort / File(s) Summary
Config
./.clang-tidy
Disable three additional clang-tidy checks.
Examples (SM100 GEMM)
examples/gemm_sm100/*
New README and two Python examples demonstrating SM100 matmul via MMA and TCGEN5MMA, with correctness and profiling.
Layouts
src/layout/layout.h, src/layout/gemm_layouts.cc, src/layout/tcgen05_layout.{h,cc}
Public make_itervar; new makeGemmABLayoutSm100; expanded Hopper layout selection; new Tcgen05 meta factories and layout expansion.
Builtins and attributes
src/op/builtin.{h,cc}, src/op/fill.cc
Add kDisableVectorize256 pass config and ptx_fence_barrier_init/init/deallocate TMEM intrinsics; include tcgen05_layout in fill.
Reducers (Hopper/Sm100 path)
src/op/finalize_reducer.cc, src/op/reduce.cc
Route Sm100 through Hopper AllReduce path.
GEMM core and bindings
src/op/gemm.{h,cc}, src/op/gemm_py.{h,cc}, src/op/gemm_sp.cc
Add GemmInst enum, TCGEN5MMA capability checks, extended lowering/layout paths, WarpPolicy API change to take GemmInst; update call sites.
CUDA codegen interface
src/target/codegen_cuda.{h,cc}
Add GetVecLoad/PrintVecStore for 256-bit global loads/stores; broaden type/lane handling, casting, broadcast packing; handle fence_barrier_init call.
CUDA minor
src/target/codegen_cpp.cc
Remove unused variable.
Runtime logging
src/runtime/runtime.cc
Replace std::endl with '\n'.
Target utilities
src/target/utils.{h,cc}
Add TargetIsSm100 and TargetHasTmem.
TL CUDA templates: copy/tcgen05/fp8/gemm
src/tl_templates/cuda/copy.h, .../copy_sm100.h, .../tcgen_05.h, .../tcgen_05_ld.h, .../cuda_fp8.h, .../gemm.h, .../gemm_sm100.h, .../debug.h
Add SM100 includes; new 256-bit ld/st, packing, TMEM load helpers, UMMA/tcgen05 device APIs, FP8 32-lane types, SM100 TCGEN5MMA GEMM templates.
Transforms: vectorize/tmem/tile-op/pipeline
src/transform/loop_vectorize.cc, src/transform/lower_shared_tmem.cc, src/transform/lower_tile_op.cc, src/transform/pipeline_planning.cc
Add 256-bit vectorize planning with global-access check; implement LowerSharedTmem pass (init/dealloc, remaps); add layout remap rewriter; integrate async dependency chain for TMEM/barrier.
Testing updates
testing/python/*
Use Target context managers; add kernel source prints; set webgpu target in context.
Python: target/NVCC helpers
tilelang/contrib/nvcc.py, tilelang/utils/target.py
Add is_hopper; improve auto target by honoring current TVM target.
Python: engine pipeline
tilelang/engine/phase.py
Gate WgmmaSync rewrite on Hopper; insert LowerSharedTmem after LowerSharedBarrier.
Python: language API
tilelang/language/__init__.py, tilelang/language/allocate.py, tilelang/language/gemm.py
Export and implement alloc_tmem (shared.tmem); gemm gains optional mbar and forwards mbarptr/C coords.
Python: transform API/config
tilelang/transform/__init__.py, tilelang/transform/pass_config.py
Expose FrontendLegalize and LowerSharedTmem; add TL_DISABLE_VECTORIZE_256 key.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant TileLang as TileLang Python
  participant Lowering as Lowering Passes
  participant Codegen as CUDA Codegen
  participant Device as GPU (SM100)

  User->>TileLang: tl.gemm(..., mbar=opt, alloc_tmem(...))
  TileLang->>Lowering: OptimizeForTarget(target)
  Note over Lowering: If Hopper: RewriterWgmmaSync<br/>Always: LowerSharedBarrier -> LowerSharedTmem
  Lowering->>Lowering: LowerSharedTmem (init/dealloc TMEM, remap)
  Lowering->>Lowering: Pipeline planning (async dep chain)
  Lowering->>Lowering: GEMM InferLayout (TCGEN5MMA/WGMMA/MMA)
  Lowering->>Codegen: Emit CUDA (vec ld/st up to 256-bit)
  Codegen-->>Device: Kernel launch + TMEM ops + barriers
  Device-->>User: Results
Loading
sequenceDiagram
  autonumber
  participant GEMM as GemmNode
  participant Policy as WarpPolicy
  participant Target as Target Utils
  participant Layout as Layout Inference

  GEMM->>Target: Get target features (Sm100/Hopper, HasTmem)
  GEMM->>GEMM: GetGemmInst (kTCGEN5MMA/kWGMMA/kMMA)
  GEMM->>Policy: ComputeWarpPartition(M,N,block_size,target,gemm_inst)
  GEMM->>Layout: Infer A/B/C layouts (incl. SM100 Swizzles)
  GEMM-->>GEMM: Lower specialized path (TCGEN5MMA uses mbar/tmem)
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~150 minutes

Possibly related PRs

Suggested labels

enhancement

Suggested reviewers

  • LeiWang1999
  • tzj-fxz

Poem

In tunnels of TMEM, I stash my hay,
256-bit hops, I bound and sway.
New warrens for GEMM, SM100’s gleam,
Barriers fenced—syncs in a stream.
I twitch my nose at layouts grand,
TCGEN5MMA, paw-in-hand.
Thump! The kernels race across the land. 🐇⚙️

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Title Check ⚠️ Warning The provided title is overly generic and does not capture the primary focus of the extensive changes, which include adding SM100 TCGEN5MMA and 256-bit vector load/store support, new TileLang shared TMEM lowering passes, multiple new API functions, and layout extensions for SM100; it merely mentions a bugfix and “enhance cuda backend” without reflecting the scope or key features introduced. Please choose a concise, descriptive title that highlights the core change—e.g. “Add SM100 TCGEN5MMA and 256-bit vector support to CUDA backend and fix static linkage bug”—so reviewers can immediately understand the main enhancements.
Docstring Coverage ⚠️ Warning Docstring coverage is 20.48% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 25

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
src/tl_templates/cuda/gemm.h (1)

3-17: Revert to __CUDA_ARCH__-based guards to avoid pulling the wrong header.

__CUDA_ARCH_LIST__ expands to a comma-separated list of all targets on the command line (or host-side preprocessing), not the architecture currently being compiled. Comparing it to a number collapses to the last element via the comma operator, i.e., the highest arch in the list, so a multi-arch build (sm75 + sm100) will include gemm_sm100.h even when compiling the sm75 pass. That will pull in instructions unsupported on sm75 and can break codegen. Please keep the per-phase checks on __CUDA_ARCH__ (and, if you need a host-side path, add a separate helper that parses __CUDA_ARCH_LIST__ instead of using it directly in these comparisons). (docs.nvidia.com)

Apply this diff to restore the original guard:

-#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1200))
+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1200))
 #include "gemm_sm120.h"
-#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1000))
+#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
 #include "gemm_sm100.h"
-#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
+#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
 #include "gemm_sm90.h"
-#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890))
+#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890))
 #include "gemm_sm89.h"
-#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
+#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
 #include "gemm_sm80.h"
-#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 700))
+#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700))
 #include "gemm_sm70.h"
 #else
 // No matching architecture found
 #endif
src/op/gemm_py.cc (1)

95-109: TCGEN5 path never selected in Python GEMM lowering

GemmPyNode::GetGemmInst still only returns {kWGMMA, kMFMA, kMMA}. In src/op/gemm.cc the native lowering path was updated to return GemmInst::kTCGEN5MMA when AllowTCGEN5MMA(target) succeeds, but the Python binding never produces that enum. As a result, on SM100 targets we always fall back to kMMA, so the new TCGEN5 warp policy and lowering you just added remain unreachable whenever kernels are created through tilelang.op.gemm. Please align this helper with the C++ implementation so it can emit kTCGEN5MMA (e.g., reuse the shared AllowTCGEN5MMA logic) and unlock the intended SM100 flow.

Apply this adjustment:

 GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
   int warp_size = TargetGetWarpSize(target);
   int num_warps = block_size / warp_size;
-  bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
-                     (num_warps % 4 == 0) && CheckWGMMA();
-  if (allow_wgmma) {
+  bool allow_tcgen5mma = AllowTCGEN5MMA(target);
+  bool allow_wgmma = AllowWGMMA(block_size, target);
+  if (allow_tcgen5mma) {
+    return GemmInst::kTCGEN5MMA;
+  } else if (allow_wgmma) {
     return GemmInst::kWGMMA;
src/op/gemm.h (1)

115-182: Account for the new GEMM metadata in structural equality/hashing.

We now carry mbarptr, mbar, and C_coords on GemmNode, but SEqualReduce/SHashReduce still ignore them. That means two nodes that differ only in their barrier pointer or TCGEN5 MMA coordinates collapse to the same key, letting memoized rewrites and caches return the wrong instance. Please include these fields in both equality and hashing.

Apply this diff:

@@
-           equal(offset_A, other->offset_A) &&
-           equal(offset_B, other->offset_B) &&
-           equal(clear_accum, other->clear_accum) &&
-           equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
-           equal(policy, other->policy);
+           equal(offset_A, other->offset_A) &&
+           equal(offset_B, other->offset_B) &&
+           equal(clear_accum, other->clear_accum) &&
+           equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
+           equal(mbarptr, other->mbarptr) && equal(mbar, other->mbar) &&
+           equal(C_coords, other->C_coords) && equal(policy, other->policy);
@@
-    hash_reduce(clear_accum);
-    hash_reduce(kPack);
-    hash_reduce(wg_wait);
-    hash_reduce(policy);
+    hash_reduce(clear_accum);
+    hash_reduce(kPack);
+    hash_reduce(wg_wait);
+    hash_reduce(mbarptr);
+    hash_reduce(mbar);
+    hash_reduce(C_coords);
+    hash_reduce(policy);
🧹 Nitpick comments (11)
src/layout/tcgen05_layout.h (1)

7-31: Add the standard library includes this header relies on.

This header uses std::string and std::tuple, but it doesn't include <string> or <tuple>. Please include them here so the header stays self-contained instead of depending on transitive includes from other headers.

testing/python/kernel/test_tilelang_kernel_gemm.py (1)

85-85: Remove unconditional kernel-source print from the test.

Dumping every generated kernel into STDOUT makes the test logs huge (each invocation prints hundreds of lines, and this test runs a dozen variants), which slows CI and buries real failures. Please drop the print or guard it behind an explicit debug flag/environment check before merging.

src/op/gemm.cc (1)

21-86: Consider adding template specializations for common atom configurations.

The GetTCGEN5MMAMeta function has repetitive patterns for atom_m (128, 64, 32) that could benefit from template specialization or a more data-driven approach to reduce code duplication. Additionally, the function returns a pair where the boolean and struct are always consistent - consider returning std::optional<TCGEN5MMAMeta> instead.

Apply this refactor to simplify the meta computation:

-// Return {is_success, meta}
-static inline std::pair<bool, TCGEN5MMAMeta>
-GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
-// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
-#define FAIL                                                                   \
-  return {                                                                     \
-    false, TCGEN5MMAMeta { 0, 0, 0 }                                           \
-  }
-#define SUCCESS(atom_m, atom_n, atom_k)                                        \
-  return {                                                                     \
-    true, TCGEN5MMAMeta { atom_m, atom_n, atom_k }                             \
-  }
+static inline std::optional<TCGEN5MMAMeta>
+GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
+  // TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
   std::vector<int> ws_valid_atom_ns = {256, 128, 64};
+  
+  int required_k_alignment = 0;
   if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
       (c_dtype.is_float() && c_dtype.bits() == 32)) {
-    if (K % 16 != 0)
-      FAIL;
-    if (M % 128 == 0) {
-      for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
-        if (N % atom_n == 0)
-          SUCCESS(128, atom_n, 16);
-      FAIL;
-    } else if (M % 64 == 0) {
-      for (int atom_n : ws_valid_atom_ns)
-        if (N % atom_n == 0)
-          SUCCESS(64, atom_n, 16);
-      FAIL;
-    } else if (M % 32 == 0) {
-      for (int atom_n : ws_valid_atom_ns)
-        if (N % atom_n == 0)
-          SUCCESS(32, atom_n, 16);
-      FAIL;
-    } else {
-      FAIL;
-    }
+    required_k_alignment = 16;
   } else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
              (c_dtype.is_float() && c_dtype.bits() == 32)) {
-    if (K % 32 != 0)
-      FAIL;
-    if (M % 128 == 0) {
-      for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
-        if (N % atom_n == 0)
-          SUCCESS(128, atom_n, 32);
-      FAIL;
-    } else if (M % 64 == 0) {
-      for (int atom_n : ws_valid_atom_ns)
-        if (N % atom_n == 0)
-          SUCCESS(64, atom_n, 32);
-      FAIL;
-    } else if (M % 32 == 0) {
-      for (int atom_n : ws_valid_atom_ns)
-        if (N % atom_n == 0)
-          SUCCESS(32, atom_n, 32);
-      FAIL;
-    } else {
-      FAIL;
-    }
+    required_k_alignment = 32;
+  } else {
+    return std::nullopt;
   }
-  FAIL;
-#undef FAIL
-#undef SUCCESS
+  
+  if (K % required_k_alignment != 0)
+    return std::nullopt;
+  
+  // Try atom_m values in descending order
+  for (int atom_m : {128, 64, 32}) {
+    if (M % atom_m != 0) continue;
+    
+    if (atom_m == 128) {
+      // For atom_m=128, try all atom_n from 256 down to 16
+      for (int atom_n = 256; atom_n >= 16; atom_n -= 16) {
+        if (N % atom_n == 0)
+          return TCGEN5MMAMeta{atom_m, atom_n, required_k_alignment};
+      }
+    } else {
+      // For atom_m=64 and 32, use predefined valid atom_n values
+      for (int atom_n : ws_valid_atom_ns) {
+        if (N % atom_n == 0)
+          return TCGEN5MMAMeta{atom_m, atom_n, required_k_alignment};
+      }
+    }
+  }
+  
+  return std::nullopt;
 }
src/tl_templates/cuda/tcgen_05_ld.h (3)

17-19: Power-of-2 validation can be simplified.

The static assertion (N & (N - 1)) == 0 is a good check for power of 2, but consider using a more readable helper or standard library function if available.

Consider extracting the power-of-2 check into a constexpr helper for better readability:

+  template<int N>
+  static constexpr bool is_power_of_two() {
+    return N > 0 && (N & (N - 1)) == 0;
+  }
+
   template <int N>
   static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) {
-    static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128,
+    static_assert(is_power_of_two<N>() && N <= 128,
                   "N must be a power of 2 and lies between 1 ~ 128");

179-180: Trap on invalid template parameters lacks diagnostic information.

Using asm volatile("trap") for the else case provides no diagnostic information. While this should never be reached due to static_assert, having a more informative error would help debugging.

Consider adding a compile-time error message:

     } else {
-      asm volatile("trap");
+      static_assert(N <= 128 && (N & (N - 1)) == 0, "Invalid N value for tmem_ld_32dp32bNx::copy");
+      __builtin_unreachable();
     }

683-711: 32dp wrapper classes use magic number for address offset.

The 32-datapath wrapper classes use (16 << 16) as an address offset, which appears to be encoding datapath lane information in the upper bits. This magic number should be documented or defined as a named constant.

Define the magic number as a named constant with documentation:

+// TMEM address encoding: bits [31:16] represent the datapath lane offset
+// For 32dp operations, we need to access both 16-lane groups
+constexpr uint32_t TMEM_LANE_OFFSET = 16 << 16;
+
 // 32 data path lanes, 64-bit pattern, repeated N times
 // (conducted with 2x16dp64bNx)
 class tmem_ld_32dp64bNx {
 public:
   template <int N>
   static TL_DEVICE void copy(uint32_t const &src_addr, uint32_t *dst_ptr) {
     tmem_ld_16dp64bNx::copy<N>(src_addr, dst_ptr);
-    tmem_ld_16dp64bNx::copy<N>(src_addr + (16 << 16), dst_ptr + N);
+    tmem_ld_16dp64bNx::copy<N>(src_addr + TMEM_LANE_OFFSET, dst_ptr + N);
   }
 };
src/tl_templates/cuda/copy_sm100.h (1)

76-83: Recursive template instantiation depth could be problematic.

The get_floor_log2 template uses unbounded recursion which could hit template instantiation depth limits for large N values. While the current use cases seem bounded, consider adding a depth limit check.

Add a recursion depth limit:

 template <int N, int K = 0>
 __device__ __forceinline__ constexpr int get_floor_log2() {
   static_assert(N > 0);
+  static_assert(K < 32, "Recursion depth exceeded in get_floor_log2");
   if constexpr ((1 << (K + 1)) > N)
     return K;
   else
     return get_floor_log2<N, K + 1>();
 }
src/tl_templates/cuda/gemm_sm100.h (4)

160-163: Remove redundant static assertion

Lines 162-163 duplicate the validation already performed at lines 149-151. The first assertion is more informative as it names the specific type.

Apply this diff to remove the redundant assertion:

   using FrgTypeC = UMMA::tmem_frg_ws_1sm<c_type>;

-  static_assert(sizeof_bits_v<ValTypeA> <= sizeof_bits_v<uint8_t> &&
-                sizeof_bits_v<ValTypeB> <= sizeof_bits_v<uint8_t>);
-
   // Logical shape-K is always 256bits, transform to units of elements
   constexpr static int K = 32;

325-330: Address the TODO comment

The TODO comment indicates that the implementation is using the .ws variant as a workaround. This should be tracked for future optimization.

The TODO comment suggests there might be a performance impact. Would you like me to create an issue to track implementing proper support for the non-.ws variant when M == 64, or would you prefer to document the performance implications?


365-370: Reminder: Implement gemm_ts function

The TODO comment indicates a missing gemm_ts implementation. Please ensure this is tracked for completion.

Do you want me to generate the gemm_ts implementation or open a new issue to track this task?


248-253: Inconsistent specialization pattern

For fp8 types with M==128, the code uses MMA_Traits<SM100_MMA_F8F6F4_SS, ...> wrapper, while for bf16/half types with M==128, it directly uses the MMA type (e.g., SM100_MMA_F16BF16_SS). Consider using a consistent pattern.

Either wrap all M==128 cases in MMA_Traits or use the direct type for all. The current mixed approach may confuse maintainers.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c382dcb and fa65e4b.

📒 Files selected for processing (49)
  • .clang-tidy (1 hunks)
  • examples/gemm_sm100/README.md (1 hunks)
  • examples/gemm_sm100/gemm_mma.py (1 hunks)
  • examples/gemm_sm100/gemm_tcgen5mma.py (1 hunks)
  • src/layout/gemm_layouts.cc (2 hunks)
  • src/layout/layout.h (2 hunks)
  • src/layout/tcgen05_layout.cc (1 hunks)
  • src/layout/tcgen05_layout.h (1 hunks)
  • src/op/builtin.cc (3 hunks)
  • src/op/builtin.h (2 hunks)
  • src/op/fill.cc (1 hunks)
  • src/op/finalize_reducer.cc (1 hunks)
  • src/op/gemm.cc (11 hunks)
  • src/op/gemm.h (5 hunks)
  • src/op/gemm_py.cc (2 hunks)
  • src/op/gemm_py.h (0 hunks)
  • src/op/gemm_sp.cc (1 hunks)
  • src/op/reduce.cc (1 hunks)
  • src/runtime/runtime.cc (3 hunks)
  • src/target/codegen_cpp.cc (0 hunks)
  • src/target/codegen_cuda.cc (13 hunks)
  • src/target/codegen_cuda.h (1 hunks)
  • src/target/utils.cc (2 hunks)
  • src/target/utils.h (1 hunks)
  • src/tl_templates/cuda/copy.h (1 hunks)
  • src/tl_templates/cuda/copy_sm100.h (1 hunks)
  • src/tl_templates/cuda/cuda_fp8.h (3 hunks)
  • src/tl_templates/cuda/debug.h (2 hunks)
  • src/tl_templates/cuda/gemm.h (2 hunks)
  • src/tl_templates/cuda/gemm_sm100.h (1 hunks)
  • src/tl_templates/cuda/tcgen_05.h (1 hunks)
  • src/tl_templates/cuda/tcgen_05_ld.h (1 hunks)
  • src/transform/loop_vectorize.cc (4 hunks)
  • src/transform/lower_shared_tmem.cc (1 hunks)
  • src/transform/lower_tile_op.cc (6 hunks)
  • src/transform/pipeline_planning.cc (7 hunks)
  • testing/python/cpu/test_tilelang_cpu_gemm.py (2 hunks)
  • testing/python/kernel/test_tilelang_kernel_gemm.py (1 hunks)
  • testing/python/transform/test_tilelang_transform_layout_inference.py (1 hunks)
  • testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py (1 hunks)
  • testing/python/webgpu/test_webgpu_codegen.py (1 hunks)
  • tilelang/contrib/nvcc.py (1 hunks)
  • tilelang/engine/phase.py (3 hunks)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/language/allocate.py (1 hunks)
  • tilelang/language/gemm.py (5 hunks)
  • tilelang/transform/__init__.py (2 hunks)
  • tilelang/transform/pass_config.py (1 hunks)
  • tilelang/utils/target.py (1 hunks)
💤 Files with no reviewable changes (2)
  • src/target/codegen_cpp.cc
  • src/op/gemm_py.h
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-12T09:47:46.474Z
Learnt from: kurisu6912
PR: tile-ai/tilelang#794
File: tilelang/transform/add_bufstore_wrapper.py:30-33
Timestamp: 2025-09-12T09:47:46.474Z
Learning: In TVM's PyStmtExprMutator, visit_block_ methods typically call super().visit_block_(op) to process child nodes and update internal state, but return the original op when the block itself doesn't need transformation. The pattern `return op` is correct for blocks that serve as containers where mutations happen at deeper levels.

Applied to files:

  • src/transform/pipeline_planning.cc
🧬 Code graph analysis (32)
tilelang/language/allocate.py (1)
tilelang/language/ast/ir.py (1)
  • alloc_buffer (441-508)
src/target/utils.h (1)
src/target/utils.cc (4)
  • TargetIsSm100 (56-61)
  • TargetIsSm100 (56-56)
  • TargetHasTmem (114-118)
  • TargetHasTmem (114-114)
tilelang/language/__init__.py (1)
tilelang/language/allocate.py (1)
  • alloc_tmem (92-118)
tilelang/transform/__init__.py (1)
src/transform/lower_shared_tmem.cc (4)
  • LowerSharedTmem (285-291)
  • LowerSharedTmem (285-285)
  • LowerSharedTmem (296-301)
  • LowerSharedTmem (296-296)
testing/python/transform/test_tilelang_transform_legalize_vectorized_loop.py (1)
tilelang/transform/__init__.py (1)
  • LegalizeVectorizedLoop (241-249)
testing/python/cpu/test_tilelang_cpu_gemm.py (2)
tilelang/engine/lower.py (1)
  • lower (190-242)
tilelang/jit/__init__.py (1)
  • compile (33-86)
testing/python/webgpu/test_webgpu_codegen.py (1)
tilelang/language/ast/ir.py (1)
  • target (1682-1713)
src/transform/lower_shared_tmem.cc (5)
src/target/codegen_cuda.cc (26)
  • op (217-232)
  • op (217-217)
  • op (1611-1613)
  • op (1611-1611)
  • op (1614-1616)
  • op (1614-1614)
  • VisitStmt_ (296-315)
  • VisitStmt_ (296-296)
  • VisitStmt_ (1934-1973)
  • VisitStmt_ (1934-1934)
  • VisitStmt_ (1975-2035)
  • VisitStmt_ (1975-1975)
  • VisitStmt_ (2037-2053)
  • VisitStmt_ (2037-2037)
  • VisitExpr_ (886-971)
  • VisitExpr_ (886-886)
  • VisitExpr_ (1176-1932)
  • VisitExpr_ (1176-1176)
  • VisitExpr_ (2055-2069)
  • VisitExpr_ (2055-2055)
  • VisitExpr_ (2071-2139)
  • VisitExpr_ (2071-2072)
  • VisitExpr_ (2141-2289)
  • VisitExpr_ (2141-2142)
  • VisitExpr_ (2345-2348)
  • VisitExpr_ (2345-2346)
src/layout/layout.h (1)
  • Array (34-66)
src/target/utils.cc (2)
  • TargetGetWarpSize (127-132)
  • TargetGetWarpSize (127-127)
tilelang/language/tir/op.py (1)
  • tvm_access_ptr (650-675)
tilelang/transform/__init__.py (1)
  • LowerSharedTmem (443-446)
tilelang/engine/phase.py (4)
tilelang/contrib/nvcc.py (2)
  • have_tma (434-449)
  • is_hopper (452-457)
src/transform/lower_shared_tmem.cc (4)
  • LowerSharedTmem (285-291)
  • LowerSharedTmem (285-285)
  • LowerSharedTmem (296-301)
  • LowerSharedTmem (296-296)
tilelang/transform/__init__.py (2)
  • LowerSharedTmem (443-446)
  • RewriteWgmmaSync (117-125)
src/transform/wgmma_sync_rewriter.cc (2)
  • RewriteWgmmaSync (262-267)
  • RewriteWgmmaSync (262-262)
src/op/reduce.cc (1)
src/target/utils.cc (4)
  • TargetIsHopper (49-54)
  • TargetIsHopper (49-49)
  • TargetIsSm100 (56-61)
  • TargetIsSm100 (56-56)
src/layout/tcgen05_layout.h (1)
src/layout/tcgen05_layout.cc (10)
  • getTcgen05Meta_32dp32b (22-30)
  • getTcgen05Meta_32dp32b (22-22)
  • getTcgen05Meta_32dp64b (32-44)
  • getTcgen05Meta_32dp64b (32-32)
  • getTcgen05Meta_32dp128b (46-57)
  • getTcgen05Meta_32dp128b (46-46)
  • getTcgen05Meta_32dp256b (59-72)
  • getTcgen05Meta_32dp256b (59-59)
  • expandTcgen05Layout (74-108)
  • expandTcgen05Layout (75-76)
src/target/codegen_cuda.cc (2)
src/target/codegen_cpp.cc (2)
  • PrintType (95-164)
  • PrintType (95-95)
src/target/codegen_hip.cc (2)
  • PrintType (178-421)
  • PrintType (178-178)
testing/python/kernel/test_tilelang_kernel_gemm.py (5)
examples/gemv/example_gemv.py (1)
  • kernel (236-285)
tilelang/jit/adapter/ctypes/adapter.py (1)
  • get_kernel_source (290-296)
tilelang/jit/adapter/cython/adapter.py (1)
  • get_kernel_source (516-522)
tilelang/jit/kernel.py (1)
  • get_kernel_source (378-389)
tilelang/jit/adapter/base.py (1)
  • get_kernel_source (51-52)
src/target/codegen_cuda.h (1)
src/target/codegen_cuda.cc (12)
  • GetVecLoad (1097-1118)
  • GetVecLoad (1097-1099)
  • t (25-63)
  • t (25-25)
  • t (67-74)
  • t (67-67)
  • t (78-94)
  • t (78-78)
  • t (98-106)
  • t (98-99)
  • PrintVecStore (1120-1142)
  • PrintVecStore (1120-1122)
examples/gemm_sm100/gemm_mma.py (11)
examples/gemm_sm100/gemm_tcgen5mma.py (2)
  • matmul (8-62)
  • main (29-60)
testing/python/kernel/test_tilelang_kernel_gemm.py (4)
  • matmul (5-50)
  • main (28-48)
  • main (323-346)
  • main (443-466)
tilelang/language/allocate.py (2)
  • alloc_shared (21-36)
  • alloc_fragment (53-64)
tilelang/language/fill.py (1)
  • clear (24-48)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/language/gemm.py (1)
  • gemm (10-212)
tilelang/jit/__init__.py (1)
  • compile (33-86)
tilelang/jit/kernel.py (2)
  • out_idx (446-447)
  • get_profiler (360-376)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-101)
tilelang/utils/tensor.py (1)
  • TensorSupplyType (11-18)
src/layout/tcgen05_layout.cc (1)
src/layout/gemm_layouts.cc (2)
  • make_itervar (16-19)
  • make_itervar (16-16)
tilelang/language/gemm.py (2)
tilelang/tileop/gemm/gemm_base.py (8)
  • C (74-75)
  • M (33-34)
  • N (37-38)
  • K (41-42)
  • policy (118-119)
  • clear_accum (106-107)
  • k_pack (110-111)
  • wg_wait (114-115)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
src/op/gemm.h (2)
tilelang/tileop/gemm/gemm_base.py (1)
  • offset_A (98-99)
src/op/gemm.cc (4)
  • AllowTCGEN5MMA (169-176)
  • AllowTCGEN5MMA (169-169)
  • AllowWGMMA (178-186)
  • AllowWGMMA (178-178)
src/tl_templates/cuda/copy_sm100.h (2)
src/tl_templates/cuda/tcgen_05.h (3)
  • tl (10-60)
  • __device__ (28-30)
  • __device__ (32-34)
src/tl_templates/cuda/tcgen_05_ld.h (5)
  • tl (10-713)
  • tmem_ld_32dp32bNx (13-182)
  • tmem_ld_32dp64bNx (684-691)
  • tmem_ld_32dp128bNx (694-701)
  • tmem_ld_32dp256bNx (704-711)
src/target/utils.cc (1)
src/op/gemm.cc (2)
  • GetArchInt (476-487)
  • GetArchInt (476-476)
src/op/gemm.cc (5)
src/op/operator.cc (2)
  • GetVarFromAccessPtr (75-82)
  • GetVarFromAccessPtr (75-75)
src/target/utils.cc (8)
  • TargetIsSm100 (56-61)
  • TargetIsSm100 (56-56)
  • TargetGetWarpSize (127-132)
  • TargetGetWarpSize (127-127)
  • TargetIsHopper (49-54)
  • TargetIsHopper (49-49)
  • TargetIsCDNA (70-79)
  • TargetIsCDNA (70-70)
src/op/gemm_py.cc (4)
  • CheckWGMMA (141-191)
  • CheckWGMMA (141-141)
  • GetGemmInst (95-109)
  • GetGemmInst (95-95)
src/layout/gemm_layouts.cc (4)
  • makeGemmABLayoutSm100 (767-787)
  • makeGemmABLayoutSm100 (767-768)
  • make_itervar (16-19)
  • make_itervar (16-16)
src/layout/tcgen05_layout.cc (2)
  • make_itervar (17-20)
  • make_itervar (17-17)
src/layout/layout.h (2)
src/layout/gemm_layouts.cc (4)
  • make_itervar (16-19)
  • make_itervar (16-16)
  • makeGemmABLayoutSm100 (767-787)
  • makeGemmABLayoutSm100 (767-768)
src/layout/tcgen05_layout.cc (2)
  • make_itervar (17-20)
  • make_itervar (17-17)
src/transform/loop_vectorize.cc (2)
src/transform/loop_vectorize_dynamic.cc (18)
  • node (80-85)
  • node (80-80)
  • node (92-96)
  • node (92-92)
  • node (98-112)
  • node (98-98)
  • node (114-120)
  • node (114-114)
  • node (122-125)
  • node (122-122)
  • node (127-135)
  • node (127-127)
  • node (263-266)
  • node (263-263)
  • node (280-283)
  • node (280-280)
  • indices (141-194)
  • indices (141-141)
src/target/utils.cc (2)
  • TargetIsSm100 (56-61)
  • TargetIsSm100 (56-56)
testing/python/transform/test_tilelang_transform_layout_inference.py (1)
tilelang/transform/__init__.py (1)
  • LayoutInference (39-47)
src/tl_templates/cuda/tcgen_05_ld.h (2)
src/tl_templates/cuda/copy_sm100.h (2)
  • tl (6-134)
  • int (77-83)
src/tl_templates/cuda/tcgen_05.h (1)
  • tl (10-60)
examples/gemm_sm100/gemm_tcgen5mma.py (6)
tilelang/env.py (1)
  • disable_cache (232-233)
tilelang/language/allocate.py (3)
  • alloc_tmem (92-118)
  • alloc_barrier (80-89)
  • alloc_fragment (53-64)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/language/gemm.py (1)
  • gemm (10-212)
tilelang/language/builtin.py (1)
  • mbarrier_wait_parity (172-219)
tilelang/jit/__init__.py (1)
  • compile (33-86)
src/op/finalize_reducer.cc (1)
src/target/utils.cc (4)
  • TargetIsHopper (49-54)
  • TargetIsHopper (49-49)
  • TargetIsSm100 (56-61)
  • TargetIsSm100 (56-56)
src/transform/pipeline_planning.cc (3)
src/transform/layout_inference.cc (4)
  • op (42-48)
  • op (42-42)
  • expr (379-391)
  • expr (379-379)
src/transform/inject_pipeline.cc (18)
  • op (165-181)
  • op (165-165)
  • op (183-196)
  • op (183-183)
  • op (198-211)
  • op (198-198)
  • op (213-219)
  • op (213-213)
  • op (839-945)
  • op (839-839)
  • op (947-964)
  • op (947-947)
  • op (966-983)
  • op (966-966)
  • buf (485-487)
  • buf (485-485)
  • call (133-163)
  • call (133-134)
tilelang/language/builtin.py (1)
  • mbarrier_wait_parity (172-219)
src/tl_templates/cuda/gemm_sm100.h (4)
src/tl_templates/cuda/tcgen_05.h (4)
  • void (12-18)
  • void (20-26)
  • void (62-68)
  • tl (10-60)
src/tl_templates/cuda/common.h (1)
  • uint32_t (115-117)
tilelang/language/gemm.py (1)
  • gemm (10-212)
src/tl_templates/cuda/tcgen_05_ld.h (1)
  • tl (10-713)
src/transform/lower_tile_op.cc (2)
src/transform/loop_partition.cc (6)
  • op (128-154)
  • op (128-128)
  • op (157-162)
  • op (157-157)
  • op (164-169)
  • op (164-164)
src/target/codegen_cuda.cc (8)
  • VisitStmt_ (296-315)
  • VisitStmt_ (296-296)
  • VisitStmt_ (1934-1973)
  • VisitStmt_ (1934-1934)
  • VisitStmt_ (1975-2035)
  • VisitStmt_ (1975-1975)
  • VisitStmt_ (2037-2053)
  • VisitStmt_ (2037-2037)
src/op/gemm_py.cc (1)
src/op/gemm.cc (2)
  • GetGemmInst (188-204)
  • GetGemmInst (188-188)
src/tl_templates/cuda/tcgen_05.h (2)
src/tl_templates/cuda/copy.h (2)
  • tl (14-80)
  • void (16-18)
src/tl_templates/cuda/tcgen_05_ld.h (1)
  • tl (10-713)
🪛 Ruff (0.13.1)
tilelang/language/__init__.py

45-45: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

examples/gemm_sm100/gemm_tcgen5mma.py

20-20: Unused function argument: num_stages

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-amd
  • GitHub Check: format-check
🔇 Additional comments (16)
src/runtime/runtime.cc (2)

75-86: LGTM on newline swap

Replacing std::endl with '\n' keeps the formatting while avoiding the extra flush cost, which is consistent with our logging cleanup.


102-104: LGTM on fatal log

Retains the newline while eliminating the flush, so the fatal path still prints the descriptor dump without incurring the extra flush penalty.

src/tl_templates/cuda/tcgen_05.h (1)

36-60: Good defensive programming with static assertions

The static assertions properly validate the supported M/N dimensions for the TCGEN5MMA operations. The error messages are clear and informative.

src/transform/loop_vectorize.cc (3)

48-71: Well-structured visitor pattern implementation

The VectorizeFindGlobalAccess class properly encapsulates global access detection with clear visitor methods and a simple public API.


151-155: Proper handling of Cast node vectorization

The addition of VisitExpr_ for CastNode correctly adjusts vector size based on the target data type width using ZeroAwareGCD.


191-191: State management change for runtime configuration

Changing vector_load_bits_max_ from const to non-const allows runtime configuration based on target. This is a necessary change for the SM100 256-bit vectorization feature.

src/target/codegen_cuda.cc (6)

360-363: Good extension for wider FP16 vectors

The extension properly handles FP16 vectors with lanes up to 16 by using ulonglong types. The error messages clearly indicate the lane requirements.


508-511: Support for int8x32 vectors

Successfully adds support for 32-lane int8 vectors using longlong4 type, which aligns with the 256-bit wide path requirements.


1097-1142: Well-implemented wide vector load/store paths

The GetVecLoad and PrintVecStore methods properly check for 256-bit vector operations on global memory and use the appropriate TL helpers (tl::ld_global_256/tl::st_global_256). The assertion ensures only 256-bit operations use this path.


1289-1290: Good addition of fence barrier init support

Adding support for tl::ptx_fence_barrier_init properly extends the barrier initialization capabilities.


2157-2171: Proper handling of int8x32 broadcast

The implementation correctly handles broadcasting for both 4-lane and 32-lane int8 vectors with appropriate bit manipulation and type construction.


577-585: No int64 vectors beyond 4 lanes in codebase—no breakage
Search for longlong5-9 and longlong10-16 across all .cc, .cpp, and .h files returned no hits, indicating no prior support for int64 vectors with lanes > 4. The new error path aligns with existing capabilities.

src/tl_templates/cuda/copy_sm100.h (1)

98-132: LGTM! Well-structured template wrappers for TMEM loading.

The tcgen05_ld wrapper functions provide a clean interface for different data path widths, properly calling the fence after loading. The MAX_LOGN parameters (7, 7, 6, 5) correctly match the maximum supported N values for each variant.

src/op/gemm.cc (3)

765-823: Add tests to validate TCGEN5MMA D/E/G layout computation
No existing coverage for Layout D/E/G (src/op/gemm.cc:789–822). Cross-check the FloorDiv/FloorMod formulas against NVIDIA’s PTX “tcgen05 data-path” docs and the reference in src/layout/tcgen05_layout.cc, and add unit tests for meta.atom_m = 128, 64, 32 cases to ensure correctness.


511-576: Clarify multi-warp TCGEN5MMA execution
The lowering in src/op/gemm.cc lines 571–574 only emits the TCGEN5MMA call when
FloorDiv(thread_var, 32) == FloorDiv(thread_bounds.min, 32) (i.e. warp 0), skipping all other warps. We found no comments, tests, or external docs covering multi-warp invocation (the sm100 note at lines 609–612 only addresses single-thread issuance within a warp). Confirm whether your schedule intends only the first warp to issue the MMA instruction or adjust the logic to invoke it on each warp as needed.


169-176: AllowTCGEN5MMA correctly gates SM100 (sm_100–sm_103)
TargetIsSm100(target) returns true for arch 100–103, and no per‐variant TCGEN5MMA restrictions were found.

Comment on lines +45 to +48
-cppcoreguidelines-pro-bounds-pointer-arithmetic,
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,
-clang-analyzer-deadcode.DeadStores,
-clang-analyzer-optin.cplusplus.VirtualCall,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Undo the blanket disable of the bounds and virtual-call checks

These three checks were part of the “Retained categories” for a reason—they regularly flag genuine memory-safety and lifecycle bugs across our CUDA backends. Disabling them globally trades concrete signal for convenience and will let regressions slip through (e.g., silent out-of-bounds pointer arithmetic or virtual dispatch in constructors/destructors). If there are nuisance warnings in new SM100 codepaths, please suppress them locally with targeted NOLINT or refactor the offending spots instead of removing the protections repo-wide.

-  -cppcoreguidelines-pro-bounds-pointer-arithmetic,
-  -cppcoreguidelines-pro-bounds-array-to-pointer-decay,
-  -clang-analyzer-optin.cplusplus.VirtualCall,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
-cppcoreguidelines-pro-bounds-pointer-arithmetic,
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,
-clang-analyzer-deadcode.DeadStores,
-clang-analyzer-optin.cplusplus.VirtualCall,
-clang-analyzer-deadcode.DeadStores,
🤖 Prompt for AI Agents
In .clang-tidy around lines 45 to 48, the rules
cppcoreguidelines-pro-bounds-pointer-arithmetic,
cppcoreguidelines-pro-bounds-array-to-pointer-decay and
clang-analyzer-optin.cplusplus.VirtualCall were removed; restore these checks by
removing them from the negative list so they run globally, and instead address
any false positives in SM100 or other new code by adding targeted NOLINT
comments at the exact offending lines or refactoring the code to eliminate the
warning (e.g., avoid pointer-arithmetic that can overflow, use safe indexing,
and avoid virtual calls in constructors/destructors).

Comment on lines +42 to +44
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Respect the num_stages argument

matmul advertises a tunable num_stages, but the pipelined loop hardcodes num_stages=1, so callers (including the example, which passes 0) silently get a different staging depth than requested. That breaks tuning knobs and can desynchronize the mbarrier usage. Please plumb the parameter through.

-            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
+            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
🤖 Prompt for AI Agents
In examples/gemm_sm100/gemm_tcgen5mma.py around lines 42 to 44, the pipelined
loop currently hardcodes num_stages=1 which ignores the matmul’s tunable
num_stages (callers may pass 0 or other values) and can desynchronize mbarrier
usage; replace the hardcoded 1 with the function/local parameter that holds the
requested staging depth (e.g., num_stages) so the call reads
T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages), ensuring the variable
is in scope and propagated from the matmul signature.

Comment on lines +43 to +54
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix TMEM tile loads for transpose flags

These two T.copy statements ignore trans_A/trans_B and always fetch tiles as if A were row-major and B were transposed. For trans_A=True, A_shape is (K, M) yet we still step the first dimension with by * block_M, which walks past the K extent; for trans_B=False, B_shape is (K, N) but we index its first dimension with bx * block_N. Anything outside the single (False, True) combination used in the demo will read the wrong region or fall off the tensor. The unit-test kernel in testing/python/kernel/test_tilelang_kernel_gemm.py fixes this by branching on the transpose flags—please mirror that logic here.

-                T.copy(A[by * block_M, k * block_K], A_shared)
-                T.copy(B[bx * block_N, k * block_K], B_shared)
+                if trans_A:
+                    T.copy(A[k * block_K, by * block_M], A_shared)
+                else:
+                    T.copy(A[by * block_M, k * block_K], A_shared)
+                if trans_B:
+                    T.copy(B[bx * block_N, k * block_K], B_shared)
+                else:
+                    T.copy(B[k * block_K, bx * block_N], B_shared)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
🤖 Prompt for AI Agents
In examples/gemm_sm100/gemm_tcgen5mma.py around lines 43 to 54, the two T.copy
calls always index A and B as if A is (M,K) and B is (N,K) and ignore
trans_A/trans_B; this causes out-of-bounds or wrong tiles when transposed. Fix
by branching on trans_A and trans_B like the unit test: when trans_A is False
copy A[by * block_M, k * block_K] (tile dims M x K), but when trans_A is True
copy A[k * block_K, by * block_M] (tile dims K x M); similarly, when trans_B is
True copy B[bx * block_N, k * block_K] (tile dims N x K), else when trans_B is
False copy B[k * block_K, bx * block_N] (tile dims K x N). Mirror the
index-swapping and tile extents from
testing/python/kernel/test_tilelang_kernel_gemm.py so TMEM loads match the
declared shapes.

Comment on lines +85 to +104
# Compile kernel
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required
})

# Run test
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
c = jit_kernel(a, b)

# Verify correctness
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)

# Performance benchmark
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix the example to compile main rather than an undefined symbol.

Lines 86-93 invoke tilelang.compile(func, ...), but the snippet only defines main. Copying the example as written will raise a NameError. Update the call (and any downstream references) to use the defined prim func so users can execute the example successfully.

-jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={
+jit_kernel = tilelang.compile(main, out_idx=[2], target="cuda", pass_configs={
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Compile kernel
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required
})
# Run test
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
c = jit_kernel(a, b)
# Verify correctness
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Performance benchmark
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS")
# Compile kernel
jit_kernel = tilelang.compile(main, out_idx=[2], target="cuda", pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required
})
# Run test
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
c = jit_kernel(a, b)
# Verify correctness
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Performance benchmark
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS")
🤖 Prompt for AI Agents
In examples/gemm_sm100/README.md around lines 85 to 104, the code calls
tilelang.compile(func, ...) but only defines a prim func named main; change the
compile invocation to tilelang.compile(main, ...) (and any downstream references
expecting that compiled object remain the same) so the example compiles the
defined symbol instead of the undefined name func; keep the existing
pass_configs, inputs, profiling and verification lines unchanged.

Comment on lines +752 to +765
if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % (vector_size * 8) == 0)
else if (mat_continuous % vector_size == 0)
return makeGemmLayoutLinear(mat_stride, mat_continuous);
else
ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride
<< ", continuous=" << mat_continuous
<< ", element_size=" << element_size << ", kfactor=" << kfactor;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Regression: Hopper layout now aborts for valid widths

For non power-of-two tiles (e.g., element_size == 8, mat_continuous == 48) we used to fall back to makeGemmABLayoutPadded, so Hopper kernels kept working. The new branch ends with ICHECK(0), which now fatals for those same shapes. That’s a correctness regression that will abort user programs. Please restore a padded/linear fallback instead of hard failing.

   else if (mat_continuous % (vector_size * 2) == 0)
     return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
                                         element_size);
-  else if (mat_continuous % vector_size == 0)
-    return makeGemmLayoutLinear(mat_stride, mat_continuous);
-  else
-    ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride
-              << ", continuous=" << mat_continuous
-              << ", element_size=" << element_size << ", kfactor=" << kfactor;
+  else if (mat_continuous % vector_size == 0)
+    return makeGemmLayoutLinear(mat_stride, mat_continuous);
+  return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % (vector_size * 8) == 0)
else if (mat_continuous % vector_size == 0)
return makeGemmLayoutLinear(mat_stride, mat_continuous);
else
ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride
<< ", continuous=" << mat_continuous
<< ", element_size=" << element_size << ", kfactor=" << kfactor;
}
if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % vector_size == 0)
return makeGemmLayoutLinear(mat_stride, mat_continuous);
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
}
🤖 Prompt for AI Agents
In src/layout/gemm_layouts.cc around lines 752-765, the final ICHECK(0) causes
Hopper to abort for valid non-power-of-two widths (e.g., element_size==8,
mat_continuous==48); restore the previous padded/linear fallback instead of
fatally asserting by returning the padded GEMM A/B layout (same call used
before, e.g. makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size,
kfactor) or, if that API isn’t available here, fall back to
makeGemmLayoutLinear/makeGemmABLayoutPadded as appropriate) so the code returns
a valid layout for those shapes rather than calling ICHECK(0).

Comment on lines +94 to +100
Stmt VisitStmt_(const BlockNode *op) final {
auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
if (op->annotations.count(attr::kLayoutMap)) {
block.CopyOnWrite()->annotations.Set(attr::kLayoutMap, layout_remap_);
}
return block;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Do not attach the global layout_remap_ to every block

layout_remap_ is populated while we traverse the entire PrimFunc, so by the time we call LayoutRemapRewriter::Substitute it contains the union of all remapped buffers. Overwriting each block’s attr::kLayoutMap with that global map means every annotated block now advertises layouts for buffers it does not allocate or match. That breaks the per-block contract that other passes (including this one, in VisitStmt_(const BlockNode*)) rely on, and it quickly leads to incorrect lowering or hard asserts once a downstream visitor iterates the annotation and cannot find the corresponding buffers in the block. Please filter the annotation so that we only write back the layouts for buffers owned by the current block.

Apply this diff to keep the annotation block-local:

     auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
     if (op->annotations.count(attr::kLayoutMap)) {
-      block.CopyOnWrite()->annotations.Set(attr::kLayoutMap, layout_remap_);
+      Map<Buffer, Layout> filtered;
+      for (const Buffer &buffer : block->alloc_buffers) {
+        if (layout_remap_.count(buffer)) {
+          filtered.Set(buffer, layout_remap_[buffer]);
+        }
+      }
+      for (const MatchBufferRegion &match_buffer : block->match_buffers) {
+        const Buffer &buffer = match_buffer->buffer;
+        if (layout_remap_.count(buffer)) {
+          filtered.Set(buffer, layout_remap_[buffer]);
+        }
+      }
+      ICHECK(!filtered.empty())
+          << "Expected layout remap info for block-local buffers";
+      block.CopyOnWrite()->annotations.Set(attr::kLayoutMap, filtered);
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Stmt VisitStmt_(const BlockNode *op) final {
auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
if (op->annotations.count(attr::kLayoutMap)) {
block.CopyOnWrite()->annotations.Set(attr::kLayoutMap, layout_remap_);
}
return block;
}
Stmt VisitStmt_(const BlockNode *op) final {
auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
if (op->annotations.count(attr::kLayoutMap)) {
// Only propagate remap entries for buffers that this block actually allocates or matches
Map<Buffer, Layout> filtered;
for (const Buffer &buffer : block->alloc_buffers) {
if (layout_remap_.count(buffer)) {
filtered.Set(buffer, layout_remap_[buffer]);
}
}
for (const MatchBufferRegion &match_buffer : block->match_buffers) {
const Buffer &buffer = match_buffer->buffer;
if (layout_remap_.count(buffer)) {
filtered.Set(buffer, layout_remap_[buffer]);
}
}
ICHECK(!filtered.empty())
<< "Expected layout remap info for block-local buffers";
block.CopyOnWrite()->annotations.Set(attr::kLayoutMap, filtered);
}
return block;
}
🤖 Prompt for AI Agents
In src/transform/lower_tile_op.cc around lines 94-100, the code attaches the
global layout_remap_ to every block which leaks layouts for buffers the block
doesn't own; instead build a block-local map by iterating the block's owned
buffers (alloc_buffers and match_buffers / any BlockNode fields that list
buffers the block allocates or matches), copy entries from layout_remap_ only
for those buffer vars, and write that filtered map into the block's
attr::kLayoutMap so the annotation only contains layouts for buffers owned by
the current block.

Comment on lines +47 to +118
Buffer buf = op->buffer;
if (buf->data->type_annotation.as<PointerTypeNode>()->storage_scope ==
"shared") {
// We only care about shared.tmem buffers
ICHECK(!result.defined())
<< "TmemLoadCollector: More than one shared buffer visited";
result = buf;
}
}
};

/*!
* \brief Build the dependency chain between async operations and their
* corresponding buffers & synchronizations.
*
* Example:
* If we encounter the following pattern:
*
* tcgen5mma_gemm_ts(..., mbar, ...)
* mbarrier_wait_parity(mbar)
*
* The builder will link the mbarrier to the buffers used in the
* TCGEN5MMA
*/
class AsyncDependencyChainBuilder : public StmtExprVisitor {
public:
AsyncDependencyChainBuilder(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {}

std::unordered_map<const BufferNode *, Array<BufferRegion>>
mbar_to_buffer_reads_;

std::unordered_map<const BufferNode *, Array<BufferRegion>>
mbar_to_buffer_writes_;

private:
Map<Var, Buffer> buffer_data_to_buffer_;

void VisitExpr_(const CallNode *op) final {
auto args = op->args;
if (op->op.same_as(builtin::call_extern())) {
std::string func_name_with_template = args[0].as<StringImmNode>()->value;
std::size_t le_pos = func_name_with_template.find_first_of('<');
std::string func_name = le_pos == std::string::npos
? func_name_with_template
: func_name_with_template.substr(0, le_pos);
if (func_name == "tl::utcmma_gemm_ts" ||
func_name == "tl::utcmma_gemm_ss") {
// TCGEN5MMA
auto get_buf_from_access_ptr_call =
[&](const PrimExpr &expr) -> Buffer {
auto call = expr.as<CallNode>();
ICHECK(call);
ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
auto var = call->args[1].as<VarNode>();
ICHECK(var);
auto it = buffer_data_to_buffer_.find(GetRef<Var>(var));
ICHECK(it != buffer_data_to_buffer_.end());
return (*it).second;
};
Buffer a_buf = get_buf_from_access_ptr_call(args[1]);
Buffer b_buf = get_buf_from_access_ptr_call(args[2]);
Buffer mbar_buf = get_buf_from_access_ptr_call(args[4]);

TmemLoadCollector tmem_collector;
tmem_collector(args[3]);
ICHECK(tmem_collector.result.defined())
<< "TmemLoadCollector: No tmem buffer load found in the TCGEN5MMA "
"call";
Buffer c_buf = tmem_collector.result;

PrimExpr clear_accum = args[5];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix TMEM detection to avoid crashes and false negatives.

ptr_type is dereferenced before we confirm it exists, so a missing type annotation segfaults before the ICHECK fires. We also look for storage scope "shared", but TCGEN5 writes land in "shared.tmem", which leaves result undefined and trips the later ICHECK. Please guard the pointer first and match the correct TMEM scope.

Apply this diff:

-    const auto *ptr_type =
-        buf->data->type_annotation.as<PointerTypeNode>();
-    auto storage_scope = ptr_type->storage_scope;
-    ICHECK(ptr_type) << "Buffer Var's type annotation must be of PointerType";
-    if (storage_scope == "shared") {
+    const auto *ptr_type =
+        buf->data->type_annotation.as<PointerTypeNode>();
+    ICHECK(ptr_type) << "Buffer Var's type annotation must be of PointerType";
+    if (ptr_type->storage_scope == "shared.tmem") {
       // We only care about shared.tmem buffers
       ICHECK(!result.defined())
-          << "TmemLoadCollector: More than one shared buffer visited";
+          << "TmemLoadCollector: More than one shared.tmem buffer visited";
       result = buf;
     }

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/transform/pipeline_planning.cc around lines 47 to 118, the
TmemLoadCollector dereferences the pointer-type annotation without null-checking
and compares storage_scope to "shared" (causing false negatives for
"shared.tmem" and crashes if type annotation is missing). Change the logic to
first verify the type annotation and that as<PointerTypeNode>() returns non-null
before accessing storage_scope, then match the TMEM scope against "shared.tmem"
(or use a starts_with/substring check for scopes that begin with "shared.tmem")
and only then set result; keep the existing ICHECK to ensure only one
shared.tmem buffer is collected.

Comment on lines +92 to +118
def alloc_tmem(shape, dtype):
"""
Allocate a Tensor Memory (TMEM) buffer for use with 5th generation Tensor Core operations (e.g., TCGEN5.MMA).
TMEM is a dedicated on-chip memory introduced in Hopper GPUs, designed to reduce register pressure and enable asynchronous, single-threaded MMA operations. It is organized as a 2D array of 512 columns by 128 rows (lanes), with each cell being 32 bits. Allocation is performed in units of columns, and every lane of a column is allocated together.
Key properties and requirements:
- The number of columns allocated must be a power of 2 and at least 32.
- TMEM allocations are dynamic and must be explicitly deallocated.
- Both allocation and deallocation must be performed by the same warp.
- The base address of the TMEM allocation is stored in shared memory and used as the offset for TCGEN5.MMA accumulator tensors.
- Only TCGEN5.MMA and specific TMEM load/store instructions can access TMEM; all pre-processing must occur before data is loaded into TMEM, and all post-processing after data is retrieved.
- The number of columns allocated should not increase between any two allocations in the execution order within the CTA.
Args:
num_cols (int): Number of columns to allocate in TMEM. Must be a power of 2 and >= 32 but less than or equal to 512.
Returns:
T.Buffer: A TVM buffer object allocated in TMEM scope, suitable for use as an accumulator or operand in TCGEN5.MMA operations.
Note:
- TMEM is only available on supported architectures (e.g., Hopper and later).
- The buffer returned should be used according to TMEM access restrictions and deallocated appropriately.
"""

assert len(shape) == 2, "shape must be a 2D tensor for TMEM allocation"
return T.alloc_buffer(shape, dtype, scope="shared.tmem")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix the docstring arg description.

The docstring still references a num_cols parameter, but the function actually accepts (shape, dtype). This mismatch is misleading for users reading the API docs—please update the text (or the signature) so they line up.

🤖 Prompt for AI Agents
In tilelang/language/allocate.py around lines 92 to 118, the docstring
incorrectly documents a non-existent num_cols parameter while the function
signature is alloc_tmem(shape, dtype); update the docstring to describe the
actual parameters (shape: 2-tuple for rows/columns and dtype: element type),
remove or replace references to num_cols, and adjust the description of TMEM
allocation behavior to explain that shape specifies the 2D tensor dimensions
(including any constraints on columns expressed in terms of shape[1] if needed);
keep the usage notes about TMEM scope and deallocation but ensure all parameter
names and examples match the function signature.

Comment on lines +209 to +212
C_coords = [r.min for r in C.region] if isinstance(C, tir.BufferRegion) else [0, 0]
return tir.call_intrin("handle", tir.op.Op.get("tl.gemm"), Aptr, Bptr, Cptr, transpose_A,
transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a,
offset_b, k_pack, wg_wait, mbarptr, C_coords[0], C_coords[1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Handle BufferLoad when computing C tile coordinates

gemm already accepts tir.BufferLoad for C (see the existing retrieve_ptr / retrieve_shape helpers), but the new C_coords branch only extracts offsets for tir.BufferRegion. For BufferLoad call sites—which we use when feeding tile slices of C—this now hardcodes [0, 0], so the backend sees every tile at the origin and the TCGEN5MMA barrier wiring becomes wrong. Please plumb the real minima for BufferLoad cases as well.

-    C_coords = [r.min for r in C.region] if isinstance(C, tir.BufferRegion) else [0, 0]
+    if isinstance(C, tir.BufferRegion):
+        C_coords = [r.min for r in C.region]
+    elif isinstance(C, tir.BufferLoad):
+        region = get_buffer_region_from_load(C).region
+        C_coords = [r.min for r in region]
+    else:
+        C_coords = [0, 0]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
C_coords = [r.min for r in C.region] if isinstance(C, tir.BufferRegion) else [0, 0]
return tir.call_intrin("handle", tir.op.Op.get("tl.gemm"), Aptr, Bptr, Cptr, transpose_A,
transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a,
offset_b, k_pack, wg_wait, mbarptr, C_coords[0], C_coords[1])
if isinstance(C, tir.BufferRegion):
C_coords = [r.min for r in C.region]
elif isinstance(C, tir.BufferLoad):
region = get_buffer_region_from_load(C).region
C_coords = [r.min for r in region]
else:
C_coords = [0, 0]
return tir.call_intrin("handle", tir.op.Op.get("tl.gemm"), Aptr, Bptr, Cptr, transpose_A,
transpose_B, M, N, K, policy, clear_accum, stride_a, stride_b, offset_a,
offset_b, k_pack, wg_wait, mbarptr, C_coords[0], C_coords[1])

Comment on lines +65 to +67
target = tvm.target.Target.current(allow_none=True)
if target is not None:
return target
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Preserve return_object contract when reusing the current TVM target

determine_target() historically returned a string backend key whenever return_object is left as the default False. With this early return we now hand back a tvm.target.Target instance whenever a target context is active, bypassing the return_object handling below. Any caller that compares the result to literals like "cuda"/"hip" (there are several today) will start failing as soon as a target context is installed. Please keep the old contract: only surface the Target instance when the caller explicitly requests it, and otherwise downgrade to the target kind string.

-    if target == "auto":
-        target = tvm.target.Target.current(allow_none=True)
-        if target is not None:
-            return target
+    if target == "auto":
+        current_target = tvm.target.Target.current(allow_none=True)
+        if current_target is not None:
+            if return_object:
+                return current_target
+            return current_target.kind.name
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
target = tvm.target.Target.current(allow_none=True)
if target is not None:
return target
if target == "auto":
- target = tvm.target.Target.current(allow_none=True)
- if target is not None:
current_target = tvm.target.Target.current(allow_none=True)
if current_target is not None:
if return_object:
return current_target
return current_target.kind.name
🤖 Prompt for AI Agents
In tilelang/utils/target.py around lines 65-67, the early return hands back a
tvm.target.Target instance unconditionally when a target context is active,
breaking callers that expect a backend key string when return_object is False;
change the logic to respect the return_object parameter: if target is not None
and return_object is True return the Target instance, otherwise when
return_object is False return a string backend key (use target.kind if
available, falling back to str(target)), and preserve existing behavior with
allow_none=True (i.e., return None when no current target).

@chengyupku chengyupku merged commit f58bcd4 into tile-ai:main Sep 28, 2025
6 of 7 checks passed
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
* update sm100 related utcmma, tmem, ld/st256 in src
* update sm100 related utcmma, tmem, ld/st256 in tilelang
* Remove deprecated GEMM examples and related README documentation for SM100 architecture support
* Update GEMM implementation to replace UTCMMA with TCGEN5MMA across relevant files
* Remove gemm_umma.py example and update README to reflect TCGEN5MMA terminology changes
* Update README.md for gemm_sm100 example by removing outdated API sections and streamlining documentation
* Update README and source files to reflect TCGEN5.MMA terminology changes
* Refactor CUDA GEMM header for improved readability
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants