2626 the dictionary of submissions.
2727"""
2828import itertools
29+ import json
2930import operator
3031import os
3132import re
3233
3334from absl import logging
35+ import matplotlib as mpl
3436import matplotlib .pyplot as plt
3537import numpy as np
3638import pandas as pd
39+ from tabulate import tabulate
3740
3841from algorithmic_efficiency .workloads .workloads import get_base_workload_name
3942import algorithmic_efficiency .workloads .workloads as workloads_registry
4346BASE_WORKLOADS = workloads_registry .BASE_WORKLOADS
4447WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)'
4548BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/'
49+ # Open json file to read heldout workloads
50+ # TODO: This probably shouldn't be hardcoded but passed as an argument.
51+ with open ("held_out_workloads_algoperf_v05.json" , "r" ) as f :
52+ HELDOUT_WORKLOADS = json .load (f )
4653# These global variables have to be set according to the current set of
4754# workloads and rules for the scoring to be correct.
4855# We do not use the workload registry since it contains test and development
6370
6471MAX_EVAL_METRICS = ['mean_average_precision' , 'ssim' , 'accuracy' , 'bleu' ]
6572
73+ #MPL params
74+ mpl .rcParams ['figure.figsize' ] = (16 , 10 ) # Width, height in inches
75+ mpl .rcParams ['font.family' ] = 'serif'
76+ mpl .rcParams ['font.serif' ] = [
77+ 'Times New Roman'
78+ ] + mpl .rcParams ['font.serif' ] # Add Times New Roman as first choice
79+ mpl .rcParams ['font.size' ] = 22
80+ mpl .rcParams ['savefig.dpi' ] = 300 # Set resolution for saved figures
81+
82+ # Plot Elements
83+ mpl .rcParams ['lines.linewidth' ] = 3 # Adjust line thickness if needed
84+ mpl .rcParams ['lines.markersize' ] = 6 # Adjust marker size if needed
85+ mpl .rcParams ['axes.prop_cycle' ] = mpl .cycler (
86+ color = ["#1f77b4" , "#ff7f0e" , "#2ca02c" , "#d62728" ,
87+ "#9467bd" ]) # Example color cycle (consider ColorBrewer or viridis)
88+ mpl .rcParams ['axes.labelsize' ] = 22 # Axis label font size
89+ mpl .rcParams ['xtick.labelsize' ] = 20 # Tick label font size
90+ mpl .rcParams ['ytick.labelsize' ] = 20
91+
92+ # Legends and Gridlines
93+ mpl .rcParams ['legend.fontsize' ] = 20 # Legend font size
94+ mpl .rcParams [
95+ 'legend.loc' ] = 'best' # Let matplotlib decide the best legend location
96+ mpl .rcParams ['axes.grid' ] = True # Enable grid
97+ mpl .rcParams ['grid.alpha' ] = 0.4 # Gridline transparency
98+
99+
100+ def print_dataframe (df ):
101+ tabulated_df = tabulate (df .T , headers = 'keys' , tablefmt = 'psql' )
102+ logging .info (tabulated_df )
103+
66104
67105def generate_eval_cols (metrics ):
68106 splits = ['train' , 'validation' ]
@@ -150,10 +188,10 @@ def get_workloads_time_to_target(submission,
150188 if strict :
151189 raise ValueError (
152190 f'Expecting { NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS } workloads '
153- f'but found { num_workloads } workloads.' )
191+ f'but found { num_workloads } workloads for { submission_name } .' )
154192 logging .warning (
155193 f'Expecting { NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS } workloads '
156- f'but found { num_workloads } workloads.' )
194+ f'but found { num_workloads } workloads for { submission_name } .' )
157195
158196 # For each workload get submission time get the submission times to target.
159197 for workload , group in submission .groupby ('workload' ):
@@ -164,11 +202,13 @@ def get_workloads_time_to_target(submission,
164202 num_studies = len (group .groupby ('study' ))
165203 if num_studies != NUM_STUDIES :
166204 if strict :
167- raise ValueError (f'Expecting { NUM_STUDIES } trials for workload '
168- f'{ workload } but found { num_studies } trials.' )
205+ raise ValueError (f'Expecting { NUM_STUDIES } studies for workload '
206+ f'{ workload } but found { num_studies } studies '
207+ f'for { submission_name } .' )
169208 else :
170- logging .warning (f'Expecting { NUM_STUDIES } trials for workload '
171- f'{ workload } but found { num_studies } trials.' )
209+ logging .warning (f'Expecting { NUM_STUDIES } studies for workload '
210+ f'{ workload } but found { num_studies } studies '
211+ f'for { submission_name } .' )
172212
173213 # For each study check trials
174214 for study , group in group .groupby ('study' ):
@@ -177,11 +217,15 @@ def get_workloads_time_to_target(submission,
177217 num_trials = len (group )
178218 if num_trials != NUM_TRIALS and not self_tuning_ruleset :
179219 if strict :
180- raise ValueError (f'Expecting { NUM_TRIALS } trials for workload '
181- f'{ workload } but found { num_trials } trials.' )
220+ raise ValueError (
221+ f'In Study { study } : Expecting { NUM_TRIALS } trials for workload '
222+ f'{ workload } but found { num_trials } trials '
223+ f'for { submission_name } .' )
182224 else :
183- logging .warning (f'Expecting { NUM_TRIALS } trials for workload '
184- f'{ workload } but found { num_trials } trials.' )
225+ logging .warning (
226+ f'In Study { study } : Expecting { NUM_TRIALS } trials for workload '
227+ f'{ workload } but found { num_trials } trials '
228+ f'for { submission_name } .' )
185229
186230 # Get trial and time index that reaches target
187231 trial_idx , time_idx = get_best_trial_index (
@@ -194,13 +238,12 @@ def get_workloads_time_to_target(submission,
194238
195239 workloads .append ({
196240 'submission' : submission_name ,
197- 'workload' : workload ,
241+ 'workload' : re . sub ( r'_(jax|pytorch)$' , '' , workload ) ,
198242 time_col : np .median (time_vals_per_study ),
199243 })
200244
201245 df = pd .DataFrame .from_records (workloads )
202246 df = df .pivot (index = 'submission' , columns = 'workload' , values = time_col )
203-
204247 return df
205248
206249
@@ -210,6 +253,9 @@ def filter(x):
210253 try :
211254 if x [variant_workload ] == np .inf :
212255 return np .inf
256+ # Also check for nan values (e.g. OOMs)
257+ elif np .isnan (x [variant_workload ]):
258+ return np .inf
213259 else :
214260 return x [base_workload ]
215261 except KeyError as e :
@@ -268,27 +314,33 @@ def compute_performance_profiles(submissions,
268314 self_tuning_ruleset ,
269315 strict ))
270316 df = pd .concat (dfs )
317+ # Restrict to base and sampled held-out workloads
318+ # (ignore the additional workload variants of the baseline
319+ # as they cause issues when checking for nans in workload variants).
320+ df = df [BASE_WORKLOADS + HELDOUT_WORKLOADS ]
321+ # Sort workloads alphabetically (for better display)
322+ df = df .reindex (sorted (df .columns ), axis = 1 )
323+
324+ # For each held-out workload set to inf if the base workload is inf or nan
325+ for workload in df .keys ():
326+ if workload not in BASE_WORKLOADS :
327+ # If base do not have finite score set variant score to inf
328+ base_workload = get_base_workload_name (workload )
329+ df [workload ] = df .apply (
330+ variant_criteria_filter (workload , base_workload ), axis = 1 )
271331
272332 # Set score to inf if not within 4x of fastest submission
273333 best_scores = df .min (axis = 0 )
274334 df [df .apply (lambda x : x > 4 * best_scores , axis = 1 )] = np .inf
275335
276- # For each held-out workload if variant target was not hit set submission to inf
277- framework = None
336+ # For each base workload if variant target was not hit set submission to inf
278337 for workload in df .keys ():
279- # Check if this is a variant
280- framework = workload .split ('_' )[- 1 ]
281- workload_ = workload .split (f'_{ framework } ' )[0 ]
282- if workload_ not in BASE_WORKLOADS :
338+ if workload not in BASE_WORKLOADS :
283339 # If variants do not have finite score set base_workload score to inf
284- base_workload = get_base_workload_name (workload_ )
340+ base_workload = get_base_workload_name (workload )
285341 df [base_workload ] = df .apply (
286- variant_criteria_filter (base_workload + f'_{ framework } ' , workload ),
287- axis = 1 )
288-
289- base_workloads = [w + f'_{ framework } ' for w in BASE_WORKLOADS ]
290- df = df [base_workloads ]
291- print (df )
342+ variant_criteria_filter (base_workload , workload ), axis = 1 )
343+ df = df [BASE_WORKLOADS ]
292344
293345 if verbosity > 0 :
294346 logging .info ('\n `{time_col}` to reach target:' )
@@ -375,8 +427,7 @@ def plot_performance_profiles(perf_df,
375427 df_col ,
376428 scale = 'linear' ,
377429 save_dir = None ,
378- figsize = (30 , 10 ),
379- font_size = 18 ):
430+ figsize = (30 , 10 )):
380431 """Plot performance profiles.
381432
382433 Args:
@@ -396,12 +447,12 @@ def plot_performance_profiles(perf_df,
396447 Returns:
397448 None. If a valid save_dir is provided, save both the plot and perf_df.
398449 """
399- fig = perf_df .T .plot (figsize = figsize )
450+ fig = perf_df .T .plot (figsize = figsize , alpha = 0.7 )
400451 df_col_display = f'log10({ df_col } )' if scale == 'log' else df_col
401- fig .set_xlabel (
402- f'Ratio of ` { df_col_display } ` to best submission' , size = font_size )
403- fig .set_ylabel ( 'Proportion of workloads' , size = font_size )
404- fig . legend ( prop = { 'size' : font_size }, bbox_to_anchor = ( 1.0 , 1.0 ) )
452+ fig .set_xlabel (f'Ratio of ` { df_col_display } ` to best submission' )
453+ fig . set_ylabel ( 'Proportion of workloads' )
454+ fig .legend ( bbox_to_anchor = ( 1.0 , 1.0 ) )
455+ plt . tight_layout ( )
405456 maybe_save_figure (save_dir , f'performance_profile_by_{ df_col_display } ' )
406457 maybe_save_df_to_csv (save_dir ,
407458 perf_df ,
0 commit comments