Skip to content

Commit

Permalink
Fix race condition in pkg/apiserver/certificate unit tests (#6004)
Browse files Browse the repository at this point in the history
There was some "interference" between TestSelfSignedCertProviderRotate
and TestSelfSignedCertProviderRun. The root cause is that the
certutil.GenerateSelfSignedCertKey does not support a custom clock
implementation and always calls time.Now() to determine the current
time. It then adds a year to the current time to set the expiration time
of the certificate. This means that when rotateSelfSignedCertificate()
is called as part of TestSelfSignedCertProviderRotate, the new
certificate is already expired, and rotateSelfSignedCertificate() will
be called immediately a second time. By this time however,
TestSelfSignedCertProviderRotate has already exited, and we are already
running the next test, TestSelfSignedCertProviderRun. This creates a
race condition because the next test will overwrite
generateSelfSignedCertKey with a mock version, right as it is called by
the second call to rotateSelfSignedCertificate() from the previous
test's provider.

To avoid this race condition, we make generateSelfSignedCertKey a member
of selfSignedCertProvider.

Fixes #5977

Signed-off-by: Antonin Bas <[email protected]>
Co-authored-by: Quan Tian <[email protected]>
  • Loading branch information
antoninbas and tnqn committed Feb 23, 2024
1 parent 5789889 commit 56f2fb5
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 29 deletions.
47 changes: 35 additions & 12 deletions pkg/apiserver/certificate/selfsignedcert_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ import (
"antrea.io/antrea/pkg/util/env"
)

var (
loopbackAddresses = []net.IP{net.ParseIP("127.0.0.1"), net.IPv6loopback}
// Declared for unit testing.
generateSelfSignedCertKey = certutil.GenerateSelfSignedCertKey
)
var loopbackAddresses = []net.IP{net.ParseIP("127.0.0.1"), net.IPv6loopback}

// generateSelfSignedCertKeyFn represents a function which can create a self-signed certificate and
// key for the given host.
type generateSelfSignedCertKeyFn func(host string, alternateIPs []net.IP, alternateDNS []string) ([]byte, []byte, error)

type selfSignedCertProvider struct {
client kubernetes.Interface
Expand All @@ -69,23 +69,46 @@ type selfSignedCertProvider struct {
cert []byte
key []byte
verifyOptions *x509.VerifyOptions

// generateSelfSignedCertKey is the function used to generate self-signed certificates and keys.
// We use a struct member for unit testing.
generateSelfSignedCertKey generateSelfSignedCertKeyFn
}

var _ dynamiccertificates.CAContentProvider = &selfSignedCertProvider{}
var _ dynamiccertificates.ControllerRunner = &selfSignedCertProvider{}

func newSelfSignedCertProvider(client kubernetes.Interface, secureServing *options.SecureServingOptionsWithLoopback, caConfig *CAConfig) (*selfSignedCertProvider, error) {
type providerOption func(p *selfSignedCertProvider)

func withGenerateSelfSignedCertKeyFn(fn generateSelfSignedCertKeyFn) providerOption {
return func(p *selfSignedCertProvider) {
p.generateSelfSignedCertKey = fn
}
}

func withClock(clock clockutils.Clock) providerOption {
return func(p *selfSignedCertProvider) {
p.clock = clock
}
}

func newSelfSignedCertProvider(client kubernetes.Interface, secureServing *options.SecureServingOptionsWithLoopback, caConfig *CAConfig, options ...providerOption) (*selfSignedCertProvider, error) {
// Set the CertKey and CertDirectory to generate the certificate files.
secureServing.ServerCert.CertDirectory = caConfig.SelfSignedCertDir
secureServing.ServerCert.CertKey.CertFile = filepath.Join(caConfig.SelfSignedCertDir, caConfig.PairName+".crt")
secureServing.ServerCert.CertKey.KeyFile = filepath.Join(caConfig.SelfSignedCertDir, caConfig.PairName+".key")

provider := &selfSignedCertProvider{
client: client,
secureServing: secureServing,
caConfig: caConfig,
queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "selfSignedCertProvider"),
clock: clockutils.RealClock{},
client: client,
secureServing: secureServing,
caConfig: caConfig,
queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "selfSignedCertProvider"),
clock: clockutils.RealClock{},
generateSelfSignedCertKey: certutil.GenerateSelfSignedCertKey,
}

for _, option := range options {
option(provider)
}

if caConfig.TLSSecretName != "" {
Expand Down Expand Up @@ -233,7 +256,7 @@ func (p *selfSignedCertProvider) rotateSelfSignedCertificate() error {
}
if p.shouldRotateCertificate(cert) {
klog.InfoS("Generating self-signed cert")
if cert, key, err = generateSelfSignedCertKey(p.caConfig.ServiceName, loopbackAddresses, GetAntreaServerNames(p.caConfig.ServiceName)); err != nil {
if cert, key, err = p.generateSelfSignedCertKey(p.caConfig.ServiceName, loopbackAddresses, GetAntreaServerNames(p.caConfig.ServiceName)); err != nil {
return fmt.Errorf("unable to generate self-signed cert: %v", err)
}
// If Secret is specified, we should save the new certificate and key to it.
Expand Down
26 changes: 9 additions & 17 deletions pkg/apiserver/certificate/selfsignedcert_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ var (
testOneYearCert3, testOneYearKey3, _ = certutil.GenerateSelfSignedCertKeyWithFixtures("localhost", loopbackAddresses, nil, "")
)

func newTestSelfSignedCertProvider(t *testing.T, client *fakeclientset.Clientset, tlsSecretName string, minValidDuration time.Duration) *selfSignedCertProvider {
func newTestSelfSignedCertProvider(t *testing.T, client *fakeclientset.Clientset, tlsSecretName string, minValidDuration time.Duration, options ...providerOption) *selfSignedCertProvider {
secureServing := genericoptions.NewSecureServingOptions().WithLoopback()
caConfig := &CAConfig{
TLSSecretName: tlsSecretName,
Expand All @@ -57,7 +57,7 @@ func newTestSelfSignedCertProvider(t *testing.T, client *fakeclientset.Clientset
ServiceName: testServiceName,
PairName: testPairName,
}
p, err := newSelfSignedCertProvider(client, secureServing, caConfig)
p, err := newSelfSignedCertProvider(client, secureServing, caConfig, options...)
require.NoError(t, err)
return p
}
Expand Down Expand Up @@ -107,8 +107,7 @@ func TestSelfSignedCertProviderRotate(t *testing.T) {
defer cancel()
client := fakeclientset.NewSimpleClientset()
fakeClock := clocktesting.NewFakeClock(time.Now())
p := newTestSelfSignedCertProvider(t, client, testSecretName, time.Hour*24*90)
p.clock = fakeClock
p := newTestSelfSignedCertProvider(t, client, testSecretName, time.Hour*24*90, withClock(fakeClock))
certInFile, err := os.ReadFile(p.secureServing.ServerCert.CertKey.CertFile)
require.NoError(t, err)
keyInFile, _ := os.ReadFile(p.secureServing.ServerCert.CertKey.KeyFile)
Expand Down Expand Up @@ -161,7 +160,7 @@ func TestSelfSignedCertProviderRotate(t *testing.T) {
assert.NotEqual(c, map[string][]byte{
corev1.TLSCertKey: testOneYearCert,
corev1.TLSPrivateKeyKey: testOneYearKey,
}, gotSecret.Data, "Secret doesn't match")
}, gotSecret.Data, "Secret should not match")
}, 2*time.Second, 50*time.Millisecond)
}

Expand Down Expand Up @@ -264,15 +263,18 @@ func TestSelfSignedCertProviderRun(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer mockGenerateSelfSignedCertKey(testOneYearCert2, testOneYearKey2)()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var objs []runtime.Object
if tt.existingSecret != nil {
objs = append(objs, tt.existingSecret)
}
client := fakeclientset.NewSimpleClientset(objs...)
p := newTestSelfSignedCertProvider(t, client, tt.tlsSecretName, tt.minValidDuration)
// mock the generateSelfSignedCertKey fuction
generateSelfSignedCertKey := func(_ string, _ []net.IP, _ []string) ([]byte, []byte, error) {
return testOneYearCert2, testOneYearKey2, nil
}
p := newTestSelfSignedCertProvider(t, client, tt.tlsSecretName, tt.minValidDuration, withGenerateSelfSignedCertKeyFn(generateSelfSignedCertKey))
go p.Run(ctx, 1)
if tt.updatedSecret != nil {
client.CoreV1().Secrets(tt.updatedSecret.Namespace).Update(ctx, tt.updatedSecret, metav1.UpdateOptions{})
Expand All @@ -291,13 +293,3 @@ func TestSelfSignedCertProviderRun(t *testing.T) {
})
}
}

func mockGenerateSelfSignedCertKey(cert, key []byte) func() {
originalFn := generateSelfSignedCertKey
generateSelfSignedCertKey = func(_ string, _ []net.IP, _ []string) ([]byte, []byte, error) {
return cert, key, nil
}
return func() {
generateSelfSignedCertKey = originalFn
}
}

0 comments on commit 56f2fb5

Please sign in to comment.