Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/event-handler/fake_fluentd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func NewFakeFluentd(t *testing.T) *FakeFluentd {

// writeCerts generates and writes temporary mTLS keys
func (f *FakeFluentd) writeCerts() error {
g, err := GenerateMTLSCerts([]string{"localhost"}, []string{}, time.Hour, 1024)
g, err := GenerateMTLSCerts([]string{"localhost"}, []string{"127.0.0.1"}, time.Hour, 1024)
if err != nil {
return trace.Wrap(err)
}
Expand Down
92 changes: 22 additions & 70 deletions integrations/event-handler/mtls_certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package main

import (
"cmp"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
Expand All @@ -29,7 +30,7 @@ import (
"github.com/gravitational/trace"
)

// MTLSCerts is the result for mTLS struct generator
// MTLSCerts is the result for mTLS certificate generator.
type MTLSCerts struct {
// caCert is a CA certificate struct used to generate mTLS CA cert and private key
caCert x509.Certificate
Expand All @@ -47,26 +48,18 @@ type MTLSCerts struct {

// keyPair is the pair of certificate and private key
type keyPair struct {
// PrivateKey represents certificate private key
PrivateKey *rsa.PrivateKey
// Certificate represents certificate
PrivateKey *rsa.PrivateKey
Certificate []byte
}

// GenerateMTLSCerts creates new MTLS certificate generator
// GenerateMTLSCerts generates server and client TLS certificates.
func GenerateMTLSCerts(dnsNames []string, ips []string, ttl time.Duration, length int) (*MTLSCerts, error) {
notBefore := time.Now()
notAfter := notBefore.Add(ttl)

caDistinguishedName := pkix.Name{
CommonName: "CA",
}
serverDistinguishedName := pkix.Name{
CommonName: "Server",
}
clientDistinguishedName := pkix.Name{
CommonName: "Client",
}
caDistinguishedName := pkix.Name{CommonName: "CA"}
serverDistinguishedName := pkix.Name{CommonName: "Server"}
clientDistinguishedName := pkix.Name{CommonName: "Client"}

c := &MTLSCerts{
caCert: x509.Certificate{
Expand Down Expand Up @@ -98,66 +91,33 @@ func GenerateMTLSCerts(dnsNames []string, ips []string, ttl time.Duration, lengt
}

// Generate and assign serial numbers
sn, err := rand.Int(rand.Reader, maxBigInt)
if err != nil {
return nil, trace.Wrap(err)
}

c.caCert.SerialNumber = sn

sn, err = rand.Int(rand.Reader, maxBigInt)
if err != nil {
return nil, trace.Wrap(err)
}

c.clientCert.SerialNumber = sn

sn, err = rand.Int(rand.Reader, maxBigInt)
if err != nil {
return nil, trace.Wrap(err)
for _, cert := range []*x509.Certificate{&c.caCert, &c.clientCert, &c.serverCert} {
sn, err := rand.Int(rand.Reader, maxBigInt)
if err != nil {
return nil, trace.Wrap(err)
}
cert.SerialNumber = sn
}

c.serverCert.SerialNumber = sn

// Append SANs and IPs to Server and Client certs
if err := c.appendSANs(&c.serverCert, dnsNames, ips); err != nil {
return nil, trace.Wrap(err)
}
if err := c.appendSANs(&c.clientCert, dnsNames, ips); err != nil {
return nil, trace.Wrap(err)
}
c.appendSANs(&c.serverCert, dnsNames, ips)
c.appendSANs(&c.clientCert, dnsNames, ips)

// Run the generator
err = c.generate(length)
if err != nil {
return c, err
if err := c.generate(length); err != nil {
return c, trace.Wrap(err)
}

return c, nil
}

// appendSANs appends subjectAltName hosts and IPs
func (c MTLSCerts) appendSANs(cert *x509.Certificate, dnsNames []string, ips []string) error {
func (MTLSCerts) appendSANs(cert *x509.Certificate, dnsNames []string, ips []string) {
cert.DNSNames = dnsNames

if len(ips) == 0 {
for _, name := range dnsNames {
ips, err := net.LookupIP(name)
if err != nil {
return trace.Wrap(err)
}

if ips != nil {
cert.IPAddresses = append(cert.IPAddresses, ips...)
}
}
} else {
for _, ip := range ips {
cert.IPAddresses = append(cert.IPAddresses, net.ParseIP(ip))
}
for _, ip := range ips {
cert.IPAddresses = append(cert.IPAddresses, net.ParseIP(ip))
}

return nil
}

// Generate generates CA, server and client certificates
Expand Down Expand Up @@ -193,16 +153,8 @@ func (c *MTLSCerts) genCertAndPK(length int, cert *x509.Certificate, parent *x50
}

// Check if it's self-signed, assign signer and parent to self
s := signer
p := parent

if s == nil {
s = pk
}

if p == nil {
p = cert
}
s := cmp.Or(signer, pk)
p := cmp.Or(parent, cert)

// Generate and sign cert
certBytes, err := x509.CreateCertificate(rand.Reader, cert, p, &pk.PublicKey, s)
Expand Down
10 changes: 6 additions & 4 deletions integrations/event-handler/mtls_certs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"crypto/x509"
"encoding/pem"
"io"
"net"
"os"
"path/filepath"
"testing"
Expand All @@ -34,11 +35,11 @@ func TestGenerateClientCertFile(t *testing.T) {
kp := "client.key"

// Generate certs in memory
certs, err := GenerateMTLSCerts([]string{"localhost"}, nil, time.Second, 1024)
certs, err := GenerateMTLSCerts([]string{"localhost"}, []string{"127.0.0.1"}, time.Second, 1024)
require.NoError(t, err)
require.NotNil(t, certs.caCert.Issuer)
require.NotNil(t, certs.clientCert.Issuer)
require.NotNil(t, certs.serverCert.Issuer)
require.NotZero(t, certs.caCert.Issuer)
require.NotZero(t, certs.clientCert.Issuer)
require.NotZero(t, certs.serverCert.Issuer)
// don't be self-signed
require.NotEqual(t, certs.serverCert.Issuer, certs.serverCert.Subject)
require.NotEqual(t, certs.clientCert.Issuer, certs.clientCert.Subject)
Expand All @@ -58,6 +59,7 @@ func TestGenerateClientCertFile(t *testing.T) {
require.NotEmpty(t, certs.clientCert.DNSNames)
// server leaf cert should have SAN DNS:localhost
require.Equal(t, "localhost", certs.serverCert.DNSNames[0])
require.Equal(t, net.ParseIP("127.0.0.1"), certs.serverCert.IPAddresses[0])

// Write the cert to the tempdir
err = certs.ClientCert.WriteFile(filepath.Join(td, cp), filepath.Join(td, kp), ".")
Expand Down
Loading