Skip to content

Commit f6bc7f7

Browse files
committed
make TestAPIBase inherit TestBase
1 parent 9608c36 commit f6bc7f7

1 file changed

Lines changed: 14 additions & 24 deletions

File tree

openml/testing.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -281,52 +281,42 @@ def _check_fold_timing_evaluations( # noqa: PLR0913
281281
assert evaluation <= max_val
282282

283283

284-
class TestAPIBase(unittest.TestCase):
285-
retries: int
286-
retry_policy: RetryPolicy
287-
ttl: int
288-
cache_dir: Path
284+
class TestAPIBase(TestBase):
289285
cache: HTTPCache
290286
http_clients: dict[APIVersion, HTTPClient]
291287
minio_client: MinIOClient
292288

293-
def setUp(self) -> None:
294-
config = openml._backend.get_config()
295-
296-
self.retries = config.connection.retries
297-
self.retry_policy = config.connection.retry_policy
298-
self.ttl = config.cache.ttl
289+
def setUp(self, n_levels: int = 1, tmpdir_suffix: str = "") -> None:
290+
super().setUp(n_levels=n_levels, tmpdir_suffix=tmpdir_suffix)
299291

300-
abspath_this_file = Path(inspect.getfile(self.__class__)).absolute()
301-
self.cache_dir = abspath_this_file.parent.parent / "files"
302-
if not self.cache_dir.is_dir():
303-
raise ValueError(
304-
f"Cannot find test cache dir, expected it to be {self.cache_dir}!",
305-
)
292+
retries = self.connection_n_retries
293+
retry_policy = RetryPolicy.HUMAN if self.retry_policy == "human" else RetryPolicy.ROBOT
294+
ttl = openml._backend.get_config_value("cache.ttl")
295+
cache_dir = self.static_cache_dir
306296

307297
self.cache = HTTPCache(
308-
path=self.cache_dir,
309-
ttl=self.ttl,
298+
path=cache_dir,
299+
ttl=ttl,
310300
)
311301
self.http_clients = {
312302
APIVersion.V1: HTTPClient(
313303
server="https://test.openml.org/",
314304
base_url="api/v1/xml/",
315305
api_key="normaluser",
316-
retries=self.retries,
317-
retry_policy=self.retry_policy,
306+
retries=retries,
307+
retry_policy=retry_policy,
318308
cache=self.cache,
319309
),
320310
APIVersion.V2: HTTPClient(
321311
server="http://localhost:8002/",
322312
base_url="",
323313
api_key="",
324-
retries=self.retries,
325-
retry_policy=self.retry_policy,
314+
retries=retries,
315+
retry_policy=retry_policy,
326316
cache=self.cache,
327317
),
328318
}
329-
self.minio_client = MinIOClient(path=self.cache_dir)
319+
self.minio_client = MinIOClient(path=cache_dir)
330320

331321

332322
def check_task_existence(

0 commit comments

Comments
 (0)