Skip to content

Commit 1ecbbba

Browse files
committed
update config for ServerMode
1 parent 7d61107 commit 1ecbbba

1 file changed

Lines changed: 16 additions & 18 deletions

File tree

openml/_config.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Any, ClassVar, Literal, cast
2020
from urllib.parse import urlparse
2121

22-
from openml.enums import APIVersion
22+
from openml.enums import APIVersion, ServerMode
2323

2424
from .__version__ import __version__
2525

@@ -60,19 +60,17 @@
6060
},
6161
}
6262

63-
_SERVERS_REGISTRY: dict[str, dict[APIVersion, dict[str, str | None]]] = {
64-
"production": _PROD_SERVERS,
65-
"test": _TEST_SERVERS_LOCAL
66-
if os.getenv("OPENML_USE_LOCAL_SERVICES") == "true"
67-
else _TEST_SERVERS,
63+
_SERVERS_REGISTRY: dict[ServerMode, dict[APIVersion, dict[str, str | None]]] = {
64+
ServerMode.PRODUCTION: _PROD_SERVERS,
65+
ServerMode.TEST: (
66+
_TEST_SERVERS_LOCAL if os.getenv("OPENML_USE_LOCAL_SERVICES") == "true" else _TEST_SERVERS
67+
),
6868
}
6969

7070

71-
def _get_servers(mode: str) -> dict[APIVersion, dict[str, str | None]]:
72-
if mode not in _SERVERS_REGISTRY:
73-
raise ValueError(
74-
f'invalid mode="{mode}" allowed modes: {", ".join(list(_SERVERS_REGISTRY.keys()))}'
75-
)
71+
def _get_servers(mode: ServerMode) -> dict[APIVersion, dict[str, str | None]]:
72+
if mode not in ServerMode:
73+
raise ValueError(f'invalid mode="{mode}" allowed modes: {", ".join(list(ServerMode))}')
7674
return deepcopy(_SERVERS_REGISTRY[mode])
7775

7876

@@ -112,7 +110,7 @@ class OpenMLConfig:
112110
"""Dataclass storing the OpenML configuration."""
113111

114112
servers: dict[APIVersion, dict[str, str | None]] = field(
115-
default_factory=lambda: _get_servers("production")
113+
default_factory=lambda: _get_servers(ServerMode.PRODUCTION)
116114
)
117115
api_version: APIVersion = APIVersion.V1
118116
fallback_api_version: APIVersion | None = None
@@ -266,24 +264,24 @@ def get_server_base_url(self) -> str:
266264
domain, _ = self._config.server.split("/api", maxsplit=1)
267265
return domain.replace("api", "www")
268266

269-
def _get_servers(self, mode: str) -> dict[APIVersion, dict[str, str | None]]:
267+
def _get_servers(self, mode: ServerMode) -> dict[APIVersion, dict[str, str | None]]:
270268
return _get_servers(mode)
271269

272-
def _set_servers(self, mode: str) -> None:
270+
def _set_servers(self, mode: ServerMode) -> None:
273271
servers = self._get_servers(mode)
274272
self._config = replace(self._config, servers=servers)
275273

276274
def get_production_servers(self) -> dict[APIVersion, dict[str, str | None]]:
277-
return self._get_servers(mode="production")
275+
return self._get_servers(mode=ServerMode.PRODUCTION)
278276

279277
def get_test_servers(self) -> dict[APIVersion, dict[str, str | None]]:
280-
return self._get_servers(mode="test")
278+
return self._get_servers(mode=ServerMode.TEST)
281279

282280
def use_production_servers(self) -> None:
283-
self._set_servers(mode="production")
281+
self._set_servers(mode=ServerMode.PRODUCTION)
284282

285283
def use_test_servers(self) -> None:
286-
self._set_servers(mode="test")
284+
self._set_servers(mode=ServerMode.TEST)
287285

288286
def set_api_version(
289287
self,

0 commit comments

Comments
 (0)