diff --git a/pkgs/top-level/python-packages.nix b/pkgs/top-level/python-packages.nix index a8fa9c95307c3b1..93ff3668d70e38e 100644 --- a/pkgs/top-level/python-packages.nix +++ b/pkgs/top-level/python-packages.nix @@ -8360,9 +8360,18 @@ in { pytools = callPackage ../development/python-modules/pytools { }; - pytorch = callPackage ../development/python-modules/pytorch { - cudaSupport = pkgs.config.cudaSupport or false; - }; + pytorch = + let + cudnn = pkgs.cudnn_8_3_cudatoolkit_11; + cudatoolkit = cudnn.cudatoolkit; + magma = pkgs.magma.override { inherit cudatoolkit; }; + nccl = pkgs.nccl.override { inherit cudatoolkit; }; + in + callPackage ../development/python-modules/pytorch + { + cudaSupport = pkgs.config.cudaSupport or false; + inherit cudnn cudatoolkit magma nccl; + }; pytorch-bin = callPackage ../development/python-modules/pytorch/bin.nix { };