Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from e47e76 to 001022
22 changes: 17 additions & 5 deletions tilelang/language/print_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
It includes functionality to print variables, print values in buffers, conditionally execute debug prints and assert.
"""

from tilelang.language.v2.builder import Builder
from tvm import tir
from typing import Any
import tilelang.language as T
Expand Down Expand Up @@ -123,19 +124,30 @@ def print_local_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffe
_IS_CUDA_AVAILABLE = check_cuda_availability()


def get_stack_str(msg, stacklevel=1):
stack = Builder.current().get_fileline_stack(stacklevel)
msg = msg + "\n"
for fileline, lineno, macro_name in stack:
msg += f" at {fileline}:{lineno} in {macro_name}\n"
return msg


@macro
def device_assert(condition: tir.PrimExpr, msg: str = ""):
def device_assert(condition: tir.PrimExpr, msg: str = "", no_stack_info=False):
"""
Device-side assert emulation.
Emits a device-side assert call on CUDA targets when CUDA is available.
The assert is always enabled and cannot be disabled at runtime.
"""
if _IS_CUDA_AVAILABLE:
if msg == "":
T.call_intrin("void", tir.op.Op.get("tl.device_assert"), condition)
if no_stack_info:
if msg == "":
T.call_intrin("void", tir.op.Op.get("tl.device_assert"), condition)
else:
warnings.warn("Non-empty msg may slightly slow down the kernel", stacklevel=2)
T.call_intrin("void", tir.op.Op.get("tl.device_assert_with_msg"), condition, msg)
else:
warnings.warn("Non-empty msg may slightly slow down the kernel", stacklevel=2)
T.call_intrin("void", tir.op.Op.get("tl.device_assert_with_msg"), condition, msg)
T.call_intrin("void", tir.op.Op.get("tl.device_assert_with_msg"), condition, get_stack_str(msg, stacklevel=2))


def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr:
Expand Down
29 changes: 23 additions & 6 deletions tilelang/language/v2/ast.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
import ast
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Generic, Any, Literal, TypeVar
from contextlib import AbstractContextManager
from collections.abc import Iterable
Expand Down Expand Up @@ -249,11 +250,12 @@ def override(self, name: str):


class DSLMutator(ast.NodeTransformer):
def __init__(self, nonlocals: dict[str, Any], globals: dict[str, Any]):
def __init__(self, nonlocals: dict[str, Any], globals: dict[str, Any], filename: str):
self.tmp_counter = 0
self.nonlocals = nonlocals
self.globals = globals
self.extra_type_hints: dict[str, Any] = {}
self.filename = filename

def get_tmp(self) -> str:
name = f"__{self.tmp_counter}"
Expand Down Expand Up @@ -469,13 +471,17 @@ def visit_FunctionDef(self, node: ast.FunctionDef):
node = self.generic_visit(node)
node.body = stmts + node.body
node.decorator_list.clear()
name = node.name
node = SpanAttacher("__tb_fl", "__tb_fn").visit(node)
return quote1(
f"def make_closure({', '.join(self.nonlocals.keys())}):\n"
f" def {node.name}(__tb):\n"
f" def {name}(__tb):\n"
f" __tb_fl = '{self.filename}'\n"
f" __tb_fn = '{name}'\n"
" range = __tb.override('range')\n"
" pass\n"
f" return {node.name}\n"
f" return {node.name}",
f" return {name}\n"
f" return {name}",
passes=[node],
)

Expand Down Expand Up @@ -573,6 +579,18 @@ def visit_Name(self, node: ast.Name):
return node


class SpanAttacher(ast.NodeTransformer):
def __init__(self, filename_var: str, func_name_var: str):
self.filename_var = filename_var
self.func_name_var = func_name_var

def visit(self, node: ast.AST):
node = self.generic_visit(node)
if isinstance(node, ast.stmt) and hasattr(node, "lineno"):
return quote(f"__tb.set_fileline({self.filename_var}, {node.lineno}, {self.func_name_var})") + [node]
return node


_P = ParamSpec("_P")


Expand Down Expand Up @@ -627,9 +645,8 @@ def mutate(func: Callable[_P, _T]) -> IRGenerator[_P, _T]:
# def bar(): x
# return bar
# ```
mut = DSLMutator(nonlocals, func.__globals__)
mut = DSLMutator(nonlocals, func.__globals__, Path(filename).name)
tree = mut.visit(tree)

make_closure = utils.get_compiled_object(
tree,
"make_closure",
Expand Down
16 changes: 16 additions & 0 deletions tilelang/language/v2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ def __init__(self):
self.out_tensor_cnt = 0
self.constexpr_var = set()
self.lazy_jit = False
self.current_file = "<unknown>"
self.current_line = 0
self.current_macro_name = "<unknown-macro>"
# stack to record caller fileline, not callee fileline
self.macro_fileline_stack: list[tuple[str, int, str]] = []

@classmethod
def current(cls) -> Self:
Expand Down Expand Up @@ -212,9 +217,11 @@ def macro(self, name=None, annotations=None):
# def bar():
# c = foo(1) # macro generates let y = x + 1
# d = c # d = c should lay inside frame of `let y = x + 1`
self.macro_fileline_stack.append((self.current_file, self.current_line, self.current_macro_name))
self.frames.append(MacroFrame())
yield
self.frames[pos] = ExitedMacroFrame()
self.macro_fileline_stack.pop()
self.name_inside_frame, self.macro_arg_annot = save

def get(self) -> PrimFunc:
Expand Down Expand Up @@ -675,6 +682,15 @@ def constexpr(self, name: str, dtype: str = "int32") -> Var:
self.constexpr_var.add(var)
return var

def set_fileline(self, filename: str, lineno: int, name: str):
self.current_file = filename
self.current_line = lineno
self.current_macro_name = name

def get_fileline_stack(self, stacklevel=1):
stack = self.macro_fileline_stack + [(self.current_file, self.current_line, self.current_macro_name)]
return stack[: len(stack) - stacklevel + 1]


_P = ParamSpec("_P")
_T = TypeVar("_T")
Expand Down
Loading