diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 4dbc85954..7343d9558 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -40,6 +40,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDeviceCompileFlags, ffi::Array); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDataRaceCheck, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLowerLDGSTG, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLowerLDGSTGPredicated, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kDisableLoopUnswitching, Bool); DataType cuTensorMapType() { return DataType::UInt(8, 128); } diff --git a/src/transform/lower_ldg_stg.cc b/src/transform/lower_ldg_stg.cc index 12a17d06b..abcdfb6e7 100644 --- a/src/transform/lower_ldg_stg.cc +++ b/src/transform/lower_ldg_stg.cc @@ -174,11 +174,6 @@ class LowerLDGSTGRewriter : public StmtExprMutator { } PrimExpr VisitExpr_(const BufferLoadNode *load) final { - // Skip if non-predicated lowering is disabled - if (!enable_non_predicated_) { - return StmtExprMutator::VisitExpr_(load); - } - // Skip loads in async scope (will be lowered to cp.async) if (in_async_scope_) { return StmtExprMutator::VisitExpr_(load); @@ -189,6 +184,17 @@ class LowerLDGSTGRewriter : public StmtExprMutator { return StmtExprMutator::VisitExpr_(load); } + // Check if we're in a predicated context (from IfThenElse store pattern) + // In this case, we need to use predicated load regardless of + // enable_non_predicated_ + bool use_predicated = current_predicate_.defined(); + + // Skip if non-predicated lowering is disabled and we're not in predicated + // context + if (!enable_non_predicated_ && !use_predicated) { + return StmtExprMutator::VisitExpr_(load); + } + // Assume buffer has been flattened by FlattenBuffer pass ICHECK(load->indices.size() == 1) << "Expected flattened buffer with single index, but got " @@ -208,6 +214,10 @@ class LowerLDGSTGRewriter : public StmtExprMutator { // Check for supported vector widths (32/64/128/256 bits) if (total_bits == 32 || total_bits == 64 || total_bits == 128 || total_bits == 256) { + if (use_predicated) { + return LowerToLDGPredicated(load, ramp->base, total_bits, + current_predicate_.value()); + } return LowerToLDG(load, ramp->base, total_bits); } } @@ -216,6 +226,10 @@ class LowerLDGSTGRewriter : public StmtExprMutator { // Single element load (non-Ramp) int bits = load->buffer->dtype.bits(); if (bits == 32 || bits == 64 || bits == 128 || bits == 256) { + if (use_predicated) { + return LowerToLDGPredicated(load, load->indices[0], bits, + current_predicate_.value()); + } return LowerToLDG(load, load->indices[0], bits); } } @@ -302,6 +316,8 @@ class LowerLDGSTGRewriter : public StmtExprMutator { bool in_async_scope_{false}; bool enable_non_predicated_{false}; bool enable_predicated_{true}; + Optional + current_predicate_; // Track predicate context for nested loads // Create access pointer for the buffer at given base offset PrimExpr CreateAccessPtr(const Buffer &buffer, const PrimExpr &base, @@ -426,9 +442,16 @@ class LowerLDGSTGRewriter : public StmtExprMutator { int bits, const PrimExpr &predicate) { PrimExpr ptr = CreateAccessPtr(store->buffer, base, 2); - // Get the value to store + // Set predicate context so that nested loads also use predicated version + Optional old_predicate = current_predicate_; + current_predicate_ = predicate; + + // Get the value to store (loads inside will use predicated version) PrimExpr value = this->VisitExpr(store->value); + // Restore old predicate context + current_predicate_ = old_predicate; + // Reinterpret value to uint32xN if needed DataType store_dtype; const Op *stg_op; diff --git a/testing/python/transform/test_tilelang_transform_lower_ldgstg.py b/testing/python/transform/test_tilelang_transform_lower_ldgstg.py index 4aca28d80..faea4e82f 100644 --- a/testing/python/transform/test_tilelang_transform_lower_ldgstg.py +++ b/testing/python/transform/test_tilelang_transform_lower_ldgstg.py @@ -193,6 +193,30 @@ def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32"), pred: T assert _check_has_intrinsic(mod, "stg128"), "Expected predicated stg128" +def test_predicated_store_with_load(): + """Test that when a predicated store contains a load, the load also gets predicated. + + This tests the pattern: if (pred) { B[i] = A[i] } + Both the store and the load should use predicated versions to avoid + out-of-bounds memory access when pred is false. + """ + + @T.prim_func + def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32"), pred: T.int32): + for i in T.thread_binding(32, "threadIdx.x"): + for j in T.vectorized(4): + with T.If(pred > 0), T.Then(): + B[i * 4 + j] = A[i * 4 + j] + + mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) + mod = _apply_passes(mod, enable_predicated=True) + print("=== test_predicated_store_with_load ===") + print(mod) + # Both load and store should be predicated + assert _check_has_intrinsic(mod, "ldg128"), "Expected predicated ldg128 for load inside predicated store" + assert _check_has_intrinsic(mod, "stg128"), "Expected predicated stg128" + + def test_predicated_disabled(): """Test that predicated lowering can be disabled."""