diff --git a/api/client/client_test.go b/api/client/client_test.go index 35c812731e6f1..7d6428549fbdf 100644 --- a/api/client/client_test.go +++ b/api/client/client_test.go @@ -502,7 +502,7 @@ func TestListResources(t *testing.T) { ResourceType: test.resourceType, }) require.Error(t, err) - require.IsType(t, &trace.LimitExceededError{}, err.(*trace.TraceErr).OrigError()) + require.True(t, trace.IsLimitExceeded(err), "trace.IsLimitExceeded failed: err=%v (%T)", err, trace.Unwrap(err)) }) } @@ -528,7 +528,7 @@ func testGetResources[T types.ResourceWithLabels](t *testing.T, clt *Client, kin ResourceType: kind, }) require.Error(t, err) - require.IsType(t, &trace.LimitExceededError{}, err.(*trace.TraceErr).OrigError()) + require.True(t, trace.IsLimitExceeded(err), "trace.IsLimitExceeded failed: err=%v (%T)", err, trace.Unwrap(err)) // Test getting a page of resources page, err := GetResourcePage[T](ctx, clt, &proto.ListResourcesRequest{ @@ -636,7 +636,7 @@ func TestGetResourcesWithFilters(t *testing.T) { ResourceType: test.resourceType, }) require.Error(t, err) - require.IsType(t, &trace.LimitExceededError{}, err.(*trace.TraceErr).OrigError()) + require.True(t, trace.IsLimitExceeded(err), "trace.IsLimitExceeded failed: err=%v (%T)", err, trace.Unwrap(err)) // Test getting all resources by chunks to handle limit exceeded. resources, err := GetResourcesWithFilters(ctx, clt, proto.ListResourcesRequest{ diff --git a/api/utils/grpc/interceptors/errors.go b/api/utils/grpc/interceptors/errors.go index d387c9adbc34c..e765cc10a94a0 100644 --- a/api/utils/grpc/interceptors/errors.go +++ b/api/utils/grpc/interceptors/errors.go @@ -16,6 +16,8 @@ package interceptors import ( "context" + "errors" + "io" "github.com/gravitational/trace" "github.com/gravitational/trace/trail" @@ -48,12 +50,22 @@ type grpcClientStreamWrapper struct { // SendMsg wraps around ClientStream.SendMsg func (s *grpcClientStreamWrapper) SendMsg(m interface{}) error { - return trace.Unwrap(trail.FromGRPC(s.ClientStream.SendMsg(m))) + if err := s.ClientStream.SendMsg(m); err != nil { + return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(s.ClientStream.SendMsg(m)))} + } + return nil } // RecvMsg wraps around ClientStream.RecvMsg func (s *grpcClientStreamWrapper) RecvMsg(m interface{}) error { - return trace.Unwrap(trail.FromGRPC(s.ClientStream.RecvMsg(m))) + switch err := s.ClientStream.RecvMsg(m); { + case errors.Is(err, io.EOF): + // Do not wrap io.EOF errors, they are often used as stop guards for streams. + return err + case err != nil: + return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(s.ClientStream.RecvMsg(m)))} + } + return nil } // GRPCServerUnaryErrorInterceptor is a gRPC unary server interceptor that @@ -63,10 +75,31 @@ func GRPCServerUnaryErrorInterceptor(ctx context.Context, req interface{}, info return resp, trace.Unwrap(trail.ToGRPC(err)) } +// RemoteError annotates server-side errors translated into trace by +// [GRPCClientUnaryErrorInterceptor] or [GRPCClientStreamErrorInterceptor]. +type RemoteError struct { + // Err is the underlying error. + Err error +} + +func (e *RemoteError) Error() string { + if e.Err == nil { + return "" + } + return e.Err.Error() +} + +func (e *RemoteError) Unwrap() error { + return e.Err +} + // GRPCClientUnaryErrorInterceptor is a gRPC unary client interceptor that // handles converting errors to the appropriate grpc status error. func GRPCClientUnaryErrorInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - return trace.Unwrap(trail.FromGRPC(invoker(ctx, method, req, reply, cc, opts...))) + if err := invoker(ctx, method, req, reply, cc, opts...); err != nil { + return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(err))} + } + return nil } // GRPCServerStreamErrorInterceptor is a gRPC server stream interceptor that @@ -81,7 +114,7 @@ func GRPCServerStreamErrorInterceptor(srv interface{}, ss grpc.ServerStream, inf func GRPCClientStreamErrorInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { s, err := streamer(ctx, desc, cc, method, opts...) if err != nil { - return nil, trace.Unwrap(trail.ToGRPC(err)) + return nil, &RemoteError{Err: trace.Unwrap(trail.ToGRPC(err))} } return &grpcClientStreamWrapper{s}, nil } diff --git a/api/utils/grpc/interceptors/errors_test.go b/api/utils/grpc/interceptors/errors_test.go index dcd808845ce0e..ffa0d28356e5c 100644 --- a/api/utils/grpc/interceptors/errors_test.go +++ b/api/utils/grpc/interceptors/errors_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -73,11 +74,13 @@ func TestGRPCErrorWrapping(t *testing.T) { t.Run("unary interceptor", func(t *testing.T) { resp, err := client.Ping(context.Background(), &proto.PingRequest{}) - require.Nil(t, resp) - require.True(t, trace.IsNotFound(err)) - require.Equal(t, "not found", err.Error()) + assert.Nil(t, resp, "resp is non-nil") + assert.True(t, trace.IsNotFound(err), "trace.IsNotFound failed: err=%v (%T)", err, trace.Unwrap(err)) + assert.Equal(t, "not found", err.Error()) _, ok := err.(*trace.TraceErr) - require.False(t, ok, "client error should not include traces originating in the middleware") + assert.False(t, ok, "client error should not include traces originating in the middleware") + var remoteErr *interceptors.RemoteError + assert.ErrorAs(t, err, &remoteErr, "Remote error is not marked as an interceptors.RemoteError") }) t.Run("stream interceptor", func(t *testing.T) { @@ -97,9 +100,12 @@ func TestGRPCErrorWrapping(t *testing.T) { } _, err = stream.Recv() - require.True(t, trace.IsAlreadyExists(err)) - require.Equal(t, "already exists", err.Error()) + assert.True(t, trace.IsAlreadyExists(err), "trace.IsAlreadyExists failed: err=%v (%T)", err, trace.Unwrap(err)) + assert.Equal(t, "already exists", err.Error()) _, ok := err.(*trace.TraceErr) - require.False(t, ok, "client error should not include traces originating in the middleware") + assert.False(t, ok, "client error should not include traces originating in the middleware") + assert.True(t, trace.IsAlreadyExists(err), "trace.IsAlreadyExists failed: err=%v (%T)", err, trace.Unwrap(err)) + var remoteErr *interceptors.RemoteError + assert.ErrorAs(t, err, &remoteErr, "Remote error is not marked as an interceptors.RemoteError") }) } diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index 8f6b66138c6e0..e53eefe83e18a 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -2374,11 +2374,11 @@ func TestNodesCRUD(t *testing.T) { // GetNode should fail if node name isn't provided _, err = clt.GetNode(ctx, apidefaults.Namespace, "") - require.IsType(t, &trace.BadParameterError{}, err.(*trace.TraceErr).OrigError()) + require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err)) // GetNode should fail if namespace isn't provided _, err = clt.GetNode(ctx, "", "node1") - require.IsType(t, &trace.BadParameterError{}, err.(*trace.TraceErr).OrigError()) + require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err)) }) }) @@ -4206,7 +4206,8 @@ func TestUpsertApplicationServerOrigin(t *testing.T) { ctx = authz.ContextWithUser(parentCtx, admin.I) _, err = client.UpsertApplicationServer(ctx, appServer) - require.ErrorIs(t, trace.BadParameter("only the Okta role can create app servers and apps with an Okta origin"), err) + require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err)) + require.ErrorContains(t, err, "only the Okta role can create app servers and apps with an Okta origin") // Okta origin should not work with instance and node roles. client, err = server.NewClient(TestIdentity{ @@ -4222,7 +4223,8 @@ func TestUpsertApplicationServerOrigin(t *testing.T) { ctx = authz.ContextWithUser(parentCtx, admin.I) _, err = client.UpsertApplicationServer(ctx, appServer) - require.ErrorIs(t, trace.BadParameter("only the Okta role can create app servers and apps with an Okta origin"), err) + require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err)) + require.ErrorContains(t, err, "only the Okta role can create app servers and apps with an Okta origin") // Okta origin should work with Okta role in role field. node := TestIdentity{ diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index 365135251931f..0b18da2209fa7 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -2327,7 +2327,7 @@ func TestGenerateCerts(t *testing.T) { Format: constants.CertificateFormatStandard, }) require.Error(t, err) - require.IsType(t, &trace.AccessDeniedError{}, trace.Unwrap(err)) + require.True(t, trace.IsAccessDenied(err), "trace.IsAccessDenied failed: err=%v (%T)", err, trace.Unwrap(err)) _, privateKeyPEM, err := utils.MarshalPrivateKey(privateKey.(crypto.Signer)) require.NoError(t, err) @@ -2345,7 +2345,7 @@ func TestGenerateCerts(t *testing.T) { Format: constants.CertificateFormatStandard, }) require.Error(t, err) - require.IsType(t, &trace.AccessDeniedError{}, trace.Unwrap(err)) + require.True(t, trace.IsAccessDenied(err), "trace.IsAccessDenied failed: err=%v (%T)", err, trace.Unwrap(err)) require.Contains(t, err.Error(), "impersonated user can not impersonate anyone else") // but can renew their own cert, for example set route to cluster diff --git a/lib/client/api.go b/lib/client/api.go index 871b267b8293a..dc43e2875e78e 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -597,7 +597,7 @@ func RetryWithRelogin(ctx context.Context, tc *TeleportClient, fn func() error, for _, o := range opts { o(&opt) } - log.Debugf("Activating relogin on %v.", fnErr) + log.Debugf("Activating relogin on error=%q (type=%T)", fnErr, trace.Unwrap(fnErr)) if keys.IsPrivateKeyPolicyError(fnErr) { privateKeyPolicy, err := keys.ParsePrivateKeyPolicyError(fnErr) @@ -668,11 +668,26 @@ func WithBeforeLoginHook(fn func() error) RetryWithReloginOption { } } +// IsErrorResolvableWithRelogin returns true if relogin is attempted on `err`. func IsErrorResolvableWithRelogin(err error) bool { - // Assume that failed handshake is a result of expired credentials. - return utils.IsHandshakeFailedError(err) || utils.IsCertExpiredError(err) || - trace.IsBadParameter(err) || trace.IsTrustError(err) || - keys.IsPrivateKeyPolicyError(err) || IsNoCredentialsError(err) + // Ignore any failures resulting from RPCs. + // These were all materialized as status.Error here before + // https://github.com/gravitational/teleport/pull/30578. + var remoteErr *interceptors.RemoteError + if errors.As(err, &remoteErr) { + return false + } + + return keys.IsPrivateKeyPolicyError(err) || + // TODO(codingllama): Retrying BadParameter is a terrible idea. + // We should fix this and remove the RemoteError condition above as well. + // Any retriable error should be explicitly marked as such. + trace.IsBadParameter(err) || + trace.IsTrustError(err) || + utils.IsCertExpiredError(err) || + // Assume that failed handshake is a result of expired credentials. + utils.IsHandshakeFailedError(err) || + IsNoCredentialsError(err) } // GetProfile gets the profile for the specified proxy address, or