11from __future__ import annotations
22
3+ import hashlib
34import json
45import logging
56import math
67import random
78import time
89import xml
9- from collections .abc import Mapping
10+ from collections .abc import Callable , Mapping
1011from pathlib import Path
1112from typing import Any
1213from urllib .parse import urlencode , urljoin , urlparse
1819from openml .__version__ import __version__
1920from openml .enums import RetryPolicy
2021from openml .exceptions import (
22+ OpenMLCacheRequiredError ,
23+ OpenMLHashException ,
2124 OpenMLNotAuthorizedError ,
2225 OpenMLServerError ,
2326 OpenMLServerException ,
@@ -315,14 +318,15 @@ def _request( # noqa: PLR0913
315318
316319 return response , retry_raise_e
317320
318- def request (
321+ def request ( # noqa: PLR0913, C901
319322 self ,
320323 method : str ,
321324 path : str ,
322325 * ,
323326 use_cache : bool = False ,
324327 reset_cache : bool = False ,
325328 use_api_key : bool = False ,
329+ md5_checksum : str | None = None ,
326330 ** request_kwargs : Any ,
327331 ) -> Response :
328332 url = urljoin (self .server , urljoin (self .base_url , path ))
@@ -384,15 +388,28 @@ def request(
384388 cache_key = self .cache .get_key (url , params )
385389 self .cache .save (cache_key , response )
386390
391+ if md5_checksum is not None :
392+ self ._verify_checksum (response , md5_checksum )
393+
387394 return response
388395
396+ def _verify_checksum (self , response : Response , md5_checksum : str ) -> None :
397+ # ruff sees hashlib.md5 as insecure
398+ actual = hashlib .md5 (response .content ).hexdigest () # noqa: S324
399+ if actual != md5_checksum :
400+ raise OpenMLHashException (
401+ f"Checksum of downloaded file is unequal to the expected checksum { md5_checksum } "
402+ f"when downloading { response .url } ." ,
403+ )
404+
389405 def get (
390406 self ,
391407 path : str ,
392408 * ,
393409 use_cache : bool = False ,
394410 reset_cache : bool = False ,
395411 use_api_key : bool = False ,
412+ md5_checksum : str | None = None ,
396413 ** request_kwargs : Any ,
397414 ) -> Response :
398415 return self .request (
@@ -401,19 +418,22 @@ def get(
401418 use_cache = use_cache ,
402419 reset_cache = reset_cache ,
403420 use_api_key = use_api_key ,
421+ md5_checksum = md5_checksum ,
404422 ** request_kwargs ,
405423 )
406424
407425 def post (
408426 self ,
409427 path : str ,
428+ * ,
429+ use_api_key : bool = True ,
410430 ** request_kwargs : Any ,
411431 ) -> Response :
412432 return self .request (
413433 method = "POST" ,
414434 path = path ,
415435 use_cache = False ,
416- use_api_key = True ,
436+ use_api_key = use_api_key ,
417437 ** request_kwargs ,
418438 )
419439
@@ -429,3 +449,33 @@ def delete(
429449 use_api_key = True ,
430450 ** request_kwargs ,
431451 )
452+
453+ def download (
454+ self ,
455+ url : str ,
456+ handler : Callable [[Response , Path , str ], Path ] | None = None ,
457+ encoding : str = "utf-8" ,
458+ file_name : str = "response.txt" ,
459+ md5_checksum : str | None = None ,
460+ ) -> Path :
461+ if self .cache is None :
462+ raise OpenMLCacheRequiredError (
463+ "A cache object is required for download, but none was provided in the HTTPClient."
464+ )
465+ base = self .cache .path
466+ file_path = base / "downloads" / urlparse (url ).path .lstrip ("/" ) / file_name
467+ file_path = file_path .expanduser ()
468+ file_path .parent .mkdir (parents = True , exist_ok = True )
469+ if file_path .exists ():
470+ return file_path
471+
472+ response = self .get (url , md5_checksum = md5_checksum )
473+ if handler is not None :
474+ return handler (response , file_path , encoding )
475+
476+ return self ._text_handler (response , file_path , encoding )
477+
478+ def _text_handler (self , response : Response , path : Path , encoding : str ) -> Path :
479+ with path .open ("w" , encoding = encoding ) as f :
480+ f .write (response .text )
481+ return path
0 commit comments