diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index e8a18b004..17c9e1f66 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -290,6 +290,24 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter), 0)) return false; + + // Check if expr is invariant within vector boundaries + // We're trying to prove the access expression A[f(var)] depends only on + // floor(var/vecsize), not on var%vecsize + // Mathematically: + // \forall var, f(floor(var/vecsize)*vecsize + var%vecsize) == + // f(floor(var/vecsize)*vecsize + 0) + // Example: for i in T.vectorized(8): + // A[i] = B[i] * C[i//4] + // if vecsize=4, f(i)=i//4 depends only on i//4 + // Therefore A[i] = B[i] * C[i//4] can be vectorized with vecsize=4 + PrimExpr var_aligned = + floordiv(var, target_vectorized_size) * target_vectorized_size; + PrimExpr expr_aligned = Substitute(expr, {{var, var_aligned}}); + if (analyzer->CanProveEqual(expr, expr_aligned)) { + return true; + } + auto simplified_expr = analyzer->Simplify(Substitute(expr, {{var, zero}})); // The base offset must be divisible if (!analyzer->CanProveEqual(FloorMod(simplified_expr, target_size_for_expr), diff --git a/testing/python/language/test_tilelang_language_vectorize.py b/testing/python/language/test_tilelang_language_vectorize.py index cee8b5a63..bc2d31446 100644 --- a/testing/python/language/test_tilelang_language_vectorize.py +++ b/testing/python/language/test_tilelang_language_vectorize.py @@ -5,7 +5,6 @@ @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True}) def vectorize_test(N, M, stride_A, stride_B): - assert N % 128 == 0 and M % 128 == 0 @T.prim_func def main( @@ -23,6 +22,7 @@ def main( def run_vectorize(N, M, stride_A, stride_B): + assert N % 128 == 0 and M % 128 == 0 assert stride_A >= N and stride_B >= N jit_kernel = vectorize_test(N, M, stride_A, stride_B) @@ -59,5 +59,62 @@ def test_vectorize(): run_vectorize(N, M, N + 8, N + 16) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True}) +def vectorize_test_invariant_index(N, M, K): + + @T.prim_func + def main( + A: T.Tensor[(N, M), "float32"], # noqa: F821 + B: T.Tensor[(N, M), "float32"], # noqa: F821 + C: T.Tensor[(N, M // K), "float32"], # noqa: F821 + ): + with T.Kernel(N // 128, threads=128) as (bx): + tx = T.get_thread_binding(0) + row = bx * 128 + tx + + for col in T.vectorized(M): + B[row, col] = A[row, col] * C[row, col // K] + + return main + + +def run_vectorize_invariant_index(N, M, K): + assert N % 128 == 0 and M % K == 0 + + jit_kernel = vectorize_test_invariant_index(N, M, K) + + a = torch.randn(N, M, device="cuda", dtype=torch.float32) + b = torch.zeros(N, M, device="cuda", dtype=torch.float32) + c = torch.randn(N, M // K, device="cuda", dtype=torch.float32) + + jit_kernel(a, b, c) + + indices = torch.arange(a.size(1)) // K + ret = a * c[:, indices] + torch.testing.assert_close(b, ret, atol=1e-8, rtol=1e-8) + + code = jit_kernel.get_kernel_source() + + vectorize_size = 1 + while vectorize_size <= 2 and K % (vectorize_size * 2) == 0: + vectorize_size *= 2 + + if vectorize_size == 4: + assert "float4" in code + elif vectorize_size == 2: + assert "float2" in code + + +def test_vectorize_invariant_index(): + N, M = 512, 256 + + run_vectorize_invariant_index(N, M, 2) + run_vectorize_invariant_index(N, M, 4) + run_vectorize_invariant_index(N, M * 3, 6) + run_vectorize_invariant_index(N, M, 8) + run_vectorize_invariant_index(N, M * 3, 12) + run_vectorize_invariant_index(N, M * 7, 14) + + if __name__ == "__main__": tilelang.testing.main()