|
19 | 19 | from typing import Any, ClassVar, Literal, cast |
20 | 20 | from urllib.parse import urlparse |
21 | 21 |
|
22 | | -from openml.enums import APIVersion |
| 22 | +from openml.enums import APIVersion, ServerMode |
23 | 23 |
|
24 | 24 | from .__version__ import __version__ |
25 | 25 |
|
|
56 | 56 | }, |
57 | 57 | APIVersion.V2: { |
58 | 58 | "server": "http://localhost:8082/", |
59 | | - "apikey": "AD000000000000000000000000000000", |
| 59 | + "apikey": "normaluser", |
60 | 60 | }, |
61 | 61 | } |
62 | 62 |
|
63 | | -_SERVERS_REGISTRY: dict[str, dict[APIVersion, dict[str, str | None]]] = { |
64 | | - "production": _PROD_SERVERS, |
65 | | - "test": _TEST_SERVERS_LOCAL |
66 | | - if os.getenv("OPENML_USE_LOCAL_SERVICES") == "true" |
67 | | - else _TEST_SERVERS, |
| 63 | +_SERVERS_REGISTRY: dict[ServerMode, dict[APIVersion, dict[str, str | None]]] = { |
| 64 | + ServerMode.PRODUCTION: _PROD_SERVERS, |
| 65 | + ServerMode.TEST: ( |
| 66 | + _TEST_SERVERS_LOCAL if os.getenv("OPENML_USE_LOCAL_SERVICES") == "true" else _TEST_SERVERS |
| 67 | + ), |
68 | 68 | } |
69 | 69 |
|
70 | 70 |
|
71 | | -def _get_servers(mode: str) -> dict[APIVersion, dict[str, str | None]]: |
72 | | - if mode not in _SERVERS_REGISTRY: |
73 | | - raise ValueError( |
74 | | - f'invalid mode="{mode}" allowed modes: {", ".join(list(_SERVERS_REGISTRY.keys()))}' |
75 | | - ) |
| 71 | +def _get_servers(mode: ServerMode) -> dict[APIVersion, dict[str, str | None]]: |
| 72 | + if mode not in ServerMode: |
| 73 | + raise ValueError(f'invalid mode="{mode}" allowed modes: {", ".join(list(ServerMode))}') |
76 | 74 | return deepcopy(_SERVERS_REGISTRY[mode]) |
77 | 75 |
|
78 | 76 |
|
@@ -112,7 +110,7 @@ class OpenMLConfig: |
112 | 110 | """Dataclass storing the OpenML configuration.""" |
113 | 111 |
|
114 | 112 | servers: dict[APIVersion, dict[str, str | None]] = field( |
115 | | - default_factory=lambda: _get_servers("production") |
| 113 | + default_factory=lambda: _get_servers(ServerMode.PRODUCTION) |
116 | 114 | ) |
117 | 115 | api_version: APIVersion = APIVersion.V1 |
118 | 116 | fallback_api_version: APIVersion | None = None |
@@ -266,24 +264,24 @@ def get_server_base_url(self) -> str: |
266 | 264 | domain, _ = self._config.server.split("/api", maxsplit=1) |
267 | 265 | return domain.replace("api", "www") |
268 | 266 |
|
269 | | - def _get_servers(self, mode: str) -> dict[APIVersion, dict[str, str | None]]: |
| 267 | + def _get_servers(self, mode: ServerMode) -> dict[APIVersion, dict[str, str | None]]: |
270 | 268 | return _get_servers(mode) |
271 | 269 |
|
272 | | - def _set_servers(self, mode: str) -> None: |
| 270 | + def _set_servers(self, mode: ServerMode) -> None: |
273 | 271 | servers = self._get_servers(mode) |
274 | 272 | self._config = replace(self._config, servers=servers) |
275 | 273 |
|
276 | 274 | def get_production_servers(self) -> dict[APIVersion, dict[str, str | None]]: |
277 | | - return self._get_servers(mode="production") |
| 275 | + return self._get_servers(mode=ServerMode.PRODUCTION) |
278 | 276 |
|
279 | 277 | def get_test_servers(self) -> dict[APIVersion, dict[str, str | None]]: |
280 | | - return self._get_servers(mode="test") |
| 278 | + return self._get_servers(mode=ServerMode.TEST) |
281 | 279 |
|
282 | 280 | def use_production_servers(self) -> None: |
283 | | - self._set_servers(mode="production") |
| 281 | + self._set_servers(mode=ServerMode.PRODUCTION) |
284 | 282 |
|
285 | 283 | def use_test_servers(self) -> None: |
286 | | - self._set_servers(mode="test") |
| 284 | + self._set_servers(mode=ServerMode.TEST) |
287 | 285 |
|
288 | 286 | def set_api_version( |
289 | 287 | self, |
|
0 commit comments