Skip to content

Commit 914124a

Browse files
committed
STY: rm whitespace-only lines
1 parent 5c50ff0 commit 914124a

1 file changed

Lines changed: 32 additions & 34 deletions

File tree

array_api_tests/test_special_cases.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def check_result(result: float) -> bool:
496496
def parse_complex_value(value_str: str) -> complex:
497497
"""
498498
Parses a complex value string to return a complex number, e.g.
499-
499+
500500
>>> parse_complex_value('+0 + 0j')
501501
0j
502502
>>> parse_complex_value('NaN + NaN j')
@@ -507,21 +507,21 @@ def parse_complex_value(value_str: str) -> complex:
507507
1.5707963267948966j
508508
>>> parse_complex_value('+infinity + 3πj/4')
509509
(inf+2.356194490192345j)
510-
510+
511511
Handles formats: "A + Bj", "A + B j", "A + πj/N", "A + Nπj/M"
512512
"""
513513
m = r_complex_value.match(value_str)
514514
if m is None:
515515
raise ParseError(value_str)
516-
516+
517517
# Parse real part with its sign
518518
# Normalize ± to + (we choose positive arbitrarily since sign is unspecified)
519519
real_sign = m.group(1) if m.group(1) else "+"
520520
if '±' in real_sign:
521521
real_sign = '+'
522522
real_val_str = m.group(2)
523523
real_val = parse_value(real_sign + real_val_str)
524-
524+
525525
# Parse imaginary part with its sign
526526
# Normalize ± to + for imaginary part as well
527527
imag_sign = m.group(3)
@@ -536,9 +536,9 @@ def parse_complex_value(value_str: str) -> complex:
536536
imag_val_str_raw = m.group(5)
537537
# Strip trailing 'j' if present: "0j" -> "0"
538538
imag_val_str = imag_val_str_raw[:-1] if imag_val_str_raw.endswith('j') else imag_val_str_raw
539-
539+
540540
imag_val = parse_value(imag_sign + imag_val_str)
541-
541+
542542
return complex(real_val, imag_val)
543543

544544

@@ -548,10 +548,10 @@ def make_strict_eq_complex(v: complex) -> Callable[[complex], bool]:
548548
"""
549549
real_check = make_strict_eq(v.real)
550550
imag_check = make_strict_eq(v.imag)
551-
551+
552552
def strict_eq_complex(z: complex) -> bool:
553553
return real_check(z.real) and imag_check(z.imag)
554-
554+
555555
return strict_eq_complex
556556

557557

@@ -560,7 +560,7 @@ def parse_complex_cond(
560560
) -> Tuple[Callable[[complex], bool], str, FromDtypeFunc]:
561561
"""
562562
Parses complex condition strings for real (a) and imaginary (b) parts.
563-
563+
564564
Returns:
565565
- cond: Function that checks if a complex number meets the condition
566566
- expr: String expression for the condition
@@ -569,16 +569,16 @@ def parse_complex_cond(
569569
# Parse conditions for real and imaginary parts separately
570570
a_cond, a_expr_template, a_from_dtype = parse_cond(a_cond_str)
571571
b_cond, b_expr_template, b_from_dtype = parse_cond(b_cond_str)
572-
572+
573573
# Create compound condition
574574
def complex_cond(z: complex) -> bool:
575575
return a_cond(z.real) and b_cond(z.imag)
576-
576+
577577
# Create expression
578578
a_expr = a_expr_template.replace("{}", "real(x_i)")
579579
b_expr = b_expr_template.replace("{}", "imag(x_i)")
580580
expr = f"{a_expr} and {b_expr}"
581-
581+
582582
# Create strategy that generates complex numbers
583583
def complex_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[complex]:
584584
assert len(kw) == 0 # sanity check
@@ -589,7 +589,7 @@ def complex_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[complex]:
589589
real_strat = a_from_dtype(float_dtype)
590590
imag_strat = b_from_dtype(float_dtype)
591591
return st.builds(complex, real_strat, imag_strat)
592-
592+
593593
return complex_cond, expr, complex_from_dtype
594594

595595

@@ -609,7 +609,7 @@ def _check_component_with_tolerance(actual: float, expected: float, allow_any_si
609609
def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], str]:
610610
"""
611611
Parses a complex result string to return a checker and expression.
612-
612+
613613
Handles cases like:
614614
- "``+0 + 0j``" - exact complex value
615615
- "``0 + NaN j`` (sign of the real component is unspecified)"
@@ -618,7 +618,7 @@ def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], st
618618
# Check for unspecified sign notes (text-based detection)
619619
unspecified_real_sign = "sign of the real component is unspecified" in result_str
620620
unspecified_imag_sign = "sign of the imaginary component is unspecified" in result_str
621-
621+
622622
# Extract the complex value from backticks - need to handle spaces in complex values
623623
# Pattern: ``...`` where ... can contain spaces (for complex values like "0 + NaN j")
624624
m = re.search(r"``([^`]+)``", result_str)
@@ -640,12 +640,12 @@ def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], st
640640

641641
# Check if the value contains π expressions (for approximate comparison)
642642
has_pi = 'π' in value_str
643-
643+
644644
try:
645645
expected = parse_complex_value(value_str)
646646
except ParseError:
647647
raise ParseError(result_str)
648-
648+
649649
# Create checker based on whether signs are unspecified and whether π is involved
650650
if has_pi:
651651
# Use approximate equality for both real and imaginary parts if they involve π
@@ -670,7 +670,7 @@ def check_result(z: complex) -> bool:
670670
else:
671671
# Exact match including signs
672672
check_result = make_strict_eq_complex(expected)
673-
673+
674674
expr = value_str
675675
return check_result, expr
676676
else:
@@ -884,35 +884,34 @@ def parse_unary_case_block(case_block: str, func_name: str, record_list: Optiona
884884
cases = []
885885
# Check if the case block contains complex cases by looking for the marker
886886
in_complex_section = r_complex_marker.search(case_block) is not None
887-
887+
888888
for case_m in r_case.finditer(case_block):
889889
case_str = case_m.group(1)
890-
890+
891891
# Record this special case if a record list is provided
892892
if record_list is not None:
893893
record_list.append(f"{func_name}: {case_str}.")
894-
895-
894+
896895
# Try to parse complex cases if we're in the complex section
897896
if in_complex_section and (m := r_complex_case.search(case_str)):
898897
try:
899898
a_cond_str = m.group(1)
900899
b_cond_str = m.group(2)
901900
result_str = m.group(3)
902-
901+
903902
# Skip cases with complex expressions like "cis(b)"
904903
if "cis" in result_str or "*" in result_str:
905904
warn(f"case for {func_name} not machine-readable: '{case_str}'")
906905
continue
907-
906+
908907
# Parse the complex condition and result
909908
complex_cond, cond_expr, complex_from_dtype = parse_complex_cond(
910909
a_cond_str, b_cond_str
911910
)
912911
_check_result, result_expr = parse_complex_result(result_str)
913-
912+
914913
check_result = make_complex_unary_check_result(_check_result)
915-
914+
916915
case = UnaryCase(
917916
cond_expr=cond_expr,
918917
cond=complex_cond,
@@ -926,7 +925,7 @@ def parse_unary_case_block(case_block: str, func_name: str, record_list: Optiona
926925
except ParseError as e:
927926
warn(f"case for {func_name} not machine-readable: '{e.value}'")
928927
continue
929-
928+
930929
# Parse regular (real-valued) cases
931930
if r_already_int_case.search(case_str):
932931
cases.append(already_int_case)
@@ -1394,11 +1393,11 @@ def parse_binary_case_block(case_block: str, func_name: str, record_list: Option
13941393
cases = []
13951394
for case_m in r_case.finditer(case_block):
13961395
case_str = case_m.group(1)
1397-
1396+
13981397
# Record this special case if a record list is provided
13991398
if record_list is not None:
14001399
record_list.append(f"{func_name}: {case_str}.")
1401-
1400+
14021401
if r_redundant_case.search(case_str):
14031402
continue
14041403
if r_binary_case.match(case_str):
@@ -1528,24 +1527,24 @@ def test_unary(func_name, func, case):
15281527
# drawing multiple examples like a normal test, or just hard-coding a
15291528
# single example test case without using hypothesis.
15301529
filterwarnings('ignore', category=NonInteractiveExampleWarning)
1531-
1530+
15321531
# Use the is_complex flag to determine the appropriate dtype
15331532
if case.is_complex:
15341533
dtype = xp.complex128
15351534
in_value = case.cond_from_dtype(dtype).example()
15361535
else:
15371536
dtype = xp.float64
15381537
in_value = case.cond_from_dtype(dtype).example()
1539-
1538+
15401539
# Create array and compute result based on dtype
15411540
x = xp.asarray(in_value, dtype=dtype)
15421541
out = func(x)
1543-
1542+
15441543
if case.is_complex:
15451544
out_value = complex(out)
15461545
else:
15471546
out_value = float(out)
1548-
1547+
15491548
assert case.check_result(in_value, out_value), (
15501549
f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n"
15511550
)
@@ -1572,7 +1571,6 @@ def test_binary(func_name, func, case, data):
15721571
)
15731572

15741573

1575-
15761574
@pytest.mark.parametrize("iop_name, iop, case", iop_params)
15771575
@settings(max_examples=1)
15781576
@given(data=st.data())

0 commit comments

Comments
 (0)