Skip to content

Commit 5149348

Browse files
committed
moved trace attribute extraction elsewhere
1 parent c4a23d1 commit 5149348

3 files changed

Lines changed: 41 additions & 30 deletions

File tree

openml/runs/functions.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,16 @@ def _run_task_get_arffcontent(model, task, class_labels):
189189

190190
if not isinstance(model, BaseSearchCV):
191191
arff_tracecontent = None
192+
arff_trace_attributes = None
193+
else:
194+
# arff_tracecontent is already set
195+
arff_trace_attributes = _extract_arfftrace_attributes(model_fold)
192196

193-
return arff_datacontent, arff_tracecontent
197+
return arff_datacontent, arff_tracecontent, arff_trace_attributes
194198

195199

196-
def _add_results_to_arfftrace(arff_tracecontent, fold_no, model, rep_no):
200+
def _extract_arfftrace(model, rep_no, fold_no):
201+
arff_tracecontent = []
197202
for itt_no in range(0, len(model.cv_results_['mean_test_score'])):
198203
# we use the string values for True and False, as it is defined in this way by the OpenML server
199204
selected = 'false'
@@ -205,6 +210,30 @@ def _add_results_to_arfftrace(arff_tracecontent, fold_no, model, rep_no):
205210
if key.startswith("param_"):
206211
arff_line.append(str(model.cv_results_[key][itt_no]))
207212
arff_tracecontent.append(arff_line)
213+
return arff_tracecontent
214+
215+
def _extract_arfftrace_attributes(model):
216+
# attributes that will be in trace arff, regardless of the model
217+
trace_attributes = [('repeat', 'NUMERIC'),
218+
('fold', 'NUMERIC'),
219+
('iteration', 'NUMERIC'),
220+
('evaluation', 'NUMERIC'),
221+
('selected', ['true', 'false'])]
222+
223+
# model dependent attributes for trace arff
224+
for key in model.cv_results_:
225+
if key.startswith("param_"):
226+
if all(isinstance(i, (bool)) for i in model.cv_results_[key]):
227+
type = ['True', 'False']
228+
elif all(isinstance(i, (int, float)) for i in model.cv_results_[key]):
229+
type = 'NUMERIC'
230+
else:
231+
values = list(set(model.cv_results_[key])) # unique values
232+
type = [str(i) for i in values]
233+
234+
attribute = ("parameter_" + key[6:], type)
235+
trace_attributes.append(attribute)
236+
return trace_attributes
208237

209238

210239
def get_runs(run_ids):

openml/runs/run.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class OpenMLRun(object):
2323
def __init__(self, task_id, flow_id, dataset_id, setup_string=None,
2424
files=None, setup_id=None, tags=None, uploader=None, uploader_name=None,
2525
evaluations=None, detailed_evaluations=None,
26-
data_content=None, trace_content=None, model=None, task_type=None,
27-
task_evaluation_measure=None, flow_name=None,
26+
data_content=None, trace_attributes=None, trace_content=None,
27+
model=None, task_type=None, task_evaluation_measure=None, flow_name=None,
2828
parameter_settings=None, predictions_url=None, task=None,
2929
flow=None, run_id=None):
3030
self.uploader = uploader
@@ -42,6 +42,7 @@ def __init__(self, task_id, flow_id, dataset_id, setup_string=None,
4242
self.evaluations = evaluations
4343
self.detailed_evaluations = detailed_evaluations
4444
self.data_content = data_content
45+
self.trace_attributes = trace_attributes
4546
self.trace_content = trace_content
4647
self.error_message = None
4748
self.task = task
@@ -80,7 +81,7 @@ def _generate_arff_dict(self):
8081
arff_dict['relation'] = 'openml_task_' + str(task.task_id) + '_predictions'
8182
return arff_dict
8283

83-
def _generate_trace_arff_dict(self, model):
84+
def _generate_trace_arff_dict(self):
8485
"""Generates the arff dictionary for uploading predictions to the server.
8586
8687
Assumes that the run has been executed.
@@ -91,32 +92,13 @@ def _generate_trace_arff_dict(self, model):
9192
Dictionary representation of the ARFF file that will be uploaded.
9293
Contains information about the optimization trace.
9394
"""
94-
if self.trace_content is None:
95+
if self.trace_content is None or len(self.trace_content) == 0:
9596
raise ValueError('No trace content avaiable.')
96-
if not isinstance(model, BaseSearchCV):
97-
raise PyOpenMLError('Cannot generate trace on provided classifier. (This should never happen.)')
97+
if len(self.trace_attributes) != len(self.trace_content[0]):
98+
raise ValueError('Trace_attributes and trace_content not compatible')
9899

99100
arff_dict = {}
100-
arff_dict['attributes'] = [('repeat', 'NUMERIC'),
101-
('fold', 'NUMERIC'),
102-
('iteration', 'NUMERIC'),
103-
('evaluation', 'NUMERIC'),
104-
('selected', ['true', 'false'])]
105-
for key in model.cv_results_:
106-
if key.startswith("param_"):
107-
type = 'STRING'
108-
if all(isinstance(i, (bool)) for i in model.cv_results_[key]):
109-
type = ['True', 'False']
110-
elif all(isinstance(i, (int, float)) for i in model.cv_results_[key]):
111-
type = 'NUMERIC'
112-
else:
113-
values = list(set(model.cv_results_[key])) # unique values
114-
type = [str(i) for i in values]
115-
print(key + ": " + str(type))
116-
117-
attribute = ("parameter_" + key[6:], type)
118-
arff_dict['attributes'].append(attribute)
119-
101+
arff_dict['attributes'] = self.trace_attributes
120102
arff_dict['data'] = self.trace_content
121103
arff_dict['relation'] = 'openml_task_' + str(self.task_id) + '_predictions'
122104

@@ -145,7 +127,7 @@ def publish(self):
145127
file_elements['predictions'] = ("predictions.arff", predictions)
146128

147129
if self.trace_content is not None:
148-
trace_arff = arff.dumps(self._generate_trace_arff_dict(self.model))
130+
trace_arff = arff.dumps(self._generate_trace_arff_dict())
149131
file_elements['trace'] = ("trace.arff", trace_arff)
150132

151133
return_value = _perform_api_call("/run/", file_elements=file_elements)

tests/test_runs/test_run_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test__run_task_get_arffcontent(self):
124124
clf, task, class_labels)
125125

126126
clf = SGDClassifier(loss='log', random_state=1)
127-
arff_datacontent, arff_tracecontent = openml.runs.functions._run_task_get_arffcontent(
127+
arff_datacontent, arff_tracecontent, _ = openml.runs.functions._run_task_get_arffcontent(
128128
clf, task, class_labels)
129129
# predictions
130130
self.assertIsInstance(arff_datacontent, list)

0 commit comments

Comments
 (0)