diff --git a/lib/web/desktop.go b/lib/web/desktop.go index cb3408fc2885f..8e097ab42e197 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -24,6 +24,7 @@ import ( "crypto/tls" "errors" "net/http" + "net/url" "github.com/google/uuid" "github.com/gorilla/websocket" @@ -40,6 +41,7 @@ import ( wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/client/sso" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/desktop" "github.com/gravitational/teleport/lib/reversetunnelclient" @@ -343,21 +345,42 @@ func (h *Handler) performSessionMFACeremony( span.End() }() + // channelID is used by the front end to differentiate between separate ongoing SSO challenges. + channelID := uuid.NewString() + mfaCeremony := &mfa.Ceremony{ - PromptConstructor: func(po ...mfa.PromptOpt) mfa.Prompt { + CreateAuthenticateChallenge: sctx.cfg.RootClient.CreateAuthenticateChallenge, + SSOMFACeremonyConstructor: func(_ context.Context) (mfa.SSOMFACeremony, error) { + u, err := url.Parse(sso.WebMFARedirect) + if err != nil { + return nil, trace.Wrap(err) + } + u.RawQuery = url.Values{"channel_id": {channelID}}.Encode() + return &sso.MFACeremony{ + ClientCallbackURL: u.String(), + ProxyAddress: h.PublicProxyAddr(), + }, nil + }, + PromptConstructor: func(...mfa.PromptOpt) mfa.Prompt { return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { - codec := tdpMFACodec{} + // Convert from proto to JSON types. + var challenge client.MFAAuthenticateChallenge + if chal.WebauthnChallenge != nil { + challenge.WebauthnChallenge = wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge) + } - if chal.WebauthnChallenge == nil { - return nil, trace.AccessDenied("Desktop access requires WebAuthn MFA, please register a WebAuthn device to connect") + if chal.SSOChallenge != nil { + challenge.SSOChallenge = client.SSOChallengeFromProto(chal.SSOChallenge) + challenge.SSOChallenge.ChannelID = channelID } + + if chal.WebauthnChallenge == nil && chal.SSOChallenge == nil { + return nil, trace.AccessDenied("Only WebAuthn and SSO MFA methods are supported on the web, please register a supported MFA method to connect to this desktop") + } + // Send the challenge over the socket. - msg, err := codec.Encode( - &client.MFAAuthenticateChallenge{ - WebauthnChallenge: wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge), - }, - defaults.WebsocketMFAChallenge, - ) + var codec tdpMFACodec + msg, err := codec.Encode(&challenge, defaults.WebsocketMFAChallenge) if err != nil { return nil, trace.Wrap(err) } @@ -407,7 +430,6 @@ func (h *Handler) performSessionMFACeremony( return assertion, nil }) }, - CreateAuthenticateChallenge: sctx.cfg.RootClient.CreateAuthenticateChallenge, } _, newCerts, err := client.PerformSessionMFACeremony(ctx, client.PerformSessionMFACeremonyParams{ diff --git a/lib/web/terminal.go b/lib/web/terminal.go index c53d0c817c47b..b8aeb5fc9b26b 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -628,16 +628,11 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te func newMFACeremony(stream *terminal.WSStream, createAuthenticateChallenge mfa.CreateAuthenticateChallengeFunc, proxyAddr string) *mfa.Ceremony { // channelID is used by the front end to differentiate between separate ongoing SSO challenges. - var channelID string + channelID := uuid.NewString() return &mfa.Ceremony{ CreateAuthenticateChallenge: createAuthenticateChallenge, SSOMFACeremonyConstructor: func(ctx context.Context) (mfa.SSOMFACeremony, error) { - id, err := uuid.NewRandom() - if err != nil { - return nil, trace.Wrap(err) - } - channelID = id.String() u, err := url.Parse(sso.WebMFARedirect) if err != nil {