@@ -21,6 +21,8 @@ def _init_worker():
2121
2222import torch .cuda
2323from cutlass .cute .nvgpu .common import OpError
24+ from cutlass ._mlir .ir import MLIRError
25+
2426from torch .cuda .nvtx import range as nvtx_range
2527
2628from utils import set_seed , clear_l2_cache_large as clear_l2_cache
@@ -33,6 +35,7 @@ def _init_worker():
3335from reference import check_implementation , generate_input
3436
3537NUM_ITERATIONS_PER_BENCHMARK = 15
38+ UNSERIALIZABLE_EXCEPTIONS = (OpError , MLIRError )
3639
3740
3841class PopcornOutput :
@@ -181,7 +184,7 @@ def _run_single_test(test: TestCase):
181184 try :
182185 submission_output = custom_kernel (_clone_data (data ))
183186
184- except OpError as E :
187+ except UNSERIALIZABLE_EXCEPTIONS as E :
185188 print (f"Encountered { E } " , file = sys .stderr )
186189 return False , str (E )
187190 torch .cuda .synchronize ()
@@ -253,7 +256,7 @@ def _run_single_benchmark(
253256 for data in data_list :
254257 output = custom_kernel (_clone_data (data ))
255258 outputs .append (output )
256- except OpError as E :
259+ except UNSERIALIZABLE_EXCEPTIONS as E :
257260 return f"Encountered { E } "
258261 for reference_output , custom_output in zip (check_copy , outputs ):
259262 good , message = check_implementation (reference_output , custom_output )
0 commit comments