From 030575ecb20dc23ba3b0fcdc7e53bf746d3f06a3 Mon Sep 17 00:00:00 2001 From: Connor Baker Date: Fri, 9 May 2025 20:14:26 +0000 Subject: [PATCH] cuda-modules: fix and clean up multiplex builder package selection logic Signed-off-by: Connor Baker --- .../generic-builders/multiplex.nix | 104 +++++++++--------- 1 file changed, 50 insertions(+), 54 deletions(-) diff --git a/pkgs/development/cuda-modules/generic-builders/multiplex.nix b/pkgs/development/cuda-modules/generic-builders/multiplex.nix index 187da29e2ee19..3e2958449dc0a 100644 --- a/pkgs/development/cuda-modules/generic-builders/multiplex.nix +++ b/pkgs/development/cuda-modules/generic-builders/multiplex.nix @@ -30,16 +30,7 @@ shimsFn ? (throw "shimsFn must be provided"), }: let - inherit (lib) - attrsets - lists - modules - strings - ; - - inherit (stdenv) hostPlatform; - - evaluatedModules = modules.evalModules { + evaluatedModules = lib.modules.evalModules { modules = [ ../modules releasesModule @@ -50,49 +41,55 @@ let # - Releases: ../modules/${pname}/releases/releases.nix # - Package: ../modules/${pname}/releases/package.nix - # FIXME: do this at the module system level - propagatePlatforms = lib.mapAttrs ( - redistArch: packages: map (p: { inherit redistArch; } // p) packages - ); + # redistArch :: String + # Value is `"unsupported"` if the platform is not supported. + redistArch = flags.getRedistArch stdenv.hostPlatform.system; - # All releases across all platforms + # Check whether a package supports our CUDA version. + # satisfiesCudaVersion :: Package -> Bool + satisfiesCudaVersion = + package: + lib.versionAtLeast cudaMajorMinorVersion package.minCudaVersion + && lib.versionAtLeast package.maxCudaVersion cudaMajorMinorVersion; + + # Releases for our platform and CUDA version. # See ../modules/${pname}/releases/releases.nix - releaseSets = propagatePlatforms evaluatedModules.config.${pname}.releases; + # allPackages :: List Package + allPackages = lib.filter satisfiesCudaVersion ( + evaluatedModules.config.${pname}.releases.${redistArch} or [ ] + ); # Compute versioned attribute name to be used in this package set # Patch version changes should not break the build, so we only use major and minor # computeName :: Package -> String - computeName = { version, ... }: mkVersionedPackageName pname version; - - # Check whether a package supports our CUDA version and platform. - # isSupported :: Package -> Bool - isSupported = - package: - redistArch == package.redistArch - && strings.versionAtLeast cudaMajorMinorVersion package.minCudaVersion - && strings.versionAtLeast package.maxCudaVersion cudaMajorMinorVersion; + computeName = package: mkVersionedPackageName pname package.version; - # Get all of the packages for our given platform. - # redistArch :: String - # Value is `"unsupported"` if the platform is not supported. - redistArch = flags.getRedistArch hostPlatform.system; - - preferable = - p1: p2: (isSupported p2 -> isSupported p1) && (strings.versionOlder p2.version p1.version); - - # All the supported packages we can build for our platform. - # perSystemReleases :: List Package - allReleases = lib.pipe releaseSets [ - (lib.attrValues) - (lists.flatten) - (lib.groupBy (p: lib.versions.majorMinor p.version)) - (lib.mapAttrs (_: builtins.sort preferable)) - (lib.mapAttrs (_: lib.take 1)) - (lib.attrValues) - (lib.concatMap lib.trivial.id) - ]; - - newest = builtins.head (builtins.sort preferable allReleases); + # The newest package for each major-minor version, with newest first. + # newestPackages :: List Package + newestPackages = + let + newestForEachMajorMinorVersion = lib.foldl' ( + newestPackages: package: + let + majorMinorVersion = lib.versions.majorMinor package.version; + existingPackage = newestPackages.${majorMinorVersion} or null; + in + newestPackages + // { + ${majorMinorVersion} = + # Only keep the existing package if it is newer than the one we are considering. + if existingPackage != null && lib.versionOlder package.version existingPackage.version then + existingPackage + else + package; + } + ) { } allPackages; + in + # Sort the packages by version so the newest is first. + # NOTE: builtins.sort requires a strict weak ordering, so we must use versionOlder rather than versionAtLeast. + lib.sort (p1: p2: lib.versionOlder p2.version p1.version) ( + lib.attrValues newestForEachMajorMinorVersion + ); extension = final: _: @@ -102,25 +99,24 @@ let buildPackage = package: let - shims = final.callPackage shimsFn { - inherit package; - inherit (package) redistArch; - }; + shims = final.callPackage shimsFn { inherit package redistArch; }; name = computeName package; drv = final.callPackage ./manifest.nix { inherit pname redistName; inherit (shims) redistribRelease featureRelease; }; in - attrsets.nameValuePair name drv; + lib.nameValuePair name drv; # versionedDerivations :: AttrSet Derivation - versionedDerivations = builtins.listToAttrs (lists.map buildPackage allReleases); + versionedDerivations = builtins.listToAttrs (lib.map buildPackage newestPackages); defaultDerivation = { - ${pname} = (buildPackage newest).value; + ${pname} = (buildPackage (lib.head newestPackages)).value; }; in - versionedDerivations // defaultDerivation; + # NOTE: Must condition on the length of newestPackages to avoid non-total function lib.head aborting if + # newestPackages is empty. + lib.optionalAttrs (lib.length newestPackages > 0) (versionedDerivations // defaultDerivation); in extension