Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions .github/actions/build-triton/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ description: |
- if the artifact does not exist, check out Triton and build it, using the given LLVM installation directory; then,
upload it as an artifact
- if the artifact does exist, download and extract it
- return the path the Triton installation directory as well as the build directory inside of it (e.g.,
`build/cmake.linux-x86_64-cpython-3.14`).
- return the path to the Triton installation directory, which contains the installable C++ tree
(bin, include, lib) alongside the importable Python package under `python/`.

This action expects `sccache` to be set up by the caller.

Expand Down Expand Up @@ -94,7 +94,11 @@ runs:
run: |
make install
find build/install -type f -executable | xargs strip --strip-debug
tar czf ../"${{ env.triton_install_dir }}.tar.gz" --transform="s|^build/install|${{ env.triton_install_dir }}|" build/install
find python -type d -name '__pycache__' -prune -exec rm -rf {} + 2>/dev/null || true
tar czhf ../"${{ env.triton_install_dir }}.tar.gz" \
--transform="s|^build/install|${{ env.triton_install_dir }}|" \
--transform="s|^python|${{ env.triton_install_dir }}/python|" \
build/install python

- name: Upload Triton artifact
if: steps.check-artifact.outputs.exists == 'false'
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ jobs:
EXTRA_CMAKE_ARGS="-DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache"

- name: Run tests
env:
PYTHONPATH: ${{ github.workspace }}/${{ steps.build-triton.outputs.triton_install_dir }}/python
run: >
make test LLVM_INSTALL_DIR="${{ steps.build-llvm.outputs.llvm_install_dir }}"
TRITON_INSTALL_DIR="${{ steps.build-triton.outputs.triton_install_dir }}"
Expand Down
9 changes: 8 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,16 @@ build: configure
cmake --build ${BUILD_DIR}

.PHONY: test
test:
test: test-lit test-unit

.PHONY: test-lit
test-lit:
ninja -C ${BUILD_DIR} check-lit-tests

.PHONY: test-unit
test-unit:
TRITON_EXT_BUILD_DIR="${BUILD_DIR}" python -m pytest testing/ -v

.PHONY: clean
clean:
rm -rf ${BUILD_DIR}
Expand Down
2 changes: 1 addition & 1 deletion ci/triton-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
4768da5e8228dfbda8e0b7a61101f87d953341bd
7a5d6a3dec31f865d0e6a6ce751fed1cc10a5f5a
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
lit~=18.1
ninja~=1.13
pre_commit~=4.6
pytest~=8.3
requests~=2.33
42 changes: 42 additions & 0 deletions testing/scripts/compile_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3
"""Lower a basic kernel through whichever plugin is loaded via TRITON_PLUGIN_PATHS.

Run manually to debug a plugin's compile pipeline:

TRITON_PLUGIN_PATHS=/path/to/lib<name>.so python testing/scripts/compile_kernel.py

Exits 0 on success (or if no target is available, e.g. no GPU present).
"""

import sys

import triton
import triton.language as tl


@triton.jit
def kernel(in_ptr, out_ptr, BLOCK: tl.constexpr):
offs = tl.arange(0, BLOCK)
tl.store(out_ptr + offs, tl.load(in_ptr + offs))


def main() -> int:
try:
target = triton.runtime.driver.active.get_current_target()
except Exception as e:
print(f"No target ({type(e).__name__}: {e}); skipping compile.")
return 0
src = triton.compiler.ASTSource(
fn=kernel,
signature={
"in_ptr": "*fp32",
"out_ptr": "*fp32"
},
constexprs={"BLOCK": 128},
)
triton.compile(src, target=target)
return 0


if __name__ == "__main__":
sys.exit(main())
26 changes: 26 additions & 0 deletions testing/scripts/load_tlx_dsl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env python3
"""Check that utlx registers the `triton.language.extra.tlx` DSL.

Loading the utlx .so alone is not enough: importing `utlx_plugin` is what
inserts itself into `sys.modules` as `triton.language.extra.tlx`. Run manually:

TRITON_PLUGIN_PATHS=/path/to/libutlx.so \
PYTHONPATH=extensions/utlx/python \
python testing/scripts/load_tlx_dsl.py
"""

import sys

import triton # noqa: F401
import utlx_plugin # noqa: F401 (registers triton.language.extra.tlx)
from triton.language.extra import tlx


def main() -> int:
for n in ("local_alloc", "local_view", "local_store", "local_load"):
assert hasattr(tlx, n), f"missing tlx.{n}"
return 0


if __name__ == "__main__":
sys.exit(main())
168 changes: 168 additions & 0 deletions testing/test_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""Plugin integration tests.

Auto-discovers every plugin declared by a `triton-ext.conf` and exercises
its `lib<name>.so` from a fresh Python interpreter with `TRITON_PLUGIN_PATHS`
set. Each plugin runs in its own subprocess so failures isolate cleanly.

Tests:
- test_plugin_loads[<name>] -- plugin static-init: `import triton`
succeeds with the .so loaded.
- test_plugin_compiles_kernel[<name>] -- end-to-end: JIT-decorate and lower
a basic kernel through the plugin's
pipeline.
- test_<plugin>_<feature> -- plugin-specific scenarios, gated
with `@pytest.mark.skipif` on the
plugin's .so existence.

Adding a new plugin: drop a `triton-ext.conf`; both parametrized tests pick
it up. To exempt a plugin from a parametrized test, mark it at parametrize
time with `pytest.param(..., marks=pytest.mark.skip(...))` -- see
`_COMPILE_PLUGINS` for an example.

The kernel-compile and tlx-DSL scenarios live as standalone scripts under
`testing/scripts/` so they can be run by hand to debug a plugin, e.g.
`TRITON_PLUGIN_PATHS=build/lib/lib<name>.so python testing/scripts/compile_kernel.py`.
On failure each test prints the exact command to reproduce it.
"""

from __future__ import annotations

import os
import shlex
import subprocess
import sys
from pathlib import Path

import pytest

REPO_ROOT = Path(__file__).resolve().parent.parent
BUILD_DIR = Path(os.environ.get("TRITON_EXT_BUILD_DIR", REPO_ROOT / "build"))
PLUGIN_LIB_DIR = BUILD_DIR / "lib"
SCRIPTS_DIR = Path(__file__).resolve().parent / "scripts"


def _discover_plugins() -> list[pytest.ParameterSet]:
plugins: list[pytest.ParameterSet] = []
for conf in REPO_ROOT.rglob("triton-ext.conf"):
rel_parts = conf.relative_to(REPO_ROOT).parts
if rel_parts[0].startswith(("triton-", "llvm-", "build")):
continue
text = conf.read_text().strip()
if not text:
continue
# Format is `name;status[;hash]` (CMake list); we only need the name.
name = text.split(";", 1)[0].strip()
if not name:
continue
plugins.append(pytest.param(name, id=name))
plugins.sort(key=lambda p: p.id)
return plugins


PLUGINS = _discover_plugins()


def _plugin_path(name: str) -> Path:
return PLUGIN_LIB_DIR / f"lib{name}.so"


def _format_command(env_overrides: dict[str, str], args: list[str]) -> str:
"""Render a copy-pasteable shell command for manual debugging."""
prefix = " ".join(f"{k}={shlex.quote(v)}"
for k, v in env_overrides.items())
cmd = " ".join(shlex.quote(a) for a in args)
return f"{prefix} {cmd}".strip()


def _run(env_overrides: dict[str, str],
args: list[str]) -> tuple[subprocess.CompletedProcess, str]:
"""Run a subprocess and return it along with its debug command string."""
env = {**os.environ, **env_overrides}
command = _format_command(env_overrides, args)
result = subprocess.run(
args,
env=env,
capture_output=True,
text=True,
check=False,
)
return result, command


# ---------------------------------------------------------------------------
# Generic per-plugin tests (auto-discovered)
# ---------------------------------------------------------------------------


def test_plugins_discovered() -> None:
"""Guard against silently testing nothing if discovery breaks."""
assert PLUGINS, f"No triton-ext.conf files found under {REPO_ROOT}"


@pytest.mark.parametrize("name", PLUGINS)
def test_plugin_loads(name: str) -> None:
"""Smoke: `import triton` succeeds with the plugin loaded."""
path = _plugin_path(name)
if not path.is_file():
pytest.skip(f"Plugin not built at {path} (extension may be disabled)")
result, command = _run({"TRITON_PLUGIN_PATHS": str(path)},
[sys.executable, "-c", "import triton"])
assert result.returncode == 0, (
f"Loading plugin {name} failed. Reproduce with:\n {command}\n"
f"--- stdout ---\n{result.stdout}\n"
f"--- stderr ---\n{result.stderr}")


# example dialect is scaffolding-only -- its Dialect::initialize() doesn't
# register StringAttr, so kernel compile aborts with an LLVM storage-uniquer
# error. Tag it as skip at parametrize time.
_COMPILE_PLUGINS = [
pytest.param(p.values[0],
marks=pytest.mark.skip(reason="scaffolding-only dialect"),
id=p.id) if p.id == "example" else p for p in PLUGINS
]


@pytest.mark.parametrize("name", _COMPILE_PLUGINS)
def test_plugin_compiles_kernel(name: str) -> None:
"""User scenario: with the plugin loaded, JIT-decorate and lower a basic kernel."""
path = _plugin_path(name)
if not path.is_file():
pytest.skip(f"Plugin not built at {path} (extension may be disabled)")
result, command = _run(
{"TRITON_PLUGIN_PATHS": str(path)},
[sys.executable,
str(SCRIPTS_DIR / "compile_kernel.py")])
assert result.returncode == 0, (
f"Plugin {name} broke kernel compile. Reproduce with:\n {command}\n"
f"--- stdout ---\n{result.stdout}\n"
f"--- stderr ---\n{result.stderr}")


# ---------------------------------------------------------------------------
# Plugin-specific tests
# ---------------------------------------------------------------------------


@pytest.mark.skipif(not _plugin_path("utlx").is_file(),
reason="utlx plugin not built")
def test_utlx_registers_tlx_dsl() -> None:
"""utlx registers `triton.language.extra.tlx` with local_alloc/view/store/load.

The Python namespace is set up by `extensions/utlx/python/utlx_plugin/__init__.py`
when imported -- it inserts itself into `sys.modules` as
`triton.language.extra.tlx`. Loading the .so alone is not enough.
"""
plugin_path = _plugin_path("utlx")
utlx_python = REPO_ROOT / "extensions" / "utlx" / "python"
pythonpath = f"{utlx_python}{os.pathsep}{os.environ.get('PYTHONPATH', '')}"
result, command = _run(
{
"TRITON_PLUGIN_PATHS": str(plugin_path),
"PYTHONPATH": pythonpath,
},
[sys.executable, str(SCRIPTS_DIR / "load_tlx_dsl.py")])
assert result.returncode == 0, (
f"utlx tlx-DSL check failed. Reproduce with:\n {command}\n"
f"--- stdout ---\n{result.stdout}\n"
f"--- stderr ---\n{result.stderr}")
Loading