diff --git a/stream.go b/stream.go index 01e66c1ed88f..3e98c79f0f15 100644 --- a/stream.go +++ b/stream.go @@ -1167,7 +1167,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { } else if err != nil { return toRPCErr(err) } - return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) + return status.Errorf(codes.Internal, "cardinality violation: expected for non server-streaming RPCs, but received another message") } func (a *csAttempt) finish(err error) { @@ -1491,7 +1491,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { } else if err != nil { return toRPCErr(err) } - return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) + return status.Errorf(codes.Internal, "cardinality violation: expected for non server-streaming RPCs, but received another message") } func (as *addrConnStream) finish(err error) { diff --git a/test/end2end_test.go b/test/end2end_test.go index a425877155e8..1fbde8b9df5f 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -3739,6 +3739,39 @@ func (s) TestClientStreaming_ReturnErrorAfterSendAndClose(t *testing.T) { } } +// Tests that a client receives a cardinality violation error for client-streaming +// RPCs if the server call SendMsg multiple times. +func (s) TestClientStreaming_ServerHandlerSendMsgAfterSendMsg(t *testing.T) { + ss := stubserver.StubServer{ + StreamingInputCallF: func(stream testgrpc.TestService_StreamingInputCallServer) error { + if err := stream.SendMsg(&testpb.StreamingInputCallResponse{}); err != nil { + t.Errorf("stream.SendMsg(_) = %v, want ", err) + } + if err := stream.SendMsg(&testpb.StreamingInputCallResponse{}); err != nil { + t.Errorf("stream.SendMsg(_) = %v, want ", err) + } + return nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatal("Error starting server:", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + stream, err := ss.Client.StreamingInputCall(ctx) + if err != nil { + t.Fatalf(".StreamingInputCall(_) = _, %v, want ", err) + } + if err := stream.Send(&testpb.StreamingInputCallRequest{}); err != nil { + t.Fatalf("stream.Send(_) = %v, want ", err) + } + if _, err := stream.CloseAndRecv(); status.Code(err) != codes.Internal { + t.Fatalf("stream.CloseAndRecv() = %v, want error with status code %s", err, codes.Internal) + } +} + func (s) TestExceedMaxStreamsLimit(t *testing.T) { for _, e := range listTestEnv() { testExceedMaxStreamsLimit(t, e)