1212import openml ._api_calls
1313from ..exceptions import PyOpenMLError
1414from ..flows import get_flow
15- from ..tasks import get_task , TaskTypeEnum
15+ from ..tasks import (get_task ,
16+ TaskTypeEnum ,
17+ OpenMLClassificationTask ,
18+ OpenMLLearningCurveTask ,
19+ OpenMLClusteringTask ,
20+ OpenMLRegressionTask
21+ )
1622
1723
1824class OpenMLRun (object ):
@@ -108,24 +114,24 @@ def from_filesystem(cls, directory: str, expect_model: bool = True) -> 'OpenMLRu
108114 if not os .path .isfile (model_path ) and expect_model :
109115 raise ValueError ('Could not find model.pkl' )
110116
111- with open (description_path , 'r' ) as fp :
112- xml_string = fp .read ()
117+ with open (description_path , 'r' ) as fht :
118+ xml_string = fht .read ()
113119 run = openml .runs .functions ._create_run_from_xml (xml_string , from_server = False )
114120
115121 if run .flow_id is None :
116122 flow = openml .flows .OpenMLFlow .from_filesystem (directory )
117123 run .flow = flow
118124 run .flow_name = flow .name
119125
120- with open (predictions_path , 'r' ) as fp :
121- predictions = arff .load (fp )
126+ with open (predictions_path , 'r' ) as fht :
127+ predictions = arff .load (fht )
122128 run .data_content = predictions ['data' ]
123129
124130 if os .path .isfile (model_path ):
125131 # note that it will load the model if the file exists, even if
126132 # expect_model is False
127- with open (model_path , 'rb' ) as fp :
128- run .model = pickle .load (fp )
133+ with open (model_path , 'rb' ) as fhb :
134+ run .model = pickle .load (fhb )
129135
130136 if os .path .isfile (trace_path ):
131137 run .trace = openml .runs .OpenMLRunTrace ._from_filesystem (trace_path )
@@ -208,7 +214,18 @@ def _generate_arff_dict(self) -> 'OrderedDict[str, Any]':
208214 arff_dict ['relation' ] = \
209215 'openml_task_{}_predictions' .format (task .task_id )
210216
211- if task .task_type_id == TaskTypeEnum .SUPERVISED_CLASSIFICATION :
217+ if isinstance (task , OpenMLLearningCurveTask ):
218+ class_labels = task .class_labels # type: ignore
219+ arff_dict ['attributes' ] = [('repeat' , 'NUMERIC' ),
220+ ('fold' , 'NUMERIC' ),
221+ ('sample' , 'NUMERIC' ),
222+ ('row_id' , 'NUMERIC' )] + \
223+ [('confidence.' + class_labels [i ],
224+ 'NUMERIC' ) for i in
225+ range (len (class_labels ))] + \
226+ [('prediction' , class_labels ),
227+ ('correct' , class_labels )]
228+ elif isinstance (task , OpenMLClassificationTask ):
212229 class_labels = task .class_labels
213230 instance_specifications = [('repeat' , 'NUMERIC' ),
214231 ('fold' , 'NUMERIC' ),
@@ -222,27 +239,14 @@ def _generate_arff_dict(self) -> 'OrderedDict[str, Any]':
222239 arff_dict ['attributes' ] = (instance_specifications
223240 + prediction_confidences
224241 + prediction_and_true )
225-
226- elif task .task_type_id == TaskTypeEnum .LEARNING_CURVE :
227- class_labels = task .class_labels
228- arff_dict ['attributes' ] = [('repeat' , 'NUMERIC' ),
229- ('fold' , 'NUMERIC' ),
230- ('sample' , 'NUMERIC' ),
231- ('row_id' , 'NUMERIC' )] + \
232- [('confidence.' + class_labels [i ],
233- 'NUMERIC' ) for i in
234- range (len (class_labels ))] + \
235- [('prediction' , class_labels ),
236- ('correct' , class_labels )]
237-
238- elif task .task_type_id == TaskTypeEnum .SUPERVISED_REGRESSION :
242+ elif isinstance (task , OpenMLRegressionTask ):
239243 arff_dict ['attributes' ] = [('repeat' , 'NUMERIC' ),
240244 ('fold' , 'NUMERIC' ),
241245 ('row_id' , 'NUMERIC' ),
242246 ('prediction' , 'NUMERIC' ),
243247 ('truth' , 'NUMERIC' )]
244248
245- elif task . task_type == TaskTypeEnum . CLUSTERING :
249+ elif isinstance ( task , OpenMLClusteringTask ) :
246250 arff_dict ['attributes' ] = [('repeat' , 'NUMERIC' ),
247251 ('fold' , 'NUMERIC' ),
248252 ('row_id' , 'NUMERIC' ),
0 commit comments