Skip to content

Commit 0942e96

Browse files
author
Mark Saroufim
authored
Fix eval.py to properly parse list values in test cases (#94)
* Fix eval.py to properly parse list values in test cases - Updated regex to use [^\]]* instead of [^\]]+ to handle edge cases - Added underscores to key pattern [a-zA-Z_]+ - Skip empty lines and empty parts when parsing - Use re.fullmatch directly instead of both re.match and re.fullmatch - Handle empty tuples/lists in value parsing * Fix eval.py to use text parsing instead of YAML Kernelbot passes a text file with format like: m: [96, 128]; n: [128, 256]; k: [128, 512]; g: 2; seed: 1111 Use get_test_cases() to parse this text format directly. Remove unused get_test_cases_from_yaml function.
1 parent 3f23047 commit 0942e96

1 file changed

Lines changed: 19 additions & 42 deletions

File tree

  • problems/nvidia/nvfp4_group_gemm

problems/nvidia/nvfp4_group_gemm/eval.py

Lines changed: 19 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,22 @@ def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]:
6767

6868
tests = []
6969
lines = content.splitlines()
70-
match = r"\s*([a-zA-Z]+):\s*(\([^)]+\)|\[[^\]]+\]|[a-zA-Z]+|[+-]?[0-9]+)\s*"
70+
# Match key: value pairs where value can be:
71+
# - a list like [1, 2, 3]
72+
# - a tuple like (1, 2, 3)
73+
# - an integer
74+
# - an alphabetic string
75+
match = r"\s*([a-zA-Z_]+)\s*:\s*(\[[^\]]*\]|\([^)]*\)|[a-zA-Z_]+|[+-]?[0-9]+)\s*"
7176
for line in lines:
77+
if not line.strip():
78+
continue
7279
parts = line.split(";")
7380
case = {}
7481
for part in parts:
75-
matched = re.match(match, part)
76-
if not re.fullmatch(match, part):
82+
if not part.strip():
83+
continue
84+
matched = re.fullmatch(match, part)
85+
if not matched:
7786
print(f"invalid test case: '{line}': '{part}'", file=sys.stderr)
7887
exit(113)
7988
key = matched[1]
@@ -84,7 +93,11 @@ def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]:
8493
# Try parsing as tuple/list
8594
if (val.startswith('(') and val.endswith(')')) or (val.startswith('[') and val.endswith(']')):
8695
try:
87-
val = tuple(int(x.strip()) for x in val[1:-1].split(','))
96+
inner = val[1:-1].strip()
97+
if inner:
98+
val = tuple(int(x.strip()) for x in inner.split(','))
99+
else:
100+
val = tuple()
88101
except ValueError:
89102
pass
90103

@@ -355,33 +368,6 @@ def run_profiling(logger: PopcornOutput, tests: list[TestCase]):
355368
return 0
356369

357370

358-
def get_test_cases_from_yaml(yaml_tests: list[dict], seed: Optional[int]) -> list[TestCase]:
359-
"""
360-
Create TestCase objects directly from YAML test definitions.
361-
This bypasses text serialization to properly handle list values.
362-
"""
363-
tests = []
364-
for test in yaml_tests:
365-
# Convert lists to tuples for consistency
366-
args = {}
367-
spec_parts = []
368-
for k, v in test.items():
369-
if isinstance(v, list):
370-
args[k] = tuple(v)
371-
else:
372-
args[k] = v
373-
spec_parts.append(f"{k}: {v}")
374-
spec = "; ".join(spec_parts)
375-
tests.append(TestCase(spec=spec, args=args))
376-
377-
if seed is not None:
378-
for test in tests:
379-
if "seed" in test.args:
380-
test.args["seed"] = _combine(test.args["seed"], seed)
381-
382-
return tests
383-
384-
385371
def main():
386372
fd = os.getenv("POPCORN_FD")
387373
if not fd:
@@ -396,17 +382,8 @@ def main():
396382
seed = int(seed) if seed else None
397383
set_seed(seed or 42)
398384

399-
import yaml
400-
401-
yaml_content = yaml.safe_load(open(sys.argv[2], "r"))
402-
if mode == "test":
403-
yaml_tests = yaml_content.get("tests", [])
404-
elif mode in ("benchmark", "leaderboard", "profile"):
405-
yaml_tests = yaml_content.get("benchmarks", [])
406-
else:
407-
yaml_tests = []
408-
409-
tests = get_test_cases_from_yaml(yaml_tests, seed)
385+
# Parse test cases from temp file (text format from kernelbot)
386+
tests = get_test_cases(sys.argv[2], seed)
410387

411388
with PopcornOutput(int(fd)) as logger:
412389
import multiprocessing

0 commit comments

Comments
 (0)