Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions api/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ func TestListResources(t *testing.T) {
ResourceType: test.resourceType,
})
require.Error(t, err)
require.IsType(t, &trace.LimitExceededError{}, err.(*trace.TraceErr).OrigError())
require.True(t, trace.IsLimitExceeded(err), "trace.IsLimitExceeded failed: err=%v (%T)", err, trace.Unwrap(err))
})
}

Expand All @@ -528,7 +528,7 @@ func testGetResources[T types.ResourceWithLabels](t *testing.T, clt *Client, kin
ResourceType: kind,
})
require.Error(t, err)
require.IsType(t, &trace.LimitExceededError{}, err.(*trace.TraceErr).OrigError())
require.True(t, trace.IsLimitExceeded(err), "trace.IsLimitExceeded failed: err=%v (%T)", err, trace.Unwrap(err))

// Test getting a page of resources
page, err := GetResourcePage[T](ctx, clt, &proto.ListResourcesRequest{
Expand Down Expand Up @@ -636,7 +636,7 @@ func TestGetResourcesWithFilters(t *testing.T) {
ResourceType: test.resourceType,
})
require.Error(t, err)
require.IsType(t, &trace.LimitExceededError{}, err.(*trace.TraceErr).OrigError())
require.True(t, trace.IsLimitExceeded(err), "trace.IsLimitExceeded failed: err=%v (%T)", err, trace.Unwrap(err))

// Test getting all resources by chunks to handle limit exceeded.
resources, err := GetResourcesWithFilters(ctx, clt, proto.ListResourcesRequest{
Expand Down
41 changes: 37 additions & 4 deletions api/utils/grpc/interceptors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package interceptors

import (
"context"
"errors"
"io"

"github.com/gravitational/trace"
"github.com/gravitational/trace/trail"
Expand Down Expand Up @@ -48,12 +50,22 @@ type grpcClientStreamWrapper struct {

// SendMsg wraps around ClientStream.SendMsg
func (s *grpcClientStreamWrapper) SendMsg(m interface{}) error {
return trace.Unwrap(trail.FromGRPC(s.ClientStream.SendMsg(m)))
if err := s.ClientStream.SendMsg(m); err != nil {
return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(s.ClientStream.SendMsg(m)))}
}
return nil
}

// RecvMsg wraps around ClientStream.RecvMsg
func (s *grpcClientStreamWrapper) RecvMsg(m interface{}) error {
return trace.Unwrap(trail.FromGRPC(s.ClientStream.RecvMsg(m)))
switch err := s.ClientStream.RecvMsg(m); {
case errors.Is(err, io.EOF):
// Do not wrap io.EOF errors, they are often used as stop guards for streams.
return err
case err != nil:
return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(s.ClientStream.RecvMsg(m)))}
}
return nil
}

// GRPCServerUnaryErrorInterceptor is a gRPC unary server interceptor that
Expand All @@ -63,10 +75,31 @@ func GRPCServerUnaryErrorInterceptor(ctx context.Context, req interface{}, info
return resp, trace.Unwrap(trail.ToGRPC(err))
}

// RemoteError annotates server-side errors translated into trace by
// [GRPCClientUnaryErrorInterceptor] or [GRPCClientStreamErrorInterceptor].
type RemoteError struct {
// Err is the underlying error.
Err error
}

func (e *RemoteError) Error() string {
if e.Err == nil {
return ""
}
return e.Err.Error()
}

func (e *RemoteError) Unwrap() error {
return e.Err
}

// GRPCClientUnaryErrorInterceptor is a gRPC unary client interceptor that
// handles converting errors to the appropriate grpc status error.
func GRPCClientUnaryErrorInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
return trace.Unwrap(trail.FromGRPC(invoker(ctx, method, req, reply, cc, opts...)))
if err := invoker(ctx, method, req, reply, cc, opts...); err != nil {
return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(err))}
}
return nil
}

// GRPCServerStreamErrorInterceptor is a gRPC server stream interceptor that
Expand All @@ -81,7 +114,7 @@ func GRPCServerStreamErrorInterceptor(srv interface{}, ss grpc.ServerStream, inf
func GRPCClientStreamErrorInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
s, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
return nil, trace.Unwrap(trail.ToGRPC(err))
return nil, &RemoteError{Err: trace.Unwrap(trail.ToGRPC(err))}
}
return &grpcClientStreamWrapper{s}, nil
}
20 changes: 13 additions & 7 deletions api/utils/grpc/interceptors/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"testing"

"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
Expand Down Expand Up @@ -73,11 +74,13 @@ func TestGRPCErrorWrapping(t *testing.T) {

t.Run("unary interceptor", func(t *testing.T) {
resp, err := client.Ping(context.Background(), &proto.PingRequest{})
require.Nil(t, resp)
require.True(t, trace.IsNotFound(err))
require.Equal(t, "not found", err.Error())
assert.Nil(t, resp, "resp is non-nil")
assert.True(t, trace.IsNotFound(err), "trace.IsNotFound failed: err=%v (%T)", err, trace.Unwrap(err))
assert.Equal(t, "not found", err.Error())
_, ok := err.(*trace.TraceErr)
require.False(t, ok, "client error should not include traces originating in the middleware")
assert.False(t, ok, "client error should not include traces originating in the middleware")
var remoteErr *interceptors.RemoteError
assert.ErrorAs(t, err, &remoteErr, "Remote error is not marked as an interceptors.RemoteError")
})

t.Run("stream interceptor", func(t *testing.T) {
Expand All @@ -97,9 +100,12 @@ func TestGRPCErrorWrapping(t *testing.T) {
}

_, err = stream.Recv()
require.True(t, trace.IsAlreadyExists(err))
require.Equal(t, "already exists", err.Error())
assert.True(t, trace.IsAlreadyExists(err), "trace.IsAlreadyExists failed: err=%v (%T)", err, trace.Unwrap(err))
assert.Equal(t, "already exists", err.Error())
_, ok := err.(*trace.TraceErr)
require.False(t, ok, "client error should not include traces originating in the middleware")
assert.False(t, ok, "client error should not include traces originating in the middleware")
assert.True(t, trace.IsAlreadyExists(err), "trace.IsAlreadyExists failed: err=%v (%T)", err, trace.Unwrap(err))
var remoteErr *interceptors.RemoteError
assert.ErrorAs(t, err, &remoteErr, "Remote error is not marked as an interceptors.RemoteError")
})
}
10 changes: 6 additions & 4 deletions lib/auth/grpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2374,11 +2374,11 @@ func TestNodesCRUD(t *testing.T) {

// GetNode should fail if node name isn't provided
_, err = clt.GetNode(ctx, apidefaults.Namespace, "")
require.IsType(t, &trace.BadParameterError{}, err.(*trace.TraceErr).OrigError())
require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err))

// GetNode should fail if namespace isn't provided
_, err = clt.GetNode(ctx, "", "node1")
require.IsType(t, &trace.BadParameterError{}, err.(*trace.TraceErr).OrigError())
require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err))
})
})

Expand Down Expand Up @@ -4206,7 +4206,8 @@ func TestUpsertApplicationServerOrigin(t *testing.T) {

ctx = authz.ContextWithUser(parentCtx, admin.I)
_, err = client.UpsertApplicationServer(ctx, appServer)
require.ErrorIs(t, trace.BadParameter("only the Okta role can create app servers and apps with an Okta origin"), err)
require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err))
require.ErrorContains(t, err, "only the Okta role can create app servers and apps with an Okta origin")

// Okta origin should not work with instance and node roles.
client, err = server.NewClient(TestIdentity{
Expand All @@ -4222,7 +4223,8 @@ func TestUpsertApplicationServerOrigin(t *testing.T) {

ctx = authz.ContextWithUser(parentCtx, admin.I)
_, err = client.UpsertApplicationServer(ctx, appServer)
require.ErrorIs(t, trace.BadParameter("only the Okta role can create app servers and apps with an Okta origin"), err)
require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err))
require.ErrorContains(t, err, "only the Okta role can create app servers and apps with an Okta origin")

// Okta origin should work with Okta role in role field.
node := TestIdentity{
Expand Down
4 changes: 2 additions & 2 deletions lib/auth/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2327,7 +2327,7 @@ func TestGenerateCerts(t *testing.T) {
Format: constants.CertificateFormatStandard,
})
require.Error(t, err)
require.IsType(t, &trace.AccessDeniedError{}, trace.Unwrap(err))
require.True(t, trace.IsAccessDenied(err), "trace.IsAccessDenied failed: err=%v (%T)", err, trace.Unwrap(err))

_, privateKeyPEM, err := utils.MarshalPrivateKey(privateKey.(crypto.Signer))
require.NoError(t, err)
Expand All @@ -2345,7 +2345,7 @@ func TestGenerateCerts(t *testing.T) {
Format: constants.CertificateFormatStandard,
})
require.Error(t, err)
require.IsType(t, &trace.AccessDeniedError{}, trace.Unwrap(err))
require.True(t, trace.IsAccessDenied(err), "trace.IsAccessDenied failed: err=%v (%T)", err, trace.Unwrap(err))
require.Contains(t, err.Error(), "impersonated user can not impersonate anyone else")

// but can renew their own cert, for example set route to cluster
Expand Down
25 changes: 20 additions & 5 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ func RetryWithRelogin(ctx context.Context, tc *TeleportClient, fn func() error,
for _, o := range opts {
o(&opt)
}
log.Debugf("Activating relogin on %v.", fnErr)
log.Debugf("Activating relogin on error=%q (type=%T)", fnErr, trace.Unwrap(fnErr))

if keys.IsPrivateKeyPolicyError(fnErr) {
privateKeyPolicy, err := keys.ParsePrivateKeyPolicyError(fnErr)
Expand Down Expand Up @@ -668,11 +668,26 @@ func WithBeforeLoginHook(fn func() error) RetryWithReloginOption {
}
}

// IsErrorResolvableWithRelogin returns true if relogin is attempted on `err`.
func IsErrorResolvableWithRelogin(err error) bool {
// Assume that failed handshake is a result of expired credentials.
return utils.IsHandshakeFailedError(err) || utils.IsCertExpiredError(err) ||
trace.IsBadParameter(err) || trace.IsTrustError(err) ||
keys.IsPrivateKeyPolicyError(err) || IsNoCredentialsError(err)
// Ignore any failures resulting from RPCs.
// These were all materialized as status.Error here before
// https://github.com/gravitational/teleport/pull/30578.
var remoteErr *interceptors.RemoteError
if errors.As(err, &remoteErr) {
return false
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@codingllama should we treat all remote errors equally here?
We can get Original Error: *interceptors.RemoteError access denied: client credentials have expired, please relogin. for which we return false but it seems that for this particular case it should be true?

I ask because I'm working on adding a client cache to Connect, and this means that the client no longer checks the user cert before making call, but instead it has to rely on errors from the server.

The original comment #38202 (comment).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @ravicious

This block here is a poor stopgap to avoid tsh from retrying all kinds of errors it shouldn't be. As you have observed it is prone to making bad choices - the entire idea of IsErrorResolvableWithRelogin is, as it loses all context from the call site.

I think the better way of doing this is explicitly marking errors as retriable by wrapping them with a RetryableError type. This way we can check for that wrapper (with errors.As) and return true if we find it. That means finding the client-side callsite for GenerateUserCerts, inspecting the response and marking it as retriable accordingly.

I'd also recommend that we have a guard error for "client credentials have expired" - as in an exported var we can check for in its entirety - so we don't go looking for specific substrings.

That's my 2c on the issue. Happy to talk more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That means finding the client-side callsite for GenerateUserCerts, inspecting the response and marking it as retriable accordingly.

So we would have to wrap methods in api/client/client.go manually? Or could we use the interceptor (where we add RemoteError) and check for 'client credentials have expired' there?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrapping client.go manually is the best choice, imo, as there is no loss of context and little chance of false-positives. That's how I think this should have started.

For a "generic" place I would do it here, not in the interceptor.

}

return keys.IsPrivateKeyPolicyError(err) ||
// TODO(codingllama): Retrying BadParameter is a terrible idea.
// We should fix this and remove the RemoteError condition above as well.
// Any retriable error should be explicitly marked as such.
trace.IsBadParameter(err) ||
trace.IsTrustError(err) ||
utils.IsCertExpiredError(err) ||
// Assume that failed handshake is a result of expired credentials.
utils.IsHandshakeFailedError(err) ||
IsNoCredentialsError(err)
}

// GetProfile gets the profile for the specified proxy address, or
Expand Down