diff --git a/src/transform/storage_rewrite.cc b/src/transform/storage_rewrite.cc index 40973f39a..353523319 100644 --- a/src/transform/storage_rewrite.cc +++ b/src/transform/storage_rewrite.cc @@ -1109,6 +1109,9 @@ class StoragePlanRewriter : public StmtExprMutator { // when not divided, no reuse, eg, float4 vs float3 if (e->bits_offset % op_elem_bits != 0) continue; + // must check element type to avoid type mismatch in codegen + if (e->elem_type != op->dtype.element_of()) + continue; if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) { continue; } @@ -1962,21 +1965,12 @@ Pass StorageRewrite() { ctx->GetConfig(kStorageRewriteDetectInplace, Bool(false)).value(); bool enable_reuse = true; bool reuse_require_exact_matched_dtype = false; - bool merge_static_smem = - ctx->GetConfig("tir.merge_static_smem", Bool(false)).value(); AllocateCollector collector; collector(f->body); - bool has_dynamic = collector.dyn_shmem_allocs_.size() > 1; - if (has_dynamic || merge_static_smem) { - // For IRModule utilizing dynamic shared memory, reuse is not enabled - // Because dynamic doesn't require maintaining the readability and - // it benefits from a more optimized allocation strategy through the - // Pass `MergeSharedMemoryAllocations`. - // When `merge_static_smem` is true, we will reuse and merge shared - // memory in a dedicated pass `MergeSharedMemoryAllocations`. - // And so we don't enable reuse in this pass. - enable_reuse = false; - } + // Always disable reuse currently, for shared memory reuse we depend on + // MergeSharedMemoryAllocations pass, for register reuse we depend on nvcc + // or other compiler its self. + enable_reuse = false; Optional target = f->GetAttr("target"); if (target.defined() && (target.value()->kind->name == "vulkan" || diff --git a/testing/python/issue/test_tilelang_issue_1678.py b/testing/python/issue/test_tilelang_issue_1678.py new file mode 100644 index 000000000..d22cc414a --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_1678.py @@ -0,0 +1,26 @@ +# ruff: noqa +import tilelang +import tilelang.testing +import tilelang.language as T + + +def test_issue_1678(): + @tilelang.jit + def qwq(): + @T.prim_func + def qwq_kernel(): + with T.Kernel(4096, 1, threads=1) as (pid_y, pid_x): + i = T.alloc_var("int32") + i = 1 + tmp_row = T.alloc_local((4,), "float32") + amax_local = T.alloc_var("float32") + j = 0 + amax_local = T.max(amax_local, tmp_row[j]) + + return qwq_kernel + + kernel = qwq() + + +if __name__ == "__main__": + tilelang.testing.main()