2626 the dictionary of submissions.
2727"""
2828import itertools
29+ import logging
2930import operator
3031import os
3132import re
3233
3334from absl import logging
35+ import matplotlib as mpl
3436import matplotlib .pyplot as plt
3537import numpy as np
3638import pandas as pd
39+ from tabulate import tabulate
3740
3841from algorithmic_efficiency .workloads .workloads import get_base_workload_name
3942import algorithmic_efficiency .workloads .workloads as workloads_registry
6366
6467MAX_EVAL_METRICS = ['mean_average_precision' , 'ssim' , 'accuracy' , 'bleu' ]
6568
69+ #MPL params
70+ mpl .rcParams ['figure.figsize' ] = (16 , 10 ) # Width, height in inches
71+ mpl .rcParams ['font.family' ] = 'serif'
72+ mpl .rcParams ['font.serif' ] = [
73+ 'Times New Roman'
74+ ] + mpl .rcParams ['font.serif' ] # Add Times New Roman as first choice
75+ mpl .rcParams ['font.size' ] = 22
76+ mpl .rcParams ['savefig.dpi' ] = 300 # Set resolution for saved figures
77+
78+ # Plot Elements
79+ mpl .rcParams ['lines.linewidth' ] = 3 # Adjust line thickness if needed
80+ mpl .rcParams ['lines.markersize' ] = 6 # Adjust marker size if needed
81+ mpl .rcParams ['axes.prop_cycle' ] = mpl .cycler (
82+ color = ["#1f77b4" , "#ff7f0e" , "#2ca02c" , "#d62728" ,
83+ "#9467bd" ]) # Example color cycle (consider ColorBrewer or viridis)
84+ mpl .rcParams ['axes.labelsize' ] = 22 # Axis label font size
85+ mpl .rcParams ['xtick.labelsize' ] = 20 # Tick label font size
86+ mpl .rcParams ['ytick.labelsize' ] = 20
87+
88+ # Legends and Gridlines
89+ mpl .rcParams ['legend.fontsize' ] = 20 # Legend font size
90+ mpl .rcParams [
91+ 'legend.loc' ] = 'best' # Let matplotlib decide the best legend location
92+ mpl .rcParams ['axes.grid' ] = True # Enable grid
93+ mpl .rcParams ['grid.alpha' ] = 0.4 # Gridline transparency
94+
95+
96+ def print_dataframe (df ):
97+ tabulated_df = tabulate (df .T , headers = 'keys' , tablefmt = 'psql' )
98+ logging .info (tabulated_df )
99+
66100
67101def generate_eval_cols (metrics ):
68102 splits = ['train' , 'validation' ]
@@ -177,11 +211,13 @@ def get_workloads_time_to_target(submission,
177211 num_trials = len (group )
178212 if num_trials != NUM_TRIALS and not self_tuning_ruleset :
179213 if strict :
180- raise ValueError (f'Expecting { NUM_TRIALS } trials for workload '
181- f'{ workload } but found { num_trials } trials.' )
214+ raise ValueError (
215+ f'In Study { study } : Expecting { NUM_TRIALS } trials for workload '
216+ f'{ workload } but found { num_trials } trials.' )
182217 else :
183- logging .warning (f'Expecting { NUM_TRIALS } trials for workload '
184- f'{ workload } but found { num_trials } trials.' )
218+ logging .warning (
219+ f'In Study { study } : Expecting { NUM_TRIALS } trials for workload '
220+ f'{ workload } but found { num_trials } trials.' )
185221
186222 # Get trial and time index that reaches target
187223 trial_idx , time_idx = get_best_trial_index (
@@ -194,13 +230,12 @@ def get_workloads_time_to_target(submission,
194230
195231 workloads .append ({
196232 'submission' : submission_name ,
197- 'workload' : workload ,
233+ 'workload' : re . sub ( r'_(jax|pytorch)$' , '' , workload ) ,
198234 time_col : np .median (time_vals_per_study ),
199235 })
200236
201237 df = pd .DataFrame .from_records (workloads )
202238 df = df .pivot (index = 'submission' , columns = 'workload' , values = time_col )
203-
204239 return df
205240
206241
@@ -276,19 +311,13 @@ def compute_performance_profiles(submissions,
276311 # For each held-out workload if variant target was not hit set submission to inf
277312 framework = None
278313 for workload in df .keys ():
279- # Check if this is a variant
280- framework = workload .split ('_' )[- 1 ]
281- workload_ = workload .split (f'_{ framework } ' )[0 ]
282- if workload_ not in BASE_WORKLOADS :
314+ if workload not in BASE_WORKLOADS :
283315 # If variants do not have finite score set base_workload score to inf
284- base_workload = get_base_workload_name (workload_ )
316+ base_workload = get_base_workload_name (workload )
285317 df [base_workload ] = df .apply (
286- variant_criteria_filter (base_workload + f'_{ framework } ' , workload ),
287- axis = 1 )
318+ variant_criteria_filter (base_workload , workload ), axis = 1 )
288319
289- base_workloads = [w + f'_{ framework } ' for w in BASE_WORKLOADS ]
290- df = df [base_workloads ]
291- print (df )
320+ df = df [BASE_WORKLOADS ]
292321
293322 if verbosity > 0 :
294323 logging .info ('\n `{time_col}` to reach target:' )
@@ -375,8 +404,7 @@ def plot_performance_profiles(perf_df,
375404 df_col ,
376405 scale = 'linear' ,
377406 save_dir = None ,
378- figsize = (30 , 10 ),
379- font_size = 18 ):
407+ figsize = (30 , 10 )):
380408 """Plot performance profiles.
381409
382410 Args:
@@ -396,12 +424,12 @@ def plot_performance_profiles(perf_df,
396424 Returns:
397425 None. If a valid save_dir is provided, save both the plot and perf_df.
398426 """
399- fig = perf_df .T .plot (figsize = figsize )
427+ fig = perf_df .T .plot (figsize = figsize , alpha = 0.7 )
400428 df_col_display = f'log10({ df_col } )' if scale == 'log' else df_col
401- fig .set_xlabel (
402- f'Ratio of ` { df_col_display } ` to best submission' , size = font_size )
403- fig .set_ylabel ( 'Proportion of workloads' , size = font_size )
404- fig . legend ( prop = { 'size' : font_size }, bbox_to_anchor = ( 1.0 , 1.0 ) )
429+ fig .set_xlabel (f'Ratio of ` { df_col_display } ` to best submission' )
430+ fig . set_ylabel ( 'Proportion of workloads' )
431+ fig .legend ( bbox_to_anchor = ( 1.0 , 1.0 ) )
432+ plt . tight_layout ( )
405433 maybe_save_figure (save_dir , f'performance_profile_by_{ df_col_display } ' )
406434 maybe_save_df_to_csv (save_dir ,
407435 perf_df ,
0 commit comments