diff --git a/src/transform/storage_rewrite.cc b/src/transform/storage_rewrite.cc index 40973f39a..91ca0d347 100644 --- a/src/transform/storage_rewrite.cc +++ b/src/transform/storage_rewrite.cc @@ -1094,6 +1094,10 @@ class StoragePlanRewriter : public StmtExprMutator { return NewAlloc(op, attach_scope, scope, const_nbits); } + if (scope.tag == ".var") { + return NewAlloc(op, attach_scope, scope, const_nbits); + } + if (is_known_size) { // constant allocation. auto begin = const_free_map_.lower_bound(const_nbits / match_range); diff --git a/testing/python/issue/test_tilelang_issue_1678.py b/testing/python/issue/test_tilelang_issue_1678.py new file mode 100644 index 000000000..236800a13 --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_1678.py @@ -0,0 +1,30 @@ +import tilelang +import tilelang.testing +from tilelang import language as T + + +@tilelang.jit +def _alloc_var_mixed_dtype_kernel(): + @T.prim_func + def kernel(): + with T.Kernel(1, 1, threads=1) as (_, _): + i = T.alloc_var(T.int32) + i = 1 + tmp_row = T.alloc_local((4,), T.float32) + amax_local = T.alloc_var(T.float32) + j = i + amax_local = T.max(amax_local, tmp_row[j]) + + return kernel + + +@tilelang.testing.requires_cuda +def test_alloc_var_mixed_dtype_codegen(): + kernel = _alloc_var_mixed_dtype_kernel() + source = kernel.get_kernel_source() + assert "int i" in source + assert "float amax_local" in source + + +if __name__ == "__main__": + tilelang.testing.main()