diff --git a/.github/actions/build-triton/action.yml b/.github/actions/build-triton/action.yml index 5301959..c800df8 100644 --- a/.github/actions/build-triton/action.yml +++ b/.github/actions/build-triton/action.yml @@ -5,8 +5,8 @@ description: | - if the artifact does not exist, check out Triton and build it, using the given LLVM installation directory; then, upload it as an artifact - if the artifact does exist, download and extract it - - return the path the Triton installation directory as well as the build directory inside of it (e.g., - `build/cmake.linux-x86_64-cpython-3.14`). + - return the path to the Triton installation directory, which contains the installable C++ tree + (bin, include, lib) alongside the importable Python package under `python/`. This action expects `sccache` to be set up by the caller. @@ -94,7 +94,11 @@ runs: run: | make install find build/install -type f -executable | xargs strip --strip-debug - tar czf ../"${{ env.triton_install_dir }}.tar.gz" --transform="s|^build/install|${{ env.triton_install_dir }}|" build/install + find python -type d -name '__pycache__' -prune -exec rm -rf {} + 2>/dev/null || true + tar czhf ../"${{ env.triton_install_dir }}.tar.gz" \ + --transform="s|^build/install|${{ env.triton_install_dir }}|" \ + --transform="s|^python|${{ env.triton_install_dir }}/python|" \ + build/install python - name: Upload Triton artifact if: steps.check-artifact.outputs.exists == 'false' diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3c64896..bb75636 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -113,6 +113,8 @@ jobs: EXTRA_CMAKE_ARGS="-DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache" - name: Run tests + env: + PYTHONPATH: ${{ github.workspace }}/${{ steps.build-triton.outputs.triton_install_dir }}/python run: > make test LLVM_INSTALL_DIR="${{ steps.build-llvm.outputs.llvm_install_dir }}" TRITON_INSTALL_DIR="${{ steps.build-triton.outputs.triton_install_dir }}" diff --git a/Makefile b/Makefile index 5d79d40..b08bc55 100644 --- a/Makefile +++ b/Makefile @@ -20,9 +20,16 @@ build: configure cmake --build ${BUILD_DIR} .PHONY: test -test: +test: test-lit test-unit + +.PHONY: test-lit +test-lit: ninja -C ${BUILD_DIR} check-lit-tests +.PHONY: test-unit +test-unit: + TRITON_EXT_BUILD_DIR="${BUILD_DIR}" python -m pytest testing/ -v + .PHONY: clean clean: rm -rf ${BUILD_DIR} diff --git a/ci/triton-hash.txt b/ci/triton-hash.txt index e5c32b1..5c0215a 100644 --- a/ci/triton-hash.txt +++ b/ci/triton-hash.txt @@ -1 +1 @@ -4768da5e8228dfbda8e0b7a61101f87d953341bd +7a5d6a3dec31f865d0e6a6ce751fed1cc10a5f5a diff --git a/requirements.txt b/requirements.txt index 9463836..0e49ab8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ lit~=18.1 ninja~=1.13 pre_commit~=4.6 +pytest~=8.3 requests~=2.33 diff --git a/testing/scripts/compile_kernel.py b/testing/scripts/compile_kernel.py new file mode 100755 index 0000000..e3cb220 --- /dev/null +++ b/testing/scripts/compile_kernel.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +"""Lower a basic kernel through whichever plugin is loaded via TRITON_PLUGIN_PATHS. + +Run manually to debug a plugin's compile pipeline: + + TRITON_PLUGIN_PATHS=/path/to/lib.so python testing/scripts/compile_kernel.py + +Exits 0 on success (or if no target is available, e.g. no GPU present). +""" + +import sys + +import triton +import triton.language as tl + + +@triton.jit +def kernel(in_ptr, out_ptr, BLOCK: tl.constexpr): + offs = tl.arange(0, BLOCK) + tl.store(out_ptr + offs, tl.load(in_ptr + offs)) + + +def main() -> int: + try: + target = triton.runtime.driver.active.get_current_target() + except Exception as e: + print(f"No target ({type(e).__name__}: {e}); skipping compile.") + return 0 + src = triton.compiler.ASTSource( + fn=kernel, + signature={ + "in_ptr": "*fp32", + "out_ptr": "*fp32" + }, + constexprs={"BLOCK": 128}, + ) + triton.compile(src, target=target) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/testing/scripts/load_tlx_dsl.py b/testing/scripts/load_tlx_dsl.py new file mode 100755 index 0000000..6810264 --- /dev/null +++ b/testing/scripts/load_tlx_dsl.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +"""Check that utlx registers the `triton.language.extra.tlx` DSL. + +Loading the utlx .so alone is not enough: importing `utlx_plugin` is what +inserts itself into `sys.modules` as `triton.language.extra.tlx`. Run manually: + + TRITON_PLUGIN_PATHS=/path/to/libutlx.so \ + PYTHONPATH=extensions/utlx/python \ + python testing/scripts/load_tlx_dsl.py +""" + +import sys + +import triton # noqa: F401 +import utlx_plugin # noqa: F401 (registers triton.language.extra.tlx) +from triton.language.extra import tlx + + +def main() -> int: + for n in ("local_alloc", "local_view", "local_store", "local_load"): + assert hasattr(tlx, n), f"missing tlx.{n}" + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/testing/test_plugins.py b/testing/test_plugins.py new file mode 100644 index 0000000..152757b --- /dev/null +++ b/testing/test_plugins.py @@ -0,0 +1,168 @@ +"""Plugin integration tests. + +Auto-discovers every plugin declared by a `triton-ext.conf` and exercises +its `lib.so` from a fresh Python interpreter with `TRITON_PLUGIN_PATHS` +set. Each plugin runs in its own subprocess so failures isolate cleanly. + +Tests: + - test_plugin_loads[] -- plugin static-init: `import triton` + succeeds with the .so loaded. + - test_plugin_compiles_kernel[] -- end-to-end: JIT-decorate and lower + a basic kernel through the plugin's + pipeline. + - test__ -- plugin-specific scenarios, gated + with `@pytest.mark.skipif` on the + plugin's .so existence. + +Adding a new plugin: drop a `triton-ext.conf`; both parametrized tests pick +it up. To exempt a plugin from a parametrized test, mark it at parametrize +time with `pytest.param(..., marks=pytest.mark.skip(...))` -- see +`_COMPILE_PLUGINS` for an example. + +The kernel-compile and tlx-DSL scenarios live as standalone scripts under +`testing/scripts/` so they can be run by hand to debug a plugin, e.g. +`TRITON_PLUGIN_PATHS=build/lib/lib.so python testing/scripts/compile_kernel.py`. +On failure each test prints the exact command to reproduce it. +""" + +from __future__ import annotations + +import os +import shlex +import subprocess +import sys +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parent.parent +BUILD_DIR = Path(os.environ.get("TRITON_EXT_BUILD_DIR", REPO_ROOT / "build")) +PLUGIN_LIB_DIR = BUILD_DIR / "lib" +SCRIPTS_DIR = Path(__file__).resolve().parent / "scripts" + + +def _discover_plugins() -> list[pytest.ParameterSet]: + plugins: list[pytest.ParameterSet] = [] + for conf in REPO_ROOT.rglob("triton-ext.conf"): + rel_parts = conf.relative_to(REPO_ROOT).parts + if rel_parts[0].startswith(("triton-", "llvm-", "build")): + continue + text = conf.read_text().strip() + if not text: + continue + # Format is `name;status[;hash]` (CMake list); we only need the name. + name = text.split(";", 1)[0].strip() + if not name: + continue + plugins.append(pytest.param(name, id=name)) + plugins.sort(key=lambda p: p.id) + return plugins + + +PLUGINS = _discover_plugins() + + +def _plugin_path(name: str) -> Path: + return PLUGIN_LIB_DIR / f"lib{name}.so" + + +def _format_command(env_overrides: dict[str, str], args: list[str]) -> str: + """Render a copy-pasteable shell command for manual debugging.""" + prefix = " ".join(f"{k}={shlex.quote(v)}" + for k, v in env_overrides.items()) + cmd = " ".join(shlex.quote(a) for a in args) + return f"{prefix} {cmd}".strip() + + +def _run(env_overrides: dict[str, str], + args: list[str]) -> tuple[subprocess.CompletedProcess, str]: + """Run a subprocess and return it along with its debug command string.""" + env = {**os.environ, **env_overrides} + command = _format_command(env_overrides, args) + result = subprocess.run( + args, + env=env, + capture_output=True, + text=True, + check=False, + ) + return result, command + + +# --------------------------------------------------------------------------- +# Generic per-plugin tests (auto-discovered) +# --------------------------------------------------------------------------- + + +def test_plugins_discovered() -> None: + """Guard against silently testing nothing if discovery breaks.""" + assert PLUGINS, f"No triton-ext.conf files found under {REPO_ROOT}" + + +@pytest.mark.parametrize("name", PLUGINS) +def test_plugin_loads(name: str) -> None: + """Smoke: `import triton` succeeds with the plugin loaded.""" + path = _plugin_path(name) + if not path.is_file(): + pytest.skip(f"Plugin not built at {path} (extension may be disabled)") + result, command = _run({"TRITON_PLUGIN_PATHS": str(path)}, + [sys.executable, "-c", "import triton"]) + assert result.returncode == 0, ( + f"Loading plugin {name} failed. Reproduce with:\n {command}\n" + f"--- stdout ---\n{result.stdout}\n" + f"--- stderr ---\n{result.stderr}") + + +# example dialect is scaffolding-only -- its Dialect::initialize() doesn't +# register StringAttr, so kernel compile aborts with an LLVM storage-uniquer +# error. Tag it as skip at parametrize time. +_COMPILE_PLUGINS = [ + pytest.param(p.values[0], + marks=pytest.mark.skip(reason="scaffolding-only dialect"), + id=p.id) if p.id == "example" else p for p in PLUGINS +] + + +@pytest.mark.parametrize("name", _COMPILE_PLUGINS) +def test_plugin_compiles_kernel(name: str) -> None: + """User scenario: with the plugin loaded, JIT-decorate and lower a basic kernel.""" + path = _plugin_path(name) + if not path.is_file(): + pytest.skip(f"Plugin not built at {path} (extension may be disabled)") + result, command = _run( + {"TRITON_PLUGIN_PATHS": str(path)}, + [sys.executable, + str(SCRIPTS_DIR / "compile_kernel.py")]) + assert result.returncode == 0, ( + f"Plugin {name} broke kernel compile. Reproduce with:\n {command}\n" + f"--- stdout ---\n{result.stdout}\n" + f"--- stderr ---\n{result.stderr}") + + +# --------------------------------------------------------------------------- +# Plugin-specific tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not _plugin_path("utlx").is_file(), + reason="utlx plugin not built") +def test_utlx_registers_tlx_dsl() -> None: + """utlx registers `triton.language.extra.tlx` with local_alloc/view/store/load. + + The Python namespace is set up by `extensions/utlx/python/utlx_plugin/__init__.py` + when imported -- it inserts itself into `sys.modules` as + `triton.language.extra.tlx`. Loading the .so alone is not enough. + """ + plugin_path = _plugin_path("utlx") + utlx_python = REPO_ROOT / "extensions" / "utlx" / "python" + pythonpath = f"{utlx_python}{os.pathsep}{os.environ.get('PYTHONPATH', '')}" + result, command = _run( + { + "TRITON_PLUGIN_PATHS": str(plugin_path), + "PYTHONPATH": pythonpath, + }, + [sys.executable, str(SCRIPTS_DIR / "load_tlx_dsl.py")]) + assert result.returncode == 0, ( + f"utlx tlx-DSL check failed. Reproduce with:\n {command}\n" + f"--- stdout ---\n{result.stdout}\n" + f"--- stderr ---\n{result.stderr}")