From fe43cf5d406879aac3fc2e4a0d95453f735c56a6 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Tue, 6 Feb 2018 18:19:34 -0800 Subject: [PATCH] proto: treat bad wire types as unknown fields Previously, an error was returned during unmarshal when a wiretype was encountered that did not match the expected type. In order to match the behavior of the C++ and Python implementations, we no longer return an error and instead store the bad wire fragment as an unknown field (or skip them if unknown field preservation is disabled). The generator still produces code that references ErrInternalBadWireType for unmarshal logic for oneof fields. However, the current proto package does not use the generated unmarshalers for oneofs, so their existence has no bearing on unmarshal semantics. Cleaning up the generator to stop producing these is future work. --- proto/all_test.go | 30 +++++++----- proto/table_unmarshal.go | 103 +++++++++++++++++++-------------------- 2 files changed, 68 insertions(+), 65 deletions(-) diff --git a/proto/all_test.go b/proto/all_test.go index 573410885b..a9da89fca7 100644 --- a/proto/all_test.go +++ b/proto/all_test.go @@ -1131,21 +1131,25 @@ func TestBigRepeated(t *testing.T) { } } -// Verify we give a useful message when decoding to the wrong structure type. -func TestTypeMismatch(t *testing.T) { - pb1 := initGoTest(true) +func TestBadWireTypeUnknown(t *testing.T) { + var b []byte + fmt.Sscanf("0a01780d00000000080b101612036161611521000000202c220362626225370000002203636363214200000000000000584d5a036464645900000000000056405d63000000", "%x", &b) - // Marshal - o := old() - o.Marshal(pb1) + m := new(MyMessage) + if err := Unmarshal(b, m); err != nil { + t.Errorf("unexpected Unmarshal error: %v", err) + } - // Now Unmarshal it to the wrong type. - pb2 := initGoTestField() - err := o.Unmarshal(pb2) - if err == nil { - t.Error("expected error, got no error") - } else if !strings.Contains(err.Error(), "bad wiretype") { - t.Error("expected bad wiretype error, got", err) + var unknown []byte + fmt.Sscanf("0a01780d0000000010161521000000202c2537000000214200000000000000584d5a036464645d63000000", "%x", &unknown) + if !bytes.Equal(m.XXX_unrecognized, unknown) { + t.Errorf("unknown bytes mismatch:\ngot %x\nwant %x", m.XXX_unrecognized, unknown) + } + DiscardUnknown(m) + + want := &MyMessage{Count: Int32(11), Name: String("aaa"), Pet: []string{"bbb", "ccc"}, Bigfloat: Float64(88)} + if !Equal(m, want) { + t.Errorf("message mismatch:\ngot %v\nwant %v", m, want) } } diff --git a/proto/table_unmarshal.go b/proto/table_unmarshal.go index afc556a2f2..a7ee274381 100644 --- a/proto/table_unmarshal.go +++ b/proto/table_unmarshal.go @@ -179,11 +179,10 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error { rnse = r continue } - if err == errInternalBadWireType { - err = fmt.Errorf("proto: bad wiretype for field at offset %d of type %s: got wiretype %d", - f.field, u.typ, wire) + if err == nil || err != errInternalBadWireType { + return err } - return err + // Fragments with bad wire type are treated as unknown fields. } // Unknown tag. @@ -688,7 +687,7 @@ func typeUnmarshaler(t reflect.Type, tags string) unmarshaler { func unmarshalInt64Value(b []byte, f pointer, w int) ([]byte, error) { if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -702,7 +701,7 @@ func unmarshalInt64Value(b []byte, f pointer, w int) ([]byte, error) { func unmarshalInt64Ptr(b []byte, f pointer, w int) ([]byte, error) { if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -739,7 +738,7 @@ func unmarshalInt64Slice(b []byte, f pointer, w int) ([]byte, error) { return res, nil } if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -754,7 +753,7 @@ func unmarshalInt64Slice(b []byte, f pointer, w int) ([]byte, error) { func unmarshalSint64Value(b []byte, f pointer, w int) ([]byte, error) { if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -768,7 +767,7 @@ func unmarshalSint64Value(b []byte, f pointer, w int) ([]byte, error) { func unmarshalSint64Ptr(b []byte, f pointer, w int) ([]byte, error) { if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -805,7 +804,7 @@ func unmarshalSint64Slice(b []byte, f pointer, w int) ([]byte, error) { return res, nil } if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -820,7 +819,7 @@ func unmarshalSint64Slice(b []byte, f pointer, w int) ([]byte, error) { func unmarshalUint64Value(b []byte, f pointer, w int) ([]byte, error) { if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -834,7 +833,7 @@ func unmarshalUint64Value(b []byte, f pointer, w int) ([]byte, error) { func unmarshalUint64Ptr(b []byte, f pointer, w int) ([]byte, error) { if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -871,7 +870,7 @@ func unmarshalUint64Slice(b []byte, f pointer, w int) ([]byte, error) { return res, nil } if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -886,7 +885,7 @@ func unmarshalUint64Slice(b []byte, f pointer, w int) ([]byte, error) { func unmarshalInt32Value(b []byte, f pointer, w int) ([]byte, error) { if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -900,7 +899,7 @@ func unmarshalInt32Value(b []byte, f pointer, w int) ([]byte, error) { func unmarshalInt32Ptr(b []byte, f pointer, w int) ([]byte, error) { if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -936,7 +935,7 @@ func unmarshalInt32Slice(b []byte, f pointer, w int) ([]byte, error) { return res, nil } if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -950,7 +949,7 @@ func unmarshalInt32Slice(b []byte, f pointer, w int) ([]byte, error) { func unmarshalSint32Value(b []byte, f pointer, w int) ([]byte, error) { if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -964,7 +963,7 @@ func unmarshalSint32Value(b []byte, f pointer, w int) ([]byte, error) { func unmarshalSint32Ptr(b []byte, f pointer, w int) ([]byte, error) { if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -1000,7 +999,7 @@ func unmarshalSint32Slice(b []byte, f pointer, w int) ([]byte, error) { return res, nil } if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -1014,7 +1013,7 @@ func unmarshalSint32Slice(b []byte, f pointer, w int) ([]byte, error) { func unmarshalUint32Value(b []byte, f pointer, w int) ([]byte, error) { if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -1028,7 +1027,7 @@ func unmarshalUint32Value(b []byte, f pointer, w int) ([]byte, error) { func unmarshalUint32Ptr(b []byte, f pointer, w int) ([]byte, error) { if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -1065,7 +1064,7 @@ func unmarshalUint32Slice(b []byte, f pointer, w int) ([]byte, error) { return res, nil } if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -1080,7 +1079,7 @@ func unmarshalUint32Slice(b []byte, f pointer, w int) ([]byte, error) { func unmarshalFixed64Value(b []byte, f pointer, w int) ([]byte, error) { if w != WireFixed64 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 8 { return nil, io.ErrUnexpectedEOF @@ -1092,7 +1091,7 @@ func unmarshalFixed64Value(b []byte, f pointer, w int) ([]byte, error) { func unmarshalFixed64Ptr(b []byte, f pointer, w int) ([]byte, error) { if w != WireFixed64 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 8 { return nil, io.ErrUnexpectedEOF @@ -1126,7 +1125,7 @@ func unmarshalFixed64Slice(b []byte, f pointer, w int) ([]byte, error) { return res, nil } if w != WireFixed64 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 8 { return nil, io.ErrUnexpectedEOF @@ -1139,7 +1138,7 @@ func unmarshalFixed64Slice(b []byte, f pointer, w int) ([]byte, error) { func unmarshalFixedS64Value(b []byte, f pointer, w int) ([]byte, error) { if w != WireFixed64 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 8 { return nil, io.ErrUnexpectedEOF @@ -1151,7 +1150,7 @@ func unmarshalFixedS64Value(b []byte, f pointer, w int) ([]byte, error) { func unmarshalFixedS64Ptr(b []byte, f pointer, w int) ([]byte, error) { if w != WireFixed64 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 8 { return nil, io.ErrUnexpectedEOF @@ -1185,7 +1184,7 @@ func unmarshalFixedS64Slice(b []byte, f pointer, w int) ([]byte, error) { return res, nil } if w != WireFixed64 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 8 { return nil, io.ErrUnexpectedEOF @@ -1198,7 +1197,7 @@ func unmarshalFixedS64Slice(b []byte, f pointer, w int) ([]byte, error) { func unmarshalFixed32Value(b []byte, f pointer, w int) ([]byte, error) { if w != WireFixed32 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 4 { return nil, io.ErrUnexpectedEOF @@ -1210,7 +1209,7 @@ func unmarshalFixed32Value(b []byte, f pointer, w int) ([]byte, error) { func unmarshalFixed32Ptr(b []byte, f pointer, w int) ([]byte, error) { if w != WireFixed32 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 4 { return nil, io.ErrUnexpectedEOF @@ -1244,7 +1243,7 @@ func unmarshalFixed32Slice(b []byte, f pointer, w int) ([]byte, error) { return res, nil } if w != WireFixed32 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 4 { return nil, io.ErrUnexpectedEOF @@ -1257,7 +1256,7 @@ func unmarshalFixed32Slice(b []byte, f pointer, w int) ([]byte, error) { func unmarshalFixedS32Value(b []byte, f pointer, w int) ([]byte, error) { if w != WireFixed32 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 4 { return nil, io.ErrUnexpectedEOF @@ -1269,7 +1268,7 @@ func unmarshalFixedS32Value(b []byte, f pointer, w int) ([]byte, error) { func unmarshalFixedS32Ptr(b []byte, f pointer, w int) ([]byte, error) { if w != WireFixed32 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 4 { return nil, io.ErrUnexpectedEOF @@ -1302,7 +1301,7 @@ func unmarshalFixedS32Slice(b []byte, f pointer, w int) ([]byte, error) { return res, nil } if w != WireFixed32 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 4 { return nil, io.ErrUnexpectedEOF @@ -1314,7 +1313,7 @@ func unmarshalFixedS32Slice(b []byte, f pointer, w int) ([]byte, error) { func unmarshalBoolValue(b []byte, f pointer, w int) ([]byte, error) { if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } // Note: any length varint is allowed, even though any sane // encoder will use one byte. @@ -1331,7 +1330,7 @@ func unmarshalBoolValue(b []byte, f pointer, w int) ([]byte, error) { func unmarshalBoolPtr(b []byte, f pointer, w int) ([]byte, error) { if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -1367,7 +1366,7 @@ func unmarshalBoolSlice(b []byte, f pointer, w int) ([]byte, error) { return res, nil } if w != WireVarint { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -1381,7 +1380,7 @@ func unmarshalBoolSlice(b []byte, f pointer, w int) ([]byte, error) { func unmarshalFloat64Value(b []byte, f pointer, w int) ([]byte, error) { if w != WireFixed64 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 8 { return nil, io.ErrUnexpectedEOF @@ -1393,7 +1392,7 @@ func unmarshalFloat64Value(b []byte, f pointer, w int) ([]byte, error) { func unmarshalFloat64Ptr(b []byte, f pointer, w int) ([]byte, error) { if w != WireFixed64 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 8 { return nil, io.ErrUnexpectedEOF @@ -1427,7 +1426,7 @@ func unmarshalFloat64Slice(b []byte, f pointer, w int) ([]byte, error) { return res, nil } if w != WireFixed64 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 8 { return nil, io.ErrUnexpectedEOF @@ -1440,7 +1439,7 @@ func unmarshalFloat64Slice(b []byte, f pointer, w int) ([]byte, error) { func unmarshalFloat32Value(b []byte, f pointer, w int) ([]byte, error) { if w != WireFixed32 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 4 { return nil, io.ErrUnexpectedEOF @@ -1452,7 +1451,7 @@ func unmarshalFloat32Value(b []byte, f pointer, w int) ([]byte, error) { func unmarshalFloat32Ptr(b []byte, f pointer, w int) ([]byte, error) { if w != WireFixed32 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 4 { return nil, io.ErrUnexpectedEOF @@ -1486,7 +1485,7 @@ func unmarshalFloat32Slice(b []byte, f pointer, w int) ([]byte, error) { return res, nil } if w != WireFixed32 { - return nil, errInternalBadWireType + return b, errInternalBadWireType } if len(b) < 4 { return nil, io.ErrUnexpectedEOF @@ -1499,7 +1498,7 @@ func unmarshalFloat32Slice(b []byte, f pointer, w int) ([]byte, error) { func unmarshalStringValue(b []byte, f pointer, w int) ([]byte, error) { if w != WireBytes { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -1519,7 +1518,7 @@ func unmarshalStringValue(b []byte, f pointer, w int) ([]byte, error) { func unmarshalStringPtr(b []byte, f pointer, w int) ([]byte, error) { if w != WireBytes { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -1539,7 +1538,7 @@ func unmarshalStringPtr(b []byte, f pointer, w int) ([]byte, error) { func unmarshalStringSlice(b []byte, f pointer, w int) ([]byte, error) { if w != WireBytes { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -1562,7 +1561,7 @@ var emptyBuf [0]byte func unmarshalBytesValue(b []byte, f pointer, w int) ([]byte, error) { if w != WireBytes { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -1583,7 +1582,7 @@ func unmarshalBytesValue(b []byte, f pointer, w int) ([]byte, error) { func unmarshalBytesSlice(b []byte, f pointer, w int) ([]byte, error) { if w != WireBytes { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -1602,7 +1601,7 @@ func unmarshalBytesSlice(b []byte, f pointer, w int) ([]byte, error) { func makeUnmarshalMessagePtr(sub *unmarshalInfo, name string) unmarshaler { return func(b []byte, f pointer, w int) ([]byte, error) { if w != WireBytes { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -1636,7 +1635,7 @@ func makeUnmarshalMessagePtr(sub *unmarshalInfo, name string) unmarshaler { func makeUnmarshalMessageSlicePtr(sub *unmarshalInfo, name string) unmarshaler { return func(b []byte, f pointer, w int) ([]byte, error) { if w != WireBytes { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, n := decodeVarint(b) if n == 0 { @@ -1663,7 +1662,7 @@ func makeUnmarshalMessageSlicePtr(sub *unmarshalInfo, name string) unmarshaler { func makeUnmarshalGroupPtr(sub *unmarshalInfo, name string) unmarshaler { return func(b []byte, f pointer, w int) ([]byte, error) { if w != WireStartGroup { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, y := findEndGroup(b) if x < 0 { @@ -1689,7 +1688,7 @@ func makeUnmarshalGroupPtr(sub *unmarshalInfo, name string) unmarshaler { func makeUnmarshalGroupSlicePtr(sub *unmarshalInfo, name string) unmarshaler { return func(b []byte, f pointer, w int) ([]byte, error) { if w != WireStartGroup { - return nil, errInternalBadWireType + return b, errInternalBadWireType } x, y := findEndGroup(b) if x < 0 {