3939)
4040from array_api_extra ._lib ._backends import NUMPY_VERSION , Backend
4141from array_api_extra ._lib ._funcs import searchsorted as _funcs_searchsorted
42- from array_api_extra ._lib ._testing import xfail , xp_assert_close , xp_assert_equal
42+ from array_api_extra ._lib ._testing import xp_assert_close , xp_assert_equal
4343from array_api_extra ._lib ._utils ._compat import (
4444 array_namespace ,
45- is_jax_namespace ,
4645 is_torch_namespace ,
4746)
4847from array_api_extra ._lib ._utils ._compat import device as get_device
@@ -1399,7 +1398,6 @@ def test_assume_unique(self, xp: ModuleType):
13991398 @pytest .mark .parametrize ("shape2" , [(), (1 ,), (1 , 1 )])
14001399 def test_shapes (
14011400 self ,
1402- request : pytest .FixtureRequest ,
14031401 assume_unique : bool ,
14041402 shape1 : tuple [int , ...],
14051403 shape2 : tuple [int , ...],
@@ -1408,26 +1406,18 @@ def test_shapes(
14081406 x1 = xp .zeros (shape1 )
14091407 x2 = xp .zeros (shape2 )
14101408
1411- if is_jax_namespace (xp ) and assume_unique and shape1 != (1 ,):
1412- xfail (request = request , reason = "jax#32335 fixed with jax>=0.8.0" )
1413-
14141409 actual = setdiff1d (x1 , x2 , assume_unique = assume_unique )
14151410 xp_assert_equal (actual , xp .empty ((0 ,)))
14161411
14171412 @assume_unique
14181413 @pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "xp=xp" )
1419- def test_python_scalar (
1420- self , request : pytest .FixtureRequest , xp : ModuleType , assume_unique : bool
1421- ):
1414+ def test_python_scalar (self , xp : ModuleType , assume_unique : bool ):
14221415 # Test no dtype promotion to xp.asarray(x2); use x1.dtype
14231416 x1 = xp .asarray ([3 , 1 , 2 ], dtype = xp .int16 )
14241417 x2 = 3
14251418 actual = setdiff1d (x1 , x2 , assume_unique = assume_unique )
14261419 xp_assert_equal (actual , xp .asarray ([1 , 2 ], dtype = xp .int16 ))
14271420
1428- if is_jax_namespace (xp ) and assume_unique :
1429- xfail (request = request , reason = "jax#32335 fixed with jax>=0.8.0" )
1430-
14311421 actual = setdiff1d (x2 , x1 , assume_unique = assume_unique )
14321422 xp_assert_equal (actual , xp .asarray ([], dtype = xp .int16 ))
14331423
0 commit comments