diff --git a/3rdparty/tvm b/3rdparty/tvm index e47e76a2a..001022bdb 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit e47e76a2a0d565e02b6474c06f9f47e1374821f3 +Subproject commit 001022bdb2dbb337d242eed9d208f8555b8edc98 diff --git a/tilelang/language/print_op.py b/tilelang/language/print_op.py index bbaa119ed..c7ef55f7d 100644 --- a/tilelang/language/print_op.py +++ b/tilelang/language/print_op.py @@ -3,6 +3,7 @@ It includes functionality to print variables, print values in buffers, conditionally execute debug prints and assert. """ +from tilelang.language.v2.builder import Builder from tvm import tir from typing import Any import tilelang.language as T @@ -123,19 +124,30 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffe _IS_CUDA_AVAILABLE = check_cuda_availability() +def get_stack_str(msg, stacklevel=1): + stack = Builder.current().get_fileline_stack(stacklevel) + msg = msg + "\n" + for fileline, lineno, macro_name in stack: + msg += f" at {fileline}:{lineno} in {macro_name}\n" + return msg + + @macro -def device_assert(condition: tir.PrimExpr, msg: str = ""): +def device_assert(condition: tir.PrimExpr, msg: str = "", no_stack_info=False): """ Device-side assert emulation. Emits a device-side assert call on CUDA targets when CUDA is available. The assert is always enabled and cannot be disabled at runtime. """ if _IS_CUDA_AVAILABLE: - if msg == "": - T.call_intrin("void", tir.op.Op.get("tl.device_assert"), condition) + if no_stack_info: + if msg == "": + T.call_intrin("void", tir.op.Op.get("tl.device_assert"), condition) + else: + warnings.warn("Non-empty msg may slightly slow down the kernel", stacklevel=2) + T.call_intrin("void", tir.op.Op.get("tl.device_assert_with_msg"), condition, msg) else: - warnings.warn("Non-empty msg may slightly slow down the kernel", stacklevel=2) - T.call_intrin("void", tir.op.Op.get("tl.device_assert_with_msg"), condition, msg) + T.call_intrin("void", tir.op.Op.get("tl.device_assert_with_msg"), condition, get_stack_str(msg, stacklevel=2)) def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr: diff --git a/tilelang/language/v2/ast.py b/tilelang/language/v2/ast.py index 7734c724e..18d071b13 100644 --- a/tilelang/language/v2/ast.py +++ b/tilelang/language/v2/ast.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast from dataclasses import dataclass, field +from pathlib import Path from typing import Callable, Generic, Any, Literal, TypeVar from contextlib import AbstractContextManager from collections.abc import Iterable @@ -249,11 +250,12 @@ def override(self, name: str): class DSLMutator(ast.NodeTransformer): - def __init__(self, nonlocals: dict[str, Any], globals: dict[str, Any]): + def __init__(self, nonlocals: dict[str, Any], globals: dict[str, Any], filename: str): self.tmp_counter = 0 self.nonlocals = nonlocals self.globals = globals self.extra_type_hints: dict[str, Any] = {} + self.filename = filename def get_tmp(self) -> str: name = f"__{self.tmp_counter}" @@ -469,13 +471,17 @@ def visit_FunctionDef(self, node: ast.FunctionDef): node = self.generic_visit(node) node.body = stmts + node.body node.decorator_list.clear() + name = node.name + node = SpanAttacher("__tb_fl", "__tb_fn").visit(node) return quote1( f"def make_closure({', '.join(self.nonlocals.keys())}):\n" - f" def {node.name}(__tb):\n" + f" def {name}(__tb):\n" + f" __tb_fl = '{self.filename}'\n" + f" __tb_fn = '{name}'\n" " range = __tb.override('range')\n" " pass\n" - f" return {node.name}\n" - f" return {node.name}", + f" return {name}\n" + f" return {name}", passes=[node], ) @@ -573,6 +579,18 @@ def visit_Name(self, node: ast.Name): return node +class SpanAttacher(ast.NodeTransformer): + def __init__(self, filename_var: str, func_name_var: str): + self.filename_var = filename_var + self.func_name_var = func_name_var + + def visit(self, node: ast.AST): + node = self.generic_visit(node) + if isinstance(node, ast.stmt) and hasattr(node, "lineno"): + return quote(f"__tb.set_fileline({self.filename_var}, {node.lineno}, {self.func_name_var})") + [node] + return node + + _P = ParamSpec("_P") @@ -627,9 +645,8 @@ def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]: # def bar(): x # return bar # ``` - mut = DSLMutator(nonlocals, func.__globals__) + mut = DSLMutator(nonlocals, func.__globals__, Path(filename).name) tree = mut.visit(tree) - make_closure = utils.get_compiled_object( tree, "make_closure", diff --git a/tilelang/language/v2/builder.py b/tilelang/language/v2/builder.py index 70dc1690b..157f20f3e 100644 --- a/tilelang/language/v2/builder.py +++ b/tilelang/language/v2/builder.py @@ -173,6 +173,11 @@ def __init__(self): self.out_tensor_cnt = 0 self.constexpr_var = set() self.lazy_jit = False + self.current_file = "" + self.current_line = 0 + self.current_macro_name = "" + # stack to record caller fileline, not callee fileline + self.macro_fileline_stack: list[tuple[str, int, str]] = [] @classmethod def current(cls) -> Self: @@ -212,9 +217,11 @@ def macro(self, name=None, annotations=None): # def bar(): # c = foo(1) # macro generates let y = x + 1 # d = c # d = c should lay inside frame of `let y = x + 1` + self.macro_fileline_stack.append((self.current_file, self.current_line, self.current_macro_name)) self.frames.append(MacroFrame()) yield self.frames[pos] = ExitedMacroFrame() + self.macro_fileline_stack.pop() self.name_inside_frame, self.macro_arg_annot = save def get(self) -> PrimFunc: @@ -675,6 +682,15 @@ def constexpr(self, name: str, dtype: str = "int32") -> Var: self.constexpr_var.add(var) return var + def set_fileline(self, filename: str, lineno: int, name: str): + self.current_file = filename + self.current_line = lineno + self.current_macro_name = name + + def get_fileline_stack(self, stacklevel=1): + stack = self.macro_fileline_stack + [(self.current_file, self.current_line, self.current_macro_name)] + return stack[: len(stack) - stacklevel + 1] + _P = ParamSpec("_P") _T = TypeVar("_T")