Skip to content

Commit

Permalink
GODRIVER-2725 Allow setting Encoder and Decoder options on a Client. (#…
Browse files Browse the repository at this point in the history
…1282)

Co-authored-by: Preston Vasquez <[email protected]>
  • Loading branch information
matthewdale and prestonvasquez authored Jun 21, 2023
1 parent 836d408 commit 41ebbc3
Show file tree
Hide file tree
Showing 31 changed files with 1,716 additions and 697 deletions.
8 changes: 4 additions & 4 deletions bson/bsoncodec/bsoncodec.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,16 +269,16 @@ func (dc *DecodeContext) ZeroStructs() {
dc.zeroStructs = true
}

// DefaultDocumentM will decode empty documents using the primitive.M type. This behavior is restricted to data typed as
// "interface{}" or "map[string]interface{}".
// DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This
// behavior is restricted to data typed as "interface{}" or "map[string]interface{}".
//
// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.DefaultDocumentM] instead.
func (dc *DecodeContext) DefaultDocumentM() {
dc.defaultDocumentType = reflect.TypeOf(primitive.M{})
}

// DefaultDocumentD will decode empty documents using the primitive.D type. This behavior is restricted to data typed as
// "interface{}" or "map[string]interface{}".
// DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This
// behavior is restricted to data typed as "interface{}" or "map[string]interface{}".
//
// Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.DefaultDocumentD] instead.
func (dc *DecodeContext) DefaultDocumentD() {
Expand Down
11 changes: 7 additions & 4 deletions bson/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,14 @@ func (d *Decoder) Decode(val interface{}) error {
// Reset will reset the state of the decoder, using the same *DecodeContext used in
// the original construction but using vr for reading.
func (d *Decoder) Reset(vr bsonrw.ValueReader) error {
// TODO:(GODRIVER-2719): Remove error return value.
d.vr = vr
return nil
}

// SetRegistry replaces the current registry of the decoder with r.
func (d *Decoder) SetRegistry(r *bsoncodec.Registry) error {
// TODO:(GODRIVER-2719): Remove error return value.
d.dc.Registry = r
return nil
}
Expand All @@ -151,18 +153,19 @@ func (d *Decoder) SetRegistry(r *bsoncodec.Registry) error {
//
// Deprecated: Use the Decoder configuration methods to set the desired unmarshal behavior instead.
func (d *Decoder) SetContext(dc bsoncodec.DecodeContext) error {
// TODO:(GODRIVER-2719): Remove error return value.
d.dc = dc
return nil
}

// DefaultDocumentM will decode empty documents using the primitive.M type. This behavior is restricted to data typed as
// "interface{}" or "map[string]interface{}".
// DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This
// behavior is restricted to data typed as "interface{}" or "map[string]interface{}".
func (d *Decoder) DefaultDocumentM() {
d.defaultDocumentM = true
}

// DefaultDocumentD will decode empty documents using the primitive.D type. This behavior is restricted to data typed as
// "interface{}" or "map[string]interface{}".
// DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This
// behavior is restricted to data typed as "interface{}" or "map[string]interface{}".
func (d *Decoder) DefaultDocumentD() {
d.defaultDocumentD = true
}
Expand Down
4 changes: 4 additions & 0 deletions bson/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type Encoder struct {

// NewEncoder returns a new encoder that uses the DefaultRegistry to write to vw.
func NewEncoder(vw bsonrw.ValueWriter) (*Encoder, error) {
// TODO:(GODRIVER-2719): Remove error return value.
if vw == nil {
return nil, errors.New("cannot create a new Encoder with a nil ValueWriter")
}
Expand Down Expand Up @@ -121,12 +122,14 @@ func (e *Encoder) Encode(val interface{}) error {
// Reset will reset the state of the Encoder, using the same *EncodeContext used in
// the original construction but using vw.
func (e *Encoder) Reset(vw bsonrw.ValueWriter) error {
// TODO:(GODRIVER-2719): Remove error return value.
e.vw = vw
return nil
}

// SetRegistry replaces the current registry of the Encoder with r.
func (e *Encoder) SetRegistry(r *bsoncodec.Registry) error {
// TODO:(GODRIVER-2719): Remove error return value.
e.ec.Registry = r
return nil
}
Expand All @@ -135,6 +138,7 @@ func (e *Encoder) SetRegistry(r *bsoncodec.Registry) error {
//
// Deprecated: Use the Encoder configuration methods set the desired marshal behavior instead.
func (e *Encoder) SetContext(ec bsoncodec.EncodeContext) error {
// TODO:(GODRIVER-2719): Remove error return value.
e.ec = ec
return nil
}
Expand Down
45 changes: 1 addition & 44 deletions bson/unmarshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"math/rand"
"reflect"
"testing"
"unsafe"

"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
Expand Down Expand Up @@ -770,49 +769,7 @@ func TestUnmarshalByteSlicesUseDistinctArrays(t *testing.T) {

// Assert that the byte slice in the unmarshaled value does not share any memory
// addresses with the input byte slice.
assertDifferentArrays(t, data, tc.getByteSlice(got))
assert.DifferentAddressRanges(t, data, tc.getByteSlice(got))
})
}
}

// assertDifferentArrays asserts that two byte slices reference distinct memory ranges, meaning
// they reference different underlying byte arrays.
func assertDifferentArrays(t *testing.T, a, b []byte) {
// Find the start and end memory addresses for the underlying byte array for each input byte
// slice.
sliceAddrRange := func(b []byte) (uintptr, uintptr) {
sh := (*reflect.SliceHeader)(unsafe.Pointer(&b))
return sh.Data, sh.Data + uintptr(sh.Cap-1)
}
aStart, aEnd := sliceAddrRange(a)
bStart, bEnd := sliceAddrRange(b)

// If "b" starts after "a" ends or "a" starts after "b" ends, there is no overlap.
if bStart > aEnd || aStart > bEnd {
return
}

// Otherwise, calculate the overlap start and end and print the memory overlap error message.
min := func(a, b uintptr) uintptr {
if a < b {
return a
}
return b
}
max := func(a, b uintptr) uintptr {
if a > b {
return a
}
return b
}
overlapLow := max(aStart, bStart)
overlapHigh := min(aEnd, bEnd)

t.Errorf("Byte slices point to the same the same underlying byte array:\n"+
"\ta addresses:\t%d ... %d\n"+
"\tb addresses:\t%d ... %d\n"+
"\toverlap:\t%d ... %d",
aStart, aEnd,
bStart, bEnd,
overlapLow, overlapHigh)
}
90 changes: 90 additions & 0 deletions internal/assert/assertion_mongo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

// assertion_mongo.go contains MongoDB-specific extensions to the "assert"
// package.

package assert

import (
"fmt"
"reflect"
"unsafe"
)

// DifferentAddressRanges asserts that two byte slices reference distinct memory
// address ranges, meaning they reference different underlying byte arrays.
func DifferentAddressRanges(t TestingT, a, b []byte) (ok bool) {
if h, ok := t.(tHelper); ok {
h.Helper()
}

if len(a) == 0 || len(b) == 0 {
return true
}

// Find the start and end memory addresses for the underlying byte array for
// each input byte slice.
sliceAddrRange := func(b []byte) (uintptr, uintptr) {
sh := (*reflect.SliceHeader)(unsafe.Pointer(&b))
return sh.Data, sh.Data + uintptr(sh.Cap-1)
}
aStart, aEnd := sliceAddrRange(a)
bStart, bEnd := sliceAddrRange(b)

// If "b" starts after "a" ends or "a" starts after "b" ends, there is no
// overlap.
if bStart > aEnd || aStart > bEnd {
return true
}

// Otherwise, calculate the overlap start and end and print the memory
// overlap error message.
min := func(a, b uintptr) uintptr {
if a < b {
return a
}
return b
}
max := func(a, b uintptr) uintptr {
if a > b {
return a
}
return b
}
overlapLow := max(aStart, bStart)
overlapHigh := min(aEnd, bEnd)

t.Errorf("Byte slices point to the same underlying byte array:\n"+
"\ta addresses:\t%d ... %d\n"+
"\tb addresses:\t%d ... %d\n"+
"\toverlap:\t%d ... %d",
aStart, aEnd,
bStart, bEnd,
overlapLow, overlapHigh)

return false
}

// EqualBSON asserts that the expected and actual BSON binary values are equal.
// If the values are not equal, it prints both the binary and Extended JSON diff
// of the BSON values. The provided BSON value types must implement the
// fmt.Stringer interface.
func EqualBSON(t TestingT, expected, actual interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}

return Equal(t,
expected,
actual,
`expected and actual BSON values do not match
As Extended JSON:
Expected: %s
Actual : %s`,
expected.(fmt.Stringer).String(),
actual.(fmt.Stringer).String())
}
125 changes: 125 additions & 0 deletions internal/assert/assertion_mongo_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package assert

import (
"testing"

"go.mongodb.org/mongo-driver/bson"
)

func TestDifferentAddressRanges(t *testing.T) {
t.Parallel()

slice := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}

testCases := []struct {
name string
a []byte
b []byte
want bool
}{
{
name: "distinct byte slices",
a: []byte{0, 1, 2, 3},
b: []byte{0, 1, 2, 3},
want: true,
},
{
name: "same byte slice",
a: slice,
b: slice,
want: false,
},
{
name: "whole and subslice",
a: slice,
b: slice[:4],
want: false,
},
{
name: "two subslices",
a: slice[1:2],
b: slice[3:4],
want: false,
},
{
name: "empty",
a: []byte{0, 1, 2, 3},
b: []byte{},
want: true,
},
{
name: "nil",
a: []byte{0, 1, 2, 3},
b: nil,
want: true,
},
}

for _, tc := range testCases {
tc := tc // Capture range variable.

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

got := DifferentAddressRanges(new(testing.T), tc.a, tc.b)
if got != tc.want {
t.Errorf("DifferentAddressRanges(%p, %p) = %v, want %v", tc.a, tc.b, got, tc.want)
}
})
}
}

func TestEqualBSON(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
expected interface{}
actual interface{}
want bool
}{
{
name: "equal bson.Raw",
expected: bson.Raw{5, 0, 0, 0, 0},
actual: bson.Raw{5, 0, 0, 0, 0},
want: true,
},
{
name: "different bson.Raw",
expected: bson.Raw{8, 0, 0, 0, 10, 120, 0, 0},
actual: bson.Raw{5, 0, 0, 0, 0},
want: false,
},
{
name: "invalid bson.Raw",
expected: bson.Raw{99, 99, 99, 99},
actual: bson.Raw{5, 0, 0, 0, 0},
want: false,
},
{
name: "nil bson.Raw",
expected: bson.Raw(nil),
actual: bson.Raw(nil),
want: true,
},
}

for _, tc := range testCases {
tc := tc // Capture range variable.

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

got := EqualBSON(new(testing.T), tc.expected, tc.actual)
if got != tc.want {
t.Errorf("EqualBSON(%#v, %#v) = %v, want %v", tc.expected, tc.actual, got, tc.want)
}
})
}
}
Loading

0 comments on commit 41ebbc3

Please sign in to comment.