|
11 | 11 | import unittest |
12 | 12 | from pathlib import Path |
13 | 13 | from typing import ClassVar |
14 | | -from urllib.parse import urljoin |
15 | 14 |
|
16 | 15 | import requests |
17 | 16 |
|
18 | 17 | import openml |
19 | 18 | from openml._api import HTTPCache, HTTPClient, MinIOClient |
20 | | -from openml.enums import RetryPolicy |
| 19 | +from openml.enums import APIVersion, RetryPolicy |
21 | 20 | from openml.exceptions import OpenMLServerException |
22 | 21 | from openml.tasks import TaskType |
23 | 22 |
|
@@ -282,96 +281,42 @@ def _check_fold_timing_evaluations( # noqa: PLR0913 |
282 | 281 | assert evaluation <= max_val |
283 | 282 |
|
284 | 283 |
|
285 | | -class TestAPIBase(unittest.TestCase): |
286 | | - server: str |
287 | | - base_url: str |
288 | | - api_key: str |
289 | | - timeout_seconds: int |
290 | | - retries: int |
291 | | - retry_policy: RetryPolicy |
292 | | - dir: str |
293 | | - ttl: int |
| 284 | +class TestAPIBase(TestBase): |
294 | 285 | cache: HTTPCache |
295 | | - http_client: HTTPClient |
296 | | - |
297 | | - def setUp(self) -> None: |
298 | | - self.server = "https://test.openml.org/" |
299 | | - self.base_url = "api/v1/xml" |
300 | | - self.api_key = "normaluser" |
301 | | - self.timeout_seconds = 10 |
302 | | - self.retries = 3 |
303 | | - self.retry_policy = RetryPolicy.HUMAN |
304 | | - self.dir = "test_cache" |
305 | | - self.ttl = 60 * 60 * 24 * 7 |
306 | | - |
307 | | - self.cache = self._get_http_cache( |
308 | | - path=Path(self.dir), |
309 | | - ttl=self.ttl, |
310 | | - ) |
311 | | - self.http_client = self._get_http_client( |
312 | | - server=self.server, |
313 | | - base_url=self.base_url, |
314 | | - api_key=self.api_key, |
315 | | - timeout_seconds=self.timeout_seconds, |
316 | | - retries=self.retries, |
317 | | - retry_policy=self.retry_policy, |
318 | | - cache=self.cache, |
319 | | - ) |
320 | | - self.minio_client = self._get_minio_client(path=Path(self.dir)) |
| 286 | + http_clients: dict[APIVersion, HTTPClient] |
| 287 | + minio_client: MinIOClient |
321 | 288 |
|
322 | | - if self.cache.path.exists(): |
323 | | - shutil.rmtree(self.cache.path) |
| 289 | + def setUp(self, n_levels: int = 1, tmpdir_suffix: str = "") -> None: |
| 290 | + super().setUp(n_levels=n_levels, tmpdir_suffix=tmpdir_suffix) |
324 | 291 |
|
325 | | - def tearDown(self) -> None: |
326 | | - if self.cache.path.exists(): |
327 | | - shutil.rmtree(self.cache.path) |
| 292 | + retries = self.connection_n_retries |
| 293 | + retry_policy = RetryPolicy.HUMAN if self.retry_policy == "human" else RetryPolicy.ROBOT |
| 294 | + ttl = openml._backend.get_config_value("cache.ttl") |
| 295 | + cache_dir = self.static_cache_dir |
328 | 296 |
|
329 | | - def _get_http_cache( |
330 | | - self, |
331 | | - path: Path, |
332 | | - ttl: int, |
333 | | - ) -> HTTPCache: |
334 | | - return HTTPCache( |
335 | | - path=path, |
| 297 | + self.cache = HTTPCache( |
| 298 | + path=cache_dir, |
336 | 299 | ttl=ttl, |
337 | 300 | ) |
338 | | - |
339 | | - def _get_http_client( # noqa: PLR0913 |
340 | | - self, |
341 | | - server: str, |
342 | | - base_url: str, |
343 | | - api_key: str, |
344 | | - timeout_seconds: int, |
345 | | - retries: int, |
346 | | - retry_policy: RetryPolicy, |
347 | | - cache: HTTPCache | None = None, |
348 | | - ) -> HTTPClient: |
349 | | - return HTTPClient( |
350 | | - server=server, |
351 | | - base_url=base_url, |
352 | | - api_key=api_key, |
353 | | - timeout_seconds=timeout_seconds, |
354 | | - retries=retries, |
355 | | - retry_policy=retry_policy, |
356 | | - cache=cache, |
357 | | - ) |
358 | | - |
359 | | - def _get_minio_client( |
360 | | - self, |
361 | | - path: Path | None = None, |
362 | | - ) -> MinIOClient: |
363 | | - return MinIOClient(path=path) |
364 | | - |
365 | | - def _get_url( |
366 | | - self, |
367 | | - server: str | None = None, |
368 | | - base_url: str | None = None, |
369 | | - path: str | None = None, |
370 | | - ) -> str: |
371 | | - server = server if server else self.server |
372 | | - base_url = base_url if base_url else self.base_url |
373 | | - path = path if path else "" |
374 | | - return urljoin(self.server, urljoin(self.base_url, path)) |
| 301 | + self.http_clients = { |
| 302 | + APIVersion.V1: HTTPClient( |
| 303 | + server="https://test.openml.org/", |
| 304 | + base_url="api/v1/xml/", |
| 305 | + api_key="normaluser", |
| 306 | + retries=retries, |
| 307 | + retry_policy=retry_policy, |
| 308 | + cache=self.cache, |
| 309 | + ), |
| 310 | + APIVersion.V2: HTTPClient( |
| 311 | + server="http://localhost:8002/", |
| 312 | + base_url="", |
| 313 | + api_key="", |
| 314 | + retries=retries, |
| 315 | + retry_policy=retry_policy, |
| 316 | + cache=self.cache, |
| 317 | + ), |
| 318 | + } |
| 319 | + self.minio_client = MinIOClient(path=cache_dir) |
375 | 320 |
|
376 | 321 |
|
377 | 322 | def check_task_existence( |
|
0 commit comments