66import pytest
77from numba .core .errors import TypingError
88
9- import numba_dpex as dpex
109import numba_dpex .experimental as dpex_exp
11- from numba_dpex .kernel_api import AtomicRef
10+ from numba_dpex .kernel_api import AtomicRef , Item , Range
1211from numba_dpex .tests ._helper import get_all_dtypes
1312
1413list_of_supported_dtypes = get_all_dtypes (
@@ -45,8 +44,8 @@ def test_fetch_phi_fn(input_arrays, ref_index, fetch_phi_fn):
4544 """A test for all fetch_phi atomic functions."""
4645
4746 @dpex_exp .kernel
48- def _kernel (a , b , ref_index ):
49- i = dpex . get_global_id (0 )
47+ def _kernel (item : Item , a , b , ref_index ):
48+ i = item . get_id (0 )
5049 v = AtomicRef (b , index = ref_index )
5150 getattr (v , fetch_phi_fn )(a [i ])
5251
@@ -60,9 +59,9 @@ def _kernel(a, b, ref_index):
6059 # fetch_and, fetch_or, fetch_xor accept only int arguments.
6160 # test for TypingError when float arguments are passed.
6261 with pytest .raises (TypingError ):
63- dpex_exp .call_kernel (_kernel , dpex . Range (10 ), a , b , ref_index )
62+ dpex_exp .call_kernel (_kernel , Range (10 ), a , b , ref_index )
6463 else :
65- dpex_exp .call_kernel (_kernel , dpex . Range (10 ), a , b , ref_index )
64+ dpex_exp .call_kernel (_kernel , Range (10 ), a , b , ref_index )
6665 # Verify that `a` accumulated at b[ref_index] by kernel
6766 # matches the `a` accumulated at b[ref_index+1] using Python
6867 for i in range (a .size ):
@@ -76,8 +75,8 @@ def test_fetch_phi_retval(fetch_phi_fn):
7675 """A test for all fetch_phi atomic functions."""
7776
7877 @dpex_exp .kernel
79- def _kernel (a , b , c ):
80- i = dpex . get_global_id (0 )
78+ def _kernel (item : Item , a , b , c ):
79+ i = item . get_id (0 )
8180 v = AtomicRef (b , index = i )
8281 c [i ] = getattr (v , fetch_phi_fn )(a [i ])
8382
@@ -89,7 +88,7 @@ def _kernel(a, b, c):
8988 b_copy = dpnp .copy (b )
9089 c_copy = dpnp .copy (c )
9190
92- dpex_exp .call_kernel (_kernel , dpex . Range (10 ), a , b , c )
91+ dpex_exp .call_kernel (_kernel , Range (10 ), a , b , c )
9392
9493 # Verify if the value returned by fetch_phi kernel
9594 # stored into `c` is same as the value returned
@@ -108,8 +107,8 @@ def test_fetch_phi_diff_types(fetch_phi_fn):
108107 """
109108
110109 @dpex_exp .kernel
111- def _kernel (a , b ):
112- i = dpex . get_global_id (0 )
110+ def _kernel (item : Item , a , b ):
111+ i = item . get_id (0 )
113112 v = AtomicRef (b , index = 0 )
114113 getattr (v , fetch_phi_fn )(a [i ])
115114
@@ -118,19 +117,19 @@ def _kernel(a, b):
118117 b = dpnp .zeros (N , dtype = dpnp .int32 )
119118
120119 with pytest .raises (TypingError ):
121- dpex_exp .call_kernel (_kernel , dpex . Range (10 ), a , b )
120+ dpex_exp .call_kernel (_kernel , Range (10 ), a , b )
122121
123122
124123@dpex_exp .kernel
125- def atomic_ref_0 (a ):
126- i = dpex . get_global_id (0 )
124+ def atomic_ref_0 (item : Item , a ):
125+ i = item . get_id (0 )
127126 v = AtomicRef (a , index = 0 )
128127 v .fetch_add (a [i + 2 ])
129128
130129
131130@dpex_exp .kernel
132- def atomic_ref_1 (a ):
133- i = dpex . get_global_id (0 )
131+ def atomic_ref_1 (item : Item , a ):
132+ i = item . get_id (0 )
134133 v = AtomicRef (a , index = 1 )
135134 v .fetch_add (a [i + 2 ])
136135
@@ -144,24 +143,24 @@ def test_spirv_compiler_flags_add():
144143 N = 10
145144 a = dpnp .ones (N , dtype = dpnp .float32 )
146145
147- dpex_exp .call_kernel (atomic_ref_0 , dpex . Range (N - 2 ), a )
148- dpex_exp .call_kernel (atomic_ref_1 , dpex . Range (N - 2 ), a )
146+ dpex_exp .call_kernel (atomic_ref_0 , Range (N - 2 ), a )
147+ dpex_exp .call_kernel (atomic_ref_1 , Range (N - 2 ), a )
149148
150149 assert a [0 ] == N - 1
151150 assert a [1 ] == N - 1
152151
153152
154153@dpex_exp .kernel
155- def atomic_max_0 (a ):
156- i = dpex . get_global_id (0 )
154+ def atomic_max_0 (item : Item , a ):
155+ i = item . get_id (0 )
157156 v = AtomicRef (a , index = 0 )
158157 if i != 0 :
159158 v .fetch_max (a [i ])
160159
161160
162161@dpex_exp .kernel
163- def atomic_max_1 (a ):
164- i = dpex . get_global_id (0 )
162+ def atomic_max_1 (item : Item , a ):
163+ i = item . get_id (0 )
165164 v = AtomicRef (a , index = 0 )
166165 if i != 0 :
167166 v .fetch_max (a [i ])
@@ -177,8 +176,8 @@ def test_spirv_compiler_flags_max():
177176 a = dpnp .arange (N , dtype = dpnp .float32 )
178177 b = dpnp .arange (N , dtype = dpnp .float32 )
179178
180- dpex_exp .call_kernel (atomic_max_0 , dpex . Range (N ), a )
181- dpex_exp .call_kernel (atomic_max_1 , dpex . Range (N ), b )
179+ dpex_exp .call_kernel (atomic_max_0 , Range (N ), a )
180+ dpex_exp .call_kernel (atomic_max_1 , Range (N ), b )
182181
183182 assert a [0 ] == N - 1
184183 assert b [0 ] == N - 1
0 commit comments