Skip to content

Commit 79db213

Browse files
committed
merge the final changes
1 parent fe8f66b commit 79db213

7 files changed

Lines changed: 26 additions & 182 deletions

File tree

openml/_api/config.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

openml/_api/resources/base/resources.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import builtins
34
from abc import abstractmethod
45
from typing import TYPE_CHECKING, Any
56

openml/_api/resources/base/versions.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Mapping
4-
from typing import Any, cast
4+
from typing import Any
55
from xml.parsers.expat import ExpatError
66

77
import xmltodict
@@ -189,7 +189,8 @@ def untag(self, resource_id: int, tag: str) -> list[str]:
189189

190190
def _parse_xml_response(self, payload: bytes | str, **kwargs: Any) -> Mapping[str, Any]:
191191
try:
192-
return cast("Mapping[str, Any]", xmltodict.parse(payload, **kwargs))
192+
parsed_response: Mapping[str, Any] = xmltodict.parse(payload, **kwargs)
193+
return parsed_response
193194
except ExpatError:
194195
payload_text = (
195196
payload.decode("utf-8", errors="ignore") if isinstance(payload, bytes) else payload
@@ -201,12 +202,16 @@ def _parse_xml_response(self, payload: bytes | str, **kwargs: Any) -> Mapping[st
201202
raise
202203

203204
xml_text = payload_text[xml_start:]
204-
return cast("Mapping[str, Any]", xmltodict.parse(xml_text, **kwargs))
205+
parsed_fallback: Mapping[str, Any] = xmltodict.parse(xml_text, **kwargs)
206+
return parsed_fallback
205207

206208
def _get_endpoint_name(self) -> str:
207209
if self.resource_type == ResourceType.DATASET:
208210
return "data"
209-
return cast("str", self.resource_type.value)
211+
endpoint_name = self.resource_type.value
212+
if not isinstance(endpoint_name, str):
213+
raise TypeError(f"Unexpected endpoint type: {type(endpoint_name)}")
214+
return endpoint_name
210215

211216
def _extract_id_from_upload(self, parsed: Mapping[str, Any]) -> int:
212217
"""
@@ -280,4 +285,7 @@ def untag(self, resource_id: int, tag: str) -> list[str]: # noqa: ARG002
280285
self._not_supported(method="untag")
281286

282287
def _get_endpoint_name(self) -> str:
283-
return cast("str", self.resource_type.value)
288+
endpoint_name = self.resource_type.value
289+
if not isinstance(endpoint_name, str):
290+
raise TypeError(f"Unexpected endpoint type: {type(endpoint_name)}")
291+
return endpoint_name

openml/runs/functions.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import warnings
77
from collections import OrderedDict
88
from functools import partial
9-
from typing import TYPE_CHECKING, Any, cast
9+
from typing import TYPE_CHECKING, Any
1010

1111
import numpy as np
1212
import pandas as pd
@@ -812,10 +812,6 @@ def get_run(run_id: int, ignore_cache: bool = False) -> OpenMLRun: # noqa: FBT0
812812
----------
813813
run_id : int
814814
815-
ignore_cache : bool
816-
Whether to ignore the cache. If ``true`` this will download and overwrite the run xml
817-
even if the requested run is already cached.
818-
819815
ignore_cache : bool
820816
Whether to ignore the cache. If ``true`` this will download and overwrite the run xml
821817
even if the requested run is already cached.
@@ -825,12 +821,9 @@ def get_run(run_id: int, ignore_cache: bool = False) -> OpenMLRun: # noqa: FBT0
825821
run : OpenMLRun
826822
Run corresponding to ID, fetched from the server.
827823
"""
828-
return cast(
829-
"OpenMLRun",
830-
openml._backend.run.get(
831-
run_id,
832-
reset_cache=ignore_cache,
833-
),
824+
return openml._backend.run.get(
825+
run_id,
826+
reset_cache=ignore_cache,
834827
)
835828

836829

@@ -906,15 +899,7 @@ def obtain_field(xml_obj, fieldname, from_server, cast=None): # type: ignore
906899
run_details = obtain_field(run, "oml:run_details", from_server=False)
907900

908901
if "oml:input_data" in run:
909-
input_data = run["oml:input_data"]
910-
if isinstance(input_data, list):
911-
input_data = input_data[0]
912-
913-
dataset_data = input_data["oml:dataset"]
914-
if isinstance(dataset_data, list):
915-
dataset_data = dataset_data[0]
916-
917-
dataset_id = int(dataset_data["oml:did"])
902+
dataset_id = int(run["oml:input_data"]["oml:dataset"]["oml:did"])
918903
elif not from_server:
919904
dataset_id = None
920905
else:
@@ -1311,4 +1296,4 @@ def delete_run(run_id: int) -> bool:
13111296
bool
13121297
True if the deletion was successful. False otherwise.
13131298
"""
1314-
return cast("bool", openml._backend.run.delete(run_id))
1299+
return openml._backend.run.delete(run_id)

openml/runs/run.py

Lines changed: 2 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# License: BSD 3-Clause
22
from __future__ import annotations
33

4-
import csv
54
import pickle
65
import time
76
from collections import OrderedDict
@@ -10,7 +9,6 @@
109
from typing import (
1110
TYPE_CHECKING,
1211
Any,
13-
cast,
1412
)
1513

1614
import arff
@@ -163,7 +161,7 @@ def predictions(self) -> pd.DataFrame:
163161
arff_text = openml._api_calls._download_text_file(self.predictions_url)
164162
if arff_text is None:
165163
raise RuntimeError("Could not download predictions ARFF content.")
166-
arff_dict = self._load_predictions_arff(arff_text)
164+
arff_dict = arff.loads(arff_text)
167165
else:
168166
raise RuntimeError("Run has no predictions.")
169167
self._predictions = pd.DataFrame(
@@ -172,68 +170,6 @@ def predictions(self) -> pd.DataFrame:
172170
)
173171
return self._predictions
174172

175-
@staticmethod
176-
def _load_predictions_arff(arff_text: str) -> dict[str, Any]:
177-
try:
178-
return cast("dict[str, Any]", arff.loads(arff_text))
179-
except arff.ArffException:
180-
normalized = arff_text.lstrip("\ufeff \t\r\n")
181-
relation_indexes = [
182-
idx
183-
for idx in [normalized.find("@relation"), normalized.find("@RELATION")]
184-
if idx >= 0
185-
]
186-
if relation_indexes:
187-
arff_candidate = normalized[min(relation_indexes) :]
188-
try:
189-
return cast("dict[str, Any]", arff.loads(arff_candidate))
190-
except arff.ArffException:
191-
sanitized = OpenMLRun._sanitize_arff_text(arff_candidate)
192-
return cast("dict[str, Any]", arff.loads(sanitized))
193-
raise
194-
195-
@staticmethod
196-
def _sanitize_arff_text(arff_text: str) -> str:
197-
lines = arff_text.splitlines()
198-
199-
in_data = False
200-
attribute_count = 0
201-
cleaned_lines: list[str] = []
202-
203-
for line in lines:
204-
stripped = line.strip()
205-
lowered = stripped.lower()
206-
207-
if not in_data:
208-
if lowered.startswith("@attribute"):
209-
attribute_count += 1
210-
if lowered.startswith("@data"):
211-
in_data = True
212-
cleaned_lines.append(line)
213-
continue
214-
215-
if stripped == "" or stripped.startswith("%"):
216-
cleaned_lines.append(line)
217-
continue
218-
219-
if stripped.startswith("{"):
220-
cleaned_lines.append(line)
221-
continue
222-
223-
parsed_fields = next(
224-
csv.reader(
225-
[line],
226-
delimiter=",",
227-
quotechar="'",
228-
skipinitialspace=True,
229-
)
230-
)
231-
232-
if len(parsed_fields) == attribute_count:
233-
cleaned_lines.append(line)
234-
235-
return "\n".join(cleaned_lines) + "\n"
236-
237173
@property
238174
def id(self) -> int | None:
239175
"""The ID of the run, None if not uploaded to the server yet."""
@@ -603,7 +539,7 @@ def get_metric_fn(self, sklearn_fn: Callable, kwargs: dict | None = None) -> np.
603539
response = openml._api_calls._download_text_file(predictions_file_url)
604540
if response is None:
605541
raise ValueError("Could not download predictions ARFF content.")
606-
predictions_arff = self._load_predictions_arff(response)
542+
predictions_arff = arff.loads(response)
607543
# TODO: make this a stream reader
608544
else:
609545
raise ValueError(

tests/test_api/test_run.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from openml.runs.run import OpenMLRun
1313

1414

15-
TEST_RUN_ID = 24
15+
TEST_RUN_ID = 1
1616

1717

1818
@pytest.fixture
@@ -32,20 +32,9 @@ def _assert_run_shape(run: OpenMLRun) -> None:
3232
assert isinstance(run.task_id, int)
3333

3434

35-
def _get_any_run_id(run_v1: RunV1API) -> int:
36-
try:
37-
run_v1.get(run_id=TEST_RUN_ID)
38-
return TEST_RUN_ID
39-
except Exception:
40-
runs_df = run_v1.list(limit=1, offset=0)
41-
if runs_df.empty:
42-
pytest.skip("No runs available on configured test server")
43-
return int(runs_df.iloc[0]["run_id"])
44-
45-
4635
@pytest.mark.test_server()
4736
def test_run_v1_get(run_v1):
48-
run = run_v1.get(run_id=_get_any_run_id(run_v1))
37+
run = run_v1.get(run_id=TEST_RUN_ID)
4938
_assert_run_shape(run)
5039

5140

@@ -133,20 +122,3 @@ def test_run_v2_publish_not_supported(run_v2):
133122
match="RunV2API: v2 API does not support `publish` for resource `run`",
134123
):
135124
run_v2.publish(path="run", files={"description": "<run/>"})
136-
137-
138-
@pytest.mark.test_server()
139-
def test_run_v1_v2_contracts(run_v1, run_v2):
140-
run_id = _get_any_run_id(run_v1)
141-
142-
run_from_v1 = run_v1.get(run_id=run_id)
143-
_assert_run_shape(run_from_v1)
144-
145-
with pytest.raises(OpenMLNotSupportedError, match="does not support `get`"):
146-
run_v2.get(run_id=run_id)
147-
148-
with pytest.raises(OpenMLNotSupportedError, match="does not support `list`"):
149-
run_v2.list(limit=5, offset=0)
150-
151-
with pytest.raises(OpenMLNotSupportedError, match="does not support `publish`"):
152-
run_v2.publish(path="run", files={"description": "<run/>"})

tests/test_runs/test_run_functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,7 @@ def _test_local_evaluations(self, run):
10251025
assert alt_scores[idx] <= 1
10261026

10271027
@pytest.mark.sklearn()
1028+
@pytest.mark.skip(reason="https://github.com/openml/openml-python/issues/1586")
10281029
@pytest.mark.test_server()
10291030
def test_local_run_swapped_parameter_order_model(self):
10301031
clf = DecisionTreeClassifier()
@@ -1074,6 +1075,7 @@ def test_local_run_swapped_parameter_order_flow(self):
10741075
Version(sklearn.__version__) < Version("0.20"),
10751076
reason="SimpleImputer doesn't handle mixed type DataFrame as input",
10761077
)
1078+
@pytest.mark.skip(reason="https://github.com/openml/openml-python/issues/1586")
10771079
@pytest.mark.test_server()
10781080
def test_local_run_metric_score(self):
10791081
# construct sci-kit learn classifier

0 commit comments

Comments
 (0)