diff --git a/integrations/event-handler/fake_fluentd_test.go b/integrations/event-handler/fake_fluentd_test.go index 72a363468ba15..64704847f2f35 100644 --- a/integrations/event-handler/fake_fluentd_test.go +++ b/integrations/event-handler/fake_fluentd_test.go @@ -66,7 +66,7 @@ func NewFakeFluentd(t *testing.T) *FakeFluentd { // writeCerts generates and writes temporary mTLS keys func (f *FakeFluentd) writeCerts() error { - g, err := GenerateMTLSCerts([]string{"localhost"}, []string{}, time.Hour, 1024) + g, err := GenerateMTLSCerts([]string{"localhost"}, []string{"127.0.0.1"}, time.Hour, 1024) if err != nil { return trace.Wrap(err) } diff --git a/integrations/event-handler/mtls_certs.go b/integrations/event-handler/mtls_certs.go index 553acf2c0ef01..b22c3fccc6731 100644 --- a/integrations/event-handler/mtls_certs.go +++ b/integrations/event-handler/mtls_certs.go @@ -17,6 +17,7 @@ limitations under the License. package main import ( + "cmp" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -29,7 +30,7 @@ import ( "github.com/gravitational/trace" ) -// MTLSCerts is the result for mTLS struct generator +// MTLSCerts is the result for mTLS certificate generator. type MTLSCerts struct { // caCert is a CA certificate struct used to generate mTLS CA cert and private key caCert x509.Certificate @@ -47,26 +48,18 @@ type MTLSCerts struct { // keyPair is the pair of certificate and private key type keyPair struct { - // PrivateKey represents certificate private key - PrivateKey *rsa.PrivateKey - // Certificate represents certificate + PrivateKey *rsa.PrivateKey Certificate []byte } -// GenerateMTLSCerts creates new MTLS certificate generator +// GenerateMTLSCerts generates server and client TLS certificates. func GenerateMTLSCerts(dnsNames []string, ips []string, ttl time.Duration, length int) (*MTLSCerts, error) { notBefore := time.Now() notAfter := notBefore.Add(ttl) - caDistinguishedName := pkix.Name{ - CommonName: "CA", - } - serverDistinguishedName := pkix.Name{ - CommonName: "Server", - } - clientDistinguishedName := pkix.Name{ - CommonName: "Client", - } + caDistinguishedName := pkix.Name{CommonName: "CA"} + serverDistinguishedName := pkix.Name{CommonName: "Server"} + clientDistinguishedName := pkix.Name{CommonName: "Client"} c := &MTLSCerts{ caCert: x509.Certificate{ @@ -98,66 +91,33 @@ func GenerateMTLSCerts(dnsNames []string, ips []string, ttl time.Duration, lengt } // Generate and assign serial numbers - sn, err := rand.Int(rand.Reader, maxBigInt) - if err != nil { - return nil, trace.Wrap(err) - } - - c.caCert.SerialNumber = sn - - sn, err = rand.Int(rand.Reader, maxBigInt) - if err != nil { - return nil, trace.Wrap(err) - } - - c.clientCert.SerialNumber = sn - - sn, err = rand.Int(rand.Reader, maxBigInt) - if err != nil { - return nil, trace.Wrap(err) + for _, cert := range []*x509.Certificate{&c.caCert, &c.clientCert, &c.serverCert} { + sn, err := rand.Int(rand.Reader, maxBigInt) + if err != nil { + return nil, trace.Wrap(err) + } + cert.SerialNumber = sn } - c.serverCert.SerialNumber = sn - // Append SANs and IPs to Server and Client certs - if err := c.appendSANs(&c.serverCert, dnsNames, ips); err != nil { - return nil, trace.Wrap(err) - } - if err := c.appendSANs(&c.clientCert, dnsNames, ips); err != nil { - return nil, trace.Wrap(err) - } + c.appendSANs(&c.serverCert, dnsNames, ips) + c.appendSANs(&c.clientCert, dnsNames, ips) // Run the generator - err = c.generate(length) - if err != nil { - return c, err + if err := c.generate(length); err != nil { + return c, trace.Wrap(err) } return c, nil } // appendSANs appends subjectAltName hosts and IPs -func (c MTLSCerts) appendSANs(cert *x509.Certificate, dnsNames []string, ips []string) error { +func (MTLSCerts) appendSANs(cert *x509.Certificate, dnsNames []string, ips []string) { cert.DNSNames = dnsNames - if len(ips) == 0 { - for _, name := range dnsNames { - ips, err := net.LookupIP(name) - if err != nil { - return trace.Wrap(err) - } - - if ips != nil { - cert.IPAddresses = append(cert.IPAddresses, ips...) - } - } - } else { - for _, ip := range ips { - cert.IPAddresses = append(cert.IPAddresses, net.ParseIP(ip)) - } + for _, ip := range ips { + cert.IPAddresses = append(cert.IPAddresses, net.ParseIP(ip)) } - - return nil } // Generate generates CA, server and client certificates @@ -193,16 +153,8 @@ func (c *MTLSCerts) genCertAndPK(length int, cert *x509.Certificate, parent *x50 } // Check if it's self-signed, assign signer and parent to self - s := signer - p := parent - - if s == nil { - s = pk - } - - if p == nil { - p = cert - } + s := cmp.Or(signer, pk) + p := cmp.Or(parent, cert) // Generate and sign cert certBytes, err := x509.CreateCertificate(rand.Reader, cert, p, &pk.PublicKey, s) diff --git a/integrations/event-handler/mtls_certs_test.go b/integrations/event-handler/mtls_certs_test.go index cfe725fd0c70f..693c0edfddc50 100644 --- a/integrations/event-handler/mtls_certs_test.go +++ b/integrations/event-handler/mtls_certs_test.go @@ -20,6 +20,7 @@ import ( "crypto/x509" "encoding/pem" "io" + "net" "os" "path/filepath" "testing" @@ -34,11 +35,11 @@ func TestGenerateClientCertFile(t *testing.T) { kp := "client.key" // Generate certs in memory - certs, err := GenerateMTLSCerts([]string{"localhost"}, nil, time.Second, 1024) + certs, err := GenerateMTLSCerts([]string{"localhost"}, []string{"127.0.0.1"}, time.Second, 1024) require.NoError(t, err) - require.NotNil(t, certs.caCert.Issuer) - require.NotNil(t, certs.clientCert.Issuer) - require.NotNil(t, certs.serverCert.Issuer) + require.NotZero(t, certs.caCert.Issuer) + require.NotZero(t, certs.clientCert.Issuer) + require.NotZero(t, certs.serverCert.Issuer) // don't be self-signed require.NotEqual(t, certs.serverCert.Issuer, certs.serverCert.Subject) require.NotEqual(t, certs.clientCert.Issuer, certs.clientCert.Subject) @@ -58,6 +59,7 @@ func TestGenerateClientCertFile(t *testing.T) { require.NotEmpty(t, certs.clientCert.DNSNames) // server leaf cert should have SAN DNS:localhost require.Equal(t, "localhost", certs.serverCert.DNSNames[0]) + require.Equal(t, net.ParseIP("127.0.0.1"), certs.serverCert.IPAddresses[0]) // Write the cert to the tempdir err = certs.ClientCert.WriteFile(filepath.Join(td, cp), filepath.Join(td, kp), ".")