Skip to content

Commit b8eb632

Browse files
committed
TST: re-enable xpassing JAX tests
1 parent d91d01e commit b8eb632

1 file changed

Lines changed: 2 additions & 12 deletions

File tree

tests/test_funcs.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,9 @@
3939
)
4040
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
4141
from array_api_extra._lib._funcs import searchsorted as _funcs_searchsorted
42-
from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal
42+
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
4343
from array_api_extra._lib._utils._compat import (
4444
array_namespace,
45-
is_jax_namespace,
4645
is_torch_namespace,
4746
)
4847
from array_api_extra._lib._utils._compat import device as get_device
@@ -1399,7 +1398,6 @@ def test_assume_unique(self, xp: ModuleType):
13991398
@pytest.mark.parametrize("shape2", [(), (1,), (1, 1)])
14001399
def test_shapes(
14011400
self,
1402-
request: pytest.FixtureRequest,
14031401
assume_unique: bool,
14041402
shape1: tuple[int, ...],
14051403
shape2: tuple[int, ...],
@@ -1408,26 +1406,18 @@ def test_shapes(
14081406
x1 = xp.zeros(shape1)
14091407
x2 = xp.zeros(shape2)
14101408

1411-
if is_jax_namespace(xp) and assume_unique and shape1 != (1,):
1412-
xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0")
1413-
14141409
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
14151410
xp_assert_equal(actual, xp.empty((0,)))
14161411

14171412
@assume_unique
14181413
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
1419-
def test_python_scalar(
1420-
self, request: pytest.FixtureRequest, xp: ModuleType, assume_unique: bool
1421-
):
1414+
def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
14221415
# Test no dtype promotion to xp.asarray(x2); use x1.dtype
14231416
x1 = xp.asarray([3, 1, 2], dtype=xp.int16)
14241417
x2 = 3
14251418
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
14261419
xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16))
14271420

1428-
if is_jax_namespace(xp) and assume_unique:
1429-
xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0")
1430-
14311421
actual = setdiff1d(x2, x1, assume_unique=assume_unique)
14321422
xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))
14331423

0 commit comments

Comments
 (0)