2323from array_api_extra ._lib import Backend
2424from array_api_extra ._lib ._testing import xp_assert_close , xp_assert_equal
2525from array_api_extra ._lib ._utils ._compat import device as get_device
26+ from array_api_extra ._lib ._utils ._helpers import ndindex
2627from array_api_extra ._lib ._utils ._typing import Array , Device
2728from array_api_extra .testing import lazy_xp_function
2829
@@ -221,7 +222,7 @@ def test_xp(self, xp: ModuleType):
221222
222223class TestCreateDiagonal :
223224 @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no device kwarg in zeros()" )
224- def test_1d (self , xp : ModuleType ):
225+ def test_1d_from_numpy (self , xp : ModuleType ):
225226 # from np.diag tests
226227 vals = 100 * xp .arange (5 , dtype = xp .float64 )
227228 b = xp .zeros ((5 , 5 ), dtype = xp .float64 )
@@ -239,7 +240,7 @@ def test_1d(self, xp: ModuleType):
239240 @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no device kwarg in zeros()" )
240241 @pytest .mark .parametrize ("n" , range (1 , 10 ))
241242 @pytest .mark .parametrize ("offset" , range (1 , 10 ))
242- def test_create_diagonal (self , xp : ModuleType , n : int , offset : int ):
243+ def test_1d_from_scipy (self , xp : ModuleType , n : int , offset : int ):
243244 # from scipy._lib tests
244245 rng = np .random .default_rng (2347823 )
245246 one = xp .asarray (1.0 )
@@ -248,13 +249,35 @@ def test_create_diagonal(self, xp: ModuleType, n: int, offset: int):
248249 B = xp .asarray (np .diag (x , offset ), dtype = one .dtype )
249250 xp_assert_equal (A , B )
250251
251- def test_0d (self , xp : ModuleType ):
252+ def test_0d_raises (self , xp : ModuleType ):
252253 with pytest .raises (ValueError , match = "1-dimensional" ):
253254 create_diagonal (xp .asarray (1 ))
254255
255- def test_2d (self , xp : ModuleType ):
256- with pytest .raises (ValueError , match = "1-dimensional" ):
257- create_diagonal (xp .asarray ([[1 ]]))
256+ @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no device kwarg in zeros()" )
257+ @pytest .mark .parametrize (
258+ "shape" ,
259+ [
260+ (0 ,),
261+ (10 ,),
262+ (0 , 1 ),
263+ (1 , 0 ),
264+ (0 , 0 ),
265+ (4 , 2 , 1 ),
266+ (1 , 1 , 7 ),
267+ (0 , 0 , 1 ),
268+ (3 , 2 , 4 , 5 ),
269+ ],
270+ )
271+ def test_nd (self , xp : ModuleType , shape : tuple [int , ...]):
272+ rng = np .random .default_rng (2347823 )
273+ b = xp .asarray (
274+ rng .integers ((1 << 64 ) - 1 , size = shape , dtype = np .uint64 ), dtype = xp .uint64
275+ )
276+ c = create_diagonal (b )
277+ zero = xp .zeros ((), dtype = xp .uint64 )
278+ assert c .shape == (* b .shape , b .shape [- 1 ])
279+ for i in ndindex (* c .shape ):
280+ xp_assert_equal (c [i ], b [i [:- 1 ]] if i [- 2 ] == i [- 1 ] else zero )
258281
259282 @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no device kwarg in zeros()" )
260283 def test_device (self , xp : ModuleType , device : Device ):
0 commit comments