Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ import (
userpreferencespb "github.com/gravitational/teleport/api/gen/proto/go/userpreferences/v1"
"github.com/gravitational/teleport/api/internalutils/stream"
"github.com/gravitational/teleport/api/metadata"
"github.com/gravitational/teleport/api/mfa"
"github.com/gravitational/teleport/api/observability/tracing"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/events"
Expand Down Expand Up @@ -630,9 +631,9 @@ type Config struct {
// PROXYHeaderGetter returns signed PROXY header that is sent to allow Proxy to propagate client's real IP to the
// auth server from the Proxy's web server, when we create user's client for the web session.
PROXYHeaderGetter PROXYHeaderGetter
// PromptAdminRequestMFA is used to prompt the user for MFA on admin requests when needed.
// MFAPromptConstructor is used to create MFA prompts when needed.
// If nil, the client will not prompt for MFA.
PromptAdminRequestMFA func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error)
MFAPromptConstructor mfa.PromptConstructor
}

// CheckAndSetDefaults checks and sets default config values.
Expand Down Expand Up @@ -694,6 +695,11 @@ func (c *Client) GetConnection() *grpc.ClientConn {
return c.conn
}

// SetMFAPromptConstructor sets the MFA prompt constructor for this client.
func (c *Client) SetMFAPromptConstructor(pc mfa.PromptConstructor) {
c.c.MFAPromptConstructor = pc
}

// Close closes the Client connection to the auth server.
func (c *Client) Close() error {
if c.setClosed() && c.conn != nil {
Expand Down
7 changes: 4 additions & 3 deletions api/client/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ import (
"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/mfa"
)

// performMFACeremony retrieves an MFA challenge from the server, prompts the
// user to answer the challenge, and returns the resulting MFA response.
func (c *Client) performMFACeremony(ctx context.Context) (*proto.MFAAuthenticateResponse, error) {
if c.c.PromptAdminRequestMFA == nil {
func (c *Client) performMFACeremony(ctx context.Context, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) {
if c.c.MFAPromptConstructor == nil {
return nil, trace.BadParameter("missing PromptAdminRequestMFA field, client cannot perform MFA ceremony")
}

Expand All @@ -38,7 +39,7 @@ func (c *Client) performMFACeremony(ctx context.Context) (*proto.MFAAuthenticate
return nil, trace.Wrap(err)
}

resp, err := c.c.PromptAdminRequestMFA(ctx, chal)
resp, err := c.c.MFAPromptConstructor(promptOpts...).Run(ctx, chal)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
19 changes: 14 additions & 5 deletions api/client/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/mfa"
)

const (
Expand Down Expand Up @@ -62,11 +63,8 @@ func TestPerformMFACeremony(t *testing.T) {
}

cfg := server.clientCfg()
cfg.PromptAdminRequestMFA = func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
if chal.TOTP != nil {
return mfaTestResp, nil
}
return nil, trace.BadParameter("expected TOTP challenge")
cfg.MFAPromptConstructor = func(opts ...mfa.PromptOpt) mfa.Prompt {
return &fakeMFAPrompt{mfaTestResp}
}

clt, err := New(ctx, cfg)
Expand All @@ -76,3 +74,14 @@ func TestPerformMFACeremony(t *testing.T) {
require.NoError(t, err)
require.Equal(t, mfaTestResp.Response, resp.Response)
}

type fakeMFAPrompt struct {
totpResp *proto.MFAAuthenticateResponse
}

func (p *fakeMFAPrompt) Run(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
if chal.TOTP != nil {
return p.totpResp, nil
}
return nil, trace.BadParameter("expected TOTP challenge")
}
94 changes: 94 additions & 0 deletions api/mfa/prompt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
Copyright 2023 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package mfa

import (
"context"
"fmt"

"github.com/gravitational/teleport/api/client/proto"
)

// Prompt is an MFA prompt.
type Prompt interface {
// Run prompts the user to complete an MFA authentication challenge.
Run(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error)
}

// PromptFunc is a function wrapper that implements the Prompt interface.
type PromptFunc func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error)

// Run prompts the user to complete an MFA authentication challenge.
func (f PromptFunc) Run(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return f(ctx, chal)
}

// PromptConstructor is a function that creates a new MFA prompt.
type PromptConstructor func(...PromptOpt) Prompt

// PromptConfig contains common mfa prompt config options.
type PromptConfig struct {
// PromptReason is an optional message to share with the user before an MFA Prompt.
// It is intended to provide context about why the user is being prompted where it may
// not be obvious, such as for admin actions or per-session MFA.
PromptReason string
// DeviceType is an optional device description to emphasize during the prompt.
DeviceType DeviceDescriptor
// Quiet suppresses users prompts.
Quiet bool
}

// DeviceDescriptor is a descriptor for a device, such as "registered".
type DeviceDescriptor string

// DeviceDescriptorRegistered is a registered device.
const DeviceDescriptorRegistered = "registered"

// PromptOpt applies configuration options to a prompt.
type PromptOpt func(*PromptConfig)

// WithQuiet sets the prompt's Quiet field.
func WithQuiet() PromptOpt {
return func(cfg *PromptConfig) {
cfg.Quiet = true
}
}

// WithPromptReason sets the prompt's PromptReason field.
func WithPromptReason(hint string) PromptOpt {
return func(cfg *PromptConfig) {
cfg.PromptReason = hint
}
}

// WithPromptReasonAdminAction sets the prompt's PromptReason field to a standard admin action message.
func WithPromptReasonAdminAction(actionName string) PromptOpt {
adminMFAPromptReason := fmt.Sprintf("MFA is required for admin-level API request: %q", actionName)
return WithPromptReason(adminMFAPromptReason)
}

// WithPromptReasonSessionMFA sets the prompt's PromptReason field to a standard session mfa message.
func WithPromptReasonSessionMFA(serviceType, serviceName string) PromptOpt {
return WithPromptReason(fmt.Sprintf("MFA is required to access %s %q", serviceType, serviceName))
}

// WithPromptDeviceType sets the prompt's DeviceType field.
func WithPromptDeviceType(deviceType DeviceDescriptor) PromptOpt {
return func(cfg *PromptConfig) {
cfg.DeviceType = deviceType
}
}
4 changes: 2 additions & 2 deletions api/utils/grpc/interceptors/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ import (

// RetryWithMFAUnaryInterceptor intercepts a GRPC client unary call to check if the
// error indicates that the client should retry with MFA verification.
func RetryWithMFAUnaryInterceptor(mfaCeremony func(ctx context.Context) (*proto.MFAAuthenticateResponse, error)) grpc.UnaryClientInterceptor {
func RetryWithMFAUnaryInterceptor(mfaCeremony func(ctx context.Context, opts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error)) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
err := invoker(ctx, method, req, reply, cc, opts...)
if !errors.Is(trail.FromGRPC(err), &mfa.ErrAdminActionMFARequired) {
return err
}

mfaResp, ceremonyErr := mfaCeremony(ctx)
mfaResp, ceremonyErr := mfaCeremony(ctx, mfa.WithPromptReasonAdminAction(method))
if ceremonyErr != nil {
return trace.NewAggregate(trail.FromGRPC(err), ceremonyErr)
}
Expand Down
4 changes: 2 additions & 2 deletions api/utils/grpc/interceptors/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func TestRetryWithMFA(t *testing.T) {

t.Run("with interceptor", func(t *testing.T) {
t.Run("ok mfa ceremony", func(t *testing.T) {
okMFACeremony := func(ctx context.Context) (*proto.MFAAuthenticateResponse, error) {
okMFACeremony := func(ctx context.Context, opts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) {
return &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_TOTP{
TOTP: &proto.TOTPResponse{
Expand All @@ -125,7 +125,7 @@ func TestRetryWithMFA(t *testing.T) {

t.Run("nok mfa ceremony", func(t *testing.T) {
mfaCeremonyErr := trace.BadParameter("client does not support mfa")
nokMFACeremony := func(ctx context.Context) (*proto.MFAAuthenticateResponse, error) {
nokMFACeremony := func(ctx context.Context, opts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) {
return nil, mfaCeremonyErr
}
conn, err := grpc.Dial(
Expand Down
16 changes: 8 additions & 8 deletions lib/auth/authclient/authclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import (

"github.com/gravitational/teleport/api/breaker"
apiclient "github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/client/webclient"
"github.com/gravitational/teleport/api/mfa"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/utils"
Expand All @@ -50,14 +50,14 @@ type Config struct {
CircuitBreakerConfig breaker.Config
// DialTimeout determines how long to wait for dialing to succeed before aborting.
DialTimeout time.Duration
// PromptAdminRequestMFA is used to prompt the user for MFA on admin requests when needed.
// MFAPromptConstructor is used to create MFA prompts when needed.
// If nil, the client will not prompt for MFA.
PromptAdminRequestMFA func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error)
MFAPromptConstructor mfa.PromptConstructor
}

// Connect creates a valid client connection to the auth service. It may
// connect directly to the auth server, or tunnel through the proxy.
func Connect(ctx context.Context, cfg *Config) (auth.ClientI, error) {
func Connect(ctx context.Context, cfg *Config) (*auth.Client, error) {
cfg.Log.Debugf("Connecting to: %v.", cfg.AuthServers)

directClient, err := connectViaAuthDirect(ctx, cfg)
Expand All @@ -83,7 +83,7 @@ func Connect(ctx context.Context, cfg *Config) (auth.ClientI, error) {
)
}

func connectViaAuthDirect(ctx context.Context, cfg *Config) (auth.ClientI, error) {
func connectViaAuthDirect(ctx context.Context, cfg *Config) (*auth.Client, error) {
// Try connecting to the auth server directly over TLS.
directClient, err := auth.NewClient(apiclient.Config{
Addrs: utils.NetAddrsToStrings(cfg.AuthServers),
Expand All @@ -93,7 +93,7 @@ func connectViaAuthDirect(ctx context.Context, cfg *Config) (auth.ClientI, error
CircuitBreakerConfig: cfg.CircuitBreakerConfig,
InsecureAddressDiscovery: cfg.TLS.InsecureSkipVerify,
DialTimeout: cfg.DialTimeout,
PromptAdminRequestMFA: cfg.PromptAdminRequestMFA,
MFAPromptConstructor: cfg.MFAPromptConstructor,
})
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -109,7 +109,7 @@ func connectViaAuthDirect(ctx context.Context, cfg *Config) (auth.ClientI, error
return directClient, nil
}

func connectViaProxyTunnel(ctx context.Context, cfg *Config) (auth.ClientI, error) {
func connectViaProxyTunnel(ctx context.Context, cfg *Config) (*auth.Client, error) {
// If direct dial failed, we may have a proxy address in
// cfg.AuthServers. Try connecting to the reverse tunnel
// endpoint and make a client over that.
Expand Down Expand Up @@ -145,7 +145,7 @@ func connectViaProxyTunnel(ctx context.Context, cfg *Config) (auth.ClientI, erro
Credentials: []apiclient.Credentials{
apiclient.LoadTLS(cfg.TLS),
},
PromptAdminRequestMFA: cfg.PromptAdminRequestMFA,
MFAPromptConstructor: cfg.MFAPromptConstructor,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down
16 changes: 6 additions & 10 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import (
apidefaults "github.com/gravitational/teleport/api/defaults"
devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1"
kubeproto "github.com/gravitational/teleport/api/gen/proto/go/teleport/kube/v1"
"github.com/gravitational/teleport/api/mfa"
apitracing "github.com/gravitational/teleport/api/observability/tracing"
tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh"
"github.com/gravitational/teleport/api/profile"
Expand All @@ -69,7 +70,6 @@ import (
"github.com/gravitational/teleport/lib/auth/touchid"
wancli "github.com/gravitational/teleport/lib/auth/webauthncli"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/client/mfa"
"github.com/gravitational/teleport/lib/client/terminal"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/devicetrust"
Expand Down Expand Up @@ -465,10 +465,6 @@ type Config struct {
// Defaults to [dtenroll.AutoEnroll].
dtAutoEnroll dtAutoEnrollFunc

// PromptMFAFunc allows tests to override the default MFA prompt function.
// Defaults to [mfa.NewPrompt().Run].
PromptMFAFunc PromptMFAFunc

// WebauthnLogin allows tests to override the Webauthn Login func.
// Defaults to [wancli.Login].
WebauthnLogin WebauthnLoginFunc
Expand Down Expand Up @@ -2729,7 +2725,7 @@ func (tc *TeleportClient) runCommandOnNodes(ctx context.Context, clt *ClusterCli
}

if tc.SSHLogDir != "" {
if err := os.MkdirAll(tc.SSHLogDir, 0700); err != nil {
if err := os.MkdirAll(tc.SSHLogDir, 0o700); err != nil {
return trace.ConvertSystemError(err)
}
}
Expand Down Expand Up @@ -2963,7 +2959,7 @@ func (tc *TeleportClient) ConnectToCluster(ctx context.Context) (*ClusterClient,
}

authClientCfg := pclt.ClientConfig(ctx, cluster)
authClientCfg.PromptAdminRequestMFA = tc.NewMFAPrompt(mfa.WithHintBeforePrompt(mfa.AdminMFAHintBeforePrompt))
authClientCfg.MFAPromptConstructor = tc.NewMFAPrompt
authClient, err := auth.NewClient(authClientCfg)
if err != nil {
return nil, trace.NewAggregate(err, pclt.Close())
Expand Down Expand Up @@ -3824,7 +3820,7 @@ func (tc *TeleportClient) mfaLocalLoginWeb(ctx context.Context, priv *keys.Priva
SSHLogin: sshLogin,
User: tc.Username,
Password: password,
PromptMFA: tc.PromptMFA,
PromptMFA: tc.NewMFAPrompt(),
})
return clt, session, trace.Wrap(err)
}
Expand Down Expand Up @@ -4123,7 +4119,7 @@ func (tc *TeleportClient) mfaLocalLogin(ctx context.Context, priv *keys.PrivateK
SSHLogin: sshLogin,
User: tc.Username,
Password: password,
PromptMFA: tc.PromptMFA,
PromptMFA: tc.NewMFAPrompt(),
})

return response, trace.Wrap(err)
Expand Down Expand Up @@ -5174,7 +5170,7 @@ func (tc *TeleportClient) NewKubernetesServiceClient(ctx context.Context, cluste
},
ALPNConnUpgradeRequired: tc.TLSRoutingConnUpgradeRequired,
InsecureAddressDiscovery: tc.InsecureSkipVerify,
PromptAdminRequestMFA: tc.NewMFAPrompt(mfa.WithHintBeforePrompt(mfa.AdminMFAHintBeforePrompt)),
MFAPromptConstructor: tc.NewMFAPrompt,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down
Loading