Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kASTPrintEnable, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationEnable, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationFormats, String);
TVM_REGISTER_PASS_CONFIG_OPTION(kDeviceCompileFlags, ffi::Array<ffi::String>);
Expand Down
1 change: 1 addition & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
static constexpr const char *kStorageRewriteDetectInplace =
"tl.storage_rewrite_detect_inplace";
static constexpr const char *kASTPrintEnable = "tl.ast_print_enable";
static constexpr const char *kLayoutVisualizationEnable =
"tl.layout_visualization_enable";
static constexpr const char *kLayoutVisualizationFormats =
Expand Down
101 changes: 90 additions & 11 deletions tilelang/analysis/ast_printer.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,102 @@
from tvm import tir
from tvm.tir import PrimFunc
from tvm.tir import PyStmtExprVisitor, PrimFunc, Stmt

from tvm.tir.transform import prim_func_pass
from tvm.tir.stmt_functor import ir_transform


_child_fields = ["body", "block", "seq"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Traverse missing statement children in AST printer

The new tree printer only recurses into fields listed in _child_fields (currently just body, block, seq). TIR statement nodes like IfThenElse use then_case/else_case, and Block can have init, which are statements but won’t be visited or expanded, so whole subtrees disappear from the tree output. This is a regression from the previous ir_transform traversal, and it can mislead debugging whenever kernels include conditionals or other statement children not named body/block/seq. Consider expanding the child field list or delegating traversal to PyStmtExprVisitor for all statement children.

Useful? React with 👍 / 👎.


_stmt_line_limit = 140
_middle_connector = "├── "
_last_connector = "└── "

_normal_indent = " " * 4
_seq_middle_indent = "|" + " " * 3


@tir.functor.visitor
class _ASTPrintVisitor(PyStmtExprVisitor):
def __init__(self) -> None:
super().__init__()
self.indent: list[str] = []

def print_with_clip(self, s: str) -> None:
if len(s) > _stmt_line_limit:
s = s[:_stmt_line_limit] + "..."
print("".join(self.indent) + s)

def print_stmt_brief(self, stmt: Stmt, prefix: str) -> None:
stmt_script = repr(stmt).splitlines()[0].split(" ")[0].strip()
self.print_with_clip(prefix + f"{stmt.__class__.__name__}: " + stmt_script)

def visit_stmt(self, stmt: Stmt) -> None:
child_field_name: str = ""

field_keys = stmt.__class__.__dict__.keys()
# Filter out private/built-in fields.
field_keys = [key for key in field_keys if not key.startswith("_")]

for idx, key in enumerate(field_keys):
# For child fields, we'll handle them specially below instead of printing them in current line.
if key in _child_fields:
child_field_name = key
continue

value = getattr(stmt, key, None)
if value is None:
continue
# Try to get its script representation.
value = repr(value)

is_last_child = idx == len(field_keys) - 1 and not child_field_name
# Add tree-like connector
connector = _last_connector if is_last_child else _middle_connector

# Every member
self.print_with_clip(connector + f"{key}: {value}")

# Handle child fields
if child_field_name and hasattr(stmt, child_field_name):
child = getattr(stmt, child_field_name)

if child_field_name != "seq":
prefix = _last_connector + f"{child_field_name}: "
self.print_stmt_brief(child, prefix)
self.indent.append(_normal_indent)
self.visit_stmt(child)
self.indent.pop()
else:
# Special output format for SeqStmt
for i, child_node in enumerate(child):
is_last_child = i == len(child) - 1
prefix = (_last_connector if is_last_child else _middle_connector) + f"seq{i}: "
self.print_stmt_brief(child_node, prefix)
self.indent.append(_normal_indent if is_last_child else _seq_middle_indent)
self.visit_stmt(child_node)
self.indent.pop()


def ASTPrinter():
"""
Print the AST of a given tilelang module for debugging.
"""
A visitor pass that renders the TileLang AST hierarchy in a visual tree format.

def pre_visit(statement: tir.Stmt) -> None:
"""
Pre-order visitor to print all visited statements.
"""
Comparing with TL script, this printer is more suitable for debugging
and understanding the internal structure of TensorIR, like the class structure of
each node and their connections.

print(f"Visiting statement: {type(statement)}, {statement}")
This printer generates a human-readable, tree-structured representation of the
Abstract Syntax Tree (AST). It uses ASCII/Unicode connectors to visualize
parent-child relationships, making it easier to inspect nested structures
(e.g., loops, blocks, scopes) and verify compiler transformations.
"""

def pass_fn(func: PrimFunc, mod, ctx) -> PrimFunc:
new_body = ir_transform(func.body, pre_visit, None)
return func.with_body(new_body)
print(f"PrimFunc(params={func.params}, ret_type={func.ret_type}, buffer_map={func.buffer_map}, attrs={func.attrs})")
func_body_prefix = _last_connector + "body="
visitor = _ASTPrintVisitor()
visitor.print_stmt_brief(func.body, func_body_prefix)
visitor.visit_stmt(func.body)
visitor.indent.append(_normal_indent)
return func

return prim_func_pass(pass_fn, opt_level=0)
11 changes: 9 additions & 2 deletions tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool:
return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False))


def should_enable_ast_print(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_AST_PRINT_ENABLE, False))


def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
Expand Down Expand Up @@ -112,8 +118,9 @@ def PreLowerSemanticCheck(mod: IRModule) -> None:
Note: This is a validation-only pipeline of passes and does not modify or return the module.
"""

# Debug
# tilelang.analysis.ASTPrinter()(mod)
# Print AST for debugging purpose
if should_enable_ast_print():
tilelang.analysis.ASTPrinter()(mod)
# Check if there are any invalid nested loops.
tilelang.analysis.NestedLoopChecker()(mod)
# Check if there are any invalid symbolic T.Parallel + fragment access.
Expand Down
3 changes: 3 additions & 0 deletions tilelang/transform/pass_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ class PassConfigKey(str, Enum):
TL_FORCE_LET_INLINE = "tl.force_let_inline"
"""Force TileLang to inline let bindings during simplification. Default: False"""

TL_AST_PRINT_ENABLE = "tl.ast_print_enable"
"""Enable TIR AST printing for debugging purposes. Default: False"""

TL_LAYOUT_VISUALIZATION_ENABLE = "tl.layout_visualization_enable"
"""Enable layout inference visualization. Default: False"""

Expand Down
Loading