Skip to content

Commit fe46cd6

Browse files
committed
BUG assert_almost_equal fails on subclasses that cannot handle bool
numpygh-8410 breaks a large number of astropy tests, because it sets up a boolean array for values that should actually be compared (i.e., are not `nan` or `inf`) using `zeros_like`. The latter means that for subclasses, the boolean test array is not a plain `ndarray` but the subclass. But for astropy's `Quantity`, the `all` method is undefined. This commit ensures the test arrays from `isinf` and `isnan` are used directly.
1 parent 83fe06d commit fe46cd6

2 files changed

Lines changed: 40 additions & 17 deletions

File tree

numpy/testing/tests/test_utils.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,24 @@ def test_subclass(self):
299299
a = np.array([[1., 2.], [3., 4.]])
300300
b = np.ma.masked_array([[1., 2.], [0., 4.]],
301301
[[False, False], [True, False]])
302-
assert_array_almost_equal(a, b)
303-
assert_array_almost_equal(b, a)
304-
assert_array_almost_equal(b, b)
302+
self._assert_func(a, b)
303+
self._assert_func(b, a)
304+
self._assert_func(b, b)
305+
306+
def test_subclass_that_cannot_be_bool(self):
307+
# While we cannot guarantee testing functions will always work for
308+
# subclasses, the tests should ideally rely only on subclasses having
309+
# comparison operators, not on them being able to store booleans
310+
# (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
311+
class MyArray(np.ndarray):
312+
def __lt__(self, other):
313+
return super(MyArray, self).__lt__(other).view(np.ndarray)
314+
315+
def all(self, *args, **kwargs):
316+
raise NotImplementedError
317+
318+
a = np.array([1., 2.]).view(MyArray)
319+
self._assert_func(a, a)
305320

306321

307322
class TestAlmostEqual(_GenericTest, unittest.TestCase):
@@ -387,6 +402,21 @@ def test_error_message(self):
387402
# remove anything that's not the array string
388403
self.assertEqual(str(e).split('%)\n ')[1], b)
389404

405+
def test_subclass_that_cannot_be_bool(self):
406+
# While we cannot guarantee testing functions will always work for
407+
# subclasses, the tests should ideally rely only on subclasses having
408+
# comparison operators, not on them being able to store booleans
409+
# (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
410+
class MyArray(np.ndarray):
411+
def __lt__(self, other):
412+
return super(MyArray, self).__lt__(other).view(np.ndarray)
413+
414+
def all(self, *args, **kwargs):
415+
raise NotImplementedError
416+
417+
a = np.array([1., 2.]).view(MyArray)
418+
self._assert_func(a, a)
419+
390420

391421
class TestApproxEqual(unittest.TestCase):
392422

numpy/testing/utils.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -669,8 +669,7 @@ def assert_array_compare(comparison, x, y, err_msg='', verbose=True,
669669
header='', precision=6, equal_nan=True,
670670
equal_inf=True):
671671
__tracebackhide__ = True # Hide traceback for py.test
672-
from numpy.core import array, isnan, isinf, any, all, inf, zeros_like
673-
from numpy.core.numerictypes import bool_
672+
from numpy.core import array, isnan, isinf, any, inf
674673
x = array(x, copy=False, subok=True)
675674
y = array(y, copy=False, subok=True)
676675

@@ -726,14 +725,13 @@ def chk_same_position(x_id, y_id, hasval='nan'):
726725
raise AssertionError(msg)
727726

728727
if isnumber(x) and isnumber(y):
729-
x_id, y_id = zeros_like(x, dtype=bool_), zeros_like(y, dtype=bool_)
730728
if equal_nan:
731729
x_isnan, y_isnan = isnan(x), isnan(y)
732730
# Validate that NaNs are in the same place
733731
if any(x_isnan) or any(y_isnan):
734732
chk_same_position(x_isnan, y_isnan, hasval='nan')
735-
x_id |= x_isnan
736-
y_id |= y_isnan
733+
x = x[~x_isnan]
734+
y = y[~y_isnan]
737735

738736
if equal_inf:
739737
x_isinf, y_isinf = isinf(x), isinf(y)
@@ -742,19 +740,14 @@ def chk_same_position(x_id, y_id, hasval='nan'):
742740
# Check +inf and -inf separately, since they are different
743741
chk_same_position(x == +inf, y == +inf, hasval='+inf')
744742
chk_same_position(x == -inf, y == -inf, hasval='-inf')
745-
x_id |= x_isinf
746-
y_id |= y_isinf
743+
x = x[~x_isinf]
744+
y = y[~y_isinf]
747745

748746
# Only do the comparison if actual values are left
749-
if all(x_id):
747+
if x.size == 0:
750748
return
751749

752-
if any(x_id):
753-
val = safe_comparison(x[~x_id], y[~y_id])
754-
else:
755-
val = safe_comparison(x, y)
756-
else:
757-
val = safe_comparison(x, y)
750+
val = safe_comparison(x, y)
758751

759752
if isinstance(val, bool):
760753
cond = val

0 commit comments

Comments
 (0)