Skip to content

Commit 935f0f4

Browse files
committed
implement HTTPClient.download and add tests
1 parent 10d134a commit 935f0f4

3 files changed

Lines changed: 123 additions & 3 deletions

File tree

openml/_api/clients/http.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

3+
import hashlib
34
import json
45
import logging
56
import math
67
import random
78
import time
89
import xml
9-
from collections.abc import Mapping
10+
from collections.abc import Callable, Mapping
1011
from pathlib import Path
1112
from typing import Any
1213
from urllib.parse import urlencode, urljoin, urlparse
@@ -18,6 +19,8 @@
1819
from openml.__version__ import __version__
1920
from openml.enums import RetryPolicy
2021
from openml.exceptions import (
22+
OpenMLCacheRequiredError,
23+
OpenMLHashException,
2124
OpenMLNotAuthorizedError,
2225
OpenMLServerError,
2326
OpenMLServerException,
@@ -315,14 +318,15 @@ def _request( # noqa: PLR0913
315318

316319
return response, retry_raise_e
317320

318-
def request(
321+
def request( # noqa: PLR0913, C901
319322
self,
320323
method: str,
321324
path: str,
322325
*,
323326
use_cache: bool = False,
324327
reset_cache: bool = False,
325328
use_api_key: bool = False,
329+
md5_checksum: str | None = None,
326330
**request_kwargs: Any,
327331
) -> Response:
328332
url = urljoin(self.server, urljoin(self.base_url, path))
@@ -384,15 +388,28 @@ def request(
384388
cache_key = self.cache.get_key(url, params)
385389
self.cache.save(cache_key, response)
386390

391+
if md5_checksum is not None:
392+
self._verify_checksum(response, md5_checksum)
393+
387394
return response
388395

396+
def _verify_checksum(self, response: Response, md5_checksum: str) -> None:
397+
# ruff sees hashlib.md5 as insecure
398+
actual = hashlib.md5(response.content).hexdigest() # noqa: S324
399+
if actual != md5_checksum:
400+
raise OpenMLHashException(
401+
f"Checksum of downloaded file is unequal to the expected checksum {md5_checksum} "
402+
f"when downloading {response.url}.",
403+
)
404+
389405
def get(
390406
self,
391407
path: str,
392408
*,
393409
use_cache: bool = False,
394410
reset_cache: bool = False,
395411
use_api_key: bool = False,
412+
md5_checksum: str | None = None,
396413
**request_kwargs: Any,
397414
) -> Response:
398415
return self.request(
@@ -401,19 +418,22 @@ def get(
401418
use_cache=use_cache,
402419
reset_cache=reset_cache,
403420
use_api_key=use_api_key,
421+
md5_checksum=md5_checksum,
404422
**request_kwargs,
405423
)
406424

407425
def post(
408426
self,
409427
path: str,
428+
*,
429+
use_api_key: bool = True,
410430
**request_kwargs: Any,
411431
) -> Response:
412432
return self.request(
413433
method="POST",
414434
path=path,
415435
use_cache=False,
416-
use_api_key=True,
436+
use_api_key=use_api_key,
417437
**request_kwargs,
418438
)
419439

@@ -429,3 +449,33 @@ def delete(
429449
use_api_key=True,
430450
**request_kwargs,
431451
)
452+
453+
def download(
454+
self,
455+
url: str,
456+
handler: Callable[[Response, Path, str], Path] | None = None,
457+
encoding: str = "utf-8",
458+
file_name: str = "response.txt",
459+
md5_checksum: str | None = None,
460+
) -> Path:
461+
if self.cache is None:
462+
raise OpenMLCacheRequiredError(
463+
"A cache object is required for download, but none was provided in the HTTPClient."
464+
)
465+
base = self.cache.path
466+
file_path = base / "downloads" / urlparse(url).path.lstrip("/") / file_name
467+
file_path = file_path.expanduser()
468+
file_path.parent.mkdir(parents=True, exist_ok=True)
469+
if file_path.exists():
470+
return file_path
471+
472+
response = self.get(url, md5_checksum=md5_checksum)
473+
if handler is not None:
474+
return handler(response, file_path, encoding)
475+
476+
return self._text_handler(response, file_path, encoding)
477+
478+
def _text_handler(self, response: Response, path: Path, encoding: str) -> Path:
479+
with path.open("w", encoding=encoding) as f:
480+
f.write(response.text)
481+
return path

openml/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,7 @@ class ObjectNotPublishedError(PyOpenMLError):
6969

7070
class OpenMLNotSupportedError(PyOpenMLError):
7171
"""Raised when an API operation is not supported for a resource/version."""
72+
73+
74+
class OpenMLCacheRequiredError(PyOpenMLError):
75+
"""Raised when a cache object is required but not provided."""

tests/test_api/test_http.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import pytest
55
from openml.testing import TestAPIBase
66
import os
7+
from pathlib import Path
78
from urllib.parse import urljoin
89
from openml.enums import APIVersion
910
from openml._api import HTTPClient
11+
from openml.exceptions import OpenMLCacheRequiredError
1012

1113

1214
class TestHTTPClient(TestAPIBase):
@@ -174,3 +176,67 @@ def test_post_and_delete(self):
174176
if task_id is not None:
175177
del_response = self.http_client.delete(f"task/{task_id}")
176178
self.assertEqual(del_response.status_code, 200)
179+
180+
def test_download_requires_cache(self):
181+
client = HTTPClient(
182+
server=self.http_client.server,
183+
base_url=self.http_client.base_url,
184+
api_key=self.http_client.api_key,
185+
retries=1,
186+
retry_policy=self.http_client.retry_policy,
187+
cache=None,
188+
)
189+
190+
with pytest.raises(OpenMLCacheRequiredError):
191+
client.download("https://www.openml.org")
192+
193+
@pytest.mark.uses_test_server()
194+
def test_download_creates_file(self):
195+
# small stable resource
196+
url = self.http_client.server
197+
198+
path = self.http_client.download(
199+
url,
200+
file_name="index.html",
201+
)
202+
203+
assert path.exists()
204+
assert path.is_file()
205+
assert path.read_text(encoding="utf-8")
206+
207+
@pytest.mark.uses_test_server()
208+
def test_download_is_cached_on_disk(self):
209+
url = self.http_client.server
210+
211+
path1 = self.http_client.download(
212+
url,
213+
file_name="cached.html",
214+
)
215+
mtime1 = path1.stat().st_mtime
216+
217+
# second call should NOT re-download
218+
path2 = self.http_client.download(
219+
url,
220+
file_name="cached.html",
221+
)
222+
mtime2 = path2.stat().st_mtime
223+
224+
assert path1 == path2
225+
assert mtime1 == mtime2
226+
227+
@pytest.mark.uses_test_server()
228+
def test_download_respects_custom_handler(self):
229+
url = self.http_client.server
230+
231+
def handler(response, path: Path, encoding: str):
232+
path.write_text("HANDLED", encoding=encoding)
233+
return path
234+
235+
path = self.http_client.download(
236+
url,
237+
handler=handler,
238+
file_name="handled.txt",
239+
)
240+
241+
assert path.exists()
242+
assert path.read_text() == "HANDLED"

0 commit comments

Comments
 (0)