Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
fd51304
Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb462…
LeiWang1999 Aug 12, 2025
c6cf3de
Merge branch 'main' of https://github.com/tile-ai/tilelang
LeiWang1999 Aug 12, 2025
9ae1f7f
Merge branch 'main' of https://github.com/tile-ai/tilelang
LeiWang1999 Aug 14, 2025
c1a456d
Merge branch 'main' of https://github.com/tile-ai/tilelang
LeiWang1999 Aug 14, 2025
b68882b
Merge branch 'main' of https://github.com/tile-ai/tilelang
LeiWang1999 Aug 14, 2025
1f811d5
Merge branch 'main' of https://github.com/tile-ai/tilelang
LeiWang1999 Aug 15, 2025
de8c4a5
Support strided tensors
LyricZhao Aug 15, 2025
1599ff1
Refactor target attribute helper functions for improved clarity
LeiWang1999 Aug 15, 2025
45f3be6
No code changes made in proxy.py and setup.py
LeiWang1999 Aug 15, 2025
f3a92a0
lint fix
LeiWang1999 Aug 15, 2025
c8e1a1b
lint fix via gemini
LeiWang1999 Aug 15, 2025
0cfe1f2
lint fix
LeiWang1999 Aug 15, 2025
6ed0611
test fix
LeiWang1999 Aug 15, 2025
056c6a3
test fix
LeiWang1999 Aug 16, 2025
a12dfad
lint fix
LeiWang1999 Aug 16, 2025
b91279e
Update wrapper.py
LeiWang1999 Aug 17, 2025
48d9a8f
test fix
LeiWang1999 Aug 17, 2025
ebda917
Merge branch 'main' of https://github.com/tile-ai/tilelang into strid…
LeiWang1999 Aug 17, 2025
c5b1a10
Merge branch 'strided_0815' of https://github.com/LeiWang1999/tilelan…
LeiWang1999 Aug 17, 2025
0a9d50f
Enhance test for InjectSoftwarePipeline by adding LowerOpaqueBlock tr…
LeiWang1999 Aug 17, 2025
9f26dbf
lint fix
LeiWang1999 Aug 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/matmul/benchmark_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_configs(args, kwargs):
from tilelang.carver.roller.rasterization import NoRasterization
import torch

arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Inverted arch selection is consistent; harden HIP detection and verify repo-wide consistency

The inversion matches other files in this PR. To avoid AttributeError on some torch builds, use getattr:

-        arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
+        arch = CUDA("cuda") if getattr(torch.version, "hip", None) is None else CDNA("hip")

Additionally, to ensure there are no missed spots still using the old mapping, you can scan the repo:

If any legacy patterns appear in contexts that should be updated, align them with this PR’s policy. Centralizing selection via a helper will also simplify future changes.


🏁 Script executed:

#!/bin/bash
# Find all inverted/new mapping occurrences
echo "Occurrences of new mapping (CUDA on non-HIP, CDNA on HIP):"
rg -nP '\bCUDA\("cuda"\)\s*if\s*getattr\(torch\.version,\s*"hip",\s*None\)\s*is\s*None\s*else\s*CDNA\("hip"\)' -C2 || true
rg -nP '\bCUDA\("cuda"\)\s*if\s*torch\.version\.hip\s*is\s*None\s*else\s*CDNA\("hip"\)' -C2 || true

# Find any remaining old mapping instances
echo
echo "Occurrences of old mapping (CDNA on non-HIP, CUDA on HIP):"
rg -nP '\bCDNA\("cuda"\)\s*if\s*torch\.version\.hip\s*is\s*None\s*else\s*CUDA\("hip"\)' -C2 || true

Length of output: 2545


Harden HIP detection by using getattr across all occurrences

To avoid AttributeError on torch builds that lack torch.version.hip, please replace raw attribute access with a safe getattr check in every spot where we currently do:

arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")

Please update the following locations:

• examples/gemm/example_gemm_autotune.py:19
• examples/analyze/example_conv_analyze.py:99
• examples/analyze/example_gemm_analyze.py:52
• benchmark/matmul/benchmark_matmul.py:56
• benchmark/matmul/benchmark_matmul_intrinsic.py:190

Example diff:

-        arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
+        arch = CUDA("cuda") if getattr(torch.version, "hip", None) is None else CDNA("hip")

For future maintainability, consider centralizing this logic into a small helper (e.g. def select_device(): ...) so any further changes only touch one place.

🤖 Prompt for AI Agents
In benchmark/matmul/benchmark_matmul.py around line 56 (and also update
examples/gemm/example_gemm_autotune.py:19,
examples/analyze/example_conv_analyze.py:99,
examples/analyze/example_gemm_analyze.py:52,
benchmark/matmul/benchmark_matmul_intrinsic.py:190), the code directly accesses
torch.version.hip which can raise AttributeError on some torch builds; replace
the raw attribute access with a safe getattr check (e.g. use
getattr(torch.version, "hip", None) is None to choose CUDA vs CDNA) in each of
those locations, and optionally factor this into a small helper function like
select_device() and call that helper from all these sites for maintainability.

topk = 10

carve_template = MatmulTemplate(
Expand Down
2 changes: 1 addition & 1 deletion benchmark/matmul/benchmark_matmul_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def get_configs(args, kwargs):
from tilelang.carver.roller.rasterization import NoRasterization
import torch

arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
topk = 10

carve_template = MatmulTemplate(
Expand Down
2 changes: 1 addition & 1 deletion examples/analyze/example_conv_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def conv(

def main():
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
result = Analyzer.analysis(my_func, cuda_device)
print(result)
print(f"Analyzed FLOPs: {result.total_flops}")
Expand Down
2 changes: 1 addition & 1 deletion examples/analyze/example_gemm_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def matmul(
def main():
my_func = kernel(128, 128, 32, 3, 128, True)

cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
result = Analyzer.analysis(my_func, cuda_device)

print(f"Analyzed FLOPs: {result.total_flops}")
Expand Down
2 changes: 0 additions & 2 deletions examples/fusedmoe/example_fusedmoe_tilelang.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from tilelang.autotuner import *
from example_fusedmoe_torch import *

# tilelang.disable_cache()


@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def moe_forward_tilelang_shared(d_hidden,
Expand Down
2 changes: 1 addition & 1 deletion examples/gemm/example_gemm_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def ref_program(A, B):

def get_configs(M, N, K, with_roller=False, topk=20):
if with_roller:
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
carve_template = MatmulTemplate(
M=M,
N=N,
Expand Down
34 changes: 5 additions & 29 deletions examples/warp_specialize/example_warp_specialize_flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,20 +145,10 @@ def flash_attn(
clear_accum=True,
wg_wait=-1)
T.barrier_wait(kv_shared_0_r_is_ready, k % 2)
T.gemm(
Q_shared_r,
KV_shared_0_r,
acc_s_0,
transpose_B=True,
wg_wait=-1)
T.gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True, wg_wait=-1)

T.barrier_wait(kv_shared_0_pe_is_ready, k % 2)
T.gemm(
Q_pe_local_0,
K_pe_shared_0,
acc_s_0,
transpose_B=True,
wg_wait=-1)
T.gemm(Q_pe_local_0, K_pe_shared_0, acc_s_0, transpose_B=True, wg_wait=-1)

T.wait_wgmma(0)

Expand Down Expand Up @@ -261,20 +251,10 @@ def flash_attn(
wg_wait=-1)

T.barrier_wait(kv_shared_1_r_is_ready, k % 2)
T.gemm(
Q_shared_r,
KV_shared_1_r,
acc_s_1,
transpose_B=True,
wg_wait=-1)
T.gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True, wg_wait=-1)

T.barrier_wait(kv_shared_1_pe_is_ready, k % 2)
T.gemm(
Q_pe_local_1,
K_pe_shared_1,
acc_s_1,
transpose_B=True,
wg_wait=-1)
T.gemm(Q_pe_local_1, K_pe_shared_1, acc_s_1, transpose_B=True, wg_wait=-1)

T.wait_wgmma(0)

Expand Down Expand Up @@ -308,11 +288,7 @@ def flash_attn(

# Step 10. compute O1 with KV_shared_1_rd
T.copy(acc_s_1, acc_s_1_cast)
T.gemm(
acc_s_1_cast,
KV_shared_1_r,
acc_o_r,
wg_wait=-1)
T.gemm(acc_s_1_cast, KV_shared_1_r, acc_o_r, wg_wait=-1)
T.copy(acc_s_1_cast, SP1_shared)
T.barrier_arrive(s_shared_ready_barrier)

Expand Down
18 changes: 9 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import fcntl
import functools
import hashlib
import io
import subprocess
import shutil
Expand All @@ -12,17 +15,14 @@
import os
import sys
import site
import hashlib
import sysconfig
import functools
import urllib.request
from packaging.version import Version
import platform
import multiprocessing
from setuptools.command.build_ext import build_ext
import importlib
import logging
import fcntl

# Configure logging with basic settings
logging.basicConfig(
Expand Down Expand Up @@ -692,15 +692,15 @@ def build_cython(self, ext):
with open(md5_path, "r") as f:
cached_hash = f.read().strip()
if cached_hash == code_hash:
logger.info("Cython jit adapter is up to date, no need to compile...")
logger.info("Cython JIT adapter is up to date, no need to compile...")
need_compile = False
else:
logger.info("Cython jit adapter is out of date, need to recompile...")
logger.info("Cython JIT adapter is out of date, need to recompile...")
else:
logger.info("No cached version found for cython jit adapter, need to compile...")
logger.info("No cached version found for Cython JIT adapter, need to compile...")

if need_compile:
logger.info("Waiting for lock to compile cython jit adapter...")
logger.info("Waiting for lock to compile Cython JIT adapter...")
with open(lock_file, 'w') as lock:
fcntl.flock(lock.fileno(), fcntl.LOCK_EX)
try:
Expand All @@ -715,7 +715,7 @@ def build_cython(self, ext):
need_compile = False

if need_compile:
logger.info("Compiling cython jit adapter...")
logger.info("Compiling Cython JIT adapter...")
temp_path = cache_dir / f"temp_{code_hash}.so"

with open(md5_path, "w") as f:
Expand All @@ -736,7 +736,7 @@ def build_cython(self, ext):
except Exception as e:
if 'temp_path' in locals() and temp_path.exists():
temp_path.unlink()
raise Exception(f"Failed to compile cython jit adapter: {e}") from e
raise Exception(f"Failed to compile Cython JIT adapter: {e}") from e
finally:
if lock_file.exists():
lock_file.unlink()
Expand Down
70 changes: 70 additions & 0 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1689,6 +1689,76 @@ void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
os << "))";
}

void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op,
std::ostream &os) { // NOLINT(*)
ICHECK_EQ(op->indices.size(), 1)
<< "Load from non-flat memory not supported.";
ICHECK(!op->predicate.defined())
<< "Predicated buffer load is not supported.";

DataType value_dtype = op->dtype;
PrimExpr index = op->indices[0];
Var buffer_var = op->buffer->data;
DataType element_dtype = op->buffer->dtype;

int lanes = op->dtype.lanes();
// delcare type.
if (value_dtype.lanes() == element_dtype.lanes()) {
std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index);
HandleVolatileLoads(ref, op, os);
} else {
bool can_vector_load = false;
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) {
const RampNode *ramp = index.as<RampNode>();
ICHECK(ramp);
can_vector_load = true;
// arith::ModularSet me = arith::Analyzer().modular_set(ramp->base);
// The condition: {k * coeff + base} divisible by the alignment for any k
// if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes()
// == 0) {
// can_vector_load = true;
// }
Comment on lines +1729 to +1734
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of commented-out code seems to be an alternative implementation for checking if a vector load is possible. If this code is no longer needed, it should be removed to improve code clarity. Leaving dead code can be confusing for future maintenance.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip for now for dynamic strides.

}

if (value_dtype.is_float4_e2m1fn() && lanes != 1) {
// A float4_e2m1fn element has 4 bits, which is an incomplete byte.
// So we cannot vector load it.
can_vector_load = false;
}
if (can_vector_load) {
std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
HandleVolatileLoads(ref, op, os);
} else {
Comment on lines +1743 to +1745
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Confirm GetVecLoad availability

The vector path relies on GetVecLoad(dtype, buffer, base). Please confirm this overload is declared/defined for CUDA and matches the expected semantics (alignment, address space handling).

Run the following to verify declaration in the header and definitions in the codebase:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Searching for VisitExpr_ BufferLoad declaration in codegen_cuda.h"
rg -n -C2 -P 'class\s+CodeGenTileLangCUDA\b(?s).*?\bVisitExpr_\s*\(\s*const\s+BufferLoadNode\s*\*\s*op\s*,\s*std::ostream\s*&\s*os\)' -- src

echo
echo "Searching for GetVecLoad declarations/definitions"
rg -n -C3 -P '\bGetVecLoad\s*\(' -- src

Length of output: 303


To be sure the GetVecLoad overload exists (and isn’t just a false negative from the previous search), let’s search for every occurrence of GetVecLoad in all codegen files:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Searching for all GetVecLoad occurrences in codegen sources"
rg -n 'GetVecLoad' -g 'src/**/codegen_*.cc' -g 'src/**/codegen_*.h'

Length of output: 326


Let’s search for any GetVecLoad declarations or definitions across the entire src/ tree to ensure it’s not defined in another module:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "Searching for explicit GetVecLoad definitions (e.g., std::string GetVecLoad(...))"
rg -nP 'std::string\s+GetVecLoad\s*\(' -- src

echo
echo "Searching for all GetVecLoad occurrences"
rg -n 'GetVecLoad' -- src

Length of output: 240


Missing GetVecLoad implementation in CUDA codegen
I searched the entire src/ tree and only found the call in codegen_cuda.cc (line 1743). There is no declaration or definition of any GetVecLoad overload for CUDA, so this reference will fail to compile.

• Add a declaration in src/target/codegen_cuda.h (e.g.:
std::string GetVecLoad(DataType dtype, const BufferNode* buffer, Expr base);)
• Provide the corresponding definition in src/target/codegen_cuda.cc, handling alignment and CUDA address spaces as needed.

This must be implemented before merging.

🤖 Prompt for AI Agents
In src/target/codegen_cuda.cc around lines 1743 to 1745, the call to
GetVecLoad(...) is unresolved because there is no CUDA-specific
declaration/definition; add a declaration in src/target/codegen_cuda.h with the
signature std::string GetVecLoad(DataType dtype, const BufferNode* buffer, Expr
base); and implement the matching definition in src/target/codegen_cuda.cc:
accept (DataType dtype, const BufferNode* buffer, Expr base), compute lane
width/alignment from dtype and buffer->dtype/strides, generate CUDA-appropriate
load code that respects buffer address spaces (global/shared/const) and
alignment/predication for unaligned or tail lanes, reuse the existing
GetVecStore/load patterns from other backends as a template, include the header
where needed, and run a build to ensure the symbol is resolved.

std::ostringstream svalue_expr;
std::string sindex = SSAGetID(PrintExpr(index), index.dtype());
std::string vid = GetVarID(buffer_var.get());
DataType elem_type = op->dtype.element_of();
for (int i = 0; i < lanes; ++i) {
std::ostringstream value_temp;
if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
value_temp << "((";
if (buffer_var.get()->dtype.is_handle()) {
auto it = alloc_storage_scope_.find(buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, value_temp);
}
}
PrintType(elem_type, value_temp);
value_temp << "*)" << vid << ')';
Comment on lines +1753 to +1761
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid printing storage scope in C-style casts unless it’s part of the type

In the per-lane fallback, the cast includes the storage scope unconditionally. Elsewhere (e.g., GetBufferRef) this is guarded by IsScopePartOfType() to avoid generating invalid types like (shared float*). Mirror that behavior here to prevent malformed code on targets where storage scope is not encoded in the type.

-        if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
-          value_temp << "((";
-          if (buffer_var.get()->dtype.is_handle()) {
-            auto it = alloc_storage_scope_.find(buffer_var.get());
-            if (it != alloc_storage_scope_.end()) {
-              PrintStorageScope(it->second, value_temp);
-            }
-          }
-          PrintType(elem_type, value_temp);
-          value_temp << "*)" << vid << ')';
-        } else {
-          value_temp << vid;
-        }
+        if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
+          value_temp << "((";
+          if (buffer_var.get()->dtype.is_handle()) {
+            auto it = alloc_storage_scope_.find(buffer_var.get());
+            if (it != alloc_storage_scope_.end() && IsScopePartOfType()) {
+              PrintStorageScope(it->second, value_temp);
+            }
+          }
+          PrintType(elem_type, value_temp);
+          value_temp << "*)" << vid << ')';
+        } else {
+          value_temp << vid;
+        }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
value_temp << "((";
if (buffer_var.get()->dtype.is_handle()) {
auto it = alloc_storage_scope_.find(buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, value_temp);
}
}
PrintType(elem_type, value_temp);
value_temp << "*)" << vid << ')';
if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
value_temp << "((";
if (buffer_var.get()->dtype.is_handle()) {
auto it = alloc_storage_scope_.find(buffer_var.get());
if (it != alloc_storage_scope_.end() && IsScopePartOfType()) {
PrintStorageScope(it->second, value_temp);
}
}
PrintType(elem_type, value_temp);
value_temp << "*)" << vid << ')';
} else {
value_temp << vid;
}
🤖 Prompt for AI Agents
In src/target/codegen_cuda.cc around lines 1753 to 1761, the C-style cast in the
per-lane fallback unconditionally prints the storage scope which can produce
invalid types like '(__shared__ float*)'; change the logic so you only emit the
storage scope when it is actually part of the buffer's type by adding a guard
that calls IsScopePartOfType(buffer_var->dtype) (or equivalent) before invoking
PrintStorageScope on the alloc_storage_scope_ entry; keep the existing
alloc_storage_scope_ lookup and only call PrintStorageScope if the entry exists
AND the scope is part of the type, otherwise skip printing the scope so the
resulting cast is valid.

} else {
value_temp << vid;
}
value_temp << '[';
PrintVecElemLoad(sindex, index.dtype(), i, value_temp);
value_temp << ']';
PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr);
}
os << svalue_expr.str();
}
}
}

void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
std::ostream &os) { // NOLINT(*)
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
Expand Down
1 change: 1 addition & 0 deletions src/target/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class CodeGenTileLangCUDA final : public CodeGenC {
void VisitStmt_(const EvaluateNode *op) final;
void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode *op) final;
void VisitExpr_(const BufferLoadNode *op, std::ostream &os) final;

// Override this as a work around for __grid_constant__ parameter
void AddFunction(const GlobalVar &gvar, const PrimFunc &f);
Expand Down
3 changes: 2 additions & 1 deletion src/tl_templates/hip/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ struct MinOp {
}
};

template <class Reducer, int threads, int scale, int thread_offset = 0> struct AllReduce {
template <class Reducer, int threads, int scale, int thread_offset = 0>
struct AllReduce {
static_assert(threads == 1024 || threads == 512 || threads == 256 ||
threads == 128 || threads == 64 || threads == 32 ||
threads == 16 || threads == 8 || threads == 4 || threads == 2);
Expand Down
34 changes: 28 additions & 6 deletions src/transform/loop_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,23 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
max_vector_size = gcd_base;
}
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);

// Generate strides if not existed
auto strides = buffer->strides;
if (buffer->strides.size() == 0) {
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
strides.push_back(stride);
stride = stride * buffer->shape[i];
}
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
}

// Generate and check element offset expression
ICHECK(indices.size() == strides.size()) << "Invalid indices and strides";
PrimExpr elem_offset = 0;
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
elem_offset = elem_offset + indices[i] * stride;
stride = stride * buffer->shape[i];
for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * strides[i];
}
Comment on lines +140 to 156
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Fix stride reversal and avoid constructing tvm::Array from iterators (compile-time bug).

Array<PrimExpr>{strides.rbegin(), strides.rend()} constructs an initializer_list from two iterator objects, not a reversed Array. This won’t compile and will break vectorization on strided buffers. Also, you don’t need to materialize/reverse strides when explicit strides are absent; you can compute elem_offset directly.

Refactor to compute elem_offset robustly for both explicit and implicit strides, and drop the invalid reverse:

Apply this diff:

-      // Generate strides if not existed
-      auto strides = buffer->strides;
-      if (buffer->strides.size() == 0) {
-        PrimExpr stride = 1;
-        for (int i = indices.size() - 1; i >= 0; --i) {
-          strides.push_back(stride);
-          stride = stride * buffer->shape[i];
-        }
-        strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
-      }
-
-      // Generate and check element offset expression
-      ICHECK(indices.size() == strides.size()) << "Invalid indices and strides";
-      PrimExpr elem_offset = 0;
-      for (int i = 0; i < indices.size(); ++i) {
-        elem_offset += indices[i] * strides[i];
-      }
+      // Compute element offset (supports both explicit and implicit strides)
+      PrimExpr elem_offset = 0;
+      if (buffer->strides.size() == 0) {
+        PrimExpr stride = 1;
+        for (int i = indices.size() - 1; i >= 0; --i) {
+          elem_offset += indices[i] * stride;
+          stride = stride * buffer->shape[i];
+        }
+      } else {
+        ICHECK(indices.size() == buffer->strides.size()) << "Invalid indices and strides";
+        for (int i = 0; i < indices.size(); ++i) {
+          elem_offset += indices[i] * buffer->strides[i];
+        }
+      }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Generate strides if not existed
auto strides = buffer->strides;
if (buffer->strides.size() == 0) {
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
strides.push_back(stride);
stride = stride * buffer->shape[i];
}
strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
}
// Generate and check element offset expression
ICHECK(indices.size() == strides.size()) << "Invalid indices and strides";
PrimExpr elem_offset = 0;
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
elem_offset = elem_offset + indices[i] * stride;
stride = stride * buffer->shape[i];
for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * strides[i];
}
// Compute element offset (supports both explicit and implicit strides)
PrimExpr elem_offset = 0;
if (buffer->strides.size() == 0) {
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
elem_offset += indices[i] * stride;
stride = stride * buffer->shape[i];
}
} else {
ICHECK(indices.size() == buffer->strides.size()) << "Invalid indices and strides";
for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * buffer->strides[i];
}
}
🤖 Prompt for AI Agents
In src/transform/loop_vectorize.cc around lines 140 to 156, the code attempts to
build reversed strides via Array<PrimExpr>{strides.rbegin(), strides.rend()}
which is invalid and unnecessary; instead handle two cases when computing
elem_offset: if buffer->strides is non-empty use those strides directly,
otherwise compute the running stride on the fly from the buffer->shape (start
with stride=1 and process dimensions from last to first) and accumulate
indices[i]*current_stride into elem_offset without creating a temporary reversed
array. Keep the ICHECK that indices and strides align for the explicit-stride
path, and ensure the implicit-stride path computes elem_offset robustly using
shape-derived strides.

while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_,
Expand Down Expand Up @@ -229,10 +241,19 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
ICHECK(target_vectorized_size >= 1);
if (target_vectorized_size == 1)
return true;
// bind thread range

// Extent must be divisible
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size),
0))
return false;

// The base offset must be divisible
if (!analyzer->CanProveEqual(
FloorMod(Substitute(expr, {{var, 0}}), target_vectorized_size), 0)) {
return false;
}

// Bind thread range
Var v0("v0"), v1("v1");
analyzer->Bind(v0, Range(0, target_vectorized_size));
analyzer->Bind(v1, Range(0, analyzer->Simplify(FloorDiv(
Expand All @@ -241,7 +262,8 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
// This simplify is necessary for thread region specifiled

// This simplify is necessary for thread region specified
// optimizations.
expr_vectorized = analyzer->Simplify(expr_vectorized);
auto ramp_node = expr_vectorized.as<RampNode>();
Expand Down
48 changes: 46 additions & 2 deletions testing/python/language/test_tilelang_language_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16")
out_idx=[1],
target="cuda",
pass_configs={
"tl.disable_warp_specialized": True,
"tl.disable_tma_lower": True
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True
})
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
Expand All @@ -42,5 +42,49 @@ def test_tilelang_copy():
run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576, dtype="float")


def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"):

@T.prim_func
def main(
A: T.StridedTensor((M, N), (NN, 1), dtype),
B: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = A[by * block_M + i, bx * block_N + j]

return main


def run_tilelang_copy_with_stride(M=1024,
N=1024,
NN=2048,
block_M=128,
block_N=128,
dtype="float16"):
if isinstance(NN, int):
assert NN > N, "NN must be greater than N"
program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
})
if isinstance(NN, T.Var):
NN = N * 2
a = torch.randn(M, NN, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a[:, :N])
torch.testing.assert_close(b, a[:, :N], rtol=1e-2, atol=1e-2)

Comment on lines +66 to +82
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Fix TypeError in isinstance check and allow NN == N

isinstance(NN, T.Var) will raise TypeError because T.Var is a function, not a type. Also, the assertion can safely allow NN == N.

Apply this diff:

-    if isinstance(NN, int):
-        assert NN > N, "NN must be greater than N"
+    if isinstance(NN, int):
+        assert NN >= N, "NN must be greater than or equal to N"
@@
-    if isinstance(NN, T.Var):
-        NN = N * 2
+    # If NN is symbolic (e.g., tvm.tir.Var), substitute a concrete allocation size.
+    if not isinstance(NN, int):
+        NN = N * 2

Optional: If you prefer an explicit type check, import tvm.tir and use isinstance(NN, tvm.tir.PrimExpr).

Run this quick scan to ensure similar mistaken isinstance(..., T.Var) patterns don’t exist elsewhere:


🏁 Script executed:

#!/bin/bash
rg -nP -C2 'isinstance\s*\(\s*[^,]+,\s*T\.Var\s*\)'

Length of output: 2709


🏁 Script executed:

#!/bin/bash
# Display import aliases for T in test_tilelang_language_copy.py
sed -n '1,50p' testing/python/language/test_tilelang_language_copy.py | sed -n '1,50p'

Length of output: 1713


🏁 Script executed:

#!/bin/bash
# Display import aliases for T in the relevant test files
sed -n '1,20p' testing/python/jit/test_tilelang_jit_gemm_ctypes.py
sed -n '1,20p' testing/python/jit/test_tilelang_jit_gemm_cython.py

Length of output: 753


Fix symbolic-dimension checks and allow NN == N

The isinstance(..., T.Var) guard will fail at runtime because T.Var is a constructor function, not a type. Also, an integer stride equal to N is valid. Update all three tests to:

  • Use assert ... >= N instead of > N.
  • Replace isinstance(x, T.Var) with if not isinstance(x, int) to catch symbolic cases.

Affected files:

  • testing/python/language/test_tilelang_language_copy.py (lines 66–82)
  • testing/python/jit/test_tilelang_jit_gemm_ctypes.py (lines 370–375)
  • testing/python/jit/test_tilelang_jit_gemm_cython.py (lines 379–384 and 448–453)

Apply these diffs:

--- a/testing/python/language/test_tilelang_language_copy.py
+++ b/testing/python/language/test_tilelang_language_copy.py
@@ -66,7 +66,7 @@ def test_tilelang_copy_with_stride():
-    if isinstance(NN, int):
-        assert NN > N, "NN must be greater than N"
+    if isinstance(NN, int):
+        assert NN >= N, "NN must be greater than or equal to N"

@@ -75,7 +75,8 @@ def test_tilelang_copy_with_stride():
-    if isinstance(NN, T.Var):
-        NN = N * 2
+    # For any non-int (e.g., tvm.tir.Var), pick a concrete size
+    if not isinstance(NN, int):
+        NN = N * 2
--- a/testing/python/jit/test_tilelang_jit_gemm_ctypes.py
+++ b/testing/python/jit/test_tilelang_jit_gemm_ctypes.py
@@ -368,12 +368,12 @@ def test_matmul_ctypes():
-    if isinstance(M, T.Var):
-        M = 1024
+    if not isinstance(M, int):
+        M = 1024

-    if isinstance(N, T.Var):
-        N = 1024
+    if not isinstance(N, int):
+        N = 1024

-    if isinstance(K, T.Var):
-        K = 768
+    if not isinstance(K, int):
+        K = 768
--- a/testing/python/jit/test_tilelang_jit_gemm_cython.py
+++ b/testing/python/jit/test_tilelang_jit_gemm_cython.py
@@ -377,12 +377,12 @@ def test_matmul_cython():
-    if isinstance(M, T.Var):
-        M = 1024
+    if not isinstance(M, int):
+        M = 1024

-    if isinstance(N, T.Var):
-        N = 1024
+    if not isinstance(N, int):
+        N = 1024

-    if isinstance(K, T.Var):
-        K = 768
+    if not isinstance(K, int):
+        K = 768
@@ -446,12 +446,12 @@ def test_matmul_cython_outidx_minus1():
-    if isinstance(M, T.Var):
-        M = 1024
+    if not isinstance(M, int):
+        M = 1024

-    if isinstance(N, T.Var):
-        N = 1024
+    if not isinstance(N, int):
+        N = 1024

-    if isinstance(K, T.Var):
-        K = 768
+    if not isinstance(K, int):
+        K = 768
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if isinstance(NN, int):
assert NN > N, "NN must be greater than N"
program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
})
if isinstance(NN, T.Var):
NN = N * 2
a = torch.randn(M, NN, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a[:, :N])
torch.testing.assert_close(b, a[:, :N], rtol=1e-2, atol=1e-2)
if isinstance(NN, int):
assert NN >= N, "NN must be greater than or equal to N"
program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
})
# For any non-int (e.g., tvm.tir.Var), pick a concrete size
if not isinstance(NN, int):
NN = N * 2
a = torch.randn(M, NN, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a[:, :N])
torch.testing.assert_close(b, a[:, :N], rtol=1e-2, atol=1e-2)
🤖 Prompt for AI Agents
testing/python/language/test_tilelang_language_copy.py around lines 66-82: the
code currently uses isinstance(NN, T.Var) (which is invalid because T.Var is a
constructor) and requires NN > N even though NN == N is valid; change the
integer check to assert NN >= N and replace the symbolic detection with if not
isinstance(NN, int) to catch non-integer (symbolic) cases; apply the same two
fixes to the other affected tests:
testing/python/jit/test_tilelang_jit_gemm_ctypes.py (lines 370-375) and
testing/python/jit/test_tilelang_jit_gemm_cython.py (lines 379-384 and 448-453).


def test_tilelang_copy_with_stride():
run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128)
run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.symbolic("NN"), block_M=128, block_N=128)


if __name__ == "__main__":
tilelang.testing.main()
Loading
Loading