|
22 | 22 | default_dtype, |
23 | 23 | expand_dims, |
24 | 24 | isclose, |
| 25 | + isin, |
25 | 26 | kron, |
26 | 27 | nan_to_num, |
27 | 28 | nunique, |
@@ -888,7 +889,7 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool): |
888 | 889 | b = xp.asarray([1e-9, 1e-4, xp.nan], device=device) |
889 | 890 | res = isclose(a, b, equal_nan=equal_nan) |
890 | 891 | assert get_device(res) == device |
891 | | - |
| 892 | + |
892 | 893 | def test_array_on_device_with_scalar(self, xp: ModuleType, device: Device): |
893 | 894 | a = xp.asarray([0.01, 0.5, 0.8, 0.9, 1.00001], device=device) |
894 | 895 | b = 1 |
@@ -1476,3 +1477,55 @@ def test_nd(self, xp: ModuleType, ndim: int): |
1476 | 1477 | @override |
1477 | 1478 | def test_input_validation(self, xp: ModuleType): |
1478 | 1479 | self._test_input_validation(xp) |
| 1480 | + |
| 1481 | + |
| 1482 | +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no unique_inverse") |
| 1483 | +class TestIsIn: |
| 1484 | + def test_simple(self, xp: ModuleType, library: Backend): |
| 1485 | + if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24): |
| 1486 | + pytest.xfail("NumPy <1.24 has no kind kwarg in isin") |
| 1487 | + |
| 1488 | + b = xp.asarray([1, 2, 3, 4]) |
| 1489 | + |
| 1490 | + # `a` with 1 dimension |
| 1491 | + a = xp.asarray([1, 3, 6, 10]) |
| 1492 | + expected = xp.asarray([True, True, False, False]) |
| 1493 | + res = isin(a, b) |
| 1494 | + xp_assert_equal(res, expected) |
| 1495 | + |
| 1496 | + # `a` with 2 dimensions |
| 1497 | + a = xp.asarray([[0, 2], [4, 6]]) |
| 1498 | + expected = xp.asarray([[False, True], [True, False]]) |
| 1499 | + res = isin(a, b) |
| 1500 | + xp_assert_equal(res, expected) |
| 1501 | + |
| 1502 | + def test_device(self, xp: ModuleType, device: Device, library: Backend): |
| 1503 | + if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24): |
| 1504 | + pytest.xfail("NumPy <1.24 has no kind kwarg in isin") |
| 1505 | + |
| 1506 | + a = xp.asarray([1, 3, 6], device=device) |
| 1507 | + b = xp.asarray([1, 2, 3], device=device) |
| 1508 | + assert get_device(isin(a, b)) == device |
| 1509 | + |
| 1510 | + def test_assume_unique_and_invert( |
| 1511 | + self, xp: ModuleType, device: Device, library: Backend |
| 1512 | + ): |
| 1513 | + if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24): |
| 1514 | + pytest.xfail("NumPy <1.24 has no kind kwarg in isin") |
| 1515 | + |
| 1516 | + a = xp.asarray([0, 3, 6, 10], device=device) |
| 1517 | + b = xp.asarray([1, 2, 3, 10], device=device) |
| 1518 | + expected = xp.asarray([True, False, True, False]) |
| 1519 | + res = isin(a, b, assume_unique=True, invert=True) |
| 1520 | + assert get_device(res) == device |
| 1521 | + xp_assert_equal(res, expected) |
| 1522 | + |
| 1523 | + def test_kind(self, xp: ModuleType, library: Backend): |
| 1524 | + if library.like(Backend.NUMPY) and NUMPY_VERSION < (1, 24): |
| 1525 | + pytest.xfail("NumPy <1.24 has no kind kwarg in isin") |
| 1526 | + |
| 1527 | + a = xp.asarray([0, 3, 6, 10]) |
| 1528 | + b = xp.asarray([1, 2, 3, 10]) |
| 1529 | + expected = xp.asarray([False, True, False, True]) |
| 1530 | + res = isin(a, b, kind="sort") |
| 1531 | + xp_assert_equal(res, expected) |
0 commit comments