Skip to content

[Bugfix] Support loop-dependent conditions in IfThenElse within T.Pipelined#1799

Merged
SiriusNEO merged 3 commits intotile-ai:mainfrom
ljwljwljwljw:fix-pipelined-loop-cond
Feb 6, 2026
Merged

[Bugfix] Support loop-dependent conditions in IfThenElse within T.Pipelined#1799
SiriusNEO merged 3 commits intotile-ai:mainfrom
ljwljwljwljw:fix-pipelined-loop-cond

Conversation

@ljwljwljwljw
Copy link
Contributor

@ljwljwljwljw ljwljwljwljw commented Feb 5, 2026

This PR extends the pipeline injection transform to handle IfThenElse statements whose conditions depend on the pipeline loop variable, fixing a regression similar to #1210 but for conditional statements.

Problem

When an if statement with a condition depending on the pipeline loop variable (e.g., if i > 1:) was placed inside a T.Pipelined block, the condition may be incorrectly hoisted outside the pipeline loop, causing the loop variable to become undefined and resulting in the error:

InternalError: variables [i] are used, but are not passed in as API arguments

The IfThenElse handling in inject_pipeline.cc always added conditions to rewrap_fns (wrapping them outside the pipeline), without checking if the condition uses the loop variable. This is the same bug that was fixed for LetStmt in PR #1212, but was not addressed for conditional statements

Example:

def _make_kernel_if_cond(M, N):
    dtype = T.bfloat16

    @T.prim_func
    def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), T.int32)):
        with T.Kernel(4, threads=1):
            A = T.alloc_shared([N], dtype)
            B = T.alloc_shared([N], dtype)

            for i in T.Pipelined(4, num_stages=1):
                if i > 1: # <-this line!!!
                    _id = ids[i]
                    T.copy(KV[_id, :], A)
                    T.clear(B)
    return fwd_main

Before this PR, above code would be transformed as:

// WRONG: Condition evaluated outside, i is undefined
if (i > 1)
  pipeline_loop:
    for i in [0, 1, 2, 3]:
      body

Solution

This PR applies the same strategy used for LetStmt to IfThenElse conditions:

  1. Introduced IfWrapper struct to track if conditions that depend on the loop variable
  2. Added dependency detection that checks if an if condition uses:
  • The pipeline loop variable directly, OR
  • Any variable transitively dependent on the loop variable
  1. Per-iteration substitution: Loop-dependent conditions are pushed inside each pipeline stage with the loop variable substituted for that iteration

Key Changes

inject_pipeline.cc:

  • Added IfWrapper struct to store if conditions and spans
  • Extended PipelineRewriter to accept and process loop_var_if_wrappers
  • Updated IfThenElse detection logic to check for loop variable dependencies
  • Apply if wrappers inside each rewritten pipeline block with proper substitution

Related Issues/PRs:

#1210
#1212
#1263

Summary by CodeRabbit

  • New Features

    • Pipeline loop compilation now wraps and stages conditional branches that depend on the pipeline loop variable, preserving correct per-stage semantics and enabling conditional control flow inside pipelined regions.
  • Tests

    • Added and expanded tests exercising conditional execution inside pipelined loops across multiple compilation configurations.

This commit applies the same strategy used for LetStmt to IfThenElse conditions:

1.Introduced IfWrapper struct to track if conditions that depend on the loop variable
2.Added dependency detection that checks if an if condition uses:
- The pipeline loop variable directly, OR
- Any variable transitively dependent on the loop variable
3.Loop-dependent conditions are pushed inside each pipeline stage with the
loop variable properly substituted for that iteration
@github-actions
Copy link

github-actions bot commented Feb 5, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 5, 2026

📝 Walkthrough

Walkthrough

Adds IfWrapper support to the pipeline injection flow: collects if-conditions depending on the pipeline loop variable, threads them through PipelineInjector → PipelineRewriter, and applies per-stage conditional substitutions during pipeline construction alongside existing LetWrapper handling.

Changes

Cohort / File(s) Summary
Pipeline injection core
src/transform/inject_pipeline.cc
Add IfWrapper struct; extend PipelineRewriter and PipelineInjector to store and propagate loop_var_if_wrappers; update constructor signatures and members; apply per-stage IfThenElse condition substitutions during loop-body rewrite and BuildPipeline.
Tests exercising conditionals
testing/python/issue/test_tilelang_issue_1210.py, testing/python/issue/test_tilelang_issue_1263.py
Add helper kernels with loop-var-dependent conditionals (_make_kernel_if_cond, _test_kernel_if_cond) and update tests to compile/run these variants to cover the new IfWrapper behavior.

Sequence Diagram(s)

sequenceDiagram
    participant Inj as PipelineInjector
    participant Coll as TraversalLogic
    participant Rw as PipelineRewriter
    participant Build as BuildPipeline

    Inj->>Coll: Analyze For-node body (collect LetWrapper & IfWrapper)
    Coll->>Coll: Identify IfThenElse nodes whose conditions depend on loop var
    Coll->>Coll: Track transitive LetStmt deps and group wrappers
    Coll->>Inj: Return loop_var_if_wrappers vector

    Inj->>Rw: Construct PipelineRewriter with loop_var_if_wrappers
    Rw->>Rw: Store loop_var_if_wrappers_ member

    Rw->>Build: Rewrite loop body per stage
    Build->>Build: Substitute IfWrapper conditions per stage context
    Build->>Build: Apply LetWrapper substitutions per stage
    Build->>Rw: Return transformed stages

    Rw->>Inj: Return rewritten pipeline
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Poem

🐰 Hopping through loops, I bind each if with care,

I trace the loop-var whispers hidden in the air,
Per-stage I tuck conditions tight,
So pipeline stages know when to light,
A rabbit's wrap — neat, conditional flair.

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 15.38% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main bugfix: adding support for loop-dependent conditions in IfThenElse statements within T.Pipelined.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@src/transform/inject_pipeline.cc`:
- Around line 872-885: The current code applies let-wrappers
(loop_var_let_wrappers_) then if-wrappers (loop_var_if_wrappers_) separately
which reverses nesting and frees let-bound vars used by if conditions; replace
these two separate vectors with a single ordered vector of variant wrappers
(e.g., CondWrapper holding either LetWrapper or IfWrapper, stored in
loop_var_wrappers_) and, when injecting, iterate loop_var_wrappers_ in reverse
and apply each element by type (apply let semantics for LetWrapper and create
IfThenElse with Substitute on iw.condition using pipeline_loop_->loop_var and
normalized_access_index for IfWrapper) so the original source nesting is
preserved and conditions see their let-bound vars. Ensure you update usages of
loop_var_if_wrappers_ and loop_var_let_wrappers_ to the unified
loop_var_wrappers_ and maintain the existing Substitute logic and iw.span
handling.
🧹 Nitpick comments (1)
src/transform/inject_pipeline.cc (1)

1106-1131: Transitive dependency detection looks correct.

The logic properly detects conditions that depend on the loop variable either directly or transitively through let-bound variables. This mirrors the existing LetStmt handling.

Consider extracting the duplicate dependent_vars construction (also at lines 1144-1148) into a helper lambda for maintainability:

♻️ Optional refactor
+    auto build_dependent_vars = [&]() {
+      std::unordered_set<const VarNode *> vars;
+      vars.insert(op->loop_var.get());
+      for (const auto &lw : loop_var_let_wrappers) {
+        vars.insert(lw.var.get());
+      }
+      return vars;
+    };
+
     // In IfThenElse handling:
-    std::unordered_set<const VarNode *> dependent_vars;
-    dependent_vars.insert(op->loop_var.get());
-    for (const auto &lw : loop_var_let_wrappers) {
-      dependent_vars.insert(lw.var.get());
-    }
+    auto dependent_vars = build_dependent_vars();

Comment on lines 872 to 885
// Similarly, handle If-wrappers whose conditions depend on the
// pipeline loop var.
if (!loop_var_if_wrappers_.empty()) {
BlockNode *n = new_block.CopyOnWrite();
Stmt inner = n->body;
for (auto it = loop_var_if_wrappers_.rbegin();
it != loop_var_if_wrappers_.rend(); ++it) {
const auto &iw = *it;
PrimExpr substituted_condition = Substitute(
iw.condition, {{pipeline_loop_->loop_var, normalized_access_index}});
inner = IfThenElse(substituted_condition, inner, Stmt(), iw.span);
}
n->body = inner;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the actual code at the mentioned lines
cat -n src/transform/inject_pipeline.cc | sed -n '850,900p'

Repository: tile-ai/tilelang

Length of output: 2522


🏁 Script executed:

# Also check what's before line 872 to understand let-wrapper handling
cat -n src/transform/inject_pipeline.cc | sed -n '859,885p'

Repository: tile-ai/tilelang

Length of output: 1380


🏁 Script executed:

# Search for the data structures used: loop_var_if_wrappers_, loop_var_let_wrappers_
rg -n "loop_var_if_wrappers_|loop_var_let_wrappers_" src/transform/inject_pipeline.cc | head -30

Repository: tile-ai/tilelang

Length of output: 622


🏁 Script executed:

# Find the test case mentioned
find . -name "test_tilelang_issue_1263.py" -o -name "*issue_1263*"

Repository: tile-ai/tilelang

Length of output: 110


🏁 Script executed:

# Search for dependent_vars logic around line 1115
cat -n src/transform/inject_pipeline.cc | sed -n '1100,1130p'

Repository: tile-ai/tilelang

Length of output: 1783


🏁 Script executed:

# Check if there are any guards or conditions that prevent dependent conditions
rg -n "dependent" src/transform/inject_pipeline.cc -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 2627


🏁 Script executed:

cat -n testing/python/issue/test_tilelang_issue_1263.py

Repository: tile-ai/tilelang

Length of output: 2370


🏁 Script executed:

# Search for where loop_var_if_wrappers and loop_var_let_wrappers are used together
rg -n "loop_var_if_wrappers|loop_var_let_wrappers" src/transform/inject_pipeline.cc -B 1 -A 1 | head -80

Repository: tile-ai/tilelang

Length of output: 2838


🏁 Script executed:

# Check if there's any special handling to ensure condition substitution happens correctly
rg -n "substituted_condition" src/transform/inject_pipeline.cc -B 3 -A 3

Repository: tile-ai/tilelang

Length of output: 526


🏁 Script executed:

# Check LetWrapper and IfWrapper structure definitions
rg -n "struct LetWrapper|struct IfWrapper" src/transform/inject_pipeline.cc -A 5

Repository: tile-ai/tilelang

Length of output: 207


🏁 Script executed:

# Check if there's any post-processing of conditions or variables after substitution
rg -n "UsesVar|HasVar|GetVars" src/transform/inject_pipeline.cc -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 550


🏁 Script executed:

# Check how Substitute is implemented and what it does with nested variables
rg -n "class Substitute|struct Substitute" --type cpp | head -5

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look for any variable renaming or binding adjustments in the pipeline building
rg -n "RenameVar|VarReplacer|VarSubstitute" src/transform/inject_pipeline.cc -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 42


Fix variable scoping when if-conditions depend on let-bound variables.

The sequential application of let-wrappers (lines 859-870) followed by if-wrappers (lines 872-885) creates incorrect nesting. When a condition depends on let-bound variables (detected at line 1115), the final structure becomes:

IfThenElse(id2 > 1,
  LetStmt(id2, ids2[id],
    LetStmt(id, ids[i], body)))

Here, the condition references id2 which is bound inside the then-branch, and id2's value references id which is bound even deeper. Both become free variables. The original structure should preserve:

LetStmt(id, ids[i],
  LetStmt(id2, ids2[id],
    IfThenElse(id2 > 1, body)))

The issue affects test case _test_kernel_if_cond at line 43 where if id2 > 1 depends on transitively loop-dependent variables.

Use a unified wrapper vector of variant type applied in collection order to preserve original nesting:

🔧 Suggested approach
struct CondWrapper {
  std::variant<LetWrapper, IfWrapper> data;
};
std::vector<CondWrapper> loop_var_wrappers_;

// Apply in reverse to preserve original nesting
for (auto it = loop_var_wrappers_.rbegin(); it != loop_var_wrappers_.rend(); ++it) {
  if (auto* lw = std::get_if<LetWrapper>(&it->data)) {
    // apply let
  } else if (auto* iw = std::get_if<IfWrapper>(&it->data)) {
    // apply if
  }
}
🤖 Prompt for AI Agents
In `@src/transform/inject_pipeline.cc` around lines 872 - 885, The current code
applies let-wrappers (loop_var_let_wrappers_) then if-wrappers
(loop_var_if_wrappers_) separately which reverses nesting and frees let-bound
vars used by if conditions; replace these two separate vectors with a single
ordered vector of variant wrappers (e.g., CondWrapper holding either LetWrapper
or IfWrapper, stored in loop_var_wrappers_) and, when injecting, iterate
loop_var_wrappers_ in reverse and apply each element by type (apply let
semantics for LetWrapper and create IfThenElse with Substitute on iw.condition
using pipeline_loop_->loop_var and normalized_access_index for IfWrapper) so the
original source nesting is preserved and conditions see their let-bound vars.
Ensure you update usages of loop_var_if_wrappers_ and loop_var_let_wrappers_ to
the unified loop_var_wrappers_ and maintain the existing Substitute logic and
iw.span handling.

@SiriusNEO SiriusNEO merged commit af30ac2 into tile-ai:main Feb 6, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants