From b3be8270b963d9ce8b28c8a09af52297f86a923c Mon Sep 17 00:00:00 2001 From: Avish Porwal Date: Mon, 15 Sep 2025 23:14:33 +0530 Subject: [PATCH] Extra validation for registries --- internal/validators/registries/mcpb.go | 11 +++++++- internal/validators/registries/mcpb_test.go | 24 +++++++++++++++++ internal/validators/registries/npm.go | 10 +++++-- internal/validators/registries/nuget.go | 12 ++++++++- internal/validators/registries/nuget_test.go | 26 ++++++++++++++++++- internal/validators/registries/oci.go | 15 ++++++++++- internal/validators/registries/oci_test.go | 24 +++++++++++++++++ internal/validators/registries/pypi.go | 16 +++++++++++- internal/validators/registries/pypi_test.go | 20 ++++++++++++-- .../validators/registries/testutils_test.go | 2 +- 10 files changed, 150 insertions(+), 10 deletions(-) diff --git a/internal/validators/registries/mcpb.go b/internal/validators/registries/mcpb.go index fa45886f..7847a6b2 100644 --- a/internal/validators/registries/mcpb.go +++ b/internal/validators/registries/mcpb.go @@ -12,10 +12,19 @@ import ( "github.com/modelcontextprotocol/registry/pkg/model" ) +var ( + ErrMissingIdentifierForMCPB = fmt.Errorf("package identifier is required for MCPB packages") + ErrMissingFileSHA256ForMCPB = fmt.Errorf("must include a fileSha256 hash for integrity verification") +) + func ValidateMCPB(ctx context.Context, pkg model.Package, _ string) error { // MCPB packages must include a file hash for integrity verification if pkg.FileSHA256 == "" { - return fmt.Errorf("MCPB package must include a fileSha256 hash for integrity verification") + return ErrMissingFileSHA256ForMCPB + } + + if pkg.Identifier == "" { + return ErrMissingIdentifierForMCPB } err := validateMCPBUrl(pkg.Identifier) diff --git a/internal/validators/registries/mcpb_test.go b/internal/validators/registries/mcpb_test.go index 8167f5bf..9db6de0e 100644 --- a/internal/validators/registries/mcpb_test.go +++ b/internal/validators/registries/mcpb_test.go @@ -20,6 +20,30 @@ func TestValidateMCPB(t *testing.T) { expectError bool errorMessage string }{ + { + name: "empty package identifier should fail", + packageName: "", + serverName: "com.example/test", + fileSHA256: "abc123ef4567890abcdef1234567890abcdef1234567890abcdef1234567890", + expectError: true, + errorMessage: "package identifier is required for MCPB packages", + }, + { + name: "empty file SHA256 should fail", + packageName: "https://github.com/example/server/releases/download/v1.0.0/server.mcpb", + serverName: "com.example/test", + fileSHA256: "", + expectError: true, + errorMessage: "must include a fileSha256 hash for integrity verification", + }, + { + name: "both empty identifier and file SHA256 should fail with file SHA256 error first", + packageName: "", + serverName: "com.example/test", + fileSHA256: "", + expectError: true, + errorMessage: "must include a fileSha256 hash for integrity verification", + }, { name: "valid MCPB package should pass", packageName: "https://github.com/domdomegg/airtable-mcp-server/releases/download/v1.7.2/airtable-mcp-server.mcpb", diff --git a/internal/validators/registries/npm.go b/internal/validators/registries/npm.go index 0cff56dd..4a1449ce 100644 --- a/internal/validators/registries/npm.go +++ b/internal/validators/registries/npm.go @@ -3,6 +3,7 @@ package registries import ( "context" "encoding/json" + "errors" "fmt" "net/http" "net/url" @@ -11,6 +12,11 @@ import ( "github.com/modelcontextprotocol/registry/pkg/model" ) +var ( + ErrMissingIdentifierForNPM = errors.New("package identifier is required for NPM packages") + ErrMissingVersionForNPM = errors.New("package version is required for NPM packages") +) + // NPMPackageResponse represents the structure returned by the NPM registry API type NPMPackageResponse struct { MCPName string `json:"mcpName"` @@ -24,7 +30,7 @@ func ValidateNPM(ctx context.Context, pkg model.Package, serverName string) erro } if pkg.Identifier == "" { - return fmt.Errorf("package identifier is required for NPM packages") + return ErrMissingIdentifierForNPM } // we need version to look up the package metadata @@ -32,7 +38,7 @@ func ValidateNPM(ctx context.Context, pkg model.Package, serverName string) erro // and we won't be able to validate the mcpName field // against the server name if pkg.Version == "" { - return fmt.Errorf("package version is required for NPM packages") + return ErrMissingVersionForNPM } // Validate that the registry base URL matches NPM exactly diff --git a/internal/validators/registries/nuget.go b/internal/validators/registries/nuget.go index 1e969987..01a2d47b 100644 --- a/internal/validators/registries/nuget.go +++ b/internal/validators/registries/nuget.go @@ -2,6 +2,7 @@ package registries import ( "context" + "errors" "fmt" "io" "net/http" @@ -11,6 +12,11 @@ import ( "github.com/modelcontextprotocol/registry/pkg/model" ) +var ( + ErrMissingIdentifierForNuget = errors.New("package identifier is required for NuGet packages") + ErrMissingVersionForNuget = errors.New("package version is required for NuGet packages") +) + // ValidateNuGet validates that a NuGet package contains the correct MCP server name func ValidateNuGet(ctx context.Context, pkg model.Package, serverName string) error { // Set default registry base URL if empty @@ -18,6 +24,10 @@ func ValidateNuGet(ctx context.Context, pkg model.Package, serverName string) er pkg.RegistryBaseURL = model.RegistryURLNuGet } + if pkg.Identifier == "" { + return ErrMissingIdentifierForNuget + } + // Validate that the registry base URL matches NuGet exactly if pkg.RegistryBaseURL != model.RegistryURLNuGet { return fmt.Errorf("registry type and base URL do not match: '%s' is not valid for registry type '%s'. Expected: %s", @@ -29,7 +39,7 @@ func ValidateNuGet(ctx context.Context, pkg model.Package, serverName string) er lowerID := strings.ToLower(pkg.Identifier) lowerVersion := strings.ToLower(pkg.Version) if lowerVersion == "" { - return fmt.Errorf("NuGet package validation requires a specific version, but none was provided") + return ErrMissingVersionForNuget } // Try to get README from the package diff --git a/internal/validators/registries/nuget_test.go b/internal/validators/registries/nuget_test.go index 9ddb402a..79955c5c 100644 --- a/internal/validators/registries/nuget_test.go +++ b/internal/validators/registries/nuget_test.go @@ -20,6 +20,30 @@ func TestValidateNuGet_RealPackages(t *testing.T) { expectError bool errorMessage string }{ + { + name: "empty package identifier should fail", + packageName: "", + version: "1.0.0", + serverName: "com.example/test", + expectError: true, + errorMessage: "package identifier is required for NuGet packages", + }, + { + name: "empty package version should fail", + packageName: "test-package", + version: "", + serverName: "com.example/test", + expectError: true, + errorMessage: "package version is required for NuGet packages", + }, + { + name: "both empty identifier and version should fail with identifier error first", + packageName: "", + version: "", + serverName: "com.example/test", + expectError: true, + errorMessage: "package identifier is required for NuGet packages", + }, { name: "non-existent package should fail", packageName: generateRandomNuGetPackageName(), @@ -34,7 +58,7 @@ func TestValidateNuGet_RealPackages(t *testing.T) { version: "", // No version provided serverName: "com.example/test", expectError: true, - errorMessage: "requires a specific version", + errorMessage: "package version is required for NuGet packages", }, { name: "real package with non-existent version should fail", diff --git a/internal/validators/registries/oci.go b/internal/validators/registries/oci.go index b45a03e6..ca96f324 100644 --- a/internal/validators/registries/oci.go +++ b/internal/validators/registries/oci.go @@ -13,6 +13,11 @@ import ( "github.com/modelcontextprotocol/registry/pkg/model" ) +var ( + ErrMissingIdentifierForOCI = errors.New("package identifier is required for OCI packages") + ErrMissingVersionForOCI = errors.New("package version is required for OCI packages") +) + const ( dockerIoAPIBaseURL = "https://registry-1.docker.io" ghcrAPIBaseURL = "https://ghcr.io" @@ -80,6 +85,15 @@ func ValidateOCI(ctx context.Context, pkg model.Package, serverName string) erro pkg.RegistryBaseURL = model.RegistryURLDocker } + if pkg.Identifier == "" { + return ErrMissingIdentifierForOCI + } + + // we need version (tag) to look up the image manifest + if pkg.Version == "" { + return ErrMissingVersionForOCI + } + // Validate that the registry base URL is supported if err := validateRegistryURL(pkg.RegistryBaseURL); err != nil { return err @@ -258,7 +272,6 @@ func getRegistryAuthToken(ctx context.Context, client *http.Client, config *Regi return authResp.Token, nil } - // getSpecificManifest retrieves a specific manifest for multi-arch images func getSpecificManifest(ctx context.Context, client *http.Client, registryConfig *RegistryConfig, namespace, repo, digest string) (*OCIManifest, error) { manifestURL := fmt.Sprintf("%s/v2/%s/%s/manifests/%s", registryConfig.APIBaseURL, namespace, repo, digest) diff --git a/internal/validators/registries/oci_test.go b/internal/validators/registries/oci_test.go index 4e31157c..ea611eaf 100644 --- a/internal/validators/registries/oci_test.go +++ b/internal/validators/registries/oci_test.go @@ -21,6 +21,30 @@ func TestValidateOCI_RealPackages(t *testing.T) { errorMessage string registryURL string }{ + { + name: "empty package identifier should fail", + packageName: "", + version: "latest", + serverName: "com.example/test", + expectError: true, + errorMessage: "package identifier is required for OCI packages", + }, + { + name: "empty package version should fail", + packageName: "test-image", + version: "", + serverName: "com.example/test", + expectError: true, + errorMessage: "package version is required for OCI packages", + }, + { + name: "both empty identifier and version should fail with identifier error first", + packageName: "", + version: "", + serverName: "com.example/test", + expectError: true, + errorMessage: "package identifier is required for OCI packages", + }, { name: "non-existent image should fail", packageName: generateRandomImageName(), diff --git a/internal/validators/registries/pypi.go b/internal/validators/registries/pypi.go index 7c4e45cf..7ae98d65 100644 --- a/internal/validators/registries/pypi.go +++ b/internal/validators/registries/pypi.go @@ -3,6 +3,7 @@ package registries import ( "context" "encoding/json" + "errors" "fmt" "net/http" "strings" @@ -11,6 +12,11 @@ import ( "github.com/modelcontextprotocol/registry/pkg/model" ) +var ( + ErrMissingIdentifierForPyPI = errors.New("package identifier is required for PyPI packages") + ErrMissingVersionForPyPi = errors.New("package version is required for PyPI packages") +) + // PyPIPackageResponse represents the structure returned by the PyPI JSON API type PyPIPackageResponse struct { Info struct { @@ -25,6 +31,14 @@ func ValidatePyPI(ctx context.Context, pkg model.Package, serverName string) err pkg.RegistryBaseURL = model.RegistryURLPyPI } + if pkg.Identifier == "" { + return ErrMissingIdentifierForPyPI + } + + if pkg.Version == "" { + return ErrMissingVersionForPyPi + } + // Validate that the registry base URL matches PyPI exactly if pkg.RegistryBaseURL != model.RegistryURLPyPI { return fmt.Errorf("registry type and base URL do not match: '%s' is not valid for registry type '%s'. Expected: %s", @@ -33,7 +47,7 @@ func ValidatePyPI(ctx context.Context, pkg model.Package, serverName string) err client := &http.Client{Timeout: 10 * time.Second} - url := fmt.Sprintf("%s/pypi/%s/json", pkg.RegistryBaseURL, pkg.Identifier) + url := fmt.Sprintf("%s/pypi/%s/%s/json", pkg.RegistryBaseURL, pkg.Identifier, pkg.Version) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return fmt.Errorf("failed to create request: %w", err) diff --git a/internal/validators/registries/pypi_test.go b/internal/validators/registries/pypi_test.go index e88b1ba5..aadb05f1 100644 --- a/internal/validators/registries/pypi_test.go +++ b/internal/validators/registries/pypi_test.go @@ -20,6 +20,22 @@ func TestValidatePyPI_RealPackages(t *testing.T) { expectError bool errorMessage string }{ + { + name: "empty package identifier should fail", + packageName: "", + version: "1.0.0", + serverName: "com.example/test", + expectError: true, + errorMessage: "package identifier is required for PyPI packages", + }, + { + name: "empty package version should fail", + packageName: "mcp-server-example", + version: "", + serverName: "com.example/test", + expectError: true, + errorMessage: "package version is required for PyPI packages", + }, { name: "non-existent package should fail", packageName: generateRandomPackageName(), @@ -47,7 +63,7 @@ func TestValidatePyPI_RealPackages(t *testing.T) { { name: "real package with server name in README should pass", packageName: "time-mcp-pypi", - version: "1.0.0", + version: "1.0.6", serverName: "io.github.domdomegg/time-mcp-pypi", expectError: false, }, @@ -71,4 +87,4 @@ func TestValidatePyPI_RealPackages(t *testing.T) { } }) } -} \ No newline at end of file +} diff --git a/internal/validators/registries/testutils_test.go b/internal/validators/registries/testutils_test.go index d6e2468b..c1b5806e 100644 --- a/internal/validators/registries/testutils_test.go +++ b/internal/validators/registries/testutils_test.go @@ -29,4 +29,4 @@ func generateRandomImageName() string { return "nonexistent-image-fallback" } return fmt.Sprintf("nonexistent-image-%s", hex.EncodeToString(bytes)[:16]) -} \ No newline at end of file +}