Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UnmarshalFirst #398

Merged
merged 1 commit into from
May 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,54 @@ func BenchmarkUnmarshal(b *testing.B) {
}
}

func BenchmarkUnmarshalFirst(b *testing.B) {
// Random trailing data
trailingData := hexDecode("4a6b0f4718c73f391091ea1c")
for _, bm := range decodeBenchmarks {
for _, t := range bm.decodeToTypes {
name := "CBOR " + bm.name + " to Go " + t.String()
if t.Kind() == reflect.Struct {
name = "CBOR " + bm.name + " to Go " + t.Kind().String()
}
data := make([]byte, 0, len(bm.cborData)+len(trailingData))
data = append(data, bm.cborData...)
data = append(data, trailingData...)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
vPtr := reflect.New(t).Interface()
if _, err := UnmarshalFirst(data, vPtr); err != nil {
b.Fatal("UnmarshalFirst:", err)
}
}
})
}
}
}

func BenchmarkUnmarshalFirstViaDecoder(b *testing.B) {
// Random trailing data
trailingData := hexDecode("4a6b0f4718c73f391091ea1c")
for _, bm := range decodeBenchmarks {
for _, t := range bm.decodeToTypes {
name := "CBOR " + bm.name + " to Go " + t.String()
if t.Kind() == reflect.Struct {
name = "CBOR " + bm.name + " to Go " + t.Kind().String()
}
data := make([]byte, 0, len(bm.cborData)+len(trailingData))
data = append(data, bm.cborData...)
data = append(data, trailingData...)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
vPtr := reflect.New(t).Interface()
if err := NewDecoder(bytes.NewReader(data)).Decode(vPtr); err != nil {
b.Fatal("UnmarshalDecoder:", err)
}
}
})
}
}
}

func BenchmarkDecode(b *testing.B) {
for _, bm := range decodeBenchmarks {
for _, t := range bm.decodeToTypes {
Expand Down
44 changes: 44 additions & 0 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,25 @@ import (
//
// Unmarshal supports CBOR tag 55799 (self-describe CBOR), tag 0 and 1 (time),
// and tag 2 and 3 (bignum).
//
// Unmarshal returns ExtraneousDataError error (without decoding into v)
// if there are any remaining bytes following the first valid CBOR data item.
// See UnmarshalFirst, if you want to unmarshal only the first
// CBOR data item without ExtraneousDataError caused by remaining bytes.
func Unmarshal(data []byte, v interface{}) error {
return defaultDecMode.Unmarshal(data, v)
}

// UnmarshalFirst parses the first CBOR data item into the value pointed to by v
// using default decoding options. Any remaining bytes are returned in rest.
//
// If v is nil, not a pointer, or a nil pointer, UnmarshalFirst returns an error.
//
// See the documentation for Unmarshal for details.
func UnmarshalFirst(data []byte, v interface{}) (rest []byte, err error) {
return defaultDecMode.UnmarshalFirst(data, v)
}

// Valid checks whether data is a well-formed encoded CBOR data item and
// that it complies with default restrictions such as MaxNestedLevels,
// MaxArrayElements, MaxMapPairs, etc.
Expand Down Expand Up @@ -604,6 +619,35 @@ func (dm *decMode) Unmarshal(data []byte, v interface{}) error {
return d.value(v)
}

// UnmarshalFirst parses the first CBOR data item into the value pointed to by v
// using dm decoding mode. Any remaining bytes are returned in rest.
//
// If v is nil, not a pointer, or a nil pointer, UnmarshalFirst returns an error.
//
// See the documentation for Unmarshal for details.
func (dm *decMode) UnmarshalFirst(data []byte, v interface{}) (rest []byte, err error) {
d := decoder{data: data, dm: dm}

// check well-formedness.
off := d.off // Save offset before data validation
err = d.wellformed(true) // allow extra data after well-formed data item
d.off = off // Restore offset

// If it is well-formed, parse the value. This is structured like this to allow
// better test coverage
if err == nil {
err = d.value(v)
}

// If either wellformed or value returned an error, do not return rest bytes
if err != nil {
return nil, err
}

// Return the rest of the data slice (which might be len 0)
return d.data[d.off:], nil
}

// Valid checks whether data is a well-formed encoded CBOR data item and
// that it complies with configurable restrictions such as MaxNestedLevels,
// MaxArrayElements, MaxMapPairs, etc.
Expand Down
57 changes: 57 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5981,3 +5981,60 @@ func TestUnmarshalToDefaultMapType(t *testing.T) {
})
}
}

func TestUnmarshalFirstNoTrailing(t *testing.T) {
for _, tc := range unmarshalTests {
var v interface{}
if rest, err := UnmarshalFirst(tc.cborData, &v); err != nil {
t.Errorf("UnmarshalFirst(0x%x) returned error %v", tc.cborData, err)
} else {
if len(rest) != 0 {
t.Errorf("UnmarshalFirst(0x%x) returned rest %x (want [])", tc.cborData, rest)
}
// Check the value as well, although this is covered by other tests
if tm, ok := tc.emptyInterfaceValue.(time.Time); ok {
if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) {
t.Errorf("UnmarshalFirst(0x%x) = %v (%T), want %v (%T)", tc.cborData, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
} else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) {
t.Errorf("UnmarshalFirst(0x%x) = %v (%T), want %v (%T)", tc.cborData, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
}
}
}

func TestUnmarshalfirstTrailing(t *testing.T) {
// Random trailing data
trailingData := hexDecode("4a6b0f4718c73f391091ea1c")
for _, tc := range unmarshalTests {
data := make([]byte, 0, len(tc.cborData)+len(trailingData))
data = append(data, tc.cborData...)
data = append(data, trailingData...)
var v interface{}
if rest, err := UnmarshalFirst(data, &v); err != nil {
t.Errorf("UnmarshalFirst(0x%x) returned error %v", data, err)
} else {
if !bytes.Equal(trailingData, rest) {
t.Errorf("UnmarshalFirst(0x%x) returned rest %x (want %x)", data, rest, trailingData)
}
// Check the value as well, although this is covered by other tests
if tm, ok := tc.emptyInterfaceValue.(time.Time); ok {
if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) {
t.Errorf("UnmarshalFirst(0x%x) = %v (%T), want %v (%T)", data, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
} else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) {
t.Errorf("UnmarshalFirst(0x%x) = %v (%T), want %v (%T)", data, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
}
}
}

func TestUnmarshalFirstInvalidItem(t *testing.T) {
// UnmarshalFirst should not return "rest" if the item was not well-formed
invalidCBOR := hexDecode("83FF20030102")
var v interface{}
rest, err := UnmarshalFirst(invalidCBOR, &v)
if rest != nil {
t.Errorf("UnmarshalFirst(0x%x) = (%x, %v), want (nil, err)", invalidCBOR, rest, err)
}
}