Skip to content

Commit 3ab78bd

Browse files
committed
ENH: add dtype_helpers.widest_{real,complex}_dtype, use in test_special_cases
1 parent 914124a commit 3ab78bd

2 files changed

Lines changed: 15 additions & 2 deletions

File tree

array_api_tests/dtype_helpers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,23 @@ def __contains__(self, other):
281281
)
282282

283283

284+
# complex128 if available else complex64
285+
widest_complex_dtype = max(
286+
[(dt, dtype_nbits[dt]) for dt in complex_dtypes], key=lambda x: x[1]
287+
)[0]
288+
289+
290+
# float64 if available else float32
291+
widest_real_dtype = max(
292+
[(dt, dtype_nbits[dt]) for dt in real_float_dtypes], key=lambda x: x[1]
293+
)[0]
294+
295+
284296
dtype_components = _make_dtype_mapping_from_names(
285297
{"complex64": xp.float32, "complex128": xp.float64}
286298
)
287299

300+
288301
def as_real_dtype(dtype):
289302
"""
290303
Return the corresponding real dtype for a given floating-point dtype.

array_api_tests/test_special_cases.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,10 +1530,10 @@ def test_unary(func_name, func, case):
15301530

15311531
# Use the is_complex flag to determine the appropriate dtype
15321532
if case.is_complex:
1533-
dtype = xp.complex128
1533+
dtype = dh.widest_complex_dtype
15341534
in_value = case.cond_from_dtype(dtype).example()
15351535
else:
1536-
dtype = xp.float64
1536+
dtype = dh.widest_real_dtype
15371537
in_value = case.cond_from_dtype(dtype).example()
15381538

15391539
# Create array and compute result based on dtype

0 commit comments

Comments
 (0)