-
Notifications
You must be signed in to change notification settings - Fork 331
[Feature] Add 1D TMA support #761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
8af8b7e
31c66b8
3ef6139
4eff7bd
87a1de6
d4df7c7
a924ccf
50df72b
616e8e9
f611c9f
3c187bc
fa4fad3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| import argparse | ||
| import tilelang | ||
| import tilelang.language as T | ||
| import torch | ||
|
|
||
| tilelang.disable_cache() | ||
|
|
||
|
|
||
| def ref_program(x, y): | ||
| return x + y | ||
|
|
||
|
|
||
| @tilelang.jit(out_idx=[-1]) | ||
| def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): | ||
|
|
||
| @T.prim_func | ||
| def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor( | ||
| (M, N), out_dtype)): | ||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): | ||
| A_shared = T.alloc_shared((block_M, block_N), in_dtype) | ||
| B_shared = T.alloc_shared((block_M, block_N), in_dtype) | ||
| C_local = T.alloc_fragment((block_M, block_N), out_dtype) | ||
| C_shared = T.alloc_shared((block_M, block_N), out_dtype) | ||
|
|
||
| T.copy(A[by * block_M, bx * block_N], A_shared) | ||
| T.copy(B[by * block_M, bx * block_N], B_shared) | ||
| for (local_y, local_x) in T.Parallel(block_M, block_N): | ||
| C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] | ||
| T.copy(C_local, C_shared) | ||
| T.copy(C_shared, C[by * block_M, bx * block_N]) | ||
|
|
||
| return elem_add | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--m", type=int, default=128) | ||
| parser.add_argument("--n", type=int, default=128) | ||
| args, _ = parser.parse_known_args() | ||
| M, N = args.m, args.n | ||
|
|
||
| a = torch.randn(M, N, dtype=torch.float32, device="cuda") | ||
| b = torch.randn(M, N, dtype=torch.float32, device="cuda") | ||
|
|
||
| # Default config | ||
| config = {"block_M": 128, "block_N": 128, "threads": 128} | ||
| kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") | ||
|
|
||
| out = kernel(a, b) | ||
| torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) | ||
| print("All passed!") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,15 @@ | ||
| import tilelang.testing | ||
| import example_elementwise_add | ||
| import example_elementwise_add_tma_1d | ||
|
|
||
|
|
||
| def test_example_elementwise_add(): | ||
| example_elementwise_add.main() | ||
|
|
||
|
|
||
| def test_example_elementwise_add_tma_1d(): | ||
| example_elementwise_add_tma_1d.main() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -772,19 +772,138 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stride *= s; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Array<PrimExpr> global_indices; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (auto r : global_range) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| global_indices.push_back(r->min); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<PrimExpr> global_strides; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr global_stride = 1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t i = 0; i < global_tensor->shape.size(); i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto s = global_tensor->shape[global_tensor->shape.size() - i - 1]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| global_strides.insert(global_strides.begin(), global_stride); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| global_stride *= s; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ICHECK(strides.size() == indices.size()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| << "strides.size() != indices.size()" << strides.size() << " " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| << indices.size(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr offset = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t i = 0; i < indices.size(); i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| offset += indices[i] * strides[i]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr global_offset = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t i = 0; i < global_indices.size(); i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| global_offset += global_indices[i] * global_strides[i]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto shared_tensor_before_remap = shared_tensor; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Layout shared_layout; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (T.layout_map.count(shared_tensor)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shared_layout = T.layout_map[shared_tensor]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shared_tensor = T.buffer_remap[shared_tensor]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Add 1D TMA copy when the global and shared memory is contiguous | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Check if shared_tensor->name is present in T.buffer_var_gemm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bool shared_is_contiguous = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (const auto &v : T.buffer_var_gemm) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (v->name_hint == shared_tensor->name) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shared_is_contiguous = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bool shared_not_full_dim_encounter = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+805
to
+816
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid name-based GEMM detection; use Var identity and skip 1D path when shared has a layout map
Apply: - // Add 1D TMA copy when the global and shared memory is contiguous
- {
+ // Add 1D TMA copy when the global and shared memory is contiguous
+ // Only when shared has no remapped/swizzled layout.
+ if (!T.layout_map.count(shared_tensor_before_remap)) {
// Check if shared_tensor->name is present in T.buffer_var_gemm
// (Array<PrimExpr>) to avoid use 1D TMA copy for swizzled layout
- bool shared_is_contiguous = true;
- for (const auto &v : T.buffer_var_gemm) {
- if (v->name_hint == shared_tensor->name) {
- shared_is_contiguous = false;
- break;
- }
- }
+ bool shared_is_contiguous = true;
+ for (const auto &v : T.buffer_var_gemm) {
+ if (v.same_as(shared_tensor_before_remap->data)) {
+ shared_is_contiguous = false;
+ break;
+ }
+ }📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (ssize_t i = shared_range.size() - 1; i >= 0; --i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!shared_not_full_dim_encounter) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!analyzer->CanProve(shared_range[i]->extent == | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shared_tensor_before_remap->shape[i] && | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shared_range[i]->min == 0)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shared_not_full_dim_encounter = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!analyzer->CanProve(shared_range[i]->extent == 1)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shared_is_contiguous = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Currently we check the empty stride of global tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bool global_is_contiguous = !global_tensor->strides.empty(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bool global_not_full_dim_encounter = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (ssize_t i = global_range.size() - 1; i >= 0; --i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!global_not_full_dim_encounter) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!analyzer->CanProve(global_range[i]->extent == | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| global_tensor->shape[i] && | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| global_range[i]->min == 0)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| global_not_full_dim_encounter = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!analyzer->CanProve(global_range[i]->extent == 1)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| global_is_contiguous = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+832
to
+847
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. global_is_contiguous initialization bug
Apply: - // Currently we check the empty stride of global tensor
- bool global_is_contiguous = !global_tensor->strides.empty();
+ // Start optimistic; invalidate below if pattern breaks contiguity
+ bool global_is_contiguous = true;Optionally, add an extra guard to verify explicit strides match compact layout when 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Ensure there is element match and no OOB | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr shared_elements = 1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t i = 0; i < shared_range.size(); i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shared_elements *= shared_range[i]->extent; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr global_elements = 1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t i = 0; i < global_range.size(); i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| global_elements *= global_range[i]->extent; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bool element_match = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| analyzer->CanProveEqual(shared_elements, global_elements); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bool no_oob = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t i = 0; i < shared_range.size(); i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!analyzer->CanProve(shared_range[i]->min + shared_range[i]->extent <= | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shared_tensor_before_remap->shape[i])) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| no_oob = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (size_t i = 0; i < global_range.size(); i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!analyzer->CanProve(global_range[i]->min + global_range[i]->extent <= | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| global_tensor->shape[i])) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| no_oob = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Add 1D TMA copy | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // LOG(INFO) << "shared_is_contiguous: " << shared_is_contiguous; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // LOG(INFO) << "global_is_contiguous: " << global_is_contiguous; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // LOG(INFO) << "element_match: " << element_match; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // LOG(INFO) << "no_oob: " << no_oob; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (shared_is_contiguous && global_is_contiguous && element_match && | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| no_oob) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // LOG(INFO) << "TMA 1D bulk copy is supported"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr elements = analyzer->Simplify(shared_elements); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr shared_addr = shared_tensor_before_remap.access_ptr( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| is_load ? 2 : 1, DataType::Handle(), 1, offset, elements); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr global_addr = global_tensor.access_ptr( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| is_load ? 1 : 2, DataType::Handle(), 1, global_offset, elements); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Stmt tma_copy; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (is_load) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // the zero is a placeholder for mbarrier id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tma_copy = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Evaluate(Call(DataType::Handle(), tma_load(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| {shared_addr, global_addr, 0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elements * shared_tensor_before_remap->dtype.bytes(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| this->eviction_policy})); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tma_copy = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Evaluate(Call(DataType::Handle(), tma_store(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| {global_addr, shared_addr, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elements * shared_tensor_before_remap->dtype.bytes(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| this->eviction_policy})); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return tma_copy; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TMADesc desc; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Verify copy rank | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| desc.rank = global_tensor->shape.size(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1221,10 +1340,11 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Register the Copy operation with TVM's TIR system | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // This makes the copy operation available for use in TVM programs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // - Takes 4 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // eviction_policy | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // - Marked as opaque since it has side effects (memory writes) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TIR_REGISTER_TL_OP(Copy, copy) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .set_num_inputs(4) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .set_num_inputs(5) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .set_attr<TCallEffectKind>("TCallEffectKind", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Integer(CallEffectKind::kOpaque)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -171,6 +171,16 @@ tma_load_im2col(const CUtensorMap &descriptor, BarrierType &smem_mbar, | |||||||||||||||||||||||||||||||||||||
| : "memory"); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> | ||||||||||||||||||||||||||||||||||||||
| TL_DEVICE void tma_store(void *gmem_ptr, void *smem_ptr, uint32_t size) { | ||||||||||||||||||||||||||||||||||||||
| uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); | ||||||||||||||||||||||||||||||||||||||
| asm volatile("cp.async.bulk.global.shared::cta.bulk_group" | ||||||||||||||||||||||||||||||||||||||
| ".L2::cache_hint [%0], [%1], %2, %3;" | ||||||||||||||||||||||||||||||||||||||
| : | ||||||||||||||||||||||||||||||||||||||
| : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint) | ||||||||||||||||||||||||||||||||||||||
| :); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+174
to
+182
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainAdd memory clobber and confirm cp.async.bulk syntax (consider adding ::bytes).
Suggested diff: TL_DEVICE void tma_store(void *gmem_ptr, void *smem_ptr, uint32_t size) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
- asm volatile("cp.async.bulk.global.shared::cta.bulk_group"
- ".L2::cache_hint [%0], [%1], %2, %3;"
- :
- : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint)
- :);
+ asm volatile("cp.async.bulk.global.shared::cta.bulk_group"
+ "::bytes.L2::cache_hint [%0], [%1], %2, %3;"
+ :
+ : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint)
+ : "memory");
}If your PTX toolchain already accepts the current mnemonic, at minimum add the "memory" clobber. I can help run a quick compile check to validate the opcode variant you intend to use. 🏁 Script executed: #!/bin/bash
set -e
# Locate the copy_sm90.h file in the repository
file=$(fd copy_sm90.h | head -n1)
echo "Processing file: $file"
# 1. List all tma_store overloads in this file
echo -e "\n=== tma_store overloads ==="
rg -nP '\btemplate.*tma_store' -n "$file"
# 2. Show all inline asm usages of cp.async in this file
echo -e "\n=== cp.async usages ==="
rg -n 'cp.async' -n "$file"
# 3. Check for any descriptor-based tma_store overload (e.g., taking a TmaDescriptor)
echo -e "\n=== Descriptor-based tma_store (TmaDescriptor) ==="
rg -n 'TmaDescriptor' -n .
# 4. For each tma_store overload, inspect whether the asm has a "memory" clobber
echo -e "\n=== Check for memory clobber in tma_store asm ==="
rg -nP 'asm volatile\("cp.async.*\:.*\:.*"[^"]*\)\s*;\s*:' -n "$file" -A2 | grep -C2 clobber || echo "No explicit memory clobber found in above asm blocks"Length of output: 331 🏁 Script executed: #!/bin/bash
set -e
# File path
file="src/tl_templates/cuda/copy_sm90.h"
echo "1) tma_store occurrences in $file"
grep -n "tma_store" "$file" || echo "No tma_store found"
echo -e "\n2) TmaDescriptor usage in repository"
grep -R -n "TmaDescriptor" . || echo "No TmaDescriptor references"
echo -e "\n3) cp.async usages in $file with context"
grep -n -C2 "cp.async" "$file" || echo "No cp.async found"
echo -e "\n4) Snippet around lines 160-182 in $file"
sed -n '160,182p' "$file"Length of output: 6357 🏁 Script executed: #!/bin/bash
set -e
# Display the first descriptor-based tma_store overload to verify clobber and opcode syntax
echo "=== Descriptor-based tma_store (1d) snippet ==="
sed -n '185,200p' src/tl_templates/cuda/copy_sm90.hLength of output: 938 Add missing “memory” clobber (and optionally The size-based store at src/tl_templates/cuda/copy_sm90.h (around lines 175–182) currently omits a clobber list and does not annotate the opcode with Suggested diff: template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
TL_DEVICE void tma_store(void *gmem_ptr, void *smem_ptr, uint32_t size) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
- asm volatile("cp.async.bulk.global.shared::cta.bulk_group"
- ".L2::cache_hint [%0], [%1], %2, %3;"
- :
- : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint)
- :);
+ asm volatile("cp.async.bulk.global.shared::cta.bulk_group"
+ "::bytes.L2::cache_hint [%0], [%1], %2, %3;"
+ :
+ : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(size), "l"(cache_hint)
+ : "memory");
}• 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL> | ||||||||||||||||||||||||||||||||||||||
| TL_DEVICE void tma_store(const CUtensorMap &descriptor, | ||||||||||||||||||||||||||||||||||||||
| void const *const smem_ptr, int32_t const &crd0) { | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -62,10 +62,17 @@ class TmaTraitsCollector : public StmtExprVisitor { | |||||||||||||||||||||||||||||||||||||||||||||
| private: | ||||||||||||||||||||||||||||||||||||||||||||||
| void VisitExpr_(const CallNode *call) final { | ||||||||||||||||||||||||||||||||||||||||||||||
| if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { | ||||||||||||||||||||||||||||||||||||||||||||||
| Call access_ptr = Downcast<Call>(call->args[2]); | ||||||||||||||||||||||||||||||||||||||||||||||
| ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); | ||||||||||||||||||||||||||||||||||||||||||||||
| int type_bytes = access_ptr->args[0]->dtype.bytes(); | ||||||||||||||||||||||||||||||||||||||||||||||
| bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes; | ||||||||||||||||||||||||||||||||||||||||||||||
| auto arg0 = call->args[0].as<Call>(); | ||||||||||||||||||||||||||||||||||||||||||||||
| if (call->op.same_as(tma_load()) && arg0 && | ||||||||||||||||||||||||||||||||||||||||||||||
| !arg0.value()->op.same_as(create_tma_descriptor())) { | ||||||||||||||||||||||||||||||||||||||||||||||
| // 1D TMA load has tvm_access_ptr of shared tensor in its args[0] | ||||||||||||||||||||||||||||||||||||||||||||||
| bulk_copy_bytes = call->args[3] * loop_extents; | ||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||
| Call access_ptr = Downcast<Call>(call->args[2]); | ||||||||||||||||||||||||||||||||||||||||||||||
| ICHECK(access_ptr->op.same_as(builtin::tvm_access_ptr())); | ||||||||||||||||||||||||||||||||||||||||||||||
| int type_bytes = access_ptr->args[0]->dtype.bytes(); | ||||||||||||||||||||||||||||||||||||||||||||||
| bulk_copy_bytes += access_ptr->args[3] * loop_extents * type_bytes; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+65
to
+75
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: bulk_copy_bytes overwritten for 1D loads; should accumulate (+=), not assign (=). For multiple 1D tma_load ops inside the same block/then-case, using '=' drops previously accounted bytes, leading to under-issuing mbarrier_expect_tx. Match the 2D path’s accumulation. Apply this fix: - // 1D TMA load has tvm_access_ptr of shared tensor in its args[0]
- bulk_copy_bytes = call->args[3] * loop_extents;
+ // 1D TMA load has tvm_access_ptr of shared tensor in its args[0]
+ bulk_copy_bytes += call->args[3] * loop_extents;📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| StmtExprVisitor::VisitExpr_(call); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -155,10 +162,15 @@ class TmaExpectTxRewriter : public IRMutatorWithAnalyzer { | |||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr VisitExpr_(const CallNode *op) { | ||||||||||||||||||||||||||||||||||||||||||||||
| if (op->op.same_as(tma_load())) { | ||||||||||||||||||||||||||||||||||||||||||||||
| auto arg0 = op->args[0].as<Call>(); | ||||||||||||||||||||||||||||||||||||||||||||||
| bool is_1d_tma_load = | ||||||||||||||||||||||||||||||||||||||||||||||
| arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && | ||||||||||||||||||||||||||||||||||||||||||||||
| op->op.same_as(tma_load()); | ||||||||||||||||||||||||||||||||||||||||||||||
| visited_tma_load_ = true; | ||||||||||||||||||||||||||||||||||||||||||||||
| Array<PrimExpr> new_args = op->args; | ||||||||||||||||||||||||||||||||||||||||||||||
| new_args.Set(1, Call(DataType::Handle(), get_mbarrier(), | ||||||||||||||||||||||||||||||||||||||||||||||
| {IntImm(DataType::Int(32), 0)})); | ||||||||||||||||||||||||||||||||||||||||||||||
| new_args.Set(is_1d_tma_load ? 2 : 1, | ||||||||||||||||||||||||||||||||||||||||||||||
| Call(DataType::Handle(), get_mbarrier(), | ||||||||||||||||||||||||||||||||||||||||||||||
| {IntImm(DataType::Int(32), 0)})); | ||||||||||||||||||||||||||||||||||||||||||||||
| return Call(op->dtype, op->op, new_args); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| return IRMutatorWithAnalyzer::VisitExpr_(op); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -443,7 +455,14 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { | |||||||||||||||||||||||||||||||||||||||||||||
| << "tma_load must be in the tma_op_to_barrier_id_"; | ||||||||||||||||||||||||||||||||||||||||||||||
| auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)]; | ||||||||||||||||||||||||||||||||||||||||||||||
| auto new_args = op->args; | ||||||||||||||||||||||||||||||||||||||||||||||
| new_args.Set(1, barrier_id); | ||||||||||||||||||||||||||||||||||||||||||||||
| auto arg0 = op->args[0].as<Call>(); | ||||||||||||||||||||||||||||||||||||||||||||||
| auto is_1d_tma_load = | ||||||||||||||||||||||||||||||||||||||||||||||
| arg0 && !arg0.value()->op.same_as(create_tma_descriptor()); | ||||||||||||||||||||||||||||||||||||||||||||||
| if (is_1d_tma_load) { | ||||||||||||||||||||||||||||||||||||||||||||||
| new_args.Set(2, barrier_id); | ||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||
| new_args.Set(1, barrier_id); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+458
to
+465
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic to detect a 1D TMA load is duplicated across multiple places in this file (e.g., // In a common header or utility file
inline bool Is1DTmaLoad(const CallNode* call) {
if (!call->op.same_as(tma_load())) return false;
auto arg0 = call->args[0].as<Call>();
return arg0 && !arg0.value()->op.same_as(create_tma_descriptor());
}This would improve code reuse and make the logic easier to maintain. |
||||||||||||||||||||||||||||||||||||||||||||||
| return Call(op->dtype, op->op, new_args); | ||||||||||||||||||||||||||||||||||||||||||||||
| } else if (op->op.same_as(mbarrier_expect_tx())) { | ||||||||||||||||||||||||||||||||||||||||||||||
| ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op))) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Don’t assume compact layout for address math; use buffer strides when available.
The 1D path computes
global_strides(and earlier, sharedstrides) from shapes, implicitly assuming a compact/row-major layout. IfBuffer.stridesis present (non-compact tensors, slices, or external buffers), the current offset calculations forglobal_offset/offsetwill be wrong.global_tensor->strideswhen available; otherwise, fall back to the current shape-derived logic.shared_tensor_before_remapas well to future-proof shared slices with explicit strides.Apply this diff within the changed hunk to honor
global_tensor->strides:Additionally (outside the changed hunk), consider the analogous update for the shared side:
If you’d like, I can draft the corresponding update for the shared strides block and thread through unit tests that cover non-compact cases (e.g., strided views and sliced buffers).
Also applies to: 794-799
🤖 Prompt for AI Agents