Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions integration/hsm/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
Expand Down Expand Up @@ -60,9 +61,6 @@ func newTeleportService(t *testing.T, config *servicecfg.Config, name string) *t
serviceChannel: make(chan *service.TeleportProcess, 1),
errorChannel: make(chan error, 1),
}
t.Cleanup(func() {
require.NoError(t, s.close(), "error while closing %s during test cleanup", name)
})
return s
}

Expand Down Expand Up @@ -111,17 +109,43 @@ func (t *teleportService) waitForNewProcess(ctx context.Context) error {
return nil
}

func (t *teleportService) waitForEvent(ctx context.Context, event string) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
waitForEventErr := make(chan error)
go func() {
_, err := t.process.WaitForEvent(ctx, event)
select {
case waitForEventErr <- err:
case <-ctx.Done():
}
}()
select {
case err := <-waitForEventErr:
return trace.Wrap(err)
case err := <-t.errorChannel:
if err != nil {
return trace.Wrap(err, "process unexpectedly exited while waiting for event %s", event)
}
return trace.Errorf("process unexpectedly exited while waiting for event %s", event)
case <-t.serviceChannel:
return trace.Errorf("process unexpectedly reloaded while waiting for event %s", event)
case <-ctx.Done():
return trace.Wrap(ctx.Err())
}
}

func (t *teleportService) waitForReady(ctx context.Context) error {
t.log.Debugf("%s gen %d: waiting for TeleportReadyEvent", t.name, t.processGeneration)
if _, err := t.process.WaitForEvent(ctx, service.TeleportReadyEvent); err != nil {
return trace.Wrap(err, "timed out waiting for %s gen %d to be ready", t.name, t.processGeneration)
if err := t.waitForEvent(ctx, service.TeleportReadyEvent); err != nil {
return trace.Wrap(err, "waiting for %s gen %d to be ready", t.name, t.processGeneration)
}
t.log.Debugf("%s gen %d: got TeleportReadyEvent", t.name, t.processGeneration)
// If this is an Auth server, also wait for AuthIdentityEvent so that we
// can safely read the admin credentials and create a test client.
if t.process.GetAuthServer() != nil {
t.log.Debugf("%s gen %d: waiting for AuthIdentityEvent", t.name, t.processGeneration)
if _, err := t.process.WaitForEvent(ctx, service.AuthIdentityEvent); err != nil {
if err := t.waitForEvent(ctx, service.AuthIdentityEvent); err != nil {
return trace.Wrap(err, "%s gen %d: timed out waiting AuthIdentityEvent", t.name, t.processGeneration)
}
t.log.Debugf("%s gen %d: got AuthIdentityEvent", t.name, t.processGeneration)
Expand Down Expand Up @@ -170,7 +194,7 @@ func (t *teleportService) waitForLocalAdditionalKeys(ctx context.Context) error
if err != nil {
return trace.Wrap(err)
}
if usableKeysResult.CAHasUsableKeys {
if usableKeysResult.CAHasPreferredKeyType {
break
}
}
Expand All @@ -180,7 +204,7 @@ func (t *teleportService) waitForLocalAdditionalKeys(ctx context.Context) error

func (t *teleportService) waitForPhaseChange(ctx context.Context) error {
t.log.Debugf("%s gen %d: waiting for phase change", t.name, t.processGeneration)
if _, err := t.process.WaitForEvent(ctx, service.TeleportPhaseChangeEvent); err != nil {
if err := t.waitForEvent(ctx, service.TeleportPhaseChangeEvent); err != nil {
return trace.Wrap(err, "%s gen %d: timed out waiting for phase change", t.name, t.processGeneration)
}
t.log.Debugf("%s gen %d: changed phase", t.name, t.processGeneration)
Expand Down Expand Up @@ -237,6 +261,7 @@ func newAuthConfig(t *testing.T, log utils.Logger) *servicecfg.Config {
config.InstanceMetadataClient = cloud.NewDisabledIMDSClient()
config.MaxRetryPeriod = 25 * time.Millisecond
config.PollingPeriod = 2 * time.Second
config.Clock = fastClock(t)

config.Auth.Enabled = true
config.Auth.NoAudit = true
Expand Down Expand Up @@ -268,6 +293,7 @@ func newAuthConfig(t *testing.T, log utils.Logger) *servicecfg.Config {

func newProxyConfig(t *testing.T, authAddr utils.NetAddr, log utils.Logger) *servicecfg.Config {
config := servicecfg.MakeDefaultConfig()
config.Version = defaults.TeleportConfigVersionV3
config.DataDir = t.TempDir()
config.CachePolicy.Enabled = true
config.Auth.Enabled = false
Expand All @@ -278,6 +304,7 @@ func newProxyConfig(t *testing.T, authAddr utils.NetAddr, log utils.Logger) *ser
config.InstanceMetadataClient = cloud.NewDisabledIMDSClient()
config.MaxRetryPeriod = 25 * time.Millisecond
config.PollingPeriod = 2 * time.Second
config.Clock = fastClock(t)

config.Proxy.Enabled = true
config.Proxy.DisableWebInterface = true
Expand All @@ -288,3 +315,24 @@ func newProxyConfig(t *testing.T, authAddr utils.NetAddr, log utils.Logger) *ser

return config
}

// fastClock returns a clock that runs at ~20x realtime.
func fastClock(t *testing.T) clockwork.FakeClock {
// Start in the past to avoid cert not yet valid errors
clock := clockwork.NewFakeClockAt(time.Now().Add(-12 * time.Hour))
done := make(chan struct{})
t.Cleanup(func() { close(done) })
go func() {
for {
select {
case <-done:
return
default:
}
clock.BlockUntil(1)
clock.Advance(time.Second)
time.Sleep(50 * time.Millisecond)
}
}()
return clock
}
Loading