Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 4 additions & 15 deletions pkgs/by-name/tr/triton-llvm/package.nix
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
lib,
stdenv,
fetchFromGitHub,
fetchpatch,
pkgsBuildBuild,
pkg-config,
cmake,
Expand Down Expand Up @@ -65,7 +64,7 @@ let
in
stdenv.mkDerivation (finalAttrs: {
pname = "triton-llvm";
version = "19.1.0-rc1"; # One of the tags at https://github.com/llvm/llvm-project/commit/10dc3a8e916d73291269e5e2b82dd22681489aa1
version = "21.0.0-git"; # See https://github.com/llvm/llvm-project/blob/main/cmake/Modules/LLVMVersion.cmake

outputs =
[
Expand All @@ -78,23 +77,13 @@ stdenv.mkDerivation (finalAttrs: {
"man"
];

# See https://github.com/triton-lang/triton/blob/main/python/setup.py
# and https://github.com/ptillet/triton-llvm-releases/releases
# See https://github.com/triton-lang/triton/blob/main/cmake/llvm-hash.txt
src = fetchFromGitHub {
owner = "llvm";
repo = "llvm-project";
rev = "10dc3a8e916d73291269e5e2b82dd22681489aa1";
hash = "sha256-9DPvcFmhzw6MipQeCQnr35LktW0uxtEL8axMMPXIfWw=";
rev = "a66376b0dc3b2ea8a84fda26faca287980986f78";
hash = "sha256-7xUPozRerxt38UeJxA8kYYxOQ4+WzDREndD2+K0BYkU=";
};
patches = [
# glibc-2.40 support
# [llvm-exegesis] Use correct rseq struct size #100804
# https://github.com/llvm/llvm-project/issues/100791
(fetchpatch {
url = "https://github.com/llvm/llvm-project//commit/84837e3cc1cf17ed71580e3ea38299ed2bfaa5f6.patch";
hash = "sha256-QKa+kyXjjGXwTQTEpmKZx5yYjOyBX8A8NQoIYUaGcIw=";
})
];

nativeBuildInputs =
[
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,11 @@
From 2751c5de5c61c90b56e3e392a41847f4c47258fd Mon Sep 17 00:00:00 2001
From: SomeoneSerge <else+aalto@someonex.net>
Date: Sun, 13 Oct 2024 14:16:48 +0000
Subject: [PATCH 1/3] _build: allow extra cc flags

---
python/triton/runtime/build.py | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py
index d7baeb286..d334dce77 100644
index 1b76548d4..2756dccdb 100644
--- a/python/triton/runtime/build.py
+++ b/python/triton/runtime/build.py
@@ -42,9 +42,17 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
include_dirs = include_dirs + [srcdir, py_include_dir]
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so]
@@ -33,5 +33,13 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
cc_cmd += [f'-l{lib}' for lib in libraries]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
+
+ # Nixpkgs support branch
+ # Allows passing e.g. extra -Wl,-rpath
Expand All @@ -23,13 +14,5 @@ index d7baeb286..d334dce77 100644
+ import shlex
+ cc_cmd.extend(shlex.split(cc_cmd_extra_flags))
+
cc_cmd += [f'-l{lib}' for lib in libraries]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
- cc_cmd += [f"-I{dir}" for dir in include_dirs]
+ cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
--
2.46.0

subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL)
return so

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,61 +1,22 @@
From 7407cb03eec82768e333909d87b7668b633bfe86 Mon Sep 17 00:00:00 2001
From: SomeoneSerge <else+aalto@someonex.net>
Date: Sun, 13 Oct 2024 14:28:48 +0000
Subject: [PATCH 2/3] {nvidia,amd}/driver: short-circuit before ldconfig

---
python/triton/runtime/build.py | 6 +++---
third_party/amd/backend/driver.py | 7 +++++++
third_party/nvidia/backend/driver.py | 3 +++
3 files changed, 13 insertions(+), 3 deletions(-)

diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py
index d334dce77..a64e98da0 100644
--- a/python/triton/runtime/build.py
+++ b/python/triton/runtime/build.py
@@ -42,6 +42,9 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
include_dirs = include_dirs + [srcdir, py_include_dir]
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so]
+ cc_cmd += [f'-l{lib}' for lib in libraries]
+ cc_cmd += [f"-L{dir}" for dir in library_dirs]
+ cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]

# Nixpkgs support branch
# Allows passing e.g. extra -Wl,-rpath
@@ -50,9 +53,6 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
import shlex
cc_cmd.extend(shlex.split(cc_cmd_extra_flags))

- cc_cmd += [f'-l{lib}' for lib in libraries]
- cc_cmd += [f"-L{dir}" for dir in library_dirs]
- cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py
index 0a8cd7bed..aab8805f6 100644
index ca712f904..0961d2dda 100644
--- a/third_party/amd/backend/driver.py
+++ b/third_party/amd/backend/driver.py
@@ -24,6 +24,13 @@ def _get_path_to_hip_runtime_dylib():
return env_libhip_path
raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}")
@@ -79,6 +79,9 @@ def _get_path_to_hip_runtime_dylib():
return mmapped_path
raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}")

+ # ...on release/3.1.x:
+ # return mmapped_path
+ # raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}")
+
+ if os.path.isdir("@libhipDir@"):
+ return ["@libhipDir@"]
+
paths = []

import site
diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py
index 90f71138b..30fbadb2a 100644
index d088ec092..625de2db8 100644
--- a/third_party/nvidia/backend/driver.py
+++ b/third_party/nvidia/backend/driver.py
@@ -21,6 +21,9 @@ def libcuda_dirs():
@@ -23,6 +23,9 @@ def libcuda_dirs():
if env_libcuda_path:
return [env_libcuda_path]

Expand All @@ -65,6 +26,3 @@ index 90f71138b..30fbadb2a 100644
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
# each line looks like the following:
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
--
2.46.0

Original file line number Diff line number Diff line change
@@ -1,26 +1,14 @@
From e503e572b6d444cd27f1cdf124aaf553aa3a8665 Mon Sep 17 00:00:00 2001
From: SomeoneSerge <else+aalto@someonex.net>
Date: Mon, 14 Oct 2024 00:12:05 +0000
Subject: [PATCH 4/4] nvidia: allow static ptxas path

---
third_party/nvidia/backend/compiler.py | 3 +++
1 file changed, 3 insertions(+)

diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py
index 6d7994923..6720e8f97 100644
index 960334744..269e22e6e 100644
--- a/third_party/nvidia/backend/compiler.py
+++ b/third_party/nvidia/backend/compiler.py
@@ -20,6 +20,9 @@ def _path_to_binary(binary: str):
@@ -38,6 +38,9 @@ def _path_to_binary(binary: str):
os.path.join(os.path.dirname(__file__), "bin", binary),
]

+ import shlex
+ paths.extend(shlex.split("@nixpkgsExtraBinaryPaths@"))
+
for bin in paths:
if os.path.exists(bin) and os.path.isfile(bin):
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
--
2.46.0

for path in paths:
if os.path.exists(path) and os.path.isfile(path):
result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
45 changes: 19 additions & 26 deletions pkgs/development/python-modules/triton/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
pybind11,
python,
pytestCheckHook,
writableTmpDirAsHomeHook,
stdenv,
replaceVars,
setuptools,
Expand All @@ -27,22 +28,21 @@
triton,
}:

buildPythonPackage {
buildPythonPackage rec {
pname = "triton";
version = "3.1.0";
version = "3.3.1";
pyproject = true;

# Remember to bump triton-llvm as well!
src = fetchFromGitHub {
owner = "triton-lang";
repo = "triton";
# latest branch commit from https://github.com/triton-lang/triton/commits/release/3.1.x/
rev = "cf34004b8a67d290a962da166f5aa2fc66751326";
hash = "sha256-233fpuR7XXOaSKN+slhJbE/CMFzAqCRCE4V4rIoJZrk=";
tag = "v${version}";
hash = "sha256-XLw7s5K0j4mfIvNMumlHkUpklSzVSTRyfGazZ4lLpn0=";
};

patches =
[
./0001-setup.py-introduce-TRITON_OFFLINE_BUILD.patch
(replaceVars ./0001-_build-allow-extra-cc-flags.patch {
ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib";
})
Expand All @@ -65,24 +65,15 @@ buildPythonPackage {
# Use our `cmakeFlags` instead and avoid downloading dependencies
# remove any downloads
substituteInPlace python/setup.py \
--replace-fail "get_json_package_info(), get_pybind11_package_info()" ""\
--replace-fail "get_pybind11_package_info(), get_llvm_package_info()" ""\
--replace-fail 'packages += ["triton/profiler"]' ""\
--replace-fail "curr_version != version" "False"
--replace-fail "[get_json_package_info()]" "[]"\
--replace-fail "[get_llvm_package_info()]" "[]"\
--replace-fail 'packages += ["triton/profiler"]' "pass"\
--replace-fail "curr_version.group(1) != version" "False"

# Don't fetch googletest
substituteInPlace unittest/CMakeLists.txt \
--replace-fail "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\
substituteInPlace cmake/AddTritonUnitTest.cmake \
--replace-fail "include(\''${PROJECT_SOURCE_DIR}/unittest/googletest.cmake)" ""\
--replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)"

# Patch the source code to make sure it doesn't specify a non-existent PTXAS version.
# CUDA 12.6 (the current default/max) tops out at PTXAS version 8.5.
# NOTE: This is fixed in `master`:
# https://github.com/triton-lang/triton/commit/f48dbc1b106c93144c198fbf3c4f30b2aab9d242
substituteInPlace "$NIX_BUILD_TOP/$sourceRoot/third_party/nvidia/backend/compiler.py" \
--replace-fail \
'return 80 + minor' \
'return 80 + min(minor, 5)'
'';

build-system = [ setuptools ];
Expand All @@ -97,6 +88,9 @@ buildPythonPackage {
# because we only support cudaPackages on x86_64-linux atm
lit
llvm

# Upstream's setup.py tries to write cache somewhere in ~/
writableTmpDirAsHomeHook
];

buildInputs = [
Expand Down Expand Up @@ -125,9 +119,6 @@ buildPythonPackage {
# Ensure that the build process uses the requested number of cores
export MAX_JOBS="$NIX_BUILD_CORES"

# Upstream's setup.py tries to write cache somewhere in ~/
export HOME=$(mktemp -d)

# Upstream's github actions patch setup.cfg to write base-dir. May be redundant
echo "
[build_ext]
Expand Down Expand Up @@ -193,13 +184,15 @@ buildPythonPackage {
];

dontBuild = true;
nativeCheckInputs = [ pytestCheckHook ];
nativeCheckInputs = [
pytestCheckHook
writableTmpDirAsHomeHook
];

doCheck = true;

preCheck = ''
cd python/test/unit
export HOME=$TMPDIR
'';
checkPhase = "pytestCheckPhase";

Expand Down