|
1 | 1 | # License: BSD 3-Clause |
2 | 2 | from __future__ import annotations |
3 | 3 |
|
| 4 | +import csv |
4 | 5 | import pickle |
5 | 6 | import time |
6 | 7 | from collections import OrderedDict |
@@ -183,9 +184,56 @@ def _load_predictions_arff(arff_text: str) -> dict[str, Any]: |
183 | 184 | if idx >= 0 |
184 | 185 | ] |
185 | 186 | if relation_indexes: |
186 | | - return cast("dict[str, Any]", arff.loads(normalized[min(relation_indexes) :])) |
| 187 | + arff_candidate = normalized[min(relation_indexes) :] |
| 188 | + try: |
| 189 | + return cast("dict[str, Any]", arff.loads(arff_candidate)) |
| 190 | + except arff.ArffException: |
| 191 | + sanitized = OpenMLRun._sanitize_arff_text(arff_candidate) |
| 192 | + return cast("dict[str, Any]", arff.loads(sanitized)) |
187 | 193 | raise |
188 | 194 |
|
| 195 | + @staticmethod |
| 196 | + def _sanitize_arff_text(arff_text: str) -> str: |
| 197 | + lines = arff_text.splitlines() |
| 198 | + |
| 199 | + in_data = False |
| 200 | + attribute_count = 0 |
| 201 | + cleaned_lines: list[str] = [] |
| 202 | + |
| 203 | + for line in lines: |
| 204 | + stripped = line.strip() |
| 205 | + lowered = stripped.lower() |
| 206 | + |
| 207 | + if not in_data: |
| 208 | + if lowered.startswith("@attribute"): |
| 209 | + attribute_count += 1 |
| 210 | + if lowered.startswith("@data"): |
| 211 | + in_data = True |
| 212 | + cleaned_lines.append(line) |
| 213 | + continue |
| 214 | + |
| 215 | + if stripped == "" or stripped.startswith("%"): |
| 216 | + cleaned_lines.append(line) |
| 217 | + continue |
| 218 | + |
| 219 | + if stripped.startswith("{"): |
| 220 | + cleaned_lines.append(line) |
| 221 | + continue |
| 222 | + |
| 223 | + parsed_fields = next( |
| 224 | + csv.reader( |
| 225 | + [line], |
| 226 | + delimiter=",", |
| 227 | + quotechar="'", |
| 228 | + skipinitialspace=True, |
| 229 | + ) |
| 230 | + ) |
| 231 | + |
| 232 | + if len(parsed_fields) == attribute_count: |
| 233 | + cleaned_lines.append(line) |
| 234 | + |
| 235 | + return "\n".join(cleaned_lines) + "\n" |
| 236 | + |
189 | 237 | @property |
190 | 238 | def id(self) -> int | None: |
191 | 239 | """The ID of the run, None if not uploaded to the server yet.""" |
|
0 commit comments