diff --git a/pkgs/development/python-modules/torchvision/default.nix b/pkgs/development/python-modules/torchvision/default.nix index d36beb6575e09..a9c5defbe38e6 100644 --- a/pkgs/development/python-modules/torchvision/default.nix +++ b/pkgs/development/python-modules/torchvision/default.nix @@ -1,30 +1,50 @@ -{ lib -, symlinkJoin -, buildPythonPackage +{ buildPythonPackage +, cudaSupport ? torch.cudaSupport or false # by default uses the value from torch , fetchFromGitHub -, ninja -, which +, lib , libjpeg_turbo , libpng +, ninja , numpy -, scipy , pillow -, torch , pytest -, cudaSupport ? torch.cudaSupport or false # by default uses the value from torch +, scipy +, symlinkJoin +, torch +, which }: let - inherit (torch) gpuTargetString; - inherit (torch.cudaPackages) cudatoolkit cudnn; + inherit (torch) cudaPackages gpuTargetString; + inherit (cudaPackages) cudatoolkit cudaFlags cudaVersion; + + # NOTE: torchvision doesn't use cudnn; torch does! + # For this reason it is not included. + cuda-common-redist = with cudaPackages; [ + cuda_cccl # + libcublas # cublas_v2.h + libcusolver # cusolverDn.h + libcusparse # cusparse.h + ]; - cudatoolkit_joined = symlinkJoin { - name = "${cudatoolkit.name}-unsplit"; - paths = [ cudatoolkit.out cudatoolkit.lib ]; + cuda-native-redist = symlinkJoin { + name = "cuda-native-redist-${cudaVersion}"; + paths = with cudaPackages; [ + cuda_cudart # cuda_runtime.h + cuda_nvcc + ] ++ cuda-common-redist; }; -in buildPythonPackage rec { + + cuda-redist = symlinkJoin { + name = "cuda-redist-${cudaVersion}"; + paths = cuda-common-redist; + }; + pname = "torchvision"; version = "0.14.1"; +in +buildPythonPackage { + inherit pname version; src = fetchFromGitHub { owner = "pytorch"; @@ -33,18 +53,22 @@ in buildPythonPackage rec { hash = "sha256-lKkEJolJQaLr1TVm44CizbJQedGa1wyy0cFWg2LTJN0="; }; - nativeBuildInputs = [ libpng ninja which ] - ++ lib.optionals cudaSupport [ cudatoolkit_joined ]; - - TORCHVISION_INCLUDE = "${libjpeg_turbo.dev}/include/"; - TORCHVISION_LIBRARY = "${libjpeg_turbo}/lib/"; + nativeBuildInputs = [ libpng ninja which ] ++ lib.optionals cudaSupport [ cuda-native-redist ]; - buildInputs = [ libjpeg_turbo libpng ] - ++ lib.optionals cudaSupport [ cudnn ]; + buildInputs = [ libjpeg_turbo libpng ] ++ lib.optionals cudaSupport [ cuda-redist ]; propagatedBuildInputs = [ numpy pillow torch scipy ]; - preBuild = lib.optionalString cudaSupport '' + preConfigure = '' + export TORCHVISION_INCLUDE="${libjpeg_turbo.dev}/include/" + export TORCHVISION_LIBRARY="${libjpeg_turbo}/lib/" + '' + # NOTE: We essentially override the compilers provided by stdenv because we don't have a hook + # for cudaPackages to swap in compilers supported by NVCC. + + lib.optionalString cudaSupport '' + export CC=${cudatoolkit.cc}/bin/cc + export CXX=${cudatoolkit.cc}/bin/c++ + export CUDAHOSTCXX=${cudatoolkit.cc}/bin/c++ export TORCH_CUDA_ARCH_LIST="${gpuTargetString}" export FORCE_CUDA=1 ''; @@ -52,6 +76,7 @@ in buildPythonPackage rec { # tries to download many datasets for tests doCheck = false; + pythonImportsCheck = [ "torchvision" ]; checkPhase = '' HOME=$TMPDIR py.test test --ignore=test/test_datasets_download.py '';