Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion integration/proxy/teleterm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ func testGatewayCertRenewal(ctx context.Context, t *testing.T, params gatewayCer
CreateTshdEventsClientCredsFunc: func() (grpc.DialOption, error) {
return grpc.WithTransportCredentials(insecure.NewCredentials()), nil
},
ClientCache: clientcache.NewNoCache(storage),
CreateClientCacheFunc: func(resolveCluster daemon.ResolveClusterFunc) daemon.ClientCache {
return clientcache.NewNoCache(clientcache.ResolveClusterFunc(resolveCluster))
},
KubeconfigsDir: t.TempDir(),
AgentsDir: t.TempDir(),
})
Expand Down
151 changes: 147 additions & 4 deletions integration/teleterm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,18 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils"
api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1"
dbhelpers "github.com/gravitational/teleport/integration/db"
"github.com/gravitational/teleport/integration/helpers"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/mocku2f"
wancli "github.com/gravitational/teleport/lib/auth/webauthncli"
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/service/servicecfg"
Expand Down Expand Up @@ -105,7 +110,7 @@ func TestTeleterm(t *testing.T) {

t.Run("CreateConnectMyComputerToken", func(t *testing.T) {
t.Parallel()
testCreateConnectMyComputerToken(t, pack)
testCreateConnectMyComputerToken(t, pack, nil /* setupUserMFA */)
})

t.Run("WaitForConnectMyComputerNodeJoin", func(t *testing.T) {
Expand All @@ -123,6 +128,115 @@ func TestTeleterm(t *testing.T) {

testClientCache(t, pack, creds)
})

t.Run("with MFA", func(t *testing.T) {
authServer := pack.Root.Cluster.Process.GetAuthServer()
rpID, _, err := net.SplitHostPort(pack.Root.Cluster.Web)
require.NoError(t, err)

// Enforce MFA
_, err = authServer.UpsertAuthPreference(context.Background(), &types.AuthPreferenceV2{
Spec: types.AuthPreferenceSpecV2{
Type: constants.Local,
SecondFactor: constants.SecondFactorWebauthn,
Webauthn: &types.Webauthn{
RPID: rpID,
},
},
})
require.NoError(t, err)

// Remove MFA enforcement on cleanup.
t.Cleanup(func() {
_, err := authServer.UpsertAuthPreference(context.Background(), &types.AuthPreferenceV2{
Spec: types.AuthPreferenceSpecV2{
Type: constants.Local,
SecondFactor: constants.SecondFactorOff,
},
})
require.NoError(t, err)
})

setupUserMFA := func(t *testing.T, userName string, tshdEventsService *mockTSHDEventsService) client.WebauthnLoginFunc {
// Configure user account with an MFA device.
origin := fmt.Sprintf("https://%s", rpID)
device, err := mocku2f.Create()
require.NoError(t, err)
device.SetPasswordless()

token, err := authServer.CreateResetPasswordToken(context.Background(), auth.CreateUserTokenRequest{
Name: userName,
})
require.NoError(t, err)

tokenID := token.GetName()
res, err := authServer.CreateRegisterChallenge(context.Background(), &proto.CreateRegisterChallengeRequest{
TokenID: tokenID,
DeviceType: proto.DeviceType_DEVICE_TYPE_WEBAUTHN,
DeviceUsage: proto.DeviceUsage_DEVICE_USAGE_PASSWORDLESS,
})
require.NoError(t, err)
cc := wantypes.CredentialCreationFromProto(res.GetWebauthn())

ccr, err := device.SignCredentialCreation(origin, cc)
require.NoError(t, err)
_, err = authServer.ChangeUserAuthentication(context.Background(), &proto.ChangeUserAuthenticationRequest{
TokenID: tokenID,
NewMFARegisterResponse: &proto.MFARegisterResponse{
Response: &proto.MFARegisterResponse_Webauthn{
Webauthn: wantypes.CredentialCreationResponseToProto(ccr),
},
},
})
require.NoError(t, err)

// Prepare a function which simulates key tap.
var webauthLoginCallCount atomic.Uint32
webauthnLogin := func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) {
t.Helper()
updatedWebauthnLoginCallCount := webauthLoginCallCount.Add(1)

// When daemon.mfaPrompt prompts for MFA, it spawns two goroutines. One calls PromptMFA on
// tshdEventService and expects OTP in response (if available). Another calls this function.
// Whichever returns a non-error response first wins.
//
// Since in this test we use Webauthn, this function can return ASAP without giving a chance
// to the other to call PromptMFA. This would cause race conditions, as we might want to
// verify later in the test that PromptMFA has indeed been called.
//
// To ensure that, this function waits until PromptMFA has been called before proceeding.
// This also simulates a flow where the user was notified about the need to tap the key
// through the UI and then taps the key.
assert.EventuallyWithT(t, func(t *assert.CollectT) {
// Each call to webauthnLogin should have an equivalent call to PromptMFA and there should
// be no multiple concurrent calls.
assert.Equal(t, updatedWebauthnLoginCallCount, tshdEventsService.promptMFACallCount.Load(),
"Expected each call to webauthnLogin to have an equivalent call to PromptMFA")
}, 5*time.Second, 50*time.Millisecond)

car, err := device.SignAssertion(origin, assertion)
if err != nil {
return nil, "", err
}

carProto := wantypes.CredentialAssertionResponseToProto(car)

return &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: carProto,
},
}, "", nil
}

return webauthnLogin
}

t.Run("CreateConnectMyComputerToken", func(t *testing.T) {
t.Parallel()

testCreateConnectMyComputerToken(t, pack, setupUserMFA)
})
})
}

func testAddingRootCluster(t *testing.T, pack *dbhelpers.DatabasePack, creds *helpers.UserCreds) {
Expand Down Expand Up @@ -312,7 +426,6 @@ func testHeadlessWatcher(t *testing.T, pack *dbhelpers.DatabasePack, creds *help
ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_PENDING

// Start the tshd event service and connect the daemon to it.

tshdEventsService, addr := newMockTSHDEventsServiceServer(t)
err = daemonService.UpdateAndDialTshdEventsServerAddress(addr)
require.NoError(t, err)
Expand Down Expand Up @@ -698,7 +811,7 @@ func testCreateConnectMyComputerRole(t *testing.T, pack *dbhelpers.DatabasePack)
}
}

func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack) {
func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack, setupUserMFA setupUserMFAFunc) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

Expand All @@ -721,6 +834,12 @@ func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack
_, err = auth.CreateUser(ctx, authServer, userName, userRoles...)
require.NoError(t, err)

tshdEventsService, addr := newMockTSHDEventsServiceServer(t)
var webauthnLogin client.WebauthnLoginFunc
if setupUserMFA != nil {
webauthnLogin = setupUserMFA(t, userName, tshdEventsService)
}

// Log in as the new user.
creds, err := helpers.GenerateUserCreds(helpers.UserCredsRequest{
Process: pack.Root.Cluster.Process,
Expand All @@ -736,6 +855,7 @@ func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack
Dir: tc.KeysDir,
InsecureSkipVerify: tc.InsecureSkipVerify,
Clock: fakeClock,
WebauthnLogin: webauthnLogin,
})
require.NoError(t, err)

Expand All @@ -744,6 +864,9 @@ func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack
Storage: storage,
KubeconfigsDir: t.TempDir(),
AgentsDir: t.TempDir(),
CreateTshdEventsClientCredsFunc: func() (grpc.DialOption, error) {
return grpc.WithTransportCredentials(insecure.NewCredentials()), nil
},
})
require.NoError(t, err)
t.Cleanup(func() {
Expand All @@ -756,6 +879,9 @@ func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack
)
require.NoError(t, err)

err = daemonService.UpdateAndDialTshdEventsServerAddress(addr)
require.NoError(t, err)

// Call CreateConnectMyComputerNodeToken.
rootClusterName, _, err := net.SplitHostPort(pack.Root.Cluster.Web)
require.NoError(t, err)
Expand All @@ -774,6 +900,11 @@ func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack
require.Equal(t, types.SystemRoles{types.RoleNode}, tokenFromAuthServer.GetRoles())
// ...and is valid for no longer than 5 minutes.
require.LessOrEqual(t, tokenFromAuthServer.Expiry(), requestCreatedAt.Add(5*time.Minute))

if setupUserMFA != nil {
require.Equal(t, uint32(1), tshdEventsService.promptMFACallCount.Load(),
"Unexpected number of calls to TSHDEventsClient.PromptMFA")
}
}

func testWaitForConnectMyComputerNodeJoin(t *testing.T, pack *dbhelpers.DatabasePack, creds *helpers.UserCreds) {
Expand Down Expand Up @@ -943,12 +1074,16 @@ func mustLogin(t *testing.T, userName string, pack *dbhelpers.DatabasePack, cred
return tc
}

type setupUserMFAFunc func(t *testing.T, userName string, tshdEventsService *mockTSHDEventsService) client.WebauthnLoginFunc

type mockTSHDEventsService struct {
*api.UnimplementedTshdEventsServiceServer
api.UnimplementedTshdEventsServiceServer
sendPendingHeadlessAuthenticationCount atomic.Uint32
promptMFACallCount atomic.Uint32
}

func newMockTSHDEventsServiceServer(t *testing.T) (service *mockTSHDEventsService, addr string) {
t.Helper()
tshdEventsService := &mockTSHDEventsService{}

ls, err := net.Listen("tcp", "localhost:0")
Expand Down Expand Up @@ -981,3 +1116,11 @@ func (c *mockTSHDEventsService) SendPendingHeadlessAuthentication(context.Contex
c.sendPendingHeadlessAuthenticationCount.Add(1)
return &api.SendPendingHeadlessAuthenticationResponse{}, nil
}

func (c *mockTSHDEventsService) PromptMFA(context.Context, *api.PromptMFARequest) (*api.PromptMFAResponse, error) {
c.promptMFACallCount.Add(1)

// PromptMFAResponse returns the TOTP code, so PromptMFA itself
// needs to be implemented only once we need to test TOTP MFA.
return nil, trace.NotImplemented("mockTSHDEventsService does not implement PromptMFA")
}
17 changes: 11 additions & 6 deletions lib/teleterm/daemon/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,12 @@ type Config struct {
ConnectMyComputerNodeDelete *connectmycomputer.NodeDelete
ConnectMyComputerNodeName *connectmycomputer.NodeName

ClientCache ClientCache
CreateClientCacheFunc func(resolver ResolveClusterFunc) ClientCache
}

// ResolveClusterFunc returns a cluster by URI.
type ResolveClusterFunc func(uri uri.ResourceURI) (*clusters.Cluster, *client.TeleportClient, error)

// ClientCache stores clients keyed by cluster URI.
type ClientCache interface {
// Get returns a client from the cache if there is one,
Expand Down Expand Up @@ -157,11 +160,13 @@ func (c *Config) CheckAndSetDefaults() error {
c.ConnectMyComputerNodeName = nodeName
}

if c.ClientCache == nil {
c.ClientCache = clientcache.New(clientcache.Config{
Log: c.Log,
Resolver: c.Storage,
})
if c.CreateClientCacheFunc == nil {
c.CreateClientCacheFunc = func(resolver ResolveClusterFunc) ClientCache {
return clientcache.New(clientcache.Config{
Log: c.Log,
ResolveClusterFunc: clientcache.ResolveClusterFunc(resolver),
})
}
}

return nil
Expand Down
21 changes: 16 additions & 5 deletions lib/teleterm/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,24 @@ func New(cfg Config) (*Service, error) {

go connectUsageReporter.Run(closeContext)

return &Service{
service := &Service{
cfg: &cfg,
closeContext: closeContext,
cancel: cancel,
gateways: make(map[string]gateway.Gateway),
usageReporter: connectUsageReporter,
headlessWatcherClosers: make(map[string]context.CancelFunc),
}, nil
}

// TODO(gzdunek): The client cache should be created outside of daemon.New.
// Unfortunately, we have to do it here, because we need to pass
// Daemon.ResolveClusterURI as a cluster resolver.
// Why can't we pass Storage.GetByResourceURI?
// That's because Daemon.ResolveClusterURI sets a custom MFAPromptConstructor that
// shows an MFA prompt in Connect.
// At the level of Storage.ResolveClusterFunc we don't have access to it.
service.clientCache = cfg.CreateClientCacheFunc(service.ResolveClusterURI)
return service, nil
}

// relogin makes the Electron app display a login modal to trigger re-login.
Expand Down Expand Up @@ -802,7 +812,7 @@ func (s *Service) Stop() {

s.StopHeadlessWatchers()

if err := s.cfg.ClientCache.Clear(); err != nil {
if err := s.clientCache.Clear(); err != nil {
s.cfg.Log.WithError(err).Error("Failed to close remote clients")
}

Expand Down Expand Up @@ -1084,14 +1094,14 @@ func (s *Service) findGatewayByTargetURI(targetURI uri.ResourceURI) (gateway.Gat
// GetCachedClient returns a client from the cache if it exists,
// otherwise it dials the remote server.
func (s *Service) GetCachedClient(ctx context.Context, clusterURI uri.ResourceURI) (*client.ProxyClient, error) {
clt, err := s.cfg.ClientCache.Get(ctx, clusterURI)
clt, err := s.clientCache.Get(ctx, clusterURI)
return clt, trace.Wrap(err)
}

// ClearCachedClientsForRoot closes and removes clients from the cache
// for the root cluster and its leaf clusters.
func (s *Service) ClearCachedClientsForRoot(clusterURI uri.ResourceURI) error {
return trace.Wrap(s.cfg.ClientCache.ClearForRoot(clusterURI))
return trace.Wrap(s.clientCache.ClearForRoot(clusterURI))
}

// Service is the daemon service
Expand Down Expand Up @@ -1126,6 +1136,7 @@ type Service struct {
// headlessWatcherClosers holds a map of root cluster URIs to headless watchers.
headlessWatcherClosers map[string]context.CancelFunc
headlessWatcherClosersMu sync.Mutex
clientCache ClientCache
}

type CreateGatewayParams struct {
Expand Down
16 changes: 12 additions & 4 deletions lib/teleterm/daemon/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,9 @@ func TestGatewayCRUD(t *testing.T) {
GatewayCreator: mockGatewayCreator,
KubeconfigsDir: t.TempDir(),
AgentsDir: t.TempDir(),
ClientCache: fakeClientCache{},
CreateClientCacheFunc: func(resolver ResolveClusterFunc) ClientCache {
return fakeClientCache{}
},
})
require.NoError(t, err)

Expand Down Expand Up @@ -451,7 +453,9 @@ func TestRetryWithRelogin(t *testing.T) {
},
KubeconfigsDir: t.TempDir(),
AgentsDir: t.TempDir(),
ClientCache: fakeClientCache{},
CreateClientCacheFunc: func(resolver ResolveClusterFunc) ClientCache {
return fakeClientCache{}
},
})
require.NoError(t, err)

Expand Down Expand Up @@ -502,7 +506,9 @@ func TestImportantModalSemaphore(t *testing.T) {
},
KubeconfigsDir: t.TempDir(),
AgentsDir: t.TempDir(),
ClientCache: fakeClientCache{},
CreateClientCacheFunc: func(resolver ResolveClusterFunc) ClientCache {
return fakeClientCache{}
},
})
require.NoError(t, err)

Expand Down Expand Up @@ -651,7 +657,9 @@ func TestGetGatewayCLICommand(t *testing.T) {
},
KubeconfigsDir: t.TempDir(),
AgentsDir: t.TempDir(),
ClientCache: fakeClientCache{},
CreateClientCacheFunc: func(resolver ResolveClusterFunc) ClientCache {
return fakeClientCache{}
},
})
require.NoError(t, err)

Expand Down
Loading