diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 3db3f9aa4..55f9252cf 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -998,11 +998,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { bool has_non_local = false; PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { if (const auto *load = obj.as()) { - if (!IsLocalBuffer(load->buffer) && !IsFragmentBuffer(load->buffer)) { + if (!IsLocalBuffer(load->buffer, /*allow_var*/ true) && + !IsFragmentBuffer(load->buffer)) { has_non_local = true; } } else if (const auto *store = obj.as()) { - if (!IsLocalBuffer(store->buffer) && !IsFragmentBuffer(store->buffer)) { + if (!IsLocalBuffer(store->buffer, /*allow_var*/ true) && + !IsFragmentBuffer(store->buffer)) { has_non_local = true; } } diff --git a/testing/python/language/test_tilelang_language_parallel.py b/testing/python/language/test_tilelang_language_parallel.py index a392e70b6..6c0aa8032 100644 --- a/testing/python/language/test_tilelang_language_parallel.py +++ b/testing/python/language/test_tilelang_language_parallel.py @@ -66,5 +66,21 @@ def test_parallel_dynamic_extent(): torch.testing.assert_close(out, reference, atol=1e-5, rtol=1e-5) +@tilelang.jit +def _parallel_vectorize_local_and_var(): + with T.Kernel(1) as _: + x = T.alloc_fragment([256], T.float32) + y = T.alloc_fragment([256], T.float32) + z = T.alloc_var(T.float32) + for i in T.parallel(256): + y[i] = x[i] * z + + +def test_parallel_vectorize_var(): + source = _parallel_vectorize_local_and_var.get_kernel_source() + # do not vectorize if the loop only contains local/fragment and var buffer access + assert "float2" not in source + + if __name__ == "__main__": tilelang.testing.main()