Skip to content

Commit a64203d

Browse files
committed
add run migrtation test file
1 parent e14532f commit a64203d

2 files changed

Lines changed: 187 additions & 0 deletions

File tree

openml/_api/resources/base/resources.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
from __future__ import annotations
22

3+
import builtins
4+
from abc import abstractmethod
5+
from typing import TYPE_CHECKING
6+
37
from openml.enums import ResourceType
48

59
from .base import ResourceAPI
610

11+
if TYPE_CHECKING:
12+
import pandas as pd
13+
14+
from openml.runs.run import OpenMLRun
15+
from openml.tasks.task import TaskType
16+
717

818
class DatasetAPI(ResourceAPI):
919
resource_type: ResourceType = ResourceType.DATASET
@@ -36,6 +46,30 @@ class StudyAPI(ResourceAPI):
3646
class RunAPI(ResourceAPI):
3747
resource_type: ResourceType = ResourceType.RUN
3848

49+
@abstractmethod
50+
def get(
51+
self,
52+
run_id: int,
53+
*,
54+
reset_cache: bool = False,
55+
) -> OpenMLRun: ...
56+
57+
def list( # type: ignore[valid-type] # noqa: PLR0913
58+
self,
59+
limit: int,
60+
offset: int,
61+
*,
62+
ids: builtins.list[int] | None = None,
63+
task: builtins.list[int] | None = None,
64+
setup: builtins.list[int] | None = None,
65+
flow: builtins.list[int] | None = None,
66+
uploader: builtins.list[int] | None = None,
67+
study: int | None = None,
68+
tag: str | None = None,
69+
display_errors: bool = False,
70+
task_type: TaskType | int | None = None,
71+
) -> pd.DataFrame: ...
72+
3973

4074
class SetupAPI(ResourceAPI):
4175
resource_type: ResourceType = ResourceType.SETUP

tests/test_api/test_run.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""Tests for Run V1 → V2 API Migration."""
2+
from __future__ import annotations
3+
4+
import pytest
5+
6+
import openml
7+
from openml._api.resources import FallbackProxy, RunV1API, RunV2API
8+
from openml.exceptions import OpenMLNotSupportedError
9+
from openml.runs.run import OpenMLRun
10+
from openml.testing import TestAPIBase
11+
12+
13+
class TestRunsV1(TestAPIBase):
14+
"""Test RunsV1 resource implementation."""
15+
16+
def setUp(self):
17+
super().setUp()
18+
self.resource = RunV1API(self.http_client)
19+
20+
@pytest.mark.uses_test_server()
21+
def test_get(self):
22+
"""Test getting a run from the V1 API."""
23+
run = self.resource.get(run_id=1)
24+
25+
assert isinstance(run, OpenMLRun)
26+
assert run.run_id == 1
27+
assert isinstance(run.task_id, int)
28+
29+
@pytest.mark.uses_test_server()
30+
def test_list(self):
31+
"""Test listing runs from the V1 API."""
32+
runs_df = self.resource.list(limit=5, offset=0)
33+
34+
assert len(runs_df) > 0
35+
assert len(runs_df) <= 5
36+
assert "run_id" in runs_df.columns
37+
assert "task_id" in runs_df.columns
38+
assert "setup_id" in runs_df.columns
39+
assert "flow_id" in runs_df.columns
40+
41+
@pytest.mark.uses_test_server()
42+
def test_publish(self):
43+
"""Test publishing a small run using V1 API."""
44+
from sklearn.neighbors import KNeighborsClassifier
45+
46+
task = openml.tasks.get_task(19)
47+
clf = KNeighborsClassifier(n_neighbors=3)
48+
run = openml.runs.run_model_on_task(clf, task)
49+
50+
file_elements = run._get_file_elements()
51+
if "description" not in file_elements:
52+
file_elements["description"] = run._to_xml()
53+
54+
run_id = self.resource.publish(path="run", files=file_elements)
55+
assert isinstance(run_id, int)
56+
assert run_id > 0
57+
58+
@pytest.mark.uses_test_server()
59+
def test_delete_run(self):
60+
"""Test deleting a run using V1 API."""
61+
# First, create and publish a run to delete
62+
from sklearn.neighbors import KNeighborsClassifier
63+
64+
task = openml.tasks.get_task(19)
65+
clf = KNeighborsClassifier(n_neighbors=3)
66+
run = openml.runs.run_model_on_task(clf, task)
67+
68+
file_elements = run._get_file_elements()
69+
if "description" not in file_elements:
70+
file_elements["description"] = run._to_xml()
71+
72+
run_id = self.resource.publish(path="run", files=file_elements)
73+
assert isinstance(run_id, int)
74+
assert run_id > 0
75+
76+
# Now delete the run
77+
self.resource.delete(run_id)
78+
79+
# Verify deletion by attempting to fetch the run
80+
with pytest.raises(Exception):
81+
self.resource.get(run_id=run_id)
82+
83+
84+
class TestRunsV2(TestAPIBase):
85+
"""Test RunsV2 resource implementation."""
86+
87+
def setUp(self):
88+
super().setUp()
89+
self.v2_http_client = self._get_http_client(
90+
server="http://127.0.0.1:8001/",
91+
base_url="",
92+
api_key=self.api_key,
93+
timeout=self.timeout,
94+
retries=self.retries,
95+
retry_policy=self.retry_policy,
96+
cache=self.cache,
97+
)
98+
self.resource = RunV2API(self.v2_http_client)
99+
100+
@pytest.mark.uses_test_server()
101+
def test_get_not_supported(self):
102+
"""Test that V2 get is not implemented."""
103+
with pytest.raises(OpenMLNotSupportedError):
104+
_ = self.resource.get(run_id=1)
105+
106+
@pytest.mark.uses_test_server()
107+
def test_list_not_supported(self):
108+
"""Test that V2 list is not implemented."""
109+
with pytest.raises(OpenMLNotSupportedError):
110+
_ = self.resource.list(limit=5, offset=0)
111+
112+
113+
class TestRunsCombined(TestAPIBase):
114+
"""Test fallback behavior between V2 and V1 for Runs."""
115+
116+
def setUp(self):
117+
super().setUp()
118+
self.v1_client = self._get_http_client(
119+
server=self.server,
120+
base_url=self.base_url,
121+
api_key=self.api_key,
122+
timeout=self.timeout,
123+
retries=self.retries,
124+
retry_policy=self.retry_policy,
125+
cache=self.cache,
126+
)
127+
self.v2_client = self._get_http_client(
128+
server="http://127.0.0.1:8001/",
129+
base_url="",
130+
api_key=self.api_key,
131+
timeout=self.timeout,
132+
retries=self.retries,
133+
retry_policy=self.retry_policy,
134+
cache=self.cache,
135+
)
136+
137+
self.resource_v1 = RunV1API(self.v1_client)
138+
self.resource_v2 = RunV2API(self.v2_client)
139+
self.resource_fallback = FallbackProxy(self.resource_v2, self.resource_v1)
140+
141+
@pytest.mark.uses_test_server()
142+
def test_get_fallback(self):
143+
"""Test fallback for get() when V2 is not implemented."""
144+
run = self.resource_fallback.get(run_id=1)
145+
assert isinstance(run, OpenMLRun)
146+
assert run.run_id == 1
147+
148+
@pytest.mark.uses_test_server()
149+
def test_list_fallback(self):
150+
"""Test fallback for list() when V2 is not implemented."""
151+
runs_df = self.resource_fallback.list(limit=5, offset=0)
152+
assert len(runs_df) > 0
153+
assert "run_id" in runs_df.columns

0 commit comments

Comments
 (0)