Skip to content

Commit 0608e7a

Browse files
committed
Changes to satisfy mypy.
1 parent bf34e11 commit 0608e7a

2 files changed

Lines changed: 29 additions & 25 deletions

File tree

openml/runs/run.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@
1212
import openml._api_calls
1313
from ..exceptions import PyOpenMLError
1414
from ..flows import get_flow
15-
from ..tasks import get_task, TaskTypeEnum
15+
from ..tasks import (get_task,
16+
TaskTypeEnum,
17+
OpenMLClassificationTask,
18+
OpenMLLearningCurveTask,
19+
OpenMLClusteringTask,
20+
OpenMLRegressionTask
21+
)
1622

1723

1824
class OpenMLRun(object):
@@ -108,24 +114,24 @@ def from_filesystem(cls, directory: str, expect_model: bool = True) -> 'OpenMLRu
108114
if not os.path.isfile(model_path) and expect_model:
109115
raise ValueError('Could not find model.pkl')
110116

111-
with open(description_path, 'r') as fp:
112-
xml_string = fp.read()
117+
with open(description_path, 'r') as fht:
118+
xml_string = fht.read()
113119
run = openml.runs.functions._create_run_from_xml(xml_string, from_server=False)
114120

115121
if run.flow_id is None:
116122
flow = openml.flows.OpenMLFlow.from_filesystem(directory)
117123
run.flow = flow
118124
run.flow_name = flow.name
119125

120-
with open(predictions_path, 'r') as fp:
121-
predictions = arff.load(fp)
126+
with open(predictions_path, 'r') as fht:
127+
predictions = arff.load(fht)
122128
run.data_content = predictions['data']
123129

124130
if os.path.isfile(model_path):
125131
# note that it will load the model if the file exists, even if
126132
# expect_model is False
127-
with open(model_path, 'rb') as fp:
128-
run.model = pickle.load(fp)
133+
with open(model_path, 'rb') as fhb:
134+
run.model = pickle.load(fhb)
129135

130136
if os.path.isfile(trace_path):
131137
run.trace = openml.runs.OpenMLRunTrace._from_filesystem(trace_path)
@@ -208,7 +214,18 @@ def _generate_arff_dict(self) -> 'OrderedDict[str, Any]':
208214
arff_dict['relation'] =\
209215
'openml_task_{}_predictions'.format(task.task_id)
210216

211-
if task.task_type_id == TaskTypeEnum.SUPERVISED_CLASSIFICATION:
217+
if isinstance(task, OpenMLLearningCurveTask):
218+
class_labels = task.class_labels # type: ignore
219+
arff_dict['attributes'] = [('repeat', 'NUMERIC'),
220+
('fold', 'NUMERIC'),
221+
('sample', 'NUMERIC'),
222+
('row_id', 'NUMERIC')] + \
223+
[('confidence.' + class_labels[i],
224+
'NUMERIC') for i in
225+
range(len(class_labels))] + \
226+
[('prediction', class_labels),
227+
('correct', class_labels)]
228+
elif isinstance(task, OpenMLClassificationTask):
212229
class_labels = task.class_labels
213230
instance_specifications = [('repeat', 'NUMERIC'),
214231
('fold', 'NUMERIC'),
@@ -222,27 +239,14 @@ def _generate_arff_dict(self) -> 'OrderedDict[str, Any]':
222239
arff_dict['attributes'] = (instance_specifications
223240
+ prediction_confidences
224241
+ prediction_and_true)
225-
226-
elif task.task_type_id == TaskTypeEnum.LEARNING_CURVE:
227-
class_labels = task.class_labels
228-
arff_dict['attributes'] = [('repeat', 'NUMERIC'),
229-
('fold', 'NUMERIC'),
230-
('sample', 'NUMERIC'),
231-
('row_id', 'NUMERIC')] + \
232-
[('confidence.' + class_labels[i],
233-
'NUMERIC') for i in
234-
range(len(class_labels))] + \
235-
[('prediction', class_labels),
236-
('correct', class_labels)]
237-
238-
elif task.task_type_id == TaskTypeEnum.SUPERVISED_REGRESSION:
242+
elif isinstance(task, OpenMLRegressionTask):
239243
arff_dict['attributes'] = [('repeat', 'NUMERIC'),
240244
('fold', 'NUMERIC'),
241245
('row_id', 'NUMERIC'),
242246
('prediction', 'NUMERIC'),
243247
('truth', 'NUMERIC')]
244248

245-
elif task.task_type == TaskTypeEnum.CLUSTERING:
249+
elif isinstance(task, OpenMLClusteringTask):
246250
arff_dict['attributes'] = [('repeat', 'NUMERIC'),
247251
('fold', 'NUMERIC'),
248252
('row_id', 'NUMERIC'),

openml/runs/trace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, run_id, trace_iterations):
3232
self.run_id = run_id
3333
self.trace_iterations = trace_iterations
3434

35-
def get_selected_iteration(self, fold: int, repeat: int) -> 'OpenMLTraceIteration':
35+
def get_selected_iteration(self, fold: int, repeat: int) -> int:
3636
"""
3737
Returns the trace iteration that was marked as selected. In
3838
case multiple are marked as selected (should not happen) the
@@ -46,7 +46,7 @@ def get_selected_iteration(self, fold: int, repeat: int) -> 'OpenMLTraceIteratio
4646
4747
Returns
4848
----------
49-
OpenMLTraceIteration
49+
int
5050
The trace iteration from the given fold and repeat that was
5151
selected as the best iteration by the search procedure
5252
"""

0 commit comments

Comments
 (0)