Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1185,19 +1185,13 @@ func (h *Handler) ping(w http.ResponseWriter, r *http.Request, p httprouter.Para

// TODO: This part should be removed once the plugin support is added to OSS.
if proxyConfig.AssistEnabled {
// TODO(jakule): Currently assist is disabled when per-session MFA is enabled as this part is not implemented.
authPreference, err := h.cfg.ProxyClient.GetAuthPreference(r.Context())
if err != nil {
return webclient.AuthenticationSettings{}, trace.Wrap(err)
}
mfaRequired := authPreference.GetRequireMFAType() != types.RequireMFAType_OFF
enabled, err := h.cfg.ProxyClient.IsAssistEnabled(r.Context())
if err != nil {
return webclient.AuthenticationSettings{}, trace.Wrap(err)
}

// disable if per-session MFA is enabled and it's ok by the auth
proxyConfig.AssistEnabled = enabled.Enabled && !mfaRequired
// disable if auth doesn't support assist
proxyConfig.AssistEnabled = enabled.Enabled
}

pr, err := h.cfg.ProxyClient.Ping(r.Context())
Expand Down
71 changes: 63 additions & 8 deletions lib/web/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"fmt"
"net/http"
"strings"
"sync"
"time"

"github.com/gogo/protobuf/proto"
Expand All @@ -33,6 +34,7 @@ import (
"github.com/julienschmidt/httprouter"
"github.com/sirupsen/logrus"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/crypto/ssh"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/gravitational/teleport"
Expand Down Expand Up @@ -185,6 +187,8 @@ func (h *Handler) executeCommand(

h.log.Debugf("Found %d hosts to run Assist command %q on.", len(hosts), req.Command)

mfaCacheFn := getMFACacheFn()

for _, host := range hosts {
err := func() error {
sessionData, err := h.generateCommandSession(&host, req.Login, clusterName, sessionCtx.cfg.User)
Expand All @@ -206,6 +210,7 @@ func (h *Handler) executeCommand(
Router: h.cfg.Router,
TracerProvider: h.cfg.TracerProvider,
LocalAuthProvider: h.auth.accessPoint,
mfaFuncCache: mfaCacheFn,
}

handler, err := newCommandHandler(ctx, commandHandlerConfig)
Expand Down Expand Up @@ -257,6 +262,27 @@ func (h *Handler) executeCommand(
return nil, nil
}

// getMFACacheFn returns a function that caches the result of the given
// get function. The cache is protected by a mutex, so it is safe to call
// the returned function from multiple goroutines.
func getMFACacheFn() mfaFuncCache {
var mutex sync.Mutex
var authMethods []ssh.AuthMethod

return func(issueMfaAuthFn func() ([]ssh.AuthMethod, error)) ([]ssh.AuthMethod, error) {
mutex.Lock()
defer mutex.Unlock()

if authMethods != nil {
return authMethods, nil
}

var err error
authMethods, err = issueMfaAuthFn()
return authMethods, trace.Wrap(err)
}
}

func newCommandHandler(ctx context.Context, cfg CommandHandlerConfig) (*commandHandler, error) {
err := cfg.CheckAndSetDefaults()
if err != nil {
Expand All @@ -282,6 +308,7 @@ func newCommandHandler(ctx context.Context, cfg CommandHandlerConfig) (*commandH
localAuthProvider: cfg.LocalAuthProvider,
tracer: cfg.tracer,
},
mfaAuthCache: cfg.mfaFuncCache,
}, nil
}

Expand Down Expand Up @@ -310,6 +337,8 @@ type CommandHandlerConfig struct {
LocalAuthProvider agentless.AuthProvider
// tracer is used to create spans
tracer oteltrace.Tracer
// mfaFuncCache is used to cache the MFA auth method
mfaFuncCache mfaFuncCache
}

// CheckAndSetDefaults checks and sets default values.
Expand Down Expand Up @@ -348,11 +377,19 @@ func (t *CommandHandlerConfig) CheckAndSetDefaults() error {
return trace.BadParameter("LocalAuthProvider must be provided")
}

if t.mfaFuncCache == nil {
return trace.BadParameter("mfaFuncCache must be provided")
}

t.tracer = t.TracerProvider.Tracer("webcommand")

return nil
}

// mfaFuncCache is a function type that caches the result of a function that
// returns a list of ssh.AuthMethods.
type mfaFuncCache func(func() ([]ssh.AuthMethod, error)) ([]ssh.AuthMethod, error)

// commandHandler is a handler for executing commands on a remote node.
type commandHandler struct {
sshBaseHandler
Expand All @@ -362,6 +399,11 @@ type commandHandler struct {

// ws a raw websocket connection to the client.
ws WSConn

// mfaAuthCache is a function that caches the result of a function that
// returns a list of ssh.AuthMethods. It is used to cache the result of
// the MFA challenge.
mfaAuthCache mfaFuncCache
}

// sendError sends an error message to the client using the provided websocket.
Expand Down Expand Up @@ -447,14 +489,7 @@ func (t *commandHandler) streamOutput(ctx context.Context, tc *client.TeleportCl
ctx, span := t.tracer.Start(ctx, "commandHandler/streamOutput")
defer span.End()

mfaAuth := func(ctx context.Context, ws WSConn, tc *client.TeleportClient,
accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator,
) (*client.NodeClient, error) {
return nil, trace.NotImplemented("MFA is not supported for command execution")
}

//TODO(jakule): Implement MFA support
nc, err := t.connectToHost(ctx, t.ws, tc, mfaAuth)
nc, err := t.connectToHost(ctx, t.ws, tc, t.connectToNodeWithMFA)
if err != nil {
t.log.WithError(err).Warn("Unable to stream terminal - failure connecting to host")
t.writeError(err)
Expand Down Expand Up @@ -482,6 +517,26 @@ func (t *commandHandler) streamOutput(ctx context.Context, tc *client.TeleportCl
t.log.Debug("Sent close event to web client.")
}

// connectToNodeWithMFA attempts to perform the mfa ceremony and then dial the
// host with the retrieved single use certs.
// If called multiple times, the mfa ceremony will only be performed once.
func (t *commandHandler) connectToNodeWithMFA(ctx context.Context, ws WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error) {
authMethods, err := t.mfaAuthCache(func() ([]ssh.AuthMethod, error) {
// perform mfa ceremony and retrieve new certs
authMethods, err := t.issueSessionMFACerts(ctx, tc, t.stream)
if err != nil {
return nil, trace.Wrap(err)
}

return authMethods, nil
})
if err != nil {
return nil, trace.Wrap(err)
}

return t.connectToNodeWithMFABase(ctx, ws, tc, accessChecker, getAgent, signer, authMethods)
}

// Close is no-op as we never want to close the connection to the client.
// Connection should be closed in the handler when it was created.
func (t *commandHandler) Close() error {
Expand Down
2 changes: 1 addition & 1 deletion lib/web/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func desktopTLSConfig(ctx context.Context, ws *websocket.Conn, pc *client.ProxyC
TLSCert: sessCtx.cfg.Session.GetTLSCert(),
WindowsDesktopCerts: make(map[string][]byte),
},
}, promptMFAChallenge(stream, tdpMFACodec{}))
}, promptMFAChallenge(&stream.WSStream, tdpMFACodec{}))
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
14 changes: 10 additions & 4 deletions lib/web/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ func (t *TerminalHandler) makeClient(ctx context.Context, ws *websocket.Conn) (*
// used to access nodes which require per-session mfa. The ceremony is performed directly
// to make use of the authProvider already established for the session instead of leveraging
// the TeleportClient which would require dialing the auth server a second time.
func (t *TerminalHandler) issueSessionMFACerts(ctx context.Context, tc *client.TeleportClient) ([]ssh.AuthMethod, error) {
func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.TeleportClient, wsStream *WSStream) ([]ssh.AuthMethod, error) {
ctx, span := t.tracer.Start(ctx, "terminal/issueSessionMFACerts")
defer span.End()

Expand Down Expand Up @@ -550,7 +550,7 @@ func (t *TerminalHandler) issueSessionMFACerts(ctx context.Context, tc *client.T
}

span.AddEvent("prompting user with mfa challenge")
assertion, err := promptMFAChallenge(t.stream, protobufMFACodec{})(ctx, tc.WebProxyAddr, challenge)
assertion, err := promptMFAChallenge(wsStream, protobufMFACodec{})(ctx, tc.WebProxyAddr, challenge)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -589,7 +589,7 @@ func (t *TerminalHandler) issueSessionMFACerts(ctx context.Context, tc *client.T
}

func promptMFAChallenge(
stream *TerminalStream,
stream *WSStream,
codec mfaCodec,
) client.PromptMFAChallengeHandler {
return func(ctx context.Context, proxyAddr string, c *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) {
Expand Down Expand Up @@ -783,11 +783,17 @@ func (t *sshBaseHandler) connectToNode(ctx context.Context, ws WSConn, tc *clien
// host with the retrieved single use certs.
func (t *TerminalHandler) connectToNodeWithMFA(ctx context.Context, ws WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error) {
// perform mfa ceremony and retrieve new certs
authMethods, err := t.issueSessionMFACerts(ctx, tc)
authMethods, err := t.issueSessionMFACerts(ctx, tc, &t.stream.WSStream)
if err != nil {
return nil, trace.Wrap(err)
}

return t.connectToNodeWithMFABase(ctx, ws, tc, accessChecker, getAgent, signer, authMethods)
}

// connectToNodeWithMFABase attempts to dial the host with the provided auth
// methods.
func (t *sshBaseHandler) connectToNodeWithMFABase(ctx context.Context, ws WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator, authMethods []ssh.AuthMethod) (*client.NodeClient, error) {
sshConfig := &ssh.ClientConfig{
User: tc.HostLogin,
Auth: authMethods,
Expand Down
Loading