Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions src/transform/inject_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ struct LetWrapper {
PrimExpr value;
};

struct IfWrapper {
PrimExpr condition;
Span span;
};

/*!
* \brief Collector to find all buffers used in a statement.
*
Expand Down Expand Up @@ -303,16 +308,20 @@ class PipelineRewriter : public StmtExprMutator {
* \param pipeline_loop The original loop to be software pipelined.
* \param pipeline_info The pipeline annotation information.
* \param loop_var_let_wrappers Let wrappers that depend on the loop var.
* \param loop_var_if_wrappers If wrappers with conditions that depend on
* the loop var.
*/
PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
const Array<Buffer> &pipeline_allocs,
const Array<Buffer> &local_allocs, const For &pipeline_loop,
const PipelineInfo &pipeline_info,
const std::vector<LetWrapper> &loop_var_let_wrappers)
const std::vector<LetWrapper> &loop_var_let_wrappers,
const std::vector<IfWrapper> &loop_var_if_wrappers)
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
pipeline_allocs_(pipeline_allocs), local_allocs_(local_allocs),
pipeline_loop_(pipeline_loop), pipeline_info_(pipeline_info),
loop_var_let_wrappers_(loop_var_let_wrappers) {}
loop_var_let_wrappers_(loop_var_let_wrappers),
loop_var_if_wrappers_(loop_var_if_wrappers) {}

Stmt BuildPipeline() {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the
Expand Down Expand Up @@ -860,6 +869,22 @@ class PipelineRewriter : public StmtExprMutator {
n->body = inner;
}

// Similarly, handle If-wrappers whose conditions depend on the
// pipeline loop var.
if (!loop_var_if_wrappers_.empty()) {
BlockNode *n = new_block.CopyOnWrite();
Stmt inner = n->body;
for (auto it = loop_var_if_wrappers_.rbegin();
it != loop_var_if_wrappers_.rend(); ++it) {
const auto &iw = *it;
PrimExpr substituted_condition =
Substitute(iw.condition,
{{pipeline_loop_->loop_var, normalized_access_index}});
inner = IfThenElse(substituted_condition, inner, Stmt(), iw.span);
}
n->body = inner;
}
Comment on lines 872 to 885
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the actual code at the mentioned lines
cat -n src/transform/inject_pipeline.cc | sed -n '850,900p'

Repository: tile-ai/tilelang

Length of output: 2522


🏁 Script executed:

# Also check what's before line 872 to understand let-wrapper handling
cat -n src/transform/inject_pipeline.cc | sed -n '859,885p'

Repository: tile-ai/tilelang

Length of output: 1380


🏁 Script executed:

# Search for the data structures used: loop_var_if_wrappers_, loop_var_let_wrappers_
rg -n "loop_var_if_wrappers_|loop_var_let_wrappers_" src/transform/inject_pipeline.cc | head -30

Repository: tile-ai/tilelang

Length of output: 622


🏁 Script executed:

# Find the test case mentioned
find . -name "test_tilelang_issue_1263.py" -o -name "*issue_1263*"

Repository: tile-ai/tilelang

Length of output: 110


🏁 Script executed:

# Search for dependent_vars logic around line 1115
cat -n src/transform/inject_pipeline.cc | sed -n '1100,1130p'

Repository: tile-ai/tilelang

Length of output: 1783


🏁 Script executed:

# Check if there are any guards or conditions that prevent dependent conditions
rg -n "dependent" src/transform/inject_pipeline.cc -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 2627


🏁 Script executed:

cat -n testing/python/issue/test_tilelang_issue_1263.py

Repository: tile-ai/tilelang

Length of output: 2370


🏁 Script executed:

# Search for where loop_var_if_wrappers and loop_var_let_wrappers are used together
rg -n "loop_var_if_wrappers|loop_var_let_wrappers" src/transform/inject_pipeline.cc -B 1 -A 1 | head -80

Repository: tile-ai/tilelang

Length of output: 2838


🏁 Script executed:

# Check if there's any special handling to ensure condition substitution happens correctly
rg -n "substituted_condition" src/transform/inject_pipeline.cc -B 3 -A 3

Repository: tile-ai/tilelang

Length of output: 526


🏁 Script executed:

# Check LetWrapper and IfWrapper structure definitions
rg -n "struct LetWrapper|struct IfWrapper" src/transform/inject_pipeline.cc -A 5

Repository: tile-ai/tilelang

Length of output: 207


🏁 Script executed:

# Check if there's any post-processing of conditions or variables after substitution
rg -n "UsesVar|HasVar|GetVars" src/transform/inject_pipeline.cc -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 550


🏁 Script executed:

# Check how Substitute is implemented and what it does with nested variables
rg -n "class Substitute|struct Substitute" --type cpp | head -5

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look for any variable renaming or binding adjustments in the pipeline building
rg -n "RenameVar|VarReplacer|VarSubstitute" src/transform/inject_pipeline.cc -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 42


Fix variable scoping when if-conditions depend on let-bound variables.

The sequential application of let-wrappers (lines 859-870) followed by if-wrappers (lines 872-885) creates incorrect nesting. When a condition depends on let-bound variables (detected at line 1115), the final structure becomes:

IfThenElse(id2 > 1,
  LetStmt(id2, ids2[id],
    LetStmt(id, ids[i], body)))

Here, the condition references id2 which is bound inside the then-branch, and id2's value references id which is bound even deeper. Both become free variables. The original structure should preserve:

LetStmt(id, ids[i],
  LetStmt(id2, ids2[id],
    IfThenElse(id2 > 1, body)))

The issue affects test case _test_kernel_if_cond at line 43 where if id2 > 1 depends on transitively loop-dependent variables.

Use a unified wrapper vector of variant type applied in collection order to preserve original nesting:

🔧 Suggested approach
struct CondWrapper {
  std::variant<LetWrapper, IfWrapper> data;
};
std::vector<CondWrapper> loop_var_wrappers_;

// Apply in reverse to preserve original nesting
for (auto it = loop_var_wrappers_.rbegin(); it != loop_var_wrappers_.rend(); ++it) {
  if (auto* lw = std::get_if<LetWrapper>(&it->data)) {
    // apply let
  } else if (auto* iw = std::get_if<IfWrapper>(&it->data)) {
    // apply if
  }
}
🤖 Prompt for AI Agents
In `@src/transform/inject_pipeline.cc` around lines 872 - 885, The current code
applies let-wrappers (loop_var_let_wrappers_) then if-wrappers
(loop_var_if_wrappers_) separately which reverses nesting and frees let-bound
vars used by if conditions; replace these two separate vectors with a single
ordered vector of variant wrappers (e.g., CondWrapper holding either LetWrapper
or IfWrapper, stored in loop_var_wrappers_) and, when injecting, iterate
loop_var_wrappers_ in reverse and apply each element by type (apply let
semantics for LetWrapper and create IfThenElse with Substitute on iw.condition
using pipeline_loop_->loop_var and normalized_access_index for IfWrapper) so the
original source nesting is preserved and conditions see their let-bound vars.
Ensure you update usages of loop_var_if_wrappers_ and loop_var_let_wrappers_ to
the unified loop_var_wrappers_ and maintain the existing Substitute logic and
iw.span handling.


if (pipeline_info_[block].async) {
auto &local_state = async_states_local[stage];
local_state.producer_head = normalized_access_index;
Expand Down Expand Up @@ -924,6 +949,7 @@ class PipelineRewriter : public StmtExprMutator {
Array<Block> ordered_stmts_;
std::map<int, AsyncStateGlobal> async_states;
std::vector<LetWrapper> loop_var_let_wrappers_;
std::vector<IfWrapper> loop_var_if_wrappers_;
};

/*!
Expand Down Expand Up @@ -1055,6 +1081,7 @@ class PipelineInjector : private StmtExprMutator {
const SeqStmtNode *pipeline_body_seq = nullptr;
std::vector<std::function<Stmt(Stmt)>> rewrap_fns;
std::vector<LetWrapper> loop_var_let_wrappers;
std::vector<IfWrapper> loop_var_if_wrappers;
auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) {
Any node = attr->node;
String attr_key = attr->attr_key;
Expand All @@ -1077,12 +1104,33 @@ class PipelineInjector : private StmtExprMutator {
ICHECK(!if_then_else->else_case.defined())
<< "InjectSoftwarePipeline: Can't handle the body of the loop "
"because the IfThenElse node has an else branch";
PrimExpr condition = if_then_else->condition;
Span span = if_then_else->span;
rewrap_fns.emplace_back(
[condition = std::move(condition), span](Stmt body) -> Stmt {
return IfThenElse(condition, body, Stmt(), span);

// Check if the condition depends on the loop variable or any
// transitively dependent variables (similar to LetStmt handling)
std::unordered_set<const VarNode *> dependent_vars;
dependent_vars.insert(op->loop_var.get());
for (const auto &lw : loop_var_let_wrappers) {
dependent_vars.insert(lw.var.get());
}
bool condition_depends_on_loop = UsesVar(
if_then_else->condition, [&dependent_vars](const VarNode *vn) {
return dependent_vars.count(vn) > 0;
});

if (condition_depends_on_loop) {
// If condition depends on loop variable, we need to push it inside
// each pipeline stage with proper substitution
loop_var_if_wrappers.push_back(
{if_then_else->condition, if_then_else->span});
} else {
// Otherwise, safe to wrap outside the pipeline
PrimExpr condition = if_then_else->condition;
Span span = if_then_else->span;
rewrap_fns.emplace_back(
[condition = std::move(condition), span](Stmt body) -> Stmt {
return IfThenElse(condition, body, Stmt(), span);
});
}
current = if_then_else->then_case;
continue;
}
Expand Down Expand Up @@ -1235,7 +1283,8 @@ class PipelineInjector : private StmtExprMutator {

PipelineRewriter rewriter(buffer_data_to_buffer_, pipeline_allocs,
local_allocs, tvm::ffi::GetRef<For>(op),
pipeline_info, loop_var_let_wrappers);
pipeline_info, loop_var_let_wrappers,
loop_var_if_wrappers);
Stmt pipeline = rewriter.BuildPipeline();

// Store the buffer remapping for updating outer block alloc_buffers
Expand Down
23 changes: 22 additions & 1 deletion testing/python/issue/test_tilelang_issue_1210.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,35 @@ def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), T.int32)):
return fwd_main


def _make_kernel_if_cond(M, N):
dtype = T.bfloat16

@T.prim_func
def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), T.int32)):
with T.Kernel(4, threads=1):
A = T.alloc_shared([N], dtype)
B = T.alloc_shared([N], dtype)

# Regression for a bug where InjectSoftwarePipeline left the loop
# variable as a free var, causing MakePackedAPI to fail
for i in T.Pipelined(4, num_stages=1):
if i > 1:
_id = ids[i]
T.copy(KV[_id, :], A)
T.clear(B)

return fwd_main


def test_make_packed_api_no_free_loop_var():
func = _make_kernel(4, 4)
func, func_if_cond = _make_kernel(4, 4), _make_kernel_if_cond(4, 4)
# Keep warp-specialization/TMA disabled to match the original repro
cfg = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
}
tilelang.compile(func, pass_configs=cfg)
tilelang.compile(func_if_cond, pass_configs=cfg)


if __name__ == "__main__":
Expand Down
72 changes: 52 additions & 20 deletions testing/python/issue/test_tilelang_issue_1263.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,63 @@
import tilelang.language as T


def test_issue_1263_pipeline_no_consumer():
def test_kernel(M, N):
dtype = "bfloat16"

@T.prim_func
def fwd_main(
KV: T.Tensor((M, N), dtype),
ids: T.Tensor((4,), "int32"),
ids2: T.Tensor((4,), "int32"),
):
with T.Kernel(4, threads=1):
A = T.alloc_shared([N], dtype)
B = T.alloc_shared([N], dtype)

for i in T.Pipelined(4, num_stages=1):
id = ids[i]
id2 = ids2[id]
def _test_kernel(M, N):
dtype = "bfloat16"

@T.prim_func
def fwd_main(
KV: T.Tensor((M, N), dtype),
ids: T.Tensor((4,), "int32"),
ids2: T.Tensor((4,), "int32"),
):
with T.Kernel(4, threads=1):
A = T.alloc_shared([N], dtype)
B = T.alloc_shared([N], dtype)

for i in T.Pipelined(4, num_stages=1):
id = ids[i]
id2 = ids2[id]
T.copy(KV[id2, :], A)
T.clear(B)

return fwd_main


def _test_kernel_if_cond(M, N):
dtype = "bfloat16"

@T.prim_func
def fwd_main(
KV: T.Tensor((M, N), dtype),
ids: T.Tensor((4,), "int32"),
ids2: T.Tensor((4,), "int32"),
):
with T.Kernel(4, threads=1):
A = T.alloc_shared([N], dtype)
B = T.alloc_shared([N], dtype)

for i in T.Pipelined(4, num_stages=1):
id = ids[i]
id2 = ids2[id]
if id2 > 1:
T.copy(KV[id2, :], A)
T.clear(B)

return fwd_main
return fwd_main

tilelang.compile(test_kernel(1024, 1024))

def test_issue_1263_pipeline_no_consumer():
tilelang.compile(_test_kernel(1024, 1024))
tilelang.compile(
_test_kernel(1024, 1024),
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},
)
tilelang.compile(_test_kernel_if_cond(1024, 1024))
tilelang.compile(
test_kernel(1024, 1024),
_test_kernel_if_cond(1024, 1024),
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
Expand Down
Loading