@@ -69,6 +69,10 @@ def gen(context, builder, sig, args):
6969 "--spirv-ext=+SPV_EXT_shader_atomic_float_add"
7070 ]
7171
72+ context .extra_compile_options [LLVM_SPIRV_ARGS ] = [
73+ "--spirv-ext=+SPV_EXT_shader_atomic_float_min_max"
74+ ]
75+
7276 ptr_type = retty .as_pointer ()
7377 ptr_type .addrspace = atomic_ref_ty .address_space
7478
@@ -118,6 +122,59 @@ def _intrinsic_fetch_add(ty_context, ty_atomic_ref, ty_val):
118122 return _intrinsic_helper (ty_context , ty_atomic_ref , ty_val , "fetch_add" )
119123
120124
125+ def _atomic_sub_float_wrapper (gen_fn ):
126+ def gen (context , builder , sig , args ):
127+ # args is a tuple, which is immutable
128+ # covert tuple to list obj first before replacing arg[1]
129+ # with fneg and convert back to tuple again.
130+ args_lst = list (args )
131+ args_lst [1 ] = builder .fneg (args [1 ])
132+ args = tuple (args_lst )
133+
134+ gen_fn (context , builder , sig , args )
135+
136+ return gen
137+
138+
139+ @intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
140+ def _intrinsic_fetch_sub (ty_context , ty_atomic_ref , ty_val ):
141+ if ty_atomic_ref .dtype in (types .float32 , types .float64 ):
142+ # dpcpp does not support ``__spirv_AtomicFSubEXT``. fetch_sub
143+ # for floats is implemented by negating the value and calling fetch_add.
144+ # For example, A.fetch_sub(A, val) is implemented as A.fetch_add(-val).
145+ sig , gen = _intrinsic_helper (
146+ ty_context , ty_atomic_ref , ty_val , "fetch_add"
147+ )
148+ return sig , _atomic_sub_float_wrapper (gen )
149+
150+ return _intrinsic_helper (ty_context , ty_atomic_ref , ty_val , "fetch_sub" )
151+
152+
153+ @intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
154+ def _intrinsic_fetch_min (ty_context , ty_atomic_ref , ty_val ):
155+ return _intrinsic_helper (ty_context , ty_atomic_ref , ty_val , "fetch_min" )
156+
157+
158+ @intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
159+ def _intrinsic_fetch_max (ty_context , ty_atomic_ref , ty_val ):
160+ return _intrinsic_helper (ty_context , ty_atomic_ref , ty_val , "fetch_max" )
161+
162+
163+ @intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
164+ def _intrinsic_fetch_and (ty_context , ty_atomic_ref , ty_val ):
165+ return _intrinsic_helper (ty_context , ty_atomic_ref , ty_val , "fetch_and" )
166+
167+
168+ @intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
169+ def _intrinsic_fetch_or (ty_context , ty_atomic_ref , ty_val ):
170+ return _intrinsic_helper (ty_context , ty_atomic_ref , ty_val , "fetch_or" )
171+
172+
173+ @intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
174+ def _intrinsic_fetch_xor (ty_context , ty_atomic_ref , ty_val ):
175+ return _intrinsic_helper (ty_context , ty_atomic_ref , ty_val , "fetch_xor" )
176+
177+
121178@intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
122179def _intrinsic_atomic_ref_ctor (
123180 ty_context , ref , ty_index , ty_retty_ref # pylint: disable=unused-argument
@@ -294,3 +351,168 @@ def ol_fetch_add_impl(atomic_ref, val):
294351 return _intrinsic_fetch_add (atomic_ref , val )
295352
296353 return ol_fetch_add_impl
354+
355+
356+ @overload_method (AtomicRefType , "fetch_sub" , target = DPEX_KERNEL_EXP_TARGET_NAME )
357+ def ol_fetch_sub (atomic_ref , val ):
358+ """SPIR-V overload for
359+ :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_sub`.
360+
361+ Generates the same LLVM IR instruction as dpcpp for the
362+ `atomic_ref::fetch_sub` function.
363+
364+ Raises:
365+ TypingError: When the dtype of the aggregator value does not match the
366+ dtype of the AtomicRef type.
367+ """
368+ if atomic_ref .dtype != val :
369+ raise errors .TypingError (
370+ f"Type of value to sub: { val } does not match the type of the "
371+ f"reference: { atomic_ref .dtype } stored in the atomic ref."
372+ )
373+
374+ def ol_fetch_sub_impl (atomic_ref , val ):
375+ # pylint: disable=no-value-for-parameter
376+ return _intrinsic_fetch_sub (atomic_ref , val )
377+
378+ return ol_fetch_sub_impl
379+
380+
381+ @overload_method (AtomicRefType , "fetch_min" , target = DPEX_KERNEL_EXP_TARGET_NAME )
382+ def ol_fetch_min (atomic_ref , val ):
383+ """SPIR-V overload for
384+ :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_min`.
385+
386+ Generates the same LLVM IR instruction as dpcpp for the
387+ `atomic_ref::fetch_min` function.
388+
389+ Raises:
390+ TypingError: When the dtype of the aggregator value does not match the
391+ dtype of the AtomicRef type.
392+ """
393+ if atomic_ref .dtype != val :
394+ raise errors .TypingError (
395+ f"Type of value to find min: { val } does not match the type of the "
396+ f"reference: { atomic_ref .dtype } stored in the atomic ref."
397+ )
398+
399+ def ol_fetch_min_impl (atomic_ref , val ):
400+ # pylint: disable=no-value-for-parameter
401+ return _intrinsic_fetch_min (atomic_ref , val )
402+
403+ return ol_fetch_min_impl
404+
405+
406+ @overload_method (AtomicRefType , "fetch_max" , target = DPEX_KERNEL_EXP_TARGET_NAME )
407+ def ol_fetch_max (atomic_ref , val ):
408+ """SPIR-V overload for
409+ :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_max`.
410+
411+ Generates the same LLVM IR instruction as dpcpp for the
412+ `atomic_ref::fetch_max` function.
413+
414+ Raises:
415+ TypingError: When the dtype of the aggregator value does not match the
416+ dtype of the AtomicRef type.
417+ """
418+ if atomic_ref .dtype != val :
419+ raise errors .TypingError (
420+ f"Type of value to find max: { val } does not match the type of the "
421+ f"reference: { atomic_ref .dtype } stored in the atomic ref."
422+ )
423+
424+ def ol_fetch_max_impl (atomic_ref , val ):
425+ # pylint: disable=no-value-for-parameter
426+ return _intrinsic_fetch_max (atomic_ref , val )
427+
428+ return ol_fetch_max_impl
429+
430+
431+ @overload_method (AtomicRefType , "fetch_and" , target = DPEX_KERNEL_EXP_TARGET_NAME )
432+ def ol_fetch_and (atomic_ref , val ):
433+ """SPIR-V overload for
434+ :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_and`.
435+
436+ Generates the same LLVM IR instruction as dpcpp for the
437+ `atomic_ref::fetch_and` function.
438+
439+ Raises:
440+ TypingError: When the dtype of the aggregator value does not match the
441+ dtype of the AtomicRef type.
442+ """
443+ if atomic_ref .dtype != val :
444+ raise errors .TypingError (
445+ f"Type of value to and: { val } does not match the type of the "
446+ f"reference: { atomic_ref .dtype } stored in the atomic ref."
447+ )
448+
449+ if atomic_ref .dtype not in (types .int32 , types .int64 ):
450+ raise errors .TypingError (
451+ "fetch_and operation only supported on int32 and int64 dtypes."
452+ )
453+
454+ def ol_fetch_and_impl (atomic_ref , val ):
455+ # pylint: disable=no-value-for-parameter
456+ return _intrinsic_fetch_and (atomic_ref , val )
457+
458+ return ol_fetch_and_impl
459+
460+
461+ @overload_method (AtomicRefType , "fetch_or" , target = DPEX_KERNEL_EXP_TARGET_NAME )
462+ def ol_fetch_or (atomic_ref , val ):
463+ """SPIR-V overload for
464+ :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_or`.
465+
466+ Generates the same LLVM IR instruction as dpcpp for the
467+ `atomic_ref::fetch_or` function.
468+
469+ Raises:
470+ TypingError: When the dtype of the aggregator value does not match the
471+ dtype of the AtomicRef type.
472+ """
473+ if atomic_ref .dtype != val :
474+ raise errors .TypingError (
475+ f"Type of value to or: { val } does not match the type of the "
476+ f"reference: { atomic_ref .dtype } stored in the atomic ref."
477+ )
478+
479+ if atomic_ref .dtype not in (types .int32 , types .int64 ):
480+ raise errors .TypingError (
481+ "fetch_or operation only supported on int32 and int64 dtypes."
482+ )
483+
484+ def ol_fetch_or_impl (atomic_ref , val ):
485+ # pylint: disable=no-value-for-parameter
486+ return _intrinsic_fetch_or (atomic_ref , val )
487+
488+ return ol_fetch_or_impl
489+
490+
491+ @overload_method (AtomicRefType , "fetch_xor" , target = DPEX_KERNEL_EXP_TARGET_NAME )
492+ def ol_fetch_xor (atomic_ref , val ):
493+ """SPIR-V overload for
494+ :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_xor`.
495+
496+ Generates the same LLVM IR instruction as dpcpp for the
497+ `atomic_ref::fetch_xor` function.
498+
499+ Raises:
500+ TypingError: When the dtype of the aggregator value does not match the
501+ dtype of the AtomicRef type.
502+ """
503+ if atomic_ref .dtype != val :
504+ raise errors .TypingError (
505+ f"Type of value to xor: { val } does not match the type of the "
506+ f"reference: { atomic_ref .dtype } stored in the atomic ref."
507+ )
508+
509+ if atomic_ref .dtype not in (types .int32 , types .int64 ):
510+ raise errors .TypingError (
511+ "fetch_xor operation only supported on int32 and int64 dtypes."
512+ )
513+
514+ def ol_fetch_xor_impl (atomic_ref , val ):
515+ # pylint: disable=no-value-for-parameter
516+ return _intrinsic_fetch_xor (atomic_ref , val )
517+
518+ return ol_fetch_xor_impl
0 commit comments