Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 0 additions & 29 deletions pkgs/development/cuda-modules/write-gpu-python-test.nix

This file was deleted.

66 changes: 66 additions & 0 deletions pkgs/development/cuda-modules/write-gpu-test-python.nix
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched from writers.writePython3Bin to a simple runCommand because it was easier to override python3Packages this way. Either way I'll keep editing this thing

Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
{
lib,
runCommand,
python3Packages,
makeWrapper,
}:
{
feature ? "cuda",
name ? if feature == null then "cpu" else feature,
libraries ? [ ], # [PythonPackage] | (PackageSet -> [PythonPackage])
...
}@args:

let
inherit (builtins) isFunction all;
librariesFun = if isFunction libraries then libraries else (_: libraries);
in

assert lib.assertMsg (
isFunction libraries || all (python3Packages.hasPythonModule) libraries
) "writeGpuTestPython was passed `libraries` from the wrong python release";

content:

let
interpreter = python3Packages.python.withPackages librariesFun;
tester =
runCommand "tester-${name}"
(
lib.removeAttrs args [
"libraries"
"name"
]
// {
inherit content;
nativeBuildInputs = args.nativeBuildInputs or [ ] ++ [ makeWrapper ];
passAsFile = args.passAsFile or [ ] ++ [ "content" ];
}
)
''
mkdir -p "$out"/bin
cat << EOF >"$out/bin/$name"
#!${lib.getExe interpreter}
EOF
cat "$contentPath" >>"$out/bin/$name"
chmod +x "$out/bin/$name"

if [[ -n "''${makeWrapperArgs+''${makeWrapperArgs[@]}}" ]] ; then
wrapProgram "$out/bin/$name" ''${makeWrapperArgs[@]}
fi
'';
tester' = tester.overrideAttrs (oldAttrs: {
passthru.gpuCheck =
runCommand "test-${name}"
{
nativeBuildInputs = [ tester' ];
requiredSystemFeatures = lib.optionals (feature != null) [ feature ];
}
''
set -e
${tester.meta.mainProgram or (lib.getName tester')}
touch $out
'';
});
in
tester'
7 changes: 4 additions & 3 deletions pkgs/development/python-modules/torch/mk-runtime-check.nix
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
{
cudaPackages,
feature,
torch,
libraries,
versionAttr,
pythonPackages,
}:

cudaPackages.writeGpuTestPython
(cudaPackages.writeGpuTestPython.override { python3Packages = pythonPackages; })
{
inherit feature;
libraries = [ torch ];
inherit libraries;
name = "${feature}Available";
}
''
Expand Down
38 changes: 38 additions & 0 deletions pkgs/development/python-modules/torch/mk-torch-compile-check.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
cudaPackages,
feature ? null,
lib,
libraries,
name ? if feature == null then "torch-compile-cpu" else "torch-compile-${feature}",
pythonPackages,
stdenv,
}:
let
deviceStr = if feature == null then "" else '', device="cuda"'';
in
(cudaPackages.writeGpuTestPython.override { python3Packages = pythonPackages; })
{
inherit name feature libraries;
makeWrapperArgs = [
"--suffix"
"PATH"
":"
"${lib.getBin stdenv.cc}/bin"
];
}
''
import torch


@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b


print(
opt_foo2(
torch.randn(10, 10${deviceStr}),
torch.randn(10, 10${deviceStr})))
''
26 changes: 18 additions & 8 deletions pkgs/development/python-modules/torch/tests.nix
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
{
callPackage,
torchWithCuda,
torchWithRocm,
}:
{ callPackage }:

{
rec {
# To perform the runtime check use either
# `nix run .#python3Packages.torch.tests.tester-cudaAvailable` (outside the sandbox), or
# `nix build .#python3Packages.torch.tests.tester-cudaAvailable.gpuCheck` (in a relaxed sandbox)
tester-cudaAvailable = callPackage ./mk-runtime-check.nix {
feature = "cuda";
versionAttr = "cuda";
torch = torchWithCuda;
libraries = ps: [ ps.torchWithCuda ];
};
tester-rocmAvailable = callPackage ./mk-runtime-check.nix {
feature = "rocm";
versionAttr = "hip";
torch = torchWithRocm;
libraries = ps: [ ps.torchWithRocm ];
};

compileCpu = tester-compileCpu.gpuCheck;
tester-compileCpu = callPackage ./mk-torch-compile-check.nix {
feature = null;
libraries = ps: [ ps.torch ];
};
tester-compileCuda = callPackage ./mk-torch-compile-check.nix {
feature = "cuda";
libraries = ps: [ ps.torchWithCuda ];
};
tester-compileRocm = callPackage ./mk-torch-compile-check.nix {
feature = "rocm";
libraries = ps: [ ps.torchWithRocm ];
};
}
2 changes: 1 addition & 1 deletion pkgs/top-level/cuda-packages.nix
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ let
nccl = final.callPackage ../development/cuda-modules/nccl { };
nccl-tests = final.callPackage ../development/cuda-modules/nccl-tests { };

writeGpuTestPython = final.callPackage ../development/cuda-modules/write-gpu-python-test.nix { };
writeGpuTestPython = final.callPackage ../development/cuda-modules/write-gpu-test-python.nix { };
});

mkVersionedPackageName =
Expand Down