@@ -393,10 +393,11 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
393393 return o1_dtype , o2_dtype
394394
395395
396- def _resolve_weak_types_comparisons (o1_dtype , o2_dtype , dev ):
397- "Resolves weak data type per NEP-0050 for comparisons,"
398- "where result type is known to be `bool` and special behavior"
399- "is needed to handle mixed integer kinds"
396+ def _resolve_weak_types_all_py_ints (o1_dtype , o2_dtype , dev ):
397+ "Resolves weak data type per NEP-0050 for comparisons and"
398+ " divide, where result type is known and special behavior"
399+ "is needed to handle mixed integer kinds and Python integers"
400+ "without overflow"
400401 if _is_weak_dtype (o1_dtype ):
401402 if _is_weak_dtype (o2_dtype ):
402403 raise ValueError
@@ -414,11 +415,13 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
414415 )
415416 return _to_device_supported_dtype (dpt .float64 , dev ), o2_dtype
416417 else :
417- if isinstance (o1_dtype , WeakIntegralType ):
418- if o2_dtype .kind == "u" :
419- # Python scalar may be negative, assumes mixed int loops
420- # exist
421- return dpt .dtype (ti .default_device_int_type (dev )), o2_dtype
418+ if o1_kind_num == o2_kind_num and isinstance (
419+ o1_dtype , WeakIntegralType
420+ ):
421+ o1_val = o1_dtype .get ()
422+ o2_iinfo = dpt .iinfo (o2_dtype )
423+ if (o1_val < o2_iinfo .min ) or (o1_val > o2_iinfo .max ):
424+ return dpt .dtype (np .min_scalar_type (o1_val )), o2_dtype
422425 return o2_dtype , o2_dtype
423426 elif _is_weak_dtype (o2_dtype ):
424427 o1_kind_num = _strong_dtype_num_kind (o1_dtype )
@@ -435,11 +438,13 @@ def _resolve_weak_types_comparisons(o1_dtype, o2_dtype, dev):
435438 _to_device_supported_dtype (dpt .float64 , dev ),
436439 )
437440 else :
438- if isinstance (o2_dtype , WeakIntegralType ):
439- if o1_dtype .kind == "u" :
440- # Python scalar may be negative, assumes mixed int loops
441- # exist
442- return o1_dtype , dpt .dtype (ti .default_device_int_type (dev ))
441+ if o1_kind_num == o2_kind_num and isinstance (
442+ o2_dtype , WeakIntegralType
443+ ):
444+ o2_val = o2_dtype .get ()
445+ o1_iinfo = dpt .iinfo (o1_dtype )
446+ if (o2_val < o1_iinfo .min ) or (o2_val > o1_iinfo .max ):
447+ return o1_dtype , dpt .dtype (np .min_scalar_type (o2_val ))
443448 return o1_dtype , o1_dtype
444449 else :
445450 return o1_dtype , o2_dtype
@@ -834,7 +839,7 @@ def _default_accumulation_dtype_fp_types(inp_dt, q):
834839 "_acceptance_fn_negative" ,
835840 "_acceptance_fn_subtract" ,
836841 "_resolve_weak_types" ,
837- "_resolve_weak_types_comparisons " ,
842+ "_resolve_weak_types_all_py_ints " ,
838843 "_weak_type_num_kind" ,
839844 "_strong_dtype_num_kind" ,
840845 "can_cast" ,
0 commit comments