2626 the dictionary of submissions.
2727"""
2828import itertools
29+ import json
2930import operator
3031import os
3132import re
4546BASE_WORKLOADS = workloads_registry .BASE_WORKLOADS
4647WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)'
4748BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/'
49+ # Open json file to read heldout workloads
50+ # TODO: This probably shouldn't be hardcoded but passed as an argument.
51+ with open ("held_out_workloads_algoperf_v05.json" , "r" ) as f :
52+ HELDOUT_WORKLOADS = json .load (f )
4853# These global variables have to be set according to the current set of
4954# workloads and rules for the scoring to be correct.
5055# We do not use the workload registry since it contains test and development
@@ -248,6 +253,9 @@ def filter(x):
248253 try :
249254 if x [variant_workload ] == np .inf :
250255 return np .inf
256+ # Also check for nan values (e.g. OOMs)
257+ elif np .isnan (x [variant_workload ]):
258+ return np .inf
251259 else :
252260 return x [base_workload ]
253261 except KeyError as e :
@@ -306,8 +314,14 @@ def compute_performance_profiles(submissions,
306314 self_tuning_ruleset ,
307315 strict ))
308316 df = pd .concat (dfs )
309-
310- # For each held-out workload set to inf if the base workload is inf
317+ # Restrict to base and sampled held-out workloads
318+ # (ignore the additional workload variants of the baseline
319+ # as they cause issues when checking for nans in workload variants).
320+ df = df [BASE_WORKLOADS + HELDOUT_WORKLOADS ]
321+ # Sort workloads alphabetically (for better display)
322+ df = df .reindex (sorted (df .columns ), axis = 1 )
323+
324+ # For each held-out workload set to inf if the base workload is inf or nan
311325 for workload in df .keys ():
312326 if workload not in BASE_WORKLOADS :
313327 # If base do not have finite score set variant score to inf
@@ -319,14 +333,13 @@ def compute_performance_profiles(submissions,
319333 best_scores = df .min (axis = 0 )
320334 df [df .apply (lambda x : x > 4 * best_scores , axis = 1 )] = np .inf
321335
322- # For each held-out workload if variant target was not hit set submission to inf
336+ # For each base workload if variant target was not hit set submission to inf
323337 for workload in df .keys ():
324338 if workload not in BASE_WORKLOADS :
325339 # If variants do not have finite score set base_workload score to inf
326340 base_workload = get_base_workload_name (workload )
327341 df [base_workload ] = df .apply (
328342 variant_criteria_filter (base_workload , workload ), axis = 1 )
329-
330343 df = df [BASE_WORKLOADS ]
331344
332345 if verbosity > 0 :
0 commit comments