55
66import openml
77from openml ._api .resources import FallbackProxy , RunV1API , RunV2API
8+ from openml .enums import APIVersion
89from openml .exceptions import OpenMLNotSupportedError
910from openml .runs .run import OpenMLRun
1011from openml .testing import TestAPIBase
1112
1213
14+ @pytest .mark .uses_test_server ()
1315class TestRunsV1 (TestAPIBase ):
1416 """Test RunsV1 resource implementation."""
1517
1618 def setUp (self ):
1719 super ().setUp ()
18- self .resource = RunV1API (self .http_client )
20+ http_client = self .http_clients [APIVersion .V1 ]
21+ self .resource = RunV1API (http_client )
1922
20- @pytest .mark .uses_test_server ()
2123 def test_get (self ):
2224 """Test getting a run from the V1 API."""
2325 run = self .resource .get (run_id = 1 )
2426
25- assert isinstance (run , OpenMLRun )
26- assert run .run_id == 1
27- assert isinstance (run .task_id , int )
27+ self . assertIsInstance (run , OpenMLRun )
28+ self . assertEqual ( run .run_id , 1 )
29+ self . assertIsInstance (run .task_id , int )
2830
29- @pytest .mark .uses_test_server ()
3031 def test_list (self ):
3132 """Test listing runs from the V1 API."""
32- runs_df = self .resource .list (limit = 5 , offset = 0 )
33-
34- assert len (runs_df ) > 0
35- assert len (runs_df ) <= 5
36- assert "run_id" in runs_df .columns
37- assert "task_id" in runs_df .columns
38- assert "setup_id" in runs_df .columns
39- assert "flow_id" in runs_df .columns
40-
41- @pytest .mark .uses_test_server ()
42- def test_publish (self ):
43- """Test publishing a small run using V1 API."""
44- from sklearn .neighbors import KNeighborsClassifier
33+ limit = 5
34+ runs_df = self .resource .list (limit = limit , offset = 0 )
4535
46- task = openml .tasks .get_task (19 )
47- clf = KNeighborsClassifier (n_neighbors = 3 )
48- run = openml .runs .run_model_on_task (clf , task )
36+ self .assertEqual (len (runs_df ), limit )
37+ self .assertIn ("run_id" , runs_df .columns )
38+ self .assertIn ("task_id" , runs_df .columns )
39+ self .assertIn ("setup_id" , runs_df .columns )
40+ self .assertIn ("flow_id" , runs_df .columns )
4941
50- file_elements = run ._get_file_elements ()
51- if "description" not in file_elements :
52- file_elements ["description" ] = run ._to_xml ()
53-
54- run_id = self .resource .publish (path = "run" , files = file_elements )
55- assert isinstance (run_id , int )
56- assert run_id > 0
57-
58- @pytest .mark .uses_test_server ()
59- def test_delete_run (self ):
60- """Test deleting a run using V1 API."""
42+ def test_delete_and_publish_run (self ):
43+ """Test publishing then deleting a run using V1 API."""
6144 # First, create and publish a run to delete
6245 from sklearn .neighbors import KNeighborsClassifier
6346
@@ -70,84 +53,47 @@ def test_delete_run(self):
7053 file_elements ["description" ] = run ._to_xml ()
7154
7255 run_id = self .resource .publish (path = "run" , files = file_elements )
73- assert isinstance (run_id , int )
74- assert run_id > 0
56+ self . assertIsInstance (run_id , int )
57+ self . assertGreater ( run_id , 0 )
7558
76- # Now delete the run
7759 self .resource .delete (run_id )
7860
79- # Verify deletion by attempting to fetch the run
8061 with pytest .raises (Exception ):
8162 self .resource .get (run_id = run_id )
8263
8364
84- class TestRunsV2 (TestAPIBase ):
65+ @pytest .mark .uses_test_server ()
66+ class TestRunsV2 (TestRunsV1 ):
8567 """Test RunsV2 resource implementation."""
8668
8769 def setUp (self ):
8870 super ().setUp ()
89- self .v2_http_client = self ._get_http_client (
90- server = "http://127.0.0.1:8001/" ,
91- base_url = "" ,
92- api_key = self .api_key ,
93- timeout = self .timeout ,
94- retries = self .retries ,
95- retry_policy = self .retry_policy ,
96- cache = self .cache ,
97- )
98- self .resource = RunV2API (self .v2_http_client )
99-
100- @pytest .mark .uses_test_server ()
101- def test_get_not_supported (self ):
102- """Test that V2 get is not implemented."""
71+ http_client = self .http_clients [APIVersion .V2 ]
72+ self .resource = RunV2API (http_client )
73+
74+ def test_get (self ):
75+ with pytest .raises (OpenMLNotSupportedError ):
76+ super ().test_get ()
77+
78+ def test_list (self ):
10379 with pytest .raises (OpenMLNotSupportedError ):
104- _ = self . resource . get ( run_id = 1 )
80+ super (). test_list ( )
10581
106- @pytest .mark .uses_test_server ()
107- def test_list_not_supported (self ):
108- """Test that V2 list is not implemented."""
82+ def test_delete_and_publish_run (self ):
10983 with pytest .raises (OpenMLNotSupportedError ):
110- _ = self . resource . list ( limit = 5 , offset = 0 )
84+ super (). test_delete_and_publish_run ( )
11185
11286
113- class TestRunsCombined (TestAPIBase ):
114- """Test fallback behavior between V2 and V1 for Runs."""
87+ @pytest .mark .uses_test_server ()
88+ class TestRunsFallback (TestRunsV1 ):
89+ """Test combined functionality and fallback between V1 and V2."""
11590
11691 def setUp (self ):
11792 super ().setUp ()
118- self .v1_client = self ._get_http_client (
119- server = self .server ,
120- base_url = self .base_url ,
121- api_key = self .api_key ,
122- timeout = self .timeout ,
123- retries = self .retries ,
124- retry_policy = self .retry_policy ,
125- cache = self .cache ,
126- )
127- self .v2_client = self ._get_http_client (
128- server = "http://127.0.0.1:8001/" ,
129- base_url = "" ,
130- api_key = self .api_key ,
131- timeout = self .timeout ,
132- retries = self .retries ,
133- retry_policy = self .retry_policy ,
134- cache = self .cache ,
135- )
136-
137- self .resource_v1 = RunV1API (self .v1_client )
138- self .resource_v2 = RunV2API (self .v2_client )
139- self .resource_fallback = FallbackProxy (self .resource_v2 , self .resource_v1 )
140-
141- @pytest .mark .uses_test_server ()
142- def test_get_fallback (self ):
143- """Test fallback for get() when V2 is not implemented."""
144- run = self .resource_fallback .get (run_id = 1 )
145- assert isinstance (run , OpenMLRun )
146- assert run .run_id == 1
147-
148- @pytest .mark .uses_test_server ()
149- def test_list_fallback (self ):
150- """Test fallback for list() when V2 is not implemented."""
151- runs_df = self .resource_fallback .list (limit = 5 , offset = 0 )
152- assert len (runs_df ) > 0
153- assert "run_id" in runs_df .columns
93+ http_client_v1 = self .http_clients [APIVersion .V1 ]
94+ resource_v1 = RunV1API (http_client_v1 )
95+
96+ http_client_v2 = self .http_clients [APIVersion .V2 ]
97+ resource_v2 = RunV2API (http_client_v2 )
98+
99+ self .resource = FallbackProxy (resource_v2 , resource_v1 )
0 commit comments