Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed issue with not overriding of existing value of map #8

Merged
merged 2 commits into from
Apr 6, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func isExported(field reflect.StructField) bool {
// Traverses recursively both values, assigning src's fields values to dst.
// The map argument tracks comparisons that have already been seen, which allows
// short circuiting on recursive types.
func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (err error) {
func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int, overwrite bool) (err error) {
if dst.CanAddr() {
addr := dst.UnsafeAddr()
h := 17 * addr
Expand All @@ -57,7 +57,7 @@ func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (err
}
fieldName := field.Name
fieldName = changeInitialCase(fieldName, unicode.ToLower)
if v, ok := dstMap[fieldName]; !ok || isEmptyValue(reflect.ValueOf(v)) {
if v, ok := dstMap[fieldName]; !ok || (isEmptyValue(reflect.ValueOf(v)) || overwrite) {
dstMap[fieldName] = src.Field(i).Interface()
}
}
Expand Down Expand Up @@ -89,12 +89,12 @@ func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (err
continue
}
if srcKind == dstKind {
if err = deepMerge(dstElement, srcElement, visited, depth+1); err != nil {
if err = deepMerge(dstElement, srcElement, visited, depth+1, overwrite); err != nil {
return
}
} else {
if srcKind == reflect.Map {
if err = deepMap(dstElement, srcElement, visited, depth+1); err != nil {
if err = deepMap(dstElement, srcElement, visited, depth+1, overwrite); err != nil {
return
}
} else {
Expand All @@ -118,6 +118,14 @@ func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (err
// This is separated method from Merge because it is cleaner and it keeps sane
// semantics: merging equal types, mapping different (restricted) types.
func Map(dst, src interface{}) error {
return _map(dst, src, false)
}

func MapWithOverwrite(dst, src interface{}) error {
return _map(dst, src, true)
}

func _map(dst, src interface{}, overwrite bool) error {
var (
vDst, vSrc reflect.Value
err error
Expand All @@ -128,7 +136,7 @@ func Map(dst, src interface{}) error {
// To be friction-less, we redirect equal-type arguments
// to deepMerge. Only because arguments can be anything.
if vSrc.Kind() == vDst.Kind() {
return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0)
return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0, overwrite)
}
switch vSrc.Kind() {
case reflect.Struct:
Expand All @@ -142,5 +150,5 @@ func Map(dst, src interface{}) error {
default:
return ErrNotSupported
}
return deepMap(vDst, vSrc, make(map[uintptr]*visit), 0)
return deepMap(vDst, vSrc, make(map[uintptr]*visit), 0, overwrite)
}
22 changes: 15 additions & 7 deletions merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
// Traverses recursively both values, assigning src's fields values to dst.
// The map argument tracks comparisons that have already been seen, which allows
// short circuiting on recursive types.
func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (err error) {
func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, overwrite bool) (err error) {
if !src.IsValid() {
return
}
Expand All @@ -35,7 +35,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (e
switch dst.Kind() {
case reflect.Struct:
for i, n := 0, dst.NumField(); i < n; i++ {
if err = deepMerge(dst.Field(i), src.Field(i), visited, depth+1); err != nil {
if err = deepMerge(dst.Field(i), src.Field(i), visited, depth+1, overwrite); err != nil {
return
}
}
Expand All @@ -50,11 +50,11 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (e
case reflect.Struct:
fallthrough
case reflect.Map:
if err = deepMerge(dstElement, srcElement, visited, depth+1); err != nil {
if err = deepMerge(dstElement, srcElement, visited, depth+1, overwrite); err != nil {
return
}
}
if !dstElement.IsValid() {
if !isEmptyValue(srcElement) && (overwrite || !dstElement.IsValid()) {
dst.SetMapIndex(key, srcElement)
}
}
Expand All @@ -64,10 +64,10 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (e
if src.IsNil() {
break
} else if dst.IsNil() {
if dst.CanSet() && isEmptyValue(dst) {
if dst.CanSet() && (isEmptyValue(dst) || overwrite) {
dst.Set(src)
}
} else if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1); err != nil {
} else if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, overwrite); err != nil {
return
}
default:
Expand All @@ -85,6 +85,14 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (e
// It won't merge unexported (private) fields and will do recursively
// any exported field.
func Merge(dst, src interface{}) error {
return merge(dst, src, false)
}

func MergeWithOverwrite(dst, src interface{}) error {
return merge(dst, src, true)
}

func merge(dst, src interface{}, overwrite bool) error {
var (
vDst, vSrc reflect.Value
err error
Expand All @@ -95,5 +103,5 @@ func Merge(dst, src interface{}) error {
if vDst.Type() != vSrc.Type() {
return ErrDifferentArgumentsTypes
}
return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0)
return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0, overwrite)
}
64 changes: 60 additions & 4 deletions mergo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,20 @@ func TestComplexStruct(t *testing.T) {
}
}

func TestComplexStructWithOverwrite(t *testing.T) {
a := complexTest{simpleTest{1}, 1, "do-not-overwrite-with-empty-value"}
b := complexTest{simpleTest{42}, 2, ""}

expect := complexTest{simpleTest{42}, 1, "do-not-overwrite-with-empty-value"}
if err := MergeWithOverwrite(&a, b); err != nil {
t.FailNow()
}

if !reflect.DeepEqual(a, expect) {
t.Fatalf("Test failed:\ngot :\n%#v\n\nwant :\n%#v\n\n", a, expect)
}
}

func TestPointerStruct(t *testing.T) {
s1 := simpleTest{}
s2 := simpleTest{19}
Expand Down Expand Up @@ -132,30 +146,72 @@ func TestSliceStruct(t *testing.T) {
}
}

func TestMapsWithOverwrite(t *testing.T) {
m := map[string]simpleTest{
"a": simpleTest{}, // overwritten by 16
"b": simpleTest{42}, // not overwritten by empty value
"c": simpleTest{13}, // overwritten by 12
"d": simpleTest{61},
}
n := map[string]simpleTest{
"a": simpleTest{16},
"b": simpleTest{},
"c": simpleTest{12},
"e": simpleTest{14},
}
expect := map[string]simpleTest{
"a": simpleTest{16},
"b": simpleTest{},
"c": simpleTest{12},
"d": simpleTest{61},
"e": simpleTest{14},
}

if err := MergeWithOverwrite(&m, n); err != nil {
t.Fatalf(err.Error())
}

if !reflect.DeepEqual(m, expect) {
t.Fatalf("Test failed:\ngot :\n%#v\n\nwant :\n%#v\n\n", m, expect)
}
}

func TestMaps(t *testing.T) {
m := map[string]simpleTest{
"a": simpleTest{},
"b": simpleTest{42},
"c": simpleTest{13},
"d": simpleTest{61},
}
n := map[string]simpleTest{
"a": simpleTest{16},
"b": simpleTest{},
"c": simpleTest{12},
"e": simpleTest{14},
}
expect := map[string]simpleTest{
"a": simpleTest{0},
"b": simpleTest{42},
"c": simpleTest{13},
"d": simpleTest{61},
"e": simpleTest{14},
}

if err := Merge(&m, n); err != nil {
t.Fatalf(err.Error())
}
if len(m) != 3 {
t.Fatalf(`n not merged in m properly, m must have 3 elements instead of %d`, len(m))

if !reflect.DeepEqual(m, expect) {
t.Fatalf("Test failed:\ngot :\n%#v\n\nwant :\n%#v\n\n", m, expect)
}
if m["a"].Value != 0 {
t.Fatalf(`n merged in m because I solved non-addressable map values TODO: m["a"].Value(%d) != n["a"].Value(%d)`, m["a"].Value, n["a"].Value)
}
if m["b"].Value != 42 {
t.Fatalf(`n wrongly merged in m: m["b"].Value(%d) != n["b"].Value(%d)`, m["b"].Value, n["b"].Value)
}
if m["c"].Value != 12 {
t.Fatalf(`n not merged in m: m["c"].Value(%d) != n["c"].Value(%d)`, m["c"].Value, n["c"].Value)
if m["c"].Value != 13 {
t.Fatalf(`n overwritten in m: m["c"].Value(%d) != n["c"].Value(%d)`, m["c"].Value, n["c"].Value)
}
}

Expand Down