Skip to content

Commit 5d2a412

Browse files
committed
moved trace attribute extraction elsewhere
1 parent ae5999a commit 5d2a412

3 files changed

Lines changed: 43 additions & 32 deletions

File tree

openml/runs/functions.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def run_task(task, model):
7070
run = OpenMLRun(task_id=task.task_id, flow_id=flow_id, dataset_id=dataset.dataset_id, model=model)
7171

7272
try:
73-
run.data_content, run.trace_content = _run_task_get_arffcontent(model, task, class_labels)
73+
run.data_content, run.trace_content, run.trace_attributes = _run_task_get_arffcontent(model, task, class_labels)
7474
except PyOpenMLError as message:
7575
run.error_message = str(message)
7676
warnings.warn("Run terminated with error: %s" %run.error_message)
@@ -159,7 +159,7 @@ def _run_task_get_arffcontent(model, task, class_labels):
159159
model_fold.fit(trainX, trainY)
160160

161161
if isinstance(model_fold, BaseSearchCV):
162-
_add_results_to_arfftrace(arff_tracecontent, fold_no, model_fold, rep_no)
162+
arff_tracecontent.extend(_extract_arfftrace(model_fold, rep_no, fold_no))
163163
model_classes = model_fold.best_estimator_.classes_
164164
else:
165165
model_classes = model_fold.classes_
@@ -181,11 +181,16 @@ def _run_task_get_arffcontent(model, task, class_labels):
181181

182182
if not isinstance(model, BaseSearchCV):
183183
arff_tracecontent = None
184+
arff_trace_attributes = None
185+
else:
186+
# arff_tracecontent is already set
187+
arff_trace_attributes = _extract_arfftrace_attributes(model_fold)
184188

185-
return arff_datacontent, arff_tracecontent
189+
return arff_datacontent, arff_tracecontent, arff_trace_attributes
186190

187191

188-
def _add_results_to_arfftrace(arff_tracecontent, fold_no, model, rep_no):
192+
def _extract_arfftrace(model, rep_no, fold_no):
193+
arff_tracecontent = []
189194
for itt_no in range(0, len(model.cv_results_['mean_test_score'])):
190195
# we use the string values for True and False, as it is defined in this way by the OpenML server
191196
selected = 'false'
@@ -197,6 +202,30 @@ def _add_results_to_arfftrace(arff_tracecontent, fold_no, model, rep_no):
197202
if key.startswith("param_"):
198203
arff_line.append(str(model.cv_results_[key][itt_no]))
199204
arff_tracecontent.append(arff_line)
205+
return arff_tracecontent
206+
207+
def _extract_arfftrace_attributes(model):
208+
# attributes that will be in trace arff, regardless of the model
209+
trace_attributes = [('repeat', 'NUMERIC'),
210+
('fold', 'NUMERIC'),
211+
('iteration', 'NUMERIC'),
212+
('evaluation', 'NUMERIC'),
213+
('selected', ['true', 'false'])]
214+
215+
# model dependent attributes for trace arff
216+
for key in model.cv_results_:
217+
if key.startswith("param_"):
218+
if all(isinstance(i, (bool)) for i in model.cv_results_[key]):
219+
type = ['True', 'False']
220+
elif all(isinstance(i, (int, float)) for i in model.cv_results_[key]):
221+
type = 'NUMERIC'
222+
else:
223+
values = list(set(model.cv_results_[key])) # unique values
224+
type = [str(i) for i in values]
225+
226+
attribute = ("parameter_" + key[6:], type)
227+
trace_attributes.append(attribute)
228+
return trace_attributes
200229

201230

202231
def get_runs(run_ids):

openml/runs/run.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class OpenMLRun(object):
2323
def __init__(self, task_id, flow_id, dataset_id, setup_string=None,
2424
files=None, setup_id=None, tags=None, uploader=None, uploader_name=None,
2525
evaluations=None, detailed_evaluations=None,
26-
data_content=None, trace_content=None, model=None, task_type=None,
27-
task_evaluation_measure=None, flow_name=None,
26+
data_content=None, trace_attributes=None, trace_content=None,
27+
model=None, task_type=None, task_evaluation_measure=None, flow_name=None,
2828
parameter_settings=None, predictions_url=None, task=None,
2929
flow=None, run_id=None):
3030
self.uploader = uploader
@@ -42,6 +42,7 @@ def __init__(self, task_id, flow_id, dataset_id, setup_string=None,
4242
self.evaluations = evaluations
4343
self.detailed_evaluations = detailed_evaluations
4444
self.data_content = data_content
45+
self.trace_attributes = trace_attributes
4546
self.trace_content = trace_content
4647
self.error_message = None
4748
self.task = task
@@ -80,7 +81,7 @@ def _generate_arff_dict(self):
8081
arff_dict['relation'] = 'openml_task_' + str(task.task_id) + '_predictions'
8182
return arff_dict
8283

83-
def _generate_trace_arff_dict(self, model):
84+
def _generate_trace_arff_dict(self):
8485
"""Generates the arff dictionary for uploading predictions to the server.
8586
8687
Assumes that the run has been executed.
@@ -91,32 +92,13 @@ def _generate_trace_arff_dict(self, model):
9192
Dictionary representation of the ARFF file that will be uploaded.
9293
Contains information about the optimization trace.
9394
"""
94-
if self.trace_content is None:
95+
if self.trace_content is None or len(self.trace_content) == 0:
9596
raise ValueError('No trace content avaiable.')
96-
if not isinstance(model, BaseSearchCV):
97-
raise PyOpenMLError('Cannot generate trace on provided classifier. (This should never happen.)')
97+
if len(self.trace_attributes) != len(self.trace_content[0]):
98+
raise ValueError('Trace_attributes and trace_content not compatible')
9899

99100
arff_dict = {}
100-
arff_dict['attributes'] = [('repeat', 'NUMERIC'),
101-
('fold', 'NUMERIC'),
102-
('iteration', 'NUMERIC'),
103-
('evaluation', 'NUMERIC'),
104-
('selected', ['true', 'false'])]
105-
for key in model.cv_results_:
106-
if key.startswith("param_"):
107-
type = 'STRING'
108-
if all(isinstance(i, (bool)) for i in model.cv_results_[key]):
109-
type = ['True', 'False']
110-
elif all(isinstance(i, (int, float)) for i in model.cv_results_[key]):
111-
type = 'NUMERIC'
112-
else:
113-
values = list(set(model.cv_results_[key])) # unique values
114-
type = [str(i) for i in values]
115-
print(key + ": " + str(type))
116-
117-
attribute = ("parameter_" + key[6:], type)
118-
arff_dict['attributes'].append(attribute)
119-
101+
arff_dict['attributes'] = self.trace_attributes
120102
arff_dict['data'] = self.trace_content
121103
arff_dict['relation'] = 'openml_task_' + str(self.task_id) + '_predictions'
122104

@@ -145,7 +127,7 @@ def publish(self):
145127
file_elements['predictions'] = ("predictions.arff", predictions)
146128

147129
if self.trace_content is not None:
148-
trace_arff = arff.dumps(self._generate_trace_arff_dict(self.model))
130+
trace_arff = arff.dumps(self._generate_trace_arff_dict())
149131
file_elements['trace'] = ("trace.arff", trace_arff)
150132

151133
return_value = _perform_api_call("/run/", file_elements=file_elements)

tests/test_runs/test_run_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test__run_task_get_arffcontent(self):
117117
clf, task, class_labels)
118118

119119
clf = SGDClassifier(loss='log', random_state=1)
120-
arff_datacontent, arff_tracecontent = openml.runs.functions._run_task_get_arffcontent(
120+
arff_datacontent, arff_tracecontent, _ = openml.runs.functions._run_task_get_arffcontent(
121121
clf, task, class_labels)
122122
# predictions
123123
self.assertIsInstance(arff_datacontent, list)

0 commit comments

Comments
 (0)