Skip to content

Commit

Permalink
Add options to disable BinaryMarshaler/BinaryUnmarshaler support.
Browse files Browse the repository at this point in the history
By default, values whose type implements BinaryMarshaler encode to a byte string whose contents are
the result of calling MarshalBinary, and decoding a byte string into a BinaryUnmarshaler calls
UnmarshalBinary on the contents of the byte string. These options make it possible to disable both
behaviors.

Signed-off-by: Ben Luddy <[email protected]>
  • Loading branch information
benluddy committed May 10, 2024
1 parent 5a131e1 commit fcbe98d
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 52 deletions.
40 changes: 35 additions & 5 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,25 @@ func (bseem ByteSliceExpectedEncodingMode) valid() bool {
return bseem >= 0 && bseem < maxByteSliceExpectedEncodingMode
}

// BinaryUnmarshalerMode specifies how to decode into types that implement
// encoding.BinaryUnmarshaler.
type BinaryUnmarshalerMode int

const (
// BinaryUnmarshalerByteString will invoke UnmarshalBinary on the contents of a CBOR byte
// string when decoding into a value that implements BinaryUnmarshaler.
BinaryUnmarshalerByteString BinaryUnmarshalerMode = iota

// BinaryUnmarshalerNone does not recognize BinaryUnmarshaler implementations during decode.
BinaryUnmarshalerNone

maxBinaryUnmarshalerMode
)

func (bum BinaryUnmarshalerMode) valid() bool {
return bum >= 0 && bum < maxBinaryUnmarshalerMode
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -751,6 +770,10 @@ type DecOptions struct {
// ByteSliceExpectedEncodingMode specifies how to decode a byte string NOT enclosed in an
// "expected later encoding" tag (RFC 8949 Section 3.4.5.2) into a Go byte slice.
ByteSliceExpectedEncoding ByteSliceExpectedEncodingMode

// BinaryUnmarshaler specifies how to decode into types that implement
// encoding.BinaryUnmarshaler.
BinaryUnmarshaler BinaryUnmarshalerMode
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -954,6 +977,10 @@ func (opts DecOptions) decMode() (*decMode, error) {
return nil, errors.New("cbor: invalid ByteSliceExpectedEncoding " + strconv.Itoa(int(opts.ByteSliceExpectedEncoding)))
}

if !opts.BinaryUnmarshaler.valid() {
return nil, errors.New("cbor: invalid BinaryUnmarshaler " + strconv.Itoa(int(opts.BinaryUnmarshaler)))
}

dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
Expand All @@ -979,6 +1006,7 @@ func (opts DecOptions) decMode() (*decMode, error) {
infDec: opts.Inf,
byteStringToTime: opts.ByteStringToTime,
byteSliceExpectedEncoding: opts.ByteSliceExpectedEncoding,
binaryUnmarshaler: opts.BinaryUnmarshaler,
}

return &dm, nil
Expand Down Expand Up @@ -1056,6 +1084,7 @@ type decMode struct {
infDec InfMode
byteStringToTime ByteStringToTimeMode
byteSliceExpectedEncoding ByteSliceExpectedEncodingMode
binaryUnmarshaler BinaryUnmarshalerMode
}

var defaultDecMode, _ = DecOptions{}.decMode()
Expand Down Expand Up @@ -1094,6 +1123,7 @@ func (dm *decMode) DecOptions() DecOptions {
Inf: dm.infDec,
ByteStringToTime: dm.byteStringToTime,
ByteSliceExpectedEncoding: dm.byteSliceExpectedEncoding,
BinaryUnmarshaler: dm.binaryUnmarshaler,
}
}

Expand Down Expand Up @@ -1413,7 +1443,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
return err
}
copied = copied || converted
return fillByteString(t, b, !copied, v, d.dm.byteStringToString)
return fillByteString(t, b, !copied, v, d.dm.byteStringToString, d.dm.binaryUnmarshaler)

case cborTypeTextString:
b, err := d.parseTextString()
Expand Down Expand Up @@ -1465,7 +1495,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
return nil
}
if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array {
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden)
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden, d.dm.binaryUnmarshaler)
}
if bi.IsUint64() {
return fillPositiveInt(t, bi.Uint64(), v)
Expand All @@ -1487,7 +1517,7 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
return nil
}
if tInfo.nonPtrKind == reflect.Slice || tInfo.nonPtrKind == reflect.Array {
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden)
return fillByteString(t, b, !copied, v, ByteStringToStringForbidden, d.dm.binaryUnmarshaler)
}
if bi.IsInt64() {
return fillNegativeInt(t, bi.Int64(), v)
Expand Down Expand Up @@ -2885,8 +2915,8 @@ func fillFloat(t cborType, val float64, v reflect.Value) error {
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
}

func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts ByteStringToStringMode) error {
if reflect.PtrTo(v.Type()).Implements(typeBinaryUnmarshaler) {
func fillByteString(t cborType, val []byte, shared bool, v reflect.Value, bsts ByteStringToStringMode, bum BinaryUnmarshalerMode) error {
if bum == BinaryUnmarshalerByteString && reflect.PtrTo(v.Type()).Implements(typeBinaryUnmarshaler) {
if v.CanAddr() {
v = v.Addr()
if u, ok := v.Interface().(encoding.BinaryUnmarshaler); ok {
Expand Down
81 changes: 81 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4923,6 +4923,7 @@ func TestDecOptions(t *testing.T) {
Inf: InfDecodeForbidden,
ByteStringToTime: ByteStringToTimeAllowed,
ByteSliceExpectedEncoding: ByteSliceToByteStringWithExpectedConversionToBase64,
BinaryUnmarshaler: BinaryUnmarshalerNone,
}
ov := reflect.ValueOf(opts1)
for i := 0; i < ov.NumField(); i++ {
Expand Down Expand Up @@ -9926,3 +9927,83 @@ func TestUnmarshalByteStringTextConversion(t *testing.T) {
})
}
}

func TestDecModeInvalidBinaryUnmarshaler(t *testing.T) {
for _, tc := range []struct {
name string
opts DecOptions
wantErrorMsg string
}{
{
name: "below range of valid modes",
opts: DecOptions{BinaryUnmarshaler: -1},
wantErrorMsg: "cbor: invalid BinaryUnmarshaler -1",
},
{
name: "above range of valid modes",
opts: DecOptions{BinaryUnmarshaler: 101},
wantErrorMsg: "cbor: invalid BinaryUnmarshaler 101",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.opts.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
}
})
}
}

type testBinaryUnmarshaler []byte

func (bu *testBinaryUnmarshaler) UnmarshalBinary(_ []byte) error {
*bu = []byte("UnmarshalBinary")
return nil
}

func TestBinaryUnmarshalerMode(t *testing.T) {
for _, tc := range []struct {
name string
opts DecOptions
in []byte
want interface{}
}{
{
name: "UnmarshalBinary is called by default",
opts: DecOptions{},
in: []byte("\x45hello"), // 'hello'
want: testBinaryUnmarshaler("UnmarshalBinary"),
},
{
name: "UnmarshalBinary is called with BinaryUnmarshalerByteString",
opts: DecOptions{BinaryUnmarshaler: BinaryUnmarshalerByteString},
in: []byte("\x45hello"), // 'hello'
want: testBinaryUnmarshaler("UnmarshalBinary"),
},
{
name: "default byte slice unmarshaling behavior is used with BinaryUnmarshalerNone",
opts: DecOptions{BinaryUnmarshaler: BinaryUnmarshalerNone},
in: []byte("\x45hello"), // 'hello'
want: testBinaryUnmarshaler("hello"),
},
} {
t.Run(tc.name, func(t *testing.T) {
dm, err := tc.opts.DecMode()
if err != nil {
t.Fatal(err)
}

gotrv := reflect.New(reflect.TypeOf(tc.want))
if err := dm.Unmarshal(tc.in, gotrv.Interface()); err != nil {
t.Fatal(err)
}

got := gotrv.Elem().Interface()
if !reflect.DeepEqual(tc.want, got) {
t.Errorf("want: %v, got: %v", tc.want, got)
}
})
}
}
111 changes: 79 additions & 32 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,23 @@ func (bam ByteArrayMode) valid() bool {
return bam >= 0 && bam < maxByteArrayMode
}

// BinaryMarshalerMode specifies how to encode types that implement encoding.BinaryMarshaler.
type BinaryMarshalerMode int

const (
// BinaryMarshalerByteString encodes the output of MarshalBinary to a CBOR byte string.
BinaryMarshalerByteString BinaryMarshalerMode = iota

// BinaryMarshalerNone does not recognize BinaryMarshaler implementations during encode.
BinaryMarshalerNone

maxBinaryMarshalerMode
)

func (bmm BinaryMarshalerMode) valid() bool {
return bmm >= 0 && bmm < maxBinaryMarshalerMode
}

// EncOptions specifies encoding options.
type EncOptions struct {
// Sort specifies sorting order.
Expand Down Expand Up @@ -493,6 +510,9 @@ type EncOptions struct {

// ByteArray specifies how to encode byte arrays.
ByteArray ByteArrayMode

// BinaryMarshaler specifies how to encode types that implement encoding.BinaryMarshaler.
BinaryMarshaler BinaryMarshalerMode
}

// CanonicalEncOptions returns EncOptions for "Canonical CBOR" encoding,
Expand Down Expand Up @@ -685,6 +705,9 @@ func (opts EncOptions) encMode() (*encMode, error) {
if !opts.ByteArray.valid() {
return nil, errors.New("cbor: invalid ByteArray " + strconv.Itoa(int(opts.ByteArray)))
}
if !opts.BinaryMarshaler.valid() {
return nil, errors.New("cbor: invalid BinaryMarshaler " + strconv.Itoa(int(opts.BinaryMarshaler)))
}
em := encMode{
sort: opts.Sort,
shortestFloat: opts.ShortestFloat,
Expand All @@ -703,6 +726,7 @@ func (opts EncOptions) encMode() (*encMode, error) {
byteSlice: opts.ByteSlice,
byteSliceEncodingTag: byteSliceEncodingTag,
byteArray: opts.ByteArray,
binaryMarshaler: opts.BinaryMarshaler,
}
return &em, nil
}
Expand Down Expand Up @@ -733,6 +757,7 @@ type encMode struct {
byteSlice ByteSliceMode
byteSliceEncodingTag uint64
byteArray ByteArrayMode
binaryMarshaler BinaryMarshalerMode
}

var defaultEncMode, _ = EncOptions{}.encMode()
Expand Down Expand Up @@ -809,21 +834,22 @@ func getMarshalerDecMode(indefLength IndefLengthMode, tagsMd TagsMode) *decMode
// EncOptions returns user specified options used to create this EncMode.
func (em *encMode) EncOptions() EncOptions {
return EncOptions{
Sort: em.sort,
ShortestFloat: em.shortestFloat,
NaNConvert: em.nanConvert,
InfConvert: em.infConvert,
BigIntConvert: em.bigIntConvert,
Time: em.time,
TimeTag: em.timeTag,
IndefLength: em.indefLength,
NilContainers: em.nilContainers,
TagsMd: em.tagsMd,
OmitEmpty: em.omitEmpty,
String: em.stringType,
FieldName: em.fieldName,
ByteSlice: em.byteSlice,
ByteArray: em.byteArray,
Sort: em.sort,
ShortestFloat: em.shortestFloat,
NaNConvert: em.nanConvert,
InfConvert: em.infConvert,
BigIntConvert: em.bigIntConvert,
Time: em.time,
TimeTag: em.timeTag,
IndefLength: em.indefLength,
NilContainers: em.nilContainers,
TagsMd: em.tagsMd,
OmitEmpty: em.omitEmpty,
String: em.stringType,
FieldName: em.fieldName,
ByteSlice: em.byteSlice,
ByteArray: em.byteArray,
BinaryMarshaler: em.binaryMarshaler,
}
}

Expand Down Expand Up @@ -1508,7 +1534,16 @@ func encodeBigInt(e *encoderBuffer, em *encMode, v reflect.Value) error {
return nil
}

func encodeBinaryMarshalerType(e *encoderBuffer, em *encMode, v reflect.Value) error {
type binaryMarshalerEncoder struct {
alternateEncode encodeFunc
alternateIsEmpty isEmptyFunc
}

func (bme binaryMarshalerEncoder) encode(e *encoderBuffer, em *encMode, v reflect.Value) error {
if em.binaryMarshaler != BinaryMarshalerByteString {
return bme.alternateEncode(e, em, v)
}

vt := v.Type()
m, ok := v.Interface().(encoding.BinaryMarshaler)
if !ok {
Expand All @@ -1528,6 +1563,24 @@ func encodeBinaryMarshalerType(e *encoderBuffer, em *encMode, v reflect.Value) e
return nil
}

func (bme binaryMarshalerEncoder) isEmpty(em *encMode, v reflect.Value) (bool, error) {
if em.binaryMarshaler != BinaryMarshalerByteString {
return bme.alternateIsEmpty(em, v)
}

m, ok := v.Interface().(encoding.BinaryMarshaler)
if !ok {
pv := reflect.New(v.Type())
pv.Elem().Set(v)
m = pv.Interface().(encoding.BinaryMarshaler)
}
data, err := m.MarshalBinary()
if err != nil {
return false, err
}
return len(data) == 0, nil
}

func encodeMarshalerType(e *encoderBuffer, em *encMode, v reflect.Value) error {
if em.tagsMd == TagsForbidden && v.Type() == typeRawTag {
return errors.New("cbor: cannot encode cbor.RawTag when TagsMd is TagsForbidden")
Expand Down Expand Up @@ -1611,7 +1664,7 @@ var (
typeByteString = reflect.TypeOf(ByteString(""))
)

func getEncodeFuncInternal(t reflect.Type) (encodeFunc, isEmptyFunc) {
func getEncodeFuncInternal(t reflect.Type) (ef encodeFunc, ief isEmptyFunc) {
k := t.Kind()
if k == reflect.Ptr {
return getEncodeIndirectValueFunc(t), isEmptyPtr
Expand All @@ -1634,7 +1687,15 @@ func getEncodeFuncInternal(t reflect.Type) (encodeFunc, isEmptyFunc) {
return encodeMarshalerType, alwaysNotEmpty
}
if reflect.PtrTo(t).Implements(typeBinaryMarshaler) {
return encodeBinaryMarshalerType, isEmptyBinaryMarshaler
defer func() {
// capture encoding method used for modes that disable BinaryMarshaler
bme := binaryMarshalerEncoder{
alternateEncode: ef,
alternateIsEmpty: ief,
}
ef = bme.encode
ief = bme.isEmpty
}()
}
switch k {
case reflect.Bool:
Expand Down Expand Up @@ -1788,20 +1849,6 @@ func isEmptyStruct(em *encMode, v reflect.Value) (bool, error) {
return true, nil
}

func isEmptyBinaryMarshaler(_ *encMode, v reflect.Value) (bool, error) {
m, ok := v.Interface().(encoding.BinaryMarshaler)
if !ok {
pv := reflect.New(v.Type())
pv.Elem().Set(v)
m = pv.Interface().(encoding.BinaryMarshaler)
}
data, err := m.MarshalBinary()
if err != nil {
return false, err
}
return len(data) == 0, nil
}

func cannotFitFloat32(f64 float64) bool {
f32 := float32(f64)
return float64(f32) != f64
Expand Down
Loading

0 comments on commit fcbe98d

Please sign in to comment.