Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion api/utils/grpc/interceptors/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ func WithMFAUnaryInterceptor(mfaCeremony mfa.MFACeremony) grpc.UnaryClientInterc
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
// Check for MFA response passed through the context.
if mfaResp, err := mfa.MFAResponseFromContext(ctx); err == nil {
return invoker(ctx, method, req, reply, cc, append(opts, mfa.WithCredentials(mfaResp))...)
// If we find an MFA response passed through the context, attach it to the
// request. Note: this may still fail if the MFA response allows reuse and
// the specified endpoint doesn't allow reuse. In this case, the client
// prompts for MFA again below.
opts = append(opts, mfa.WithCredentials(mfaResp))
} else if !trace.IsNotFound(err) {
return trace.Wrap(err)
}
Expand Down
124 changes: 117 additions & 7 deletions api/utils/grpc/interceptors/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,24 @@ import (
"github.com/gravitational/teleport/api/utils/grpc/interceptors"
)

const otpTestCode = "otp-test-code"
const (
otpTestCode = "otp-test-code"
otpTestCodeReusable = "otp-test-code-reusable"
)

type mfaService struct {
allowReuse bool
proto.UnimplementedAuthServiceServer
}

func (s *mfaService) Ping(ctx context.Context, req *proto.PingRequest) (*proto.PingResponse, error) {
if err := verifyMFAFromContext(ctx); err != nil {
if err := s.verifyMFAFromContext(ctx); err != nil {
return nil, trace.Wrap(err)
}
return &proto.PingResponse{}, nil
}

func verifyMFAFromContext(ctx context.Context) error {
func (s *mfaService) verifyMFAFromContext(ctx context.Context) error {
mfaResp, err := mfa.CredentialsFromContext(ctx)
if err != nil {
// (In production consider logging err, so we don't swallow it silently.)
Expand All @@ -53,14 +57,20 @@ func verifyMFAFromContext(ctx context.Context) error {

switch r := mfaResp.Response.(type) {
case *proto.MFAAuthenticateResponse_TOTP:
if r.TOTP.Code != otpTestCode {
return trace.AccessDenied("failed MFA verification")
switch r.TOTP.Code {
case otpTestCode:
return nil
case otpTestCodeReusable:
if s.allowReuse {
return nil
}
fallthrough
default:
return trace.Wrap(&mfa.ErrAdminActionMFARequired)
}
default:
return trace.BadParameter("unexpected mfa response type %T", r)
}

return nil
}

// TestGRPCErrorWrapping tests the error wrapping capability of the client
Expand Down Expand Up @@ -169,3 +179,103 @@ func TestRetryWithMFA(t *testing.T) {
})
})
}

func TestRetryWithMFA_Reuse(t *testing.T) {
t.Parallel()
ctx := context.Background()

mtlsConfig := mtls.NewConfig(t)
listener, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)

mfaService := &mfaService{}
server := grpc.NewServer(
grpc.Creds(credentials.NewTLS(mtlsConfig.ServerTLS)),
grpc.ChainUnaryInterceptor(interceptors.GRPCServerUnaryErrorInterceptor),
)
proto.RegisterAuthServiceServer(server, mfaService)
go func() {
server.Serve(listener)
}()
defer server.Stop()

okMFACeremony := func(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) {
return &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_TOTP{
TOTP: &proto.TOTPResponse{
Code: otpTestCode,
},
},
}, nil
}

okMFACeremonyAllowReuse := func(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) {
return &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_TOTP{
TOTP: &proto.TOTPResponse{
Code: otpTestCodeReusable,
},
},
}, nil
}

t.Run("ok allow reuse", func(t *testing.T) {
mfaService.allowReuse = true
conn, err := grpc.Dial(
listener.Addr().String(),
grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)),
grpc.WithChainUnaryInterceptor(
interceptors.WithMFAUnaryInterceptor(okMFACeremonyAllowReuse),
interceptors.GRPCClientUnaryErrorInterceptor,
),
)
require.NoError(t, err)
defer conn.Close()

client := proto.NewAuthServiceClient(conn)
_, err = client.Ping(ctx, &proto.PingRequest{})
assert.NoError(t, err)
})

t.Run("nok disallow reuse", func(t *testing.T) {
mfaService.allowReuse = false
conn, err := grpc.Dial(
listener.Addr().String(),
grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)),
grpc.WithChainUnaryInterceptor(
interceptors.WithMFAUnaryInterceptor(okMFACeremonyAllowReuse),
interceptors.GRPCClientUnaryErrorInterceptor,
),
)
require.NoError(t, err)
defer conn.Close()

client := proto.NewAuthServiceClient(conn)
_, err = client.Ping(ctx, &proto.PingRequest{})
assert.ErrorIs(t, err, &mfa.ErrAdminActionMFARequired, "Ping error mismatch")
})

t.Run("ok disallow reuse, retry with one-shot mfa", func(t *testing.T) {
mfaService.allowReuse = false
conn, err := grpc.Dial(
listener.Addr().String(),
grpc.WithTransportCredentials(credentials.NewTLS(mtlsConfig.ClientTLS)),
grpc.WithChainUnaryInterceptor(
interceptors.WithMFAUnaryInterceptor(okMFACeremony),
interceptors.GRPCClientUnaryErrorInterceptor,
),
)
require.NoError(t, err)
defer conn.Close()

// Pass reusable MFA through the context. The interceptor should
// catch the resulting ErrAdminActionMFARequired and retry with
// a one-shot mfa challenge.
mfaResp, _ := okMFACeremony(ctx, nil)
ctx := mfa.ContextWithMFAResponse(ctx, mfaResp)

client := proto.NewAuthServiceClient(conn)
_, err = client.Ping(ctx, &proto.PingRequest{})
assert.NoError(t, err)
})
}
4 changes: 2 additions & 2 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -2060,7 +2060,7 @@ func (a *ServerWithRoles) GetTokens(ctx context.Context) ([]types.ProvisionToken
return nil, trace.Wrap(err)
}

if err := a.context.AuthorizeAdminAction(); err != nil {
if err := a.context.AuthorizeAdminActionAllowReusedMFA(); err != nil {
return nil, trace.Wrap(err)
}

Expand All @@ -2076,7 +2076,7 @@ func (a *ServerWithRoles) GetToken(ctx context.Context, token string) (types.Pro
}
}

if err := a.context.AuthorizeAdminAction(); err != nil {
if err := a.context.AuthorizeAdminActionAllowReusedMFA(); err != nil {
return nil, trace.Wrap(err)
}

Expand Down