Skip to content

Commit 5c9d38e

Browse files
author
Mark Saroufim
authored
Add nvfp4_group_gemm to nvidia.yaml + fix eval.py (#93)
* Add nvfp4_group_gemm problem to nvidia.yaml - Deadline: Feb 20, 2026 - Runners: B200 and NVIDIA * Fix eval.py to handle list values in test cases Bypass text serialization and parse YAML directly to properly handle list values for m, n, k in group GEMM test cases.
1 parent 250b004 commit 5c9d38e

1 file changed

Lines changed: 36 additions & 26 deletions

File tree

  • problems/nvidia/nvfp4_group_gemm

problems/nvidia/nvfp4_group_gemm/eval.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,33 @@ def run_profiling(logger: PopcornOutput, tests: list[TestCase]):
355355
return 0
356356

357357

358+
def get_test_cases_from_yaml(yaml_tests: list[dict], seed: Optional[int]) -> list[TestCase]:
359+
"""
360+
Create TestCase objects directly from YAML test definitions.
361+
This bypasses text serialization to properly handle list values.
362+
"""
363+
tests = []
364+
for test in yaml_tests:
365+
# Convert lists to tuples for consistency
366+
args = {}
367+
spec_parts = []
368+
for k, v in test.items():
369+
if isinstance(v, list):
370+
args[k] = tuple(v)
371+
else:
372+
args[k] = v
373+
spec_parts.append(f"{k}: {v}")
374+
spec = "; ".join(spec_parts)
375+
tests.append(TestCase(spec=spec, args=args))
376+
377+
if seed is not None:
378+
for test in tests:
379+
if "seed" in test.args:
380+
test.args["seed"] = _combine(test.args["seed"], seed)
381+
382+
return tests
383+
384+
358385
def main():
359386
fd = os.getenv("POPCORN_FD")
360387
if not fd:
@@ -369,34 +396,17 @@ def main():
369396
seed = int(seed) if seed else None
370397
set_seed(seed or 42)
371398

372-
filename = None
373-
374-
with tempfile.NamedTemporaryFile(delete=False) as tmp:
375-
376-
def build_test_string(tests: list[dict]):
377-
as_str = ""
378-
for test in tests:
379-
kvs = []
380-
for k, v in test.items():
381-
kvs.append(f"{k}: {v}")
382-
as_str += "; ".join(kvs) + "\n"
383-
return as_str
399+
import yaml
384400

385-
import yaml
386-
387-
yaml_content = yaml.safe_load(open(sys.argv[2], "r"))
388-
if mode == "test":
389-
tests_str = build_test_string(yaml_content.get("tests", []))
390-
elif mode in ("benchmark", "leaderboard", "profile"):
391-
tests_str = build_test_string(yaml_content.get("benchmarks", []))
392-
393-
tmp.write(tests_str.encode("utf-8"))
394-
tmp.flush()
395-
filename = tmp.name
396-
397-
tests = get_test_cases(filename, seed)
401+
yaml_content = yaml.safe_load(open(sys.argv[2], "r"))
402+
if mode == "test":
403+
yaml_tests = yaml_content.get("tests", [])
404+
elif mode in ("benchmark", "leaderboard", "profile"):
405+
yaml_tests = yaml_content.get("benchmarks", [])
406+
else:
407+
yaml_tests = []
398408

399-
os.unlink(filename)
409+
tests = get_test_cases_from_yaml(yaml_tests, seed)
400410

401411
with PopcornOutput(int(fd)) as logger:
402412
import multiprocessing

0 commit comments

Comments
 (0)