11# License: BSD 3-Clause
22from __future__ import annotations
33
4- import csv
54import pickle
65import time
76from collections import OrderedDict
109from typing import (
1110 TYPE_CHECKING ,
1211 Any ,
13- cast ,
1412)
1513
1614import arff
@@ -163,7 +161,7 @@ def predictions(self) -> pd.DataFrame:
163161 arff_text = openml ._api_calls ._download_text_file (self .predictions_url )
164162 if arff_text is None :
165163 raise RuntimeError ("Could not download predictions ARFF content." )
166- arff_dict = self . _load_predictions_arff (arff_text )
164+ arff_dict = arff . loads (arff_text )
167165 else :
168166 raise RuntimeError ("Run has no predictions." )
169167 self ._predictions = pd .DataFrame (
@@ -172,68 +170,6 @@ def predictions(self) -> pd.DataFrame:
172170 )
173171 return self ._predictions
174172
175- @staticmethod
176- def _load_predictions_arff (arff_text : str ) -> dict [str , Any ]:
177- try :
178- return cast ("dict[str, Any]" , arff .loads (arff_text ))
179- except arff .ArffException :
180- normalized = arff_text .lstrip ("\ufeff \t \r \n " )
181- relation_indexes = [
182- idx
183- for idx in [normalized .find ("@relation" ), normalized .find ("@RELATION" )]
184- if idx >= 0
185- ]
186- if relation_indexes :
187- arff_candidate = normalized [min (relation_indexes ) :]
188- try :
189- return cast ("dict[str, Any]" , arff .loads (arff_candidate ))
190- except arff .ArffException :
191- sanitized = OpenMLRun ._sanitize_arff_text (arff_candidate )
192- return cast ("dict[str, Any]" , arff .loads (sanitized ))
193- raise
194-
195- @staticmethod
196- def _sanitize_arff_text (arff_text : str ) -> str :
197- lines = arff_text .splitlines ()
198-
199- in_data = False
200- attribute_count = 0
201- cleaned_lines : list [str ] = []
202-
203- for line in lines :
204- stripped = line .strip ()
205- lowered = stripped .lower ()
206-
207- if not in_data :
208- if lowered .startswith ("@attribute" ):
209- attribute_count += 1
210- if lowered .startswith ("@data" ):
211- in_data = True
212- cleaned_lines .append (line )
213- continue
214-
215- if stripped == "" or stripped .startswith ("%" ):
216- cleaned_lines .append (line )
217- continue
218-
219- if stripped .startswith ("{" ):
220- cleaned_lines .append (line )
221- continue
222-
223- parsed_fields = next (
224- csv .reader (
225- [line ],
226- delimiter = "," ,
227- quotechar = "'" ,
228- skipinitialspace = True ,
229- )
230- )
231-
232- if len (parsed_fields ) == attribute_count :
233- cleaned_lines .append (line )
234-
235- return "\n " .join (cleaned_lines ) + "\n "
236-
237173 @property
238174 def id (self ) -> int | None :
239175 """The ID of the run, None if not uploaded to the server yet."""
@@ -603,7 +539,7 @@ def get_metric_fn(self, sklearn_fn: Callable, kwargs: dict | None = None) -> np.
603539 response = openml ._api_calls ._download_text_file (predictions_file_url )
604540 if response is None :
605541 raise ValueError ("Could not download predictions ARFF content." )
606- predictions_arff = self . _load_predictions_arff (response )
542+ predictions_arff = arff . loads (response )
607543 # TODO: make this a stream reader
608544 else :
609545 raise ValueError (
0 commit comments