Skip to content

Commit d716ecf

Browse files
committed
update server methods in config
1 parent b1a9e7f commit d716ecf

5 files changed

Lines changed: 121 additions & 76 deletions

File tree

openml/_config.py

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -27,39 +27,54 @@
2727
openml_logger = logging.getLogger("openml")
2828

2929

30-
_SERVERS_REGISTRY: dict[str, dict[APIVersion, dict[str, str | None]]] = {
31-
"production": {
32-
APIVersion.V1: {
33-
"server": "https://www.openml.org/api/v1/xml/",
34-
"apikey": None,
35-
},
36-
APIVersion.V2: {
37-
"server": None,
38-
"apikey": None,
39-
},
30+
_PROD_SERVERS: dict[APIVersion, dict[str, str | None]] = {
31+
APIVersion.V1: {
32+
"server": "https://www.openml.org/api/v1/xml/",
33+
"apikey": None,
34+
},
35+
APIVersion.V2: {
36+
"server": None,
37+
"apikey": None,
38+
},
39+
}
40+
41+
_TEST_SERVERS: dict[APIVersion, dict[str, str | None]] = {
42+
APIVersion.V1: {
43+
"server": "https://test.openml.org/api/v1/xml/",
44+
"apikey": "normaluser",
4045
},
41-
"test": {
42-
APIVersion.V1: {
43-
"server": "https://test.openml.org/api/v1/xml/",
44-
"apikey": "normaluser",
45-
},
46-
APIVersion.V2: {
47-
"server": None,
48-
"apikey": None,
49-
},
46+
APIVersion.V2: {
47+
"server": None,
48+
"apikey": None,
5049
},
51-
"local": {
52-
APIVersion.V1: {
53-
"server": "http://localhost:8080/api/v1/xml/",
54-
"apikey": "normaluser",
55-
},
56-
APIVersion.V2: {
57-
"server": "http://localhost:8082/",
58-
"apikey": "AD000000000000000000000000000000",
59-
},
50+
}
51+
52+
_TEST_SERVERS_LOCAL: dict[APIVersion, dict[str, str | None]] = {
53+
APIVersion.V1: {
54+
"server": "http://localhost:8080/api/v1/xml/",
55+
"apikey": "normaluser",
56+
},
57+
APIVersion.V2: {
58+
"server": "http://localhost:8082/",
59+
"apikey": "AD000000000000000000000000000000",
6060
},
6161
}
6262

63+
_SERVERS_REGISTRY: dict[str, dict[APIVersion, dict[str, str | None]]] = {
64+
"production": _PROD_SERVERS,
65+
"test": _TEST_SERVERS_LOCAL
66+
if os.getenv("OPENML_USE_LOCAL_SERVICES") == "true"
67+
else _TEST_SERVERS,
68+
}
69+
70+
71+
def _get_servers(mode: str) -> dict[APIVersion, dict[str, str | None]]:
72+
if mode not in _SERVERS_REGISTRY:
73+
raise ValueError(
74+
f'invalid mode="{mode}" allowed modes: {", ".join(list(_SERVERS_REGISTRY.keys()))}'
75+
)
76+
return deepcopy(_SERVERS_REGISTRY[mode])
77+
6378

6479
def _resolve_default_cache_dir() -> Path:
6580
user_defined_cache_dir = os.environ.get("OPENML_CACHE_DIR")
@@ -97,7 +112,7 @@ class OpenMLConfig:
97112
"""Dataclass storing the OpenML configuration."""
98113

99114
servers: dict[APIVersion, dict[str, str | None]] = field(
100-
default_factory=lambda: deepcopy(_SERVERS_REGISTRY["production"])
115+
default_factory=lambda: _get_servers("production")
101116
)
102117
api_version: APIVersion = APIVersion.V1
103118
fallback_api_version: APIVersion | None = None
@@ -251,17 +266,25 @@ def get_server_base_url(self) -> str:
251266
domain, _ = self._config.server.split("/api", maxsplit=1)
252267
return domain.replace("api", "www")
253268

254-
def get_servers(self, mode: str) -> dict[APIVersion, dict[str, str | None]]:
255-
if mode not in _SERVERS_REGISTRY:
256-
raise ValueError(
257-
f'invalid mode="{mode}" allowed modes: {", ".join(list(_SERVERS_REGISTRY.keys()))}'
258-
)
259-
return deepcopy(_SERVERS_REGISTRY[mode])
269+
def _get_servers(self, mode: str) -> dict[APIVersion, dict[str, str | None]]:
270+
return _get_servers(mode)
260271

261-
def set_servers(self, mode: str) -> None:
262-
servers = self.get_servers(mode)
272+
def _set_servers(self, mode: str) -> None:
273+
servers = self._get_servers(mode)
263274
self._config = replace(self._config, servers=servers)
264275

276+
def get_production_servers(self) -> dict[APIVersion, dict[str, str | None]]:
277+
return self._get_servers(mode="production")
278+
279+
def get_test_servers(self) -> dict[APIVersion, dict[str, str | None]]:
280+
return self._get_servers(mode="test")
281+
282+
def use_production_servers(self) -> None:
283+
self._set_servers(mode="production")
284+
285+
def use_test_servers(self) -> None:
286+
self._set_servers(mode="test")
287+
265288
def set_api_version(
266289
self,
267290
api_version: APIVersion,
@@ -498,7 +521,7 @@ class ConfigurationForExamples:
498521

499522
def __init__(self, manager: OpenMLConfigManager):
500523
self._manager = manager
501-
self._test_servers = manager.get_servers("test")
524+
self._test_servers = manager.get_test_servers()
502525

503526
def start_using_configuration_for_example(self) -> None:
504527
"""Sets the configuration to connect to the test server with valid apikey.

openml/cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ def check_server(server: str) -> str:
112112

113113
def replace_shorthand(server: str) -> str:
114114
if server == "test":
115-
return cast("str", openml.config.get_servers("test")[APIVersion.V1]["server"])
115+
return cast("str", openml.config.get_test_servers()[APIVersion.V1]["server"])
116116
if server == "production_server":
117-
return cast("str", openml.config.get_servers("production")[APIVersion.V1]["server"])
117+
return cast("str", openml.config.get_production_servers()[APIVersion.V1]["server"])
118118
return server
119119

120120
configure_field(

openml/testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def use_production_server(self) -> None:
124124
125125
Please use this sparingly - it is better to use the test server.
126126
"""
127-
openml.config.set_servers("production")
127+
openml.config.use_production_servers()
128128

129129
def tearDown(self) -> None:
130130
"""Tear down the test"""

tests/conftest.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def delete_remote_files(tracker, flow_names) -> None:
9999
:param tracker: Dict
100100
:return: None
101101
"""
102-
openml.config.set_servers("test")
102+
openml.config.use_test_servers()
103103

104104
# reordering to delete sub flows at the end of flows
105105
# sub-flows have shorter names, hence, sorting by descending order of flow name length
@@ -252,22 +252,22 @@ def test_files_directory() -> Path:
252252

253253
@pytest.fixture(scope="session")
254254
def test_server_v1() -> str:
255-
return openml.config.get_servers("test")[APIVersion.V1]["server"]
255+
return openml.config.get_test_servers()[APIVersion.V1]["server"]
256256

257257

258258
@pytest.fixture(scope="session")
259259
def test_apikey_v1() -> str:
260-
return openml.config.get_servers("test")[APIVersion.V1]["apikey"]
260+
return openml.config.get_test_servers()[APIVersion.V1]["apikey"]
261261

262262

263263
@pytest.fixture(scope="session")
264264
def test_server_v2() -> str:
265-
return openml.config.get_servers("test")[APIVersion.V2]["server"]
265+
return openml.config.get_test_servers()[APIVersion.V2]["server"]
266266

267267

268268
@pytest.fixture(scope="session")
269269
def test_apikey_v2() -> str:
270-
return openml.config.get_servers("test")[APIVersion.V2]["apikey"]
270+
return openml.config.get_test_servers()[APIVersion.V2]["apikey"]
271271

272272

273273
@pytest.fixture(autouse=True, scope="function")
@@ -288,18 +288,11 @@ def as_robot() -> Iterator[None]:
288288

289289
@pytest.fixture(autouse=True)
290290
def with_server(request):
291-
openml.config.set_api_version(APIVersion.V1)
292-
293291
if "production_server" in request.keywords:
294-
# use-production-server (remote)
295-
openml.config.set_servers("production")
296-
elif os.getenv("OPENML_USE_LOCAL_SERVICES") == "true":
297-
# use-test-server (local)
298-
openml.config.set_servers("local")
299-
else:
300-
# use-test-server (remote)
301-
openml.config.set_servers("test")
302-
292+
openml.config.use_production_servers()
293+
yield
294+
return
295+
openml.config.use_test_servers()
303296
yield
304297

305298

tests/test_openml/test_config.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_get_config_as_dict(self):
8181
_config = {}
8282
_config["api_version"] = APIVersion.V1
8383
_config["fallback_api_version"] = None
84-
_config["servers"] = openml.config.get_servers("test")
84+
_config["servers"] = openml.config.get_test_servers()
8585
_config["cachedir"] = self.workdir
8686
_config["avoid_duplicate_runs"] = False
8787
_config["connection_n_retries"] = 20
@@ -96,7 +96,7 @@ def test_setup_with_config(self):
9696
_config = {}
9797
_config["api_version"] = APIVersion.V1
9898
_config["fallback_api_version"] = None
99-
_config["servers"] = openml.config.get_servers("test")
99+
_config["servers"] = openml.config.get_test_servers()
100100
_config["cachedir"] = self.workdir
101101
_config["avoid_duplicate_runs"] = True
102102
_config["retry_policy"] = "human"
@@ -113,21 +113,22 @@ class TestConfigurationForExamples(openml.testing.TestBase):
113113
@pytest.mark.production_server()
114114
def test_switch_to_example_configuration(self):
115115
"""Verifies the test configuration is loaded properly."""
116-
openml.config.set_servers("production")
116+
openml.config.use_production_servers()
117117

118118
openml.config.start_using_configuration_for_example()
119119

120-
openml.config.servers = openml.config.get_servers("test")
120+
assert openml.config.servers == openml.config.get_test_servers()
121121

122122
@pytest.mark.production_server()
123123
def test_switch_from_example_configuration(self):
124124
"""Verifies the previous configuration is loaded after stopping."""
125125
# Below is the default test key which would be used anyway, but just for clarity:
126-
openml.config.set_servers("production")
126+
openml.config.use_production_servers()
127127

128128
openml.config.start_using_configuration_for_example()
129129
openml.config.stop_using_configuration_for_example()
130-
openml.config.servers = openml.config.get_servers("production")
130+
131+
assert openml.config.servers == openml.config.get_production_servers()
131132

132133
def test_example_configuration_stop_before_start(self):
133134
"""Verifies an error is raised if `stop_...` is called before `start_...`."""
@@ -144,13 +145,13 @@ def test_example_configuration_stop_before_start(self):
144145
@pytest.mark.production_server()
145146
def test_example_configuration_start_twice(self):
146147
"""Checks that the original config can be returned to if `start..` is called twice."""
147-
openml.config.set_servers("production")
148+
openml.config.use_production_servers()
148149

149150
openml.config.start_using_configuration_for_example()
150151
openml.config.start_using_configuration_for_example()
151152
openml.config.stop_using_configuration_for_example()
152153

153-
assert openml.config.servers == openml.config.get_servers("production")
154+
assert openml.config.servers == openml.config.get_production_servers()
154155

155156

156157
def test_configuration_file_not_overwritten_on_load():
@@ -192,28 +193,28 @@ def test_openml_cache_dir_env_var(tmp_path: Path) -> None:
192193
assert openml.config.get_cache_directory() == str(expected_path / "org" / "openml" / "www")
193194

194195

195-
@pytest.mark.parametrize("mode", ["production", "test", "local"])
196+
@pytest.mark.parametrize("mode", ["production", "test"])
196197
@pytest.mark.parametrize("api_version", [APIVersion.V1, APIVersion.V2])
197198
def test_get_servers(mode, api_version):
198-
orig_servers = openml.config.get_servers(mode)
199+
orig_servers = openml.config._get_servers(mode)
199200

200-
openml.config.set_servers(mode)
201+
openml.config._set_servers(mode)
201202
openml.config.set_api_version(api_version)
202203
openml.config.server = "temp-server1"
203204
openml.config.apikey = "temp-apikey1"
204-
openml.config.get_servers(mode)["server"] = 'temp-server2'
205-
openml.config.get_servers(mode)["apikey"] = 'temp-server2'
205+
openml.config._get_servers(mode)["server"] = 'temp-server2'
206+
openml.config._get_servers(mode)["apikey"] = 'temp-server2'
206207

207-
assert openml.config.get_servers(mode) == orig_servers
208+
assert openml.config._get_servers(mode) == orig_servers
208209

209210

210-
@pytest.mark.parametrize("mode", ["production", "test", "local"])
211+
@pytest.mark.parametrize("mode", ["production", "test"])
211212
@pytest.mark.parametrize("api_version", [APIVersion.V1, APIVersion.V2])
212213
def test_set_servers(mode, api_version):
213-
openml.config.set_servers(mode)
214+
openml.config._set_servers(mode)
214215
openml.config.set_api_version(api_version)
215216

216-
assert openml.config.servers == openml.config.get_servers(mode)
217+
assert openml.config.servers == openml.config._get_servers(mode)
217218
assert openml.config.api_version == api_version
218219

219220
openml.config.server = "temp-server"
@@ -224,6 +225,34 @@ def test_set_servers(mode, api_version):
224225

225226
for version, servers in openml.config.servers.items():
226227
if version == api_version:
227-
assert servers != openml.config.get_servers(mode)[version]
228+
assert servers != openml.config._get_servers(mode)[version]
228229
else:
229-
assert servers == openml.config.get_servers(mode)[version]
230+
assert servers == openml.config._get_servers(mode)[version]
231+
232+
233+
def test_get_production_servers():
234+
assert openml.config.get_production_servers() == openml.config._get_servers("production")
235+
236+
237+
def test_get_test_servers():
238+
assert openml.config.get_test_servers() == openml.config._get_servers("test")
239+
240+
241+
def test_use_production_servers():
242+
openml.config.use_production_servers()
243+
servers_1 = openml.config.servers
244+
245+
openml.config._set_servers("production")
246+
servers_2 = openml.config.servers
247+
248+
assert servers_1 == servers_2
249+
250+
251+
def test_use_test_servers():
252+
openml.config.use_test_servers()
253+
servers_1 = openml.config.servers
254+
255+
openml.config._set_servers("test")
256+
servers_2 = openml.config.servers
257+
258+
assert servers_1 == servers_2

0 commit comments

Comments
 (0)