Skip to content

Commit a20c9ea

Browse files
Implment issue 28 with FLAG_IGNORE_SLICE_ORDER
1 parent 47c10a1 commit a20c9ea

File tree

2 files changed

+152
-17
lines changed

2 files changed

+152
-17
lines changed

deep.go

Lines changed: 88 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,23 @@ var (
5656
ErrNotHandled = errors.New("cannot compare the reflect.Kind")
5757
)
5858

59+
const (
60+
// FLAG_NONE is a placeholder for default Equal behavior. You don't have to
61+
// pass it to Equal; if you do, it does nothing.
62+
FLAG_NONE byte = iota
63+
64+
// FLAG_IGNORE_SLICE_ORDER causes Equal to ignore slice order and, instead,
65+
// compare value counts. For example, []{1, 2} and []{2, 1} are equal
66+
// because each value has the same count. But []{1, 2, 2} and []{1, 2}
67+
// are not equal because the first slice has two occurrences of value 2.
68+
FLAG_IGNORE_SLICE_ORDER
69+
)
70+
5971
type cmp struct {
6072
diff []string
6173
buff []string
6274
floatFormat string
75+
flag map[byte]bool
6376
}
6477

6578
var errorType = reflect.TypeOf((*error)(nil)).Elem()
@@ -74,13 +87,17 @@ var errorType = reflect.TypeOf((*error)(nil)).Elem()
7487
//
7588
// When comparing a struct, if a field has the tag `deep:"-"` then it will be
7689
// ignored.
77-
func Equal(a, b interface{}) []string {
90+
func Equal(a, b interface{}, flags ...interface{}) []string {
7891
aVal := reflect.ValueOf(a)
7992
bVal := reflect.ValueOf(b)
8093
c := &cmp{
8194
diff: []string{},
8295
buff: []string{},
8396
floatFormat: fmt.Sprintf("%%.%df", FloatPrecision),
97+
flag: map[byte]bool{},
98+
}
99+
for i := range flags {
100+
c.flag[flags[i].(byte)] = true
84101
}
85102
if a == nil && b == nil {
86103
return nil
@@ -339,29 +356,54 @@ func (c *cmp) equals(a, b reflect.Value, level int) {
339356
}
340357
}
341358

359+
// Equal if same underlying pointer and same length, this latter handles
360+
// foo := []int{1, 2, 3, 4}
361+
// a := foo[0:2] // == {1,2}
362+
// b := foo[2:4] // == {3,4}
363+
// a and b are same pointer but different slices (lengths) of the underlying
364+
// array, so not equal.
342365
aLen := a.Len()
343366
bLen := b.Len()
344-
345367
if a.Pointer() == b.Pointer() && aLen == bLen {
346368
return
347369
}
348370

349-
n := aLen
350-
if bLen > aLen {
351-
n = bLen
352-
}
353-
for i := 0; i < n; i++ {
354-
c.push(fmt.Sprintf("slice[%d]", i))
355-
if i < aLen && i < bLen {
356-
c.equals(a.Index(i), b.Index(i), level+1)
357-
} else if i < aLen {
358-
c.saveDiff(a.Index(i), "<no value>")
359-
} else {
360-
c.saveDiff("<no value>", b.Index(i))
371+
if c.flag[FLAG_IGNORE_SLICE_ORDER] {
372+
// Compare slices by value and value count; ignore order.
373+
// Value equality is impliclity established by the maps:
374+
// any value v1 will hash to the same map value if it's equal
375+
// to another value v2. Then equality is determiend by value
376+
// count: presuming v1==v2, then the slics are equal if there
377+
// are equal numbers of v1 in each slice.
378+
am := map[interface{}]int{}
379+
for i := 0; i < a.Len(); i++ {
380+
am[a.Index(i).Interface()] += 1
361381
}
362-
c.pop()
363-
if len(c.diff) >= MaxDiff {
364-
break
382+
bm := map[interface{}]int{}
383+
for i := 0; i < b.Len(); i++ {
384+
bm[b.Index(i).Interface()] += 1
385+
}
386+
c.cmpMapValueCounts(a, b, am, bm, true) // a cmp b
387+
c.cmpMapValueCounts(b, a, bm, am, false) // b cmp a
388+
} else {
389+
// Compare slices by order
390+
n := aLen
391+
if bLen > aLen {
392+
n = bLen
393+
}
394+
for i := 0; i < n; i++ {
395+
c.push(fmt.Sprintf("slice[%d]", i))
396+
if i < aLen && i < bLen {
397+
c.equals(a.Index(i), b.Index(i), level+1)
398+
} else if i < aLen {
399+
c.saveDiff(a.Index(i), "<no value>")
400+
} else {
401+
c.saveDiff("<no value>", b.Index(i))
402+
}
403+
c.pop()
404+
if len(c.diff) >= MaxDiff {
405+
break
406+
}
365407
}
366408
}
367409

@@ -435,6 +477,35 @@ func (c *cmp) saveDiff(aval, bval interface{}) {
435477
}
436478
}
437479

480+
func (c *cmp) cmpMapValueCounts(a, b reflect.Value, am, bm map[interface{}]int, a2b bool) {
481+
for v := range am {
482+
aCount, _ := am[v]
483+
bCount, _ := bm[v]
484+
485+
if aCount != bCount {
486+
c.push(fmt.Sprintf("(unordered) slice[]=%v: value count", v))
487+
if a2b {
488+
c.saveDiff(fmt.Sprintf("%d", aCount), fmt.Sprintf("%d", bCount))
489+
} else {
490+
c.saveDiff(fmt.Sprintf("%d", bCount), fmt.Sprintf("%d", aCount))
491+
}
492+
c.pop()
493+
}
494+
delete(am, v)
495+
delete(bm, v)
496+
497+
/*else {
498+
c.push(fmt.Sprintf("(unordered) slice[]=%v: value count", v))
499+
c.saveDiff(fmt.Sprintf("%d", aCount), "<no value>")
500+
} else {
501+
c.saveDiff("<no value>", fmt.Sprintf("%d occurrences", bCount))
502+
}
503+
c.pop()
504+
}
505+
*/
506+
}
507+
}
508+
438509
func logError(err error) {
439510
if LogErrors {
440511
log.Println(err)

deep_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,3 +1495,67 @@ func TestFunc(t *testing.T) {
14951495
t.Errorf("expected 0 diff, got %d: %s", len(diff), diff)
14961496
}
14971497
}
1498+
1499+
func TestSliceOrderString(t *testing.T) {
1500+
// https://github.com/go-test/deep/issues/28
1501+
1502+
// These are equal if we ignore order
1503+
a := []string{"foo", "bar"}
1504+
b := []string{"bar", "foo"}
1505+
diff := deep.Equal(a, b, deep.FLAG_IGNORE_SLICE_ORDER)
1506+
if len(diff) != 0 {
1507+
t.Fatalf("expected 0 diff, got %d: %s", len(diff), diff)
1508+
}
1509+
1510+
// Equal with dupes
1511+
a = []string{"foo", "foo", "bar"}
1512+
b = []string{"bar", "foo", "foo"}
1513+
diff = deep.Equal(a, b, deep.FLAG_IGNORE_SLICE_ORDER)
1514+
if len(diff) != 0 {
1515+
t.Fatalf("expected 0 diff, got %d: %s", len(diff), diff)
1516+
}
1517+
1518+
// NOT equal with dupes
1519+
a = []string{"foo", "foo", "bar"}
1520+
b = []string{"bar", "bar", "foo"}
1521+
diff = deep.Equal(a, b, deep.FLAG_IGNORE_SLICE_ORDER)
1522+
if len(diff) != 2 {
1523+
t.Fatalf("expected 2 diff, got %d: %s", len(diff), diff)
1524+
}
1525+
m1 := "(unordered) slice[]=foo: value count: 2 != 1"
1526+
m2 := "(unordered) slice[]=bar: value count: 1 != 2"
1527+
if diff[0] != m1 && diff[0] != m2 {
1528+
t.Errorf("got %s, expected '%s' or '%s'", diff[0], m1, m2)
1529+
}
1530+
if diff[1] != m1 && diff[1] != m2 {
1531+
t.Errorf("got %s, expected '%s' or '%s'", diff[1], m1, m2)
1532+
}
1533+
1534+
// NOT equal with one missing
1535+
a = []string{"foo", "bar"}
1536+
b = []string{"bar", "foo", "gone"}
1537+
diff = deep.Equal(a, b, deep.FLAG_IGNORE_SLICE_ORDER)
1538+
if len(diff) != 1 {
1539+
t.Fatalf("expected 2 diff, got %d: %s", len(diff), diff)
1540+
}
1541+
if diff[0] != "(unordered) slice[]=gone: value count: 0 != 1" {
1542+
t.Errorf("got %s, expected ''", diff[0])
1543+
}
1544+
1545+
// NOT equal at all
1546+
a = []string{"foo", "bar"}
1547+
b = []string{"x"}
1548+
diff = deep.Equal(a, b, deep.FLAG_IGNORE_SLICE_ORDER)
1549+
if len(diff) != 3 {
1550+
t.Fatalf("expected 2 diff, got %d: %s", len(diff), diff)
1551+
}
1552+
if diff[0] != "(unordered) slice[]=foo: value count: 1 != 0" {
1553+
t.Errorf("got %s, expected '(unordered) slice[]=foo: value count: 1 != 0", diff[0])
1554+
}
1555+
if diff[1] != "(unordered) slice[]=bar: value count: 1 != 0" {
1556+
t.Errorf("got %s, expected '(unordered) slice[]=bar: value count: 1 != 0'", diff[1])
1557+
}
1558+
if diff[2] != "(unordered) slice[]=x: value count: 0 != 1" {
1559+
t.Errorf("got %s, expected '(unordered) slice[]=x: value count: 0 != 1'", diff[2])
1560+
}
1561+
}

0 commit comments

Comments
 (0)