From ed31be7d55b0cc9a35e3e89a72cf9b7aff8a50c7 Mon Sep 17 00:00:00 2001 From: KEKE046 Date: Thu, 8 Jan 2026 06:44:23 +0000 Subject: [PATCH 01/13] [Feat] Add tilelang autodd for delta debugging --- tilelang/__init__.py | 135 +++-- tilelang/autodd.py | 1180 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1261 insertions(+), 54 deletions(-) create mode 100644 tilelang/autodd.py diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 420516173..8148555a7 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -7,6 +7,24 @@ from pathlib import Path +def _is_running_autodd() -> bool: + orig_argv = getattr(sys, "orig_argv", None) + if orig_argv is None: + return False + if "-mtilelang.autodd" in orig_argv: + return True + pos = orig_argv.index("-m") if "-m" in orig_argv else -1 + if pos != -1 and pos + 1 < len(orig_argv): + module_name = orig_argv[pos + 1] + if module_name == "tilelang.autodd" or module_name.startswith("tilelang.autodd."): + return True + return False + + +# check if we are running under AutoDD +_RUNNING_AUTODD = _is_running_autodd() + + def _compute_version() -> str: """Return the package version without being polluted by unrelated installs. @@ -65,7 +83,10 @@ def set_log_level(level): def _init_logger(): """Initialize the logger specific for this module with custom settings and a Tqdm-based handler.""" - from tqdm.auto import tqdm + try: + from tqdm.auto import tqdm + except ImportError: + tqdm = None class TqdmLoggingHandler(logging.Handler): """Custom logging handler that directs log output to tqdm progress bar to avoid interference.""" @@ -78,7 +99,8 @@ def emit(self, record): """Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted.""" try: msg = self.format(record) - tqdm.write(msg) + if tqdm is not None: + tqdm.write(msg) except Exception: self.handleError(record) @@ -93,7 +115,10 @@ def emit(self, record): set_log_level("INFO") -_init_logger() +# Skip logger initialization when running under AutoDD +if not _RUNNING_AUTODD: + _init_logger() + del _init_logger @@ -116,56 +141,58 @@ def lazy_init(self, name, mode=ctypes.DEFAULT_MODE, *args, **kwargs): ctypes.CDLL.__init__ = old_init -with _lazy_load_lib(): - from .env import enable_cache, disable_cache, is_cache_enabled # noqa: F401 - from .env import env as env # noqa: F401 - - import tvm - import tvm.base # noqa: F401 - from tvm import DataType # noqa: F401 - - # Setup tvm search path before importing tvm - from . import libinfo - - def _load_tile_lang_lib(): - """Load Tile Lang lib""" - if sys.platform.startswith("win32") and sys.version_info >= (3, 8): - for path in libinfo.get_dll_directories(): - os.add_dll_directory(path) - # pylint: disable=protected-access - lib_name = "tilelang" if tvm.base._RUNTIME_ONLY else "tilelang_module" - # pylint: enable=protected-access - lib_path = libinfo.find_lib_path(lib_name) - return ctypes.CDLL(lib_path), lib_path - - # only load once here - if env.SKIP_LOADING_TILELANG_SO == "0": - _LIB, _LIB_PATH = _load_tile_lang_lib() - - from .jit import jit, lazy_jit, JITKernel, compile, par_compile # noqa: F401 - from .profiler import Profiler # noqa: F401 - from .cache import clear_cache # noqa: F401 - from .utils import ( - TensorSupplyType, # noqa: F401 - deprecated, # noqa: F401 - ) - from .layout import ( - Layout, # noqa: F401 - Fragment, # noqa: F401 - ) - from . import ( - analysis, # noqa: F401 - transform, # noqa: F401 - language, # noqa: F401 - engine, # noqa: F401 - tools, # noqa: F401 - ) - from .language.v2 import dtypes # noqa: F401 - from .autotuner import autotune # noqa: F401 - from .transform import PassConfigKey # noqa: F401 - from .engine import lower, register_cuda_postproc, register_hip_postproc, register_c_postproc # noqa: F401 - from .math import * # noqa: F403 - from . import ir # noqa: F401 - from . import tileop # noqa: F401 +# Skip import when running under AutoDD +if not _RUNNING_AUTODD: + with _lazy_load_lib(): + from .env import enable_cache, disable_cache, is_cache_enabled # noqa: F401 + from .env import env as env # noqa: F401 + + import tvm + import tvm.base # noqa: F401 + from tvm import DataType # noqa: F401 + + # Setup tvm search path before importing tvm + from . import libinfo + + def _load_tile_lang_lib(): + """Load Tile Lang lib""" + if sys.platform.startswith("win32") and sys.version_info >= (3, 8): + for path in libinfo.get_dll_directories(): + os.add_dll_directory(path) + # pylint: disable=protected-access + lib_name = "tilelang" if tvm.base._RUNTIME_ONLY else "tilelang_module" + # pylint: enable=protected-access + lib_path = libinfo.find_lib_path(lib_name) + return ctypes.CDLL(lib_path), lib_path + + # only load once here + if env.SKIP_LOADING_TILELANG_SO == "0": + _LIB, _LIB_PATH = _load_tile_lang_lib() + + from .jit import jit, lazy_jit, JITKernel, compile, par_compile # noqa: F401 + from .profiler import Profiler # noqa: F401 + from .cache import clear_cache # noqa: F401 + from .utils import ( + TensorSupplyType, # noqa: F401 + deprecated, # noqa: F401 + ) + from .layout import ( + Layout, # noqa: F401 + Fragment, # noqa: F401 + ) + from . import ( + analysis, # noqa: F401 + transform, # noqa: F401 + language, # noqa: F401 + engine, # noqa: F401 + tools, # noqa: F401 + ) + from .language.v2 import dtypes # noqa: F401 + from .autotuner import autotune # noqa: F401 + from .transform import PassConfigKey # noqa: F401 + from .engine import lower, register_cuda_postproc, register_hip_postproc, register_c_postproc # noqa: F401 + from .math import * # noqa: F403 + from . import ir # noqa: F401 + from . import tileop # noqa: F401 del _lazy_load_lib diff --git a/tilelang/autodd.py b/tilelang/autodd.py new file mode 100644 index 000000000..348a257c9 --- /dev/null +++ b/tilelang/autodd.py @@ -0,0 +1,1180 @@ +from abc import ABC, abstractmethod +import ast +import asyncio +from collections import Counter +from copy import copy, deepcopy +from dataclasses import dataclass +from pathlib import Path +import shutil +from typing import Callable, Literal, NamedTuple, override +from collections.abc import Sequence +from collections.abc import Iterable +import contextlib +import io +import multiprocessing +import queue +import subprocess +import tempfile +import time +import os +import traceback + + +def ast_replace(node: ast.AST, **changes) -> ast.AST: + node = copy(node) + for field, value in changes.items(): + setattr(node, field, value) + return node + + +def parse_stmts(s: str) -> list[ast.stmt]: + mod = ast.parse(s) + return mod.body + + +def parse_expr(s: str) -> ast.expr: + mod = ast.parse(s, mode="eval") + return mod.body + + +class ASTRewrite(ABC): + @abstractmethod + def get_name(self) -> str: + raise NotImplementedError + + @abstractmethod + def match(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> bool: + raise NotImplementedError + + @abstractmethod + def rewrite(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> "ast.AST | list[ast.AST] | None": + raise NotImplementedError + + +@dataclass +class GeneralRemove(ASTRewrite): + name: str + target_type: type[ast.AST] + inside_list: bool = True + replace_with: "ast.AST | list[ast.AST] | None" = None + + @override + def get_name(self) -> str: + return self.name + + @override + def match(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> bool: + return isinstance(node, self.target_type) and (not self.inside_list or inside_list) + + @override + def rewrite(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> None: + return deepcopy(self.replace_with) + + +def expr_to_zeros(target: ast.expr) -> ast.expr: + if isinstance(target, ast.Tuple): + zeros = [ast.Constant(value=0) for _ in target.elts] + return ast.Tuple(elts=zeros, ctx=ast.Load()) + else: + return ast.Constant(value=0) + + +class CallFwdArg1(ASTRewrite): + @override + def get_name(self) -> str: + return "call-fwd-arg1" + + @override + def match(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> bool: + return isinstance(node, ast.Call) and len(node.args) >= 1 + + @override + def rewrite(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> ast.AST: + assert isinstance(node, ast.Call) + return node.args[0] + + +class AttachFullFuncArgs(ASTRewrite): + @override + def get_name(self) -> str: + return "attach-full-func-args" + + @override + def match(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> bool: + return isinstance(node, ast.FunctionDef) and (node.args.vararg is None or node.args.kwarg is None) + + @override + def rewrite(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> ast.AST: + assert isinstance(node, ast.FunctionDef) + node = copy(node) + node.args = copy(node.args) + if node.args.vararg is None: + node.args.vararg = ast.arg(arg="args") + if node.args.kwarg is None: + node.args.kwarg = ast.arg(arg="kwargs") + return node + + +@dataclass +class IntConstApply(ASTRewrite): + matcher: Callable[[int], bool] + apply: Callable[[int], ast.AST] + name: str + + @override + def get_name(self) -> str: + return self.name + + @override + def match(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> bool: + return isinstance(node, ast.Constant) and isinstance(node.value, int) and self.matcher(node.value) + + @override + def rewrite(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> ast.AST: + assert isinstance(node, ast.Constant) and isinstance(node.value, int) + return ast_replace(node, value=self.apply(node.value)) + + +@dataclass +class BinOpFwdArg(ASTRewrite): + forward: Literal["left", "right"] = "left" + + @override + def get_name(self) -> str: + return f"binop-fwd-arg-{self.forward}" + + @override + def match(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> bool: + return isinstance(node, ast.BinOp) + + @override + def rewrite(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> ast.AST: + assert isinstance(node, ast.BinOp) + if self.forward == "left": + return node.left + else: + return node.right + + +class DictCanonicalize(ASTRewrite): + @override + def get_name(self) -> str: + return "dict-canonicalize" + + @override + def match(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> bool: + return isinstance(node, ast.Dict) + + @override + def rewrite(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> ast.AST: + assert isinstance(node, ast.Dict) + return ast.Call( + func=ast.Name("dict", ctx=ast.Load()), + args=[ast.List(elts=[ast.Tuple([k, v], ctx=ast.Load()) for k, v in zip(node.keys, node.values)], ctx=ast.Load())], + keywords=[], + ) + + +def _as_expr_placeholder(temp: ast.AST) -> "str | None": + if isinstance(temp, ast.Name): + return temp.id + else: + return None + + +def _as_stmt_placeholder(temp: ast.AST) -> "str | None": + if isinstance(temp, ast.Expr) and isinstance(temp.value, ast.Name): + return temp.value.id + else: + return None + + +def _ast_match(temp: ast.AST, node: ast.expr, placeholders: set[str]): + ph_expr = _as_expr_placeholder(temp) + if ph_expr is not None and ph_expr in placeholders: + return {ph_expr: node} + if type(temp) is not type(node): + return False + result = {} + for field, value in ast.iter_fields(temp): + if isinstance(value, list): + if len(value) == 1: + ph_stmts = _as_stmt_placeholder(value[0]) + if ph_stmts is not None and ph_stmts in placeholders: + result.update({ph_stmts: getattr(node, field)}) + continue + if not isinstance(getattr(node, field), list): + return False + if len(value) != len(getattr(node, field)): + return False + for v1, v2 in zip(value, getattr(node, field)): + sub_result = _ast_match(v1, v2, placeholders) + if sub_result is False: + return False + result.update(sub_result) + elif isinstance(value, ast.AST): + if not isinstance(getattr(node, field), ast.AST): + return False + sub_result = _ast_match(value, getattr(node, field), placeholders) + if sub_result is False: + return False + result.update(sub_result) + else: + if value != getattr(node, field): + return False + return result + + +def _ast_replace(temp: ast.expr, repl: dict[str, ast.AST]) -> ast.expr: + ph_expr = _as_expr_placeholder(temp) + if ph_expr is not None and ph_expr in repl: + return deepcopy(repl[ph_expr]) + ph_stmts = _as_stmt_placeholder(temp) + if ph_stmts is not None and ph_stmts in repl: + return deepcopy(repl[ph_stmts]) + temp = copy(temp) + for field, value in ast.iter_fields(temp): + if isinstance(value, list): + if len(value) == 1: + ph_stmts = _as_stmt_placeholder(value[0]) + if ph_stmts is not None and ph_stmts in repl: + setattr(temp, field, deepcopy(repl[ph_stmts])) + continue + new_values = [] + for v in value: + res = _ast_replace(v, repl) + if res is None: + continue + if isinstance(res, ast.AST): + new_values.append(res) + else: + new_values.extend(res) + setattr(temp, field, new_values) + elif isinstance(value, ast.AST): + setattr(temp, field, _ast_replace(value, repl)) + return temp + + +ASTPatKind = Literal["expr", "stmt"] + + +@dataclass +class ASTPat: + tree: "ast.expr | list[ast.stmt]" + placeholders: set[str] + + @classmethod + def from_code(cls, kind: ASTPatKind, code: str, placeholders: set[str]) -> "ASTPat": + if kind == "expr": + tree = parse_expr(code) + elif kind == "stmt": + tree = parse_stmts(code) + if len(tree) == 1: + tree = tree[0] + else: + raise ValueError(f"Unknown AST pattern kind: {kind}") + return cls(tree, placeholders) + + def match_placeholders(self, node: "ast.AST | list[ast.AST]") -> "dict[str, ast.AST] | bool": + return _ast_match(self.tree, node, self.placeholders) + + def match(self, node: ast.AST) -> bool: + return self.match_placeholders(node) is not False + + def replace(self, repl: dict[str, ast.AST]) -> ast.AST: + if isinstance(self.tree, list): + replaced_stmts = [] + for stmt in self.tree: + replaced = _ast_replace(stmt, repl) + if isinstance(replaced, ast.AST): + replaced_stmts.append(replaced) + else: + replaced_stmts.extend(replaced) + return replaced_stmts + else: + return _ast_replace(self.tree, repl) + + +@dataclass +class ASTPatRewrite(ASTRewrite): + name: str + match_pat: ASTPat + rewrite_pat: ASTPat + checker: "Callable[[dict[str, ast.AST]], bool] | dict[str, Callable[[ast.AST], bool]] | None" = None + derived: "dict[str, Callable[[dict[str, ast.AST]], ast.AST]] | None" = None + + @classmethod + def from_code( + cls, + name: str, + kind: ASTPatKind, + match: str, + rewrite: str, + placeholders: set[str], + checker: "Callable[[dict[str, ast.AST]], bool] | dict[str, Callable[[ast.AST], bool]] | None" = None, + derived: "dict[str, Callable[[dict[str, ast.AST]], ast.AST]] | None" = None, + ) -> "ASTPatRewrite": + match_pat = ASTPat.from_code(kind, match, placeholders) + rewrite_pat = ASTPat.from_code(kind, rewrite, placeholders) + return cls(name, match_pat, rewrite_pat, checker, derived) + + @override + def get_name(self) -> str: + return self.name + + def match_placeholders(self, node: ast.AST): + ph = self.match_pat.match_placeholders(node) + if ph is False: + return False + if self.derived is not None: + for k, v in self.derived.items(): + ph[k] = v(ph) + if self.checker is not None: + if isinstance(self.checker, dict): + for k, v in self.checker.items(): + if k not in ph or not v(ph[k]): + return False + else: + return self.checker(ph) + return ph + + @override + def match(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> bool: + return self.match_placeholders(node) is not False + + def _rewrite(self, node: ast.AST): + # this function is for debugging purpose + repl = self.match_placeholders(node) + assert repl is not False + replaced = self.rewrite_pat.replace(repl) + return replaced + + @override + def rewrite(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> ast.AST: + return self._rewrite(node) + + +class ASTMutator: + def generic_visit(self, node): + for field, old_value in ast.iter_fields(copy(node)): + if isinstance(old_value, list): + new_values = [] + for value in old_value: + if isinstance(value, ast.AST): + value = self.visit(value, node, field, True) + if value is None: + continue + elif not isinstance(value, ast.AST): + new_values.extend(value) + continue + new_values.append(value) + old_value[:] = new_values + elif isinstance(old_value, ast.AST): + new_node = self.visit(old_value, node, field, False) + if new_node is None: + delattr(node, field) + else: + setattr(node, field, new_node) + return node + + def visit(self, node: ast.AST, parent: "ast.AST | None", field: "str | None", inside_list: bool): + return self.generic_visit(node) + + +@dataclass +class LabeledRewrite: + label: int + rewrite: ASTRewrite + + +class RewriteAttacher(ASTMutator): + def __init__(self, rewrites: list[ASTRewrite]): + self.rewrites = rewrites + self.uid_counter = 0 + self.rewrite_counter = 0 + self.rewrite_names = Counter() + + @override + def visit(self, node: ast.AST, parent: "ast.AST | None", field: "str | None", inside_list: bool): + node = copy(node) + node._dd_uid = self.uid_counter + self.uid_counter += 1 + node._dd_rewrites = [] + for r in self.rewrites: + if r.match(node, parent, field, inside_list): + lr = LabeledRewrite(self.rewrite_counter, r) + self.rewrite_counter += 1 + self.rewrite_names[lr.rewrite.get_name()] += 1 + node._dd_rewrites.append(lr) + res = self.generic_visit(node) + return res + + +def attach_rewrites(tree: ast.AST, rewrites: list[ASTRewrite]) -> tuple[ast.AST, int, int]: + attacher = RewriteAttacher(rewrites) + new_tree = attacher.visit(tree, None, None, False) + print("Rewrites:", attacher.rewrite_names) + return new_tree, attacher.uid_counter, attacher.rewrite_counter + + +class RewriteApplier(ASTMutator): + def __init__(self, target_labels: set[int]): + self.target_labels = target_labels + self.applied_rewrites: set[int] = set() + self.visited: set[int] = set() + + @override + def visit(self, node: ast.AST, parent: "ast.AST | None", field: "str | None", inside_list: bool): + orig_uid = getattr(node, "_dd_uid", None) + if orig_uid in self.visited: + return self.generic_visit(node) + self.visited.add(orig_uid) + + node = copy(node) + for lr in getattr(node, "_dd_rewrites", []): + lr: LabeledRewrite + if lr.label in self.target_labels: + node = lr.rewrite.rewrite(node, parent, field, inside_list) + self.applied_rewrites.add(lr.label) + break + + if node is None: + return None + elif isinstance(node, ast.AST): + return self.visit(node, parent, field, inside_list) + else: + new_items = [] + for item in node: + if isinstance(item, ast.AST): + res = self.visit(item, parent, field, inside_list) + if res is None: + continue + elif isinstance(res, ast.AST): + new_items.append(res) + else: + new_items.extend(res) + return new_items + + +def apply_rewrites(tree: ast.AST, target_labels: set[int]) -> tuple[ast.AST, set[int]]: + applier = RewriteApplier(target_labels) + new_tree = applier.visit(deepcopy(tree), None, None, False) + return new_tree, applier.applied_rewrites + + +def test_rewrite(rewrite: ASTRewrite, code: str): + tree = ast.parse(code) + tree, _, num_matched = attach_rewrites(tree, [rewrite]) + tree, _ = apply_rewrites(tree, set(i for i in range(num_matched))) + ast.fix_missing_locations(tree) + return ast.unparse(tree) + + +@dataclass +class Task: + source: str + applied: list[int] + masked: list[int] + + def with_source(self, source: str) -> "Task": + return Task(source, self.applied, self.masked) + + +class PDD: + def __init__(self, all_labels: list[int], init_proba: float = 0.93): + self.all_labels = all_labels + self.probas = {label: init_proba for label in all_labels} + + def apply(self, target_labels: set[int]) -> set[int]: + return target_labels + + @staticmethod + def _update_probas(probas: dict[int, float], task: Task, is_interesting: bool): + if is_interesting: + for label in task.applied: + probas[label] = 1.0 + for label in task.masked: + probas[label] = 0.0 + else: + prod = 1.0 + for label in task.applied: + if probas[label] > 0: + prod *= probas[label] + denorm = 1.0 - prod + for label in task.applied: + p = probas[label] + if p >= 1.0: + continue + probas[label] = 1.0 - (1.0 - p) / denorm if denorm > 0.0 else 0.0 + + def generator(self) -> Iterable[Task]: + probas = deepcopy(self.probas) + while True: + choices = sorted(probas.items(), key=lambda x: (x[1], x[0]), reverse=True) + selected = [] + sum, prod = 0.0, 1.0 + for label, p in choices: + if p >= 1.0: + selected.append(label) + continue + if (sum + 1) * prod * p > sum * prod: + selected.append(label) + sum, prod = sum + 1, prod * p + else: + break + applied = self.apply(set(selected)) + masked = set(selected).difference(applied) + task = Task(source=None, applied=list(applied), masked=list(masked)) + if sum * prod == 0 or all(probas[label] >= 1.0 for label in applied): + break + yield deepcopy(task) + self._update_probas(probas, task, is_interesting=False) + + def update(self, task: Task, is_interesting: bool): + self._update_probas(self.probas, task, is_interesting) + + +class TaskManager(ABC): + @abstractmethod + def task_generator(self) -> Iterable[Task]: ... + + @abstractmethod + def task_update(self, task: Task, is_interesting: bool): ... + + @classmethod + @abstractmethod + def from_source(cls, source: str, *args, **kwargs) -> "TaskManager": ... + + +class ASTPDD(TaskManager, PDD): + def __init__(self, tree: ast.AST, rewrites: list[ASTRewrite], init_proba: float = 0.93): + self.tree, _, total_rewrites = attach_rewrites(tree, rewrites) + all_labels = [i for i in range(total_rewrites)] + super().__init__(all_labels, init_proba) + + @override + @classmethod + def from_source(cls, source, *args, **kwargs): + return cls(ast.parse(source), *args, **kwargs) + + def apply(self, target_labels: set[int]) -> set[int]: + _, applied = apply_rewrites(self.tree, target_labels) + return applied + + @override + def task_generator(self) -> Iterable[Task]: + for task in self.generator(): + new_tree, _ = apply_rewrites(self.tree, task.applied) + try: + new_tree = deepcopy(new_tree) + ast.fix_missing_locations(new_tree) + source = ast.unparse(new_tree) + except Exception as _: + continue + yield task.with_source(source) + # self.update(task, is_interesting=False) + + @override + def task_update(self, task: Task, is_interesting: bool): + self.update(task, is_interesting) + + +def ruff_fix_code(code_string: str, fix_lint: bool = True, format_code: bool = True) -> str: + ruff_executable = shutil.which("ruff") + if not ruff_executable: + raise FileNotFoundError("Unable to find ruff") + + with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False, encoding="utf-8") as tmp: + tmp.write(code_string) + tmp_path = tmp.name + + try: + if fix_lint: + print("Running ruff fix on:", tmp_path) + subprocess.run([ruff_executable, "check", "--fix", "--unsafe-fixes", tmp_path], capture_output=True, check=False) + + if format_code: + print("Running ruff format on:", tmp_path) + subprocess.run([ruff_executable, "format", tmp_path], capture_output=True, check=False) + + with open(tmp_path) as f: + fixed_code = f.read() + + return fixed_code + + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) + + +class LinePDD(TaskManager, PDD): + def __init__(self, source: str, init_proba: float = 0.93): + lines = [line for line in source.splitlines() if line.strip() != ""] + all_labels = [i for i in range(len(lines))] + super().__init__(all_labels, init_proba) + + @override + @classmethod + def from_source(cls, source, *args, **kwargs): + return cls(source, *args, **kwargs) + + @override + def task_generator(self) -> Iterable[Task]: + for task in self.generator(): + lines = [line for line in self.source.splitlines() if line.strip() != ""] + new_lines = [line for idx, line in enumerate(lines) if idx not in task.applied] + source = "\n".join(new_lines) + try: + ast.parse(source) + except Exception as _: + # self.update(task, is_interesting=False) + continue + yield task.with_source(source) + + @override + def task_update(self, task: Task, is_interesting: bool): + self.update(task, is_interesting) + + +class Ruff(TaskManager): + def __init__(self, source: str, fix_lint: bool = True, format_code: bool = True): + self.source = source + self.fix_lint = fix_lint + self.format_code = format_code + self.finished = False + + @override + @classmethod + def from_source(cls, source: str, *args, **kwargs) -> "Ruff": + return cls(source) + + @override + def task_generator(self): + if self.finished: + return + self.finished = True + try: + fixed_code = ruff_fix_code(self.source, fix_lint=self.fix_lint, format_code=self.format_code) + yield Task(source=fixed_code, applied=[], masked=[]) + except FileNotFoundError as _: + return + + @override + def task_update(self, task: Task, is_interesting: bool): + pass + + +def _worker_loop(input_queue, output_queue): + while True: + try: + task = input_queue.get() + if task is None: + break + + capture_out = io.StringIO() + capture_err = io.StringIO() + success = False + with tempfile.NamedTemporaryFile("w", suffix=".py", delete=True) as f: + f.write(task) + f.flush() + try: + with contextlib.redirect_stdout(capture_out), contextlib.redirect_stderr(capture_err): + code = compile(task, f.name, "exec") + exec(code, {"__builtins__": __builtins__}) + success = True + except SystemExit as e: + capture_err.write(f"SystemExit: Code {e.code}\n") + except Exception: + traceback.print_exc(file=capture_err) + + output_queue.put((capture_out.getvalue(), capture_err.getvalue(), success)) + except KeyboardInterrupt: + break + except Exception as e: + output_queue.put(("", f"Critical: {e}", False)) + + +# This class is written by Gemini +class AsyncPythonRunner: + def __init__(self): + self.process = None + self.input_queue = None + self.output_queue = None + self.lock = asyncio.Lock() + + def start_proc(self): + if self.process and self.process.is_alive(): + return + ctx = multiprocessing.get_context("spawn") + self.input_queue = ctx.Queue() + self.output_queue = ctx.Queue() + self.process = ctx.Process(target=_worker_loop, args=(self.input_queue, self.output_queue), daemon=True) + self.process.start() + + def stop_proc(self): + if self.process: + # Try to send a stop signal. + # Note: if the queue is full or broken, put may block, so wrap it in a try. + with contextlib.suppress(Exception): + self.input_queue.put_nowait(None) + + self.process.join(timeout=0.5) + if self.process.is_alive(): + self.process.terminate() + self.process = None + + def __del__(self): + self.stop_proc() + + async def run(self, code: str, timeout: float = 5.0): + async with self.lock: + if not self.process or not self.process.is_alive(): + self.start_proc() + + try: + self.input_queue.put(code) + except Exception as e: + # Rare case: the pipe is broken. + return "", f"Queue Error: {e}", False + + start_time = time.time() + while True: + # 1. Check whether we timed out. + if time.time() - start_time > timeout: + self._handle_timeout(timeout) + return "", f"TimeoutError: Exceeded {timeout}s", False + + # 2. Check whether the child process is still alive (avoid hanging if it segfaults). + if not self.process.is_alive(): + # Try one last read (in case the result was just written before the process exited). + try: + return self.output_queue.get_nowait() + except queue.Empty: + self.process = None # Mark as needing restart. + return "", "Error: Worker process died unexpectedly", False + + # 3. Try to read results in a non-blocking way. + try: + # get_nowait raises queue.Empty immediately if the queue is empty. + result = self.output_queue.get_nowait() + return result + except queue.Empty: + # No data in the queue yet, sleep briefly and yield control back to the event loop. + # A 0.05s delay is perfectly acceptable for interactive usage. + await asyncio.sleep(0.05) + + def _handle_timeout(self, timeout): + """Handle cleanup logic when a timeout happens.""" + # We must force-terminate because exec may still be stuck in a tight loop. + if self.process and self.process.is_alive(): + self.process.terminate() + # Give the OS a bit of time to reclaim resources. + self.process.join(timeout=0.5) + + # Mark as None so that the next run triggers start_proc and restarts the worker. + self.process = None + self.input_queue = None + self.output_queue = None + + +class SubProcRunner: + def __init__(self): + pass + + def start_proc(self): + pass + + def stop_proc(self): + pass + + async def run(self, code: str, timeout: float = 5.0): + with tempfile.NamedTemporaryFile("w", suffix=".py", delete=True) as f: + f.writelines(code) + f.flush() + + def run_subprocess(args): + try: + proc = subprocess.run( + args, + capture_output=True, + text=True, # Decodes output as strings (Python 3.5+) + timeout=timeout, # Timeout after 5 seconds + check=False, # Do not raise exception for non-zero exit codes + ) + return proc.stdout, proc.stderr, proc.returncode == 0 + except subprocess.CalledProcessError as e: + return e.stdout, e.stderr, False + + return await asyncio.get_running_loop().run_in_executor(None, run_subprocess, ["python3", f.name]) + + +def clean_empty_pass(code: str) -> str: + tree = ast.parse(code) + + class PassRemover(ast.NodeTransformer): + def clean_body(self, body: list[ast.stmt], keep_one=True) -> list[ast.stmt]: + res = [stmt for stmt in body if not isinstance(stmt, ast.Pass)] + if not res and keep_one: + return [ast.Pass()] + return res + + def visit_For(self, node: ast.For) -> ast.AST: + self.generic_visit(node) + node.body = self.clean_body(node.body) + return node + + def visit_If(self, node: ast.If) -> ast.AST: + self.generic_visit(node) + node.body = self.clean_body(node.body) + node.orelse = self.clean_body(node.orelse, keep_one=False) + return node + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: + self.generic_visit(node) + node.body = self.clean_body(node.body) + return node + + def visit_ClassDef(self, node): + self.generic_visit(node) + node.body = self.clean_body(node.body) + return node + + def visit_Module(self, node: ast.Module) -> ast.AST: + self.generic_visit(node) + node.body = self.clean_body(node.body) + return node + + def visit_With(self, node: ast.With) -> ast.AST: + self.generic_visit(node) + node.body = self.clean_body(node.body) + return node + + def visit_AsyncWith(self, node: ast.AsyncWith) -> ast.AST: + self.generic_visit(node) + node.body = self.clean_body(node.body) + return node + + def visit_While(self, node: ast.While) -> ast.AST: + self.generic_visit(node) + node.body = self.clean_body(node.body) + return node + + def visit_Try(self, node: ast.Try) -> ast.AST: + self.generic_visit(node) + node.body = self.clean_body(node.body) + node.orelse = self.clean_body(node.orelse) + node.finalbody = self.clean_body(node.finalbody) + for handler in node.handlers: + handler.body = self.clean_body(handler.body) + return node + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: + self.generic_visit(node) + node.body = self.clean_body(node.body) + return node + + def visit_AsyncFor(self, node: ast.AsyncFor) -> ast.AST: + self.generic_visit(node) + node.body = self.clean_body(node.body) + return node + + def visit_ExceptHandler(self, node: ast.ExceptHandler) -> ast.AST: + self.generic_visit(node) + node.body = self.clean_body(node.body) + return node + + new_tree = PassRemover().visit(tree) + return ast.unparse(new_tree) + + +JobBackend = Literal["subproc", "runner"] + + +@dataclass +class ParTaskManager: + err_msg: str + text: str + output_file: Path + timeout: int = 60 + num_workers: int = 1 + backend: JobBackend = "runner" + allow_larger: bool = False + + def __post_init__(self): + self.worker_tasks: list[asyncio.Task] = [] + self.stopped = False + self.task_manager: TaskManager | None = None + self.generator: Iterable[Task] | None = None + self.condition = asyncio.Condition() + self.waiting_workers = 0 + self.finished = True + self.task_counter = 0 + + @property + def text_len(self): + return len(self.text) + + def reset(self, task_manager: TaskManager): + self.task_manager = task_manager + self.generator = task_manager.task_generator() + self.finished = False + + async def get_next_task(self) -> "Task | None": + async with self.condition: + while True: + if self.stopped: + return None + if self.finished or self.generator is None: + await self.condition.wait() + continue + try: + result = deepcopy(next(self.generator)) + self.task_counter += 1 + if self.task_counter % self.num_workers == 0: + print(f"Dispatched {self.task_counter} tasks") + return result + except StopIteration: + self.waiting_workers += 1 + if self.waiting_workers == self.num_workers: + self.finished = True + self.generator = None + self.condition.notify_all() + await self.condition.wait() + self.waiting_workers -= 1 + + async def submit_result(self, task: Task, is_interested: bool): + async with self.condition: + self.task_manager.task_update(task, is_interested) + if is_interested: + self.generator = self.task_manager.task_generator() + self.condition.notify_all() + text = self.post_proc(task.source) + if len(text) <= self.text_len or self.allow_larger: + print("Accept length", len(text)) + self.text = text + self.output_file.write_text(text) + self.updated = True + + def post_proc(self, text): + return clean_empty_pass(text) + + async def worker(self, wid: int): + runner = AsyncPythonRunner() if self.backend == "runner" else SubProcRunner() + try: + while True: + task = await self.get_next_task() + if task is None: + break + out, err, ok = await runner.run(task.source, timeout=self.timeout) + is_interested = self.err_msg in out or self.err_msg in err + await self.submit_result(task, is_interested) + finally: + if hasattr(runner, "stop_proc"): + runner.stop_proc() + + async def start_workers(self): + if self.worker_tasks: + return + self.stopped = False + self.worker_tasks = [asyncio.create_task(self.worker(wid)) for wid in range(self.num_workers)] + + async def stop_workers(self): + if not self.worker_tasks: + return + self.stopped = True + async with self.condition: + self.condition.notify_all() + await asyncio.gather(*self.worker_tasks, return_exceptions=True) + self.worker_tasks = [] + self.generator = None + + async def run_async(self, task_manager: TaskManager): + await self.start_workers() + self.reset(task_manager) + best_length = self.text_len + async with self.condition: + self.condition.notify_all() + while not self.finished: + await self.condition.wait() + return self.text_len < best_length + + async def run_with(self, cls: type[TaskManager], *args, **kwargs): + allow_larger = kwargs.pop("allow_larger", False) + if allow_larger: + self.allow_larger = True + self.updated = False + task_manager = cls.from_source(self.text, *args, **kwargs) + res = await self.run_async(task_manager) + self.allow_larger = False + if allow_larger: + return self.updated + return res + + +class Args(NamedTuple): + source: Path + err_msg: str + output: Path + backend: JobBackend + timeout: int + jobs: int + + +async def main(args: Args): + source = args.source.read_text() + + manager = ParTaskManager( + err_msg=args.err_msg, + text=source, + output_file=args.output, + timeout=args.timeout, + backend=args.backend, + num_workers=args.jobs, + ) + + # remove any statement + + for_bind_0 = ASTPatRewrite.from_code( + name='for-bind-0', + kind='stmt', + match='for VARS in EXPR: BODY', + rewrite='VARS = ZEROS\nBODY', + placeholders={'VARS', 'EXPR', 'BODY', 'ZEROS'}, + derived={ + 'ZEROS': lambda ph: expr_to_zeros(ph['VARS']), + } + ) + + with_bind_0 = ASTPatRewrite.from_code( + name='with-bind-0', + kind='stmt', + match='with EXPR as VARS: BODY', + rewrite='with EXPR:\n VARS = ZEROS\n BODY', + placeholders={'VARS', 'EXPR', 'BODY', 'ZEROS'}, + derived={ + 'ZEROS': lambda ph: expr_to_zeros(ph['VARS']), + } + ) + + assign_rhs_1 = ASTPatRewrite.from_code( + name="assign-rhs-1", + kind="stmt", + match="VAR = EXPR", + rewrite="VAR = 1", + placeholders={"VAR", "EXPR"}, + ) + + if_remover_1 = ASTPatRewrite.from_code( + name="if-remover-1", + kind="stmt", + match="if COND: BODY", + rewrite="BODY", + placeholders={"COND", "BODY"}, + ) + + if_remover_2 = ASTPatRewrite.from_code( + name="if-remover-2", + kind="stmt", + match="if COND: BODY\nelse: ELSE_BODY", + rewrite="BODY", + placeholders={"COND", "BODY", "ELSE_BODY"}, + ) + + if_remover_3 = ASTPatRewrite.from_code( + name="if-remover-3", + kind="stmt", + match="if COND: BODY\nelse: ELSE_BODY", + rewrite="ELSE_BODY", + placeholders={"COND", "BODY", "ELSE_BODY"}, + ) + + # replace all integer constant x with x // 2 + int_reduce = IntConstApply(lambda x: x > 1, lambda x: x // 2, "int-reduce-2") + + # 1. first, we only do statement level fast reductions + fast_reducers = [ + if_remover_1, + if_remover_2, + if_remover_3, + for_bind_0, + GeneralRemove("stmt-remover", ast.stmt, replace_with=ast.Pass()), + ] + + # 2. canonicalizer enables more simplifications + canonicalizers = [ + with_bind_0, + AttachFullFuncArgs(), + ] + + # 3. simplifiers + simplifiers = [ + assign_rhs_1, + CallFwdArg1(), + BinOpFwdArg("left"), + BinOpFwdArg("right"), + GeneralRemove("func-arg-remover", ast.arg), + ] + fast_reducers + + # 4. finally apply expr level slow reductions + slow_reducers = [ + GeneralRemove("func-arg-remover", ast.arg), + GeneralRemove("general-expr-remover", ast.expr), + GeneralRemove("general-keyword-remover", ast.keyword), + ] + fast_reducers + + await manager.start_workers() + manager.text = manager.post_proc(manager.text) + try: + while True: + changed = False + while await manager.run_with(ASTPDD, fast_reducers): + changed = True + await manager.run_with(ASTPDD, canonicalizers, allow_larger=True) + while await manager.run_with(ASTPDD, simplifiers): + changed = True + while await manager.run_with(ASTPDD, [int_reduce], allow_larger=True): + changed = True + while await manager.run_with(ASTPDD, slow_reducers): + changed = True + if not changed: + break + finally: + await manager.stop_workers() + + +def cli_main(argv: "Sequence[str] | None" = None) -> None: + from argparse import ArgumentParser + + parser = ArgumentParser( + usage="python autodd.py source --err-msg MSG -o OUTPUT [--backend {runner,subproc}] [--timeout SEC] [-j N]", + description="Delta-debug the provided Python source until the target error message remains reproducible.", + epilog="Author: Kexing Zhou ", + ) + parser.add_argument("source", type=Path, help="Input python source file") + parser.add_argument("--err-msg", type=str, required=True, help="Error message to look for") + parser.add_argument("-o", "--output", type=Path, required=True, help="Output file path") + parser.add_argument( + "--backend", default="runner", choices=["runner", "subproc"], help="Backend for running code: runner is faster, subproc is stable" + ) + parser.add_argument("--timeout", type=int, default=60, help="Timeout for each task in seconds (default: 60)") + parser.add_argument("-j", "--jobs", type=int, default=1, help="Number of parallel jobs (default: 1)") + ns = parser.parse_args(argv) + + if ns.backend == "runner": + backend = JobBackend(AsyncPythonRunner, "runner") + else: + backend = JobBackend(SubProcRunner, "subproc") + + args = Args( + source=ns.source, + err_msg=ns.err_msg, + output=ns.output, + backend=backend, + timeout=ns.timeout, + jobs=ns.jobs, + ) + asyncio.run(main(args)) + + +if __name__ == "__main__": + cli_main() From 6736dae14a07af4e28f20b024ebc9238755710c5 Mon Sep 17 00:00:00 2001 From: KEKE046 Date: Thu, 8 Jan 2026 07:22:48 +0000 Subject: [PATCH 02/13] fix typos --- tilelang/autodd.py | 30 +++--------------------------- 1 file changed, 3 insertions(+), 27 deletions(-) diff --git a/tilelang/autodd.py b/tilelang/autodd.py index 348a257c9..38569179c 100644 --- a/tilelang/autodd.py +++ b/tilelang/autodd.py @@ -156,25 +156,6 @@ def rewrite(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) return node.right -class DictCanonicalize(ASTRewrite): - @override - def get_name(self) -> str: - return "dict-canonicalize" - - @override - def match(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> bool: - return isinstance(node, ast.Dict) - - @override - def rewrite(self, node: ast.AST, parent: ast.AST, field: str, inside_list: bool) -> ast.AST: - assert isinstance(node, ast.Dict) - return ast.Call( - func=ast.Name("dict", ctx=ast.Load()), - args=[ast.List(elts=[ast.Tuple([k, v], ctx=ast.Load()) for k, v in zip(node.keys, node.values)], ctx=ast.Load())], - keywords=[], - ) - - def _as_expr_placeholder(temp: ast.AST) -> "str | None": if isinstance(temp, ast.Name): return temp.id @@ -610,6 +591,7 @@ def ruff_fix_code(code_string: str, fix_lint: bool = True, format_code: bool = T class LinePDD(TaskManager, PDD): def __init__(self, source: str, init_proba: float = 0.93): lines = [line for line in source.splitlines() if line.strip() != ""] + self.lines = lines all_labels = [i for i in range(len(lines))] super().__init__(all_labels, init_proba) @@ -621,8 +603,7 @@ def from_source(cls, source, *args, **kwargs): @override def task_generator(self) -> Iterable[Task]: for task in self.generator(): - lines = [line for line in self.source.splitlines() if line.strip() != ""] - new_lines = [line for idx, line in enumerate(lines) if idx not in task.applied] + new_lines = [line for idx, line in enumerate(self.lines) if idx not in task.applied] source = "\n".join(new_lines) try: ast.parse(source) @@ -1160,16 +1141,11 @@ def cli_main(argv: "Sequence[str] | None" = None) -> None: parser.add_argument("-j", "--jobs", type=int, default=1, help="Number of parallel jobs (default: 1)") ns = parser.parse_args(argv) - if ns.backend == "runner": - backend = JobBackend(AsyncPythonRunner, "runner") - else: - backend = JobBackend(SubProcRunner, "subproc") - args = Args( source=ns.source, err_msg=ns.err_msg, output=ns.output, - backend=backend, + backend=ns.backend, timeout=ns.timeout, jobs=ns.jobs, ) From 41428d88aacdf00e99953f0a90a0f35bb45c9979 Mon Sep 17 00:00:00 2001 From: KEKE046 Date: Thu, 8 Jan 2026 07:25:38 +0000 Subject: [PATCH 03/13] fix lint error --- tilelang/autodd.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tilelang/autodd.py b/tilelang/autodd.py index 38569179c..4689f9c0e 100644 --- a/tilelang/autodd.py +++ b/tilelang/autodd.py @@ -1016,25 +1016,25 @@ async def main(args: Args): # remove any statement for_bind_0 = ASTPatRewrite.from_code( - name='for-bind-0', - kind='stmt', - match='for VARS in EXPR: BODY', - rewrite='VARS = ZEROS\nBODY', - placeholders={'VARS', 'EXPR', 'BODY', 'ZEROS'}, + name="for-bind-0", + kind="stmt", + match="for VARS in EXPR: BODY", + rewrite="VARS = ZEROS\nBODY", + placeholders={"VARS", "EXPR", "BODY", "ZEROS"}, derived={ - 'ZEROS': lambda ph: expr_to_zeros(ph['VARS']), - } + "ZEROS": lambda ph: expr_to_zeros(ph["VARS"]), + }, ) with_bind_0 = ASTPatRewrite.from_code( - name='with-bind-0', - kind='stmt', - match='with EXPR as VARS: BODY', - rewrite='with EXPR:\n VARS = ZEROS\n BODY', - placeholders={'VARS', 'EXPR', 'BODY', 'ZEROS'}, + name="with-bind-0", + kind="stmt", + match="with EXPR as VARS: BODY", + rewrite="with EXPR:\n VARS = ZEROS\n BODY", + placeholders={"VARS", "EXPR", "BODY", "ZEROS"}, derived={ - 'ZEROS': lambda ph: expr_to_zeros(ph['VARS']), - } + "ZEROS": lambda ph: expr_to_zeros(ph["VARS"]), + }, ) assign_rhs_1 = ASTPatRewrite.from_code( From 636c3b27995b7a9d546e57d9a13d047b50c7f366 Mon Sep 17 00:00:00 2001 From: KEKE046 Date: Thu, 8 Jan 2026 07:32:26 +0000 Subject: [PATCH 04/13] fix typos --- tilelang/autodd.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tilelang/autodd.py b/tilelang/autodd.py index 4689f9c0e..93de582ff 100644 --- a/tilelang/autodd.py +++ b/tilelang/autodd.py @@ -770,7 +770,7 @@ def stop_proc(self): async def run(self, code: str, timeout: float = 5.0): with tempfile.NamedTemporaryFile("w", suffix=".py", delete=True) as f: - f.writelines(code) + f.write(code) f.flush() def run_subprocess(args): @@ -783,8 +783,8 @@ def run_subprocess(args): check=False, # Do not raise exception for non-zero exit codes ) return proc.stdout, proc.stderr, proc.returncode == 0 - except subprocess.CalledProcessError as e: - return e.stdout, e.stderr, False + except subprocess.TimeoutExpired as e: + return "", f"TimeoutError: Exceeded {timeout}s", False return await asyncio.get_running_loop().run_in_executor(None, run_subprocess, ["python3", f.name]) @@ -890,6 +890,7 @@ def __post_init__(self): self.waiting_workers = 0 self.finished = True self.task_counter = 0 + self.updated = False @property def text_len(self): From 0427dfbc9b92cab9d62020bb071b50570e637274 Mon Sep 17 00:00:00 2001 From: KEKE046 Date: Thu, 8 Jan 2026 07:32:43 +0000 Subject: [PATCH 05/13] fix lint error --- tilelang/autodd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tilelang/autodd.py b/tilelang/autodd.py index 93de582ff..c12f64f2b 100644 --- a/tilelang/autodd.py +++ b/tilelang/autodd.py @@ -783,7 +783,7 @@ def run_subprocess(args): check=False, # Do not raise exception for non-zero exit codes ) return proc.stdout, proc.stderr, proc.returncode == 0 - except subprocess.TimeoutExpired as e: + except subprocess.TimeoutExpired: return "", f"TimeoutError: Exceeded {timeout}s", False return await asyncio.get_running_loop().run_in_executor(None, run_subprocess, ["python3", f.name]) From 285b3cf60a73dca30e17ef0daee551529d044706 Mon Sep 17 00:00:00 2001 From: KEKE046 Date: Thu, 8 Jan 2026 07:39:03 +0000 Subject: [PATCH 06/13] fix bugs --- tilelang/autodd.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/tilelang/autodd.py b/tilelang/autodd.py index c12f64f2b..9e2227917 100644 --- a/tilelang/autodd.py +++ b/tilelang/autodd.py @@ -769,24 +769,26 @@ def stop_proc(self): pass async def run(self, code: str, timeout: float = 5.0): - with tempfile.NamedTemporaryFile("w", suffix=".py", delete=True) as f: + with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f: f.write(code) - f.flush() - def run_subprocess(args): - try: - proc = subprocess.run( - args, - capture_output=True, - text=True, # Decodes output as strings (Python 3.5+) - timeout=timeout, # Timeout after 5 seconds - check=False, # Do not raise exception for non-zero exit codes - ) - return proc.stdout, proc.stderr, proc.returncode == 0 - except subprocess.TimeoutExpired: - return "", f"TimeoutError: Exceeded {timeout}s", False - - return await asyncio.get_running_loop().run_in_executor(None, run_subprocess, ["python3", f.name]) + def run_subprocess(args): + try: + proc = subprocess.run( + args, + capture_output=True, + text=True, # Decodes output as strings (Python 3.5+) + timeout=timeout, # Timeout after 5 seconds + check=False, # Do not raise exception for non-zero exit codes + ) + return proc.stdout, proc.stderr, proc.returncode == 0 + except subprocess.TimeoutExpired: + return "", f"TimeoutError: Exceeded {timeout}s", False + + result = await asyncio.get_running_loop().run_in_executor(None, run_subprocess, ["python3", f.name]) + with contextlib.suppress(OSError): + os.remove(f.name) + return result def clean_empty_pass(code: str) -> str: @@ -794,6 +796,8 @@ def clean_empty_pass(code: str) -> str: class PassRemover(ast.NodeTransformer): def clean_body(self, body: list[ast.stmt], keep_one=True) -> list[ast.stmt]: + if body is None: + return None res = [stmt for stmt in body if not isinstance(stmt, ast.Pass)] if not res and keep_one: return [ast.Pass()] From 889c37fdc1ed035a6daead380641fdc3edb0a76c Mon Sep 17 00:00:00 2001 From: Kexing Zhou Date: Thu, 8 Jan 2026 18:10:55 +0800 Subject: [PATCH 07/13] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tilelang/autodd.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tilelang/autodd.py b/tilelang/autodd.py index 9e2227917..fa2aa262f 100644 --- a/tilelang/autodd.py +++ b/tilelang/autodd.py @@ -422,7 +422,9 @@ def visit(self, node: ast.AST, parent: "ast.AST | None", field: "str | None", in if node is None: return None elif isinstance(node, ast.AST): - return self.visit(node, parent, field, inside_list) + # After rewriting this node, traverse its children without + # re-applying rewrite selection logic to the node itself. + return self.generic_visit(node) else: new_items = [] for item in node: @@ -493,20 +495,20 @@ def generator(self) -> Iterable[Task]: while True: choices = sorted(probas.items(), key=lambda x: (x[1], x[0]), reverse=True) selected = [] - sum, prod = 0.0, 1.0 + selected_count, prod = 0.0, 1.0 for label, p in choices: if p >= 1.0: selected.append(label) continue - if (sum + 1) * prod * p > sum * prod: + if (selected_count + 1) * prod * p > selected_count * prod: selected.append(label) - sum, prod = sum + 1, prod * p + selected_count, prod = selected_count + 1, prod * p else: break applied = self.apply(set(selected)) masked = set(selected).difference(applied) task = Task(source=None, applied=list(applied), masked=list(masked)) - if sum * prod == 0 or all(probas[label] >= 1.0 for label in applied): + if selected_count * prod == 0 or all(probas[label] >= 1.0 for label in applied): break yield deepcopy(task) self._update_probas(probas, task, is_interesting=False) @@ -1007,7 +1009,14 @@ class Args(NamedTuple): async def main(args: Args): - source = args.source.read_text() + if not args.source.exists() or not args.source.is_file(): + raise FileNotFoundError(f"Source file '{args.source}' does not exist or is not a regular file.") + if not os.access(args.source, os.R_OK): + raise OSError(f"Source file '{args.source}' is not readable.") + try: + source = args.source.read_text() + except OSError as e: + raise OSError(f"Failed to read source file '{args.source}': {e}") from e manager = ParTaskManager( err_msg=args.err_msg, @@ -1132,7 +1141,7 @@ def cli_main(argv: "Sequence[str] | None" = None) -> None: from argparse import ArgumentParser parser = ArgumentParser( - usage="python autodd.py source --err-msg MSG -o OUTPUT [--backend {runner,subproc}] [--timeout SEC] [-j N]", + usage="python -m tilelang.autodd source --err-msg MSG -o OUTPUT [--backend {runner,subproc}] [--timeout SEC] [-j N]", description="Delta-debug the provided Python source until the target error message remains reproducible.", epilog="Author: Kexing Zhou ", ) From 03a3c76eef082a2db39645c5b98f35984576d0b3 Mon Sep 17 00:00:00 2001 From: KEKE046 Date: Thu, 8 Jan 2026 10:20:00 +0000 Subject: [PATCH 08/13] fix codeview comments --- tilelang/autodd.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tilelang/autodd.py b/tilelang/autodd.py index 9e2227917..c55918efd 100644 --- a/tilelang/autodd.py +++ b/tilelang/autodd.py @@ -704,7 +704,11 @@ def stop_proc(self): self.process.terminate() self.process = None - def __del__(self): + def __enter__(self): + self.start_proc() + return self + + def __exit__(self, exc_type, exc_value, traceback): self.stop_proc() async def run(self, code: str, timeout: float = 5.0): @@ -762,10 +766,10 @@ class SubProcRunner: def __init__(self): pass - def start_proc(self): + def __enter__(self): pass - def stop_proc(self): + def __exit__(self, exc_type, exc_value, traceback): pass async def run(self, code: str, timeout: float = 5.0): @@ -778,7 +782,7 @@ def run_subprocess(args): args, capture_output=True, text=True, # Decodes output as strings (Python 3.5+) - timeout=timeout, # Timeout after 5 seconds + timeout=timeout, # Timeout check=False, # Do not raise exception for non-zero exit codes ) return proc.stdout, proc.stderr, proc.returncode == 0 @@ -946,7 +950,7 @@ def post_proc(self, text): async def worker(self, wid: int): runner = AsyncPythonRunner() if self.backend == "runner" else SubProcRunner() - try: + with runner: while True: task = await self.get_next_task() if task is None: @@ -954,9 +958,6 @@ async def worker(self, wid: int): out, err, ok = await runner.run(task.source, timeout=self.timeout) is_interested = self.err_msg in out or self.err_msg in err await self.submit_result(task, is_interested) - finally: - if hasattr(runner, "stop_proc"): - runner.stop_proc() async def start_workers(self): if self.worker_tasks: From a096abb0d97f331e364a6d80e96cbb1691811495 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 13 Jan 2026 16:29:14 +0800 Subject: [PATCH 09/13] [Refactor] Move AutoDD detection to env module and update import logic * Refactor: Relocate the _is_running_autodd function to the env module for better organization and encapsulation. * Update initialization logic to skip logger and heavy imports based on a new light import mode, enhancing flexibility in module usage. * Ensure consistent handling of environment variables across the package, improving overall code clarity and maintainability. --- tilelang/__init__.py | 29 ++++++----------------------- tilelang/env.py | 26 ++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/tilelang/__init__.py b/tilelang/__init__.py index f08177181..9a439fdf7 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -7,24 +7,6 @@ from pathlib import Path -def _is_running_autodd() -> bool: - orig_argv = getattr(sys, "orig_argv", None) - if orig_argv is None: - return False - if "-mtilelang.autodd" in orig_argv: - return True - pos = orig_argv.index("-m") if "-m" in orig_argv else -1 - if pos != -1 and pos + 1 < len(orig_argv): - module_name = orig_argv[pos + 1] - if module_name == "tilelang.autodd" or module_name.startswith("tilelang.autodd."): - return True - return False - - -# check if we are running under AutoDD -_RUNNING_AUTODD = _is_running_autodd() - - def _compute_version() -> str: """Return the package version without being polluted by unrelated installs. @@ -115,8 +97,10 @@ def emit(self, record): set_log_level("INFO") -# Skip logger initialization when running under AutoDD -if not _RUNNING_AUTODD: +from .env import env as env # noqa: F401 + +# Skip logger initialization in light import mode +if not env.is_light_import(): _init_logger() del _init_logger @@ -141,11 +125,10 @@ def lazy_init(self, name, mode=ctypes.DEFAULT_MODE, *args, **kwargs): ctypes.CDLL.__init__ = old_init -# Skip import when running under AutoDD -if not _RUNNING_AUTODD: +# Skip heavy imports in light import mode +if not env.is_light_import(): with _lazy_load_lib(): from .env import enable_cache, disable_cache, is_cache_enabled # noqa: F401 - from .env import env as env # noqa: F401 import tvm import tvm.base # noqa: F401 diff --git a/tilelang/env.py b/tilelang/env.py index bd49d24b3..2b4868420 100644 --- a/tilelang/env.py +++ b/tilelang/env.py @@ -52,6 +52,21 @@ def _get_package_version(pkg: str) -> str | None: return None +def _is_running_autodd() -> bool: + """Detect if we are running under `python -m tilelang.autodd`.""" + orig_argv = getattr(sys, "orig_argv", None) + if orig_argv is None: + return False + if "-mtilelang.autodd" in orig_argv: + return True + pos = orig_argv.index("-m") if "-m" in orig_argv else -1 + if pos != -1 and pos + 1 < len(orig_argv): + module_name = orig_argv[pos + 1] + if module_name == "tilelang.autodd" or module_name.startswith("tilelang.autodd."): + return True + return False + + def _find_cuda_home() -> str: """Find the CUDA install path. @@ -326,6 +341,17 @@ def get_default_verbose(self) -> bool: """Get default verbose flag from environment.""" return self.TILELANG_DEFAULT_VERBOSE.lower() in ("1", "true", "yes", "on") + def is_running_autodd(self) -> bool: + """Return True if we are running under `python -m tilelang.autodd`.""" + # means we are running under `python -m tilelang.autodd` + return _is_running_autodd() + + def is_light_import(self) -> bool: + """Return True if we are running in light import mode.""" + # means we are running under `python -m tilelang.autodd` or some + # other scripts that only require the minimal environment variables. + return self.is_running_autodd() + # Instantiate as a global configuration object env = Environment() From 5c4ebfa6f99df6f977426f9f1f1664d316494bfb Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 13 Jan 2026 18:46:06 +0800 Subject: [PATCH 10/13] [Documentation] Add AutoDD section to debug_tools_for_tilelang.md * Introduced a comprehensive guide on AutoDD (Automatic Delta Debugging) for isolating bugs in TileLang programs. * Explained Delta Debugging methodology, usage, parameters, and provided examples for clarity. * Highlighted the benefits of using AutoDD for large codebases and hard-to-locate errors, emphasizing time-saving aspects. * Included tips for effective usage and a reference to a complete example in the documentation. --- docs/tutorials/debug_tools_for_tilelang.md | 105 ++++++++ examples/autodd/README.md | 126 ++++++++++ examples/autodd/tilelang_buggy.py | 232 ++++++++++++++++++ .../autodd/tilelang_minimized_expected.py | 51 ++++ 4 files changed, 514 insertions(+) create mode 100644 examples/autodd/README.md create mode 100644 examples/autodd/tilelang_buggy.py create mode 100644 examples/autodd/tilelang_minimized_expected.py diff --git a/docs/tutorials/debug_tools_for_tilelang.md b/docs/tutorials/debug_tools_for_tilelang.md index d98d4cb5e..078440f34 100644 --- a/docs/tutorials/debug_tools_for_tilelang.md +++ b/docs/tutorials/debug_tools_for_tilelang.md @@ -194,8 +194,113 @@ C_local inferenced layout: Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] ``` +## AutoDD: Automatic Delta Debugging + +When dealing with complex TileLang programs that produce errors, manually isolating the bug can be tedious. **AutoDD** (Automatic Delta Debugging) is a built-in tool that automatically simplifies your program to the minimal code needed to reproduce a specific error. + +### What is Delta Debugging? + +Delta Debugging is an automated debugging technique that: +1. Takes a program that triggers a bug +2. Systematically removes code fragments +3. Checks if the simplified program still triggers the same bug +4. Produces the minimal code that reproduces the bug + +AutoDD uses a Probability Distribution Driven Delta Debugging (PDD) algorithm for efficient minimization. + +### Why Use AutoDD? + +- **Large codebases**: Real projects often have hundreds of lines of configuration, helper functions, and logging +- **Hard-to-locate errors**: Error messages may point to TVM/CUDA internals rather than your TileLang code +- **Time-saving**: Manually deleting code to isolate bugs is very time-consuming + +AutoDD can reduce a 200+ line program to just 30 lines, directly exposing the root cause. + +### Basic Usage + +```bash +python -m tilelang.autodd --err-msg "" -o +``` + +### Parameters + +| Parameter | Description | +|-----------|-------------| +| `source` | Path to the input Python source file | +| `--err-msg` | Error message to match (searched in stdout or stderr) | +| `-o, --output` | Path to the minimized output file | +| `--backend` | Execution backend: `runner` (faster) or `subproc` (more stable), default `runner` | +| `--timeout` | Timeout for each task in seconds, default 60 | +| `-j, --jobs` | Number of parallel jobs, default 1 | + +### Example + +Suppose you have a complex TileLang program with a GEMM shape mismatch bug: + +```python +# buggy_matmul.py (200+ lines) +@tilelang.jit +def buggy_matmul(M, N, K, block_M, block_N, block_K, ...): + @T.prim_func + def matmul_kernel(...): + with T.Kernel(...) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_M, block_N), dtype) # Bug: should be (block_K, block_N) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + # ... lots of other code ... + T.gemm(A_shared, B_shared, C_local) # Error here + return matmul_kernel +``` + +Run AutoDD to minimize: + +```bash +python -m tilelang.autodd buggy_matmul.py --err-msg "Dimension mismatch" -o minimized.py -j 4 +``` + +AutoDD will produce a minimal reproduction: + +```python +# minimized.py (~30 lines) +import tilelang.language as T + +def buggy_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, *args, **kwargs): + @T.prim_func + def matmul_kernel(): + with T.Kernel(): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_M, block_N), dtype) # Bug exposed! + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.gemm(A_shared, B_shared, C_local) +``` + +### How AutoDD Works + +AutoDD uses AST (Abstract Syntax Tree) analysis with multiple rewrite rules: + +1. **Fast Reducers**: Remove statements, simplify if/for constructs +2. **Canonicalizers**: Expand with statements, add `*args, **kwargs` for compatibility +3. **Simplifiers**: Replace expressions with constants, simplify function calls +4. **Slow Reducers**: Remove arbitrary expressions, reduce integer constants + +### Tips + +- **Error message matching**: Use a unique substring from the error output +- **Timeout**: Increase `--timeout` for programs with long compilation times +- **Parallel jobs**: Use `-j 4` or higher to speed up minimization +- **Backend**: Try `--backend subproc` if `runner` is unstable + +### Complete Example + +A complete example is available in `examples/autodd/`: +- `tilelang_buggy.py`: A complex program with a bug (~200 lines) +- `tilelang_minimized_expected.py`: Expected output after AutoDD (~30 lines) +- `README.md`: Detailed documentation + ## Conclusion By carefully examining intermediate representations (IR) before final code generation—and by leveraging runtime printing through `T.print`—one can quickly diagnose where index calculations, copy logic, or other kernel operations deviate from the intended behavior. This two-pronged approach (inspecting IR transformations and using runtime prints) is often sufficient for resolving generation and correctness issues in TileLang programs. +For complex programs where manual debugging is tedious, **AutoDD** provides automated delta debugging to quickly isolate the minimal code that reproduces a bug. + For advanced performance tuning (e.g., analyzing memory bandwidth or occupancy), more specialized profiling tools such as **Nsight Compute**, **rocProf**, or vendor-specific profilers may be required. Those aspects will be covered in future documents. diff --git a/examples/autodd/README.md b/examples/autodd/README.md new file mode 100644 index 000000000..9ae9f9816 --- /dev/null +++ b/examples/autodd/README.md @@ -0,0 +1,126 @@ +# AutoDD - Automatic Delta Debugging for TileLang + +AutoDD (Automatic Delta Debugging) is a built-in debugging tool for TileLang that automatically simplifies complex Python programs to the minimal code needed to reproduce a specific error. This is extremely useful for debugging large, complex TileLang programs. + +## What is Delta Debugging? + +Delta Debugging is an automated debugging technique with the core idea: +1. Given a program that triggers a bug +2. Systematically remove code fragments from the program +3. Check if the simplified program still triggers the same bug +4. Eventually obtain the minimal code that triggers the bug + +AutoDD uses a Probability Distribution Driven Delta Debugging (PDD) algorithm for efficient search of minimized code. + +## Why AutoDD? + +When developing TileLang programs, bugs are often hidden in complex code: + +- **Lots of irrelevant code**: Real projects may have hundreds of lines of configuration, helper functions, logging, etc. +- **Hard to locate**: Error messages may point to underlying TVM/CUDA rather than TileLang code +- **Tedious debugging**: Manually deleting code to locate bugs is very time-consuming + +AutoDD automates this process, reducing hundreds of lines of code to just a few dozen, directly exposing the root cause of the problem. + +## Usage + +### Basic Usage + +```bash +python -m tilelang.autodd --err-msg "" -o +``` + +### Parameters + +| Parameter | Description | +|-----------|-------------| +| `source` | Path to the input Python source file | +| `--err-msg` | Error message to match (searched in stdout or stderr) | +| `-o, --output` | Path to the minimized output file | +| `--backend` | Execution backend: `runner` (faster) or `subproc` (more stable), default `runner` | +| `--timeout` | Timeout for each task in seconds, default 60 | +| `-j, --jobs` | Number of parallel jobs, default 1 | + +### Example + +Run AutoDD on `tilelang_buggy.py` in this directory: + +```bash +# Use 4 parallel jobs, search for "Dimension mismatch" error +python -m tilelang.autodd tilelang_buggy.py --err-msg "Dimension mismatch" -o minimized.py -j 4 + +# Or use subprocess backend (more stable but slower) +python -m tilelang.autodd tilelang_buggy.py --err-msg "Dimension mismatch" -o minimized.py --backend subproc +``` + +## Example Files + +### `tilelang_buggy.py` + +A complex TileLang program with a bug (~200 lines), containing: +- Multiple useless helper functions (`calculate_optimal_block_size`, `get_memory_requirements`, etc.) +- A complex configuration class (`MatmulConfig`) +- Unused benchmark code (`benchmark_pytorch`) +- **A GEMM shape mismatch bug** + +The bug is on line 124: +```python +B_shared = T.alloc_shared((block_M, block_N), dtype) # Wrong! Should be (block_K, block_N) +``` + +### `tilelang_minimized_expected.py` + +The expected output after AutoDD simplification (~30 lines). The simplified code clearly shows the root cause of the bug: + +```python +def buggy_matmul(...): + @T.prim_func + def matmul_kernel(): + with T.Kernel(): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_M, block_N), dtype) # Bug! + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.gemm(A_shared, B_shared, C_local) # Error occurs here +``` + +## How AutoDD Works + +AutoDD uses AST (Abstract Syntax Tree) analysis and multiple rewrite rules to simplify code: + +### 1. Fast Reducers +- **Statement removal**: Directly remove statements that don't affect bug reproduction +- **If statement simplification**: Simplify `if cond: body` to `body` +- **For loop simplification**: Bind loop variables to constants + +### 2. Canonicalizers +- **With statement expansion**: Convert `with expr as var` to explicit assignment +- **Function argument extension**: Add `*args, **kwargs` for compatibility + +### 3. Simplifiers +- **Assignment simplification**: Replace complex expressions with constants +- **Function call simplification**: Simplify `f(x)` to `x` +- **Binary operation simplification**: Simplify `a + b` to `a` or `b` + +### 4. Slow Reducers +- **Expression removal**: Remove arbitrary expressions +- **Argument removal**: Remove function arguments +- **Integer reduction**: Gradually reduce large integers + +## Use Cases + +1. **TileLang kernel debugging**: Simplify complex TileLang programs to locate bugs +2. **Bug report submission**: Generate minimal reproduction code for easier issue tracking +3. **Understanding errors**: Easier to understand the nature of errors after removing irrelevant code +4. **Regression testing**: Simplified code can serve as regression test cases + +## Notes + +1. **Error message matching**: The `--err-msg` parameter needs to exactly match a string in the error output +2. **Timeout setting**: For programs with long compilation times, you may need to increase `--timeout` +3. **Parallel jobs**: Increasing `-j` can speed up the simplification process but consumes more resources +4. **Backend selection**: If the `runner` backend is unstable, try the `subproc` backend + +## References + +- [Delta Debugging Paper](https://www.st.cs.uni-saarland.de/papers/tse2002/) +- [TileLang Documentation](https://github.com/tile-ai/tilelang) diff --git a/examples/autodd/tilelang_buggy.py b/examples/autodd/tilelang_buggy.py new file mode 100644 index 000000000..47e71fe50 --- /dev/null +++ b/examples/autodd/tilelang_buggy.py @@ -0,0 +1,232 @@ +""" +A complex TileLang program with lots of redundant code and a bug that triggers an error. +AutoDD will simplify it to the minimal code needed to reproduce the error. + +This example demonstrates how AutoDD can help developers quickly isolate bugs +in complex TileLang programs by automatically removing irrelevant code. + +To run AutoDD on this file: + python -m tilelang.autodd tilelang_buggy.py --err-msg "Dimension mismatch" -o minimized.py -j 4 + +The bug in this file: B_shared has shape (block_M, block_N) instead of (block_K, block_N), +causing a GEMM dimension mismatch error. +""" + +import tilelang +import tilelang.language as T +import torch + + +# Useless helper function - will be removed by AutoDD +def calculate_optimal_block_size(M, N, K): + """Calculate optimal block size - this function is completely useless""" + options = [32, 64, 128, 256] + best = 128 + for opt in options: + if M % opt == 0 and N % opt == 0: + best = opt + break + return best, best, 32 + + +def get_memory_requirements(M, N, K, block_M, block_N, block_K, dtype_size=2): + """Calculate memory requirements - completely useless""" + shared_mem_a = block_M * block_K * dtype_size + shared_mem_b = block_K * block_N * dtype_size + total_shared = shared_mem_a + shared_mem_b + return total_shared + + +def validate_parameters(M, N, K, block_M, block_N, block_K): + """Validate parameters - redundant check""" + if M <= 0 or N <= 0 or K <= 0: + raise ValueError("Matrix dimensions must be positive") + if block_M <= 0 or block_N <= 0 or block_K <= 0: + raise ValueError("Block sizes must be positive") + if M % block_M != 0: + print(f"Warning: M ({M}) not divisible by block_M ({block_M})") + if N % block_N != 0: + print(f"Warning: N ({N}) not divisible by block_N ({block_N})") + if K % block_K != 0: + print(f"Warning: K ({K}) not divisible by block_K ({block_K})") + return True + + +class MatmulConfig: + """Configuration class - increases code complexity but is actually useless""" + + def __init__(self, M, N, K): + self.M = M + self.N = N + self.K = K + self.block_M = 128 + self.block_N = 128 + self.block_K = 32 + self.num_stages = 3 + self.threads = 128 + self.dtype = "float16" + self.accum_dtype = "float32" + + def get_grid_size(self): + grid_x = (self.N + self.block_N - 1) // self.block_N + grid_y = (self.M + self.block_M - 1) // self.block_M + return grid_x, grid_y + + def get_shared_memory_size(self): + return get_memory_requirements( + self.M, self.N, self.K, self.block_M, self.block_N, self.block_K + ) + + def validate(self): + return validate_parameters( + self.M, self.N, self.K, self.block_M, self.block_N, self.block_K + ) + + +def create_reference_output(a, b, activation="relu"): + """Create reference output - not actually used in verification""" + result = a @ b + if activation == "relu": + result = torch.relu(result) + elif activation == "gelu": + result = torch.nn.functional.gelu(result) + elif activation == "sigmoid": + result = torch.sigmoid(result) + return result + + +def benchmark_pytorch(M, N, K, num_iters=10, warmup=5): + """PyTorch benchmark - not used""" + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + # Warmup + for _ in range(warmup): + _ = a @ b + torch.cuda.synchronize() + + # Benchmark + import time + start = time.time() + for _ in range(num_iters): + _ = a @ b + torch.cuda.synchronize() + end = time.time() + + return (end - start) / num_iters * 1000 # ms + + +# Main TileLang kernel - contains a BUG: GEMM shape mismatch! +@tilelang.jit +def buggy_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def matmul_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + # Allocate shared memory + A_shared = T.alloc_shared((block_M, block_K), dtype) + # BUG: the first dimension of B_shared should be block_K, but block_M is used here! + B_shared = T.alloc_shared((block_M, block_N), dtype) # Wrong shape! + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Allocate some useless temp variables + temp_buffer = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Zero out + T.clear(C_local) + T.clear(temp_buffer) + + # Main loop + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy a tile of A + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Copy a tile of B - shape can mismatch here too + T.copy(B[ko * block_K, bx * block_N], B_shared) + + # GEMM computation - shape mismatch will cause an error + # A_shared: (block_M, block_K) + # B_shared: (block_M, block_N) <- should be (block_K, block_N) + T.gemm(A_shared, B_shared, C_local) + + # ReLU activation + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0) + + # Some useless postprocessing + for i, j in T.Parallel(block_M, block_N): + if temp_buffer[i, j] > 0: + C_local[i, j] = C_local[i, j] + 0.0 + + # Write back result + T.copy(C_local, C[by * block_M, bx * block_N]) + + return matmul_kernel + + +def run_kernel(config): + """Run kernel - includes extra redundant logic""" + # Validate parameters + config.validate() + + # Get config + M, N, K = config.M, config.N, config.K + block_M, block_N, block_K = config.block_M, config.block_N, config.block_K + + # Calculate some useless statistics + grid_size = config.get_grid_size() + shared_mem = config.get_shared_memory_size() + print(f"Grid size: {grid_size}") + print(f"Shared memory: {shared_mem} bytes") + + # Create test data + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + c = torch.empty(M, N, device="cuda", dtype=torch.float16) + + # Compile and run kernel - will trigger the BUG here + kernel = buggy_matmul(M, N, K, block_M, block_N, block_K) + kernel(a, b, c) + + # Validate results (if it can get here) + ref_c = torch.relu(a @ b) + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("Kernel output matches PyTorch reference.") + + return c + + +def main(): + # Useless printing + print("=" * 60) + print("TileLang Matmul Kernel Test") + print("=" * 60) + + # Create config + M, N, K = 512, 512, 512 + config = MatmulConfig(M, N, K) + + # Calculate some useless values + optimal_block = calculate_optimal_block_size(M, N, K) + print(f"Optimal block size: {optimal_block}") + + # Run PyTorch benchmark - result is not used + # pytorch_time = benchmark_pytorch(M, N, K) + # print(f"PyTorch time: {pytorch_time:.3f} ms") + + # Run our kernel - will trigger the error here + try: + result = run_kernel(config) + print(f"Result shape: {result.shape}") + except Exception as e: + print(f"Error: {e}") + raise + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/examples/autodd/tilelang_minimized_expected.py b/examples/autodd/tilelang_minimized_expected.py new file mode 100644 index 000000000..2135f6fce --- /dev/null +++ b/examples/autodd/tilelang_minimized_expected.py @@ -0,0 +1,51 @@ +""" +This is the expected output after running AutoDD on tilelang_buggy.py. +AutoDD automatically simplified the 200+ line buggy program to ~30 lines +while preserving the ability to reproduce the error. + +The minimized code clearly shows the root cause of the bug: +- A_shared has shape (block_M, block_K) +- B_shared has shape (block_M, block_N) - should be (block_K, block_N) +- This causes a dimension mismatch in T.gemm() +""" + +import tilelang.language as T + + +class MatmulConfig: + + def __init__(self, *args, **kwargs): + self.M = 1 + self.N = 1 + self.K = 1 + self.block_M = 2 + self.block_N = 1 + self.block_K = 1 + + +def buggy_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, *args, **kwargs): + + @T.prim_func + def matmul_kernel(): + with T.Kernel(): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_M, block_N), dtype) # Bug: should be (block_K, block_N) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.gemm(A_shared, B_shared, C_local) + + +def run_kernel(config, *args, **kwargs): + M, N, K = (config.M, config.N, config.K) + block_M, block_N, block_K = (config.block_M, config.block_N, config.block_K) + buggy_matmul(M, N, K, block_M, block_N, block_K) + + +def main(*args, **kwargs): + config = MatmulConfig() + try: + run_kernel(config) + except Exception as e: + print(f'{e}') + + +main() From 076257be4b487a4be384aaab877a6db349ef8737 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 14 Jan 2026 17:23:29 +0800 Subject: [PATCH 11/13] [Refactor] Update flash attention implementation in TileLang examples * Refactor: Simplify the flash attention function signatures in example scripts to accept parameters directly, enhancing clarity and usability. * Update the kernel invocation logic in SparseFlashAttn class to align with the new function signatures. * Remove redundant code and improve the organization of dynamic parameters for better maintainability. * Enhance the handling of cache sequence lengths and block sizes in the regression performance tests, ensuring consistency across examples. * Clean up unused imports and streamline the code for improved readability and performance. --- ...xample_tilelang_sparse_gqa_decode_paged.py | 376 +++++++++--------- ...ilelang_sparse_gqa_decode_varlen_indice.py | 337 ++++++++-------- ..._tilelang_sparse_gqa_decode_varlen_mask.py | 337 ++++++++-------- src/op/atomic_add.cc | 7 +- src/op/copy.cc | 8 +- .../common/loop_parallel_transform_utils.h | 170 -------- src/transform/layout_inference.cc | 2 - 7 files changed, 518 insertions(+), 719 deletions(-) delete mode 100644 src/transform/common/loop_parallel_transform_utils.h diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py index 6e7321452..a93e4de13 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py @@ -2,7 +2,6 @@ import torch import torch.nn.functional as F import tilelang -from tilelang.autotuner import * import tilelang.language as T from einops import rearrange, einsum import argparse @@ -13,160 +12,159 @@ from heuristic import num_splits_heuristic -def flashattn(batch, heads, heads_kv, dim, dim_v): +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, heads_kv, dim, dim_v, block_N, block_H, page_block_size, num_stages, threads, num_pages): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) dtype = T.float16 accum_dtype = T.float32 kv_group_num = heads // heads_kv - @tilelang.jit( - out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - ) - def kernel_func( - block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, max_num_blocks_per_seq, max_selected_blocks + num_split = T.dynamic("num_split") + max_num_blocks_per_seq = T.dynamic("max_num_blocks_per_seq") + max_selected_blocks = T.dynamic("max_selected_blocks") + + shape_q = [batch, heads, dim] + shape_k = [num_pages, page_block_size, heads_kv, dim] + shape_v = [num_pages, page_block_size, heads_kv, dim_v] + shape_indices = [batch, heads_kv, max_selected_blocks] + shape_block_table = [batch, max_num_blocks_per_seq] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + assert block_N <= page_block_size and page_block_size % block_N == 0 + block_ratio = page_block_size // block_N + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + block_table: T.Tensor(shape_block_table, T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): - shape_q = [batch, heads, dim] - shape_k = [num_pages, page_block_size, heads_kv, dim] - shape_v = [num_pages, page_block_size, heads_kv, dim_v] - shape_indices = [batch, heads_kv, max_selected_blocks] - shape_block_table = [batch, max_num_blocks_per_seq] - shape_o = [batch, heads, dim_v] - part_shape = [batch, heads, num_split, dim_v] - valid_block_H = min(block_H, kv_group_num) - assert block_N <= page_block_size and page_block_size % block_N == 0 - block_ratio = page_block_size // block_N - - @T.prim_func - def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, T.int32), - cache_seqlens: T.Tensor([batch], T.int32), - block_table: T.Tensor(shape_block_table, T.int32), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - # flash_attn_split - with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_H, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim_v], dtype) - acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) - acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) - - scores_max = T.alloc_fragment([block_H], accum_dtype) - scores_max_prev = T.alloc_fragment([block_H], accum_dtype) - scores_scale = T.alloc_fragment([block_H], accum_dtype) - scores_sum = T.alloc_fragment([block_H], accum_dtype) - logsum = T.alloc_fragment([block_H], accum_dtype) - has_valid_block = T.alloc_var("bool") - - bid = bx - hid = by - sid = bz - cur_kv_head = hid // (kv_group_num // valid_block_H) - - T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - num_blocks = max_selected_blocks - blocks_per_split = T.floordiv(num_blocks, num_split) - remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) - start = blocks_per_split * sid + T.min(sid, remaining_blocks) - has_valid_block = False - for k in T.Pipelined(loop_range, num_stages=num_stages): - logical_block_idx = block_indices[bid, cur_kv_head, start + k] - if logical_block_idx >= 0: - has_valid_block = True - block_table_idx = T.floordiv(logical_block_idx, block_ratio) - block_tile_idx = T.floormod(logical_block_idx, block_ratio) - physical_block_idx = block_table[bid, block_table_idx] - T.copy(K[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], K_shared) - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else( - logical_block_idx * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j] - ) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_H): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + # flash_attn_split + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var(T.bool) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + num_blocks = max_selected_blocks + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + for k in T.Pipelined(loop_range, num_stages=num_stages): + logical_block_idx = block_indices[bid, cur_kv_head, start + k] + if logical_block_idx >= 0: + has_valid_block = True + block_table_idx = T.floordiv(logical_block_idx, block_ratio) + block_tile_idx = T.floormod(logical_block_idx, block_ratio) + physical_block_idx = block_table[bid, block_table_idx] + T.copy(K[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_H): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] *= scores_scale[i] - T.copy(V[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - if has_valid_block: - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] /= logsum[i] - + acc_s[i, j] = T.if_then_else( + logical_block_idx * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j] + ) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - - for i in T.Parallel(block_H): - if i < valid_block_H: - glse[bid, hid * valid_block_H + i, sid] = logsum[i] - + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: for i, j in T.Parallel(block_H, dim_v): - if i < valid_block_H: - Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] - - # combine - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim_v], accum_dtype) - o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_var(accum_dtype) - lse_logsum_local = T.alloc_var(accum_dtype) - lse_max_local = T.alloc_var(accum_dtype) - scale_local = T.alloc_var(accum_dtype) - max_split = T.alloc_var(T.int32) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local = -T.infinity(accum_dtype) - for k in T.serial(num_split): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + # TODO(lei): Support T.Parallel(valid_block_H) + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split: lse_local_split = glse[bz, by, k] - if lse_local_split != 0: - max_split = k - lse_max_local = T.max(lse_max_local, glse[bz, by, k]) - - for k in T.Pipelined(num_split, num_stages=1): - if k <= max_split: - lse_local_split = glse[bz, by, k] - lse_logsum_local += T.exp2(lse_local_split - lse_max_local) - lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local - for k in T.serial(num_split): - if k <= max_split: - for i in T.Parallel(dim_v): - po_local[i] = Output_partial[bz, by, k, i] - lse_local_split = glse[bz, by, k] - scale_local = T.exp2(lse_local_split - lse_logsum_local) - for i in T.Parallel(dim_v): - o_accum_local[i] += po_local[i] * scale_local - for i in T.Parallel(dim_v): - Output[bz, by, i] = o_accum_local[i] - - return main - - return kernel_func + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + if k <= max_split: + for i in T.Parallel(dim_v): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dim_v): + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] + + print(main) + return main class SparseFlashAttn(torch.nn.Module): @@ -181,19 +179,6 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, self.page_block_size = page_block_size self.num_pages = num_pages self.block_H = 64 - - self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( - block_N=block_N, - block_H=self.block_H, - page_block_size=page_block_size, - num_split=T.dynamic("num_split"), - num_stages=2, - threads=128, - num_pages=num_pages, - max_num_blocks_per_seq=T.dynamic("max_num_blocks_per_seq"), - max_selected_blocks=T.dynamic("max_selected_blocks"), - ) - props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -221,16 +206,19 @@ def forward(self, query, key, value, block_indices, cache_seqlens, block_table): glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - output = self.kernel( - query, - key, - value, - block_indices, - cache_seqlens, - block_table, - glse, - output_partial, - ) + output = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=self.block_H, + page_block_size=self.page_block_size, + num_stages=2, + threads=128, + num_pages=self.num_pages, + )(query, key, value, block_indices, cache_seqlens, block_table, glse, output_partial) return output @@ -513,6 +501,8 @@ def main(args): def run_regression_perf(args): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = ( args.batch, args.heads, @@ -524,15 +514,15 @@ def run_regression_perf(args): sparse_ratio = args.sparse_ratio block_N = args.block_N page_block_size = args.page_block_size - num_blocks = args.num_pages + num_pages = args.num_pages max_selected_blocks = int(math.ceil(max_cache_seqlen / block_N)) dtype = torch.float16 Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") cache_seqlens = torch.randint(max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device="cuda") K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") - K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device="cuda") - V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda") + K_cache = torch.zeros((num_pages, page_block_size, heads_kv, dim), dtype=dtype, device="cuda") + V_cache = torch.zeros((num_pages, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda") max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size)) block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device="cuda") block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), dtype=torch.int32, device="cuda") @@ -596,22 +586,20 @@ def run_regression_perf(args): for i in range(len(selected_blocks), max_selected_blocks): block_indices[seq_idx, head_idx, i] = -1 - sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_blocks) - kernel = sparse_attn.kernel - batch = sparse_attn.batch - heads = sparse_attn.heads - heads_kv = sparse_attn.heads_kv - dim_v = sparse_attn.dim_v - dim = sparse_attn.dim - block_size = sparse_attn.block_N + sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_pages) + batch = sparse_kernel.batch + heads = sparse_kernel.heads + heads_kv = sparse_kernel.heads_kv + dim_v = sparse_kernel.dim_v + dim = sparse_kernel.dim + block_size = sparse_kernel.block_N max_selected_blocks = block_indices.shape[-1] - num_m_blocks = 1 * (heads // heads_kv + sparse_attn.block_H - 1) // sparse_attn.block_H + num_m_blocks = 1 * (heads // heads_kv + sparse_kernel.block_H - 1) // sparse_kernel.block_H num_n_blocks = max_selected_blocks size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks - - num_sm = sparse_attn.num_sm + num_sm = sparse_kernel.num_sm num_split = num_splits_heuristic( total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 @@ -619,18 +607,22 @@ def run_regression_perf(args): glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=sparse_kernel.block_H, + page_block_size=sparse_kernel.page_block_size, + num_stages=2, + threads=128, + num_pages=sparse_kernel.num_pages, + ) def run_kernel_only(): - kernel( - Q, - K_cache, - V_cache, - block_indices, - cache_seqlens, - block_table, - glse, - output_partial, - ) + kernel(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table, glse, output_partial) return do_bench(run_kernel_only, backend="cupti") diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py index d6cf7d917..f432fe0fa 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -10,153 +10,150 @@ from tilelang.profiler import do_bench -def flashattn(batch, heads, heads_kv, dim, dim_v): +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, heads_kv, dim, dim_v, block_N, block_H, num_stages, threads): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) dtype = T.float16 accum_dtype = T.float32 kv_group_num = heads // heads_kv - @tilelang.jit( - out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - ) - def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, max_selected_blocks): - shape_q = [batch, heads, dim] - shape_k = [batch, max_cache_seqlen, heads_kv, dim] - shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] - shape_indices = [batch, heads_kv, max_selected_blocks] - shape_o = [batch, heads, dim_v] - part_shape = [batch, heads, num_split, dim_v] - valid_block_H = min(block_H, kv_group_num) - - @T.prim_func - def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, T.int32), - cache_seqlens: T.Tensor([batch], T.int32), - # actual_num_blocks: T.Tensor([batch], T.int32), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - # flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial) - # flash_attn_split - with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_H, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim_v], dtype) - # O_shared = T.alloc_shared([valid_block_H, dim_v], dtype) - acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) - acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) - - scores_max = T.alloc_fragment([block_H], accum_dtype) - scores_max_prev = T.alloc_fragment([block_H], accum_dtype) - scores_scale = T.alloc_fragment([block_H], accum_dtype) - scores_sum = T.alloc_fragment([block_H], accum_dtype) - logsum = T.alloc_fragment([block_H], accum_dtype) - has_valid_block = T.alloc_var("bool") - - bid = bx - hid = by - sid = bz - cur_kv_head = hid // (kv_group_num // valid_block_H) - - T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - num_blocks = max_selected_blocks - blocks_per_split = T.floordiv(num_blocks, num_split) - remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) - start = blocks_per_split * sid + T.min(sid, remaining_blocks) - has_valid_block = False - - for k in T.Pipelined(loop_range, num_stages=num_stages): - i_s = block_indices[bid, cur_kv_head, start + k] - if i_s >= 0: - has_valid_block = True - T.copy(K[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], K_shared) - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j]) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_H): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + num_split = T.dynamic("num_split") + max_cache_seqlen = T.dynamic("max_cache_seqlen") + max_selected_blocks = T.dynamic("max_selected_blocks") + + shape_q = [batch, heads, dim] + shape_k = [batch, max_cache_seqlen, heads_kv, dim] + shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] + shape_indices = [batch, heads_kv, max_selected_blocks] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + # actual_num_blocks: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var(T.bool) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + num_blocks = max_selected_blocks + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + + for k in T.Pipelined(loop_range, num_stages=num_stages): + i_s = block_indices[bid, cur_kv_head, start + k] + if i_s >= 0: + has_valid_block = True + T.copy(K[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_H): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] *= scores_scale[i] - T.copy(V[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - if has_valid_block: - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] /= logsum[i] - + acc_s[i, j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j]) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - - for i in T.Parallel(block_H): - if i < valid_block_H: - glse[bid, hid * valid_block_H + i, sid] = logsum[i] - + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: for i, j in T.Parallel(block_H, dim_v): - if i < valid_block_H: - Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] - - # combine - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim_v], accum_dtype) - o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_var(accum_dtype) - lse_logsum_local = T.alloc_var(accum_dtype) - lse_max_local = T.alloc_var(accum_dtype) - scale_local = T.alloc_var(accum_dtype) - max_split = T.alloc_var(T.int32) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local = -T.infinity(accum_dtype) - for k in T.serial(num_split): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + # TODO(lei): Support T.Parallel(valid_block_H) + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split: + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + if k <= max_split: + for i in T.Parallel(dim_v): + po_local[i] = Output_partial[bz, by, k, i] lse_local_split = glse[bz, by, k] - if lse_local_split != 0: - max_split = k - lse_max_local = T.max(lse_max_local, glse[bz, by, k]) - - for k in T.Pipelined(num_split, num_stages=1): - if k <= max_split: - lse_local_split = glse[bz, by, k] - lse_logsum_local += T.exp2(lse_local_split - lse_max_local) - lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local - for k in T.serial(num_split): - if k <= max_split: - for i in T.Parallel(dim_v): - po_local[i] = Output_partial[bz, by, k, i] - lse_local_split = glse[bz, by, k] - scale_local = T.exp2(lse_local_split - lse_logsum_local) - for i in T.Parallel(dim_v): - o_accum_local[i] += po_local[i] * scale_local - for i in T.Parallel(dim_v): - Output[bz, by, i] = o_accum_local[i] - - return main - - return kernel_func + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dim_v): + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] + + return main class SparseFlashAttn(torch.nn.Module): @@ -168,19 +165,7 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): self.dim = dim self.dim_v = dim_v self.block_size = block_size - self.block_H = 64 - - self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( - block_N=block_size, - block_H=self.block_H, - num_split=T.dynamic("num_split"), - num_stages=2, - threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - max_selected_blocks=T.dynamic("max_selected_blocks"), - ) - props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -208,7 +193,17 @@ def forward(self, query, key, value, block_indices, cache_seqlens): glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - output = self.kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial) + output = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=self.block_H, + num_stages=2, + threads=128, + )(query, key, value, block_indices, cache_seqlens, glse, output_partial) return output @@ -252,14 +247,16 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, block_N=block_size, block_H=block_H, - num_split=T.dynamic("num_split"), num_stages=2, threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - max_selected_blocks=T.dynamic("max_selected_blocks"), ) output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial) @@ -311,7 +308,7 @@ def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_se return output -def debug(name, expect, actual, atol=1e-3, rtol=1e-3): +def assert_close(name, expect, actual, atol=1e-3, rtol=1e-3): all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol) print(name + " all_close={}".format(all_close)) if not all_close: @@ -324,29 +321,17 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + dtype = torch.float16 sparse_ratio = sparse_ratio block_size = block_size max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) - print("max_selected_blocks: ", max_selected_blocks) - dtype = torch.float16 Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") - # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') - # # Ensure at least one element equals cache_seqlen - # random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - # # cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence - - print("cache_seqlens: ", cache_seqlens) - max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() - print("max_valid_num_blocks: ", max_valid_num_blocks) - # Initialize block_indices with -1 (for padding blocks) block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") - # max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size) - # block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda') # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): @@ -354,27 +339,17 @@ def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=12 if max_valid_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] - # valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks] block_indices[b, h, : len(valid_indices)] = valid_indices - # Sort indices within each batch-group for consistency block_indices, _ = block_indices.sort(dim=-1, descending=True) - # print("block_indices: ", block_indices) - actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0] - print("actual_num_blocks: ", actual_num_blocks) - # print(block_indices.shape, actual_num_blocks.shape) - max_num_blocks = torch.max(max_valid_num_blocks).item() - print("max_num_blocks: ", max_num_blocks) # parity reference ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) - debug("output", ref, out, atol=1e-3, rtol=1e-3) - - import flash_attn # noqa: F401 + assert_close("output", ref, out, atol=1e-3, rtol=1e-3) ## latency reference for _ in range(10): @@ -387,12 +362,10 @@ def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=12 print("dense time: ", (time.time() - start) / 100 * 1000) for _ in range(10): - # out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) torch.cuda.synchronize() start = time.time() for _ in range(100): - # out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) torch.cuda.synchronize() print("sparse time: ", (time.time() - start) / 100 * 1000) diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py index e48428fb8..75c6ed46d 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -1,7 +1,6 @@ import torch import torch.nn.functional as F import tilelang -from tilelang.autotuner import * import tilelang.language as T from einops import rearrange, einsum import argparse @@ -11,137 +10,146 @@ from tilelang.profiler import do_bench -def flashattn(batch, heads, heads_kv, dim, dim_v): +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, heads_kv, dim, dim_v, block_N, block_H, num_stages, threads): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) dtype = T.float16 accum_dtype = T.float32 kv_group_num = heads // heads_kv - @tilelang.jit( - out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - ) - def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks): - shape_q = [batch, heads, dim] - shape_k = [batch, max_cache_seqlen, heads_kv, dim] - shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] - shape_mask = [batch, heads_kv, num_blocks] - shape_o = [batch, heads, dim_v] - part_shape = [batch, heads, num_split, dim_v] - valid_block_H = min(block_H, kv_group_num) - - @T.prim_func - def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_mask: T.Tensor(shape_mask, T.bool), - cache_seqlens: T.Tensor([batch], T.int32), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_H, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim_v], dtype) - acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) - acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) - - scores_max = T.alloc_fragment([block_H], accum_dtype) - scores_max_prev = T.alloc_fragment([block_H], accum_dtype) - scores_scale = T.alloc_fragment([block_H], accum_dtype) - scores_sum = T.alloc_fragment([block_H], accum_dtype) - logsum = T.alloc_fragment([block_H], accum_dtype) - has_valid_block = T.alloc_var("bool") - - bid = bx - hid = by - sid = bz - cur_kv_head = hid // (kv_group_num // valid_block_H) - - T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - blocks_per_split = T.floordiv(num_blocks, num_split) - remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) - start = blocks_per_split * sid + T.min(sid, remaining_blocks) - has_valid_block = False - for k in T.Pipelined(loop_range, num_stages=num_stages): - if block_mask[bid, hid, start + k]: - has_valid_block = True - T.copy(K[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], K_shared) - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else( - (start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j] - ) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_H): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_H): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] *= scores_scale[i] - T.copy(V[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - if has_valid_block: - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] /= logsum[i] + num_split = T.dynamic("num_split") + max_cache_seqlen = T.dynamic("max_cache_seqlen") + num_blocks = T.dynamic("num_blocks") + + shape_q = [batch, heads, dim] + shape_k = [batch, max_cache_seqlen, heads_kv, dim] + shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] + shape_mask = [batch, heads_kv, num_blocks] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_mask: T.Tensor(shape_mask, T.bool), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var(T.bool) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + for k in T.Pipelined(loop_range, num_stages=num_stages): + if block_mask[bid, hid, start + k]: + has_valid_block = True + T.copy(K[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else( + (start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j] + ) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - - for i in T.Parallel(block_H): - if i < valid_block_H: - glse[bid, hid * valid_block_H + i, sid] = logsum[i] - + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: for i, j in T.Parallel(block_H, dim_v): - if i < valid_block_H: - Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] - - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim_v], accum_dtype) - o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_var(accum_dtype) - lse_logsum_local = T.alloc_var(accum_dtype) - lse_max_local = T.alloc_var(accum_dtype) - scale_local = T.alloc_var(accum_dtype) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local = -T.infinity(accum_dtype) - for k in T.serial(num_split): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + # TODO(lei): Support T.Parallel(valid_block_H) + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k lse_max_local = T.max(lse_max_local, glse[bz, by, k]) - for k in T.Pipelined(num_split, num_stages=1): + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split: lse_local_split = glse[bz, by, k] lse_logsum_local += T.exp2(lse_local_split - lse_max_local) - lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local - for k in T.serial(num_split): + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + if k <= max_split: for i in T.Parallel(dim_v): po_local[i] = Output_partial[bz, by, k, i] lse_local_split = glse[bz, by, k] scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim_v): o_accum_local[i] += po_local[i] * scale_local - for i in T.Parallel(dim_v): - Output[bz, by, i] = o_accum_local[i] + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] - return main - - return kernel_func + return main class SparseFlashAttn(torch.nn.Module): @@ -153,19 +161,7 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): self.dim = dim self.dim_v = dim_v self.block_size = block_size - self.block_H = 64 - - self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( - block_N=block_size, - block_H=self.block_H, - num_split=T.dynamic("num_split"), - num_stages=2, - threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - num_blocks=T.dynamic("num_blocks"), - ) - props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -176,27 +172,35 @@ def forward(self, query, key, value, block_mask, cache_seqlens): dim_v = self.dim_v dim = self.dim block_size = self.block_size - block_H = self.block_H max_cache_seqlen = key.shape[1] # get num_split max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size - num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_m_blocks = 1 * (heads // heads_kv + self.block_H - 1) // self.block_H num_n_blocks = max_selected_blocks size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks - # num_sm = 132 num_sm = self.num_sm num_split = num_splits_heuristic( total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 ) - # print("num_split: ", num_split) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") - Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - output = self.kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + + output = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=self.block_H, + num_stages=2, + threads=128, + )(query, key, value, block_mask, cache_seqlens, glse, output_partial) return output - def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, block_size): """ Args: @@ -233,21 +237,21 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 ) - kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, block_N=block_size, block_H=block_H, - num_split=T.dynamic("num_split"), num_stages=2, threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - num_blocks=T.dynamic("num_blocks"), ) - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") - Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - # print(kernel.get_kernel_source()) output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) - return output @@ -297,12 +301,10 @@ def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_se return output -def debug(name, expect, actual, atol=1e-3, rtol=1e-3): +def assert_close(name, expect, actual, atol=1e-3, rtol=1e-3): all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol) print(name + " all_close={}".format(all_close)) if not all_close: - # print(expect[3, 28]) - # print(actual[3, 28]) diff = (expect - actual).abs() print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) max_indices = torch.nonzero(diff == diff.max().item()) @@ -353,7 +355,7 @@ def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=12 # out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size) model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) out = model(Q, K, V, block_mask, cache_seqlens) - debug("output", ref, out, atol=1e-3, rtol=1e-3) + assert_close("output", ref, out, atol=1e-3, rtol=1e-3) import flash_attn # noqa: F401 @@ -381,12 +383,13 @@ def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=12 def run_regression_perf(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v sparse_ratio = sparse_ratio block_size = block_size max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) dtype = torch.float16 - Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") @@ -408,31 +411,41 @@ def run_regression_perf(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, di perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] block_mask[b, h, perm] = True - model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) - batch = model.batch - heads = model.heads - heads_kv = model.heads_kv - dim_v = model.dim_v - dim = model.dim - block_size = model.block_size - block_H = model.block_H - max_cache_seqlen = K.shape[1] + sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) + batch = sparse_kernel.batch + heads = sparse_kernel.heads + heads_kv = sparse_kernel.heads_kv + dim_v = sparse_kernel.dim_v + dim = sparse_kernel.dim + block_size = sparse_kernel.block_size max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size - num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H - num_n_blocks = max_selected_blocks + num_m_blocks = 1 * (heads // heads_kv + sparse_kernel.block_H - 1) // sparse_kernel.block_H + num_n_blocks = max_selected_blocks size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks - num_sm = model.num_sm + num_sm = sparse_kernel.num_sm + num_split = num_splits_heuristic( total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 ) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") - Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - kernel = model.kernel + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=sparse_kernel.block_H, + num_stages=2, + threads=128, + ) def run_kernel_only(): - kernel(Q, K, V, block_mask, cache_seqlens, glse, Output_partial) + kernel(Q, K, V, block_mask, cache_seqlens, glse, output_partial) return do_bench(run_kernel_only, backend="cupti") diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index e918b4ed7..36a302fd9 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -14,7 +14,6 @@ #include "../target/utils.h" #include "../transform/atomicadd_vectorize.h" #include "../transform/common/loop_fusion_utils.h" -#include "../transform/common/loop_parallel_transform_utils.h" #include "../transform/loop_partition.h" #include "builtin.h" @@ -398,8 +397,6 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } auto simt_loop = MakeSIMTLoop(analyzer); auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); - auto transformed_loop = - Downcast(ParallelLoopTransformer::Substitute(fused_loop)); auto GetArchInt = [&](const Target &tgt) -> int { int arch_int = 0; @@ -525,12 +522,12 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return {loop_layout, pred}; }; - auto ret = AtomicAddInferLayout(transformed_loop, + auto ret = AtomicAddInferLayout(fused_loop, {T.target, T.thread_bounds, T.layout_map, analyzer, false, T.buffer_remap}); Fragment loop_layout = ret.loop_layout; auto thread_loop = - PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout); + PartitionLoop(fused_loop, T.thread_var, analyzer, loop_layout); auto vectorized_thread_loop = VectorizeAtomicAdd(thread_loop, GetArchInt(target)); return vectorized_thread_loop; diff --git a/src/op/copy.cc b/src/op/copy.cc index 070df4305..71e380591 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -12,7 +12,6 @@ #include "../layout/tcgen05_layout.h" #include "../target/utils.h" #include "../transform/common/loop_fusion_utils.h" -#include "../transform/common/loop_parallel_transform_utils.h" #include "../transform/loop_partition.h" #include "../transform/loop_vectorize.h" #include "utils.h" @@ -787,11 +786,8 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, auto simt_loop = MakeSIMTLoop(analyzer); auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); - auto transformed_loop = - Downcast(ParallelLoopTransformer::Substitute(fused_loop)); - For vectorized_thread_loop; - auto par_op = ParallelOp(transformed_loop); + auto par_op = ParallelOp(fused_loop); if (is_cpu_target || IsLocalBuffer(src) || IsLocalBuffer(dst)) { if (IsLocalBuffer(src) && !IsLocalBuffer(dst)) { @@ -799,7 +795,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, << dst.scope() << " buffer `" << dst->name << "` may cause conflicted write."; } - vectorized_thread_loop = VectorizeLoop(transformed_loop); + vectorized_thread_loop = VectorizeLoop(fused_loop); return vectorized_thread_loop; } else { std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, diff --git a/src/transform/common/loop_parallel_transform_utils.h b/src/transform/common/loop_parallel_transform_utils.h deleted file mode 100644 index 52a5a9b97..000000000 --- a/src/transform/common/loop_parallel_transform_utils.h +++ /dev/null @@ -1,170 +0,0 @@ -/*! - * \file common.h - * \brief Common utilities for TL transforms - */ - -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "arith/ir_mutator_with_analyzer.h" -#include "arith/ir_visitor_with_analyzer.h" -#include - -#include "../../op/utils.h" - -namespace tvm { -namespace tl { - -using namespace tir; -using arith::IRMutatorWithAnalyzer; -using arith::IRVisitorWithAnalyzer; - -class ParallelLoopTransformer : public IRMutatorWithAnalyzer { -public: - static Stmt Substitute(const Stmt &stmt, bool skip_thread_partition = false) { - arith::Analyzer analyzer; - ParallelLoopTransformer transformer(&analyzer); - return transformer.VisitStmt(stmt); - } - - ParallelLoopTransformer(arith::Analyzer *analyzer) - : IRMutatorWithAnalyzer(analyzer) {} - - Stmt VisitStmt_(const ForNode *op) final { - - if (op->kind != ForKind::kParallel) - return StmtMutator::VisitStmt_(op); - - // Collect loop variables and ranges - auto for_node = tvm::ffi::GetRef(op); - Array loop_vars; - Array loop_extents; - Stmt body = op->body; - - // Bind the range of outer loop variables - analyzer_->Bind(op->loop_var, Range::FromMinExtent(0, op->extent)); - loop_vars.push_back(op->loop_var); - loop_extents.push_back(op->extent); - - // If there are inner loops, bind their ranges as well - while (const ForNode *inner = body.as()) { - analyzer_->Bind(inner->loop_var, Range::FromMinExtent(0, inner->extent)); - loop_vars.push_back(inner->loop_var); - loop_extents.push_back(inner->extent); - body = inner->body; - } - - ICHECK(loop_vars.size() == loop_extents.size()) - << "loop_vars and loop_extents size mismatch"; - - // Collect buffer access information - BufferAccessCollector collector; - collector(op->body); - - PrimExpr condition; - - for (const auto &[buffer, indices] : collector.buffer_indices) { - ICHECK(indices.size() == buffer->shape.size()) - << "indices size mismatch with buffer shape"; - - for (size_t i = 0; i < indices.size(); ++i) { - auto index = indices[i]; - auto bound = analyzer_->const_int_bound(index); - - // Collect the variables that used in the index - std::unordered_set used_vars; - // post order visit the index - PostOrderVisit(index, [&](const ObjectRef &obj) { - if (const VarNode *v = obj.as()) { - used_vars.insert(tvm::ffi::GetRef(v)); - } - }); - if (used_vars.empty()) { - continue; - } - - // find related loop vars - Array related_loop_vars; - for (size_t j = 0; j < loop_vars.size(); ++j) { - auto loop_var = loop_vars[j]; - // if find related, pop the loop_vars and loop_extents - if (used_vars.count(loop_var)) { - related_loop_vars.push_back(loop_var); - } - if (related_loop_vars.size() > 1) { - // Only one related loop var is supported transformation currently. - return for_node; - } - - auto bound = analyzer_->const_int_bound(index); - int64_t upper_bound = bound->max_value + 1; - int64_t shape = Downcast(buffer->shape[i])->value; - if (upper_bound < shape) { - PrimExpr predicate = LT(index, IntImm(index.dtype(), upper_bound)); - condition = - condition.defined() ? And(condition, predicate) : predicate; - } - } - } - } - - if (condition.defined()) { - body = IfThenElse(condition, body); - - for (int j = loop_vars.size() - 1; j >= 0; --j) { - auto loop_var = loop_vars[j]; - auto loop_extent = loop_extents[j]; - body = For(loop_var, 0, loop_extent, ForKind::kParallel, body); - } - - return Downcast(body); - } - - // Only traverse the outer loop - return for_node; - } - - // Helper class for collecting buffer access information, only counts fragment - // buffer access - class BufferAccessCollector : public StmtExprVisitor { - public: - void VisitExpr_(const BufferLoadNode *op) final { - if (IsFragmentBuffer(op->buffer)) { - if (buffer_indices.find(op->buffer) == buffer_indices.end()) { - buffer_indices[op->buffer] = op->indices; - } else { - // check equal - ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices)) - << "indices mismatch for buffer: " << op->buffer; - } - } - StmtExprVisitor::VisitExpr_(op); - } - - void VisitStmt_(const BufferStoreNode *op) final { - if (IsFragmentBuffer(op->buffer)) { - if (buffer_indices.find(op->buffer) == buffer_indices.end()) { - buffer_indices[op->buffer] = op->indices; - } else { - // check equal - ICHECK(StructuralEqual()(buffer_indices[op->buffer], op->indices)) - << "indices mismatch for buffer: " << op->buffer; - } - } - StmtExprVisitor::VisitStmt_(op); - } - - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> - buffer_indices; - }; -}; - -} // namespace tl -} // namespace tvm diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 15d7f71e2..a622d71f4 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -26,7 +26,6 @@ #include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h" #include "common/loop_fusion_utils.h" -#include "common/loop_parallel_transform_utils.h" #include "common/union_find.h" #include "layout_reducer.h" #include "parallel_loop_layout_validator.h" @@ -1253,7 +1252,6 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { tvm::transform::Pass LayoutInference() { using namespace tir::transform; auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { - f.CopyOnWrite()->body = ParallelLoopTransformer::Substitute(f->body); ThreadBindingCollector collector; collector(f->body); bool has_thread_binding = !collector.thread_binding_.empty(); From 6581b5fe14a9cbc5f857ac33f77b4a04e9549083 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 14 Jan 2026 17:29:16 +0800 Subject: [PATCH 12/13] [Refactor] Clean up code formatting and improve readability in TileLang examples * Refactor: Consolidate multi-line function calls into single lines for better clarity in `tilelang_buggy.py` and `tilelang_minimized_expected.py`. * Remove unnecessary blank lines and streamline print statement formatting for consistency. * Enhance overall code organization and maintainability across example scripts. --- examples/autodd/tilelang_buggy.py | 9 +++------ examples/autodd/tilelang_minimized_expected.py | 4 +--- .../example_tilelang_sparse_gqa_decode_varlen_mask.py | 5 ++--- src/op/atomic_add.cc | 6 +++--- 4 files changed, 9 insertions(+), 15 deletions(-) diff --git a/examples/autodd/tilelang_buggy.py b/examples/autodd/tilelang_buggy.py index 47e71fe50..d2c5469bb 100644 --- a/examples/autodd/tilelang_buggy.py +++ b/examples/autodd/tilelang_buggy.py @@ -73,14 +73,10 @@ def get_grid_size(self): return grid_x, grid_y def get_shared_memory_size(self): - return get_memory_requirements( - self.M, self.N, self.K, self.block_M, self.block_N, self.block_K - ) + return get_memory_requirements(self.M, self.N, self.K, self.block_M, self.block_N, self.block_K) def validate(self): - return validate_parameters( - self.M, self.N, self.K, self.block_M, self.block_N, self.block_K - ) + return validate_parameters(self.M, self.N, self.K, self.block_M, self.block_N, self.block_K) def create_reference_output(a, b, activation="relu"): @@ -107,6 +103,7 @@ def benchmark_pytorch(M, N, K, num_iters=10, warmup=5): # Benchmark import time + start = time.time() for _ in range(num_iters): _ = a @ b diff --git a/examples/autodd/tilelang_minimized_expected.py b/examples/autodd/tilelang_minimized_expected.py index 2135f6fce..3dc88f992 100644 --- a/examples/autodd/tilelang_minimized_expected.py +++ b/examples/autodd/tilelang_minimized_expected.py @@ -13,7 +13,6 @@ class MatmulConfig: - def __init__(self, *args, **kwargs): self.M = 1 self.N = 1 @@ -24,7 +23,6 @@ def __init__(self, *args, **kwargs): def buggy_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, *args, **kwargs): - @T.prim_func def matmul_kernel(): with T.Kernel(): @@ -45,7 +43,7 @@ def main(*args, **kwargs): try: run_kernel(config) except Exception as e: - print(f'{e}') + print(f"{e}") main() diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py index 75c6ed46d..e588ec54c 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -81,9 +81,7 @@ def main( T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else( - (start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j] - ) + acc_s[i, j] = T.if_then_else((start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j]) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) @@ -201,6 +199,7 @@ def forward(self, query, key, value, block_mask, cache_seqlens): )(query, key, value, block_mask, cache_seqlens, glse, output_partial) return output + def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, block_size): """ Args: diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 36a302fd9..47c656f61 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -522,9 +522,9 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return {loop_layout, pred}; }; - auto ret = AtomicAddInferLayout(fused_loop, - {T.target, T.thread_bounds, T.layout_map, - analyzer, false, T.buffer_remap}); + auto ret = + AtomicAddInferLayout(fused_loop, {T.target, T.thread_bounds, T.layout_map, + analyzer, false, T.buffer_remap}); Fragment loop_layout = ret.loop_layout; auto thread_loop = PartitionLoop(fused_loop, T.thread_var, analyzer, loop_layout); From 8c7d3017d01d57347d92e4a181fcdcb227f75e9f Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Jan 2026 18:37:17 +0800 Subject: [PATCH 13/13] [Refactor] Remove unused import in fill.cc * Cleaned up the fill.cc file by removing the unused import of loop_parallel_transform_utils.h, improving code clarity and maintainability. --- src/op/fill.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/op/fill.cc b/src/op/fill.cc index 02962d242..6a1768668 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -13,7 +13,6 @@ #include "../layout/tcgen05_layout.h" #include "../target/utils.h" #include "../transform/common/loop_fusion_utils.h" -#include "../transform/common/loop_parallel_transform_utils.h" #include "../transform/loop_partition.h" #include "../transform/loop_vectorize.h" #include "builtin.h"