Skip to content

Commit aa91d92

Browse files
committed
Merge branch 'migration' of https://github.com/geetu040/openml-python into runs-migration-stacked
2 parents 4a24b83 + f926092 commit aa91d92

5 files changed

Lines changed: 41 additions & 29 deletions

File tree

openml/_api/clients/http.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def get_key(self, url: str, params: dict[str, Any]) -> str:
5050
"""
5151
Generate a filesystem-safe cache key for a request.
5252
53-
The key is constructed from the reversed domain components, URL path
54-
segments, and URL-encoded query parameters (excluding ``api_key``).
53+
The key is constructed from URL path segments and
54+
URL-encoded query parameters (excluding ``api_key``).
5555
5656
Parameters
5757
----------
@@ -66,13 +66,12 @@ def get_key(self, url: str, params: dict[str, Any]) -> str:
6666
A relative path string representing the cache key.
6767
"""
6868
parsed_url = urlparse(url)
69-
netloc_parts = parsed_url.netloc.split(".")[::-1]
7069
path_parts = parsed_url.path.strip("/").split("/")
7170

7271
filtered_params = {k: v for k, v in params.items() if k != "api_key"}
7372
params_part = [urlencode(filtered_params)] if filtered_params else []
7473

75-
return str(Path(*netloc_parts, *path_parts, *params_part))
74+
return str(Path(*path_parts, *params_part))
7675

7776
def _key_to_path(self, key: str) -> Path:
7877
"""

openml/_config.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Any, ClassVar, Literal, cast
2020
from urllib.parse import urlparse
2121

22-
from openml.enums import APIVersion
22+
from openml.enums import APIVersion, ServerMode
2323

2424
from .__version__ import __version__
2525

@@ -56,23 +56,21 @@
5656
},
5757
APIVersion.V2: {
5858
"server": "http://localhost:8082/",
59-
"apikey": "AD000000000000000000000000000000",
59+
"apikey": "normaluser",
6060
},
6161
}
6262

63-
_SERVERS_REGISTRY: dict[str, dict[APIVersion, dict[str, str | None]]] = {
64-
"production": _PROD_SERVERS,
65-
"test": _TEST_SERVERS_LOCAL
66-
if os.getenv("OPENML_USE_LOCAL_SERVICES") == "true"
67-
else _TEST_SERVERS,
63+
_SERVERS_REGISTRY: dict[ServerMode, dict[APIVersion, dict[str, str | None]]] = {
64+
ServerMode.PRODUCTION: _PROD_SERVERS,
65+
ServerMode.TEST: (
66+
_TEST_SERVERS_LOCAL if os.getenv("OPENML_USE_LOCAL_SERVICES") == "true" else _TEST_SERVERS
67+
),
6868
}
6969

7070

71-
def _get_servers(mode: str) -> dict[APIVersion, dict[str, str | None]]:
72-
if mode not in _SERVERS_REGISTRY:
73-
raise ValueError(
74-
f'invalid mode="{mode}" allowed modes: {", ".join(list(_SERVERS_REGISTRY.keys()))}'
75-
)
71+
def _get_servers(mode: ServerMode) -> dict[APIVersion, dict[str, str | None]]:
72+
if mode not in ServerMode:
73+
raise ValueError(f'invalid mode="{mode}" allowed modes: {", ".join(list(ServerMode))}')
7674
return deepcopy(_SERVERS_REGISTRY[mode])
7775

7876

@@ -112,7 +110,7 @@ class OpenMLConfig:
112110
"""Dataclass storing the OpenML configuration."""
113111

114112
servers: dict[APIVersion, dict[str, str | None]] = field(
115-
default_factory=lambda: _get_servers("production")
113+
default_factory=lambda: _get_servers(ServerMode.PRODUCTION)
116114
)
117115
api_version: APIVersion = APIVersion.V1
118116
fallback_api_version: APIVersion | None = None
@@ -266,24 +264,24 @@ def get_server_base_url(self) -> str:
266264
domain, _ = self._config.server.split("/api", maxsplit=1)
267265
return domain.replace("api", "www")
268266

269-
def _get_servers(self, mode: str) -> dict[APIVersion, dict[str, str | None]]:
267+
def _get_servers(self, mode: ServerMode) -> dict[APIVersion, dict[str, str | None]]:
270268
return _get_servers(mode)
271269

272-
def _set_servers(self, mode: str) -> None:
270+
def _set_servers(self, mode: ServerMode) -> None:
273271
servers = self._get_servers(mode)
274272
self._config = replace(self._config, servers=servers)
275273

276274
def get_production_servers(self) -> dict[APIVersion, dict[str, str | None]]:
277-
return self._get_servers(mode="production")
275+
return self._get_servers(mode=ServerMode.PRODUCTION)
278276

279277
def get_test_servers(self) -> dict[APIVersion, dict[str, str | None]]:
280-
return self._get_servers(mode="test")
278+
return self._get_servers(mode=ServerMode.TEST)
281279

282280
def use_production_servers(self) -> None:
283-
self._set_servers(mode="production")
281+
self._set_servers(mode=ServerMode.PRODUCTION)
284282

285283
def use_test_servers(self) -> None:
286-
self._set_servers(mode="test")
284+
self._set_servers(mode=ServerMode.TEST)
287285

288286
def set_api_version(
289287
self,

openml/enums.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33
from enum import Enum
44

55

6+
class ServerMode(str, Enum):
7+
"""Supported modes in server."""
8+
9+
PRODUCTION = "production"
10+
TEST = "test"
11+
12+
613
class APIVersion(str, Enum):
714
"""Supported OpenML API versions."""
815

tests/test_api/test_http.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,30 @@ def sample_download_url_v1(test_server_v1) -> str:
4242
def test_cache(cache, sample_url_v1):
4343
params = {"param1": "value1", "param2": "value2"}
4444

45+
# validate key
46+
4547
parsed_url = urlparse(sample_url_v1)
46-
netloc_parts = parsed_url.netloc.split(".")[::-1]
4748
path_parts = parsed_url.path.strip("/").split("/")
4849
params_key = "&".join([f"{k}={v}" for k, v in params.items()])
4950

50-
5151
key = cache.get_key(sample_url_v1, params)
5252

5353
expected_key = os.path.join(
54-
*netloc_parts,
5554
*path_parts,
5655
params_key,
5756
)
5857

5958
assert key == expected_key
6059

60+
# validate path
61+
62+
path = cache._key_to_path(key)
63+
expected_path = Path(openml.config.get_cache_directory()).joinpath(key)
64+
65+
assert path == expected_path
66+
67+
# validate save/load
68+
6169
# mock response
6270
req = Request("GET", sample_url_v1).prepare()
6371
response = Response()

tests/test_openml/test_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import openml
1717
import openml.testing
1818
from openml.testing import TestBase
19-
from openml.enums import APIVersion
19+
from openml.enums import APIVersion, ServerMode
2020

2121

2222
@contextmanager
@@ -193,7 +193,7 @@ def test_openml_cache_dir_env_var(tmp_path: Path) -> None:
193193
assert openml.config.get_cache_directory() == str(expected_path / "org" / "openml" / "www")
194194

195195

196-
@pytest.mark.parametrize("mode", ["production", "test"])
196+
@pytest.mark.parametrize("mode", list(ServerMode))
197197
@pytest.mark.parametrize("api_version", [APIVersion.V1, APIVersion.V2])
198198
def test_get_servers(mode, api_version):
199199
orig_servers = openml.config._get_servers(mode)
@@ -208,7 +208,7 @@ def test_get_servers(mode, api_version):
208208
assert openml.config._get_servers(mode) == orig_servers
209209

210210

211-
@pytest.mark.parametrize("mode", ["production", "test"])
211+
@pytest.mark.parametrize("mode", list(ServerMode))
212212
@pytest.mark.parametrize("api_version", [APIVersion.V1, APIVersion.V2])
213213
def test_set_servers(mode, api_version):
214214
openml.config._set_servers(mode)

0 commit comments

Comments
 (0)