diff --git a/api/utils/grpc/interceptors/errors.go b/api/utils/grpc/interceptors/errors.go index e765cc10a94a0..3b9e834b60c18 100644 --- a/api/utils/grpc/interceptors/errors.go +++ b/api/utils/grpc/interceptors/errors.go @@ -50,22 +50,24 @@ type grpcClientStreamWrapper struct { // SendMsg wraps around ClientStream.SendMsg func (s *grpcClientStreamWrapper) SendMsg(m interface{}) error { - if err := s.ClientStream.SendMsg(m); err != nil { - return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(s.ClientStream.SendMsg(m)))} - } - return nil + return wrapStreamErr(s.ClientStream.SendMsg(m)) } // RecvMsg wraps around ClientStream.RecvMsg func (s *grpcClientStreamWrapper) RecvMsg(m interface{}) error { - switch err := s.ClientStream.RecvMsg(m); { + return wrapStreamErr(s.ClientStream.RecvMsg(m)) +} + +func wrapStreamErr(err error) error { + switch { + case err == nil: + return nil case errors.Is(err, io.EOF): // Do not wrap io.EOF errors, they are often used as stop guards for streams. return err - case err != nil: - return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(s.ClientStream.RecvMsg(m)))} + default: + return &RemoteError{Err: trace.Unwrap(trail.FromGRPC(err))} } - return nil } // GRPCServerUnaryErrorInterceptor is a gRPC unary server interceptor that diff --git a/api/utils/grpc/interceptors/errors_test.go b/api/utils/grpc/interceptors/errors_test.go index ffa0d28356e5c..3e01ffca507ff 100644 --- a/api/utils/grpc/interceptors/errors_test.go +++ b/api/utils/grpc/interceptors/errors_test.go @@ -16,10 +16,10 @@ package interceptors_test import ( "context" - "errors" "io" "net" "testing" + "time" "github.com/gravitational/trace" "github.com/stretchr/testify/assert" @@ -88,15 +88,19 @@ func TestGRPCErrorWrapping(t *testing.T) { stream, err := client.AddMFADevice(context.Background()) require.NoError(t, err) + // Give the server time to close the stream. This allows us to more + // consistently hit the io.EOF error. + time.Sleep(100 * time.Millisecond) + //nolint:staticcheck // SA1019. The specific stream used here doesn't matter. sendErr := stream.Send(&proto.AddMFADeviceRequest{}) - // io.EOF means the server closed the stream, which can - // happen depending in timing. In either case, it is - // still safe to recv from the stream and check for + // Expect either a success (unlikely because of the Sleep) or an unwrapped + // io.EOF error (meaning the server errored and closed the stream). + // In either case, it is still safe to recv from the stream and check for // the already exists error. - if sendErr != nil && !errors.Is(sendErr, io.EOF) { - t.Fatalf("Unexpected error: %v", sendErr) + if sendErr != nil && sendErr != io.EOF /* == error comparison on purpose! */ { + t.Fatalf("Unexpected error: %q (%T)", sendErr, sendErr) } _, err = stream.Recv()