diff --git a/lib/join/azurejoin/azure.go b/lib/join/azurejoin/azure.go index 7d2f5e7957901..798a30c672d87 100644 --- a/lib/join/azurejoin/azure.go +++ b/lib/join/azurejoin/azure.go @@ -220,7 +220,14 @@ func parseAndVerifyAttestedData( if len(p7.Certificates) == 0 { return "", "", trace.AccessDenied("no certificates for signature") } - fixAzureSigningAlgorithm(p7) + signingCert := p7.Certificates[0] + + if !isAllowedDomain(signingCert.Subject.CommonName, allowedAzureCommonNames) { + return "", "", trace.AccessDenied( + "certificate common name does not match allow-list (%s)", + signingCert.Subject.CommonName, + ) + } if len(intermediates) > 0 { // Client explicitly sent intermediate CAs, included them. @@ -245,6 +252,7 @@ func parseAndVerifyAttestedData( pool.AddCert(cert) } + fixAzureSigningAlgorithm(p7) if err := p7.VerifyWithChain(pool); err != nil { return "", "", trace.Wrap(err) } diff --git a/lib/join/azurejoin/azure_certs.go b/lib/join/azurejoin/azure_certs.go index 431e7bcecb35b..5ce76b77b3478 100644 --- a/lib/join/azurejoin/azure_certs.go +++ b/lib/join/azurejoin/azure_certs.go @@ -23,6 +23,7 @@ import ( "crypto/x509" "encoding/pem" "net/http" + "net/url" "strings" "github.com/gravitational/trace" @@ -47,40 +48,89 @@ func isAllowedDomain(cn string, domains []string) bool { return false } +func validateAzureCertIssuerURL(issuerURLString string) (string, error) { + // All active issuing certs are listed here + // https://www.microsoft.com/pkiops/docs/repository.htm + // + // The cert path always looks like the following, although this does not + // appear to be guaranteed by Microsoft. + // url: http://www.microsoft.com/pkiops/certs/.crt + // + // This code path is only used by the legacy join service which will be + // removed in v20, v18+ agents use the new join service where the joining + // client sends the intermediate CAs along with the request. + const ( + allowedHost = "www.microsoft.com" + allowedPathPrefix = "/pkiops/certs/" + allowedPathSuffix = ".crt" + ) + + issuerURL, err := url.Parse(issuerURLString) + if err != nil { + return "", trace.AccessDenied("url failed to parse") + } + + switch issuerURL.Scheme { + case "http", "https": + default: + return "", trace.AccessDenied("invalid url scheme %q", issuerURL.Scheme) + } + + if issuerURL.Host != allowedHost { + return "", trace.AccessDenied("invalid host %q", issuerURL.Host) + } + + if !strings.HasPrefix(issuerURL.Path, allowedPathPrefix) || + !strings.HasSuffix(issuerURL.Path, allowedPathSuffix) { + return "", trace.AccessDenied("invalid path, must match %s%s", + allowedPathPrefix, allowedPathSuffix) + } + + // Construct a new URL with only the scheme, host, and path to strip any + // possible extra fields like query params or fragments. + sanitizedURL := url.URL{ + Scheme: issuerURL.Scheme, + Host: allowedHost, + Path: issuerURL.Path, + } + return sanitizedURL.String(), nil +} + // getAzureIssuerCert fetches a x509 certificate's issuing certificate. func getAzureIssuerCert(ctx context.Context, cert *x509.Certificate, httpClient utils.HTTPDoClient) (*x509.Certificate, error) { if len(cert.IssuingCertificateURL) == 0 { - return nil, nil + return nil, trace.BadParameter("certificate has no issuing certificate URL") } // Azure sends only one issuing cert. issuerURL := cert.IssuingCertificateURL[0] - commonName := cert.Subject.CommonName - if !isAllowedDomain(commonName, allowedAzureCommonNames) { - return nil, trace.AccessDenied( - "certificate common name does not match allow-list (%s)", - commonName, - ) + sanitizedIssuerURL, err := validateAzureCertIssuerURL(issuerURL) + if err != nil { + return nil, trace.Wrap(err, "validating issuing certificate URL") } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, issuerURL, nil) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, sanitizedIssuerURL, nil /*body*/) if err != nil { return nil, trace.Wrap(err) } + resp, err := httpClient.Do(req) if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(err, "fetching issuing certificate") } defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, trace.AccessDenied("failed to fetch issuing cert, got HTTP status code %d", resp.StatusCode) + } + body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize) if err != nil { - return nil, trace.Wrap(err) - } - if resp.StatusCode != http.StatusOK { - return nil, trace.ReadError(resp.StatusCode, body) + return nil, trace.Wrap(err, "reading HTTP response body") } issuerCert, err := x509.ParseCertificate(body) - return issuerCert, trace.Wrap(err) + return issuerCert, trace.Wrap(err, "parsing issuing certificate") } func getAzureRootCerts() ([]*x509.Certificate, error) { diff --git a/lib/join/azurejoin/join_azure_test.go b/lib/join/azurejoin/join_azure_test.go index 9752766d8d769..631ca29973565 100644 --- a/lib/join/azurejoin/join_azure_test.go +++ b/lib/join/azurejoin/join_azure_test.go @@ -925,6 +925,163 @@ func TestJoinAzureClaims(t *testing.T) { } } +func TestAzureIssuerCert(t *testing.T) { + server, err := authtest.NewTestServer(authtest.ServerConfig{ + Auth: authtest.AuthServerConfig{ + Dir: t.TempDir(), + }}) + require.NoError(t, err) + a := server.Auth() + + nopClient, err := server.NewClient(authtest.TestNop()) + require.NoError(t, err) + + token, err := types.NewProvisionTokenFromSpec("testtoken", time.Now().Add(time.Minute), types.ProvisionTokenSpecV2{ + JoinMethod: types.JoinMethodAzure, + Roles: types.SystemRoles{types.RoleNode}, + Azure: &types.ProvisionTokenSpecV2Azure{ + Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ + { + Subscription: "testsubscription", + }, + }, + }, + }) + require.NoError(t, err) + require.NoError(t, a.UpsertToken(t.Context(), token)) + + caChain := newFakeAzureCAChain(t) + + instanceKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + instanceID := vmResourceID("testsubscription", "testgroup", "testid") + + accessToken, err := makeToken(instanceID, instanceID, a.GetClock().Now()) + require.NoError(t, err) + + defaultSubscription := uuid.NewString() + defaultResourceGroup := "my-resource-group" + defaultVMName := "test-vm" + defaultVMID := "my-vm-id" + defaultVMResourceID := vmResourceID(defaultSubscription, defaultResourceGroup, defaultVMName) + vmClient := &mockAzureVMClient{ + vms: map[string]*azure.VirtualMachine{ + defaultVMResourceID: { + ID: defaultVMResourceID, + Name: defaultVMName, + Subscription: defaultSubscription, + ResourceGroup: defaultResourceGroup, + VMID: defaultVMID, + }, + }, + } + getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{ + defaultSubscription: vmClient, + }) + + isAccessDenied := func(t require.TestingT, err error, msgAndArgs ...any) { + require.ErrorAs(t, err, new(*trace.AccessDeniedError), msgAndArgs...) + } + for _, tc := range []struct { + desc string + commonName string + issuerURL string + errorAssertion require.ErrorAssertionFunc + expecteRequestedIssuingCA bool + }{ + { + desc: "passing", + commonName: "instance.metadata.azure.com", + issuerURL: "http://www.microsoft.com/pkiops/certs/testca.crt", + errorAssertion: require.NoError, + expecteRequestedIssuingCA: true, + }, + { + desc: "bad common name", + commonName: "instance.metadata.bad.example.com", + issuerURL: "http://www.microsoft.com/pkiops/certs/testca.crt", + errorAssertion: func(t require.TestingT, err error, msgAndArgs ...any) { + isAccessDenied(t, err, msgAndArgs...) + require.ErrorContains(t, err, "certificate common name does not match allow-list") + }, + }, + { + desc: "bad issuer host", + commonName: "instance.metadata.azure.com", + issuerURL: "http://www.bad.example.com/pkiops/certs/testca.crt", + errorAssertion: func(t require.TestingT, err error, msgAndArgs ...any) { + isAccessDenied(t, err, msgAndArgs...) + require.ErrorContains(t, err, "validating issuing certificate URL") + require.ErrorContains(t, err, "invalid host") + }, + }, + { + desc: "bad cert path", + commonName: "instance.metadata.azure.com", + issuerURL: "http://www.microsoft.com/pkiops/badcerts/badca.crt", + errorAssertion: func(t require.TestingT, err error, msgAndArgs ...any) { + isAccessDenied(t, err, msgAndArgs...) + require.ErrorContains(t, err, "validating issuing certificate URL") + require.ErrorContains(t, err, "invalid path") + }, + }, + { + desc: "bad url scheme", + commonName: "instance.metadata.azure.com", + issuerURL: "bad://www.microsoft.com/pkiops/certs/testca.crt", + errorAssertion: func(t require.TestingT, err error, msgAndArgs ...any) { + isAccessDenied(t, err, msgAndArgs...) + require.ErrorContains(t, err, "validating issuing certificate URL") + require.ErrorContains(t, err, "invalid url scheme") + }, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + // Generate an azure instance cert with the common name and issuer URL for the testcase. + instanceCert := caChain.issueLeafCert(t, instanceKey.Public(), tc.commonName, tc.issuerURL) + imdsClient := &fakeIMDSClient{ + accessToken: accessToken, + signingCert: instanceCert, + signingKey: instanceKey, + subscription: "testsubscription", + vmID: instanceID, + } + + // Fake the HTTP client used to fetch the issuer CA. + httpClient := newFakeAzureIssuerHTTPClient(caChain.intermediateCertDER) + a.SetAzureJoinConfig(&azurejoin.AzureJoinConfig{ + CertificateAuthorities: []*x509.Certificate{caChain.rootCert}, + Verify: mockVerifyToken(nil), + GetVMClient: getVMClient, + IssuerHTTPClient: httpClient, + }) + + // Join via the legacy join service. + _, err = joinclient.LegacyJoin(t.Context(), joinclient.JoinParams{ + Token: "testtoken", + JoinMethod: types.JoinMethodAzure, + ID: state.IdentityID{ + Role: types.RoleInstance, + HostUUID: "testuuid", + }, + AuthClient: nopClient, + AzureParams: joinclient.AzureParams{ + ClientID: instanceID, + IMDSClient: imdsClient, + }, + }) + tc.errorAssertion(t, err) + + if tc.expecteRequestedIssuingCA { + require.Equal(t, 1, httpClient.called, "expected issuing CA to be requested once") + } else { + require.Equal(t, 0, httpClient.called, "expected issuing CA not to be requested") + } + }) + } +} + type fakeIMDSClient struct { accessToken string accessTokenErr error