Skip to content

Commit 5911e38

Browse files
Merge pull request #1531 from arjunmahishi/fix-equal-values
assert: Fix EqualValues to handle overflow/underflow
2 parents d25ac14 + 4c4d011 commit 5911e38

File tree

2 files changed

+61
-20
lines changed

2 files changed

+61
-20
lines changed

assert/assertions.go

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,40 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool {
165165
return true
166166
}
167167

168-
actualType := reflect.TypeOf(actual)
169-
if actualType == nil {
168+
expectedValue := reflect.ValueOf(expected)
169+
actualValue := reflect.ValueOf(actual)
170+
if !expectedValue.IsValid() || !actualValue.IsValid() {
170171
return false
171172
}
172-
expectedValue := reflect.ValueOf(expected)
173-
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {
173+
174+
expectedType := expectedValue.Type()
175+
actualType := actualValue.Type()
176+
if !expectedType.ConvertibleTo(actualType) {
177+
return false
178+
}
179+
180+
if !isNumericType(expectedType) || !isNumericType(actualType) {
174181
// Attempt comparison after type conversion
175-
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
182+
return reflect.DeepEqual(
183+
expectedValue.Convert(actualType).Interface(), actual,
184+
)
176185
}
177186

178-
return false
187+
// If BOTH values are numeric, there are chances of false positives due
188+
// to overflow or underflow. So, we need to make sure to always convert
189+
// the smaller type to a larger type before comparing.
190+
if expectedType.Size() >= actualType.Size() {
191+
return actualValue.Convert(expectedType).Interface() == expected
192+
}
193+
194+
return expectedValue.Convert(actualType).Interface() == actual
195+
}
196+
197+
// isNumericType returns true if the type is one of:
198+
// int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64,
199+
// float32, float64, complex64, complex128
200+
func isNumericType(t reflect.Type) bool {
201+
return t.Kind() >= reflect.Int && t.Kind() <= reflect.Complex128
179202
}
180203

181204
/* CallerInfo is necessary because the assert functions use the testing object

assert/assertions_test.go

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -135,24 +135,42 @@ func TestObjectsAreEqual(t *testing.T) {
135135

136136
})
137137
}
138+
}
138139

139-
// Cases where type differ but values are equal
140-
if !ObjectsAreEqualValues(uint32(10), int32(10)) {
141-
t.Error("ObjectsAreEqualValues should return true")
142-
}
143-
if ObjectsAreEqualValues(0, nil) {
144-
t.Fail()
145-
}
146-
if ObjectsAreEqualValues(nil, 0) {
147-
t.Fail()
148-
}
140+
func TestObjectsAreEqualValues(t *testing.T) {
141+
now := time.Now()
149142

150-
tm := time.Now()
151-
tz := tm.In(time.Local)
152-
if !ObjectsAreEqualValues(tm, tz) {
153-
t.Error("ObjectsAreEqualValues should return true for time.Time objects with different time zones")
143+
cases := []struct {
144+
expected interface{}
145+
actual interface{}
146+
result bool
147+
}{
148+
{uint32(10), int32(10), true},
149+
{0, nil, false},
150+
{nil, 0, false},
151+
{now, now.In(time.Local), true}, // should be time zone independent
152+
{int(270), int8(14), false}, // should handle overflow/underflow
153+
{int8(14), int(270), false},
154+
{[]int{270, 270}, []int8{14, 14}, false},
155+
{complex128(1e+100 + 1e+100i), complex64(complex(math.Inf(0), math.Inf(0))), false},
156+
{complex64(complex(math.Inf(0), math.Inf(0))), complex128(1e+100 + 1e+100i), false},
157+
{complex128(1e+100 + 1e+100i), 270, false},
158+
{270, complex128(1e+100 + 1e+100i), false},
159+
{complex128(1e+100 + 1e+100i), 3.14, false},
160+
{3.14, complex128(1e+100 + 1e+100i), false},
161+
{complex128(1e+10 + 1e+10i), complex64(1e+10 + 1e+10i), true},
162+
{complex64(1e+10 + 1e+10i), complex128(1e+10 + 1e+10i), true},
154163
}
155164

165+
for _, c := range cases {
166+
t.Run(fmt.Sprintf("ObjectsAreEqualValues(%#v, %#v)", c.expected, c.actual), func(t *testing.T) {
167+
res := ObjectsAreEqualValues(c.expected, c.actual)
168+
169+
if res != c.result {
170+
t.Errorf("ObjectsAreEqualValues(%#v, %#v) should return %#v", c.expected, c.actual, c.result)
171+
}
172+
})
173+
}
156174
}
157175

158176
type Nested struct {

0 commit comments

Comments
 (0)