Skip to content

Commit

Permalink
Merge pull request #8 from dutchcoders/master
Browse files Browse the repository at this point in the history
fixed issue with not overriding of existing value of map
  • Loading branch information
darccio committed Apr 6, 2015
2 parents 67b9c0a + b6ee4c7 commit 2f8eb1d
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 17 deletions.
20 changes: 14 additions & 6 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func isExported(field reflect.StructField) bool {
// Traverses recursively both values, assigning src's fields values to dst.
// The map argument tracks comparisons that have already been seen, which allows
// short circuiting on recursive types.
func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (err error) {
func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int, overwrite bool) (err error) {
if dst.CanAddr() {
addr := dst.UnsafeAddr()
h := 17 * addr
Expand All @@ -57,7 +57,7 @@ func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (err
}
fieldName := field.Name
fieldName = changeInitialCase(fieldName, unicode.ToLower)
if v, ok := dstMap[fieldName]; !ok || isEmptyValue(reflect.ValueOf(v)) {
if v, ok := dstMap[fieldName]; !ok || (isEmptyValue(reflect.ValueOf(v)) || overwrite) {
dstMap[fieldName] = src.Field(i).Interface()
}
}
Expand Down Expand Up @@ -89,12 +89,12 @@ func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (err
continue
}
if srcKind == dstKind {
if err = deepMerge(dstElement, srcElement, visited, depth+1); err != nil {
if err = deepMerge(dstElement, srcElement, visited, depth+1, overwrite); err != nil {
return
}
} else {
if srcKind == reflect.Map {
if err = deepMap(dstElement, srcElement, visited, depth+1); err != nil {
if err = deepMap(dstElement, srcElement, visited, depth+1, overwrite); err != nil {
return
}
} else {
Expand All @@ -118,6 +118,14 @@ func deepMap(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (err
// This is separated method from Merge because it is cleaner and it keeps sane
// semantics: merging equal types, mapping different (restricted) types.
func Map(dst, src interface{}) error {
return _map(dst, src, false)
}

func MapWithOverwrite(dst, src interface{}) error {
return _map(dst, src, true)
}

func _map(dst, src interface{}, overwrite bool) error {
var (
vDst, vSrc reflect.Value
err error
Expand All @@ -128,7 +136,7 @@ func Map(dst, src interface{}) error {
// To be friction-less, we redirect equal-type arguments
// to deepMerge. Only because arguments can be anything.
if vSrc.Kind() == vDst.Kind() {
return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0)
return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0, overwrite)
}
switch vSrc.Kind() {
case reflect.Struct:
Expand All @@ -142,5 +150,5 @@ func Map(dst, src interface{}) error {
default:
return ErrNotSupported
}
return deepMap(vDst, vSrc, make(map[uintptr]*visit), 0)
return deepMap(vDst, vSrc, make(map[uintptr]*visit), 0, overwrite)
}
22 changes: 15 additions & 7 deletions merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
// Traverses recursively both values, assigning src's fields values to dst.
// The map argument tracks comparisons that have already been seen, which allows
// short circuiting on recursive types.
func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (err error) {
func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int, overwrite bool) (err error) {
if !src.IsValid() {
return
}
Expand All @@ -35,7 +35,7 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (e
switch dst.Kind() {
case reflect.Struct:
for i, n := 0, dst.NumField(); i < n; i++ {
if err = deepMerge(dst.Field(i), src.Field(i), visited, depth+1); err != nil {
if err = deepMerge(dst.Field(i), src.Field(i), visited, depth+1, overwrite); err != nil {
return
}
}
Expand All @@ -50,11 +50,11 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (e
case reflect.Struct:
fallthrough
case reflect.Map:
if err = deepMerge(dstElement, srcElement, visited, depth+1); err != nil {
if err = deepMerge(dstElement, srcElement, visited, depth+1, overwrite); err != nil {
return
}
}
if !dstElement.IsValid() {
if !isEmptyValue(srcElement) && (overwrite || !dstElement.IsValid()) {
dst.SetMapIndex(key, srcElement)
}
}
Expand All @@ -64,10 +64,10 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (e
if src.IsNil() {
break
} else if dst.IsNil() {
if dst.CanSet() && isEmptyValue(dst) {
if dst.CanSet() && (isEmptyValue(dst) || overwrite) {
dst.Set(src)
}
} else if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1); err != nil {
} else if err = deepMerge(dst.Elem(), src.Elem(), visited, depth+1, overwrite); err != nil {
return
}
default:
Expand All @@ -85,6 +85,14 @@ func deepMerge(dst, src reflect.Value, visited map[uintptr]*visit, depth int) (e
// It won't merge unexported (private) fields and will do recursively
// any exported field.
func Merge(dst, src interface{}) error {
return merge(dst, src, false)
}

func MergeWithOverwrite(dst, src interface{}) error {
return merge(dst, src, true)
}

func merge(dst, src interface{}, overwrite bool) error {
var (
vDst, vSrc reflect.Value
err error
Expand All @@ -95,5 +103,5 @@ func Merge(dst, src interface{}) error {
if vDst.Type() != vSrc.Type() {
return ErrDifferentArgumentsTypes
}
return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0)
return deepMerge(vDst, vSrc, make(map[uintptr]*visit), 0, overwrite)
}
64 changes: 60 additions & 4 deletions mergo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,20 @@ func TestComplexStruct(t *testing.T) {
}
}

func TestComplexStructWithOverwrite(t *testing.T) {
a := complexTest{simpleTest{1}, 1, "do-not-overwrite-with-empty-value"}
b := complexTest{simpleTest{42}, 2, ""}

expect := complexTest{simpleTest{42}, 1, "do-not-overwrite-with-empty-value"}
if err := MergeWithOverwrite(&a, b); err != nil {
t.FailNow()
}

if !reflect.DeepEqual(a, expect) {
t.Fatalf("Test failed:\ngot :\n%#v\n\nwant :\n%#v\n\n", a, expect)
}
}

func TestPointerStruct(t *testing.T) {
s1 := simpleTest{}
s2 := simpleTest{19}
Expand Down Expand Up @@ -132,30 +146,72 @@ func TestSliceStruct(t *testing.T) {
}
}

func TestMapsWithOverwrite(t *testing.T) {
m := map[string]simpleTest{
"a": simpleTest{}, // overwritten by 16
"b": simpleTest{42}, // not overwritten by empty value
"c": simpleTest{13}, // overwritten by 12
"d": simpleTest{61},
}
n := map[string]simpleTest{
"a": simpleTest{16},
"b": simpleTest{},
"c": simpleTest{12},
"e": simpleTest{14},
}
expect := map[string]simpleTest{
"a": simpleTest{16},
"b": simpleTest{},
"c": simpleTest{12},
"d": simpleTest{61},
"e": simpleTest{14},
}

if err := MergeWithOverwrite(&m, n); err != nil {
t.Fatalf(err.Error())
}

if !reflect.DeepEqual(m, expect) {
t.Fatalf("Test failed:\ngot :\n%#v\n\nwant :\n%#v\n\n", m, expect)
}
}

func TestMaps(t *testing.T) {
m := map[string]simpleTest{
"a": simpleTest{},
"b": simpleTest{42},
"c": simpleTest{13},
"d": simpleTest{61},
}
n := map[string]simpleTest{
"a": simpleTest{16},
"b": simpleTest{},
"c": simpleTest{12},
"e": simpleTest{14},
}
expect := map[string]simpleTest{
"a": simpleTest{0},
"b": simpleTest{42},
"c": simpleTest{13},
"d": simpleTest{61},
"e": simpleTest{14},
}

if err := Merge(&m, n); err != nil {
t.Fatalf(err.Error())
}
if len(m) != 3 {
t.Fatalf(`n not merged in m properly, m must have 3 elements instead of %d`, len(m))

if !reflect.DeepEqual(m, expect) {
t.Fatalf("Test failed:\ngot :\n%#v\n\nwant :\n%#v\n\n", m, expect)
}
if m["a"].Value != 0 {
t.Fatalf(`n merged in m because I solved non-addressable map values TODO: m["a"].Value(%d) != n["a"].Value(%d)`, m["a"].Value, n["a"].Value)
}
if m["b"].Value != 42 {
t.Fatalf(`n wrongly merged in m: m["b"].Value(%d) != n["b"].Value(%d)`, m["b"].Value, n["b"].Value)
}
if m["c"].Value != 12 {
t.Fatalf(`n not merged in m: m["c"].Value(%d) != n["c"].Value(%d)`, m["c"].Value, n["c"].Value)
if m["c"].Value != 13 {
t.Fatalf(`n overwritten in m: m["c"].Value(%d) != n["c"].Value(%d)`, m["c"].Value, n["c"].Value)
}
}

Expand Down

0 comments on commit 2f8eb1d

Please sign in to comment.