Skip to content

Commit a3e513e

Browse files
committed
cosmetic and functional fixes
1 parent dc6f189 commit a3e513e

2 files changed

Lines changed: 89 additions & 30 deletions

File tree

scoring/performance_profile.py

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,14 @@
3131
import re
3232

3333
from absl import logging
34+
import matplotlib as mpl
3435
import matplotlib.pyplot as plt
3536
import numpy as np
3637
import pandas as pd
38+
from tabulate import tabulate
39+
import re
40+
41+
import logging
3742

3843
from algorithmic_efficiency.workloads.workloads import get_base_workload_name
3944
import algorithmic_efficiency.workloads.workloads as workloads_registry
@@ -63,6 +68,30 @@
6368

6469
MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu']
6570

71+
#MPL params
72+
mpl.rcParams['figure.figsize'] = (16, 10) # Width, height in inches
73+
mpl.rcParams['font.family'] = 'serif'
74+
mpl.rcParams['font.serif'] = ['Times New Roman'] + mpl.rcParams['font.serif'] # Add Times New Roman as first choice
75+
mpl.rcParams['font.size'] = 22
76+
mpl.rcParams['savefig.dpi'] = 300 # Set resolution for saved figures
77+
78+
# Plot Elements
79+
mpl.rcParams['lines.linewidth'] = 3 # Adjust line thickness if needed
80+
mpl.rcParams['lines.markersize'] = 6 # Adjust marker size if needed
81+
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color=["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"]) # Example color cycle (consider ColorBrewer or viridis)
82+
mpl.rcParams['axes.labelsize'] = 22 # Axis label font size
83+
mpl.rcParams['xtick.labelsize'] = 20 # Tick label font size
84+
mpl.rcParams['ytick.labelsize'] = 20
85+
86+
# Legends and Gridlines
87+
mpl.rcParams['legend.fontsize'] = 20 # Legend font size
88+
mpl.rcParams['legend.loc'] = 'best' # Let matplotlib decide the best legend location
89+
mpl.rcParams['axes.grid'] = True # Enable grid
90+
mpl.rcParams['grid.alpha'] = 0.4 # Gridline transparency
91+
92+
def print_dataframe(df):
93+
tabulated_df = tabulate(df.T, headers='keys', tablefmt='psql')
94+
logging.info(tabulated_df)
6695

6796
def generate_eval_cols(metrics):
6897
splits = ['train', 'validation']
@@ -177,10 +206,10 @@ def get_workloads_time_to_target(submission,
177206
num_trials = len(group)
178207
if num_trials != NUM_TRIALS and not self_tuning_ruleset:
179208
if strict:
180-
raise ValueError(f'Expecting {NUM_TRIALS} trials for workload '
209+
raise ValueError(f'In Study {study}: Expecting {NUM_TRIALS} trials for workload '
181210
f'{workload} but found {num_trials} trials.')
182211
else:
183-
logging.warning(f'Expecting {NUM_TRIALS} trials for workload '
212+
logging.warning(f'In Study {study}: Expecting {NUM_TRIALS} trials for workload '
184213
f'{workload} but found {num_trials} trials.')
185214

186215
# Get trial and time index that reaches target
@@ -194,13 +223,14 @@ def get_workloads_time_to_target(submission,
194223

195224
workloads.append({
196225
'submission': submission_name,
197-
'workload': workload,
226+
'workload': re.sub(r'_(jax|pytorch)$', '', workload),
198227
time_col: np.median(time_vals_per_study),
199228
})
200229

201230
df = pd.DataFrame.from_records(workloads)
202231
df = df.pivot(index='submission', columns='workload', values=time_col)
203-
232+
logging.info("HELLOOOOOOOOO")
233+
print_dataframe(df)
204234
return df
205235

206236

@@ -269,26 +299,30 @@ def compute_performance_profiles(submissions,
269299
strict))
270300
df = pd.concat(dfs)
271301

302+
logging.info("TIME TO TARGET")
303+
print_dataframe(df)
304+
272305
# Set score to inf if not within 4x of fastest submission
273306
best_scores = df.min(axis=0)
274307
df[df.apply(lambda x: x > 4 * best_scores, axis=1)] = np.inf
275308

309+
logging.info("4X of budget")
310+
print_dataframe(df)
311+
276312
# For each held-out workload if variant target was not hit set submission to inf
277313
framework = None
278314
for workload in df.keys():
279-
# Check if this is a variant
280-
framework = workload.split('_')[-1]
281-
workload_ = workload.split(f'_{framework}')[0]
282-
if workload_ not in BASE_WORKLOADS:
315+
if workload not in BASE_WORKLOADS:
283316
# If variants do not have finite score set base_workload score to inf
284-
base_workload = get_base_workload_name(workload_)
317+
base_workload = get_base_workload_name(workload)
285318
df[base_workload] = df.apply(
286-
variant_criteria_filter(base_workload + f'_{framework}', workload),
319+
variant_criteria_filter(base_workload, workload),
287320
axis=1)
321+
322+
logging.info("HELDOUT_WORKLOAD FILTER")
323+
print_dataframe(df)
288324

289-
base_workloads = [w + f'_{framework}' for w in BASE_WORKLOADS]
290-
df = df[base_workloads]
291-
print(df)
325+
df = df[BASE_WORKLOADS]
292326

293327
if verbosity > 0:
294328
logging.info('\n`{time_col}` to reach target:')
@@ -316,11 +350,17 @@ def compute_performance_profiles(submissions,
316350
1000):
317351
logging.info(df)
318352

353+
logging.info('DIVIDE BY FASTEST')
354+
print_dataframe(df)
355+
319356
# If no max_tau is supplied, choose the value of tau that would plot all non
320357
# inf or nan data.
321358
if max_tau is None:
322359
max_tau = df.replace(float('inf'), -1).replace(np.nan, -1).values.max()
323360

361+
logging.info('AFTER MAYBE SETTING MAX TAU')
362+
print_dataframe(df)
363+
324364
if scale == 'linear':
325365
points = np.linspace(min_tau, max_tau, num=num_points)
326366
elif scale == 'log':
@@ -375,8 +415,8 @@ def plot_performance_profiles(perf_df,
375415
df_col,
376416
scale='linear',
377417
save_dir=None,
378-
figsize=(30, 10),
379-
font_size=18):
418+
figsize=(30, 10)
419+
):
380420
"""Plot performance profiles.
381421
382422
Args:
@@ -396,12 +436,13 @@ def plot_performance_profiles(perf_df,
396436
Returns:
397437
None. If a valid save_dir is provided, save both the plot and perf_df.
398438
"""
399-
fig = perf_df.T.plot(figsize=figsize)
439+
fig = perf_df.T.plot(figsize=figsize, alpha=0.7)
400440
df_col_display = f'log10({df_col})' if scale == 'log' else df_col
401441
fig.set_xlabel(
402-
f'Ratio of `{df_col_display}` to best submission', size=font_size)
403-
fig.set_ylabel('Proportion of workloads', size=font_size)
404-
fig.legend(prop={'size': font_size}, bbox_to_anchor=(1.0, 1.0))
442+
f'Ratio of `{df_col_display}` to best submission')
443+
fig.set_ylabel('Proportion of workloads')
444+
fig.legend(bbox_to_anchor=(1.0, 1.0))
445+
plt.tight_layout()
405446
maybe_save_figure(save_dir, f'performance_profile_by_{df_col_display}')
406447
maybe_save_df_to_csv(save_dir,
407448
perf_df,

scoring/score_submissions.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
import pandas as pd
2323
import scoring_utils
2424
from tabulate import tabulate
25+
import json
26+
import pickle
2527

26-
from scoring import performance_profile
28+
import performance_profile
2729

2830
flags.DEFINE_string(
2931
'submission_directory',
@@ -101,8 +103,13 @@ def get_summary_df(workload, workload_df, include_test_split=False):
101103
return summary_df
102104

103105

104-
def print_submission_summary(df, include_test_split=True):
106+
def get_submission_summary(df, include_test_split=True):
107+
"""Summarizes the submission results into metric and time tables
108+
organized by workload.
109+
"""
110+
105111
dfs = []
112+
print(df)
106113
for workload, group in df.groupby('workload'):
107114
summary_df = get_summary_df(
108115
workload, group, include_test_split=include_test_split)
@@ -115,15 +122,26 @@ def print_submission_summary(df, include_test_split=True):
115122

116123
def main(_):
117124
results = {}
118-
119-
for submission in os.listdir(FLAGS.submission_directory):
120-
experiment_path = os.path.join(FLAGS.submission_directory, submission)
121-
df = scoring_utils.get_experiment_df(experiment_path)
122-
results[submission] = df
123-
summary_df = print_submission_summary(df)
124-
with open(os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'),
125-
'w') as fout:
126-
summary_df.to_csv(fout)
125+
os.makedirs(FLAGS.output_dir, exist_ok=True)
126+
127+
# for team in os.listdir(FLAGS.submission_directory):
128+
# for submission in os.listdir(os.path.join(FLAGS.submission_directory, team)):
129+
# print(submission)
130+
# experiment_path = os.path.join(FLAGS.submission_directory, team, submission)
131+
# df = scoring_utils.get_experiment_df(experiment_path)
132+
# results[submission] = df
133+
# summary_df = get_submission_summary(df)
134+
# with open(os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'),
135+
# 'w') as fout:
136+
# summary_df.to_csv(fout)
137+
138+
# # Save results
139+
# with open(os.path.join(FLAGS.output_dir, 'results.pkl'), 'wb') as f:
140+
# pickle.dump(results, f)
141+
142+
# Read results
143+
with open(os.path.join(FLAGS.output_dir, 'results.pkl'), 'rb') as f:
144+
results = pickle.load(f)
127145

128146
if not FLAGS.strict:
129147
logging.warning(

0 commit comments

Comments
 (0)