Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,15 @@ void init_triton_llvm(py::module &&m) {
.def_property_readonly(
"name", [](llvm::Function *fn) { return fn->getName().str(); })
.def("set_calling_conv", &llvm::Function::setCallingConv)
.def("add_fn_attr", [](llvm::Function *fn, std::string &name,
std::string &val) { fn->addFnAttr(name, val); })
.def(
"add_fn_attr",
[](llvm::Function *fn, std::string &name, std::string &val) {
if (val.empty())
fn->addFnAttr(name);
else
fn->addFnAttr(name, val);
},
py::arg("name"), py::arg("val") = "")
.def("remove_fn_attr", [](llvm::Function *fn,
std::string &name) { fn->removeFnAttr(name); })
.def("add_fn_asan_attr",
Expand Down
22 changes: 22 additions & 0 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ def is_consan_supported(arch):
return arch in ["gfx1250"]


def _parse_llvm_fn_attrs(attrs):
if not isinstance(attrs, str):
return tuple(attrs)
parsed = []
for attr in attrs.split(","):
name, sep, value = attr.partition("=")
name = name.strip()
if name:
parsed.append((name, value.strip() if sep else ""))
return tuple(parsed)


@dataclass(frozen=True)
class HIPOptions:
num_warps: int = 4
Expand Down Expand Up @@ -89,6 +101,11 @@ class HIPOptions:
# schedule_hint="attention,memory-bound-attention"
schedule_hint: str = 'none'

# Experimental: intended for development and debugging; may change or be removed without notice.
# Comma-separated LLVM function attributes; bare names are emitted as valueless attributes.
# Example: llvm_fn_attrs="amdgpu-sched-strategy=iterative-ilp,noinline"
llvm_fn_attrs: str | Tuple[Tuple[str, str], ...] = ""

def __post_init__(self):
gfx_major = int(self.arch[3:-2]) # Drop "gfx" prefix and minor/patch number
warp_size = 32 if gfx_major >= 10 else 64
Expand All @@ -102,6 +119,8 @@ def __post_init__(self):
)
object.__setattr__(self, 'kpack', 1)

object.__setattr__(self, 'llvm_fn_attrs', _parse_llvm_fn_attrs(self.llvm_fn_attrs))

default_libdir = Path(__file__).parent / 'lib'
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
for lib in ["ocml", "ockl"]:
Expand Down Expand Up @@ -457,6 +476,9 @@ def make_llir(src, metadata, options):
if knobs.compilation.enable_asan:
kernel_fn.add_fn_target_feature("+xnack")
kernel_fn.add_fn_asan_attr()
for name, value in options.llvm_fn_attrs:
kernel_fn.remove_fn_attr(name)
kernel_fn.add_fn_attr(name, value)

# Hint the compiler that we'd like the firmware to set the kernel arguments
# to user SGPRs so that the kernel does not need to s_load its arguments
Expand Down
58 changes: 58 additions & 0 deletions third_party/amd/python/test/test_llvm_fn_attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from pathlib import Path

import pytest
import triton

from triton._internal_testing import is_hip

if not is_hip():
pytest.skip(allow_module_level=True)

ATTN_FWD_TTIR = str(Path(__file__).parent / "attn_fwd.ttir")


def test_llvm_fn_attrs_set_single_llir_attribute():
target = triton.runtime.driver.active.get_current_target()

baseline = triton.compile(ATTN_FWD_TTIR, target=target)
with_attrs = triton.compile(
ATTN_FWD_TTIR,
target=target,
options={"llvm_fn_attrs": "amdgpu-sched-strategy=iterative-ilp"},
)

assert '"amdgpu-sched-strategy"="iterative-ilp"' not in baseline.asm["llir"]
assert '"amdgpu-sched-strategy"="iterative-ilp"' in with_attrs.asm["llir"]


def test_llvm_fn_attrs_set_multiple_llir_attributes():
target = triton.runtime.driver.active.get_current_target()
llvm_fn_attrs = ",".join([
"amdgpu-sched-strategy=iterative-ilp",
"triton-test-attr=enabled",
"triton-bare-attr",
])

with_attrs = triton.compile(
ATTN_FWD_TTIR,
target=target,
options={"llvm_fn_attrs": llvm_fn_attrs},
)

assert '"amdgpu-sched-strategy"="iterative-ilp"' in with_attrs.asm["llir"]
assert '"triton-test-attr"="enabled"' in with_attrs.asm["llir"]
assert '"triton-bare-attr"' in with_attrs.asm["llir"]
assert '"triton-bare-attr"=' not in with_attrs.asm["llir"]


def test_llvm_fn_attrs_change_amdgcn():
target = triton.runtime.driver.active.get_current_target()

baseline = triton.compile(ATTN_FWD_TTIR, target=target)
with_sched_attr = triton.compile(
ATTN_FWD_TTIR,
target=target,
options={"llvm_fn_attrs": "amdgpu-sched-strategy=iterative-ilp"},
)

assert baseline.asm["amdgcn"] != with_sched_attr.asm["amdgcn"]
Loading