Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions api/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ func TestListResources(t *testing.T) {
ResourceType: test.resourceType,
})
require.Error(t, err)
require.IsType(t, &trace.LimitExceededError{}, err.(*trace.TraceErr).OrigError())
require.True(t, trace.IsLimitExceeded(err), "trace.IsLimitExceeded failed: err=%v (%T)", err, trace.Unwrap(err))
})
}

Expand All @@ -523,7 +523,7 @@ func testGetResources[T types.ResourceWithLabels](t *testing.T, clt *Client, kin
ResourceType: kind,
})
require.Error(t, err)
require.IsType(t, &trace.LimitExceededError{}, err.(*trace.TraceErr).OrigError())
require.True(t, trace.IsLimitExceeded(err), "trace.IsLimitExceeded failed: err=%v (%T)", err, trace.Unwrap(err))

// Test getting a page of resources
page, err := GetResourcePage[T](ctx, clt, &proto.ListResourcesRequest{
Expand Down Expand Up @@ -631,7 +631,7 @@ func TestGetResourcesWithFilters(t *testing.T) {
ResourceType: test.resourceType,
})
require.Error(t, err)
require.IsType(t, &trace.LimitExceededError{}, err.(*trace.TraceErr).OrigError())
require.True(t, trace.IsLimitExceeded(err), "trace.IsLimitExceeded failed: err=%v (%T)", err, trace.Unwrap(err))

// Test getting all resources by chunks to handle limit exceeded.
resources, err := GetResourcesWithFilters(ctx, clt, proto.ListResourcesRequest{
Expand Down
43 changes: 39 additions & 4 deletions api/utils/grpc/interceptors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package interceptors

import (
"context"
"errors"
"io"

"github.com/gravitational/trace"
"github.com/gravitational/trace/trail"
Expand Down Expand Up @@ -48,12 +50,24 @@ type grpcClientStreamWrapper struct {

// SendMsg wraps around ClientStream.SendMsg
func (s *grpcClientStreamWrapper) SendMsg(m interface{}) error {
return trace.Unwrap(trail.FromGRPC(s.ClientStream.SendMsg(m)))
return wrapStreamErr(s.ClientStream.SendMsg(m))
}

// RecvMsg wraps around ClientStream.RecvMsg
func (s *grpcClientStreamWrapper) RecvMsg(m interface{}) error {
return trace.Unwrap(trail.FromGRPC(s.ClientStream.RecvMsg(m)))
return wrapStreamErr(s.ClientStream.RecvMsg(m))
}

func wrapStreamErr(err error) error {
switch {
case err == nil:
return nil
case errors.Is(err, io.EOF):
// Do not wrap io.EOF errors, they are often used as stop guards for streams.
return err
default:
return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(err))}
}
}

// GRPCServerUnaryErrorInterceptor is a gRPC unary server interceptor that
Expand All @@ -63,10 +77,31 @@ func GRPCServerUnaryErrorInterceptor(ctx context.Context, req interface{}, info
return resp, trace.Unwrap(trail.ToGRPC(err))
}

// RemoteError annotates server-side errors translated into trace by
// [GRPCClientUnaryErrorInterceptor] or [GRPCClientStreamErrorInterceptor].
type RemoteError struct {
// Err is the underlying error.
Err error
}

func (e *RemoteError) Error() string {
if e.Err == nil {
return ""
}
return e.Err.Error()
}

func (e *RemoteError) Unwrap() error {
return e.Err
}

// GRPCClientUnaryErrorInterceptor is a gRPC unary client interceptor that
// handles converting errors to the appropriate grpc status error.
func GRPCClientUnaryErrorInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
return trace.Unwrap(trail.FromGRPC(invoker(ctx, method, req, reply, cc, opts...)))
if err := invoker(ctx, method, req, reply, cc, opts...); err != nil {
return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(err))}
}
return nil
}

// GRPCServerStreamErrorInterceptor is a gRPC server stream interceptor that
Expand All @@ -81,7 +116,7 @@ func GRPCServerStreamErrorInterceptor(srv interface{}, ss grpc.ServerStream, inf
func GRPCClientStreamErrorInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
s, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
return nil, trace.Unwrap(trail.ToGRPC(err))
return nil, &RemoteError{Err: trace.Unwrap(trail.ToGRPC(err))}
}
return &grpcClientStreamWrapper{s}, nil
}
37 changes: 24 additions & 13 deletions api/utils/grpc/interceptors/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ package interceptors

import (
"context"
"errors"
"io"
"net"
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
Expand Down Expand Up @@ -73,31 +74,41 @@ func TestGRPCErrorWrapping(t *testing.T) {

t.Run("unary interceptor", func(t *testing.T) {
resp, err := client.Ping(context.Background(), &proto.PingRequest{})
require.Nil(t, resp)
require.True(t, trace.IsNotFound(err))
require.Equal(t, err.Error(), "not found")
assert.Nil(t, resp, "resp is non-nil")
assert.True(t, trace.IsNotFound(err), "trace.IsNotFound failed: err=%v (%T)", err, trace.Unwrap(err))
assert.Equal(t, "not found", err.Error())
_, ok := err.(*trace.TraceErr)
require.False(t, ok, "client error should not include traces originating in the middleware")
assert.False(t, ok, "client error should not include traces originating in the middleware")
var remoteErr *RemoteError
assert.ErrorAs(t, err, &remoteErr, "Remote error is not marked as an interceptors.RemoteError")
})

t.Run("stream interceptor", func(t *testing.T) {
stream, err := client.AddMFADevice(context.Background())
require.NoError(t, err)

// Give the server time to close the stream. This allows us to more
// consistently hit the io.EOF error.
time.Sleep(100 * time.Millisecond)

//nolint:staticcheck // SA1019. The specific stream used here doesn't matter.
sendErr := stream.Send(&proto.AddMFADeviceRequest{})

// io.EOF means the server closed the stream, which can
// happen depending in timing. In either case, it is
// still safe to recv from the stream and check for
// Expect either a success (unlikely because of the Sleep) or an unwrapped
// io.EOF error (meaning the server errored and closed the stream).
// In either case, it is still safe to recv from the stream and check for
// the already exists error.
if sendErr != nil && !errors.Is(sendErr, io.EOF) {
t.Fatalf("Unexpected error: %v", sendErr)
if sendErr != nil && sendErr != io.EOF /* == error comparison on purpose! */ {
t.Fatalf("Unexpected error: %q (%T)", sendErr, sendErr)
}

_, err = stream.Recv()
require.True(t, trace.IsAlreadyExists(err))
require.Equal(t, err.Error(), "already exists")
assert.True(t, trace.IsAlreadyExists(err), "trace.IsAlreadyExists failed: err=%v (%T)", err, trace.Unwrap(err))
assert.Equal(t, "already exists", err.Error())
_, ok := err.(*trace.TraceErr)
require.False(t, ok, "client error should not include traces originating in the middleware")
assert.False(t, ok, "client error should not include traces originating in the middleware")
assert.True(t, trace.IsAlreadyExists(err), "trace.IsAlreadyExists failed: err=%v (%T)", err, trace.Unwrap(err))
var remoteErr *RemoteError
assert.ErrorAs(t, err, &remoteErr, "Remote error is not marked as an interceptors.RemoteError")
})
}
10 changes: 6 additions & 4 deletions lib/auth/grpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2502,11 +2502,11 @@ func TestNodesCRUD(t *testing.T) {

// GetNode should fail if node name isn't provided
_, err = clt.GetNode(ctx, apidefaults.Namespace, "")
require.IsType(t, &trace.BadParameterError{}, err.(*trace.TraceErr).OrigError())
require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err))

// GetNode should fail if namespace isn't provided
_, err = clt.GetNode(ctx, "", "node1")
require.IsType(t, &trace.BadParameterError{}, err.(*trace.TraceErr).OrigError())
require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err))
})
})

Expand Down Expand Up @@ -4478,7 +4478,8 @@ func TestUpsertApplicationServerOrigin(t *testing.T) {

ctx = authz.ContextWithUser(parentCtx, admin.I)
_, err = client.UpsertApplicationServer(ctx, appServer)
require.ErrorIs(t, trace.BadParameter("only the Okta role can create app servers and apps with an Okta origin"), err)
require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err))
require.ErrorContains(t, err, "only the Okta role can create app servers and apps with an Okta origin")

// Okta origin should not work with instance and node roles.
client, err = server.NewClient(TestIdentity{
Expand All @@ -4494,7 +4495,8 @@ func TestUpsertApplicationServerOrigin(t *testing.T) {

ctx = authz.ContextWithUser(parentCtx, admin.I)
_, err = client.UpsertApplicationServer(ctx, appServer)
require.ErrorIs(t, trace.BadParameter("only the Okta role can create app servers and apps with an Okta origin"), err)
require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err))
require.ErrorContains(t, err, "only the Okta role can create app servers and apps with an Okta origin")

// Okta origin should work with Okta role in role field.
node := TestIdentity{
Expand Down
4 changes: 2 additions & 2 deletions lib/auth/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2334,7 +2334,7 @@ func TestGenerateCerts(t *testing.T) {
Format: constants.CertificateFormatStandard,
})
require.Error(t, err)
require.IsType(t, &trace.AccessDeniedError{}, trace.Unwrap(err))
require.True(t, trace.IsAccessDenied(err), "trace.IsAccessDenied failed: err=%v (%T)", err, trace.Unwrap(err))

_, privateKeyPEM, err := utils.MarshalPrivateKey(privateKey.(crypto.Signer))
require.NoError(t, err)
Expand All @@ -2352,7 +2352,7 @@ func TestGenerateCerts(t *testing.T) {
Format: constants.CertificateFormatStandard,
})
require.Error(t, err)
require.IsType(t, &trace.AccessDeniedError{}, trace.Unwrap(err))
require.True(t, trace.IsAccessDenied(err), "trace.IsAccessDenied failed: err=%v (%T)", err, trace.Unwrap(err))
require.Contains(t, err.Error(), "impersonated user can not impersonate anyone else")

// but can renew their own cert, for example set route to cluster
Expand Down
25 changes: 20 additions & 5 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ func RetryWithRelogin(ctx context.Context, tc *TeleportClient, fn func() error,
for _, o := range opts {
o(&opt)
}
log.Debugf("Activating relogin on %v.", fnErr)
log.Debugf("Activating relogin on error=%q (type=%T)", fnErr, trace.Unwrap(fnErr))

// check if the error is a private key policy error.
if privateKeyPolicy, err := keys.ParsePrivateKeyPolicyError(fnErr); err == nil {
Expand Down Expand Up @@ -630,11 +630,26 @@ func WithBeforeLoginHook(fn func() error) RetryWithReloginOption {
}
}

// IsErrorResolvableWithRelogin returns true if relogin is attempted on `err`.
func IsErrorResolvableWithRelogin(err error) bool {
// Assume that failed handshake is a result of expired credentials.
return utils.IsHandshakeFailedError(err) || utils.IsCertExpiredError(err) ||
trace.IsBadParameter(err) || trace.IsTrustError(err) ||
keys.IsPrivateKeyPolicyError(err) || IsNoCredentialsError(err)
// Ignore any failures resulting from RPCs.
// These were all materialized as status.Error here before
// https://github.com/gravitational/teleport/pull/30578.
var remoteErr *interceptors.RemoteError
if errors.As(err, &remoteErr) {
return false
}

return keys.IsPrivateKeyPolicyError(err) ||
// TODO(codingllama): Retrying BadParameter is a terrible idea.
// We should fix this and remove the RemoteError condition above as well.
// Any retriable error should be explicitly marked as such.
trace.IsBadParameter(err) ||
trace.IsTrustError(err) ||
utils.IsCertExpiredError(err) ||
// Assume that failed handshake is a result of expired credentials.
utils.IsHandshakeFailedError(err) ||
IsNoCredentialsError(err)
}

// GetProfile gets the profile for the specified proxy address, or
Expand Down