Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 97 additions & 95 deletions src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1090,112 +1090,114 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
reducer_info = op->annotations.Get(attr::kReducerInfo)
->as<Map<Var, ReducerInfo>>()
.value();

if (!result_.for_map.count(tvm::ffi::GetRef<For>(op))) {
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
// the analyzer will be modified in PartitionLoop and VectorizeLoop
// we need to save its state to prevent conflicted bindings
auto saved_analyzer = analyzer_->Clone();
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (result_.for_map.count(tvm::ffi::GetRef<For>(op))) {
auto root = tvm::ffi::GetRef<For>(op);
// This check is a workaround to support T.Parallel for local buffers.
// For example:
// for i in T.Parallel(1024):
// A_local[i] = A_global[i]
// Here, A_local is a register-local buffer held independently by each
// thread, so explicit thread binding is not required.
bool store_into_local = false;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (store->buffer.scope() == "local") {
store_into_local = true;
}
// if the case is like:
// for i in T.Parallel(1024):
// A_local[i] = B_global[i]
// A_frag[i] = A_global[i]
// exception will be raise in Parallel::LayoutInference
auto root = tvm::ffi::GetRef<For>(op);
// This check is a workaround to support T.Parallel for local buffers.
// For example:
// for i in T.Parallel(1024):
// A_local[i] = A_global[i]
// Here, A_local is a register-local buffer held independently by each
// thread, so explicit thread binding is not required.
bool store_into_local = false;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (store->buffer.scope() == "local") {
store_into_local = true;
}
});
// This check if for the loop that only manuplates "local" buffers,
// for i in T.Parallel(1024):
// A_local[i] = B_local[i]
// Though this might be illegal
// We use PostOrderVisit to detect whether the loop only manuplates
// "local" buffers, which indicates register usage and justifies skipping
// thread binding.
bool local_register_only = true;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (store->buffer.scope() != "local") {
local_register_only = false;
}
} else if (const auto *load = obj.as<BufferLoadNode>()) {
if (load->buffer.scope() != "local") {
local_register_only = false;
}
// if the case is like:
// for i in T.Parallel(1024):
// A_local[i] = B_global[i]
// A_frag[i] = A_global[i]
// exception will be raise in Parallel::LayoutInference
}
});
// This check if for the loop that only manuplates "local" buffers,
// for i in T.Parallel(1024):
// A_local[i] = B_local[i]
// Though this might be illegal
// We use PostOrderVisit to detect whether the loop only manuplates
// "local" buffers, which indicates register usage and justifies skipping
// thread binding.
bool local_register_only = true;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (store->buffer.scope() != "local") {
local_register_only = false;
}
});
} else if (const auto *load = obj.as<BufferLoadNode>()) {
if (load->buffer.scope() != "local") {
local_register_only = false;
}
}
});

auto loop_layout = result_.for_map[root];
// FIXME: tell in-Parallel and out-of-Parallel `local`s apart
// NOTE(lei): a bit ugly, we should rethink about this part in future.
bool parallel_loop =
!skip_thread_partition_ && !local_register_only && !store_into_local;
auto loop_layout = result_.for_map[root];
// FIXME: tell in-Parallel and out-of-Parallel `local`s apart
// NOTE(lei): a bit ugly, we should rethink about this part in future.
bool parallel_loop =
!skip_thread_partition_ && !local_register_only && !store_into_local;

if (parallel_loop) {
for_node =
PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
if (parallel_loop) {
for_node =
PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
}
// If none thread bindings are provided, partition the loop
bool has_non_local = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *load = obj.as<BufferLoadNode>()) {
String scope = load->buffer.scope();
if (scope != "local" && scope != "local.fragment") {
has_non_local = true;
}
} else if (const auto *store = obj.as<BufferStoreNode>()) {
String scope = store->buffer.scope();
if (scope != "local" && scope != "local.fragment") {
has_non_local = true;
}
}
// If none thread bindings are provided, partition the loop
bool has_non_local = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *load = obj.as<BufferLoadNode>()) {
String scope = load->buffer.scope();
if (scope != "local" && scope != "local.fragment") {
has_non_local = true;
}
} else if (const auto *store = obj.as<BufferStoreNode>()) {
String scope = store->buffer.scope();
if (scope != "local" && scope != "local.fragment") {
has_non_local = true;
}
});
// Workaround: if reducer is presented, don't vectorize loop
// Best solution should be isolate reduction axis out of vectorization
bool has_reducer = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (!has_reducer)
if (const auto *store = obj.as<BufferStoreNode>()) {
has_reducer = reducer_info.count(store->buffer->data) != 0;
}
});
// Workaround: if reducer is presented, don't vectorize loop
// Best solution should be isolate reduction axis out of vectorization
bool has_reducer = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (!has_reducer)
if (const auto *store = obj.as<BufferStoreNode>()) {
has_reducer = reducer_info.count(store->buffer->data) != 0;
}
});

// If a cast operation exists, vectorization may still be required
bool has_cast_operations = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *cast = obj.as<CastNode>()) {
// Check if this is a non-reducer store with Cast operation
DataType src_type = cast->value.dtype();
DataType dst_type = cast->dtype;
bool src_ok = src_type.is_float() || src_type.is_bfloat() ||
src_type.is_float8_e4m3() || src_type.is_float8_e5m2();
bool dst_ok = dst_type.is_float() || dst_type.is_bfloat() ||
dst_type.is_float8_e4m3() || dst_type.is_float8_e5m2();
if (src_ok && dst_ok && TargetIsCuda(Target::Current())) {
has_cast_operations = true;
}
});

// If a cast operation exists, vectorization may still be required
bool has_cast_operations = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *cast = obj.as<CastNode>()) {
// Check if this is a non-reducer store with Cast operation
DataType src_type = cast->value.dtype();
DataType dst_type = cast->dtype;
bool src_ok = src_type.is_float() || src_type.is_bfloat() ||
src_type.is_float8_e4m3() || src_type.is_float8_e5m2();
bool dst_ok = dst_type.is_float() || dst_type.is_bfloat() ||
dst_type.is_float8_e4m3() || dst_type.is_float8_e5m2();
if (src_ok && dst_ok && TargetIsCuda(Target::Current())) {
has_cast_operations = true;
}
});

if ((has_non_local || has_cast_operations) && !has_reducer) {
for_node = VectorizeLoop(for_node, analyzer_);
}
});
Comment on lines +1175 to +1190
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Search for TargetIsCuda implementation
rg -n "TargetIsCuda" --type=cpp -B2 -A3

Repository: tile-ai/tilelang

Length of output: 5825


🏁 Script executed:

#!/bin/bash
# Search for TargetIsCuda in the codebase
echo "=== Searching for TargetIsCuda implementation ===" 
rg -n "bool TargetIsCuda|TargetIsCuda\(" --type=cpp -B2 -A5 | head -100

echo -e "\n=== Searching for PostOrderVisit calls in layout_inference.cc ===" 
rg -n "PostOrderVisit" src/transform/layout_inference.cc -B2 -A5

echo -e "\n=== Checking if Target context is set before the function ===" 
rg -n "EnterWithScope|ExitWithScope|Target::" src/transform/layout_inference.cc -B1 -A1

Repository: tile-ai/tilelang

Length of output: 7736


🏁 Script executed:

#!/bin/bash
# Check if Target context is explicitly set in this function or nearby
echo "=== Checking function context and Target setup ===" 
rg -n "Target\(\)" src/transform/layout_inference.cc -B5 -A5 | head -80

# Look for the function signature
rg -n "bool LayoutInferencePass|AnalyzeVectorization" src/transform/layout_inference.cc -B2 -A15 | head -100

Repository: tile-ai/tilelang

Length of output: 110


Check Target context is available; consolidate consecutive tree traversals.

Line 1186 calls TargetIsCuda(Target::Current()) without verifying the Target context stack is populated. Target::Current() with default parameters returns an undefined Target if the context stack is empty, which could cause GetTargetDeviceType() to return an unexpected value. If this pass doesn't explicitly establish a Target scope, vectorization decisions may be made based on undefined target information.

Additionally, lines 1152, 1168, and 1177 perform three separate PostOrderVisit traversals on for_node->body. These can be merged into a single traversal to collect all conditions (has_non_local, has_reducer, has_cast_operations) in one pass.

🤖 Prompt for AI Agents
In src/transform/layout_inference.cc around lines 1175-1190, avoid calling
TargetIsCuda(Target::Current()) when no Target is on the context stack and
consolidate the three PostOrderVisit traversals into one: fetch the current
Target into a local variable and only call TargetIsCuda if that Target is
defined (treat undefined as non-CUDA), then perform a single PostOrderVisit over
for_node->body that sets has_non_local, has_reducer and has_cast_operations
within the same traversal (check cast src/dst types and use the guarded
TargetIsCuda check when deciding has_cast_operations).


if (result_.predicate_map.count(root) && parallel_loop) {
return IfThenElse(result_.predicate_map[root], for_node);
} else {
return for_node;
}
if ((has_non_local || has_cast_operations) && !has_reducer) {
for_node = VectorizeLoop(for_node, saved_analyzer.get());
}

if (result_.predicate_map.count(root) && parallel_loop) {
return IfThenElse(result_.predicate_map[root], for_node);
} else {
return for_node;
}
return for_node;
}

Stmt VisitStmt_(const AttrStmtNode *op) final {
Expand Down
Loading