Skip to content

Commit

Permalink
python3Packages.torch.tests: tests for darwin
Browse files Browse the repository at this point in the history
Run nixfmt for mk-torch-compile-check.nix
On darwin, both CPU and MPS are tested now.

Signed-off-by: Mika Tammi <[email protected]>
  • Loading branch information
mikatammi committed Oct 28, 2024
1 parent cae0fd4 commit 58c970d
Showing 1 changed file with 30 additions and 34 deletions.
64 changes: 30 additions & 34 deletions pkgs/development/python-modules/torch/mk-torch-compile-check.nix
Original file line number Diff line number Diff line change
@@ -1,38 +1,34 @@
{
cudaPackages,
feature ? null,
lib,
libraries,
name ? if feature == null then "torch-compile-cpu" else "torch-compile-${feature}",
pythonPackages,
stdenv,
}:
let
deviceStr = if feature == null then "" else '', device="cuda"'';
in
(cudaPackages.writeGpuTestPython.override { python3Packages = pythonPackages; })
{
inherit name feature libraries;
makeWrapperArgs = [
"--suffix"
"PATH"
":"
"${lib.getBin stdenv.cc}/bin"
];
}
''
import torch
{ cudaPackages, feature ? null, lib, libraries, name ?
if feature == null then "torch-compile-cpu" else "torch-compile-${feature}"
, pythonPackages, stdenv, }:
let deviceStr = if feature == null then "" else '', device="${feature}"'';
in (cudaPackages.writeGpuTestPython.override {
python3Packages = pythonPackages;
}) {
inherit name feature libraries;
makeWrapperArgs = [ "--suffix" "PATH" ":" "${lib.getBin stdenv.cc}/bin" ];
} (''
import torch
@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
'' + lib.optionalString (!stdenv.hostPlatform.isDarwin) ''
# torch.compile requires OpenMP which is not available on Darwin
@torch.compile
'' + ''
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
print(
opt_foo2(
torch.randn(10, 10${deviceStr}),
torch.randn(10, 10${deviceStr})))
''
print(
opt_foo2(
torch.randn(10, 10${deviceStr}),
torch.randn(10, 10${deviceStr})))
'' + lib.optionalString stdenv.hostPlatform.isDarwin ''
# MPS is built by default on Darwin, so test it.
print(
opt_foo2(
torch.randn(10, 10, device="mps"),
torch.randn(10, 10, device="mps")))
'')

0 comments on commit 58c970d

Please sign in to comment.