diff --git a/transport/tlscommon/tls_config.go b/transport/tlscommon/tls_config.go index 161da1f6..98a1b320 100644 --- a/transport/tlscommon/tls_config.go +++ b/transport/tlscommon/tls_config.go @@ -183,17 +183,24 @@ func trustRootCA(cfg *TLSConfig, peerCerts []*x509.Certificate) error { return fmt.Errorf("decode 'ca_trusted_fingerprint': %w", err) } + foundCADigests := []string{} + for _, cert := range peerCerts { + // Compute digest for each certificate. digest := sha256.Sum256(cert.Raw) + if cert.IsCA { + foundCADigests = append(foundCADigests, hex.EncodeToString(digest[:])) + } + if !bytes.Equal(digest[0:], fingerprint) { continue } // Make sure the fingerprint matches a CA certificate if !cert.IsCA { - logger.Info("Certificate matching 'ca_trusted_fingerprint' found, but is not a CA certificate") + logger.Warn("Certificate matching 'ca_trusted_fingerprint' found, but it is not a CA certificate. 'ca_trusted_fingerprint' can only be used to trust CA certificates.") continue } @@ -206,7 +213,13 @@ func trustRootCA(cfg *TLSConfig, peerCerts []*x509.Certificate) error { return nil } - logger.Warn("no CA certificate matching the fingerprint") + // if we are here, we didn't find any CA certificate matching the fingerprint + if len(foundCADigests) == 0 { + logger.Warn("The remote server's certificate is presented without its certificate chain. Using 'ca_trusted_fingerprint' requires that the server presents a certificate chain that includes the certificate's issuing certificate authority.") + } else { + logger.Warnf("The provided 'ca_trusted_fingerprint': '%s' does not match the fingerprint of any Certificate Authority present in the server's certificate chain. Found the following CA fingerprints instead: %v", cfg.CATrustedFingerprint, foundCADigests) + } + return nil } diff --git a/transport/tlscommon/tls_config_test.go b/transport/tlscommon/tls_config_test.go index b6a78261..d4da9cd2 100644 --- a/transport/tlscommon/tls_config_test.go +++ b/transport/tlscommon/tls_config_test.go @@ -18,8 +18,10 @@ package tlscommon import ( + "crypto/sha256" "crypto/tls" "crypto/x509" + "encoding/hex" "errors" "net" "net/http" @@ -29,6 +31,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/elastic-agent-libs/transport/tlscommontest" ) @@ -191,37 +194,60 @@ func TestTrustRootCA(t *testing.T) { nonEmptyCertPool.AddCert(certs["wildcard"]) nonEmptyCertPool.AddCert(certs["unknown_authority"]) - fingerprint := tlscommontest.GetCertFingerprint(certs["ca"]) + certfingerprint := tlscommontest.GetCertFingerprint(certs["correct"]) + cafingerprint := tlscommontest.GetCertFingerprint(certs["ca"]) + + unknownAuthorityDigest := sha256.Sum256(certs["unknown_authority"].Raw) + unknownAuthoritySha256 := hex.EncodeToString(unknownAuthorityDigest[:]) testCases := []struct { name string rootCAs *x509.CertPool caTrustedFingerprint string peerCerts []*x509.Certificate + expectingWarnings []string expectingError bool expectedRootCAsLen int }{ { name: "RootCA cert matches the fingerprint and is added to cfg.RootCAs", - caTrustedFingerprint: fingerprint, + caTrustedFingerprint: cafingerprint, peerCerts: []*x509.Certificate{certs["correct"], certs["ca"]}, expectedRootCAsLen: 1, }, { - name: "RootCA cert doesn not matche the fingerprint and is not added to cfg.RootCAs", - caTrustedFingerprint: fingerprint, - peerCerts: []*x509.Certificate{certs["correct"], certs["ca"]}, + name: "RootCA cert doesn't match the fingerprint and is not added to cfg.RootCAs", + caTrustedFingerprint: cafingerprint, + peerCerts: []*x509.Certificate{certs["correct"], certs["unknown_authority"]}, + expectingWarnings: []string{"The provided 'ca_trusted_fingerprint': '" + cafingerprint + "' does not match the fingerprint of any Certificate Authority present in the server's certificate chain. Found the following CA fingerprints instead: [" + unknownAuthoritySha256 + "]"}, + expectedRootCAsLen: 0, + }, + { + name: "Peer cert does not include a CA Certificate and is not added to cfg.RootCAs", + caTrustedFingerprint: cafingerprint, + peerCerts: []*x509.Certificate{certs["correct"]}, + expectingWarnings: []string{"The remote server's certificate is presented without its certificate chain. Using 'ca_trusted_fingerprint' requires that the server presents a certificate chain that includes the certificate's issuing certificate authority."}, expectedRootCAsLen: 0, }, + { + name: "fingerprint matches peer cert instead of the CA Certificate and is not added to cfg.RootCAs", + caTrustedFingerprint: certfingerprint, + peerCerts: []*x509.Certificate{certs["correct"]}, + expectingWarnings: []string{ + "Certificate matching 'ca_trusted_fingerprint' found, but it is not a CA certificate. 'ca_trusted_fingerprint' can only be used to trust CA certificates.", + "The remote server's certificate is presented without its certificate chain. Using 'ca_trusted_fingerprint' requires that the server presents a certificate chain that includes the certificate's issuing certificate authority.", + }, + expectedRootCAsLen: 0, + }, { name: "non empty CertPool has the RootCA added", rootCAs: nonEmptyCertPool, - caTrustedFingerprint: fingerprint, + caTrustedFingerprint: cafingerprint, peerCerts: []*x509.Certificate{certs["correct"], certs["ca"]}, expectedRootCAsLen: 3, }, { - name: "invalis HEX encoding", + name: "invalid HEX encoding", caTrustedFingerprint: "INVALID ENCODING", expectedRootCAsLen: 0, expectingError: true, @@ -234,6 +260,10 @@ func TestTrustRootCA(t *testing.T) { RootCAs: tc.rootCAs, CATrustedFingerprint: tc.caTrustedFingerprint, } + + // Capture the logs + _ = logp.DevelopmentSetup(logp.ToObserverOutput()) + err := trustRootCA(&cfg, tc.peerCerts) if tc.expectingError && err == nil { t.Fatal("expecting an error when calling trustRootCA") @@ -243,9 +273,29 @@ func TestTrustRootCA(t *testing.T) { t.Fatalf("did not expect an error calling trustRootCA: %v", err) } - if tc.expectedRootCAsLen != 0 { + if len(tc.expectingWarnings) > 0 { + warnings := logp.ObserverLogs().FilterLevelExact(logp.WarnLevel.ZapLevel()).TakeAll() + if len(warnings) == 0 { + t.Fatal("expecting a warning message") + } + if len(warnings) != len(tc.expectingWarnings) { + t.Fatalf("expecting %d warning messages, got %d", len(tc.expectingWarnings), len(warnings)) + } + + for i, expectedWarning := range tc.expectingWarnings { + if got := warnings[i].Message; got != expectedWarning { + t.Fatalf("expecting warning message to be '%s', got '%s'", expectedWarning, got) + } + } + } + + if tc.expectedRootCAsLen == 0 { + if cfg.RootCAs != nil { + t.Fatal("cfg.RootCAs should be nil") + } + } else { if cfg.RootCAs == nil { - t.Fatal("cfg.RootCAs cannot be nil") + t.Fatal("cfg.RootCAs should not be nil") } // we want to know the number of certificates in the CertPool (RootCAs), as it is not diff --git a/transport/tlscommontest/test_helper.go b/transport/tlscommontest/test_helper.go index 59ef83b4..918e5ecb 100644 --- a/transport/tlscommontest/test_helper.go +++ b/transport/tlscommontest/test_helper.go @@ -88,7 +88,7 @@ func GenTestCerts(t *testing.T) map[string]*x509.Certificate { "unknown_authority": { ca: unknownCA, keyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, - isCA: false, + isCA: true, dnsNames: []string{"localhost"}, // IPV4 and IPV6 ips: []net.IP{{127, 0, 0, 1}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},