2121
2222import dpctl
2323import dpctl .tensor as dpt
24+ from dpctl .tensor ._numpy_helper import AxisError
2425from dpctl .tests .helper import get_queue_or_skip
2526from dpctl .utils import ExecutionPlacementError
2627
@@ -59,7 +60,7 @@ def test_permute_dims_0d_1d():
5960 assert_array_equal (dpt .asnumpy (Y_1d ), dpt .asnumpy (X_1d ))
6061
6162 pytest .raises (ValueError , dpt .permute_dims , X_1d , ())
62- pytest .raises (np . AxisError , dpt .permute_dims , X_1d , (1 ))
63+ pytest .raises (AxisError , dpt .permute_dims , X_1d , (1 ))
6364 pytest .raises (ValueError , dpt .permute_dims , X_1d , (1 , 0 ))
6465 pytest .raises (
6566 ValueError , dpt .permute_dims , dpt .reshape (X_1d , (2 , 3 )), (1 , 1 )
@@ -105,8 +106,8 @@ def test_expand_dims_0d():
105106 Ynp = np .expand_dims (Xnp , axis = - 1 )
106107 assert_array_equal (Ynp , dpt .asnumpy (Y ))
107108
108- pytest .raises (np . AxisError , dpt .expand_dims , X , axis = 1 )
109- pytest .raises (np . AxisError , dpt .expand_dims , X , axis = - 2 )
109+ pytest .raises (AxisError , dpt .expand_dims , X , axis = 1 )
110+ pytest .raises (AxisError , dpt .expand_dims , X , axis = - 2 )
110111
111112
112113@pytest .mark .parametrize ("shapes" , [(3 ,), (3 , 3 ), (3 , 3 , 3 )])
@@ -123,8 +124,8 @@ def test_expand_dims_1d_3d(shapes):
123124 Ynp = np .expand_dims (Xnp , axis = axis )
124125 assert_array_equal (Ynp , dpt .asnumpy (Y ))
125126
126- pytest .raises (np . AxisError , dpt .expand_dims , X , axis = shape_len + 1 )
127- pytest .raises (np . AxisError , dpt .expand_dims , X , axis = - shape_len - 2 )
127+ pytest .raises (AxisError , dpt .expand_dims , X , axis = shape_len + 1 )
128+ pytest .raises (AxisError , dpt .expand_dims , X , axis = - shape_len - 2 )
128129
129130
130131@pytest .mark .parametrize (
@@ -145,9 +146,9 @@ def test_expand_dims_incorrect_tuple():
145146 X = dpt .empty ((3 , 3 , 3 ), dtype = "i4" )
146147 except dpctl .SyclDeviceCreationError :
147148 pytest .skip ("No SYCL devices available" )
148- with pytest .raises (np . AxisError ):
149+ with pytest .raises (AxisError ):
149150 dpt .expand_dims (X , axis = (0 , - 6 ))
150- with pytest .raises (np . AxisError ):
151+ with pytest .raises (AxisError ):
151152 dpt .expand_dims (X , axis = (0 , 5 ))
152153
153154 with pytest .raises (ValueError ):
@@ -181,10 +182,10 @@ def test_squeeze_0d():
181182 Ynp = Xnp .squeeze (- 1 )
182183 assert_array_equal (Ynp , dpt .asnumpy (Y ))
183184
184- pytest .raises (np . AxisError , dpt .squeeze , X , 1 )
185- pytest .raises (np . AxisError , dpt .squeeze , X , - 2 )
186- pytest .raises (np . AxisError , dpt .squeeze , X , (1 ))
187- pytest .raises (np . AxisError , dpt .squeeze , X , (- 2 ))
185+ pytest .raises (AxisError , dpt .squeeze , X , 1 )
186+ pytest .raises (AxisError , dpt .squeeze , X , - 2 )
187+ pytest .raises (AxisError , dpt .squeeze , X , (1 ))
188+ pytest .raises (AxisError , dpt .squeeze , X , (- 2 ))
188189 pytest .raises (ValueError , dpt .squeeze , X , (0 , 0 ))
189190
190191
@@ -446,10 +447,10 @@ def test_flip_axis_incorrect():
446447 X_np = np .ones ((4 , 4 ))
447448 X = dpt .asarray (X_np , sycl_queue = q )
448449
449- pytest .raises (np . AxisError , dpt .flip , dpt .asarray (np .ones (4 )), axis = 1 )
450- pytest .raises (np . AxisError , dpt .flip , X , axis = 2 )
451- pytest .raises (np . AxisError , dpt .flip , X , axis = - 3 )
452- pytest .raises (np . AxisError , dpt .flip , X , axis = (0 , 3 ))
450+ pytest .raises (AxisError , dpt .flip , dpt .asarray (np .ones (4 )), axis = 1 )
451+ pytest .raises (AxisError , dpt .flip , X , axis = 2 )
452+ pytest .raises (AxisError , dpt .flip , X , axis = - 3 )
453+ pytest .raises (AxisError , dpt .flip , X , axis = (0 , 3 ))
453454
454455
455456def test_flip_0d ():
@@ -461,9 +462,9 @@ def test_flip_0d():
461462 Y = dpt .flip (X )
462463 assert_array_equal (Ynp , dpt .asnumpy (Y ))
463464
464- pytest .raises (np . AxisError , dpt .flip , X , axis = 0 )
465- pytest .raises (np . AxisError , dpt .flip , X , axis = 1 )
466- pytest .raises (np . AxisError , dpt .flip , X , axis = - 1 )
465+ pytest .raises (AxisError , dpt .flip , X , axis = 0 )
466+ pytest .raises (AxisError , dpt .flip , X , axis = 1 )
467+ pytest .raises (AxisError , dpt .flip , X , axis = - 1 )
467468
468469
469470def test_flip_1d ():
@@ -588,9 +589,9 @@ def test_roll_empty():
588589 Y = dpt .roll (X , 1 )
589590 Ynp = np .roll (Xnp , 1 )
590591 assert_array_equal (Ynp , dpt .asnumpy (Y ))
591- with pytest .raises (np . AxisError ):
592+ with pytest .raises (AxisError ):
592593 dpt .roll (X , 1 , axis = 0 )
593- with pytest .raises (np . AxisError ):
594+ with pytest .raises (AxisError ):
594595 dpt .roll (X , 1 , axis = 1 )
595596
596597
@@ -1086,13 +1087,13 @@ def test_moveaxis_errors():
10861087 pytest .skip ("No SYCL devices available" )
10871088 x = dpt .reshape (x_flat , (1 , 2 , 3 ))
10881089 assert_raises_regex (
1089- np . AxisError , "source.*out of bounds" , dpt .moveaxis , x , 3 , 0
1090+ AxisError , "source.*out of bounds" , dpt .moveaxis , x , 3 , 0
10901091 )
10911092 assert_raises_regex (
1092- np . AxisError , "source.*out of bounds" , dpt .moveaxis , x , - 4 , 0
1093+ AxisError , "source.*out of bounds" , dpt .moveaxis , x , - 4 , 0
10931094 )
10941095 assert_raises_regex (
1095- np . AxisError , "destination.*out of bounds" , dpt .moveaxis , x , 0 , 5
1096+ AxisError , "destination.*out of bounds" , dpt .moveaxis , x , 0 , 5
10961097 )
10971098 assert_raises_regex (
10981099 ValueError , "repeated axis in `source`" , dpt .moveaxis , x , [0 , 0 ], [0 , 1 ]
0 commit comments