Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions examples/elementwise/example_elementwise_add_tma_1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import argparse
import tilelang
import tilelang.language as T
import torch

tilelang.disable_cache()


def ref_program(x, y):
return x + y


@tilelang.jit(out_idx=[-1])
def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):

@T.prim_func
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor(
(M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), in_dtype)
B_shared = T.alloc_shared((block_M, block_N), in_dtype)
C_local = T.alloc_fragment((block_M, block_N), out_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)

T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(B[by * block_M, bx * block_N], B_shared)
for (local_y, local_x) in T.Parallel(block_M, block_N):
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])

return elem_add


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=128)
parser.add_argument("--n", type=int, default=128)
args, _ = parser.parse_known_args()
M, N = args.m, args.n

a = torch.randn(M, N, dtype=torch.float32, device="cuda")
b = torch.randn(M, N, dtype=torch.float32, device="cuda")

# Default config
config = {"block_M": 128, "block_N": 128, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")

out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
print("All passed!")


if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions examples/elementwise/test_example_elementwise.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import tilelang.testing
import example_elementwise_add
import example_elementwise_add_tma_1d


def test_example_elementwise_add():
example_elementwise_add.main()


def test_example_elementwise_add_tma_1d():
example_elementwise_add_tma_1d.main()


if __name__ == "__main__":
tilelang.testing.main()
124 changes: 122 additions & 2 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -772,19 +772,138 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer,
stride *= s;
}

Array<PrimExpr> global_indices;
for (auto r : global_range) {
global_indices.push_back(r->min);
}
std::vector<PrimExpr> global_strides;
PrimExpr global_stride = 1;
for (size_t i = 0; i < global_tensor->shape.size(); i++) {
auto s = global_tensor->shape[global_tensor->shape.size() - i - 1];
global_strides.insert(global_strides.begin(), global_stride);
global_stride *= s;
}

Comment on lines +775 to +786
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Don’t assume compact layout for address math; use buffer strides when available.

The 1D path computes global_strides (and earlier, shared strides) from shapes, implicitly assuming a compact/row-major layout. If Buffer.strides is present (non-compact tensors, slices, or external buffers), the current offset calculations for global_offset/offset will be wrong.

  • Use global_tensor->strides when available; otherwise, fall back to the current shape-derived logic.
  • Consider mirroring this for shared_tensor_before_remap as well to future-proof shared slices with explicit strides.

Apply this diff within the changed hunk to honor global_tensor->strides:

-  std::vector<PrimExpr> global_strides;
-  PrimExpr global_stride = 1;
-  for (size_t i = 0; i < global_tensor->shape.size(); i++) {
-    auto s = global_tensor->shape[global_tensor->shape.size() - i - 1];
-    global_strides.insert(global_strides.begin(), global_stride);
-    global_stride *= s;
-  }
+  std::vector<PrimExpr> global_strides;
+  if (!global_tensor->strides.empty()) {
+    // Respect explicit (possibly non-compact) strides
+    global_strides = std::vector<PrimExpr>(global_tensor->strides.begin(),
+                                           global_tensor->strides.end());
+  } else {
+    // Derive element strides from shape (compact layout)
+    PrimExpr global_stride = 1;
+    for (size_t i = 0; i < global_tensor->shape.size(); i++) {
+      auto s = global_tensor->shape[global_tensor->shape.size() - i - 1];
+      global_strides.insert(global_strides.begin(), global_stride);
+      global_stride *= s;
+    }
+  }

Additionally (outside the changed hunk), consider the analogous update for the shared side:

// If shared_tensor_before_remap->strides not empty, form `strides` from it;
// otherwise keep the existing shape-derived logic.

If you’d like, I can draft the corresponding update for the shared strides block and thread through unit tests that cover non-compact cases (e.g., strided views and sliced buffers).

Also applies to: 794-799

🤖 Prompt for AI Agents
In src/op/copy.cc around lines 775-786 (and similarly 794-799), the code
currently derives global_strides from shapes assuming compact layout; change it
to check if global_tensor->strides is non-empty and, if so, build global_strides
from that vector (preserving order consistent with existing index arithmetic),
otherwise fall back to the current shape-derived logic; also mirror the same
pattern for the shared_tensor_before_remap strides block elsewhere so shared
buffers with explicit strides are honored; ensure the rest of the offset math
uses these stride vectors and that no assumptions of row-major compactness
remain.

ICHECK(strides.size() == indices.size())
<< "strides.size() != indices.size()" << strides.size() << " "
<< indices.size();
PrimExpr offset = 0;
for (size_t i = 0; i < indices.size(); i++) {
offset += indices[i] * strides[i];
}
PrimExpr global_offset = 0;
for (size_t i = 0; i < global_indices.size(); i++) {
global_offset += global_indices[i] * global_strides[i];
}
auto shared_tensor_before_remap = shared_tensor;
Layout shared_layout;
if (T.layout_map.count(shared_tensor)) {
shared_layout = T.layout_map[shared_tensor];
shared_tensor = T.buffer_remap[shared_tensor];
}

// Add 1D TMA copy when the global and shared memory is contiguous
{
// Check if shared_tensor->name is present in T.buffer_var_gemm
// (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout
bool shared_is_contiguous = true;
for (const auto &v : T.buffer_var_gemm) {
if (v->name_hint == shared_tensor->name) {
shared_is_contiguous = false;
break;
}
}
bool shared_not_full_dim_encounter = false;
Comment on lines +805 to +816
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid name-based GEMM detection; use Var identity and skip 1D path when shared has a layout map

  • Comparing v->name_hint to shared_tensor->name can miss due to buffer remap; compare the Var to shared_tensor_before_remap->data.
  • If T.layout_map contains the shared buffer, 1D TMA path will produce a tvm_access_ptr that later remaps only the Var (not the offset), yielding wrong addressing.

Apply:

-  // Add 1D TMA copy when the global and shared memory is contiguous
-  {
+  // Add 1D TMA copy when the global and shared memory is contiguous
+  // Only when shared has no remapped/swizzled layout.
+  if (!T.layout_map.count(shared_tensor_before_remap)) {
     // Check if shared_tensor->name is present in T.buffer_var_gemm
     // (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout
-    bool shared_is_contiguous = true;
-    for (const auto &v : T.buffer_var_gemm) {
-      if (v->name_hint == shared_tensor->name) {
-        shared_is_contiguous = false;
-        break;
-      }
-    }
+    bool shared_is_contiguous = true;
+    for (const auto &v : T.buffer_var_gemm) {
+      if (v.same_as(shared_tensor_before_remap->data)) {
+        shared_is_contiguous = false;
+        break;
+      }
+    }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Add 1D TMA copy when the global and shared memory is contiguous
{
// Check if shared_tensor->name is present in T.buffer_var_gemm
// (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout
bool shared_is_contiguous = true;
for (const auto &v : T.buffer_var_gemm) {
if (v->name_hint == shared_tensor->name) {
shared_is_contiguous = false;
break;
}
}
bool shared_not_full_dim_encounter = false;
// Add 1D TMA copy when the global and shared memory is contiguous
// Only when shared has no remapped/swizzled layout.
if (!T.layout_map.count(shared_tensor_before_remap)) {
// Check if shared_tensor->name is present in T.buffer_var_gemm
// (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout
bool shared_is_contiguous = true;
for (const auto &v : T.buffer_var_gemm) {
if (v.same_as(shared_tensor_before_remap->data)) {
shared_is_contiguous = false;
break;
}
}
bool shared_not_full_dim_encounter = false;
}

for (ssize_t i = shared_range.size() - 1; i >= 0; --i) {
if (!shared_not_full_dim_encounter) {
if (!analyzer->CanProve(shared_range[i]->extent ==
shared_tensor_before_remap->shape[i] &&
shared_range[i]->min == 0)) {
shared_not_full_dim_encounter = true;
}
} else {
if (!analyzer->CanProve(shared_range[i]->extent == 1)) {
shared_is_contiguous = false;
break;
}
}
}
// Currently we check the empty stride of global tensor
bool global_is_contiguous = !global_tensor->strides.empty();
bool global_not_full_dim_encounter = false;
for (ssize_t i = global_range.size() - 1; i >= 0; --i) {
if (!global_not_full_dim_encounter) {
if (!analyzer->CanProve(global_range[i]->extent ==
global_tensor->shape[i] &&
global_range[i]->min == 0)) {
global_not_full_dim_encounter = true;
}
} else {
if (!analyzer->CanProve(global_range[i]->extent == 1)) {
global_is_contiguous = false;
break;
}
}
}
Comment on lines +832 to +847
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

global_is_contiguous initialization bug

bool global_is_contiguous = !global_tensor->strides.empty(); incorrectly marks arbitrary strided tensors as contiguous. Initialize to true and let the checks invalidate it.

Apply:

-    // Currently we check the empty stride of global tensor
-    bool global_is_contiguous = !global_tensor->strides.empty();
+    // Start optimistic; invalidate below if pattern breaks contiguity
+    bool global_is_contiguous = true;

Optionally, add an extra guard to verify explicit strides match compact layout when global_tensor->strides is provided.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
bool global_is_contiguous = !global_tensor->strides.empty();
bool global_not_full_dim_encounter = false;
for (ssize_t i = global_range.size() - 1; i >= 0; --i) {
if (!global_not_full_dim_encounter) {
if (!analyzer->CanProve(global_range[i]->extent ==
global_tensor->shape[i] &&
global_range[i]->min == 0)) {
global_not_full_dim_encounter = true;
}
} else {
if (!analyzer->CanProve(global_range[i]->extent == 1)) {
global_is_contiguous = false;
break;
}
}
}
// Start optimistic; invalidate below if pattern breaks contiguity
bool global_is_contiguous = true;
bool global_not_full_dim_encounter = false;
for (ssize_t i = global_range.size() - 1; i >= 0; --i) {
if (!global_not_full_dim_encounter) {
if (!analyzer->CanProve(global_range[i]->extent ==
global_tensor->shape[i] &&
global_range[i]->min == 0)) {
global_not_full_dim_encounter = true;
}
} else {
if (!analyzer->CanProve(global_range[i]->extent == 1)) {
global_is_contiguous = false;
break;
}
}
}
🤖 Prompt for AI Agents
In src/op/copy.cc around lines 832 to 847, the variable global_is_contiguous is
incorrectly initialized from !global_tensor->strides.empty(), which treats
tensors with arbitrary strides as contiguous; change its initialization to true
and let the subsequent loop set it false when non-contiguous conditions are
found. Also (optionally) add a guard: when global_tensor->strides is non-empty,
verify the explicit strides correspond to a compact (row-major) layout for the
given shape and set global_is_contiguous=false if they don’t match.

// Ensure there is element match and no OOB
PrimExpr shared_elements = 1;
for (size_t i = 0; i < shared_range.size(); i++) {
shared_elements *= shared_range[i]->extent;
}
PrimExpr global_elements = 1;
for (size_t i = 0; i < global_range.size(); i++) {
global_elements *= global_range[i]->extent;
}
bool element_match =
analyzer->CanProveEqual(shared_elements, global_elements);
bool no_oob = true;
for (size_t i = 0; i < shared_range.size(); i++) {
if (!analyzer->CanProve(shared_range[i]->min + shared_range[i]->extent <=
shared_tensor_before_remap->shape[i])) {
no_oob = false;
break;
}
}
for (size_t i = 0; i < global_range.size(); i++) {
if (!analyzer->CanProve(global_range[i]->min + global_range[i]->extent <=
global_tensor->shape[i])) {
no_oob = false;
break;
}
}
// Add 1D TMA copy
// LOG(INFO) << "shared_is_contiguous: " << shared_is_contiguous;
// LOG(INFO) << "global_is_contiguous: " << global_is_contiguous;
// LOG(INFO) << "element_match: " << element_match;
// LOG(INFO) << "no_oob: " << no_oob;
if (shared_is_contiguous && global_is_contiguous && element_match &&
no_oob) {
// LOG(INFO) << "TMA 1D bulk copy is supported";
PrimExpr elements = analyzer->Simplify(shared_elements);
PrimExpr shared_addr = shared_tensor_before_remap.access_ptr(
is_load ? 2 : 1, DataType::Handle(), 1, offset, elements);
PrimExpr global_addr = global_tensor.access_ptr(
is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements);
Stmt tma_copy;
if (is_load) {
// the zero is a placeholder for mbarrier id
tma_copy =
Evaluate(Call(DataType::Handle(), tma_load(),
{shared_addr, global_addr, 0,
elements * shared_tensor_before_remap->dtype.bytes(),
this->eviction_policy}));
} else {
tma_copy =
Evaluate(Call(DataType::Handle(), tma_store(),
{global_addr, shared_addr,
elements * shared_tensor_before_remap->dtype.bytes(),
this->eviction_policy}));
}
tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy);
return tma_copy;
}
}

TMADesc desc;
// Verify copy rank
desc.rank = global_tensor->shape.size();
Expand Down Expand Up @@ -1221,10 +1340,11 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {

// Register the Copy operation with TVM's TIR system
// This makes the copy operation available for use in TVM programs
// - Takes 4 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma
// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
// eviction_policy
// - Marked as opaque since it has side effects (memory writes)
TIR_REGISTER_TL_OP(Copy, copy)
.set_num_inputs(4)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

Expand Down
1 change: 1 addition & 0 deletions src/op/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct LowerArgs {
AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap;
Array<Var> buffer_var_gemm;
};

struct LayoutInferArgs {
Expand Down
10 changes: 10 additions & 0 deletions src/tl_templates/cuda/copy_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ tma_load_im2col(const CUtensorMap &descriptor, BarrierType &smem_mbar,
: "memory");
}

template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
TL_DEVICE void tma_store(void *gmem_ptr, void *smem_ptr, uint32_t size) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.global.shared::cta.bulk_group"
".L2::cache_hint [%0], [%1], %2, %3;"
:
: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint)
:);
}
Comment on lines +174 to +182
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Add memory clobber and confirm cp.async.bulk syntax (consider adding ::bytes).

  • The inline asm for the new size-based tma_store lacks a "memory" clobber. Given it performs a GMEM write, compilers may reorder surrounding memory ops. The descriptor-based tma_store overloads use "memory"; mirror that here.
  • Also, please double-check the PTX mnemonic. The non-descriptor tma_load variant explicitly includes "::bytes" in the opcode. For symmetry and to match Hopper docs, the non-descriptor store path typically uses the "::bytes" variant as well. If required, add it to prevent PTX parsing issues.

Suggested diff:

 TL_DEVICE void tma_store(void *gmem_ptr, void *smem_ptr, uint32_t size) {
   uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
-  asm volatile("cp.async.bulk.global.shared::cta.bulk_group"
-               ".L2::cache_hint [%0], [%1], %2, %3;"
-               :
-               : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint)
-               :);
+  asm volatile("cp.async.bulk.global.shared::cta.bulk_group"
+               "::bytes.L2::cache_hint [%0], [%1], %2, %3;"
+               :
+               : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint)
+               : "memory");
 }

If your PTX toolchain already accepts the current mnemonic, at minimum add the "memory" clobber. I can help run a quick compile check to validate the opcode variant you intend to use.


🏁 Script executed:

#!/bin/bash
set -e

# Locate the copy_sm90.h file in the repository
file=$(fd copy_sm90.h | head -n1)
echo "Processing file: $file"

# 1. List all tma_store overloads in this file
echo -e "\n=== tma_store overloads ==="
rg -nP '\btemplate.*tma_store' -n "$file"

# 2. Show all inline asm usages of cp.async in this file
echo -e "\n=== cp.async usages ==="
rg -n 'cp.async' -n "$file"

# 3. Check for any descriptor-based tma_store overload (e.g., taking a TmaDescriptor)
echo -e "\n=== Descriptor-based tma_store (TmaDescriptor) ==="
rg -n 'TmaDescriptor' -n .

# 4. For each tma_store overload, inspect whether the asm has a "memory" clobber
echo -e "\n=== Check for memory clobber in tma_store asm ==="
rg -nP 'asm volatile\("cp.async.*\:.*\:.*"[^"]*\)\s*;\s*:' -n "$file" -A2 | grep -C2 clobber || echo "No explicit memory clobber found in above asm blocks"

Length of output: 331


🏁 Script executed:

#!/bin/bash
set -e

# File path
file="src/tl_templates/cuda/copy_sm90.h"
echo "1) tma_store occurrences in $file"
grep -n "tma_store" "$file" || echo "No tma_store found"

echo -e "\n2) TmaDescriptor usage in repository"
grep -R -n "TmaDescriptor" . || echo "No TmaDescriptor references"

echo -e "\n3) cp.async usages in $file with context"
grep -n -C2 "cp.async" "$file" || echo "No cp.async found"

echo -e "\n4) Snippet around lines 160-182 in $file"
sed -n '160,182p' "$file"

Length of output: 6357


🏁 Script executed:

#!/bin/bash
set -e

# Display the first descriptor-based tma_store overload to verify clobber and opcode syntax
echo "=== Descriptor-based tma_store (1d) snippet ==="
sed -n '185,200p' src/tl_templates/cuda/copy_sm90.h

Length of output: 938


Add missing “memory” clobber (and optionally ::bytes) to the size‐based tma_store

The size-based store at src/tl_templates/cuda/copy_sm90.h (around lines 175–182) currently omits a clobber list and does not annotate the opcode with ::bytes. All descriptor-based tma_store overloads include a "memory" clobber to prevent the compiler from reordering surrounding memory operations. At minimum, mirror their clobber list here. If you also want symmetry with the load path and Hopper documentation, insert ::bytes between bulk_group and .L2.

Suggested diff:

 template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
 TL_DEVICE void tma_store(void *gmem_ptr, void *smem_ptr, uint32_t size) {
   uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
-  asm volatile("cp.async.bulk.global.shared::cta.bulk_group"
-               ".L2::cache_hint [%0], [%1], %2, %3;"
-               :
-               : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint)
-               :);
+  asm volatile("cp.async.bulk.global.shared::cta.bulk_group"
+               "::bytes.L2::cache_hint [%0], [%1], %2, %3;"
+               :
+               : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint)
+               : "memory");
 }

"memory" clobber: required to ensure correct ordering of adjacent loads/stores.
::bytes: optional if your PTX assembler already accepts the unannotated mnemonic; recommended for consistency with the load path.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
TL_DEVICE void tma_store(void *gmem_ptr, void *smem_ptr, uint32_t size) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.global.shared::cta.bulk_group"
".L2::cache_hint [%0], [%1], %2, %3;"
:
: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint)
:);
}
template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
TL_DEVICE void tma_store(void *gmem_ptr, void *smem_ptr, uint32_t size) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
asm volatile("cp.async.bulk.global.shared::cta.bulk_group"
"::bytes.L2::cache_hint [%0], [%1], %2, %3;"
:
: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint)
: "memory");
}
🤖 Prompt for AI Agents
In src/tl_templates/cuda/copy_sm90.h around lines 174–182, the size-based
tma_store inline asm omits the required "memory" clobber (and lacks the optional
::bytes opcode annotation), which can allow the compiler to reorder surrounding
memory operations; update the asm template to include "memory" in the clobber
list (matching the descriptor-based overloads) and, for symmetry with the load
path and Hopper docs, insert ::bytes between bulk_group and .L2 in the opcode
string if your assembler accepts it.


template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
TL_DEVICE void tma_store(const CUtensorMap &descriptor,
void const *const smem_ptr, int32_t const &crd0) {
Expand Down
33 changes: 26 additions & 7 deletions src/transform/inject_tma_barrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,17 @@ class TmaTraitsCollector : public StmtExprVisitor {
private:
void VisitExpr_(const CallNode *call) final {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
int type_bytes = access_ptr->args[0]->dtype.bytes();
bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes;
auto arg0 = call->args[0].as<Call>();
if (call->op.same_as(tma_load()) && arg0 &&
!arg0.value()->op.same_as(create_tma_descriptor())) {
// 1D TMA load has tvm_access_ptr of shared tensor in its args[0]
bulk_copy_bytes = call->args[3] * loop_extents;
} else {
Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
int type_bytes = access_ptr->args[0]->dtype.bytes();
bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes;
}
Comment on lines +65 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: bulk_copy_bytes overwritten for 1D loads; should accumulate (+=), not assign (=).

For multiple 1D tma_load ops inside the same block/then-case, using '=' drops previously accounted bytes, leading to under-issuing mbarrier_expect_tx. Match the 2D path’s accumulation.

Apply this fix:

-        // 1D TMA load has tvm_access_ptr of shared tensor in its args[0]
-        bulk_copy_bytes = call->args[3] * loop_extents;
+        // 1D TMA load has tvm_access_ptr of shared tensor in its args[0]
+        bulk_copy_bytes += call->args[3] * loop_extents;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
auto arg0 = call->args[0].as<Call>();
if (call->op.same_as(tma_load()) && arg0 &&
!arg0.value()->op.same_as(create_tma_descriptor())) {
// 1D TMA load has tvm_access_ptr of shared tensor in its args[0]
bulk_copy_bytes = call->args[3] * loop_extents;
} else {
Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
int type_bytes = access_ptr->args[0]->dtype.bytes();
bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes;
}
auto arg0 = call->args[0].as<Call>();
if (call->op.same_as(tma_load()) && arg0 &&
!arg0.value()->op.same_as(create_tma_descriptor())) {
// 1D TMA load has tvm_access_ptr of shared tensor in its args[0]
bulk_copy_bytes += call->args[3] * loop_extents;
} else {
Call access_ptr = Downcast<Call>(call->args[2]);
ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr()));
int type_bytes = access_ptr->args[0]->dtype.bytes();
bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes;
}
🤖 Prompt for AI Agents
In src/transform/inject_tma_barrier.cc around lines 65 to 75, the 1D tma_load
branch uses assignment to set bulk_copy_bytes which overwrites prior counts;
change the logic to accumulate into bulk_copy_bytes (use += semantics) like the
2D path so multiple 1D loads add their byte counts instead of replacing them,
ensuring you multiply by loop_extents as before and keep types/expressions
identical except using addition.

}
StmtExprVisitor::VisitExpr_(call);
}
Expand Down Expand Up @@ -155,10 +162,15 @@ class TmaExpectTxRewriter : public IRMutatorWithAnalyzer {

PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load())) {
auto arg0 = op->args[0].as<Call>();
bool is_1d_tma_load =
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
op->op.same_as(tma_load());
visited_tma_load_ = true;
Array<PrimExpr> new_args = op->args;
new_args.Set(1, Call(DataType::Handle(), get_mbarrier(),
{IntImm(DataType::Int(32), 0)}));
new_args.Set(is_1d_tma_load ? 2 : 1,
Call(DataType::Handle(), get_mbarrier(),
{IntImm(DataType::Int(32), 0)}));
return Call(op->dtype, op->op, new_args);
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
Expand Down Expand Up @@ -443,7 +455,14 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
<< "tma_load must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)];
auto new_args = op->args;
new_args.Set(1, barrier_id);
auto arg0 = op->args[0].as<Call>();
auto is_1d_tma_load =
arg0 && !arg0.value()->op.same_as(create_tma_descriptor());
if (is_1d_tma_load) {
new_args.Set(2, barrier_id);
} else {
new_args.Set(1, barrier_id);
}
Comment on lines +458 to +465
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to detect a 1D TMA load is duplicated across multiple places in this file (e.g., TmaExpectTxRewriter::VisitExpr_) and also in warp_specialized_rewriter.cc. Consider creating a helper function to encapsulate this check, for example:

// In a common header or utility file
inline bool Is1DTmaLoad(const CallNode* call) {
  if (!call->op.same_as(tma_load())) return false;
  auto arg0 = call->args[0].as<Call>();
  return arg0 && !arg0.value()->op.same_as(create_tma_descriptor());
}

This would improve code reuse and make the logic easier to maintain.

return Call(op->dtype, op->op, new_args);
} else if (op->op.same_as(mbarrier_expect_tx())) {
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))
Expand Down
Loading
Loading