diff --git a/api/client/client.go b/api/client/client.go index 18b65fddd5918..23628066534c5 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -71,6 +71,7 @@ import ( userpreferencespb "github.com/gravitational/teleport/api/gen/proto/go/userpreferences/v1" "github.com/gravitational/teleport/api/internalutils/stream" "github.com/gravitational/teleport/api/metadata" + "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" @@ -630,9 +631,9 @@ type Config struct { // PROXYHeaderGetter returns signed PROXY header that is sent to allow Proxy to propagate client's real IP to the // auth server from the Proxy's web server, when we create user's client for the web session. PROXYHeaderGetter PROXYHeaderGetter - // PromptAdminRequestMFA is used to prompt the user for MFA on admin requests when needed. + // MFAPromptConstructor is used to create MFA prompts when needed. // If nil, the client will not prompt for MFA. - PromptAdminRequestMFA func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) + MFAPromptConstructor mfa.PromptConstructor } // CheckAndSetDefaults checks and sets default config values. @@ -694,6 +695,11 @@ func (c *Client) GetConnection() *grpc.ClientConn { return c.conn } +// SetMFAPromptConstructor sets the MFA prompt constructor for this client. +func (c *Client) SetMFAPromptConstructor(pc mfa.PromptConstructor) { + c.c.MFAPromptConstructor = pc +} + // Close closes the Client connection to the auth server. func (c *Client) Close() error { if c.setClosed() && c.conn != nil { diff --git a/api/client/mfa.go b/api/client/mfa.go index 2f2f39cc894a3..cc81ca66405d6 100644 --- a/api/client/mfa.go +++ b/api/client/mfa.go @@ -22,12 +22,13 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/mfa" ) // performMFACeremony retrieves an MFA challenge from the server, prompts the // user to answer the challenge, and returns the resulting MFA response. -func (c *Client) performMFACeremony(ctx context.Context) (*proto.MFAAuthenticateResponse, error) { - if c.c.PromptAdminRequestMFA == nil { +func (c *Client) performMFACeremony(ctx context.Context, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) { + if c.c.MFAPromptConstructor == nil { return nil, trace.BadParameter("missing PromptAdminRequestMFA field, client cannot perform MFA ceremony") } @@ -38,7 +39,7 @@ func (c *Client) performMFACeremony(ctx context.Context) (*proto.MFAAuthenticate return nil, trace.Wrap(err) } - resp, err := c.c.PromptAdminRequestMFA(ctx, chal) + resp, err := c.c.MFAPromptConstructor(promptOpts...).Run(ctx, chal) if err != nil { return nil, trace.Wrap(err) } diff --git a/api/client/mfa_test.go b/api/client/mfa_test.go index d97b2d6f2da98..22ae22d69cd40 100644 --- a/api/client/mfa_test.go +++ b/api/client/mfa_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/mfa" ) const ( @@ -62,11 +63,8 @@ func TestPerformMFACeremony(t *testing.T) { } cfg := server.clientCfg() - cfg.PromptAdminRequestMFA = func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { - if chal.TOTP != nil { - return mfaTestResp, nil - } - return nil, trace.BadParameter("expected TOTP challenge") + cfg.MFAPromptConstructor = func(opts ...mfa.PromptOpt) mfa.Prompt { + return &fakeMFAPrompt{mfaTestResp} } clt, err := New(ctx, cfg) @@ -76,3 +74,14 @@ func TestPerformMFACeremony(t *testing.T) { require.NoError(t, err) require.Equal(t, mfaTestResp.Response, resp.Response) } + +type fakeMFAPrompt struct { + totpResp *proto.MFAAuthenticateResponse +} + +func (p *fakeMFAPrompt) Run(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { + if chal.TOTP != nil { + return p.totpResp, nil + } + return nil, trace.BadParameter("expected TOTP challenge") +} diff --git a/api/mfa/prompt.go b/api/mfa/prompt.go new file mode 100644 index 0000000000000..7b8a0121d7dd9 --- /dev/null +++ b/api/mfa/prompt.go @@ -0,0 +1,94 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mfa + +import ( + "context" + "fmt" + + "github.com/gravitational/teleport/api/client/proto" +) + +// Prompt is an MFA prompt. +type Prompt interface { + // Run prompts the user to complete an MFA authentication challenge. + Run(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) +} + +// PromptFunc is a function wrapper that implements the Prompt interface. +type PromptFunc func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) + +// Run prompts the user to complete an MFA authentication challenge. +func (f PromptFunc) Run(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { + return f(ctx, chal) +} + +// PromptConstructor is a function that creates a new MFA prompt. +type PromptConstructor func(...PromptOpt) Prompt + +// PromptConfig contains common mfa prompt config options. +type PromptConfig struct { + // PromptReason is an optional message to share with the user before an MFA Prompt. + // It is intended to provide context about why the user is being prompted where it may + // not be obvious, such as for admin actions or per-session MFA. + PromptReason string + // DeviceType is an optional device description to emphasize during the prompt. + DeviceType DeviceDescriptor + // Quiet suppresses users prompts. + Quiet bool +} + +// DeviceDescriptor is a descriptor for a device, such as "registered". +type DeviceDescriptor string + +// DeviceDescriptorRegistered is a registered device. +const DeviceDescriptorRegistered = "registered" + +// PromptOpt applies configuration options to a prompt. +type PromptOpt func(*PromptConfig) + +// WithQuiet sets the prompt's Quiet field. +func WithQuiet() PromptOpt { + return func(cfg *PromptConfig) { + cfg.Quiet = true + } +} + +// WithPromptReason sets the prompt's PromptReason field. +func WithPromptReason(hint string) PromptOpt { + return func(cfg *PromptConfig) { + cfg.PromptReason = hint + } +} + +// WithPromptReasonAdminAction sets the prompt's PromptReason field to a standard admin action message. +func WithPromptReasonAdminAction(actionName string) PromptOpt { + adminMFAPromptReason := fmt.Sprintf("MFA is required for admin-level API request: %q", actionName) + return WithPromptReason(adminMFAPromptReason) +} + +// WithPromptReasonSessionMFA sets the prompt's PromptReason field to a standard session mfa message. +func WithPromptReasonSessionMFA(serviceType, serviceName string) PromptOpt { + return WithPromptReason(fmt.Sprintf("MFA is required to access %s %q", serviceType, serviceName)) +} + +// WithPromptDeviceType sets the prompt's DeviceType field. +func WithPromptDeviceType(deviceType DeviceDescriptor) PromptOpt { + return func(cfg *PromptConfig) { + cfg.DeviceType = deviceType + } +} diff --git a/api/utils/grpc/interceptors/mfa.go b/api/utils/grpc/interceptors/mfa.go index 546054e2551f0..7bc3c3d6d8bb6 100644 --- a/api/utils/grpc/interceptors/mfa.go +++ b/api/utils/grpc/interceptors/mfa.go @@ -28,14 +28,14 @@ import ( // RetryWithMFAUnaryInterceptor intercepts a GRPC client unary call to check if the // error indicates that the client should retry with MFA verification. -func RetryWithMFAUnaryInterceptor(mfaCeremony func(ctx context.Context) (*proto.MFAAuthenticateResponse, error)) grpc.UnaryClientInterceptor { +func RetryWithMFAUnaryInterceptor(mfaCeremony func(ctx context.Context, opts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error)) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { err := invoker(ctx, method, req, reply, cc, opts...) if !errors.Is(trail.FromGRPC(err), &mfa.ErrAdminActionMFARequired) { return err } - mfaResp, ceremonyErr := mfaCeremony(ctx) + mfaResp, ceremonyErr := mfaCeremony(ctx, mfa.WithPromptReasonAdminAction(method)) if ceremonyErr != nil { return trace.NewAggregate(trail.FromGRPC(err), ceremonyErr) } diff --git a/api/utils/grpc/interceptors/mfa_test.go b/api/utils/grpc/interceptors/mfa_test.go index 068fb78fe7972..aaf77923fc5cf 100644 --- a/api/utils/grpc/interceptors/mfa_test.go +++ b/api/utils/grpc/interceptors/mfa_test.go @@ -98,7 +98,7 @@ func TestRetryWithMFA(t *testing.T) { t.Run("with interceptor", func(t *testing.T) { t.Run("ok mfa ceremony", func(t *testing.T) { - okMFACeremony := func(ctx context.Context) (*proto.MFAAuthenticateResponse, error) { + okMFACeremony := func(ctx context.Context, opts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) { return &proto.MFAAuthenticateResponse{ Response: &proto.MFAAuthenticateResponse_TOTP{ TOTP: &proto.TOTPResponse{ @@ -125,7 +125,7 @@ func TestRetryWithMFA(t *testing.T) { t.Run("nok mfa ceremony", func(t *testing.T) { mfaCeremonyErr := trace.BadParameter("client does not support mfa") - nokMFACeremony := func(ctx context.Context) (*proto.MFAAuthenticateResponse, error) { + nokMFACeremony := func(ctx context.Context, opts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) { return nil, mfaCeremonyErr } conn, err := grpc.Dial( diff --git a/lib/auth/authclient/authclient.go b/lib/auth/authclient/authclient.go index 4ae847f667b7a..6c15b756256ba 100644 --- a/lib/auth/authclient/authclient.go +++ b/lib/auth/authclient/authclient.go @@ -29,8 +29,8 @@ import ( "github.com/gravitational/teleport/api/breaker" apiclient "github.com/gravitational/teleport/api/client" - "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/client/webclient" + "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/utils" @@ -50,14 +50,14 @@ type Config struct { CircuitBreakerConfig breaker.Config // DialTimeout determines how long to wait for dialing to succeed before aborting. DialTimeout time.Duration - // PromptAdminRequestMFA is used to prompt the user for MFA on admin requests when needed. + // MFAPromptConstructor is used to create MFA prompts when needed. // If nil, the client will not prompt for MFA. - PromptAdminRequestMFA func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) + MFAPromptConstructor mfa.PromptConstructor } // Connect creates a valid client connection to the auth service. It may // connect directly to the auth server, or tunnel through the proxy. -func Connect(ctx context.Context, cfg *Config) (auth.ClientI, error) { +func Connect(ctx context.Context, cfg *Config) (*auth.Client, error) { cfg.Log.Debugf("Connecting to: %v.", cfg.AuthServers) directClient, err := connectViaAuthDirect(ctx, cfg) @@ -83,7 +83,7 @@ func Connect(ctx context.Context, cfg *Config) (auth.ClientI, error) { ) } -func connectViaAuthDirect(ctx context.Context, cfg *Config) (auth.ClientI, error) { +func connectViaAuthDirect(ctx context.Context, cfg *Config) (*auth.Client, error) { // Try connecting to the auth server directly over TLS. directClient, err := auth.NewClient(apiclient.Config{ Addrs: utils.NetAddrsToStrings(cfg.AuthServers), @@ -93,7 +93,7 @@ func connectViaAuthDirect(ctx context.Context, cfg *Config) (auth.ClientI, error CircuitBreakerConfig: cfg.CircuitBreakerConfig, InsecureAddressDiscovery: cfg.TLS.InsecureSkipVerify, DialTimeout: cfg.DialTimeout, - PromptAdminRequestMFA: cfg.PromptAdminRequestMFA, + MFAPromptConstructor: cfg.MFAPromptConstructor, }) if err != nil { return nil, trace.Wrap(err) @@ -109,7 +109,7 @@ func connectViaAuthDirect(ctx context.Context, cfg *Config) (auth.ClientI, error return directClient, nil } -func connectViaProxyTunnel(ctx context.Context, cfg *Config) (auth.ClientI, error) { +func connectViaProxyTunnel(ctx context.Context, cfg *Config) (*auth.Client, error) { // If direct dial failed, we may have a proxy address in // cfg.AuthServers. Try connecting to the reverse tunnel // endpoint and make a client over that. @@ -145,7 +145,7 @@ func connectViaProxyTunnel(ctx context.Context, cfg *Config) (auth.ClientI, erro Credentials: []apiclient.Credentials{ apiclient.LoadTLS(cfg.TLS), }, - PromptAdminRequestMFA: cfg.PromptAdminRequestMFA, + MFAPromptConstructor: cfg.MFAPromptConstructor, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/client/api.go b/lib/client/api.go index a2a49b57a417a..e347a36d22874 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -55,6 +55,7 @@ import ( apidefaults "github.com/gravitational/teleport/api/defaults" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" kubeproto "github.com/gravitational/teleport/api/gen/proto/go/teleport/kube/v1" + "github.com/gravitational/teleport/api/mfa" apitracing "github.com/gravitational/teleport/api/observability/tracing" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/profile" @@ -69,7 +70,6 @@ import ( "github.com/gravitational/teleport/lib/auth/touchid" wancli "github.com/gravitational/teleport/lib/auth/webauthncli" "github.com/gravitational/teleport/lib/authz" - "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/client/terminal" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/devicetrust" @@ -465,10 +465,6 @@ type Config struct { // Defaults to [dtenroll.AutoEnroll]. dtAutoEnroll dtAutoEnrollFunc - // PromptMFAFunc allows tests to override the default MFA prompt function. - // Defaults to [mfa.NewPrompt().Run]. - PromptMFAFunc PromptMFAFunc - // WebauthnLogin allows tests to override the Webauthn Login func. // Defaults to [wancli.Login]. WebauthnLogin WebauthnLoginFunc @@ -2729,7 +2725,7 @@ func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterCli } if tc.SSHLogDir != "" { - if err := os.MkdirAll(tc.SSHLogDir, 0700); err != nil { + if err := os.MkdirAll(tc.SSHLogDir, 0o700); err != nil { return trace.ConvertSystemError(err) } } @@ -2963,7 +2959,7 @@ func (tc *TeleportClient) ConnectToCluster(ctx context.Context) (*ClusterClient, } authClientCfg := pclt.ClientConfig(ctx, cluster) - authClientCfg.PromptAdminRequestMFA = tc.NewMFAPrompt(mfa.WithHintBeforePrompt(mfa.AdminMFAHintBeforePrompt)) + authClientCfg.MFAPromptConstructor = tc.NewMFAPrompt authClient, err := auth.NewClient(authClientCfg) if err != nil { return nil, trace.NewAggregate(err, pclt.Close()) @@ -3824,7 +3820,7 @@ func (tc *TeleportClient) mfaLocalLoginWeb(ctx context.Context, priv *keys.Priva SSHLogin: sshLogin, User: tc.Username, Password: password, - PromptMFA: tc.PromptMFA, + PromptMFA: tc.NewMFAPrompt(), }) return clt, session, trace.Wrap(err) } @@ -4123,7 +4119,7 @@ func (tc *TeleportClient) mfaLocalLogin(ctx context.Context, priv *keys.PrivateK SSHLogin: sshLogin, User: tc.Username, Password: password, - PromptMFA: tc.PromptMFA, + PromptMFA: tc.NewMFAPrompt(), }) return response, trace.Wrap(err) @@ -5174,7 +5170,7 @@ func (tc *TeleportClient) NewKubernetesServiceClient(ctx context.Context, cluste }, ALPNConnUpgradeRequired: tc.TLSRoutingConnUpgradeRequired, InsecureAddressDiscovery: tc.InsecureSkipVerify, - PromptAdminRequestMFA: tc.NewMFAPrompt(mfa.WithHintBeforePrompt(mfa.AdminMFAHintBeforePrompt)), + MFAPromptConstructor: tc.NewMFAPrompt, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/client/client.go b/lib/client/client.go index 5c1b7ff33bd2d..cdb57eb203f8f 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -45,13 +45,13 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/client/webclient" apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/api/observability/tracing" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/sshutils" @@ -443,7 +443,7 @@ func WithMFARequired(mfaRequired *bool) IssueUserCertsOpt { } // IssueUserCertsWithMFA generates a single-use certificate for the user. -func (proxy *ProxyClient) IssueUserCertsWithMFA(ctx context.Context, params ReissueParams, promptMFA PromptMFAFunc, applyOpts ...IssueUserCertsOpt) (*Key, error) { +func (proxy *ProxyClient) IssueUserCertsWithMFA(ctx context.Context, params ReissueParams, mfaPrompt mfa.Prompt, applyOpts ...IssueUserCertsOpt) (*Key, error) { ctx, span := proxy.Tracer.Start( ctx, "proxyClient/IssueUserCertsWithMFA", @@ -558,7 +558,7 @@ func (proxy *ProxyClient) IssueUserCertsWithMFA(ctx context.Context, params Reis key, _, err = PerformMFACeremony(ctx, PerformMFACeremonyParams{ CurrentAuthClient: proxy.currentCluster, RootAuthClient: clt, - PromptMFA: promptMFA, + MFAPrompt: mfaPrompt, MFAAgainstRoot: params.RouteToCluster == rootClusterName, MFARequiredReq: nil, // No need to check if we got this far. CertsReq: certsReq, @@ -1076,7 +1076,7 @@ func (proxy *ProxyClient) ConnectToAuthServiceThroughALPNSNIProxy(ctx context.Co ALPNConnUpgradeRequired: proxy.teleportClient.IsALPNConnUpgradeRequiredForWebProxy(ctx, proxyAddr), PROXYHeaderGetter: CreatePROXYHeaderGetter(ctx, proxy.teleportClient.PROXYSigner), InsecureAddressDiscovery: proxy.teleportClient.InsecureSkipVerify, - PromptAdminRequestMFA: proxy.teleportClient.NewMFAPrompt(mfa.WithHintBeforePrompt(mfa.AdminMFAHintBeforePrompt)), + MFAPromptConstructor: proxy.teleportClient.NewMFAPrompt, }) if err != nil { return nil, trace.Wrap(err) @@ -1154,8 +1154,8 @@ func (proxy *ProxyClient) ConnectToCluster(ctx context.Context, clusterName stri Credentials: []client.Credentials{ client.LoadTLS(tlsConfig), }, - CircuitBreakerConfig: breaker.NoopBreakerConfig(), - PromptAdminRequestMFA: proxy.teleportClient.NewMFAPrompt(mfa.WithHintBeforePrompt(mfa.AdminMFAHintBeforePrompt)), + CircuitBreakerConfig: breaker.NoopBreakerConfig(), + MFAPromptConstructor: proxy.teleportClient.NewMFAPrompt, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/client/cluster_client.go b/lib/client/cluster_client.go index d9952c16ac35b..33152f6b2c142 100644 --- a/lib/client/cluster_client.go +++ b/lib/client/cluster_client.go @@ -24,6 +24,7 @@ import ( "github.com/gravitational/teleport/api/client/proto" proxyclient "github.com/gravitational/teleport/api/client/proxy" + "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/services" ) @@ -204,7 +205,7 @@ func (c *ClusterClient) performMFACeremony(ctx context.Context, rootClient *Clus key, _, err = PerformMFACeremony(ctx, PerformMFACeremonyParams{ CurrentAuthClient: c.AuthClient, RootAuthClient: rootClient.AuthClient, - PromptMFA: c.tc.PromptMFA, + MFAPrompt: c.tc.NewMFAPrompt(), MFAAgainstRoot: c.cluster == rootClient.cluster, MFARequiredReq: params.isMFARequiredRequest(c.tc.HostLogin), CertsReq: certsReq, @@ -235,8 +236,8 @@ type PerformMFACeremonyParams struct { // This is the client used to acquire the authn challenge and issue the user // certificates. RootAuthClient PerformMFARootClient - // PromptMFA is used to prompt the user for an MFA solution. - PromptMFA PromptMFAFunc + // MFAPrompt is used to prompt the user for an MFA solution. + MFAPrompt mfa.Prompt // MFAAgainstRoot tells whether to run the MFA required check against root or // current cluster. @@ -304,7 +305,7 @@ func PerformMFACeremony(ctx context.Context, params PerformMFACeremonyParams) (* } // Prompt user for solution (eg, security key touch). - authnSolved, err := params.PromptMFA(ctx, authnChal) + authnSolved, err := params.MFAPrompt.Run(ctx, authnChal) if err != nil { return nil, nil, trace.Wrap(err) } diff --git a/lib/client/kubesession.go b/lib/client/kubesession.go index ea6217402a106..bc4d3b47a4030 100644 --- a/lib/client/kubesession.go +++ b/lib/client/kubesession.go @@ -30,8 +30,8 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/client/terminal" "github.com/gravitational/teleport/lib/kube/proxy/streamproto" "github.com/gravitational/teleport/lib/utils" diff --git a/lib/client/local_proxy_middleware.go b/lib/client/local_proxy_middleware.go index 1466b556260fa..94acc39316993 100644 --- a/lib/client/local_proxy_middleware.go +++ b/lib/client/local_proxy_middleware.go @@ -19,7 +19,6 @@ package client import ( "context" "crypto/tls" - "fmt" "net" "time" @@ -27,8 +26,8 @@ import ( "github.com/jonboulle/clockwork" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/api/utils/keys" - "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/srv/alpnproxy" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" @@ -89,7 +88,6 @@ func (c *DBCertChecker) renewCerts(ctx context.Context, lp *alpnproxy.LocalProxy accessRequests = profile.ActiveRequests.AccessRequests } - hint := fmt.Sprintf("MFA is required to access database %q", c.dbRoute.ServiceName) var key *Key if err := RetryWithRelogin(ctx, c.tc, func() error { newKey, err := c.tc.IssueUserCertsWithMFA(ctx, ReissueParams{ @@ -102,7 +100,7 @@ func (c *DBCertChecker) renewCerts(ctx context.Context, lp *alpnproxy.LocalProxy }, AccessRequests: accessRequests, RequesterName: proto.UserCertsRequest_TSH_DB_LOCAL_PROXY_TUNNEL, - }, mfa.WithHintBeforePrompt(hint)) + }, mfa.WithPromptReasonSessionMFA("database", c.dbRoute.ServiceName)) key = newKey return trace.Wrap(err) }); err != nil { diff --git a/lib/client/mfa.go b/lib/client/mfa.go index 03aff3b690316..f4734b90a06bc 100644 --- a/lib/client/mfa.go +++ b/lib/client/mfa.go @@ -20,42 +20,36 @@ import ( "context" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/mfa" wancli "github.com/gravitational/teleport/lib/auth/webauthncli" wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" - "github.com/gravitational/teleport/lib/client/mfa" + libmfa "github.com/gravitational/teleport/lib/client/mfa" ) -// PromptMFAFunc matches the signature of [mfa.Prompt.Run]. -type PromptMFAFunc func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) - // WebauthnLoginFunc matches the signature of [wancli.Login]. type WebauthnLoginFunc func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) // NewMFAPrompt creates a new MFA prompt from client settings. -func (tc *TeleportClient) NewMFAPrompt(opts ...mfa.PromptOpt) PromptMFAFunc { - if tc.PromptMFAFunc != nil { - return tc.PromptMFAFunc - } +func (tc *TeleportClient) NewMFAPrompt(opts ...mfa.PromptOpt) mfa.Prompt { + cfg := tc.newPromptConfig(opts...) + var prompt mfa.Prompt = libmfa.NewCLIPrompt(cfg, tc.Stderr) + return prompt +} - prompt := mfa.NewPrompt(tc.WebProxyAddr) - prompt.AuthenticatorAttachment = tc.AuthenticatorAttachment - prompt.PreferOTP = tc.PreferOTP - prompt.AllowStdinHijack = tc.AllowStdinHijack +// PromptMFA runs a standard MFA prompt from client settings. +func (tc *TeleportClient) PromptMFA(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { + return tc.NewMFAPrompt().Run(ctx, chal) +} - if tc.WebauthnLogin != nil { - prompt.WebauthnLogin = tc.WebauthnLogin - prompt.WebauthnSupported = true - } +func (tc *TeleportClient) newPromptConfig(opts ...mfa.PromptOpt) *libmfa.PromptConfig { + cfg := libmfa.NewPromptConfig(tc.WebProxyAddr, opts...) + cfg.AuthenticatorAttachment = tc.AuthenticatorAttachment + cfg.PreferOTP = tc.PreferOTP + cfg.AllowStdinHijack = tc.AllowStdinHijack - for _, opt := range opts { - opt(prompt) + if tc.WebauthnLogin != nil { + cfg.WebauthnLoginFunc = tc.WebauthnLogin + cfg.WebauthnSupported = true } - - return prompt.Run -} - -// PromptMFA prompts for MFA for the given challenge using the clients standard settings. -// Use [NewMFAPrompt] to create a prompt with customizable settings. -func (tc *TeleportClient) PromptMFA(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { - return tc.NewMFAPrompt()(ctx, chal) + return cfg } diff --git a/lib/client/mfa/cli.go b/lib/client/mfa/cli.go new file mode 100644 index 0000000000000..b3ad1af8cb52a --- /dev/null +++ b/lib/client/mfa/cli.go @@ -0,0 +1,213 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mfa + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/utils/prompt" + wancli "github.com/gravitational/teleport/lib/auth/webauthncli" + wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" + "github.com/gravitational/teleport/lib/auth/webauthnwin" +) + +// CLIPrompt is the default CLI mfa prompt implementation. +type CLIPrompt struct { + cfg PromptConfig + writer io.Writer +} + +// NewCLIPrompt returns a new CLI mfa prompt with the config and writer. +func NewCLIPrompt(cfg *PromptConfig, writer io.Writer) *CLIPrompt { + return &CLIPrompt{ + cfg: *cfg, + writer: writer, + } +} + +// Run prompts the user to complete an MFA authentication challenge. +func (c *CLIPrompt) Run(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { + var wg sync.WaitGroup + runOpts, err := c.cfg.getRunOptions(ctx, chal) + if err != nil { + return nil, trace.Wrap(err) + } + + type response struct { + kind string + resp *proto.MFAAuthenticateResponse + err error + } + respC := make(chan response, 2) + + ctx, cancel := context.WithCancel(ctx) + defer func() { + cancel() + // Wait for all goroutines to complete to ensure there are no leaks. + wg.Wait() + }() + + // Use variables below to cancel OTP reads and make sure the goroutine exited. + otpCtx, otpCancel := context.WithCancel(ctx) + defer otpCancel() + otpDone := make(chan struct{}) + otpCancelAndWait := func() { + otpCancel() + <-otpDone + } + + // Fire TOTP goroutine. + if runOpts.promptTOTP { + wg.Add(1) + go func() { + defer wg.Done() + defer close(otpDone) + + // Let Webauthn take the prompt below if applicable. + quiet := c.cfg.Quiet || runOpts.promptWebauthn + + resp, err := c.promptTOTP(otpCtx, chal, quiet) + respC <- response{kind: "TOTP", resp: resp, err: err} + }() + } + + // Fire Webauthn goroutine. + if runOpts.promptWebauthn { + wg.Add(1) + go func() { + defer wg.Done() + + // get webauthn prompt and wrap with otp context handler. + prompt := &webauthnPromptWithOTP{ + LoginPrompt: c.getWebauthnPrompt(ctx, runOpts.promptTOTP), + otpCancelAndWait: otpCancelAndWait, + } + + resp, err := c.promptWebauthn(ctx, chal, prompt) + respC <- response{kind: "WEBAUTHN", resp: resp, err: err} + }() + } + + // Wait for the 1-2 authn goroutines above to complete, then close respC. + go func() { + wg.Wait() + close(respC) + }() + + // Wait for a successful response, or terminating error, from the 1-2 authn goroutines. + // The goroutine above will ensure the response channel is closed once all goroutines are done. + for resp := range respC { + switch err := resp.err; { + case errors.Is(err, wancli.ErrUsingNonRegisteredDevice): + // Surface error immediately. + return nil, trace.Wrap(resp.err) + case err != nil: + c.cfg.Log.WithError(err).Debugf("%s authentication failed", resp.kind) + // Continue to give the other authn goroutine a chance to succeed. + // If both have failed, this will exit the loop. + continue + } + + // Return successful response. + return resp.resp, nil + } + + // If no successful response is returned, this means the authn goroutines were unsuccessful. + // This usually occurs when the prompt times out or no devices are available to prompt. + // Return a user readable error message. + return nil, trace.BadParameter("failed to authenticate using available MFA devices, rerun the command with '-d' to see error details for each device") +} + +func (c *CLIPrompt) promptTOTP(ctx context.Context, chal *proto.MFAAuthenticateChallenge, quiet bool) (*proto.MFAAuthenticateResponse, error) { + var msg string + if !quiet { + msg = fmt.Sprintf("Enter an OTP code from a %sdevice", c.promptDevicePrefix()) + } + + otp, err := prompt.Password(ctx, c.writer, prompt.Stdin(), msg) + if err != nil { + return nil, trace.Wrap(err) + } + + return &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_TOTP{ + TOTP: &proto.TOTPResponse{Code: otp}, + }, + }, nil +} + +func (c *CLIPrompt) getWebauthnPrompt(ctx context.Context, withTOTP bool) wancli.LoginPrompt { + writer := c.writer + if c.cfg.Quiet { + writer = io.Discard + } + + prompt := wancli.NewDefaultPrompt(ctx, writer) + prompt.SecondTouchMessage = fmt.Sprintf("Tap your %ssecurity key to complete login", c.promptDevicePrefix()) + prompt.FirstTouchMessage = fmt.Sprintf("Tap any %ssecurity key", c.promptDevicePrefix()) + + if withTOTP { + prompt.FirstTouchMessage = fmt.Sprintf("Tap any %ssecurity key or enter a code from a %sOTP device", c.promptDevicePrefix(), c.promptDevicePrefix()) + + // Customize Windows prompt directly. + // Note that the platform popup is a modal and will only go away if canceled. + webauthnwin.PromptPlatformMessage = "Follow the OS dialogs for platform authentication, or enter an OTP code here:" + defer webauthnwin.ResetPromptPlatformMessage() + } + + return prompt +} + +func (c *CLIPrompt) promptWebauthn(ctx context.Context, chal *proto.MFAAuthenticateChallenge, prompt wancli.LoginPrompt) (*proto.MFAAuthenticateResponse, error) { + opts := &wancli.LoginOpts{AuthenticatorAttachment: c.cfg.AuthenticatorAttachment} + resp, _, err := c.cfg.WebauthnLoginFunc(ctx, c.cfg.getWebauthnOrigin(), wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge), prompt, opts) + if err != nil { + return nil, trace.Wrap(err) + } + + return resp, nil +} + +func (c *CLIPrompt) promptDevicePrefix() string { + if c.cfg.DeviceType != "" { + return fmt.Sprintf("*%s* ", c.cfg.DeviceType) + } + return "" +} + +// webauthnPromptWithOTP implements wancli.LoginPrompt for MFA logins. +// In most cases authenticators shouldn't require PINs or additional touches for +// MFA, but the implementation exists in case we find some unusual +// authenticators out there. +type webauthnPromptWithOTP struct { + wancli.LoginPrompt + otpCancelAndWait func() +} + +func (w *webauthnPromptWithOTP) PromptPIN() (string, error) { + // If we get to this stage, Webauthn PIN verification is underway. + // Cancel otp goroutine so that it doesn't capture the PIN from stdin. + w.otpCancelAndWait() + return w.LoginPrompt.PromptPIN() +} diff --git a/lib/client/mfa/prompt.go b/lib/client/mfa/prompt.go index d270f296b7a8a..0a67eaf5dc4de 100644 --- a/lib/client/mfa/prompt.go +++ b/lib/client/mfa/prompt.go @@ -18,48 +18,25 @@ package mfa import ( "context" - "errors" - "fmt" - "os" "strings" - "sync" "github.com/gravitational/trace" "github.com/sirupsen/logrus" - oteltrace "go.opentelemetry.io/otel/trace" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client/proto" - "github.com/gravitational/teleport/api/observability/tracing" - "github.com/gravitational/teleport/api/utils/prompt" + "github.com/gravitational/teleport/api/mfa" wancli "github.com/gravitational/teleport/lib/auth/webauthncli" wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" - "github.com/gravitational/teleport/lib/auth/webauthnwin" ) -// AdminMFAHintBeforePrompt is a hint used for MFA prompts for admin-level API requests. -const AdminMFAHintBeforePrompt = "MFA is required for admin-level API request." - -var log = logrus.WithFields(logrus.Fields{ - trace.Component: teleport.ComponentClient, -}) - -// Prompt is an MFA prompt. -type Prompt struct { - // WebauthnLogin performs client-side Webauthn login. - WebauthnLogin func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) +// PromptConfig contains common mfa prompt config options. +type PromptConfig struct { + mfa.PromptConfig // ProxyAddress is the address of the authenticating proxy. required. ProxyAddress string - // HintBeforePrompt is an optional hint message to print before an MFA prompt. - // It is used to provide context about why the user is being prompted where it may - // not be obvious. - HintBeforePrompt string - // PromptDevicePrefix is an optional prefix printed before "security key" or - // "device". It is used to emphasize between different kinds of devices, like - // registered vs new. - PromptDevicePrefix string - // Quiet suppresses users prompts. - Quiet bool + // WebauthnLoginFunc performs client-side Webauthn login. + WebauthnLoginFunc func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) // AllowStdinHijack allows stdin hijack during MFA prompts. // Stdin hijack provides a better login UX, but it can be difficult to reason // about and is often a source of bugs. @@ -73,228 +50,70 @@ type Prompt struct { PreferOTP bool // WebauthnSupported indicates whether Webauthn is supported. WebauthnSupported bool + // Log is a logging entry. + Log *logrus.Entry } -// PromptOpt applies configuration options to a prompt. -type PromptOpt func(*Prompt) - -// WithQuiet sets the prompt's Quiet field. -func WithQuiet() PromptOpt { - return func(p *Prompt) { - p.Quiet = true +// NewPromptConfig returns a prompt config that will induce default behavior. +func NewPromptConfig(proxyAddr string, opts ...mfa.PromptOpt) *PromptConfig { + cfg := &PromptConfig{ + ProxyAddress: proxyAddr, + WebauthnLoginFunc: wancli.Login, + WebauthnSupported: wancli.HasPlatformSupport(), + Log: logrus.WithFields(logrus.Fields{ + trace.Component: teleport.ComponentClient, + }), } -} -// WithHintBeforePrompt sets the prompt's HintBeforePrompt field. -func WithHintBeforePrompt(hint string) PromptOpt { - return func(p *Prompt) { - p.HintBeforePrompt = hint + for _, opt := range opts { + opt(&cfg.PromptConfig) } -} -// WithPromptDevicePrefix sets the prompt's PromptDevicePrefix field. -func WithPromptDevicePrefix(prefix string) PromptOpt { - return func(p *Prompt) { - p.PromptDevicePrefix = prefix - } + return cfg } -// NewPrompt creates a new prompt with standard behavior. -// If you want to customize [Prompt], for example for testing purposes, you may -// create or configure an instance directly, without calling this method. -func NewPrompt(proxyAddr string) *Prompt { - return &Prompt{ - WebauthnLogin: wancli.Login, - ProxyAddress: proxyAddr, - WebauthnSupported: wancli.HasPlatformSupport(), - } +// runOpts are mfa prompt run options. +type runOpts struct { + promptTOTP bool + promptWebauthn bool } -// Run prompts the user to complete MFA authentication challenges according to the prompt's configuration. -func (p *Prompt) Run(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { - ctx, span := tracing.NewTracer("MFACeremony").Start( - ctx, - "Run", - oteltrace.WithSpanKind(oteltrace.SpanKindClient), - ) - defer span.End() - - // Is there a challenge present? - if chal.TOTP == nil && chal.WebauthnChallenge == nil { - return &proto.MFAAuthenticateResponse{}, nil - } - - writer := os.Stderr - if p.HintBeforePrompt != "" { - fmt.Fprintln(writer, p.HintBeforePrompt) - } +// getRunOptions gets mfa prompt run options by cross referencing the mfa challenge with prompt configuration. +func (c PromptConfig) getRunOptions(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (runOpts, error) { + promptTOTP := chal.TOTP != nil + promptWebauthn := chal.WebauthnChallenge != nil - promptDevicePrefix := p.PromptDevicePrefix - if promptDevicePrefix != "" { - promptDevicePrefix += " " + if !promptTOTP && !promptWebauthn { + return runOpts{}, trace.BadParameter("mfa challenge is empty") } - quiet := p.Quiet - - hasTOTP := chal.TOTP != nil - hasWebauthn := chal.WebauthnChallenge != nil - // Does the current platform support hardware MFA? Adjust accordingly. switch { - case !hasTOTP && !p.WebauthnSupported: - return nil, trace.BadParameter("hardware device MFA not supported by your platform, please register an OTP device") - case !p.WebauthnSupported: + case !promptTOTP && !c.WebauthnSupported: + return runOpts{}, trace.BadParameter("hardware device MFA not supported by your platform, please register an OTP device") + case !c.WebauthnSupported: // Do not prompt for hardware devices, it won't work. - hasWebauthn = false + promptWebauthn = false } // Tweak enabled/disabled methods according to opts. switch { - case hasTOTP && p.PreferOTP: - hasWebauthn = false - case hasWebauthn && p.AuthenticatorAttachment != wancli.AttachmentAuto: + case promptTOTP && c.PreferOTP: + promptWebauthn = false + case promptWebauthn && c.AuthenticatorAttachment != wancli.AttachmentAuto: // Prefer Webauthn if an specific attachment was requested. - hasTOTP = false - case hasWebauthn && !p.AllowStdinHijack: + promptTOTP = false + case promptWebauthn && !c.AllowStdinHijack: // Use strongest auth if hijack is not allowed. - hasTOTP = false + promptTOTP = false } - var numGoroutines int - if hasTOTP && hasWebauthn { - numGoroutines = 2 - } else { - numGoroutines = 1 - } - - type response struct { - kind string - resp *proto.MFAAuthenticateResponse - err error - } - respC := make(chan response, numGoroutines) - - // Use ctx and wg to clean up after ourselves. - ctx, cancel := context.WithCancel(ctx) - defer cancel() - var wg sync.WaitGroup - cancelAndWait := func() { - cancel() - wg.Wait() - } - - // Use variables below to cancel OTP reads and make sure the goroutine exited. - otpWait := &sync.WaitGroup{} - otpCtx, otpCancel := context.WithCancel(ctx) - defer otpCancel() - - // Fire TOTP goroutine. - if hasTOTP { - otpWait.Add(1) - wg.Add(1) - go func() { - defer otpWait.Done() - defer wg.Done() - const kind = "TOTP" - - // Let Webauthn take the prompt, it knows better if it's necessary. - var msg string - if !quiet && !hasWebauthn { - msg = fmt.Sprintf("Enter an OTP code from a %sdevice", promptDevicePrefix) - } - - otp, err := prompt.Password(otpCtx, writer, prompt.Stdin(), msg) - if err != nil { - respC <- response{kind: kind, err: err} - return - } - respC <- response{ - kind: kind, - resp: &proto.MFAAuthenticateResponse{ - Response: &proto.MFAAuthenticateResponse_TOTP{ - TOTP: &proto.TOTPResponse{Code: otp}, - }, - }, - } - }() - } - - // Fire Webauthn goroutine. - if hasWebauthn { - origin := p.ProxyAddress - if !strings.HasPrefix(origin, "https://") { - origin = "https://" + origin - } - wg.Add(1) - go func() { - defer wg.Done() - log.Debugf("WebAuthn: prompting devices with origin %q", origin) - - prompt := wancli.NewDefaultPrompt(ctx, writer) - prompt.SecondTouchMessage = fmt.Sprintf("Tap your %ssecurity key to complete login", promptDevicePrefix) - switch { - case quiet: - // Do not prompt. - prompt.FirstTouchMessage = "" - prompt.SecondTouchMessage = "" - case hasTOTP: // Webauthn + OTP - prompt.FirstTouchMessage = fmt.Sprintf("Tap any %ssecurity key or enter a code from a %sOTP device", promptDevicePrefix, promptDevicePrefix) - - // Customize Windows prompt directly. - // Note that the platform popup is a modal and will only go away if - // canceled. - webauthnwin.PromptPlatformMessage = "Follow the OS dialogs for platform authentication, or enter an OTP code here:" - defer webauthnwin.ResetPromptPlatformMessage() - - default: // Webauthn only - prompt.FirstTouchMessage = fmt.Sprintf("Tap any %ssecurity key", promptDevicePrefix) - } - mfaPrompt := &mfaPrompt{LoginPrompt: prompt, otpCancelAndWait: func() { - otpCancel() - otpWait.Wait() - }} - - resp, _, err := p.WebauthnLogin(ctx, origin, wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge), mfaPrompt, &wancli.LoginOpts{ - AuthenticatorAttachment: p.AuthenticatorAttachment, - }) - respC <- response{kind: "WEBAUTHN", resp: resp, err: err} - }() - } - - for i := 0; i < numGoroutines; i++ { - select { - case resp := <-respC: - switch err := resp.err; { - case errors.Is(err, wancli.ErrUsingNonRegisteredDevice): - // Surface error immediately. - case err != nil: - log.WithError(err).Debugf("%s authentication failed", resp.kind) - continue - } - - // Cleanup in-flight goroutines. - cancelAndWait() - return resp.resp, trace.Wrap(resp.err) - case <-ctx.Done(): - cancelAndWait() - return nil, trace.Wrap(ctx.Err()) - } - } - cancelAndWait() - return nil, trace.BadParameter( - "failed to authenticate using all MFA devices, rerun the command with '-d' to see error details for each device") -} - -// mfaPrompt implements wancli.LoginPrompt for MFA logins. -// In most cases authenticators shouldn't require PINs or additional touches for -// MFA, but the implementation exists in case we find some unusual -// authenticators out there. -type mfaPrompt struct { - wancli.LoginPrompt - otpCancelAndWait func() + return runOpts{promptTOTP, promptWebauthn}, nil } -func (p *mfaPrompt) PromptPIN() (string, error) { - p.otpCancelAndWait() - return p.LoginPrompt.PromptPIN() +func (c PromptConfig) getWebauthnOrigin() string { + if !strings.HasPrefix(c.ProxyAddress, "https://") { + return "https://" + c.ProxyAddress + } + return c.ProxyAddress } diff --git a/lib/client/mfa_test.go b/lib/client/mfa_test.go index 97819109ed04d..99faf439a3a8f 100644 --- a/lib/client/mfa_test.go +++ b/lib/client/mfa_test.go @@ -17,6 +17,7 @@ package client_test import ( "context" "errors" + "os" "testing" "time" @@ -68,7 +69,7 @@ func TestPromptMFAChallenge_usingNonRegisteredDevice(t *testing.T) { tests := []struct { name string challenge *proto.MFAAuthenticateChallenge - customizePrompt func(p *mfa.Prompt) + customizePrompt func(p *mfa.PromptConfig) }{ { name: "webauthn only", @@ -77,7 +78,7 @@ func TestPromptMFAChallenge_usingNonRegisteredDevice(t *testing.T) { { name: "webauthn and OTP", challenge: challengeWebauthnOTP, - customizePrompt: func(p *mfa.Prompt) { + customizePrompt: func(p *mfa.PromptConfig) { p.AllowStdinHijack = true // required for OTP+WebAuthn prompt. }, }, @@ -98,19 +99,17 @@ func TestPromptMFAChallenge_usingNonRegisteredDevice(t *testing.T) { return "", ctx.Err() })) - promptMFA := &mfa.Prompt{ - ProxyAddress: proxyAddr, - WebauthnSupported: true, - WebauthnLogin: func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { - return nil, "", wancli.ErrUsingNonRegisteredDevice - }, + promptConfig := mfa.NewPromptConfig(proxyAddr) + promptConfig.WebauthnSupported = true + promptConfig.WebauthnLoginFunc = func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { + return nil, "", wancli.ErrUsingNonRegisteredDevice } if test.customizePrompt != nil { - test.customizePrompt(promptMFA) + test.customizePrompt(promptConfig) } - _, err := promptMFA.Run(ctx, test.challenge) + _, err := mfa.NewCLIPrompt(promptConfig, os.Stderr).Run(ctx, test.challenge) if !errors.Is(err, wancli.ErrUsingNonRegisteredDevice) { t.Errorf("PromptMFAChallenge returned err=%q, want %q", err, wancli.ErrUsingNonRegisteredDevice) } diff --git a/lib/client/presence.go b/lib/client/presence.go index f7f7c5254e7dd..cdb76f23e6fbe 100644 --- a/lib/client/presence.go +++ b/lib/client/presence.go @@ -26,6 +26,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/mfa" ) // PresenceMaintainer allows maintaining presence with the Auth service. @@ -53,7 +54,7 @@ func WithPresenceClock(clock clockwork.Clock) PresenceOption { // RunPresenceTask periodically performs and MFA ceremony to detect that a user is // still present and attentive. -func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMaintainer, sessionID string, promptMFA PromptMFAFunc, opts ...PresenceOption) error { +func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMaintainer, sessionID string, mfaPrompt mfa.Prompt, opts ...PresenceOption) error { fmt.Fprintf(term, "\r\nTeleport > MFA presence enabled\r\n") o := &presenceOptions{ @@ -98,7 +99,7 @@ func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMai // We don't support TOTP for live presence. challenge.TOTP = nil - solution, err := promptMFA(ctx, challenge) + solution, err := mfaPrompt.Run(ctx, challenge) if err != nil { fmt.Fprintf(term, "\r\nTeleport > Failed to confirm presence: %v\r\n", err) return trace.Wrap(err) diff --git a/lib/client/weblogin.go b/lib/client/weblogin.go index 6c99fec13b993..7c3ce30a5d5a9 100644 --- a/lib/client/weblogin.go +++ b/lib/client/weblogin.go @@ -43,12 +43,13 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/auth" wancli "github.com/gravitational/teleport/lib/auth/webauthncli" wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" - "github.com/gravitational/teleport/lib/client/mfa" + libmfa "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/httplib/csrf" @@ -268,7 +269,7 @@ type SSHLoginMFA struct { SSHLogin // PromptMFA is a customizable MFA prompt function. // Defaults to [mfa.NewPrompt().Run] - PromptMFA PromptMFAFunc + PromptMFA mfa.Prompt // User is the login username. User string // Password is the login password. @@ -618,10 +619,10 @@ func SSHAgentMFALogin(ctx context.Context, login SSHLoginMFA) (*auth.SSHLoginRes promptMFA := login.PromptMFA if promptMFA == nil { - promptMFA = mfa.NewPrompt(login.ProxyAddr).Run + promptMFA = libmfa.NewCLIPrompt(libmfa.NewPromptConfig(login.ProxyAddr), os.Stderr) } - respPB, err := promptMFA(ctx, chal) + respPB, err := promptMFA.Run(ctx, chal) if err != nil { return nil, trace.Wrap(err) } @@ -816,10 +817,10 @@ func SSHAgentMFAWebSessionLogin(ctx context.Context, login SSHLoginMFA) (*WebCli promptMFA := login.PromptMFA if promptMFA == nil { - promptMFA = mfa.NewPrompt(login.ProxyAddr).Run + promptMFA = libmfa.NewCLIPrompt(libmfa.NewPromptConfig(login.ProxyAddr), os.Stderr) } - respPB, err := promptMFA(ctx, chal) + respPB, err := promptMFA.Run(ctx, chal) if err != nil { return nil, nil, trace.Wrap(err) } diff --git a/lib/teleterm/clusters/cluster_auth.go b/lib/teleterm/clusters/cluster_auth.go index 55b7e8be75cd6..4b88d07051116 100644 --- a/lib/teleterm/clusters/cluster_auth.go +++ b/lib/teleterm/clusters/cluster_auth.go @@ -226,7 +226,7 @@ func (c *Cluster) localMFALogin(user, password string) client.SSHLoginFunc { }, User: user, Password: password, - PromptMFA: c.clusterClient.PromptMFA, + PromptMFA: c.clusterClient.NewMFAPrompt(), }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 06c544925f04b..f7ba5226fad39 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -60,6 +60,7 @@ import ( "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/mfa" apitracing "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" @@ -178,7 +179,7 @@ type proxySettingsGetter interface { // PresenceChecker is a function that executes an mfa prompt to enforce // that a user is present. -type PresenceChecker = func(ctx context.Context, term io.Writer, maintainer client.PresenceMaintainer, sessionID string, promptMFA client.PromptMFAFunc, opts ...client.PresenceOption) error +type PresenceChecker = func(ctx context.Context, term io.Writer, maintainer client.PresenceMaintainer, sessionID string, mfaPrompt mfa.Prompt, opts ...client.PresenceOption) error // Config represents web handler configuration parameters type Config struct { diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 79aed118af900..7b9a58e2e4943 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -89,6 +89,7 @@ import ( apidefaults "github.com/gravitational/teleport/api/defaults" kubeproto "github.com/gravitational/teleport/api/gen/proto/go/teleport/kube/v1" transportpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/transport/v1" + "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" apiutils "github.com/gravitational/teleport/api/utils" @@ -9354,8 +9355,8 @@ func TestModeratedSessionWithMFA(t *testing.T) { RPID: RPID, }, }, - presenceChecker: func(ctx context.Context, term io.Writer, maintainer client.PresenceMaintainer, sessionID string, promptMFA client.PromptMFAFunc, opts ...client.PresenceOption) error { - return trace.Wrap(client.RunPresenceTask(ctx, term, maintainer, sessionID, promptMFA, client.WithPresenceClock(presenceClock))) + presenceChecker: func(ctx context.Context, term io.Writer, maintainer client.PresenceMaintainer, sessionID string, mfaPrompt mfa.Prompt, opts ...client.PresenceOption) error { + return trace.Wrap(client.RunPresenceTask(ctx, term, maintainer, sessionID, mfaPrompt, client.WithPresenceClock(presenceClock))) }, }) diff --git a/lib/web/desktop.go b/lib/web/desktop.go index c9b52d5401b76..a0f93129d49c8 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -41,6 +41,7 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/auth" @@ -311,7 +312,7 @@ func (h *Handler) performMFACeremony(ctx context.Context, authClient auth.Client span.End() }() - promptMFA := func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { + promptMFA := mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { codec := tdpMFACodec{} // Send the challenge over the socket. @@ -346,12 +347,12 @@ func (h *Handler) performMFACeremony(ctx context.Context, authClient auth.Client span.AddEvent("mfa ceremony completed") return assertion, nil - } + }) _, newCerts, err := client.PerformMFACeremony(ctx, client.PerformMFACeremonyParams{ CurrentAuthClient: nil, // Only RootAuthClient is used. RootAuthClient: authClient, - PromptMFA: promptMFA, + MFAPrompt: promptMFA, MFAAgainstRoot: true, MFARequiredReq: nil, // No need to verify. CertsReq: certsReq, @@ -589,7 +590,6 @@ func (h *Handler) desktopAccessScriptConfigureHandle(w http.ResponseWriter, r *h types.CertAuthID{Type: types.UserCA, DomainName: clusterName}, false, ) - if err != nil { return nil, trace.Wrap(err) } @@ -624,7 +624,6 @@ func (h *Handler) desktopAccessScriptConfigureHandle(w http.ResponseWriter, r *h }) return nil, trace.Wrap(err) - } func (h *Handler) desktopAccessScriptInstallADDSHandle(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 8f90f15ef4b69..b8950268925be 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -42,6 +42,7 @@ import ( "github.com/gravitational/teleport" authproto "github.com/gravitational/teleport/api/client/proto" apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/api/observability/tracing" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" @@ -536,12 +537,12 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te key, _, err = client.PerformMFACeremony(ctx, client.PerformMFACeremonyParams{ CurrentAuthClient: t.authProvider, RootAuthClient: t.ctx.cfg.RootClient, - PromptMFA: func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) { + MFAPrompt: mfa.PromptFunc(func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) { span.AddEvent("prompting user with mfa challenge") - assertion, err := promptMFAChallenge(wsStream, protobufMFACodec{})(ctx, chal) + assertion, err := promptMFAChallenge(wsStream, protobufMFACodec{}).Run(ctx, chal) span.AddEvent("user completed mfa challenge") return assertion, trace.Wrap(err) - }, + }), MFAAgainstRoot: t.ctx.cfg.RootClusterName == tc.SiteName, MFARequiredReq: mfaRequiredReq, CertsReq: certsReq, @@ -560,8 +561,8 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te return []ssh.AuthMethod{am}, nil } -func promptMFAChallenge(stream *WSStream, codec mfaCodec) client.PromptMFAFunc { - return func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) { +func promptMFAChallenge(stream *WSStream, codec mfaCodec) mfa.Prompt { + return mfa.PromptFunc(func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) { var challenge *client.MFAAuthenticateChallenge // Convert from proto to JSON types. @@ -580,7 +581,7 @@ func promptMFAChallenge(stream *WSStream, codec mfaCodec) client.PromptMFAFunc { resp, err := stream.readChallengeResponse(codec) return resp, trace.Wrap(err) - } + }) } type connectWithMFAFn = func(ctx context.Context, ws WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error) diff --git a/tool/tctl/common/helpers_test.go b/tool/tctl/common/helpers_test.go index 1670bda326a90..5dc2efbb93bd5 100644 --- a/tool/tctl/common/helpers_test.go +++ b/tool/tctl/common/helpers_test.go @@ -93,9 +93,7 @@ func getAuthClient(ctx context.Context, t *testing.T, fc *config.FileConfig, opt require.NoError(t, err) t.Cleanup(func() { - if closer, ok := client.(io.Closer); ok { - closer.Close() - } + client.Close() }) return client @@ -200,6 +198,7 @@ func mustDecodeYAML[T any](t *testing.T, r io.Reader) T { require.NoError(t, err) return out } + func mustGetBase64EncFileConfig(t *testing.T, fc *config.FileConfig) string { configYamlContent, err := yaml.Marshal(fc) require.NoError(t, err) @@ -210,7 +209,7 @@ func mustWriteFileConfig(t *testing.T, fc *config.FileConfig) string { fileConfPath := filepath.Join(t.TempDir(), "teleport.yaml") fileConfYAML, err := yaml.Marshal(fc) require.NoError(t, err) - err = os.WriteFile(fileConfPath, fileConfYAML, 0600) + err = os.WriteFile(fileConfPath, fileConfYAML, 0o600) require.NoError(t, err) return fileConfPath } diff --git a/tool/tctl/common/tctl.go b/tool/tctl/common/tctl.go index 4cbd282e95563..2e36b89b8514d 100644 --- a/tool/tctl/common/tctl.go +++ b/tool/tctl/common/tctl.go @@ -33,12 +33,13 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/breaker" "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/identityfile" - "github.com/gravitational/teleport/lib/client/mfa" + libmfa "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/config" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/modules" @@ -195,11 +196,6 @@ func TryRun(commands []CLICommand, args []string) error { } ctx := context.Background() - - mfaPrompt := mfa.NewPrompt("") - mfaPrompt.HintBeforePrompt = mfa.AdminMFAHintBeforePrompt - clientConfig.PromptAdminRequestMFA = mfaPrompt.Run - client, err := authclient.Connect(ctx, clientConfig) if err != nil { if utils.IsUntrustedCertErr(err) { @@ -211,12 +207,17 @@ func TryRun(commands []CLICommand, args []string) error { return trace.NewAggregate(&common.ExitCodeError{Code: 1}, err) } - // Set proxy address for the MFA prompt from the ping response. + // Get the proxy address and set the MFA prompt constructor. resp, err := client.Ping(ctx) if err != nil { return trace.Wrap(err) } - mfaPrompt.ProxyAddress = resp.ProxyPublicAddr + + proxyAddr := resp.ProxyPublicAddr + client.SetMFAPromptConstructor(func(opts ...mfa.PromptOpt) mfa.Prompt { + promptCfg := libmfa.NewPromptConfig(proxyAddr, opts...) + return libmfa.NewCLIPrompt(promptCfg, os.Stderr) + }) // execute whatever is selected: var match bool diff --git a/tool/tsh/common/kube_proxy.go b/tool/tsh/common/kube_proxy.go index f0f7eda760413..55c16095f2d98 100644 --- a/tool/tsh/common/kube_proxy.go +++ b/tool/tsh/common/kube_proxy.go @@ -35,12 +35,12 @@ import ( "github.com/gravitational/teleport/api/client/proto" apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/asciitable" "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/client" - "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/kube/kubeconfig" "github.com/gravitational/teleport/lib/srv/alpnproxy" "github.com/gravitational/teleport/lib/utils" @@ -567,7 +567,6 @@ func issueKubeCert(ctx context.Context, tc *client.TeleportClient, proxy *client requesterName = proto.UserCertsRequest_TSH_KUBE_LOCAL_PROXY_HEADLESS } - hint := fmt.Sprintf("MFA is required to access Kubernetes cluster %q", kubeCluster) key, err := proxy.IssueUserCertsWithMFA( ctx, client.ReissueParams{ @@ -575,7 +574,7 @@ func issueKubeCert(ctx context.Context, tc *client.TeleportClient, proxy *client KubernetesCluster: kubeCluster, RequesterName: requesterName, }, - tc.NewMFAPrompt(mfa.WithHintBeforePrompt(hint)), + tc.NewMFAPrompt(mfa.WithPromptReasonSessionMFA("Kubernetes cluster", kubeCluster)), client.WithMFARequired(&mfaRequired), ) if err != nil { diff --git a/tool/tsh/common/mfa.go b/tool/tsh/common/mfa.go index 06c4a7c903319..03e8287eb3762 100644 --- a/tool/tsh/common/mfa.go +++ b/tool/tsh/common/mfa.go @@ -34,6 +34,7 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/prompt" "github.com/gravitational/teleport/lib/asciitable" @@ -42,7 +43,6 @@ import ( wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" wanwin "github.com/gravitational/teleport/lib/auth/webauthnwin" "github.com/gravitational/teleport/lib/client" - "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/utils" @@ -349,7 +349,7 @@ func (c *mfaAddCommand) addDeviceRPC(ctx context.Context, tc *client.TeleportCli // Prompt for authentication. // Does nothing if no challenges were issued (aka user has no devices). - authnResp, err := tc.NewMFAPrompt(mfa.WithPromptDevicePrefix("*registered*"))(ctx, authChallenge) + authnResp, err := tc.NewMFAPrompt(mfa.WithPromptDeviceType(mfa.DeviceDescriptorRegistered)).Run(ctx, authChallenge) if err != nil { return trace.Wrap(err) }