diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 6feb93369..dcb9570fb 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -738,7 +738,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { bool store_into_local = false; PostOrderVisit(root, [&](const ObjectRef &obj) { if (const auto *store = obj.as()) { - if (IsLocalBuffer(store->buffer, true)) { + if (IsLocalBuffer(store->buffer)) { store_into_local = true; } } @@ -751,11 +751,11 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { bool local_register_only = true; PostOrderVisit(root, [&](const ObjectRef &obj) { if (const auto *store = obj.as()) { - if (!IsLocalBuffer(store->buffer, true)) { + if (!IsLocalBuffer(store->buffer)) { local_register_only = false; } } else if (const auto *load = obj.as()) { - if (!IsLocalBuffer(load->buffer, true)) { + if (!IsLocalBuffer(load->buffer)) { local_register_only = false; } } @@ -769,13 +769,11 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { bool has_non_local = false; PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { if (const auto *load = obj.as()) { - if (!IsLocalBuffer(load->buffer, true) && - !IsFragmentBuffer(load->buffer)) { + if (!IsLocalBuffer(load->buffer) && !IsFragmentBuffer(load->buffer)) { has_non_local = true; } } else if (const auto *store = obj.as()) { - if (!IsLocalBuffer(store->buffer, true) && - !IsFragmentBuffer(store->buffer)) { + if (!IsLocalBuffer(store->buffer) && !IsFragmentBuffer(store->buffer)) { has_non_local = true; } } diff --git a/testing/python/issue/test_tilelang_issue_1549.py b/testing/python/issue/test_tilelang_issue_1549.py index 1becf5a74..d23659e37 100644 --- a/testing/python/issue/test_tilelang_issue_1549.py +++ b/testing/python/issue/test_tilelang_issue_1549.py @@ -28,6 +28,15 @@ def main( kernel = get_wrong_kernel(M) data = torch.randint(0, 100, (M,), dtype=torch.int32, device="cuda") kernel(data) + code = kernel.get_kernel_source() + print(code) + assert ( + """for (int i = 0; i < 32; ++i) { + idx = ((i * 64) + ((int)threadIdx.x)); + Data[((i * 64) + ((int)threadIdx.x))] = idx; + }""" + in code + ) if __name__ == "__main__":