Skip to content

Commit 4403578

Browse files
authored
diff: detect cycles in the values being compared. (#64)
This prevents infinite recursion, similar to the fixes to the formatter in #13. In addition to detecting cycles, we also check that the two values contain the same cyclic structure.
1 parent ead4522 commit 4403578

File tree

2 files changed

+87
-13
lines changed

2 files changed

+87
-13
lines changed

diff.go

+31-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@ type Printfer interface {
4141
// It calls Printf once for each difference, with no trailing newline.
4242
// The standard library log.Logger is a Printfer.
4343
func Pdiff(p Printfer, a, b interface{}) {
44-
diffPrinter{w: p}.diff(reflect.ValueOf(a), reflect.ValueOf(b))
44+
d := diffPrinter{
45+
w: p,
46+
aVisited: make(map[visit]visit),
47+
bVisited: make(map[visit]visit),
48+
}
49+
d.diff(reflect.ValueOf(a), reflect.ValueOf(b))
4550
}
4651

4752
type Logfer interface {
@@ -66,6 +71,9 @@ func Ldiff(l Logfer, a, b interface{}) {
6671
type diffPrinter struct {
6772
w Printfer
6873
l string // label
74+
75+
aVisited map[visit]visit
76+
bVisited map[visit]visit
6977
}
7078

7179
func (w diffPrinter) printf(f string, a ...interface{}) {
@@ -96,6 +104,28 @@ func (w diffPrinter) diff(av, bv reflect.Value) {
96104
return
97105
}
98106

107+
if av.CanAddr() && bv.CanAddr() {
108+
avis := visit{av.UnsafeAddr(), at}
109+
bvis := visit{bv.UnsafeAddr(), bt}
110+
var cycle bool
111+
112+
// Have we seen this value before?
113+
if vis, ok := w.aVisited[avis]; ok {
114+
cycle = true
115+
if vis != bvis {
116+
w.printf("%# v (previously visited) != %# v", formatter{v: av, quote: true}, formatter{v: bv, quote: true})
117+
}
118+
} else if _, ok := w.bVisited[bvis]; ok {
119+
cycle = true
120+
w.printf("%# v != %# v (previously visited)", formatter{v: av, quote: true}, formatter{v: bv, quote: true})
121+
}
122+
w.aVisited[avis] = bvis
123+
w.bVisited[bvis] = avis
124+
if cycle {
125+
return
126+
}
127+
}
128+
99129
switch kind := at.Kind(); kind {
100130
case reflect.Bool:
101131
if a, b := av.Bool(), bv.Bool(); a != b {

diff_test.go

+56-12
Original file line numberDiff line numberDiff line change
@@ -130,20 +130,23 @@ var diffs = []difftest{
130130

131131
func TestDiff(t *testing.T) {
132132
for _, tt := range diffs {
133-
got := Diff(tt.a, tt.b)
134-
eq := len(got) == len(tt.exp)
135-
if eq {
136-
for i := range got {
137-
eq = eq && got[i] == tt.exp[i]
138-
}
139-
}
140-
if !eq {
141-
t.Errorf("diffing % #v", tt.a)
142-
t.Errorf("with % #v", tt.b)
143-
diffdiff(t, got, tt.exp)
144-
continue
133+
expectDiffOutput(t, tt.a, tt.b, tt.exp)
134+
}
135+
}
136+
137+
func expectDiffOutput(t *testing.T, a, b interface{}, exp []string) {
138+
got := Diff(a, b)
139+
eq := len(got) == len(exp)
140+
if eq {
141+
for i := range got {
142+
eq = eq && got[i] == exp[i]
145143
}
146144
}
145+
if !eq {
146+
t.Errorf("diffing % #v", a)
147+
t.Errorf("with % #v", b)
148+
diffdiff(t, got, exp)
149+
}
147150
}
148151

149152
func TestKeyEqual(t *testing.T) {
@@ -193,6 +196,47 @@ func TestFdiff(t *testing.T) {
193196
}
194197
}
195198

199+
func TestDiffCycle(t *testing.T) {
200+
// Diff two cyclic structs
201+
a := &I{i: 1, R: nil}
202+
a.R = a
203+
b := &I{i: 2, R: nil}
204+
b.R = b
205+
expectDiffOutput(t, a, b, []string{
206+
`i: 1 != 2`,
207+
})
208+
209+
// Diff two equal cyclic structs
210+
b.i = 1
211+
expectDiffOutput(t, a, b, []string{})
212+
213+
// Diff two structs with different cycles
214+
b2 := &I{i: 1, R: b}
215+
b.R = b2
216+
expectDiffOutput(t, a, b, []string{`R: pretty.I{
217+
i: 1,
218+
R: &pretty.I{(CYCLIC REFERENCE)},
219+
} (previously visited) != pretty.I{
220+
i: 1,
221+
R: &pretty.I{
222+
i: 1,
223+
R: &pretty.I{(CYCLIC REFERENCE)},
224+
},
225+
}`})
226+
227+
// ... and the same in the other direction
228+
expectDiffOutput(t, b, a, []string{`R: pretty.I{
229+
i: 1,
230+
R: &pretty.I{
231+
i: 1,
232+
R: &pretty.I{(CYCLIC REFERENCE)},
233+
},
234+
} != pretty.I{
235+
i: 1,
236+
R: &pretty.I{(CYCLIC REFERENCE)},
237+
} (previously visited)`})
238+
}
239+
196240
func diffdiff(t *testing.T, got, exp []string) {
197241
minus(t, "unexpected:", got, exp)
198242
minus(t, "missing:", exp, got)

0 commit comments

Comments
 (0)