diff --git a/pkgs/development/python-modules/torch/darwin-check.py b/pkgs/development/python-modules/torch/darwin-check.py new file mode 100644 index 00000000000000..8a9da0d7f277d2 --- /dev/null +++ b/pkgs/development/python-modules/torch/darwin-check.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +import torch + + +def opt_foo2(x, y): + a = torch.sin(x) + b = torch.cos(y) + return a + b + + +print("Testing CPU") +print( + opt_foo2( + torch.randn(10, 10), + torch.randn(10, 10))) + +print("Testing MPS") +assert torch.backends.mps.is_built(), "PyTorch not built with MPS enabled" +if not torch.backends.mps.is_available(): + print("MPS not available because the current MacOS version is not 12.3+ " + "and/or you do not have an MPS-enabled device on this machine.") + +else: + print( + opt_foo2( + torch.randn(10, 10, device="mps"), + torch.randn(10, 10, device="mps"))) diff --git a/pkgs/development/python-modules/torch/mk-darwin-check.nix b/pkgs/development/python-modules/torch/mk-darwin-check.nix new file mode 100644 index 00000000000000..e48df51cdd31f6 --- /dev/null +++ b/pkgs/development/python-modules/torch/mk-darwin-check.nix @@ -0,0 +1,12 @@ +{ + libraries, + python, + writeShellApplication, +}: +writeShellApplication { + name = "test-torch-darwin"; + runtimeInputs = [ (python.withPackages libraries) ]; + text = '' + python ${./darwin-check.py} + ''; +} diff --git a/pkgs/development/python-modules/torch/tests.nix b/pkgs/development/python-modules/torch/tests.nix index e3f2ca44ba5a94..c16b6502a87308 100644 --- a/pkgs/development/python-modules/torch/tests.nix +++ b/pkgs/development/python-modules/torch/tests.nix @@ -28,4 +28,7 @@ rec { feature = "rocm"; libraries = ps: [ ps.torchWithRocm ]; }; + tester-darwin = callPackage ./mk-darwin-check.nix { + libraries = ps: [ ps.torch ]; + }; }