Skip to content

Commit bdef2b8

Browse files
authored
TST: re-enable xpassing JAX tests (#672)
1 parent 865f716 commit bdef2b8

1 file changed

Lines changed: 2 additions & 14 deletions

File tree

tests/test_funcs.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,9 @@
3939
)
4040
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
4141
from array_api_extra._lib._funcs import searchsorted as _funcs_searchsorted
42-
from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal
42+
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
4343
from array_api_extra._lib._utils._compat import (
4444
array_namespace,
45-
is_jax_namespace,
4645
is_torch_namespace,
4746
)
4847
from array_api_extra._lib._utils._compat import device as get_device
@@ -558,8 +557,6 @@ def test_complex(self, xp: ModuleType):
558557
expect = xp.asarray([[1.0, -1.0j], [1.0j, 1.0]], dtype=xp.complex128)
559558
xp_assert_close(actual, expect)
560559

561-
@pytest.mark.xfail_xp_backend(Backend.JAX_GPU, reason="jax#32296")
562-
@pytest.mark.xfail_xp_backend(Backend.JAX, reason="jax#32296")
563560
def test_empty(self, xp: ModuleType):
564561
with warnings.catch_warnings(record=True):
565562
warnings.simplefilter("always", RuntimeWarning)
@@ -1399,7 +1396,6 @@ def test_assume_unique(self, xp: ModuleType):
13991396
@pytest.mark.parametrize("shape2", [(), (1,), (1, 1)])
14001397
def test_shapes(
14011398
self,
1402-
request: pytest.FixtureRequest,
14031399
assume_unique: bool,
14041400
shape1: tuple[int, ...],
14051401
shape2: tuple[int, ...],
@@ -1408,26 +1404,18 @@ def test_shapes(
14081404
x1 = xp.zeros(shape1)
14091405
x2 = xp.zeros(shape2)
14101406

1411-
if is_jax_namespace(xp) and assume_unique and shape1 != (1,):
1412-
xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0")
1413-
14141407
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
14151408
xp_assert_equal(actual, xp.empty((0,)))
14161409

14171410
@assume_unique
14181411
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
1419-
def test_python_scalar(
1420-
self, request: pytest.FixtureRequest, xp: ModuleType, assume_unique: bool
1421-
):
1412+
def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
14221413
# Test no dtype promotion to xp.asarray(x2); use x1.dtype
14231414
x1 = xp.asarray([3, 1, 2], dtype=xp.int16)
14241415
x2 = 3
14251416
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
14261417
xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16))
14271418

1428-
if is_jax_namespace(xp) and assume_unique:
1429-
xfail(request=request, reason="jax#32335 fixed with jax>=0.8.0")
1430-
14311419
actual = setdiff1d(x2, x1, assume_unique=assume_unique)
14321420
xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))
14331421

0 commit comments

Comments
 (0)