Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions lib/web/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"errors"
"log/slog"
"net/http"
"net/url"

"github.com/google/uuid"
"github.com/gorilla/websocket"
Expand All @@ -40,6 +41,7 @@ import (
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/client/sso"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/desktop"
"github.com/gravitational/teleport/lib/reversetunnelclient"
Expand Down Expand Up @@ -384,21 +386,42 @@ func (h *Handler) performSessionMFACeremony(
span.End()
}()

// channelID is used by the front end to differentiate between separate ongoing SSO challenges.
channelID := uuid.NewString()

mfaCeremony := &mfa.Ceremony{
PromptConstructor: func(po ...mfa.PromptOpt) mfa.Prompt {
CreateAuthenticateChallenge: sctx.cfg.RootClient.CreateAuthenticateChallenge,
SSOMFACeremonyConstructor: func(_ context.Context) (mfa.SSOMFACeremony, error) {
u, err := url.Parse(sso.WebMFARedirect)
if err != nil {
return nil, trace.Wrap(err)
}
u.RawQuery = url.Values{"channel_id": {channelID}}.Encode()
return &sso.MFACeremony{
ClientCallbackURL: u.String(),
ProxyAddress: h.PublicProxyAddr(),
}, nil
},
PromptConstructor: func(...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
codec := tdpMFACodec{}
// Convert from proto to JSON types.
var challenge client.MFAAuthenticateChallenge
if chal.WebauthnChallenge != nil {
challenge.WebauthnChallenge = wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge)
}

if chal.WebauthnChallenge == nil {
return nil, trace.AccessDenied("Desktop access requires WebAuthn MFA, please register a WebAuthn device to connect")
if chal.SSOChallenge != nil {
challenge.SSOChallenge = client.SSOChallengeFromProto(chal.SSOChallenge)
challenge.SSOChallenge.ChannelID = channelID
}

if chal.WebauthnChallenge == nil && chal.SSOChallenge == nil {
return nil, trace.AccessDenied("Only WebAuthn and SSO MFA methods are supported on the web, please register a supported MFA method to connect to this desktop")
}

// Send the challenge over the socket.
msg, err := codec.Encode(
&client.MFAAuthenticateChallenge{
WebauthnChallenge: wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge),
},
defaults.WebsocketMFAChallenge,
)
var codec tdpMFACodec
msg, err := codec.Encode(&challenge, defaults.WebsocketMFAChallenge)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -457,7 +480,6 @@ func (h *Handler) performSessionMFACeremony(
return assertion, nil
})
},
CreateAuthenticateChallenge: sctx.cfg.RootClient.CreateAuthenticateChallenge,
}

result, err := client.PerformSessionMFACeremony(ctx, client.PerformSessionMFACeremonyParams{
Expand Down
7 changes: 1 addition & 6 deletions lib/web/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -624,16 +624,11 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te

func newMFACeremony(stream *terminal.WSStream, createAuthenticateChallenge mfa.CreateAuthenticateChallengeFunc, proxyAddr string) *mfa.Ceremony {
// channelID is used by the front end to differentiate between separate ongoing SSO challenges.
var channelID string
channelID := uuid.NewString()

return &mfa.Ceremony{
CreateAuthenticateChallenge: createAuthenticateChallenge,
SSOMFACeremonyConstructor: func(ctx context.Context) (mfa.SSOMFACeremony, error) {
id, err := uuid.NewRandom()
if err != nil {
return nil, trace.Wrap(err)
}
channelID = id.String()

u, err := url.Parse(sso.WebMFARedirect)
if err != nil {
Expand Down
Loading