diff --git a/pkgs/development/cuda-modules/write-gpu-python-test.nix b/pkgs/development/cuda-modules/write-gpu-python-test.nix deleted file mode 100644 index 5f0d5c6b8fe68..0000000000000 --- a/pkgs/development/cuda-modules/write-gpu-python-test.nix +++ /dev/null @@ -1,29 +0,0 @@ -{ - lib, - writers, - runCommand, -}: -{ - feature ? "cuda", - name ? feature, - libraries ? [ ], -}: -content: - -let - tester = writers.writePython3Bin "tester-${name}" { inherit libraries; } content; - tester' = tester.overrideAttrs (oldAttrs: { - passthru.gpuCheck = - runCommand "test-${name}" - { - nativeBuildInputs = [ tester' ]; - requiredSystemFeatures = [ feature ]; - } - '' - set -e - ${tester.meta.mainProgram or (lib.getName tester')} - touch $out - ''; - }); -in -tester' diff --git a/pkgs/development/cuda-modules/write-gpu-test-python.nix b/pkgs/development/cuda-modules/write-gpu-test-python.nix new file mode 100644 index 0000000000000..23282e9169291 --- /dev/null +++ b/pkgs/development/cuda-modules/write-gpu-test-python.nix @@ -0,0 +1,66 @@ +{ + lib, + runCommand, + python3Packages, + makeWrapper, +}: +{ + feature ? "cuda", + name ? if feature == null then "cpu" else feature, + libraries ? [ ], # [PythonPackage] | (PackageSet -> [PythonPackage]) + ... +}@args: + +let + inherit (builtins) isFunction all; + librariesFun = if isFunction libraries then libraries else (_: libraries); +in + +assert lib.assertMsg ( + isFunction libraries || all (python3Packages.hasPythonModule) libraries +) "writeGpuTestPython was passed `libraries` from the wrong python release"; + +content: + +let + interpreter = python3Packages.python.withPackages librariesFun; + tester = + runCommand "tester-${name}" + ( + lib.removeAttrs args [ + "libraries" + "name" + ] + // { + inherit content; + nativeBuildInputs = args.nativeBuildInputs or [ ] ++ [ makeWrapper ]; + passAsFile = args.passAsFile or [ ] ++ [ "content" ]; + } + ) + '' + mkdir -p "$out"/bin + cat << EOF >"$out/bin/$name" + #!${lib.getExe interpreter} + EOF + cat "$contentPath" >>"$out/bin/$name" + chmod +x "$out/bin/$name" + + if [[ -n "''${makeWrapperArgs+''${makeWrapperArgs[@]}}" ]] ; then + wrapProgram "$out/bin/$name" ''${makeWrapperArgs[@]} + fi + ''; + tester' = tester.overrideAttrs (oldAttrs: { + passthru.gpuCheck = + runCommand "test-${name}" + { + nativeBuildInputs = [ tester' ]; + requiredSystemFeatures = lib.optionals (feature != null) [ feature ]; + } + '' + set -e + ${tester.meta.mainProgram or (lib.getName tester')} + touch $out + ''; + }); +in +tester' diff --git a/pkgs/development/python-modules/torch/mk-runtime-check.nix b/pkgs/development/python-modules/torch/mk-runtime-check.nix index 14560b06f87ce..61180a19aaba5 100644 --- a/pkgs/development/python-modules/torch/mk-runtime-check.nix +++ b/pkgs/development/python-modules/torch/mk-runtime-check.nix @@ -1,14 +1,15 @@ { cudaPackages, feature, - torch, + libraries, versionAttr, + pythonPackages, }: -cudaPackages.writeGpuTestPython +(cudaPackages.writeGpuTestPython.override { python3Packages = pythonPackages; }) { inherit feature; - libraries = [ torch ]; + inherit libraries; name = "${feature}Available"; } '' diff --git a/pkgs/development/python-modules/torch/mk-torch-compile-check.nix b/pkgs/development/python-modules/torch/mk-torch-compile-check.nix new file mode 100644 index 0000000000000..268ed5297da94 --- /dev/null +++ b/pkgs/development/python-modules/torch/mk-torch-compile-check.nix @@ -0,0 +1,38 @@ +{ + cudaPackages, + feature ? null, + lib, + libraries, + name ? if feature == null then "torch-compile-cpu" else "torch-compile-${feature}", + pythonPackages, + stdenv, +}: +let + deviceStr = if feature == null then "" else '', device="cuda"''; +in +(cudaPackages.writeGpuTestPython.override { python3Packages = pythonPackages; }) + { + inherit name feature libraries; + makeWrapperArgs = [ + "--suffix" + "PATH" + ":" + "${lib.getBin stdenv.cc}/bin" + ]; + } + '' + import torch + + + @torch.compile + def opt_foo2(x, y): + a = torch.sin(x) + b = torch.cos(y) + return a + b + + + print( + opt_foo2( + torch.randn(10, 10${deviceStr}), + torch.randn(10, 10${deviceStr}))) + '' diff --git a/pkgs/development/python-modules/torch/tests.nix b/pkgs/development/python-modules/torch/tests.nix index 76b901cbcea91..e3f2ca44ba5a9 100644 --- a/pkgs/development/python-modules/torch/tests.nix +++ b/pkgs/development/python-modules/torch/tests.nix @@ -1,21 +1,31 @@ -{ - callPackage, - torchWithCuda, - torchWithRocm, -}: +{ callPackage }: -{ +rec { # To perform the runtime check use either # `nix run .#python3Packages.torch.tests.tester-cudaAvailable` (outside the sandbox), or # `nix build .#python3Packages.torch.tests.tester-cudaAvailable.gpuCheck` (in a relaxed sandbox) tester-cudaAvailable = callPackage ./mk-runtime-check.nix { feature = "cuda"; versionAttr = "cuda"; - torch = torchWithCuda; + libraries = ps: [ ps.torchWithCuda ]; }; tester-rocmAvailable = callPackage ./mk-runtime-check.nix { feature = "rocm"; versionAttr = "hip"; - torch = torchWithRocm; + libraries = ps: [ ps.torchWithRocm ]; + }; + + compileCpu = tester-compileCpu.gpuCheck; + tester-compileCpu = callPackage ./mk-torch-compile-check.nix { + feature = null; + libraries = ps: [ ps.torch ]; + }; + tester-compileCuda = callPackage ./mk-torch-compile-check.nix { + feature = "cuda"; + libraries = ps: [ ps.torchWithCuda ]; + }; + tester-compileRocm = callPackage ./mk-torch-compile-check.nix { + feature = "rocm"; + libraries = ps: [ ps.torchWithRocm ]; }; } diff --git a/pkgs/top-level/cuda-packages.nix b/pkgs/top-level/cuda-packages.nix index 5540b89f1b983..639fa70446bee 100644 --- a/pkgs/top-level/cuda-packages.nix +++ b/pkgs/top-level/cuda-packages.nix @@ -81,7 +81,7 @@ let nccl = final.callPackage ../development/cuda-modules/nccl { }; nccl-tests = final.callPackage ../development/cuda-modules/nccl-tests { }; - writeGpuTestPython = final.callPackage ../development/cuda-modules/write-gpu-python-test.nix { }; + writeGpuTestPython = final.callPackage ../development/cuda-modules/write-gpu-test-python.nix { }; }); mkVersionedPackageName =