Skip to content

Commit 6dac61f

Browse files
authored
Merge pull request #99 from djsaunde/unseriable-exceptions
add MLIRError, UNSERIALIZABLE_EXCEPTIONS tuple
2 parents 53801cc + 62e4b61 commit 6dac61f

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

problems/nvidia/eval_better_bench_grouped_gemm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def _init_worker():
2121

2222
import torch.cuda
2323
from cutlass.cute.nvgpu.common import OpError
24+
from cutlass._mlir.ir import MLIRError
25+
2426
from torch.cuda.nvtx import range as nvtx_range
2527

2628
from utils import set_seed, clear_l2_cache_large as clear_l2_cache
@@ -33,6 +35,7 @@ def _init_worker():
3335
from reference import check_implementation, generate_input
3436

3537
NUM_ITERATIONS_PER_BENCHMARK = 15
38+
UNSERIALIZABLE_EXCEPTIONS = (OpError, MLIRError)
3639

3740

3841
class PopcornOutput:
@@ -181,7 +184,7 @@ def _run_single_test(test: TestCase):
181184
try:
182185
submission_output = custom_kernel(_clone_data(data))
183186

184-
except OpError as E:
187+
except UNSERIALIZABLE_EXCEPTIONS as E:
185188
print(f"Encountered {E}", file=sys.stderr)
186189
return False, str(E)
187190
torch.cuda.synchronize()
@@ -253,7 +256,7 @@ def _run_single_benchmark(
253256
for data in data_list:
254257
output = custom_kernel(_clone_data(data))
255258
outputs.append(output)
256-
except OpError as E:
259+
except UNSERIALIZABLE_EXCEPTIONS as E:
257260
return f"Encountered {E}"
258261
for reference_output, custom_output in zip(check_copy, outputs):
259262
good, message = check_implementation(reference_output, custom_output)

0 commit comments

Comments
 (0)