diff --git a/.github/actions/build-triton/action.yml b/.github/actions/build-triton/action.yml index 58d2cc5..8f56446 100644 --- a/.github/actions/build-triton/action.yml +++ b/.github/actions/build-triton/action.yml @@ -35,7 +35,7 @@ runs: SHORT_TRITON_COMMIT_HASH="${TRITON_COMMIT_HASH:0:8}" SYSINFO=$(python ci/probe-sysinfo.py) - INSTALL_DIR="triton-${SHORT_TRITON_COMMIT_HASH}-${SYSINFO}" + INSTALL_DIR="triton-${SHORT_TRITON_COMMIT_HASH}-${SYSINFO}-v2" echo "Triton installation directory: ${INSTALL_DIR}" echo "triton_install_dir=${INSTALL_DIR}" >> ${GITHUB_ENV} echo "triton_install_dir=${INSTALL_DIR}" >> ${GITHUB_OUTPUT} @@ -93,8 +93,11 @@ runs: # Stripping debug symbols reduces the artifact size by an order of magnitude. run: | make install - find build/install -type f -executable | xargs strip --strip-debug - tar czf ../"${{ env.triton_install_dir }}.tar.gz" --transform="s|^build/install|${{ env.triton_install_dir }}|" build/install + find build/install python/triton -type f \( -executable -o -name "*.so" \) -print0 | xargs -0 -r strip --strip-debug || true + tar czf ../"${{ env.triton_install_dir }}.tar.gz" \ + --transform="s|^build/install|${{ env.triton_install_dir }}|" \ + --transform="s|^python|${{ env.triton_install_dir }}/python|" \ + build/install python/triton - name: Upload Triton artifact if: steps.check-artifact.outputs.exists == 'false' diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c18fa1b..fc704ae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -117,5 +117,13 @@ jobs: make test LLVM_INSTALL_DIR="${{ steps.build-llvm.outputs.llvm_install_dir }}" TRITON_INSTALL_DIR="${{ steps.build-triton.outputs.triton_install_dir }}" + - name: Run plugin tests + env: + PYTHONPATH: ${{ github.workspace }}/${{ steps.build-triton.outputs.triton_install_dir }}/python + LD_LIBRARY_PATH: ${{ github.workspace }}/${{ steps.build-llvm.outputs.llvm_install_dir }}/lib:${{ github.workspace }}/${{ steps.build-triton.outputs.triton_install_dir }}/lib + run: | + source ~/.venv/bin/activate + python -m pytest testing/test_plugins.py + - name: List files run: du -d 3 -h . diff --git a/extensions/utlx/uTLXPlugin.cpp b/extensions/utlx/uTLXPlugin.cpp index 5512d59..72196bd 100644 --- a/extensions/utlx/uTLXPlugin.cpp +++ b/extensions/utlx/uTLXPlugin.cpp @@ -13,6 +13,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Tools/PluginUtils.h" +#include "triton/Version.h" // TLX dialect headers #include "tlx/dialect/include/IR/Dialect.h" @@ -803,6 +804,7 @@ TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() { 1, // numDialects ops, 48, // numOps + TRITON_VERSION, }; return &info; } diff --git a/requirements.txt b/requirements.txt index 9463836..0e49ab8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ lit~=18.1 ninja~=1.13 pre_commit~=4.6 +pytest~=8.3 requests~=2.33 diff --git a/testing/test_plugins.py b/testing/test_plugins.py new file mode 100644 index 0000000..db9d9ba --- /dev/null +++ b/testing/test_plugins.py @@ -0,0 +1,185 @@ +"""Plugin integration tests. + +Auto-discovers every plugin declared by a `triton-ext.conf` and exercises +its `lib.so` from a fresh Python interpreter with `TRITON_PLUGIN_PATHS` +set. Each plugin runs in its own subprocess so failures isolate cleanly. + +Tests: + - test_plugin_loads[] -- plugin static-init: `import triton` + succeeds with the .so loaded. + - test_plugin_compiles_kernel[] -- end-to-end: JIT-decorate and lower + a basic kernel through the plugin's + pipeline. + - test__ -- plugin-specific scenarios, gated + with `@pytest.mark.skipif` on the + plugin's .so existence. + +Adding a new plugin: drop a `triton-ext.conf`; both parametrized tests pick +it up. To exempt a plugin from a parametrized test, mark it at parametrize +time with `pytest.param(..., marks=pytest.mark.skip(...))` -- see +`_COMPILE_PLUGINS` for an example. +""" + +from __future__ import annotations + +import os +import subprocess +import sys +import textwrap +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parent.parent +BUILD_DIR = Path(os.environ.get("TRITON_EXT_BUILD_DIR", REPO_ROOT / "build")) +PLUGIN_LIB_DIR = BUILD_DIR / "lib" + + +def _discover_plugins() -> list[pytest.ParameterSet]: + plugins: list[pytest.ParameterSet] = [] + for conf in REPO_ROOT.rglob("triton-ext.conf"): + rel_parts = conf.relative_to(REPO_ROOT).parts + if rel_parts[0].startswith(("triton-", "llvm-", "build")): + continue + text = conf.read_text().strip() + if not text: + continue + # Format is `name;status[;hash]` (CMake list); we only need the name. + name = text.split(";", 1)[0].strip() + if not name: + continue + plugins.append(pytest.param(name, id=name)) + plugins.sort(key=lambda p: p.id) + return plugins + + +PLUGINS = _discover_plugins() + + +def _plugin_path(name: str) -> Path: + return PLUGIN_LIB_DIR / f"lib{name}.so" + + +def _run_with_plugin(plugin_path: Path, script: str) -> subprocess.CompletedProcess: + env = {**os.environ, "TRITON_PLUGIN_PATHS": str(plugin_path)} + return subprocess.run( + [sys.executable, "-c", textwrap.dedent(script)], + env=env, + capture_output=True, + text=True, + check=False, + ) + + +# --------------------------------------------------------------------------- +# Generic per-plugin tests (auto-discovered) +# --------------------------------------------------------------------------- + + +def test_plugins_discovered() -> None: + """Guard against silently testing nothing if discovery breaks.""" + assert PLUGINS, f"No triton-ext.conf files found under {REPO_ROOT}" + + +@pytest.mark.parametrize("name", PLUGINS) +def test_plugin_loads(name: str) -> None: + """Smoke: `import triton` succeeds with the plugin loaded.""" + path = _plugin_path(name) + if not path.is_file(): + pytest.skip(f"Plugin not built at {path} (extension may be disabled)") + result = _run_with_plugin(path, "import triton") + assert result.returncode == 0, ( + f"Loading plugin {name} from {path} failed:\n" + f"--- stdout ---\n{result.stdout}\n" + f"--- stderr ---\n{result.stderr}" + ) + + +# example dialect is scaffolding-only -- its Dialect::initialize() doesn't +# register StringAttr, so kernel compile aborts with an LLVM storage-uniquer +# error. Tag it as skip at parametrize time. +_COMPILE_PLUGINS = [ + pytest.param( + p.values[0], marks=pytest.mark.skip(reason="scaffolding-only dialect"), id=p.id + ) + if p.id == "example" + else p + for p in PLUGINS +] + + +@pytest.mark.parametrize("name", _COMPILE_PLUGINS) +def test_plugin_compiles_kernel(name: str) -> None: + """User scenario: with the plugin loaded, JIT-decorate and lower a basic kernel.""" + path = _plugin_path(name) + if not path.is_file(): + pytest.skip(f"Plugin not built at {path} (extension may be disabled)") + script = """ + import sys + import triton + import triton.language as tl + + @triton.jit + def kernel(in_ptr, out_ptr, BLOCK: tl.constexpr): + offs = tl.arange(0, BLOCK) + tl.store(out_ptr + offs, tl.load(in_ptr + offs)) + + try: + target = triton.runtime.driver.active.get_current_target() + except Exception as e: + print(f"No target ({type(e).__name__}: {e}); skipping compile.") + sys.exit(0) + src = triton.compiler.ASTSource( + fn=kernel, + signature={"in_ptr": "*fp32", "out_ptr": "*fp32"}, + constexprs={"BLOCK": 128}, + ) + triton.compile(src, target=target) + """ + result = _run_with_plugin(path, script) + assert result.returncode == 0, ( + f"Plugin {name} broke kernel compile:\n" + f"--- stdout ---\n{result.stdout}\n" + f"--- stderr ---\n{result.stderr}" + ) + + +# --------------------------------------------------------------------------- +# Plugin-specific tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not _plugin_path("utlx").is_file(), reason="utlx plugin not built") +def test_utlx_registers_tlx_dsl() -> None: + """utlx registers `triton.language.extra.tlx` with local_alloc/view/store/load. + + The Python namespace is set up by `extensions/utlx/python/utlx_plugin/__init__.py` + when imported -- it inserts itself into `sys.modules` as + `triton.language.extra.tlx`. Loading the .so alone is not enough. + """ + plugin_path = _plugin_path("utlx") + utlx_python = REPO_ROOT / "extensions" / "utlx" / "python" + env = { + **os.environ, + "TRITON_PLUGIN_PATHS": str(plugin_path), + "PYTHONPATH": f"{utlx_python}{os.pathsep}{os.environ.get('PYTHONPATH', '')}", + } + script = """ + import triton # noqa: F401 + import utlx_plugin # noqa: F401 (registers triton.language.extra.tlx) + from triton.language.extra import tlx + for n in ("local_alloc", "local_view", "local_store", "local_load"): + assert hasattr(tlx, n), f"missing tlx.{n}" + """ + result = subprocess.run( + [sys.executable, "-c", textwrap.dedent(script)], + env=env, + capture_output=True, + text=True, + check=False, + ) + assert result.returncode == 0, ( + f"utlx tlx-DSL check failed:\n" + f"--- stdout ---\n{result.stdout}\n" + f"--- stderr ---\n{result.stderr}" + )