99from typing import (
1010 TYPE_CHECKING ,
1111 Any ,
12+ cast ,
1213)
1314
1415import arff
@@ -154,11 +155,14 @@ def __init__( # noqa: PLR0913
154155 def predictions (self ) -> pd .DataFrame :
155156 """Return a DataFrame with predictions for this run"""
156157 if self ._predictions is None :
158+ arff_dict : dict [str , Any ]
157159 if self .data_content :
158160 arff_dict = self ._generate_arff_dict ()
159161 elif self .predictions_url :
160162 arff_text = openml ._api_calls ._download_text_file (self .predictions_url )
161- arff_dict = arff .loads (arff_text )
163+ if arff_text is None :
164+ raise RuntimeError ("Could not download predictions ARFF content." )
165+ arff_dict = self ._load_predictions_arff (arff_text )
162166 else :
163167 raise RuntimeError ("Run has no predictions." )
164168 self ._predictions = pd .DataFrame (
@@ -167,6 +171,21 @@ def predictions(self) -> pd.DataFrame:
167171 )
168172 return self ._predictions
169173
174+ @staticmethod
175+ def _load_predictions_arff (arff_text : str ) -> dict [str , Any ]:
176+ try :
177+ return cast ("dict[str, Any]" , arff .loads (arff_text ))
178+ except arff .ArffException :
179+ normalized = arff_text .lstrip ("\ufeff \t \r \n " )
180+ relation_indexes = [
181+ idx
182+ for idx in [normalized .find ("@relation" ), normalized .find ("@RELATION" )]
183+ if idx >= 0
184+ ]
185+ if relation_indexes :
186+ return cast ("dict[str, Any]" , arff .loads (normalized [min (relation_indexes ) :]))
187+ raise
188+
170189 @property
171190 def id (self ) -> int | None :
172191 """The ID of the run, None if not uploaded to the server yet."""
@@ -525,6 +544,7 @@ def get_metric_fn(self, sklearn_fn: Callable, kwargs: dict | None = None) -> np.
525544 metric results
526545 """
527546 kwargs = kwargs if kwargs else {}
547+ predictions_arff : dict [str , Any ]
528548 if self .data_content is not None and self .task_id is not None :
529549 predictions_arff = self ._generate_arff_dict ()
530550 elif (self .output_files is not None ) and ("predictions" in self .output_files ):
@@ -533,7 +553,9 @@ def get_metric_fn(self, sklearn_fn: Callable, kwargs: dict | None = None) -> np.
533553 "predictions.arff" ,
534554 )
535555 response = openml ._api_calls ._download_text_file (predictions_file_url )
536- predictions_arff = arff .loads (response )
556+ if response is None :
557+ raise ValueError ("Could not download predictions ARFF content." )
558+ predictions_arff = self ._load_predictions_arff (response )
537559 # TODO: make this a stream reader
538560 else :
539561 raise ValueError (
0 commit comments