Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,16 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
"NoCheck; symbolic dims: "
<< symbolic_dims;
}
arith::IterMapResult res =
arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer);

arith::IterMapResult res;
res = arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer);
// We do not find the exact level of current layout, because non-bijective
// will leave loop guard to handle the boundary. So just set the level flag
// and use bijective res to generate correct inverse layour
if (!res->errors.empty()) {
level = arith::IterMapLevel::NoCheck;
}

if (!res->errors.empty()) {
std::ostringstream msg;
msg << "Layout " << DebugOutput() << " has errors: " << res->errors;
Expand Down Expand Up @@ -607,7 +615,16 @@ arith::IterMapResult FragmentNode::DetectInjective() const {
<< "NoCheck; symbolic dims: " << symbolic_dims;
}

return arith::DetectIterMap(indices, getVarMap(), 1, level, &analyzer);
arith::IterMapResult res;
while (true) {
res = arith::DetectIterMap(indices, getVarMap(), 1, level, &analyzer);
if (res->errors.empty())
break;
if (level == arith::IterMapLevel::NoCheck)
break;
level = static_cast<arith::IterMapLevel>((static_cast<int>(level) * 2) + 1);
}
return res;
}

PrimExpr FragmentNode::ThreadExtent() const {
Expand Down
23 changes: 23 additions & 0 deletions src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,29 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
}
}

// To avoid fragment layout exceed loop extent's boundary
if (loop_layout_.defined()) {
auto inv_layout = loop_layout_->Inverse();
Array<PrimExpr> thread_indices;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++) {
thread_indices.push_back(0);
}
thread_indices.push_back(InputPlaceholder(0) - T.thread_bounds->min);
auto logical_indices = inv_layout->Forward(thread_indices);

// Check if the logical indices exceed the original loop_vars_ extent
for (size_t i = 0; i < loop_vars_.size(); ++i) {
PrimExpr logical_i = logical_indices[i];
PrimExpr original_extent = loop_vars_[i]->dom->extent;

// If the logical index cannot be proved to be less than the original
// extent, add a predicate
if (!analyzer_.CanProve(LT(logical_i, original_extent))) {
AddPredicate(LT(logical_i, original_extent));
}
}
}

if (!analyzer_.CanProveEqual(loop_thread_extent, block_size)) {
AddPredicate(
LT(InputPlaceholder(0), loop_thread_extent + T.thread_bounds->min));
Expand Down
12 changes: 10 additions & 2 deletions src/transform/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,8 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
Array<PrimExpr> loop_extents;
auto inverse_info = loop_layout->InverseWithLevel();
auto inv_loop = inverse_info.first;
// Must check the guard if the layout can not be proved as bijective
bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective;
auto indices = inv_loop->Forward(Array<PrimExpr>(vars.begin(), vars.end()));

// Normalize thread var once so we can reuse the same substitution later.
Map<Var, PrimExpr> thread_offset_map;
bool has_thread_offset = false;
Expand All @@ -94,15 +93,24 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
thread_offset_map.Set(thread_var, thread_var - range->min);
has_thread_offset = true;
}
bool bounds_match = true;
for (int i = 0; i < old_loop_depth; i++) {
const ForNode *loop = body.as<ForNode>();
ICHECK(loop != nullptr)
<< "No extra statements are allowed between nested parallel loops.";
if (!analyzer->CanProve(loop->extent == loop_layout->InputShape()[i]) ||
!analyzer->CanProve(loop->min == 0)) {
bounds_match = false;
}
vmap.Set(loop->loop_var, indices[i]);
loop_mins.push_back(loop->min);
loop_extents.push_back(loop->extent);
body = loop->body;
}
// Must check the guard if the layout can not be proved as bijective
bool need_guard =
inverse_info.second != arith::IterMapLevel::Bijective || !bounds_match;

// substitute and re-construct the serial loop
body = Substitute(body, vmap);
// Guard executes the recovered loop body only if each inverse-mapped iterator
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import tilelang
import tilelang.language as T
import tilelang.testing
from tilelang import tvm as tvm
from tilelang.utils.target import determine_target


def _tilelang_transform_loop_partition_boundary():
def before():
@T.prim_func
def main(
S: T.Tensor((8), T.bfloat16),
D: T.Tensor((4, 64), T.bfloat16),
):
with T.Kernel(1, threads=128):
S_shared = T.alloc_shared((8), T.bfloat16)
S_fragment = T.alloc_fragment((8), T.float32)
D_shared = T.alloc_shared((4, 64), T.bfloat16)

T.copy(S, S_shared)
T.copy(S_shared, S_fragment)
for k in T.serial(64):
for i in T.Parallel(4):
D_shared[i, k] = S_fragment[i]
T.copy(D_shared, D)

return main

def after():
@T.prim_func
def main(
S: T.Tensor((8), T.bfloat16),
D: T.Tensor((4, 64), T.bfloat16),
):
with T.Kernel(1, threads=128):
tx = T.get_thread_binding(0)
S_shared = T.alloc_shared((8), T.bfloat16)
S_fragment = T.alloc_fragment((8), T.float32)
D_shared = T.alloc_shared((4, 64), T.bfloat16)

T.copy(S, S_shared)
T.copy(S_shared, S_fragment)
for k in T.serial(64):
if tx < 4:
for i in T.Parallel(4):
D_shared[i, k] = S_fragment[i]
T.copy(D_shared, D)

return main

return tvm.IRModule({"main": before()}), tvm.IRModule({"main": after()})


def boundary_check():
before, after = _tilelang_transform_loop_partition_boundary()
target = tvm.target.Target(determine_target("auto"))
with target:
with tvm.transform.PassContext():
mod = tvm.tir.transform.BindTarget(target)(before)
mod = tilelang.transform.LayoutInference()(mod)
mod = tilelang.transform.LowerTileOp()(mod)
mod = tvm.tir.transform.Simplify()(mod)
with tvm.transform.PassContext():
ref_mod = tvm.tir.transform.BindTarget(target)(after)
ref_mod = tilelang.transform.LayoutInference()(ref_mod)
ref_mod = tilelang.transform.LowerTileOp()(ref_mod)
ref_mod = tvm.tir.transform.Simplify()(ref_mod)
assert mod["main"].script() == ref_mod["main"].script(), "mod and ref_mod are not structural equal"


def test_tilelang_transform_loop_partition_boundary():
boundary_check()


if __name__ == "__main__":
tilelang.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ def main(

return tvm.IRModule({"main": main})

before_mod, after_mod = before(), after()
with tvm.transform.PassContext():
mod = tvm.tir.transform.BindTarget(auto_target)(before())
mod = tvm.tir.transform.BindTarget(auto_target)(before_mod)
mod = tl.transform.LowerTileOp()(mod)
mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(after())
ref_mod = tvm.tir.transform.Simplify()(ref_mod)
with tvm.transform.PassContext():
ref_mod = tvm.tir.transform.BindTarget(auto_target)(after_mod)
ref_mod = tvm.tir.transform.Simplify()(ref_mod)
# Note(tzj): The structures are equal except the argument in "T.reads" function.
# The difference is just between the first index and the indices range, which is totally equivalent
tvm.ir.structural_equal(mod, ref_mod)
Expand Down
Loading