3131import re
3232
3333from absl import logging
34+ import matplotlib as mpl
3435import matplotlib .pyplot as plt
3536import numpy as np
3637import pandas as pd
38+ from tabulate import tabulate
39+ import re
40+
41+ import logging
3742
3843from algorithmic_efficiency .workloads .workloads import get_base_workload_name
3944import algorithmic_efficiency .workloads .workloads as workloads_registry
6368
6469MAX_EVAL_METRICS = ['mean_average_precision' , 'ssim' , 'accuracy' , 'bleu' ]
6570
71+ #MPL params
72+ mpl .rcParams ['figure.figsize' ] = (16 , 10 ) # Width, height in inches
73+ mpl .rcParams ['font.family' ] = 'serif'
74+ mpl .rcParams ['font.serif' ] = ['Times New Roman' ] + mpl .rcParams ['font.serif' ] # Add Times New Roman as first choice
75+ mpl .rcParams ['font.size' ] = 22
76+ mpl .rcParams ['savefig.dpi' ] = 300 # Set resolution for saved figures
77+
78+ # Plot Elements
79+ mpl .rcParams ['lines.linewidth' ] = 3 # Adjust line thickness if needed
80+ mpl .rcParams ['lines.markersize' ] = 6 # Adjust marker size if needed
81+ mpl .rcParams ['axes.prop_cycle' ] = mpl .cycler (color = ["#1f77b4" , "#ff7f0e" , "#2ca02c" , "#d62728" , "#9467bd" ]) # Example color cycle (consider ColorBrewer or viridis)
82+ mpl .rcParams ['axes.labelsize' ] = 22 # Axis label font size
83+ mpl .rcParams ['xtick.labelsize' ] = 20 # Tick label font size
84+ mpl .rcParams ['ytick.labelsize' ] = 20
85+
86+ # Legends and Gridlines
87+ mpl .rcParams ['legend.fontsize' ] = 20 # Legend font size
88+ mpl .rcParams ['legend.loc' ] = 'best' # Let matplotlib decide the best legend location
89+ mpl .rcParams ['axes.grid' ] = True # Enable grid
90+ mpl .rcParams ['grid.alpha' ] = 0.4 # Gridline transparency
91+
92+ def print_dataframe (df ):
93+ tabulated_df = tabulate (df .T , headers = 'keys' , tablefmt = 'psql' )
94+ logging .info (tabulated_df )
6695
6796def generate_eval_cols (metrics ):
6897 splits = ['train' , 'validation' ]
@@ -177,10 +206,10 @@ def get_workloads_time_to_target(submission,
177206 num_trials = len (group )
178207 if num_trials != NUM_TRIALS and not self_tuning_ruleset :
179208 if strict :
180- raise ValueError (f'Expecting { NUM_TRIALS } trials for workload '
209+ raise ValueError (f'In Study { study } : Expecting { NUM_TRIALS } trials for workload '
181210 f'{ workload } but found { num_trials } trials.' )
182211 else :
183- logging .warning (f'Expecting { NUM_TRIALS } trials for workload '
212+ logging .warning (f'In Study { study } : Expecting { NUM_TRIALS } trials for workload '
184213 f'{ workload } but found { num_trials } trials.' )
185214
186215 # Get trial and time index that reaches target
@@ -194,13 +223,14 @@ def get_workloads_time_to_target(submission,
194223
195224 workloads .append ({
196225 'submission' : submission_name ,
197- 'workload' : workload ,
226+ 'workload' : re . sub ( r'_(jax|pytorch)$' , '' , workload ) ,
198227 time_col : np .median (time_vals_per_study ),
199228 })
200229
201230 df = pd .DataFrame .from_records (workloads )
202231 df = df .pivot (index = 'submission' , columns = 'workload' , values = time_col )
203-
232+ logging .info ("HELLOOOOOOOOO" )
233+ print_dataframe (df )
204234 return df
205235
206236
@@ -269,26 +299,30 @@ def compute_performance_profiles(submissions,
269299 strict ))
270300 df = pd .concat (dfs )
271301
302+ logging .info ("TIME TO TARGET" )
303+ print_dataframe (df )
304+
272305 # Set score to inf if not within 4x of fastest submission
273306 best_scores = df .min (axis = 0 )
274307 df [df .apply (lambda x : x > 4 * best_scores , axis = 1 )] = np .inf
275308
309+ logging .info ("4X of budget" )
310+ print_dataframe (df )
311+
276312 # For each held-out workload if variant target was not hit set submission to inf
277313 framework = None
278314 for workload in df .keys ():
279- # Check if this is a variant
280- framework = workload .split ('_' )[- 1 ]
281- workload_ = workload .split (f'_{ framework } ' )[0 ]
282- if workload_ not in BASE_WORKLOADS :
315+ if workload not in BASE_WORKLOADS :
283316 # If variants do not have finite score set base_workload score to inf
284- base_workload = get_base_workload_name (workload_ )
317+ base_workload = get_base_workload_name (workload )
285318 df [base_workload ] = df .apply (
286- variant_criteria_filter (base_workload + f'_ { framework } ' , workload ),
319+ variant_criteria_filter (base_workload , workload ),
287320 axis = 1 )
321+
322+ logging .info ("HELDOUT_WORKLOAD FILTER" )
323+ print_dataframe (df )
288324
289- base_workloads = [w + f'_{ framework } ' for w in BASE_WORKLOADS ]
290- df = df [base_workloads ]
291- print (df )
325+ df = df [BASE_WORKLOADS ]
292326
293327 if verbosity > 0 :
294328 logging .info ('\n `{time_col}` to reach target:' )
@@ -316,11 +350,17 @@ def compute_performance_profiles(submissions,
316350 1000 ):
317351 logging .info (df )
318352
353+ logging .info ('DIVIDE BY FASTEST' )
354+ print_dataframe (df )
355+
319356 # If no max_tau is supplied, choose the value of tau that would plot all non
320357 # inf or nan data.
321358 if max_tau is None :
322359 max_tau = df .replace (float ('inf' ), - 1 ).replace (np .nan , - 1 ).values .max ()
323360
361+ logging .info ('AFTER MAYBE SETTING MAX TAU' )
362+ print_dataframe (df )
363+
324364 if scale == 'linear' :
325365 points = np .linspace (min_tau , max_tau , num = num_points )
326366 elif scale == 'log' :
@@ -375,8 +415,8 @@ def plot_performance_profiles(perf_df,
375415 df_col ,
376416 scale = 'linear' ,
377417 save_dir = None ,
378- figsize = (30 , 10 ),
379- font_size = 18 ):
418+ figsize = (30 , 10 )
419+ ):
380420 """Plot performance profiles.
381421
382422 Args:
@@ -396,12 +436,13 @@ def plot_performance_profiles(perf_df,
396436 Returns:
397437 None. If a valid save_dir is provided, save both the plot and perf_df.
398438 """
399- fig = perf_df .T .plot (figsize = figsize )
439+ fig = perf_df .T .plot (figsize = figsize , alpha = 0.7 )
400440 df_col_display = f'log10({ df_col } )' if scale == 'log' else df_col
401441 fig .set_xlabel (
402- f'Ratio of `{ df_col_display } ` to best submission' , size = font_size )
403- fig .set_ylabel ('Proportion of workloads' , size = font_size )
404- fig .legend (prop = {'size' : font_size }, bbox_to_anchor = (1.0 , 1.0 ))
442+ f'Ratio of `{ df_col_display } ` to best submission' )
443+ fig .set_ylabel ('Proportion of workloads' )
444+ fig .legend (bbox_to_anchor = (1.0 , 1.0 ))
445+ plt .tight_layout ()
405446 maybe_save_figure (save_dir , f'performance_profile_by_{ df_col_display } ' )
406447 maybe_save_df_to_csv (save_dir ,
407448 perf_df ,
0 commit comments