Skip to content

Commit 61816f4

Browse files
committed
updated the tests
1 parent 808bc5b commit 61816f4

1 file changed

Lines changed: 24 additions & 2 deletions

File tree

openml/runs/run.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import (
1010
TYPE_CHECKING,
1111
Any,
12+
cast,
1213
)
1314

1415
import arff
@@ -154,11 +155,14 @@ def __init__( # noqa: PLR0913
154155
def predictions(self) -> pd.DataFrame:
155156
"""Return a DataFrame with predictions for this run"""
156157
if self._predictions is None:
158+
arff_dict: dict[str, Any]
157159
if self.data_content:
158160
arff_dict = self._generate_arff_dict()
159161
elif self.predictions_url:
160162
arff_text = openml._api_calls._download_text_file(self.predictions_url)
161-
arff_dict = arff.loads(arff_text)
163+
if arff_text is None:
164+
raise RuntimeError("Could not download predictions ARFF content.")
165+
arff_dict = self._load_predictions_arff(arff_text)
162166
else:
163167
raise RuntimeError("Run has no predictions.")
164168
self._predictions = pd.DataFrame(
@@ -167,6 +171,21 @@ def predictions(self) -> pd.DataFrame:
167171
)
168172
return self._predictions
169173

174+
@staticmethod
175+
def _load_predictions_arff(arff_text: str) -> dict[str, Any]:
176+
try:
177+
return cast("dict[str, Any]", arff.loads(arff_text))
178+
except arff.ArffException:
179+
normalized = arff_text.lstrip("\ufeff \t\r\n")
180+
relation_indexes = [
181+
idx
182+
for idx in [normalized.find("@relation"), normalized.find("@RELATION")]
183+
if idx >= 0
184+
]
185+
if relation_indexes:
186+
return cast("dict[str, Any]", arff.loads(normalized[min(relation_indexes) :]))
187+
raise
188+
170189
@property
171190
def id(self) -> int | None:
172191
"""The ID of the run, None if not uploaded to the server yet."""
@@ -525,6 +544,7 @@ def get_metric_fn(self, sklearn_fn: Callable, kwargs: dict | None = None) -> np.
525544
metric results
526545
"""
527546
kwargs = kwargs if kwargs else {}
547+
predictions_arff: dict[str, Any]
528548
if self.data_content is not None and self.task_id is not None:
529549
predictions_arff = self._generate_arff_dict()
530550
elif (self.output_files is not None) and ("predictions" in self.output_files):
@@ -533,7 +553,9 @@ def get_metric_fn(self, sklearn_fn: Callable, kwargs: dict | None = None) -> np.
533553
"predictions.arff",
534554
)
535555
response = openml._api_calls._download_text_file(predictions_file_url)
536-
predictions_arff = arff.loads(response)
556+
if response is None:
557+
raise ValueError("Could not download predictions ARFF content.")
558+
predictions_arff = self._load_predictions_arff(response)
537559
# TODO: make this a stream reader
538560
else:
539561
raise ValueError(

0 commit comments

Comments
 (0)