2222
2323from ..exceptions import UnsupportedParforError
2424from ..types .dpnp_ndarray_type import DpnpNdArray
25- from .kernel_builder import create_kernel_for_parfor
25+ from .kernel_builder import ParforKernel , create_kernel_for_parfor
2626from .reduction_kernel_builder import (
2727 create_reduction_main_kernel_for_parfor ,
2828 create_reduction_remainder_kernel_for_parfor ,
@@ -91,7 +91,9 @@ class ParforLowerImpl:
9191 for a parfor and submits it to a queue.
9292 """
9393
94- def _build_kernel_arglist (self , kernel_fn , lowerer , kernel_builder ):
94+ def _build_kernel_arglist (
95+ self , kernel_fn , lowerer , kernel_builder : KernelLaunchIRBuilder
96+ ):
9597 """Creates local variables for all the arguments and the argument types
9698 that are passes to the kernel function.
9799
@@ -139,27 +141,15 @@ def _build_kernel_arglist(self, kernel_fn, lowerer, kernel_builder):
139141 arg_types = args_ty_list ,
140142 )
141143
142- def _submit_parfor_kernel (
144+ def _loop_ranges (
143145 self ,
144146 lowerer ,
145- kernel_fn ,
146147 loop_ranges ,
147148 ):
148- """
149- Adds a call to submit a kernel function into the function body of the
150- current Numba JIT compiled function.
151- """
152- # Ensure that the Python arguments are kept alive for the duration of
153- # the kernel execution
154- keep_alive_kernels .append (kernel_fn .kernel )
155- kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
156-
157- ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
158- args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
159-
160149 # Create a global range over which to submit the kernel based on the
161150 # loop_ranges of the parfor
162151 global_range = []
152+
163153 # SYCL ranges can have at max 3 dimension. If the parfor is of a higher
164154 # dimension then the indexing for the higher dimensions is done inside
165155 # the kernel.
@@ -176,45 +166,13 @@ def _submit_parfor_kernel(
176166
177167 local_range = []
178168
179- kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
180- kernel_ref = lowerer .builder .inttoptr (
181- lowerer .context .get_constant (types .uintp , kernel_ref_addr ),
182- cgutils .voidptr_t ,
183- )
184- curr_queue_ref = lowerer .builder .load (ptr_to_queue_ref )
185-
186- # Submit a synchronous kernel
187- kernel_builder .submit_sycl_kernel (
188- sycl_kernel_ref = kernel_ref ,
189- sycl_queue_ref = curr_queue_ref ,
190- total_kernel_args = args .num_flattened_args ,
191- arg_list = args .arg_vals ,
192- arg_ty_list = args .arg_types ,
193- global_range = global_range ,
194- local_range = local_range ,
195- )
196-
197- # At this point we can free the DPCTLSyclQueueRef (curr_queue)
198- kernel_builder .free_queue (ptr_to_sycl_queue_ref = ptr_to_queue_ref )
169+ return global_range , local_range
199170
200- def _submit_reduction_main_parfor_kernel (
171+ def _reduction_ranges (
201172 self ,
202173 lowerer ,
203- kernel_fn ,
204174 reductionHelper = None ,
205175 ):
206- """
207- Adds a call to submit the main kernel of a parfor reduction into the
208- function body of the current Numba JIT compiled function.
209- """
210- # Ensure that the Python arguments are kept alive for the duration of
211- # the kernel execution
212- keep_alive_kernels .append (kernel_fn .kernel )
213- kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
214-
215- ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
216-
217- args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
218176 # Create a global range over which to submit the kernel based on the
219177 # loop_ranges of the parfor
220178 global_range = []
@@ -228,54 +186,39 @@ def _submit_reduction_main_parfor_kernel(
228186 _load_range (lowerer , reductionHelper .work_group_size )
229187 )
230188
231- kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
232- kernel_ref = lowerer .builder .inttoptr (
233- lowerer .context .get_constant (types .uintp , kernel_ref_addr ),
234- cgutils .voidptr_t ,
235- )
236- curr_queue_ref = lowerer .builder .load (ptr_to_queue_ref )
189+ return global_range , local_range
237190
238- # Submit a synchronous kernel
239- kernel_builder .submit_sycl_kernel (
240- sycl_kernel_ref = kernel_ref ,
241- sycl_queue_ref = curr_queue_ref ,
242- total_kernel_args = args .num_flattened_args ,
243- arg_list = args .arg_vals ,
244- arg_ty_list = args .arg_types ,
245- global_range = global_range ,
246- local_range = local_range ,
247- )
191+ def _remainder_ranges (self , lowerer ):
192+ # Create a global range over which to submit the kernel based on the
193+ # loop_ranges of the parfor
194+ global_range = []
248195
249- # At this point we can free the DPCTLSyclQueueRef (curr_queue)
250- kernel_builder .free_queue (ptr_to_sycl_queue_ref = ptr_to_queue_ref )
196+ stop = _load_range (lowerer , 1 )
197+
198+ global_range .append (stop )
251199
252- def _submit_reduction_remainder_parfor_kernel (
200+ local_range = []
201+
202+ return global_range , local_range
203+
204+ def _submit_parfor_kernel (
253205 self ,
254206 lowerer ,
255- kernel_fn ,
207+ kernel_fn : ParforKernel ,
208+ global_range ,
209+ local_range ,
256210 ):
257211 """
258- Adds a call to submit the remainder kernel of a parfor reduction into
259- the function body of the current Numba JIT compiled function.
212+ Adds a call to submit a kernel function into the function body of the
213+ current Numba JIT compiled function.
260214 """
261215 # Ensure that the Python arguments are kept alive for the duration of
262216 # the kernel execution
263217 keep_alive_kernels .append (kernel_fn .kernel )
264-
265218 kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
266219
267220 ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
268-
269221 args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
270- # Create a global range over which to submit the kernel based on the
271- # loop_ranges of the parfor
272- global_range = []
273-
274- stop = _load_range (lowerer , 1 )
275-
276- global_range .append (stop )
277-
278- local_range = []
279222
280223 kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
281224 kernel_ref = lowerer .builder .inttoptr (
@@ -360,10 +303,15 @@ def _reduction_codegen(
360303 parfor_reddict ,
361304 )
362305
363- self ._submit_reduction_main_parfor_kernel (
306+ global_range , local_range = self ._reduction_ranges (
307+ lowerer , reductionHelperList [0 ]
308+ )
309+
310+ self ._submit_parfor_kernel (
364311 lowerer ,
365312 parfor_kernel ,
366- reductionHelperList [0 ],
313+ global_range ,
314+ local_range ,
367315 )
368316
369317 parfor_kernel = create_reduction_remainder_kernel_for_parfor (
@@ -376,9 +324,13 @@ def _reduction_codegen(
376324 reductionHelperList ,
377325 )
378326
379- self ._submit_reduction_remainder_parfor_kernel (
327+ global_range , local_range = self ._remainder_ranges (lowerer )
328+
329+ self ._submit_parfor_kernel (
380330 lowerer ,
381331 parfor_kernel ,
332+ global_range ,
333+ local_range ,
382334 )
383335
384336 reductionKernelVar .copy_final_sum_to_host (parfor_kernel )
@@ -492,11 +444,14 @@ def _lower_parfor_as_kernel(self, lowerer, parfor):
492444 # FIXME: Make the exception more informative
493445 raise UnsupportedParforError
494446
447+ global_range , local_range = self ._loop_ranges (lowerer , loop_ranges )
448+
495449 # Finally submit the kernel
496450 self ._submit_parfor_kernel (
497451 lowerer ,
498452 parfor_kernel ,
499- loop_ranges ,
453+ global_range ,
454+ local_range ,
500455 )
501456
502457 # TODO: free the kernel at this point
0 commit comments