@@ -23,8 +23,8 @@ class OpenMLRun(object):
2323 def __init__ (self , task_id , flow_id , dataset_id , setup_string = None ,
2424 files = None , setup_id = None , tags = None , uploader = None , uploader_name = None ,
2525 evaluations = None , detailed_evaluations = None ,
26- data_content = None , trace_content = None , model = None , task_type = None ,
27- task_evaluation_measure = None , flow_name = None ,
26+ data_content = None , trace_attributes = None , trace_content = None ,
27+ model = None , task_type = None , task_evaluation_measure = None , flow_name = None ,
2828 parameter_settings = None , predictions_url = None , task = None ,
2929 flow = None , run_id = None ):
3030 self .uploader = uploader
@@ -42,6 +42,7 @@ def __init__(self, task_id, flow_id, dataset_id, setup_string=None,
4242 self .evaluations = evaluations
4343 self .detailed_evaluations = detailed_evaluations
4444 self .data_content = data_content
45+ self .trace_attributes = trace_attributes
4546 self .trace_content = trace_content
4647 self .error_message = None
4748 self .task = task
@@ -80,7 +81,7 @@ def _generate_arff_dict(self):
8081 arff_dict ['relation' ] = 'openml_task_' + str (task .task_id ) + '_predictions'
8182 return arff_dict
8283
83- def _generate_trace_arff_dict (self , model ):
84+ def _generate_trace_arff_dict (self ):
8485 """Generates the arff dictionary for uploading predictions to the server.
8586
8687 Assumes that the run has been executed.
@@ -91,32 +92,13 @@ def _generate_trace_arff_dict(self, model):
9192 Dictionary representation of the ARFF file that will be uploaded.
9293 Contains information about the optimization trace.
9394 """
94- if self .trace_content is None :
95+ if self .trace_content is None or len ( self . trace_content ) == 0 :
9596 raise ValueError ('No trace content avaiable.' )
96- if not isinstance ( model , BaseSearchCV ):
97- raise PyOpenMLError ( 'Cannot generate trace on provided classifier. (This should never happen.) ' )
97+ if len ( self . trace_attributes ) != len ( self . trace_content [ 0 ] ):
98+ raise ValueError ( 'Trace_attributes and trace_content not compatible ' )
9899
99100 arff_dict = {}
100- arff_dict ['attributes' ] = [('repeat' , 'NUMERIC' ),
101- ('fold' , 'NUMERIC' ),
102- ('iteration' , 'NUMERIC' ),
103- ('evaluation' , 'NUMERIC' ),
104- ('selected' , ['true' , 'false' ])]
105- for key in model .cv_results_ :
106- if key .startswith ("param_" ):
107- type = 'STRING'
108- if all (isinstance (i , (bool )) for i in model .cv_results_ [key ]):
109- type = ['True' , 'False' ]
110- elif all (isinstance (i , (int , float )) for i in model .cv_results_ [key ]):
111- type = 'NUMERIC'
112- else :
113- values = list (set (model .cv_results_ [key ])) # unique values
114- type = [str (i ) for i in values ]
115- print (key + ": " + str (type ))
116-
117- attribute = ("parameter_" + key [6 :], type )
118- arff_dict ['attributes' ].append (attribute )
119-
101+ arff_dict ['attributes' ] = self .trace_attributes
120102 arff_dict ['data' ] = self .trace_content
121103 arff_dict ['relation' ] = 'openml_task_' + str (self .task_id ) + '_predictions'
122104
@@ -145,7 +127,7 @@ def publish(self):
145127 file_elements ['predictions' ] = ("predictions.arff" , predictions )
146128
147129 if self .trace_content is not None :
148- trace_arff = arff .dumps (self ._generate_trace_arff_dict (self . model ))
130+ trace_arff = arff .dumps (self ._generate_trace_arff_dict ())
149131 file_elements ['trace' ] = ("trace.arff" , trace_arff )
150132
151133 return_value = _perform_api_call ("/run/" , file_elements = file_elements )
0 commit comments