@@ -496,7 +496,7 @@ def check_result(result: float) -> bool:
496496def parse_complex_value (value_str : str ) -> complex :
497497 """
498498 Parses a complex value string to return a complex number, e.g.
499-
499+
500500 >>> parse_complex_value('+0 + 0j')
501501 0j
502502 >>> parse_complex_value('NaN + NaN j')
@@ -507,21 +507,21 @@ def parse_complex_value(value_str: str) -> complex:
507507 1.5707963267948966j
508508 >>> parse_complex_value('+infinity + 3πj/4')
509509 (inf+2.356194490192345j)
510-
510+
511511 Handles formats: "A + Bj", "A + B j", "A + πj/N", "A + Nπj/M"
512512 """
513513 m = r_complex_value .match (value_str )
514514 if m is None :
515515 raise ParseError (value_str )
516-
516+
517517 # Parse real part with its sign
518518 # Normalize ± to + (we choose positive arbitrarily since sign is unspecified)
519519 real_sign = m .group (1 ) if m .group (1 ) else "+"
520520 if '±' in real_sign :
521521 real_sign = '+'
522522 real_val_str = m .group (2 )
523523 real_val = parse_value (real_sign + real_val_str )
524-
524+
525525 # Parse imaginary part with its sign
526526 # Normalize ± to + for imaginary part as well
527527 imag_sign = m .group (3 )
@@ -536,9 +536,9 @@ def parse_complex_value(value_str: str) -> complex:
536536 imag_val_str_raw = m .group (5 )
537537 # Strip trailing 'j' if present: "0j" -> "0"
538538 imag_val_str = imag_val_str_raw [:- 1 ] if imag_val_str_raw .endswith ('j' ) else imag_val_str_raw
539-
539+
540540 imag_val = parse_value (imag_sign + imag_val_str )
541-
541+
542542 return complex (real_val , imag_val )
543543
544544
@@ -548,10 +548,10 @@ def make_strict_eq_complex(v: complex) -> Callable[[complex], bool]:
548548 """
549549 real_check = make_strict_eq (v .real )
550550 imag_check = make_strict_eq (v .imag )
551-
551+
552552 def strict_eq_complex (z : complex ) -> bool :
553553 return real_check (z .real ) and imag_check (z .imag )
554-
554+
555555 return strict_eq_complex
556556
557557
@@ -560,7 +560,7 @@ def parse_complex_cond(
560560) -> Tuple [Callable [[complex ], bool ], str , FromDtypeFunc ]:
561561 """
562562 Parses complex condition strings for real (a) and imaginary (b) parts.
563-
563+
564564 Returns:
565565 - cond: Function that checks if a complex number meets the condition
566566 - expr: String expression for the condition
@@ -569,16 +569,16 @@ def parse_complex_cond(
569569 # Parse conditions for real and imaginary parts separately
570570 a_cond , a_expr_template , a_from_dtype = parse_cond (a_cond_str )
571571 b_cond , b_expr_template , b_from_dtype = parse_cond (b_cond_str )
572-
572+
573573 # Create compound condition
574574 def complex_cond (z : complex ) -> bool :
575575 return a_cond (z .real ) and b_cond (z .imag )
576-
576+
577577 # Create expression
578578 a_expr = a_expr_template .replace ("{}" , "real(x_i)" )
579579 b_expr = b_expr_template .replace ("{}" , "imag(x_i)" )
580580 expr = f"{ a_expr } and { b_expr } "
581-
581+
582582 # Create strategy that generates complex numbers
583583 def complex_from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [complex ]:
584584 assert len (kw ) == 0 # sanity check
@@ -589,7 +589,7 @@ def complex_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[complex]:
589589 real_strat = a_from_dtype (float_dtype )
590590 imag_strat = b_from_dtype (float_dtype )
591591 return st .builds (complex , real_strat , imag_strat )
592-
592+
593593 return complex_cond , expr , complex_from_dtype
594594
595595
@@ -609,7 +609,7 @@ def _check_component_with_tolerance(actual: float, expected: float, allow_any_si
609609def parse_complex_result (result_str : str ) -> Tuple [Callable [[complex ], bool ], str ]:
610610 """
611611 Parses a complex result string to return a checker and expression.
612-
612+
613613 Handles cases like:
614614 - "``+0 + 0j``" - exact complex value
615615 - "``0 + NaN j`` (sign of the real component is unspecified)"
@@ -618,7 +618,7 @@ def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], st
618618 # Check for unspecified sign notes (text-based detection)
619619 unspecified_real_sign = "sign of the real component is unspecified" in result_str
620620 unspecified_imag_sign = "sign of the imaginary component is unspecified" in result_str
621-
621+
622622 # Extract the complex value from backticks - need to handle spaces in complex values
623623 # Pattern: ``...`` where ... can contain spaces (for complex values like "0 + NaN j")
624624 m = re .search (r"``([^`]+)``" , result_str )
@@ -640,12 +640,12 @@ def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], st
640640
641641 # Check if the value contains π expressions (for approximate comparison)
642642 has_pi = 'π' in value_str
643-
643+
644644 try :
645645 expected = parse_complex_value (value_str )
646646 except ParseError :
647647 raise ParseError (result_str )
648-
648+
649649 # Create checker based on whether signs are unspecified and whether π is involved
650650 if has_pi :
651651 # Use approximate equality for both real and imaginary parts if they involve π
@@ -670,7 +670,7 @@ def check_result(z: complex) -> bool:
670670 else :
671671 # Exact match including signs
672672 check_result = make_strict_eq_complex (expected )
673-
673+
674674 expr = value_str
675675 return check_result , expr
676676 else :
@@ -884,35 +884,34 @@ def parse_unary_case_block(case_block: str, func_name: str, record_list: Optiona
884884 cases = []
885885 # Check if the case block contains complex cases by looking for the marker
886886 in_complex_section = r_complex_marker .search (case_block ) is not None
887-
887+
888888 for case_m in r_case .finditer (case_block ):
889889 case_str = case_m .group (1 )
890-
890+
891891 # Record this special case if a record list is provided
892892 if record_list is not None :
893893 record_list .append (f"{ func_name } : { case_str } ." )
894-
895-
894+
896895 # Try to parse complex cases if we're in the complex section
897896 if in_complex_section and (m := r_complex_case .search (case_str )):
898897 try :
899898 a_cond_str = m .group (1 )
900899 b_cond_str = m .group (2 )
901900 result_str = m .group (3 )
902-
901+
903902 # Skip cases with complex expressions like "cis(b)"
904903 if "cis" in result_str or "*" in result_str :
905904 warn (f"case for { func_name } not machine-readable: '{ case_str } '" )
906905 continue
907-
906+
908907 # Parse the complex condition and result
909908 complex_cond , cond_expr , complex_from_dtype = parse_complex_cond (
910909 a_cond_str , b_cond_str
911910 )
912911 _check_result , result_expr = parse_complex_result (result_str )
913-
912+
914913 check_result = make_complex_unary_check_result (_check_result )
915-
914+
916915 case = UnaryCase (
917916 cond_expr = cond_expr ,
918917 cond = complex_cond ,
@@ -926,7 +925,7 @@ def parse_unary_case_block(case_block: str, func_name: str, record_list: Optiona
926925 except ParseError as e :
927926 warn (f"case for { func_name } not machine-readable: '{ e .value } '" )
928927 continue
929-
928+
930929 # Parse regular (real-valued) cases
931930 if r_already_int_case .search (case_str ):
932931 cases .append (already_int_case )
@@ -1394,11 +1393,11 @@ def parse_binary_case_block(case_block: str, func_name: str, record_list: Option
13941393 cases = []
13951394 for case_m in r_case .finditer (case_block ):
13961395 case_str = case_m .group (1 )
1397-
1396+
13981397 # Record this special case if a record list is provided
13991398 if record_list is not None :
14001399 record_list .append (f"{ func_name } : { case_str } ." )
1401-
1400+
14021401 if r_redundant_case .search (case_str ):
14031402 continue
14041403 if r_binary_case .match (case_str ):
@@ -1528,24 +1527,24 @@ def test_unary(func_name, func, case):
15281527 # drawing multiple examples like a normal test, or just hard-coding a
15291528 # single example test case without using hypothesis.
15301529 filterwarnings ('ignore' , category = NonInteractiveExampleWarning )
1531-
1530+
15321531 # Use the is_complex flag to determine the appropriate dtype
15331532 if case .is_complex :
15341533 dtype = xp .complex128
15351534 in_value = case .cond_from_dtype (dtype ).example ()
15361535 else :
15371536 dtype = xp .float64
15381537 in_value = case .cond_from_dtype (dtype ).example ()
1539-
1538+
15401539 # Create array and compute result based on dtype
15411540 x = xp .asarray (in_value , dtype = dtype )
15421541 out = func (x )
1543-
1542+
15441543 if case .is_complex :
15451544 out_value = complex (out )
15461545 else :
15471546 out_value = float (out )
1548-
1547+
15491548 assert case .check_result (in_value , out_value ), (
15501549 f"out={ out_value } , but should be { case .result_expr } [{ func_name } ()]\n "
15511550 )
@@ -1572,7 +1571,6 @@ def test_binary(func_name, func, case, data):
15721571 )
15731572
15741573
1575-
15761574@pytest .mark .parametrize ("iop_name, iop, case" , iop_params )
15771575@settings (max_examples = 1 )
15781576@given (data = st .data ())
0 commit comments