Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2830,8 +2830,14 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {

void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
CHECK_LE(lanes, 4) << "Translate Ramp Node " << tvm::ffi::GetRef<Ramp>(op)
<< " with " << lanes << " lanes is not allowed.";
// TODO(chaofan): Comment the ramp lanes limit for now since we have
// LegalizeVectorizedLoop to automatically legalize vectorized loop whose
// width exceeds the limit. But we should add check here for safety in the
// future. The check should be aligned to certain bit width like 128bits or
// 256bits.

// CHECK_LE(lanes, 8) << "Translate Ramp Node " << tvm::ffi::GetRef<Ramp>(op)
// << "error: " << lanes << " exceeds max ramp lanes 8.";
Comment on lines +2833 to +2840
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This TODO comment indicates that the commented-out CHECK is temporarily removed, which creates a safety issue. While the LegalizeVectorizedLoop pass may handle oversized vectorized loops, removing runtime checks without proper validation could allow invalid code to pass through. The TODO should either be addressed in this PR or tracked in a separate issue with specific bit-width constraints (128/256 bits) documented.

Suggested change
// TODO(chaofan): Comment the ramp lanes limit for now since we have
// LegalizeVectorizedLoop to automatically legalize vectorized loop whose
// width exceeds the limit. But we should add check here for safety in the
// future. The check should be aligned to certain bit width like 128bits or
// 256bits.
// CHECK_LE(lanes, 8) << "Translate Ramp Node " << tvm::ffi::GetRef<Ramp>(op)
// << "error: " << lanes << " exceeds max ramp lanes 8.";
// Enforce a maximum total vector width for safety. Even though
// LegalizeVectorizedLoop is expected to legalize oversized vectorized loops,
// we keep this runtime check to prevent emitting vectors wider than the
// supported bit-width.
//
// Here we conservatively cap the total vector width at 256 bits.
int max_vector_bits = 256;
int dtype_bits = op->dtype.bits();
if (dtype_bits > 0) {
int max_lanes = max_vector_bits / dtype_bits;
CHECK_LE(lanes, max_lanes)
<< "Translate Ramp Node " << tvm::ffi::GetRef<Ramp>(op)
<< " error: " << lanes << " exceeds max ramp lanes " << max_lanes
<< " for element bit-width " << dtype_bits
<< " (max total vector width " << max_vector_bits << " bits).";
}

Copilot uses AI. Check for mistakes.
os << "(make_";
PrintType(op->dtype, os);
os << "(";
Expand Down
43 changes: 43 additions & 0 deletions src/tl_templates/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,49 @@ TL_DEVICE int4_t make_int4(short x0, short x1, short y0, short y1, short z0,
return result;
}

// Pack four char values.
TL_DEVICE unsigned int make_uint(unsigned char x0, unsigned char x1,
unsigned char x2, unsigned char x3) {
return (x3 << 24) | (x2 << 16) | (x1 << 8) | x0;
Comment on lines +141 to +143

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Prevent UB when packing uint8 bytes into uint

The shifts in make_uint operate on int because unsigned char is promoted, so when any input byte is ≥128 the x3 << 24 (or other shifts) overflows a signed 32‑bit int, which is undefined behavior in C++. That can mispack vectorized uint8 data on device for values with the high bit set. Cast each operand to unsigned int (or uint32_t) before shifting to make the packing defined.

Useful? React with 👍 / 👎.

}

// Pack eight char values.
TL_DEVICE uint2 make_uint2(unsigned char x0, unsigned char x1, unsigned char x2,
unsigned char x3, unsigned char y0, unsigned char y1,
unsigned char y2, unsigned char y3) {
uint2 result;
result.x = make_uint(x0, x1, x2, x3);
result.y = make_uint(y0, y1, y2, y3);
return result;
}

// Pack sixteen char values.
TL_DEVICE uint4 make_uint4(unsigned char x0, unsigned char x1, unsigned char x2,
unsigned char x3, unsigned char y0, unsigned char y1,
unsigned char y2, unsigned char y3, unsigned char z0,
unsigned char z1, unsigned char z2, unsigned char z3,
unsigned char w0, unsigned char w1, unsigned char w2,
unsigned char w3) {
uint4 result;
result.x = make_uint(x0, x1, x2, x3);
result.y = make_uint(y0, y1, y2, y3);
result.z = make_uint(z0, z1, z2, z3);
result.w = make_uint(w0, w1, w2, w3);
return result;
}

TL_DEVICE uint4 make_uint4(unsigned short x0, unsigned short x1,
unsigned short y0, unsigned short y1,
unsigned short z0, unsigned short z1,
unsigned short w0, unsigned short w1) {
uint4 result;
*((ushort2 *)&result.x) = make_ushort2(x0, x1);
*((ushort2 *)&result.y) = make_ushort2(y0, y1);
*((ushort2 *)&result.z) = make_ushort2(z0, z1);
*((ushort2 *)&result.w) = make_ushort2(w0, w1);
return result;
}

// Pack eight int values.
TL_DEVICE longlong4 make_longlong4(int x0, int x1, int y0, int y1, int z0,
int z1, int w0, int w1) {
Expand Down
35 changes: 35 additions & 0 deletions testing/python/language/test_tilelang_language_vectorize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import tilelang.testing
import tilelang.language as T
import pytest


@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True})
Expand Down Expand Up @@ -112,5 +113,39 @@ def test_vectorize_invariant_index():
run_vectorize_invariant_index(N, M * 7, 14)


@tilelang.jit
def vectorize_test_all_dtypes(dtype, vec_num):
@T.prim_func
def main(A: T.Tensor[(64,), dtype]):
with T.Kernel(1, threads=256):
for i in T.vectorized(vec_num):
A[i] = T.cast(i + 1, dtype)

return main


@pytest.mark.parametrize(
"dtype",
[
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e8m0fnu,
],
)
@pytest.mark.parametrize("vec_num", [1, 2, 4, 8])
def test_vectorize_all_dtypes(dtype, vec_num):
x = torch.empty((64,), dtype=dtype, device="cuda")
kernel = vectorize_test_all_dtypes(dtype, vec_num)
kernel(x)
Comment on lines +144 to +147
Copy link

Copilot AI Dec 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test doesn't verify that the kernel executed correctly. It should check that the tensor values match the expected output (i + 1 for each index i). Currently, the test only verifies that the kernel runs without error, but doesn't validate the results.

Copilot uses AI. Check for mistakes.

Comment on lines +116 to +148
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add output verification to the vectorization test.

The test currently only verifies that the kernel executes without crashing, but doesn't validate the correctness of the vectorized memory operations. This significantly reduces the test's effectiveness.

🔎 Proposed enhancement to verify output correctness
 @pytest.mark.parametrize("vec_num", [1, 2, 4, 8])
 def test_vectorize_all_dtypes(dtype, vec_num):
     x = torch.empty((64,), dtype=dtype, device="cuda")
     kernel = vectorize_test_all_dtypes(dtype, vec_num)
     kernel(x)
+    
+    # Verify the kernel wrote the expected values
+    expected = torch.arange(1, vec_num + 1, dtype=dtype, device="cuda")
+    torch.testing.assert_close(x[:vec_num], expected, atol=0, rtol=0)

This ensures the vectorized writes are actually working correctly for all data types and vectorization widths.

🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_vectorize.py around lines
116-148, the test calls the kernel but does not assert that the buffer contents
are correct; after kernel(x) run, construct the expected tensor (values 1..64
cast to dtype on CUDA), then verify x matches expected: for integer dtypes use
exact equality, for floating dtypes use torch.testing.assert_close with small
atol/rtol; ensure comparisons run on the same device and handle any
dtype-specific quirks (use torch.arange(1,65, dtype=dtype, device='cuda') and
cast as needed) so the test fails if vectorized writes produced incorrect
results.


if __name__ == "__main__":
tilelang.testing.main()
Loading