Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDeviceCompileFlags, ffi::Array<ffi::String>);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDataRaceCheck, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLowerLDGSTG, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLowerLDGSTGPredicated, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableLoopUnswitching, Bool);

DataType cuTensorMapType() { return DataType::UInt(8, 128); }

Expand Down
35 changes: 29 additions & 6 deletions src/transform/lower_ldg_stg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,6 @@ class LowerLDGSTGRewriter : public StmtExprMutator {
}

PrimExpr VisitExpr_(const BufferLoadNode *load) final {
// Skip if non-predicated lowering is disabled
if (!enable_non_predicated_) {
return StmtExprMutator::VisitExpr_(load);
}

// Skip loads in async scope (will be lowered to cp.async)
if (in_async_scope_) {
return StmtExprMutator::VisitExpr_(load);
Expand All @@ -189,6 +184,17 @@ class LowerLDGSTGRewriter : public StmtExprMutator {
return StmtExprMutator::VisitExpr_(load);
}

// Check if we're in a predicated context (from IfThenElse store pattern)
// In this case, we need to use predicated load regardless of
// enable_non_predicated_
bool use_predicated = current_predicate_.defined();

// Skip if non-predicated lowering is disabled and we're not in predicated
// context
if (!enable_non_predicated_ && !use_predicated) {
return StmtExprMutator::VisitExpr_(load);
}

// Assume buffer has been flattened by FlattenBuffer pass
ICHECK(load->indices.size() == 1)
<< "Expected flattened buffer with single index, but got "
Expand All @@ -208,6 +214,10 @@ class LowerLDGSTGRewriter : public StmtExprMutator {
// Check for supported vector widths (32/64/128/256 bits)
if (total_bits == 32 || total_bits == 64 || total_bits == 128 ||
total_bits == 256) {
if (use_predicated) {
return LowerToLDGPredicated(load, ramp->base, total_bits,
current_predicate_.value());
}
return LowerToLDG(load, ramp->base, total_bits);
}
}
Expand All @@ -216,6 +226,10 @@ class LowerLDGSTGRewriter : public StmtExprMutator {
// Single element load (non-Ramp)
int bits = load->buffer->dtype.bits();
if (bits == 32 || bits == 64 || bits == 128 || bits == 256) {
if (use_predicated) {
return LowerToLDGPredicated(load, load->indices[0], bits,
current_predicate_.value());
}
return LowerToLDG(load, load->indices[0], bits);
}
}
Expand Down Expand Up @@ -302,6 +316,8 @@ class LowerLDGSTGRewriter : public StmtExprMutator {
bool in_async_scope_{false};
bool enable_non_predicated_{false};
bool enable_predicated_{true};
Optional<PrimExpr>
current_predicate_; // Track predicate context for nested loads

// Create access pointer for the buffer at given base offset
PrimExpr CreateAccessPtr(const Buffer &buffer, const PrimExpr &base,
Expand Down Expand Up @@ -426,9 +442,16 @@ class LowerLDGSTGRewriter : public StmtExprMutator {
int bits, const PrimExpr &predicate) {
PrimExpr ptr = CreateAccessPtr(store->buffer, base, 2);

// Get the value to store
// Set predicate context so that nested loads also use predicated version
Optional<PrimExpr> old_predicate = current_predicate_;
current_predicate_ = predicate;

// Get the value to store (loads inside will use predicated version)
PrimExpr value = this->VisitExpr(store->value);

// Restore old predicate context
current_predicate_ = old_predicate;

// Reinterpret value to uint32xN if needed
DataType store_dtype;
const Op *stg_op;
Expand Down
24 changes: 24 additions & 0 deletions testing/python/transform/test_tilelang_transform_lower_ldgstg.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,30 @@ def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32"), pred: T
assert _check_has_intrinsic(mod, "stg128"), "Expected predicated stg128"


def test_predicated_store_with_load():
"""Test that when a predicated store contains a load, the load also gets predicated.

This tests the pattern: if (pred) { B[i] = A[i] }
Both the store and the load should use predicated versions to avoid
out-of-bounds memory access when pred is false.
"""

@T.prim_func
def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32"), pred: T.int32):
for i in T.thread_binding(32, "threadIdx.x"):
for j in T.vectorized(4):
with T.If(pred > 0), T.Then():
B[i * 4 + j] = A[i * 4 + j]

mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = _apply_passes(mod, enable_predicated=True)
print("=== test_predicated_store_with_load ===")
print(mod)
# Both load and store should be predicated
assert _check_has_intrinsic(mod, "ldg128"), "Expected predicated ldg128 for load inside predicated store"
assert _check_has_intrinsic(mod, "stg128"), "Expected predicated stg128"


def test_predicated_disabled():
"""Test that predicated lowering can be disabled."""

Expand Down
Loading