Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add options to disable BinaryMarshaler/BinaryUnmarshaler support. #526

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,25 @@ func (bseem ByteSliceExpectedEncodingMode) valid() bool {
return bseem >= 0 && bseem < maxByteSliceExpectedEncodingMode
}

// BinaryUnmarshalerMode specifies how to decode into types that implement
// encoding.BinaryUnmarshaler.
type BinaryUnmarshalerMode int

const (
// BinaryUnmarshalerByteString will invoke UnmarshalBinary on the contents of a CBOR byte
// string when decoding into a value that implements BinaryUnmarshaler.
BinaryUnmarshalerByteString BinaryUnmarshalerMode = iota

// BinaryUnmarshalerNone does not recognize BinaryUnmarshaler implementations during decode.
BinaryUnmarshalerNone

maxBinaryUnmarshalerMode
)

func (bum BinaryUnmarshalerMode) valid() bool {
return bum >= 0 && bum < maxBinaryUnmarshalerMode
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -751,6 +770,10 @@ type DecOptions struct {
// ByteSliceExpectedEncodingMode specifies how to decode a byte string NOT enclosed in an
// "expected later encoding" tag (RFC 8949 Section 3.4.5.2) into a Go byte slice.
ByteSliceExpectedEncoding ByteSliceExpectedEncodingMode

// BinaryUnmarshaler specifies how to decode into types that implement
// encoding.BinaryUnmarshaler.
BinaryUnmarshaler BinaryUnmarshalerMode
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -954,6 +977,10 @@ func (opts DecOptions) decMode() (*decMode, error) {
return nil, errors.New("cbor: invalid ByteSliceExpectedEncoding " + strconv.Itoa(int(opts.ByteSliceExpectedEncoding)))
}

if !opts.BinaryUnmarshaler.valid() {
return nil, errors.New("cbor: invalid BinaryUnmarshaler " + strconv.Itoa(int(opts.BinaryUnmarshaler)))
}

dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
Expand All @@ -979,6 +1006,7 @@ func (opts DecOptions) decMode() (*decMode, error) {
infDec: opts.Inf,
byteStringToTime: opts.ByteStringToTime,
byteSliceExpectedEncoding: opts.ByteSliceExpectedEncoding,
binaryUnmarshaler: opts.BinaryUnmarshaler,
}

return &dm, nil
Expand Down Expand Up @@ -1056,6 +1084,7 @@ type decMode struct {
infDec InfMode
byteStringToTime ByteStringToTimeMode
byteSliceExpectedEncoding ByteSliceExpectedEncodingMode
binaryUnmarshaler BinaryUnmarshalerMode
}

var defaultDecMode, _ = DecOptions{}.decMode()
Expand Down Expand Up @@ -1094,6 +1123,7 @@ func (dm *decMode) DecOptions() DecOptions {
Inf: dm.infDec,
ByteStringToTime: dm.byteStringToTime,
ByteSliceExpectedEncoding: dm.byteSliceExpectedEncoding,
BinaryUnmarshaler: dm.binaryUnmarshaler,
}
}

Expand Down Expand Up @@ -1413,7 +1443,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
return err
}
copied = copied || converted
return fillByteString(t, b, !copied, v, d.dm.byteStringToString)
return fillByteString(t, b, !copied, v, d.dm.byteStringToString, d.dm.binaryUnmarshaler)

case cborTypeTextString:
b, err := d.parseTextString()
Expand Down Expand Up @@ -1465,7 +1495,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
return nil
}
if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array {
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden)
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden, d.dm.binaryUnmarshaler)
}
if bi.IsUint64() {
return fillPositiveInt(t, bi.Uint64(), v)
Expand All @@ -1487,7 +1517,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
return nil
}
if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array {
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden)
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden, d.dm.binaryUnmarshaler)
}
if bi.IsInt64() {
return fillNegativeInt(t, bi.Int64(), v)
Expand Down Expand Up @@ -2885,8 +2915,8 @@ func fillFloat(t cborType, val float64, v reflect.Value) error {
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
}

func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts ByteStringToStringMode) error {
if reflect.PtrTo(v.Type()).Implements(typeBinaryUnmarshaler) {
func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts ByteStringToStringMode, bum BinaryUnmarshalerMode) error {
if bum == BinaryUnmarshalerByteString && reflect.PtrTo(v.Type()).Implements(typeBinaryUnmarshaler) {
if v.CanAddr() {
v = v.Addr()
if u, ok := v.Interface().(encoding.BinaryUnmarshaler); ok {
Expand Down
81 changes: 81 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4923,6 +4923,7 @@ func TestDecOptions(t *testing.T) {
Inf: InfDecodeForbidden,
ByteStringToTime: ByteStringToTimeAllowed,
ByteSliceExpectedEncoding: ByteSliceToByteStringWithExpectedConversionToBase64,
BinaryUnmarshaler: BinaryUnmarshalerNone,
}
ov := reflect.ValueOf(opts1)
for i := 0; i < ov.NumField(); i++ {
Expand Down Expand Up @@ -9926,3 +9927,83 @@ func TestUnmarshalByteStringTextConversion(t *testing.T) {
})
}
}

func TestDecModeInvalidBinaryUnmarshaler(t *testing.T) {
for _, tc := range []struct {
name string
opts DecOptions
wantErrorMsg string
}{
{
name: "below range of valid modes",
opts: DecOptions{BinaryUnmarshaler: -1},
wantErrorMsg: "cbor: invalid BinaryUnmarshaler -1",
},
{
name: "above range of valid modes",
opts: DecOptions{BinaryUnmarshaler: 101},
wantErrorMsg: "cbor: invalid BinaryUnmarshaler 101",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.opts.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
}
})
}
}

type testBinaryUnmarshaler []byte

func (bu *testBinaryUnmarshaler) UnmarshalBinary(_ []byte) error {
*bu = []byte("UnmarshalBinary")
return nil
}

func TestBinaryUnmarshalerMode(t *testing.T) {
for _, tc := range []struct {
name string
opts DecOptions
in []byte
want interface{}
}{
{
name: "UnmarshalBinary is called by default",
opts: DecOptions{},
in: []byte("\x45hello"), // 'hello'
want: testBinaryUnmarshaler("UnmarshalBinary"),
},
{
name: "UnmarshalBinary is called with BinaryUnmarshalerByteString",
opts: DecOptions{BinaryUnmarshaler: BinaryUnmarshalerByteString},
in: []byte("\x45hello"), // 'hello'
want: testBinaryUnmarshaler("UnmarshalBinary"),
},
{
name: "default byte slice unmarshaling behavior is used with BinaryUnmarshalerNone",
opts: DecOptions{BinaryUnmarshaler: BinaryUnmarshalerNone},
in: []byte("\x45hello"), // 'hello'
want: testBinaryUnmarshaler("hello"),
},
} {
t.Run(tc.name, func(t *testing.T) {
dm, err := tc.opts.DecMode()
if err != nil {
t.Fatal(err)
}

gotrv := reflect.New(reflect.TypeOf(tc.want))
if err := dm.Unmarshal(tc.in, gotrv.Interface()); err != nil {
t.Fatal(err)
}

got := gotrv.Elem().Interface()
if !reflect.DeepEqual(tc.want, got) {
t.Errorf("want: %v, got: %v", tc.want, got)
}
})
}
}
111 changes: 79 additions & 32 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,23 @@ func (bam ByteArrayMode) valid() bool {
return bam >= 0 && bam < maxByteArrayMode
}

// BinaryMarshalerMode specifies how to encode types that implement encoding.BinaryMarshaler.
type BinaryMarshalerMode int

const (
// BinaryMarshalerByteString encodes the output of MarshalBinary to a CBOR byte string.
BinaryMarshalerByteString BinaryMarshalerMode = iota

// BinaryMarshalerNone does not recognize BinaryMarshaler implementations during encode.
BinaryMarshalerNone

maxBinaryMarshalerMode
)

func (bmm BinaryMarshalerMode) valid() bool {
return bmm >= 0 && bmm < maxBinaryMarshalerMode
}

// EncOptions specifies encoding options.
type EncOptions struct {
// Sort specifies sorting order.
Expand Down Expand Up @@ -493,6 +510,9 @@ type EncOptions struct {

// ByteArray specifies how to encode byte arrays.
ByteArray ByteArrayMode

// BinaryMarshaler specifies how to encode types that implement encoding.BinaryMarshaler.
BinaryMarshaler BinaryMarshalerMode
}

// CanonicalEncOptions returns EncOptions for "Canonical CBOR" encoding,
Expand Down Expand Up @@ -685,6 +705,9 @@ func (opts EncOptions) encMode() (*encMode, error) {
if !opts.ByteArray.valid() {
return nil, errors.New("cbor: invalid ByteArray " + strconv.Itoa(int(opts.ByteArray)))
}
if !opts.BinaryMarshaler.valid() {
return nil, errors.New("cbor: invalid BinaryMarshaler " + strconv.Itoa(int(opts.BinaryMarshaler)))
}
em := encMode{
sort: opts.Sort,
shortestFloat: opts.ShortestFloat,
Expand All @@ -703,6 +726,7 @@ func (opts EncOptions) encMode() (*encMode, error) {
byteSlice: opts.ByteSlice,
byteSliceEncodingTag: byteSliceEncodingTag,
byteArray: opts.ByteArray,
binaryMarshaler: opts.BinaryMarshaler,
}
return &em, nil
}
Expand Down Expand Up @@ -733,6 +757,7 @@ type encMode struct {
byteSlice ByteSliceMode
byteSliceEncodingTag uint64
byteArray ByteArrayMode
binaryMarshaler BinaryMarshalerMode
}

var defaultEncMode, _ = EncOptions{}.encMode()
Expand Down Expand Up @@ -809,21 +834,22 @@ func getMarshalerDecMode(indefLength IndefLengthMode, tagsMd TagsMode) *decMode
// EncOptions returns user specified options used to create this EncMode.
func (em *encMode) EncOptions() EncOptions {
return EncOptions{
Sort: em.sort,
ShortestFloat: em.shortestFloat,
NaNConvert: em.nanConvert,
InfConvert: em.infConvert,
BigIntConvert: em.bigIntConvert,
Time: em.time,
TimeTag: em.timeTag,
IndefLength: em.indefLength,
NilContainers: em.nilContainers,
TagsMd: em.tagsMd,
OmitEmpty: em.omitEmpty,
String: em.stringType,
FieldName: em.fieldName,
ByteSlice: em.byteSlice,
ByteArray: em.byteArray,
Sort: em.sort,
ShortestFloat: em.shortestFloat,
NaNConvert: em.nanConvert,
InfConvert: em.infConvert,
BigIntConvert: em.bigIntConvert,
Time: em.time,
TimeTag: em.timeTag,
IndefLength: em.indefLength,
NilContainers: em.nilContainers,
TagsMd: em.tagsMd,
OmitEmpty: em.omitEmpty,
String: em.stringType,
FieldName: em.fieldName,
ByteSlice: em.byteSlice,
ByteArray: em.byteArray,
BinaryMarshaler: em.binaryMarshaler,
}
}

Expand Down Expand Up @@ -1513,7 +1539,16 @@ func encodeBigInt(e *bytes.Buffer, em *encMode, v reflect.Value) error {
return nil
}

func encodeBinaryMarshalerType(e *bytes.Buffer, em *encMode, v reflect.Value) error {
type binaryMarshalerEncoder struct {
alternateEncode encodeFunc
alternateIsEmpty isEmptyFunc
}

func (bme binaryMarshalerEncoder) encode(e *bytes.Buffer, em *encMode, v reflect.Value) error {
if em.binaryMarshaler != BinaryMarshalerByteString {
return bme.alternateEncode(e, em, v)
}

vt := v.Type()
m, ok := v.Interface().(encoding.BinaryMarshaler)
if !ok {
Expand All @@ -1533,6 +1568,24 @@ func encodeBinaryMarshalerType(e *bytes.Buffer, em *encMode, v reflect.Value) er
return nil
}

func (bme binaryMarshalerEncoder) isEmpty(em *encMode, v reflect.Value) (bool, error) {
if em.binaryMarshaler != BinaryMarshalerByteString {
return bme.alternateIsEmpty(em, v)
}

m, ok := v.Interface().(encoding.BinaryMarshaler)
if !ok {
pv := reflect.New(v.Type())
pv.Elem().Set(v)
m = pv.Interface().(encoding.BinaryMarshaler)
}
data, err := m.MarshalBinary()
if err != nil {
return false, err
}
return len(data) == 0, nil
}

func encodeMarshalerType(e *bytes.Buffer, em *encMode, v reflect.Value) error {
if em.tagsMd == TagsForbidden && v.Type() == typeRawTag {
return errors.New("cbor: cannot encode cbor.RawTag when TagsMd is TagsForbidden")
Expand Down Expand Up @@ -1618,7 +1671,7 @@ var (
typeByteString = reflect.TypeOf(ByteString(""))
)

func getEncodeFuncInternal(t reflect.Type) (encodeFunc, isEmptyFunc) {
func getEncodeFuncInternal(t reflect.Type) (ef encodeFunc, ief isEmptyFunc) {
k := t.Kind()
if k == reflect.Ptr {
return getEncodeIndirectValueFunc(t), isEmptyPtr
Expand All @@ -1641,7 +1694,15 @@ func getEncodeFuncInternal(t reflect.Type) (encodeFunc, isEmptyFunc) {
return encodeMarshalerType, alwaysNotEmpty
}
if reflect.PtrTo(t).Implements(typeBinaryMarshaler) {
return encodeBinaryMarshalerType, isEmptyBinaryMarshaler
defer func() {
// capture encoding method used for modes that disable BinaryMarshaler
bme := binaryMarshalerEncoder{
alternateEncode: ef,
alternateIsEmpty: ief,
}
ef = bme.encode
ief = bme.isEmpty
}()
}
switch k {
case reflect.Bool:
Expand Down Expand Up @@ -1795,20 +1856,6 @@ func isEmptyStruct(em *encMode, v reflect.Value) (bool, error) {
return true, nil
}

func isEmptyBinaryMarshaler(_ *encMode, v reflect.Value) (bool, error) {
m, ok := v.Interface().(encoding.BinaryMarshaler)
if !ok {
pv := reflect.New(v.Type())
pv.Elem().Set(v)
m = pv.Interface().(encoding.BinaryMarshaler)
}
data, err := m.MarshalBinary()
if err != nil {
return false, err
}
return len(data) == 0, nil
}

func cannotFitFloat32(f64 float64) bool {
f32 := float32(f64)
return float64(f32) != f64
Expand Down
Loading
Loading