diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 84c2efbb740a..39dbd60c5c87 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -569,8 +569,15 @@ void init_triton_llvm(py::module &&m) { .def_property_readonly( "name", [](llvm::Function *fn) { return fn->getName().str(); }) .def("set_calling_conv", &llvm::Function::setCallingConv) - .def("add_fn_attr", [](llvm::Function *fn, std::string &name, - std::string &val) { fn->addFnAttr(name, val); }) + .def( + "add_fn_attr", + [](llvm::Function *fn, std::string &name, std::string &val) { + if (val.empty()) + fn->addFnAttr(name); + else + fn->addFnAttr(name, val); + }, + py::arg("name"), py::arg("val") = "") .def("remove_fn_attr", [](llvm::Function *fn, std::string &name) { fn->removeFnAttr(name); }) .def("add_fn_asan_attr", diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 599ea98f31fa..d3192e478a41 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -41,6 +41,18 @@ def is_consan_supported(arch): return arch in ["gfx1250"] +def _parse_llvm_fn_attrs(attrs): + if not isinstance(attrs, str): + return tuple(attrs) + parsed = [] + for attr in attrs.split(","): + name, sep, value = attr.partition("=") + name = name.strip() + if name: + parsed.append((name, value.strip() if sep else "")) + return tuple(parsed) + + @dataclass(frozen=True) class HIPOptions: num_warps: int = 4 @@ -89,6 +101,11 @@ class HIPOptions: # schedule_hint="attention,memory-bound-attention" schedule_hint: str = 'none' + # Experimental: intended for development and debugging; may change or be removed without notice. + # Comma-separated LLVM function attributes; bare names are emitted as valueless attributes. + # Example: llvm_fn_attrs="amdgpu-sched-strategy=iterative-ilp,noinline" + llvm_fn_attrs: str | Tuple[Tuple[str, str], ...] = "" + def __post_init__(self): gfx_major = int(self.arch[3:-2]) # Drop "gfx" prefix and minor/patch number warp_size = 32 if gfx_major >= 10 else 64 @@ -102,6 +119,8 @@ def __post_init__(self): ) object.__setattr__(self, 'kpack', 1) + object.__setattr__(self, 'llvm_fn_attrs', _parse_llvm_fn_attrs(self.llvm_fn_attrs)) + default_libdir = Path(__file__).parent / 'lib' extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) for lib in ["ocml", "ockl"]: @@ -457,6 +476,9 @@ def make_llir(src, metadata, options): if knobs.compilation.enable_asan: kernel_fn.add_fn_target_feature("+xnack") kernel_fn.add_fn_asan_attr() + for name, value in options.llvm_fn_attrs: + kernel_fn.remove_fn_attr(name) + kernel_fn.add_fn_attr(name, value) # Hint the compiler that we'd like the firmware to set the kernel arguments # to user SGPRs so that the kernel does not need to s_load its arguments diff --git a/third_party/amd/python/test/test_llvm_fn_attrs.py b/third_party/amd/python/test/test_llvm_fn_attrs.py new file mode 100644 index 000000000000..ae242ec50ea4 --- /dev/null +++ b/third_party/amd/python/test/test_llvm_fn_attrs.py @@ -0,0 +1,58 @@ +from pathlib import Path + +import pytest +import triton + +from triton._internal_testing import is_hip + +if not is_hip(): + pytest.skip(allow_module_level=True) + +ATTN_FWD_TTIR = str(Path(__file__).parent / "attn_fwd.ttir") + + +def test_llvm_fn_attrs_set_single_llir_attribute(): + target = triton.runtime.driver.active.get_current_target() + + baseline = triton.compile(ATTN_FWD_TTIR, target=target) + with_attrs = triton.compile( + ATTN_FWD_TTIR, + target=target, + options={"llvm_fn_attrs": "amdgpu-sched-strategy=iterative-ilp"}, + ) + + assert '"amdgpu-sched-strategy"="iterative-ilp"' not in baseline.asm["llir"] + assert '"amdgpu-sched-strategy"="iterative-ilp"' in with_attrs.asm["llir"] + + +def test_llvm_fn_attrs_set_multiple_llir_attributes(): + target = triton.runtime.driver.active.get_current_target() + llvm_fn_attrs = ",".join([ + "amdgpu-sched-strategy=iterative-ilp", + "triton-test-attr=enabled", + "triton-bare-attr", + ]) + + with_attrs = triton.compile( + ATTN_FWD_TTIR, + target=target, + options={"llvm_fn_attrs": llvm_fn_attrs}, + ) + + assert '"amdgpu-sched-strategy"="iterative-ilp"' in with_attrs.asm["llir"] + assert '"triton-test-attr"="enabled"' in with_attrs.asm["llir"] + assert '"triton-bare-attr"' in with_attrs.asm["llir"] + assert '"triton-bare-attr"=' not in with_attrs.asm["llir"] + + +def test_llvm_fn_attrs_change_amdgcn(): + target = triton.runtime.driver.active.get_current_target() + + baseline = triton.compile(ATTN_FWD_TTIR, target=target) + with_sched_attr = triton.compile( + ATTN_FWD_TTIR, + target=target, + options={"llvm_fn_attrs": "amdgpu-sched-strategy=iterative-ilp"}, + ) + + assert baseline.asm["amdgcn"] != with_sched_attr.asm["amdgcn"]