diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 80e527194b8a2..42240bb442d83 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -1185,19 +1185,13 @@ func (h *Handler) ping(w http.ResponseWriter, r *http.Request, p httprouter.Para // TODO: This part should be removed once the plugin support is added to OSS. if proxyConfig.AssistEnabled { - // TODO(jakule): Currently assist is disabled when per-session MFA is enabled as this part is not implemented. - authPreference, err := h.cfg.ProxyClient.GetAuthPreference(r.Context()) - if err != nil { - return webclient.AuthenticationSettings{}, trace.Wrap(err) - } - mfaRequired := authPreference.GetRequireMFAType() != types.RequireMFAType_OFF enabled, err := h.cfg.ProxyClient.IsAssistEnabled(r.Context()) if err != nil { return webclient.AuthenticationSettings{}, trace.Wrap(err) } - // disable if per-session MFA is enabled and it's ok by the auth - proxyConfig.AssistEnabled = enabled.Enabled && !mfaRequired + // disable if auth doesn't support assist + proxyConfig.AssistEnabled = enabled.Enabled } pr, err := h.cfg.ProxyClient.Ping(r.Context()) diff --git a/lib/web/command.go b/lib/web/command.go index 07d1cebb302b2..ba47b4dac654f 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -25,6 +25,7 @@ import ( "fmt" "net/http" "strings" + "sync" "time" "github.com/gogo/protobuf/proto" @@ -33,6 +34,7 @@ import ( "github.com/julienschmidt/httprouter" "github.com/sirupsen/logrus" oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/crypto/ssh" "google.golang.org/protobuf/types/known/timestamppb" "github.com/gravitational/teleport" @@ -185,6 +187,8 @@ func (h *Handler) executeCommand( h.log.Debugf("Found %d hosts to run Assist command %q on.", len(hosts), req.Command) + mfaCacheFn := getMFACacheFn() + for _, host := range hosts { err := func() error { sessionData, err := h.generateCommandSession(&host, req.Login, clusterName, sessionCtx.cfg.User) @@ -206,6 +210,7 @@ func (h *Handler) executeCommand( Router: h.cfg.Router, TracerProvider: h.cfg.TracerProvider, LocalAuthProvider: h.auth.accessPoint, + mfaFuncCache: mfaCacheFn, } handler, err := newCommandHandler(ctx, commandHandlerConfig) @@ -257,6 +262,27 @@ func (h *Handler) executeCommand( return nil, nil } +// getMFACacheFn returns a function that caches the result of the given +// get function. The cache is protected by a mutex, so it is safe to call +// the returned function from multiple goroutines. +func getMFACacheFn() mfaFuncCache { + var mutex sync.Mutex + var authMethods []ssh.AuthMethod + + return func(issueMfaAuthFn func() ([]ssh.AuthMethod, error)) ([]ssh.AuthMethod, error) { + mutex.Lock() + defer mutex.Unlock() + + if authMethods != nil { + return authMethods, nil + } + + var err error + authMethods, err = issueMfaAuthFn() + return authMethods, trace.Wrap(err) + } +} + func newCommandHandler(ctx context.Context, cfg CommandHandlerConfig) (*commandHandler, error) { err := cfg.CheckAndSetDefaults() if err != nil { @@ -282,6 +308,7 @@ func newCommandHandler(ctx context.Context, cfg CommandHandlerConfig) (*commandH localAuthProvider: cfg.LocalAuthProvider, tracer: cfg.tracer, }, + mfaAuthCache: cfg.mfaFuncCache, }, nil } @@ -310,6 +337,8 @@ type CommandHandlerConfig struct { LocalAuthProvider agentless.AuthProvider // tracer is used to create spans tracer oteltrace.Tracer + // mfaFuncCache is used to cache the MFA auth method + mfaFuncCache mfaFuncCache } // CheckAndSetDefaults checks and sets default values. @@ -348,11 +377,19 @@ func (t *CommandHandlerConfig) CheckAndSetDefaults() error { return trace.BadParameter("LocalAuthProvider must be provided") } + if t.mfaFuncCache == nil { + return trace.BadParameter("mfaFuncCache must be provided") + } + t.tracer = t.TracerProvider.Tracer("webcommand") return nil } +// mfaFuncCache is a function type that caches the result of a function that +// returns a list of ssh.AuthMethods. +type mfaFuncCache func(func() ([]ssh.AuthMethod, error)) ([]ssh.AuthMethod, error) + // commandHandler is a handler for executing commands on a remote node. type commandHandler struct { sshBaseHandler @@ -362,6 +399,11 @@ type commandHandler struct { // ws a raw websocket connection to the client. ws WSConn + + // mfaAuthCache is a function that caches the result of a function that + // returns a list of ssh.AuthMethods. It is used to cache the result of + // the MFA challenge. + mfaAuthCache mfaFuncCache } // sendError sends an error message to the client using the provided websocket. @@ -447,14 +489,7 @@ func (t *commandHandler) streamOutput(ctx context.Context, tc *client.TeleportCl ctx, span := t.tracer.Start(ctx, "commandHandler/streamOutput") defer span.End() - mfaAuth := func(ctx context.Context, ws WSConn, tc *client.TeleportClient, - accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator, - ) (*client.NodeClient, error) { - return nil, trace.NotImplemented("MFA is not supported for command execution") - } - - //TODO(jakule): Implement MFA support - nc, err := t.connectToHost(ctx, t.ws, tc, mfaAuth) + nc, err := t.connectToHost(ctx, t.ws, tc, t.connectToNodeWithMFA) if err != nil { t.log.WithError(err).Warn("Unable to stream terminal - failure connecting to host") t.writeError(err) @@ -482,6 +517,26 @@ func (t *commandHandler) streamOutput(ctx context.Context, tc *client.TeleportCl t.log.Debug("Sent close event to web client.") } +// connectToNodeWithMFA attempts to perform the mfa ceremony and then dial the +// host with the retrieved single use certs. +// If called multiple times, the mfa ceremony will only be performed once. +func (t *commandHandler) connectToNodeWithMFA(ctx context.Context, ws WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error) { + authMethods, err := t.mfaAuthCache(func() ([]ssh.AuthMethod, error) { + // perform mfa ceremony and retrieve new certs + authMethods, err := t.issueSessionMFACerts(ctx, tc, t.stream) + if err != nil { + return nil, trace.Wrap(err) + } + + return authMethods, nil + }) + if err != nil { + return nil, trace.Wrap(err) + } + + return t.connectToNodeWithMFABase(ctx, ws, tc, accessChecker, getAgent, signer, authMethods) +} + // Close is no-op as we never want to close the connection to the client. // Connection should be closed in the handler when it was created. func (t *commandHandler) Close() error { diff --git a/lib/web/desktop.go b/lib/web/desktop.go index 66de7829fa3c2..69afce3bc3eb4 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -264,7 +264,7 @@ func desktopTLSConfig(ctx context.Context, ws *websocket.Conn, pc *client.ProxyC TLSCert: sessCtx.cfg.Session.GetTLSCert(), WindowsDesktopCerts: make(map[string][]byte), }, - }, promptMFAChallenge(stream, tdpMFACodec{})) + }, promptMFAChallenge(&stream.WSStream, tdpMFACodec{})) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 9a334e5361f62..d87f898d24acb 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -467,7 +467,7 @@ func (t *TerminalHandler) makeClient(ctx context.Context, ws *websocket.Conn) (* // used to access nodes which require per-session mfa. The ceremony is performed directly // to make use of the authProvider already established for the session instead of leveraging // the TeleportClient which would require dialing the auth server a second time. -func (t *TerminalHandler) issueSessionMFACerts(ctx context.Context, tc *client.TeleportClient) ([]ssh.AuthMethod, error) { +func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.TeleportClient, wsStream *WSStream) ([]ssh.AuthMethod, error) { ctx, span := t.tracer.Start(ctx, "terminal/issueSessionMFACerts") defer span.End() @@ -550,7 +550,7 @@ func (t *TerminalHandler) issueSessionMFACerts(ctx context.Context, tc *client.T } span.AddEvent("prompting user with mfa challenge") - assertion, err := promptMFAChallenge(t.stream, protobufMFACodec{})(ctx, tc.WebProxyAddr, challenge) + assertion, err := promptMFAChallenge(wsStream, protobufMFACodec{})(ctx, tc.WebProxyAddr, challenge) if err != nil { return nil, trace.Wrap(err) } @@ -589,7 +589,7 @@ func (t *TerminalHandler) issueSessionMFACerts(ctx context.Context, tc *client.T } func promptMFAChallenge( - stream *TerminalStream, + stream *WSStream, codec mfaCodec, ) client.PromptMFAChallengeHandler { return func(ctx context.Context, proxyAddr string, c *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) { @@ -783,11 +783,17 @@ func (t *sshBaseHandler) connectToNode(ctx context.Context, ws WSConn, tc *clien // host with the retrieved single use certs. func (t *TerminalHandler) connectToNodeWithMFA(ctx context.Context, ws WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error) { // perform mfa ceremony and retrieve new certs - authMethods, err := t.issueSessionMFACerts(ctx, tc) + authMethods, err := t.issueSessionMFACerts(ctx, tc, &t.stream.WSStream) if err != nil { return nil, trace.Wrap(err) } + return t.connectToNodeWithMFABase(ctx, ws, tc, accessChecker, getAgent, signer, authMethods) +} + +// connectToNodeWithMFABase attempts to dial the host with the provided auth +// methods. +func (t *sshBaseHandler) connectToNodeWithMFABase(ctx context.Context, ws WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator, authMethods []ssh.AuthMethod) (*client.NodeClient, error) { sshConfig := &ssh.ClientConfig{ User: tc.HostLogin, Auth: authMethods, diff --git a/web/packages/teleport/src/Assist/Chat/ChatItem/Action/RunAction.tsx b/web/packages/teleport/src/Assist/Chat/ChatItem/Action/RunAction.tsx index fc0b5ff3adf07..f12a2fbaadd1d 100644 --- a/web/packages/teleport/src/Assist/Chat/ChatItem/Action/RunAction.tsx +++ b/web/packages/teleport/src/Assist/Chat/ChatItem/Action/RunAction.tsx @@ -15,7 +15,7 @@ * limitations under the License. */ -import React, { useEffect, useRef, useState } from 'react'; +import React, { useCallback, useState } from 'react'; import styled from 'styled-components'; import { useParams } from 'react-router'; @@ -27,6 +27,11 @@ import { ExecuteRemoteCommandContent } from 'teleport/Assist/services/messages'; import { MessageTypeEnum, Protobuf } from 'teleport/lib/term/protobuf'; import { Dots } from 'teleport/Assist/Dots'; import cfg from 'teleport/config'; +import { WebauthnAssertionResponse } from 'teleport/services/auth'; +import useWebAuthn from 'teleport/lib/useWebAuthn'; +import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; +import AuthnDialog from 'teleport/components/AuthnDialog'; +import { TermEvent } from 'teleport/lib/term/enums'; interface RunCommandProps { actions: ExecuteRemoteCommandContent; @@ -70,6 +75,82 @@ interface RawPayload { payload: string; } +class assistClient extends EventEmitterWebAuthnSender { + private readonly ws: WebSocket; + readonly proto: Protobuf = new Protobuf(); + readonly encoder = new window.TextEncoder(); + + constructor( + url: string, + setState: React.Dispatch> + ) { + super(); + + this.ws = new WebSocket(url); + this.ws.binaryType = 'arraybuffer'; + + this.ws.onmessage = event => { + const uintArray = new Uint8Array(event.data); + const msg = this.proto.decode(uintArray); + + switch (msg.type) { + case MessageTypeEnum.RAW: + const data = JSON.parse(msg.payload) as RawPayload; + const payload = atob(data.payload); + + setState(state => { + const results = state.find(node => node.nodeId == data.node_id); + if (!results) { + state.push({ + nodeId: data.node_id, + status: RunActionStatus.Connecting, + }); + } + + return state.map(item => { + if (item.nodeId === data.node_id) { + if (!item.stdout) { + item.stdout = ''; + } + return { + ...item, + status: RunActionStatus.Finished, + stdout: item.stdout + payload, + }; + } + + return item; + }); + }); + + break; + case MessageTypeEnum.ERROR: + console.error(msg.payload); + break; + case MessageTypeEnum.WEBAUTHN_CHALLENGE: + this.emit(TermEvent.WEBAUTHN_CHALLENGE, msg.payload); + break; + } + }; + } + + sendWebAuthn(data: WebauthnAssertionResponse) { + const msg = this.encoder.encode(JSON.stringify(data)); + this.send(msg); + } + + send(data) { + if (!this.ws || this.ws.readyState !== WebSocket.OPEN || !data) { + console.warn('websocket unavailable', this.ws, data); + return; + } + + const msg = this.proto.encodeRawMessage(data); + const bytearray = new Uint8Array(msg); + this.ws.send(bytearray.buffer); + } +} + export function RunCommand(props: RunCommandProps) { const { clusterId } = useStickyClusterId(); const urlParams = useParams<{ conversationId: string }>(); @@ -91,65 +172,34 @@ export function RunCommand(props: RunCommandProps) { execParams ); - const websocket = useRef(null); - const protoRef = useRef(null); - - useEffect(() => { - if (!websocket.current) { - const proto = new Protobuf(); - const ws = new WebSocket(url); - ws.binaryType = 'arraybuffer'; - - ws.onmessage = event => { - const uintArray = new Uint8Array(event.data); - const msg = proto.decode(uintArray); - - switch (msg.type) { - case MessageTypeEnum.RAW: - const data = JSON.parse(msg.payload) as RawPayload; - const payload = atob(data.payload); - - setState(state => { - const results = state.find(node => node.nodeId == data.node_id); - if (!results) { - state.push({ - nodeId: data.node_id, - status: RunActionStatus.Connecting, - }); - } - - const s = state.map(item => { - if (item.nodeId === data.node_id) { - if (!item.stdout) { - item.stdout = ''; - } - return { - ...item, - status: RunActionStatus.Finished, - stdout: item.stdout + payload, - }; - } + const [assistClt] = useState(() => new assistClient(url, setState)); + const webauthn = useWebAuthn(assistClt); - return item; - }); - - return s; - }); - - break; - } + const cancelCallback = useCallback(() => { + webauthn.setState(prevState => { + return { + ...prevState, + requested: false, }; - - protoRef.current = proto; - websocket.current = ws; - } - }, []); + }); + }, [webauthn]); const nodes = state.map((item, index) => ( )); - return
{nodes}
; + return ( + <> + {webauthn.requested && ( + + )} +
{nodes}
+ + ); } interface NodeOutputProps {