Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/transform/loop_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,24 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter),
0))
return false;

// Check if expr is invariant within vector boundaries
// We're trying to prove the access expression A[f(var)] depends only on
// floor(var/vecsize), not on var%vecsize
// Mathematically:
// \forall var, f(floor(var/vecsize)*vecsize + var%vecsize) ==
// f(floor(var/vecsize)*vecsize + 0)
// Example: for i in T.vectorized(8):
// A[i] = B[i] * C[i//4]
// if vecsize=4, f(i)=i//4 depends only on i//4
// Therefore A[i] = B[i] * C[i//4] can be vectorized with vecsize=4
PrimExpr var_aligned =
floordiv(var, target_vectorized_size) * target_vectorized_size;
PrimExpr expr_aligned = Substitute(expr, {{var, var_aligned}});
if (analyzer->CanProveEqual(expr, expr_aligned)) {
return true;
}
Comment on lines 293 to 309
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix syntax error and function name casing.

The map initialization is missing an opening brace, and floordiv should be FloorDiv to match the naming convention used elsewhere in the file (e.g., lines 290, 311).

Apply this diff to fix both issues:

   // Check if expr is invariant within vector boundaries
   PrimExpr expr_aligned = Substitute(expr,
-      {var, floordiv(var, target_vectorized_size) * target_vectorized_size}});
+      {{var, FloorDiv(var, target_vectorized_size) * target_vectorized_size}});
   if (analyzer->CanProveEqual(expr, expr_aligned)) {
     return true;
   }

Based on the static analysis hint reporting "Unmatched '('" at line 295.

🧰 Tools
🪛 Cppcheck (2.18.0)

[error] 295-295: Unmatched '('. Configuration

(syntaxError)

🤖 Prompt for AI Agents
In src/transform/loop_vectorize.cc around lines 293 to 299, the map initializer
passed to Substitute is missing the opening brace and the function name is
incorrectly cased; change the map to start with an opening brace and rename
floordiv to FloorDiv so the call reads Substitute(expr, {var, FloorDiv(var,
target_vectorized_size) * target_vectorized_size}); ensure the braces and
parentheses are balanced.


auto simplified_expr = analyzer->Simplify(Substitute(expr, {{var, zero}}));
// The base offset must be divisible
if (!analyzer->CanProveEqual(FloorMod(simplified_expr, target_size_for_expr),
Expand Down
59 changes: 58 additions & 1 deletion testing/python/language/test_tilelang_language_vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True})
def vectorize_test(N, M, stride_A, stride_B):
assert N % 128 == 0 and M % 128 == 0

@T.prim_func
def main(
Expand All @@ -23,6 +22,7 @@ def main(


def run_vectorize(N, M, stride_A, stride_B):
assert N % 128 == 0 and M % 128 == 0
assert stride_A >= N and stride_B >= N

jit_kernel = vectorize_test(N, M, stride_A, stride_B)
Expand Down Expand Up @@ -59,5 +59,62 @@ def test_vectorize():
run_vectorize(N, M, N + 8, N + 16)


@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True})
def vectorize_test_invariant_index(N, M, K):

@T.prim_func
def main(
A: T.Tensor[(N, M), "float32"], # noqa: F821
B: T.Tensor[(N, M), "float32"], # noqa: F821
C: T.Tensor[(N, M // K), "float32"], # noqa: F821
):
with T.Kernel(N // 128, threads=128) as (bx):
tx = T.get_thread_binding(0)
row = bx * 128 + tx

for col in T.vectorized(M):
B[row, col] = A[row, col] * C[row, col // K]

return main


def run_vectorize_invariant_index(N, M, K):
assert N % 128 == 0 and M % K == 0

jit_kernel = vectorize_test_invariant_index(N, M, K)

a = torch.randn(N, M, device="cuda", dtype=torch.float32)
b = torch.zeros(N, M, device="cuda", dtype=torch.float32)
c = torch.randn(N, M // K, device="cuda", dtype=torch.float32)

jit_kernel(a, b, c)

indices = torch.arange(a.size(1)) // K
ret = a * c[:, indices]
torch.testing.assert_close(b, ret, atol=1e-8, rtol=1e-8)

code = jit_kernel.get_kernel_source()

vectorize_size = 1
while vectorize_size <= 2 and K % (vectorize_size * 2) == 0:
vectorize_size *= 2

if vectorize_size == 4:
assert "float4" in code
elif vectorize_size == 2:
assert "float2" in code


def test_vectorize_invariant_index():
N, M = 512, 256

run_vectorize_invariant_index(N, M, 2)
run_vectorize_invariant_index(N, M, 4)
run_vectorize_invariant_index(N, M * 3, 6)
run_vectorize_invariant_index(N, M, 8)
run_vectorize_invariant_index(N, M * 3, 12)
run_vectorize_invariant_index(N, M * 7, 14)


if __name__ == "__main__":
tilelang.testing.main()
Loading