diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 1e065ed02..df0b1abc3 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -34,6 +34,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kASTPrintEnable, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationEnable, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationFormats, String); TVM_REGISTER_PASS_CONFIG_OPTION(kDeviceCompileFlags, ffi::Array); diff --git a/src/op/builtin.h b/src/op/builtin.h index fd0bb22e2..ea0f5d985 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -55,6 +55,7 @@ static constexpr const char *kDisableWGMMA = "tl.disable_wgmma"; static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect"; static constexpr const char *kStorageRewriteDetectInplace = "tl.storage_rewrite_detect_inplace"; +static constexpr const char *kASTPrintEnable = "tl.ast_print_enable"; static constexpr const char *kLayoutVisualizationEnable = "tl.layout_visualization_enable"; static constexpr const char *kLayoutVisualizationFormats = diff --git a/tilelang/analysis/ast_printer.py b/tilelang/analysis/ast_printer.py index e634e0271..fe94505a5 100644 --- a/tilelang/analysis/ast_printer.py +++ b/tilelang/analysis/ast_printer.py @@ -1,23 +1,102 @@ from tvm import tir -from tvm.tir import PrimFunc +from tvm.tir import PyStmtExprVisitor, PrimFunc, Stmt + from tvm.tir.transform import prim_func_pass -from tvm.tir.stmt_functor import ir_transform + + +_child_fields = ["body", "block", "seq"] + +_stmt_line_limit = 140 +_middle_connector = "├── " +_last_connector = "└── " + +_normal_indent = " " * 4 +_seq_middle_indent = "|" + " " * 3 + + +@tir.functor.visitor +class _ASTPrintVisitor(PyStmtExprVisitor): + def __init__(self) -> None: + super().__init__() + self.indent: list[str] = [] + + def print_with_clip(self, s: str) -> None: + if len(s) > _stmt_line_limit: + s = s[:_stmt_line_limit] + "..." + print("".join(self.indent) + s) + + def print_stmt_brief(self, stmt: Stmt, prefix: str) -> None: + stmt_script = repr(stmt).splitlines()[0].split(" ")[0].strip() + self.print_with_clip(prefix + f"{stmt.__class__.__name__}: " + stmt_script) + + def visit_stmt(self, stmt: Stmt) -> None: + child_field_name: str = "" + + field_keys = stmt.__class__.__dict__.keys() + # Filter out private/built-in fields. + field_keys = [key for key in field_keys if not key.startswith("_")] + + for idx, key in enumerate(field_keys): + # For child fields, we'll handle them specially below instead of printing them in current line. + if key in _child_fields: + child_field_name = key + continue + + value = getattr(stmt, key, None) + if value is None: + continue + # Try to get its script representation. + value = repr(value) + + is_last_child = idx == len(field_keys) - 1 and not child_field_name + # Add tree-like connector + connector = _last_connector if is_last_child else _middle_connector + + # Every member + self.print_with_clip(connector + f"{key}: {value}") + + # Handle child fields + if child_field_name and hasattr(stmt, child_field_name): + child = getattr(stmt, child_field_name) + + if child_field_name != "seq": + prefix = _last_connector + f"{child_field_name}: " + self.print_stmt_brief(child, prefix) + self.indent.append(_normal_indent) + self.visit_stmt(child) + self.indent.pop() + else: + # Special output format for SeqStmt + for i, child_node in enumerate(child): + is_last_child = i == len(child) - 1 + prefix = (_last_connector if is_last_child else _middle_connector) + f"seq{i}: " + self.print_stmt_brief(child_node, prefix) + self.indent.append(_normal_indent if is_last_child else _seq_middle_indent) + self.visit_stmt(child_node) + self.indent.pop() def ASTPrinter(): """ - Print the AST of a given tilelang module for debugging. - """ + A visitor pass that renders the TileLang AST hierarchy in a visual tree format. - def pre_visit(statement: tir.Stmt) -> None: - """ - Pre-order visitor to print all visited statements. - """ + Comparing with TL script, this printer is more suitable for debugging + and understanding the internal structure of TensorIR, like the class structure of + each node and their connections. - print(f"Visiting statement: {type(statement)}, {statement}") + This printer generates a human-readable, tree-structured representation of the + Abstract Syntax Tree (AST). It uses ASCII/Unicode connectors to visualize + parent-child relationships, making it easier to inspect nested structures + (e.g., loops, blocks, scopes) and verify compiler transformations. + """ def pass_fn(func: PrimFunc, mod, ctx) -> PrimFunc: - new_body = ir_transform(func.body, pre_visit, None) - return func.with_body(new_body) + print(f"PrimFunc(params={func.params}, ret_type={func.ret_type}, buffer_map={func.buffer_map}, attrs={func.attrs})") + func_body_prefix = _last_connector + "body=" + visitor = _ASTPrintVisitor() + visitor.print_stmt_brief(func.body, func_body_prefix) + visitor.visit_stmt(func.body) + visitor.indent.append(_normal_indent) + return func return prim_func_pass(pass_fn, opt_level=0) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 0e72c837e..17b351d86 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -63,6 +63,12 @@ def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool: return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False)) +def should_enable_ast_print(pass_ctx: PassContext | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_AST_PRINT_ENABLE, False)) + + def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() @@ -112,8 +118,9 @@ def PreLowerSemanticCheck(mod: IRModule) -> None: Note: This is a validation-only pipeline of passes and does not modify or return the module. """ - # Debug - # tilelang.analysis.ASTPrinter()(mod) + # Print AST for debugging purpose + if should_enable_ast_print(): + tilelang.analysis.ASTPrinter()(mod) # Check if there are any invalid nested loops. tilelang.analysis.NestedLoopChecker()(mod) # Check if there are any invalid symbolic T.Parallel + fragment access. diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index b42ccd7ed..29fd6e536 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -84,6 +84,9 @@ class PassConfigKey(str, Enum): TL_FORCE_LET_INLINE = "tl.force_let_inline" """Force TileLang to inline let bindings during simplification. Default: False""" + TL_AST_PRINT_ENABLE = "tl.ast_print_enable" + """Enable TIR AST printing for debugging purposes. Default: False""" + TL_LAYOUT_VISUALIZATION_ENABLE = "tl.layout_visualization_enable" """Enable layout inference visualization. Default: False"""