diff --git a/lib/auth/webauthnwin/api.go b/lib/auth/webauthnwin/api.go index 75f1100df1b41..a9597c68d3a92 100644 --- a/lib/auth/webauthnwin/api.go +++ b/lib/auth/webauthnwin/api.go @@ -73,7 +73,7 @@ type makeCredentialRequest struct { } // Login implements Login for Windows Webauthn API. -func Login(ctx context.Context, origin string, assertion *wanlib.CredentialAssertion, loginOpts *LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { +func Login(_ context.Context, origin string, assertion *wanlib.CredentialAssertion, loginOpts *LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { if origin == "" { return nil, "", trace.BadParameter("origin required") } @@ -112,10 +112,7 @@ func Login(ctx context.Context, origin string, assertion *wanlib.CredentialAsser } // Register implements Register for Windows Webauthn API. -func Register( - ctx context.Context, - origin string, cc *wanlib.CredentialCreation, -) (*proto.MFARegisterResponse, error) { +func Register(_ context.Context, origin string, cc *wanlib.CredentialCreation) (*proto.MFARegisterResponse, error) { if origin == "" { return nil, trace.BadParameter("origin required") } @@ -163,13 +160,21 @@ func Register( }, nil } +const defaultPromptMessage = "Using platform authenticator, follow the OS dialogs" + var ( - // PromptPlatformMessage is the message shown before Touch ID prompts. - PromptPlatformMessage = "Using platform authenticator, follow the OS dialogs" + // PromptPlatformMessage is the message shown before system prompts. + PromptPlatformMessage = defaultPromptMessage + // PromptWriter is the writer used for prompt messages. PromptWriter io.Writer = os.Stderr ) +// ResetPromptPlatformMessage resets [PromptPlatformMessage] to its original state. +func ResetPromptPlatformMessage() { + PromptPlatformMessage = defaultPromptMessage +} + func promptPlatform() { if PromptPlatformMessage != "" { fmt.Fprintln(PromptWriter, PromptPlatformMessage) @@ -210,7 +215,7 @@ type DiagResult struct { // Diag runs a few diagnostic commands and returns the result. // User interaction is required. -func Diag(ctx context.Context, promptOut io.Writer) (*DiagResult, error) { +func Diag(ctx context.Context) (*DiagResult, error) { res := &DiagResult{} if !IsAvailable() { return res, nil diff --git a/tool/tsh/mfa.go b/tool/tsh/mfa.go index 41806b54dec9f..6871603f0d96e 100644 --- a/tool/tsh/mfa.go +++ b/tool/tsh/mfa.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/teleport/lib/auth/touchid" wanlib "github.com/gravitational/teleport/lib/auth/webauthn" wancli "github.com/gravitational/teleport/lib/auth/webauthncli" + "github.com/gravitational/teleport/lib/auth/webauthnwin" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/utils" @@ -350,6 +351,16 @@ func (c *mfaAddCommand) addDeviceRPC(ctx context.Context, tc *client.TeleportCli if authChallenge == nil { return trace.BadParameter("server bug: server sent %T when client expected AddMFADeviceResponse_ExistingMFAChallenge", resp.Response) } + + // Tweak Windows platform messages so it's clear we whether we are prompting + // for the *registered* or *new* device. + // We do it here, preemptively, because it's the simpler solution (instead + // of finding out whether it is a Windows prompt or not). + const registeredMsg = "Using platform authentication for *registered* device, follow the OS dialogs" + const newMsg = "Using platform authentication for *new* device, follow the OS dialogs" + defer webauthnwin.ResetPromptPlatformMessage() + webauthnwin.PromptPlatformMessage = registeredMsg + authResp, err := tc.PromptMFAChallenge(ctx, "" /* proxyAddr */, authChallenge, func(opts *client.PromptMFAChallengeOpts) { opts.PromptDevicePrefix = "*registered* " }) @@ -371,6 +382,8 @@ func (c *mfaAddCommand) addDeviceRPC(ctx context.Context, tc *client.TeleportCli if regChallenge == nil { return trace.BadParameter("server bug: server sent %T when client expected AddMFADeviceResponse_NewMFARegisterChallenge", resp.Response) } + + webauthnwin.PromptPlatformMessage = newMsg regResp, regCallback, err := promptRegisterChallenge(ctx, tc.WebProxyAddr, c.devType, regChallenge) if err != nil { return trace.Wrap(err) diff --git a/tool/tsh/winwebauthn.go b/tool/tsh/winwebauthn.go index c1cb7d7ba809b..24b1c136191d3 100644 --- a/tool/tsh/winwebauthn.go +++ b/tool/tsh/winwebauthn.go @@ -58,7 +58,12 @@ func (w *webauthnwinDiagCommand) run(cf *CLIConf) error { if !diag.IsAvailable { return nil } - resp, err := webauthnwin.Diag(cf.Context, os.Stdout) + + promptBefore := webauthnwin.PromptWriter + defer func() { webauthnwin.PromptWriter = promptBefore }() + webauthnwin.PromptWriter = os.Stderr + + resp, err := webauthnwin.Diag(cf.Context) // Abort if we got a nil diagnostic, otherwise print as much as we can. if resp == nil { return trace.Wrap(err)