Skip to content

Commit

Permalink
Add RawMessage, Marshaler, and Unmarshaler
Browse files Browse the repository at this point in the history
Add RawMessage type. RawMessage can be used to delay CBOR decoding or
precompute CBOR encoding.  Nil or empty RawMessage marshals to CBOR
nil value.

Add Marshaler and Unmarshaler interfaces to let user-defined types
implement their own CBOR encoding and decoding.
  • Loading branch information
fxamacker committed Nov 1, 2019
1 parent 9ff43a1 commit 1a29187
Show file tree
Hide file tree
Showing 6 changed files with 434 additions and 27 deletions.
53 changes: 37 additions & 16 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

var (
typeTime = reflect.TypeOf(time.Time{})
typeUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
typeBinaryUnmarshaler = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem()
)

Expand Down Expand Up @@ -93,6 +94,14 @@ func (e *UnmarshalTypeError) Error() string {
return s
}

// Unmarshaler is the interface implemented by types that can unmarshal a CBOR
// representation of themselves. The input can be assumed to be a valid encoding
// of a CBOR value. UnmarshalCBOR must copy the CBOR data if it wishes to retain
// the data after returning.
type Unmarshaler interface {
UnmarshalCBOR([]byte) error
}

type decodeState struct {
data []byte
offset int // next read offset in data
Expand Down Expand Up @@ -184,20 +193,30 @@ func (d *decodeState) parse(v reflect.Value) (err error) {
return nil
}

// Process cbor nil/undefined.
if d.data[d.offset] == 0xf6 || d.data[d.offset] == 0xf7 {
d.offset++
return fillNil(cborTypePrimitives, v)
// Create new value for the pointer v to point to if CBOR value is not nil/undefined.
if d.data[d.offset] != 0xf6 && d.data[d.offset] != 0xf7 {
for v.Kind() == reflect.Ptr {
if v.IsNil() {
if !v.CanSet() {
return errors.New("cbor: cannot set new value for " + v.Type().String())
}
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
}

for v.Kind() == reflect.Ptr {
if v.IsNil() {
if !v.CanSet() {
return errors.New("cbor: cannot set new value for " + v.Type().String())
}
v.Set(reflect.New(v.Type().Elem()))
if reflect.PtrTo(v.Type()).Implements(typeUnmarshaler) {
pv := reflect.New(v.Type())
pv.Elem().Set(v)
u := pv.Interface().(Unmarshaler)
start := d.offset
d.skip()
if err := u.UnmarshalCBOR(d.data[start:d.offset]); err != nil {
return err
}
v = v.Elem()
v.Set(pv.Elem())
return nil
}

// Process byte/text string.
Expand Down Expand Up @@ -231,6 +250,8 @@ func (d *decodeState) parse(v reflect.Value) (err error) {
switch ai {
case 20, 21:
return fillBool(t, ai == 21, v)
case 22, 23:
return fillNil(t, v)
case 24:
return fillPositiveInt(t, uint64(val), v)
case 25:
Expand Down Expand Up @@ -281,11 +302,6 @@ func (d *decodeState) parse(v reflect.Value) (err error) {

// parseInterface assumes data is well-formed, and does not perform bounds checking.
func (d *decodeState) parseInterface() (_ interface{}, err error) {
if d.data[d.offset] == 0xf6 || d.data[d.offset] == 0xf7 {
d.offset++
return nil, nil
}

// Process byte/text string.
t := cborType(d.data[d.offset] & 0xE0)
if t == cborTypeByteString {
Expand Down Expand Up @@ -316,6 +332,8 @@ func (d *decodeState) parseInterface() (_ interface{}, err error) {
switch ai {
case 20, 21:
return (ai == 21), nil
case 22, 23:
return nil, nil
case 24:
return uint64(val), nil
case 25:
Expand Down Expand Up @@ -874,6 +892,9 @@ func isHashableKind(k reflect.Kind) bool {
// time.Time value, Unmarshal creates an unix time with integer/float as seconds
// and fractional seconds since January 1, 1970 UTC.
//
// To unmarshal CBOR into a value implementing the Unmarshaler interface,
// Unmarshal calls that value's UnmarshalCBOR method.
//
// Unmarshal decodes a CBOR byte string into a value implementing
// encoding.BinaryUnmarshaler.
//
Expand Down
204 changes: 195 additions & 9 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,13 @@ func TestUnmarshal(t *testing.T) {
} else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) {
t.Errorf("Unmarshal(0x%0x) = %v (%T), want %v (%T)", tc.cborData, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
// Test unmarshalling CBOR into RawMessage.
var r cbor.RawMessage
if err := cbor.Unmarshal(tc.cborData, &r); err != nil {
t.Errorf("Unmarshal(0x%0x) returns error %v", tc.cborData, err)
} else if !bytes.Equal(r, tc.cborData) {
t.Errorf("Unmarshal(0x%0x) returns RawMessage %v, want %v", tc.cborData, r, tc.cborData)
}
// Test unmarshalling CBOR into compatible data types.
for _, value := range tc.values {
v := reflect.New(reflect.TypeOf(value))
Expand Down Expand Up @@ -656,6 +663,13 @@ func TestUnmarshalFloat(t *testing.T) {
}
}
}
// Test unmarshalling CBOR into RawMessage.
var r cbor.RawMessage
if err := cbor.Unmarshal(tc.cborData, &r); err != nil {
t.Errorf("Unmarshal(0x%0x) returns error %v", tc.cborData, err)
} else if !bytes.Equal(r, tc.cborData) {
t.Errorf("Unmarshal(0x%0x) returns RawMessage %v, want %v", tc.cborData, r, tc.cborData)
}
// Test unmarshalling CBOR into compatible data types.
for _, value := range tc.values {
v := reflect.New(reflect.TypeOf(value))
Expand Down Expand Up @@ -701,6 +715,7 @@ func TestUnmarshalIntoPointer(t *testing.T) {

var p1 *int
var p2 *string
var p3 *cbor.RawMessage

var i int
pi := &i
Expand All @@ -710,39 +725,67 @@ func TestUnmarshalIntoPointer(t *testing.T) {
ps := &s
pps := &ps

// Unmarshal CBOR nil into a pointer.
var r cbor.RawMessage
pr := &r
ppr := &pr

// Unmarshal CBOR nil into a nil pointer.
if err := cbor.Unmarshal(cborDataNil, &p1); err != nil {
t.Errorf("Unmarshal(0x%0x) returns error %v", cborDataNil, err)
} else if p1 != nil {
t.Errorf("Unmarshal(0x%0x) = %v (%T), want nil", cborDataNil, p1, p1)
}
if err := cbor.Unmarshal(cborDataNil, &p2); err != nil {
t.Errorf("Unmarshal(0x%0x) returns error %v", cborDataNil, err)
} else if p2 != nil {
t.Errorf("Unmarshal(0x%0x) = %v (%T), want nil", cborDataNil, p1, p1)
}
if err := cbor.Unmarshal(cborDataNil, &p3); err != nil {
t.Errorf("Unmarshal(0x%0x) returns error %v", cborDataNil, err)
} else if p3 != nil {
t.Errorf("Unmarshal(0x%0x) = %v (%T), want nil", cborDataNil, p1, p1)
}

// Unmarshal CBOR integer into a non-nil pointer.
if err := cbor.Unmarshal(cborDataInt, &ppi); err != nil {
t.Errorf("Unmarshal(0x%0x) returns error %v", cborDataNil, err)
t.Errorf("Unmarshal(0x%0x) returns error %v", cborDataInt, err)
} else if i != 24 {
t.Errorf("Unmarshal(0x%0x) = %v (%T), want 24", cborDataNil, i, i)
t.Errorf("Unmarshal(0x%0x) = %v (%T), want 24", cborDataInt, i, i)
}

// Unmarshal CBOR integer into a nil pointer.
if err := cbor.Unmarshal(cborDataInt, &p1); err != nil {
t.Errorf("Unmarshal(0x%0x) returns error %v", cborDataNil, err)
t.Errorf("Unmarshal(0x%0x) returns error %v", cborDataInt, err)
} else if *p1 != 24 {
t.Errorf("Unmarshal(0x%0x) = %v (%T), want 24", cborDataNil, *pi, pi)
t.Errorf("Unmarshal(0x%0x) = %v (%T), want 24", cborDataInt, *pi, pi)
}

// Unmarshal CBOR string into a non-nil pointer.
if err := cbor.Unmarshal(cborDataString, &pps); err != nil {
t.Errorf("Unmarshal(0x%0x) returns error %v", cborDataNil, err)
t.Errorf("Unmarshal(0x%0x) returns error %v", cborDataString, err)
} else if s != "streaming" {
t.Errorf("Unmarshal(0x%0x) = %v (%T), want \"streaming\"", cborDataNil, s, s)
t.Errorf("Unmarshal(0x%0x) = %v (%T), want \"streaming\"", cborDataString, s, s)
}

// Unmarshal CBOR string into a nil pointer.
if err := cbor.Unmarshal(cborDataString, &p2); err != nil {
t.Errorf("Unmarshal(0x%0x) returns error %v", cborDataNil, err)
t.Errorf("Unmarshal(0x%0x) returns error %v", cborDataString, err)
} else if *p2 != "streaming" {
t.Errorf("Unmarshal(0x%0x) = %v (%T), want \"streaming\"", cborDataNil, *p2, p2)
t.Errorf("Unmarshal(0x%0x) = %v (%T), want \"streaming\"", cborDataString, *p2, p2)
}

// Unmarshal CBOR string into a non-nil cbor.RawMessage.
if err := cbor.Unmarshal(cborDataString, &ppr); err != nil {
t.Errorf("Unmarshal(0x%0x) returns error %v", cborDataString, err)
} else if !bytes.Equal(r, cborDataString) {
t.Errorf("Unmarshal(0x%0x) = %v (%T), want %v", cborDataString, r, r, cborDataString)
}

// Unmarshal CBOR string into a nil pointer to cbor.RawMessage.
if err := cbor.Unmarshal(cborDataString, &p3); err != nil {
t.Errorf("Unmarshal(0x%0x) returns error %v", cborDataString, err)
} else if !bytes.Equal(*p3, cborDataString) {
t.Errorf("Unmarshal(0x%0x) = %v (%T), want %v", cborDataString, *p3, p3, cborDataString)
}
}

Expand Down Expand Up @@ -1461,6 +1504,12 @@ func (s *stru) UnmarshalBinary(data []byte) (err error) {
return
}

type marshalBinaryError string

func (n marshalBinaryError) MarshalBinary() (data []byte, err error) {
return nil, errors.New(string(n))
}

func TestBinaryUnmarshal(t *testing.T) {
testCases := []struct {
name string
Expand Down Expand Up @@ -1529,3 +1578,140 @@ func TestBinaryUnmarshalError(t *testing.T) {
})
}
}

func TestBinaryMarshalError(t *testing.T) {
wantErrorMsg := "MarshalBinary: error"
v := marshalBinaryError(wantErrorMsg)
if _, err := cbor.Marshal(v, cbor.EncOptions{}); err == nil {
t.Errorf("Unmarshal(0x%0x) doesn't return error, want error msg %s\n", v, wantErrorMsg)
} else if err.Error() != wantErrorMsg {
t.Errorf("Unmarshal(0x%0x) returns error %s, want %s", v, err, wantErrorMsg)
}
}

type number2 uint64

func (n number2) MarshalCBOR() (data []byte, err error) {
m := map[string]uint64{"num": uint64(n)}
return cbor.Marshal(m, cbor.EncOptions{})
}

func (n *number2) UnmarshalCBOR(data []byte) (err error) {
var v map[string]uint64
if err := cbor.Unmarshal(data, &v); err != nil {
return err
}
*n = number2(v["num"])
return nil
}

type stru2 struct {
a, b, c string
}

func (s *stru2) MarshalCBOR() ([]byte, error) {
v := []string{s.a, s.b, s.c}
return cbor.Marshal(v, cbor.EncOptions{})
}

func (s *stru2) UnmarshalCBOR(data []byte) (err error) {
var v []string
if err := cbor.Unmarshal(data, &v); err != nil {
return err
}
if len(v) > 0 {
s.a = v[0]
}
if len(v) > 1 {
s.b = v[1]
}
if len(v) > 2 {
s.c = v[2]
}
return nil
}

type marshalCBORError string

func (n marshalCBORError) MarshalCBOR() (data []byte, err error) {
return nil, errors.New(string(n))
}

func TestUnmarshalCBOR(t *testing.T) {
testCases := []struct {
name string
obj interface{}
wantCborData []byte
}{
{
name: "primitive obj",
obj: number2(1),
wantCborData: hexDecode("a1636e756d01"),
},
{
name: "struct obj",
obj: stru2{a: "a", b: "b", c: "c"},
wantCborData: hexDecode("83616161626163"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
b, err := cbor.Marshal(tc.obj, cbor.EncOptions{})
if err != nil {
t.Errorf("Marshal(%+v) returns error %v\n", tc.obj, err)
}
if !bytes.Equal(b, tc.wantCborData) {
t.Errorf("Marshal(%+v) = 0x%0x, want 0x%0x", tc.obj, b, tc.wantCborData)
}
v := reflect.New(reflect.TypeOf(tc.obj))
if err := cbor.Unmarshal(b, v.Interface()); err != nil {
t.Errorf("Unmarshal() returns error %v\n", err)
}
if !reflect.DeepEqual(tc.obj, v.Elem().Interface()) {
t.Errorf("Marshal-Unmarshal return different values: %v, %v\n", tc.obj, v.Elem().Interface())
}
})
}
}

func TestUnmarshalCBORError(t *testing.T) {
testCases := []struct {
name string
typ reflect.Type
cborData []byte
wantErrorMsg string
}{
{
name: "primitive type",
typ: reflect.TypeOf(number2(0)),
cborData: hexDecode("44499602d2"),
wantErrorMsg: "cbor: cannot unmarshal byte string into Go value of type map[string]uint64",
},
{
name: "struct type",
typ: reflect.TypeOf(stru2{}),
cborData: hexDecode("47612C622C632C64"),
wantErrorMsg: "cbor: cannot unmarshal byte string into Go value of type []string",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
v := reflect.New(tc.typ)
if err := cbor.Unmarshal(tc.cborData, v.Interface()); err == nil {
t.Errorf("Unmarshal(0x%0x) doesn't return error, want error msg %s\n", tc.cborData, tc.wantErrorMsg)
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("Unmarshal(0x%0x) returns error %s, want %s", tc.cborData, err, tc.wantErrorMsg)
}
})
}
}

func TestMarshalCBORError(t *testing.T) {
wantErrorMsg := "MarshalCBOR: error"
v := marshalCBORError(wantErrorMsg)
if _, err := cbor.Marshal(v, cbor.EncOptions{}); err == nil {
t.Errorf("Marshal(%+v) doesn't return error, want error msg %s\n", v, wantErrorMsg)
} else if err.Error() != wantErrorMsg {
t.Errorf("Marshal(%+v) returns error %s, want %s", v, err, wantErrorMsg)
}
}
Loading

0 comments on commit 1a29187

Please sign in to comment.