diff --git a/api/client/client_test.go b/api/client/client_test.go index 9adb3136fe269..d70b50af2b969 100644 --- a/api/client/client_test.go +++ b/api/client/client_test.go @@ -497,7 +497,7 @@ func TestListResources(t *testing.T) { ResourceType: test.resourceType, }) require.Error(t, err) - require.IsType(t, &trace.LimitExceededError{}, err.(*trace.TraceErr).OrigError()) + require.True(t, trace.IsLimitExceeded(err), "trace.IsLimitExceeded failed: err=%v (%T)", err, trace.Unwrap(err)) }) } @@ -523,7 +523,7 @@ func testGetResources[T types.ResourceWithLabels](t *testing.T, clt *Client, kin ResourceType: kind, }) require.Error(t, err) - require.IsType(t, &trace.LimitExceededError{}, err.(*trace.TraceErr).OrigError()) + require.True(t, trace.IsLimitExceeded(err), "trace.IsLimitExceeded failed: err=%v (%T)", err, trace.Unwrap(err)) // Test getting a page of resources page, err := GetResourcePage[T](ctx, clt, &proto.ListResourcesRequest{ @@ -631,7 +631,7 @@ func TestGetResourcesWithFilters(t *testing.T) { ResourceType: test.resourceType, }) require.Error(t, err) - require.IsType(t, &trace.LimitExceededError{}, err.(*trace.TraceErr).OrigError()) + require.True(t, trace.IsLimitExceeded(err), "trace.IsLimitExceeded failed: err=%v (%T)", err, trace.Unwrap(err)) // Test getting all resources by chunks to handle limit exceeded. resources, err := GetResourcesWithFilters(ctx, clt, proto.ListResourcesRequest{ diff --git a/api/utils/grpc/interceptors/errors.go b/api/utils/grpc/interceptors/errors.go index d387c9adbc34c..3b9e834b60c18 100644 --- a/api/utils/grpc/interceptors/errors.go +++ b/api/utils/grpc/interceptors/errors.go @@ -16,6 +16,8 @@ package interceptors import ( "context" + "errors" + "io" "github.com/gravitational/trace" "github.com/gravitational/trace/trail" @@ -48,12 +50,24 @@ type grpcClientStreamWrapper struct { // SendMsg wraps around ClientStream.SendMsg func (s *grpcClientStreamWrapper) SendMsg(m interface{}) error { - return trace.Unwrap(trail.FromGRPC(s.ClientStream.SendMsg(m))) + return wrapStreamErr(s.ClientStream.SendMsg(m)) } // RecvMsg wraps around ClientStream.RecvMsg func (s *grpcClientStreamWrapper) RecvMsg(m interface{}) error { - return trace.Unwrap(trail.FromGRPC(s.ClientStream.RecvMsg(m))) + return wrapStreamErr(s.ClientStream.RecvMsg(m)) +} + +func wrapStreamErr(err error) error { + switch { + case err == nil: + return nil + case errors.Is(err, io.EOF): + // Do not wrap io.EOF errors, they are often used as stop guards for streams. + return err + default: + return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(err))} + } } // GRPCServerUnaryErrorInterceptor is a gRPC unary server interceptor that @@ -63,10 +77,31 @@ func GRPCServerUnaryErrorInterceptor(ctx context.Context, req interface{}, info return resp, trace.Unwrap(trail.ToGRPC(err)) } +// RemoteError annotates server-side errors translated into trace by +// [GRPCClientUnaryErrorInterceptor] or [GRPCClientStreamErrorInterceptor]. +type RemoteError struct { + // Err is the underlying error. + Err error +} + +func (e *RemoteError) Error() string { + if e.Err == nil { + return "" + } + return e.Err.Error() +} + +func (e *RemoteError) Unwrap() error { + return e.Err +} + // GRPCClientUnaryErrorInterceptor is a gRPC unary client interceptor that // handles converting errors to the appropriate grpc status error. func GRPCClientUnaryErrorInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - return trace.Unwrap(trail.FromGRPC(invoker(ctx, method, req, reply, cc, opts...))) + if err := invoker(ctx, method, req, reply, cc, opts...); err != nil { + return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(err))} + } + return nil } // GRPCServerStreamErrorInterceptor is a gRPC server stream interceptor that @@ -81,7 +116,7 @@ func GRPCServerStreamErrorInterceptor(srv interface{}, ss grpc.ServerStream, inf func GRPCClientStreamErrorInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { s, err := streamer(ctx, desc, cc, method, opts...) if err != nil { - return nil, trace.Unwrap(trail.ToGRPC(err)) + return nil, &RemoteError{Err: trace.Unwrap(trail.ToGRPC(err))} } return &grpcClientStreamWrapper{s}, nil } diff --git a/api/utils/grpc/interceptors/errors_test.go b/api/utils/grpc/interceptors/errors_test.go index b3a9bfecd4028..02e917dc073fc 100644 --- a/api/utils/grpc/interceptors/errors_test.go +++ b/api/utils/grpc/interceptors/errors_test.go @@ -16,12 +16,13 @@ package interceptors import ( "context" - "errors" "io" "net" "testing" + "time" "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -73,31 +74,41 @@ func TestGRPCErrorWrapping(t *testing.T) { t.Run("unary interceptor", func(t *testing.T) { resp, err := client.Ping(context.Background(), &proto.PingRequest{}) - require.Nil(t, resp) - require.True(t, trace.IsNotFound(err)) - require.Equal(t, err.Error(), "not found") + assert.Nil(t, resp, "resp is non-nil") + assert.True(t, trace.IsNotFound(err), "trace.IsNotFound failed: err=%v (%T)", err, trace.Unwrap(err)) + assert.Equal(t, "not found", err.Error()) _, ok := err.(*trace.TraceErr) - require.False(t, ok, "client error should not include traces originating in the middleware") + assert.False(t, ok, "client error should not include traces originating in the middleware") + var remoteErr *RemoteError + assert.ErrorAs(t, err, &remoteErr, "Remote error is not marked as an interceptors.RemoteError") }) t.Run("stream interceptor", func(t *testing.T) { stream, err := client.AddMFADevice(context.Background()) require.NoError(t, err) + // Give the server time to close the stream. This allows us to more + // consistently hit the io.EOF error. + time.Sleep(100 * time.Millisecond) + + //nolint:staticcheck // SA1019. The specific stream used here doesn't matter. sendErr := stream.Send(&proto.AddMFADeviceRequest{}) - // io.EOF means the server closed the stream, which can - // happen depending in timing. In either case, it is - // still safe to recv from the stream and check for + // Expect either a success (unlikely because of the Sleep) or an unwrapped + // io.EOF error (meaning the server errored and closed the stream). + // In either case, it is still safe to recv from the stream and check for // the already exists error. - if sendErr != nil && !errors.Is(sendErr, io.EOF) { - t.Fatalf("Unexpected error: %v", sendErr) + if sendErr != nil && sendErr != io.EOF /* == error comparison on purpose! */ { + t.Fatalf("Unexpected error: %q (%T)", sendErr, sendErr) } _, err = stream.Recv() - require.True(t, trace.IsAlreadyExists(err)) - require.Equal(t, err.Error(), "already exists") + assert.True(t, trace.IsAlreadyExists(err), "trace.IsAlreadyExists failed: err=%v (%T)", err, trace.Unwrap(err)) + assert.Equal(t, "already exists", err.Error()) _, ok := err.(*trace.TraceErr) - require.False(t, ok, "client error should not include traces originating in the middleware") + assert.False(t, ok, "client error should not include traces originating in the middleware") + assert.True(t, trace.IsAlreadyExists(err), "trace.IsAlreadyExists failed: err=%v (%T)", err, trace.Unwrap(err)) + var remoteErr *RemoteError + assert.ErrorAs(t, err, &remoteErr, "Remote error is not marked as an interceptors.RemoteError") }) } diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index e7d6d755639a8..7e221591b3e38 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -2502,11 +2502,11 @@ func TestNodesCRUD(t *testing.T) { // GetNode should fail if node name isn't provided _, err = clt.GetNode(ctx, apidefaults.Namespace, "") - require.IsType(t, &trace.BadParameterError{}, err.(*trace.TraceErr).OrigError()) + require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err)) // GetNode should fail if namespace isn't provided _, err = clt.GetNode(ctx, "", "node1") - require.IsType(t, &trace.BadParameterError{}, err.(*trace.TraceErr).OrigError()) + require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err)) }) }) @@ -4478,7 +4478,8 @@ func TestUpsertApplicationServerOrigin(t *testing.T) { ctx = authz.ContextWithUser(parentCtx, admin.I) _, err = client.UpsertApplicationServer(ctx, appServer) - require.ErrorIs(t, trace.BadParameter("only the Okta role can create app servers and apps with an Okta origin"), err) + require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err)) + require.ErrorContains(t, err, "only the Okta role can create app servers and apps with an Okta origin") // Okta origin should not work with instance and node roles. client, err = server.NewClient(TestIdentity{ @@ -4494,7 +4495,8 @@ func TestUpsertApplicationServerOrigin(t *testing.T) { ctx = authz.ContextWithUser(parentCtx, admin.I) _, err = client.UpsertApplicationServer(ctx, appServer) - require.ErrorIs(t, trace.BadParameter("only the Okta role can create app servers and apps with an Okta origin"), err) + require.True(t, trace.IsBadParameter(err), "trace.IsBadParameter failed: err=%v (%T)", err, trace.Unwrap(err)) + require.ErrorContains(t, err, "only the Okta role can create app servers and apps with an Okta origin") // Okta origin should work with Okta role in role field. node := TestIdentity{ diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index da5f8c8f379b6..9ae34aa384aa4 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -2334,7 +2334,7 @@ func TestGenerateCerts(t *testing.T) { Format: constants.CertificateFormatStandard, }) require.Error(t, err) - require.IsType(t, &trace.AccessDeniedError{}, trace.Unwrap(err)) + require.True(t, trace.IsAccessDenied(err), "trace.IsAccessDenied failed: err=%v (%T)", err, trace.Unwrap(err)) _, privateKeyPEM, err := utils.MarshalPrivateKey(privateKey.(crypto.Signer)) require.NoError(t, err) @@ -2352,7 +2352,7 @@ func TestGenerateCerts(t *testing.T) { Format: constants.CertificateFormatStandard, }) require.Error(t, err) - require.IsType(t, &trace.AccessDeniedError{}, trace.Unwrap(err)) + require.True(t, trace.IsAccessDenied(err), "trace.IsAccessDenied failed: err=%v (%T)", err, trace.Unwrap(err)) require.Contains(t, err.Error(), "impersonated user can not impersonate anyone else") // but can renew their own cert, for example set route to cluster diff --git a/lib/client/api.go b/lib/client/api.go index e7826c19b461f..09b91f41731d6 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -565,7 +565,7 @@ func RetryWithRelogin(ctx context.Context, tc *TeleportClient, fn func() error, for _, o := range opts { o(&opt) } - log.Debugf("Activating relogin on %v.", fnErr) + log.Debugf("Activating relogin on error=%q (type=%T)", fnErr, trace.Unwrap(fnErr)) // check if the error is a private key policy error. if privateKeyPolicy, err := keys.ParsePrivateKeyPolicyError(fnErr); err == nil { @@ -630,11 +630,26 @@ func WithBeforeLoginHook(fn func() error) RetryWithReloginOption { } } +// IsErrorResolvableWithRelogin returns true if relogin is attempted on `err`. func IsErrorResolvableWithRelogin(err error) bool { - // Assume that failed handshake is a result of expired credentials. - return utils.IsHandshakeFailedError(err) || utils.IsCertExpiredError(err) || - trace.IsBadParameter(err) || trace.IsTrustError(err) || - keys.IsPrivateKeyPolicyError(err) || IsNoCredentialsError(err) + // Ignore any failures resulting from RPCs. + // These were all materialized as status.Error here before + // https://github.com/gravitational/teleport/pull/30578. + var remoteErr *interceptors.RemoteError + if errors.As(err, &remoteErr) { + return false + } + + return keys.IsPrivateKeyPolicyError(err) || + // TODO(codingllama): Retrying BadParameter is a terrible idea. + // We should fix this and remove the RemoteError condition above as well. + // Any retriable error should be explicitly marked as such. + trace.IsBadParameter(err) || + trace.IsTrustError(err) || + utils.IsCertExpiredError(err) || + // Assume that failed handshake is a result of expired credentials. + utils.IsHandshakeFailedError(err) || + IsNoCredentialsError(err) } // GetProfile gets the profile for the specified proxy address, or