1313)
1414
1515from numba_dpex import config
16- from numba_dpex .core .datamodel .models import dpex_data_model_manager as dpex_dmm
1716from numba_dpex .core .parfors .reduction_helper import (
1817 ReductionHelper ,
1918 ReductionKernelVariables ,
2019)
2120from numba_dpex .core .utils .kernel_launcher import KernelLaunchIRBuilder
21+ from numba_dpex .dpctl_iface import libsyclinterface_bindings as sycl
22+ from numba_dpex .core .datamodel .models import (
23+ dpex_data_model_manager as kernel_dmm ,
24+ )
2225
2326from ..exceptions import UnsupportedParforError
2427from ..types .dpnp_ndarray_type import DpnpNdArray
2831 create_reduction_remainder_kernel_for_parfor ,
2932)
3033
31- _KernelArgs = namedtuple (
32- "_KernelArgs" ,
33- ["num_flattened_args" , "arg_vals" , "arg_types" ],
34- )
35-
3634
3735# A global list of kernels to keep the objects alive indefinitely.
3836keep_alive_kernels = []
@@ -68,11 +66,8 @@ def _getvar(lowerer, x):
6866 var_val = lowerer .varmap [x ]
6967
7068 if var_val :
71- if not isinstance (var_val .type , llvmir .PointerType ):
72- with lowerer .builder .goto_entry_block ():
73- var_val_ptr = lowerer .builder .alloca (var_val .type )
74- lowerer .builder .store (var_val , var_val_ptr )
75- return var_val_ptr
69+ if isinstance (var_val .type , llvmir .PointerType ):
70+ return lowerer .builder .load (var_val )
7671 else :
7772 return var_val
7873 else :
@@ -91,56 +86,6 @@ class ParforLowerImpl:
9186 for a parfor and submits it to a queue.
9287 """
9388
94- def _build_kernel_arglist (
95- self , kernel_fn , lowerer , kernel_builder : KernelLaunchIRBuilder
96- ):
97- """Creates local variables for all the arguments and the argument types
98- that are passes to the kernel function.
99-
100- Args:
101- kernel_fn: Kernel function to be launched.
102- lowerer: The Numba lowerer used to generate the LLVM IR
103-
104- Raises:
105- AssertionError: If the LLVM IR Value for an argument defined in
106- Numba IR is not found.
107- """
108- num_flattened_args = 0
109-
110- # Compute number of args to be passed to the kernel. Note that the
111- # actual number of kernel arguments is greater than the count of
112- # kernel_fn.kernel_args as arrays get flattened.
113- for arg_type in kernel_fn .kernel_arg_types :
114- if isinstance (arg_type , DpnpNdArray ):
115- datamodel = dpex_dmm .lookup (arg_type )
116- num_flattened_args += datamodel .flattened_field_count
117- elif arg_type == types .complex64 or arg_type == types .complex128 :
118- num_flattened_args += 2
119- else :
120- num_flattened_args += 1
121-
122- # Create LLVM values for the kernel args list and kernel arg types list
123- args_list = kernel_builder .allocate_kernel_arg_array (num_flattened_args )
124- args_ty_list = kernel_builder .allocate_kernel_arg_ty_array (
125- num_flattened_args
126- )
127- callargs_ptrs = []
128- for arg in kernel_fn .kernel_args :
129- callargs_ptrs .append (_getvar (lowerer , arg ))
130-
131- kernel_builder .populate_kernel_args_and_args_ty_arrays (
132- kernel_argtys = kernel_fn .kernel_arg_types ,
133- callargs_ptrs = callargs_ptrs ,
134- args_list = args_list ,
135- args_ty_list = args_ty_list ,
136- )
137-
138- return _KernelArgs (
139- num_flattened_args = num_flattened_args ,
140- arg_vals = args_list ,
141- arg_types = args_ty_list ,
142- )
143-
14489 def _loop_ranges (
14590 self ,
14691 lowerer ,
@@ -163,7 +108,10 @@ def _loop_ranges(
163108 "non-unit strides are not yet supported."
164109 )
165110 global_range .append (stop )
166-
111+ # For now the local_range is always an empty list as numba_dpex always
112+ # submits kernels generated for parfor nodes as range kernels.
113+ # The provision is kept here if in future there is newer functionality
114+ # to submit these kernels as ndrange.
167115 local_range = []
168116
169117 return global_range , local_range
@@ -215,31 +163,34 @@ def _submit_parfor_kernel(
215163 # Ensure that the Python arguments are kept alive for the duration of
216164 # the kernel execution
217165 keep_alive_kernels .append (kernel_fn .kernel )
218- kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
166+ kl_builder = KernelLaunchIRBuilder (
167+ lowerer .context , lowerer .builder , kernel_dmm
168+ )
169+
170+ queue_ref = kl_builder .get_queue (exec_queue = kernel_fn .queue )
219171
220- ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
221- args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
172+ kernel_args = []
173+ for arg in kernel_fn .kernel_args :
174+ kernel_args .append (_getvar (lowerer , arg ))
222175
223176 kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
224177 kernel_ref = lowerer .builder .inttoptr (
225178 lowerer .context .get_constant (types .uintp , kernel_ref_addr ),
226179 cgutils .voidptr_t ,
227180 )
228- curr_queue_ref = lowerer .builder .load (ptr_to_queue_ref )
229-
230- # Submit a synchronous kernel
231- kernel_builder .submit_sycl_kernel (
232- sycl_kernel_ref = kernel_ref ,
233- sycl_queue_ref = curr_queue_ref ,
234- total_kernel_args = args .num_flattened_args ,
235- arg_list = args .arg_vals ,
236- arg_ty_list = args .arg_types ,
237- global_range = global_range ,
238- local_range = local_range ,
181+
182+ kl_builder .set_kernel (kernel_ref )
183+ kl_builder .set_queue (queue_ref )
184+ kl_builder .set_range (global_range , local_range )
185+ kl_builder .set_arguments (
186+ kernel_fn .kernel_arg_types , kernel_args = kernel_args
239187 )
188+ kl_builder .set_dependant_event_list (dep_events = [])
189+ event_ref = kl_builder .submit ()
240190
241- # At this point we can free the DPCTLSyclQueueRef (curr_queue)
242- kernel_builder .free_queue (ptr_to_sycl_queue_ref = ptr_to_queue_ref )
191+ sycl .dpctl_event_wait (lowerer .builder , event_ref )
192+ sycl .dpctl_event_delete (lowerer .builder , event_ref )
193+ sycl .dpctl_queue_delete (lowerer .builder , queue_ref )
243194
244195 def _reduction_codegen (
245196 self ,
0 commit comments