Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions pkg/controller/common/certificates/ca_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
package certificates

import (
"context"
"fmt"
"time"

pkgerrors "github.com/pkg/errors"
corev1 "k8s.io/api/core/v1"

ulog "github.com/elastic/cloud-on-k8s/v3/pkg/utils/log"
)

// ParseCustomCASecret checks that mandatory fields are present and returns a CA struct.
Expand All @@ -32,6 +36,24 @@ func ParseCustomCASecret(s corev1.Secret) (*CA, error) {
return parseCAFromSecret(s, keyFileName, crtFileName)
}

// ValidateCustomCA validates the time-bounds of the given CA certificate and checks that the public key matches the
// private one. It returns nil if the CA is valid and an error otherwise.
func ValidateCustomCA(ctx context.Context, ca *CA) error {
now := time.Now()
log := ulog.FromContext(ctx)
switch {
case now.Before(ca.Cert.NotBefore):
return fmt.Errorf("the CA certificate is not yet valid")
case now.After(ca.Cert.NotAfter):
return fmt.Errorf("the CA certificate has expired")
case !PrivateMatchesPublicKey(ctx, ca.Cert.PublicKey, ca.PrivateKey):
return fmt.Errorf("the private key does not match the public one")
case now.After(ca.Cert.NotAfter.Add(-DefaultRotateBefore)):
log.Info("CA cert will expire soon", "subject", ca.Cert.Subject, "expiration", ca.Cert.NotAfter)
}
return nil
}

// parseCAFromSecret internal helper func to retrieve and parse a CA stored at the given keys in a Secret.
func parseCAFromSecret(s corev1.Secret, keyFileName string, crtFileName string) (*CA, error) {
// Validate private key
Expand Down
64 changes: 64 additions & 0 deletions pkg/controller/common/certificates/ca_secret_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
package certificates

import (
cryptorand "crypto/rand"
"crypto/rsa"
"testing"
"time"

"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
)

Expand Down Expand Up @@ -132,3 +136,63 @@ func TestParseCustomCASecret(t *testing.T) {
})
}
}

func TestValidateCustomCA(t *testing.T) {
tests := []struct {
name string
ca func() *CA
wantErr bool
}{
{
name: "valid ca",
ca: func() *CA {
testCa, err := NewSelfSignedCA(CABuilderOptions{})
require.NoError(t, err)
return testCa
},
wantErr: false,
},
{
name: "expired ca",
ca: func() *CA {
testCa, err := NewSelfSignedCA(CABuilderOptions{})
require.NoError(t, err)
testCa.Cert.NotAfter = time.Now().Add(-1 * time.Hour)
return testCa
},
wantErr: true,
},
{
name: "not valid yet ca",
ca: func() *CA {
testCa, err := NewSelfSignedCA(CABuilderOptions{})
require.NoError(t, err)
testCa.Cert.NotBefore = time.Now().Add(1 * time.Hour)
return testCa
},
wantErr: true,
},
{
name: "cert public key & private key mismatch",
ca: func() *CA {
testCa, err := NewSelfSignedCA(CABuilderOptions{})
require.NoError(t, err)
privateKey2, err := rsa.GenerateKey(cryptorand.Reader, 2048)
require.NoError(t, err)
testCa.PrivateKey = privateKey2
return testCa
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateCustomCA(t.Context(), tt.ca())
Comment thread
pkoutsovasilis marked this conversation as resolved.
if tt.wantErr {
require.Error(t, err, "expected error but got none")
} else {
require.NoError(t, err, "expected no err")
}
})
}
}
9 changes: 9 additions & 0 deletions pkg/controller/elasticsearch/certificates/transport/ca.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package transport

import (
"context"
"fmt"

corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
Expand Down Expand Up @@ -84,6 +85,14 @@ func ReconcileOrRetrieveCA(
return nil, err
}

if err := certificates.ValidateCustomCA(ctx, ca); err != nil {
// Surface validation errors to the user via an event
validationErr := fmt.Errorf("error validating custom CA certificate in %s/%s: %w",
customCASecret.GetNamespace(), customCASecret.GetName(), err)
driver.Recorder().Eventf(&es, corev1.EventTypeWarning, events.EventReasonValidation, validationErr.Error())
return nil, validationErr
}

// Garbage collect the self-signed CA secret which might be left over from an earlier revision on a best effort basis.
err = driver.K8sClient().Delete(ctx, &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Expand Down