diff --git a/pkg/config/compatibility.go b/pkg/config/compatibility.go index a449561701..b6971431e5 100644 --- a/pkg/config/compatibility.go +++ b/pkg/config/compatibility.go @@ -111,7 +111,7 @@ func init() { filteredTorchCompatibilityMatrix := []TorchCompatibility{} for _, compat := range torchCompatibilityMatrix { for _, cudaBaseImage := range CUDABaseImages { - if compat.CUDA == nil || strings.HasPrefix(cudaBaseImage.CUDA, *compat.CUDA) { + if compat.CUDA != nil && version.Matches(*compat.CUDA, cudaBaseImage.CUDA) { filteredTorchCompatibilityMatrix = append(filteredTorchCompatibilityMatrix, compat) break } @@ -192,7 +192,7 @@ func resolveMinorToPatch(minor string) (string, error) { func latestCuDNNForCUDA(cuda string) (string, error) { cuDNNs := []string{} for _, image := range CUDABaseImages { - if version.Equal(image.CUDA, cuda) { + if version.Matches(cuda, image.CUDA) { cuDNNs = append(cuDNNs, image.CuDNN) } } @@ -241,7 +241,7 @@ func versionGreater(a string, b string) (bool, error) { func CUDABaseImageFor(cuda string, cuDNN string) (string, error) { for _, image := range CUDABaseImages { - if version.Equal(image.CUDA, cuda) && image.CuDNN == cuDNN { + if version.Matches(cuda, image.CUDA) && image.CuDNN == cuDNN { return image.ImageTag(), nil } } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 5d1355ffbe..9abf6596f4 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -146,29 +146,24 @@ func TestValidateAndCompleteCUDAForAllTF(t *testing.T) { } } -// func TestValidateAndCompleteCUDAForAllTorch(t *testing.T) { -// // test that all torch versions fill out cuda -// for _, compat := range TorchCompatibilityMatrix { -// config := &Config{ -// Build: &Build{ -// GPU: true, -// PythonVersion: "3.8", -// PythonPackages: []string{ -// "torch==" + compat.TorchVersion(), -// }, -// }, -// } - -// for _, cudaBaseImage := range CUDABaseImages { -// if compat.CUDA == nil || strings.HasPrefix(cudaBaseImage.CUDA, *compat.CUDA) { -// err := config.ValidateAndComplete("") -// require.NoError(t, err) -// require.NotEqual(t, "", config.Build.CUDA) -// require.NotEqual(t, "", config.Build.CuDNN) -// } -// } -// } -// } +func TestValidateAndCompleteCUDAForAllTorch(t *testing.T) { + for _, compat := range TorchCompatibilityMatrix { + config := &Config{ + Build: &Build{ + GPU: compat.CUDA != nil, + PythonVersion: "3.8", + PythonPackages: []string{ + "torch==" + compat.TorchVersion(), + }, + }, + } + + err := config.ValidateAndComplete("") + require.NoError(t, err) + require.NotEqual(t, "", config.Build.CUDA) + require.NotEqual(t, "", config.Build.CuDNN) + } +} func TestValidateAndCompleteCUDAForSelectedTorch(t *testing.T) { for _, tt := range []struct { @@ -317,9 +312,8 @@ func TestPythonPackagesForArchTorchCPU(t *testing.T) { requirements, err := config.PythonRequirementsForArch("", "") require.NoError(t, err) - expected := `--find-links https://download.pytorch.org/whl/torch_stable.html -torch==1.7.1+cpu -torchvision==0.8.2+cpu + expected := `torch==1.7.1 +torchvision==0.8.2 torchaudio==0.7.2 foo==1.0.0` require.Equal(t, expected, requirements) diff --git a/pkg/config/data/config_schema_v1.0.json b/pkg/config/data/config_schema_v1.0.json index 0e73702dda..7c7a48468f 100644 --- a/pkg/config/data/config_schema_v1.0.json +++ b/pkg/config/data/config_schema_v1.0.json @@ -14,6 +14,11 @@ "type": "string", "description": "Cog automatically picks the correct version of CUDA to install, but this lets you override it for whatever reason." }, + "cudnn": { + "$id": "#/properties/build/properties/cudnn", + "type": "string", + "description": "Cog automatically picks the correct version of cuDNN to install, but this lets you override it for whatever reason." + }, "gpu": { "$id": "#/properties/build/properties/gpu", "type": "boolean", diff --git a/pkg/dockerfile/generator_test.go b/pkg/dockerfile/generator_test.go index 6e11acf869..827b88ef3e 100644 --- a/pkg/dockerfile/generator_test.go +++ b/pkg/dockerfile/generator_test.go @@ -166,8 +166,7 @@ COPY . /src` requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt")) require.NoError(t, err) - require.Equal(t, `--find-links https://download.pytorch.org/whl/torch_stable.html -torch==1.5.1+cpu + require.Equal(t, `torch==1.5.1 pandas==1.2.0.12`, string(requirements)) } diff --git a/pkg/util/version/version.go b/pkg/util/version/version.go index 09c1338344..1f21dc4532 100644 --- a/pkg/util/version/version.go +++ b/pkg/util/version/version.go @@ -85,3 +85,20 @@ func EqualMinor(v1 string, v2 string) bool { func Greater(v1 string, v2 string) bool { return MustVersion(v1).Greater(MustVersion(v2)) } + +func (v *Version) Matches(other *Version) bool { + switch { + case v.Major != other.Major: + return false + case v.Minor != other.Minor: + return false + case v.Patch != 0 && v.Patch != other.Patch: + return false + default: + return true + } +} + +func Matches(v1 string, v2 string) bool { + return MustVersion(v1).Matches(MustVersion(v2)) +}