1212from io import BufferedReader
1313from urllib .parse import urlencode , urljoin
1414
15- import requests
16- import urllib3
15+ import aiofiles
16+ import aiofiles .os
17+ import aiohttp
1718from multidict import CIMultiDict , CIMultiDictProxy , MutableMultiMapping
1819
1920from pulp_glue .common import __version__
@@ -134,37 +135,12 @@ def __init__(
134135 if cid :
135136 self ._headers ["Correlation-Id" ] = cid
136137
137- self ._setup_session ()
138-
139138 self ._oauth2_lock = asyncio .Lock ()
140139 self ._oauth2_token : str | None = None
141140 self ._oauth2_expires : datetime = datetime .now ()
142141
143142 self .load_api (refresh_cache = refresh_cache )
144143
145- def _setup_session (self ) -> None :
146- # This is specific requests library.
147-
148- if self ._verify_ssl is False :
149- urllib3 .disable_warnings (urllib3 .exceptions .InsecureRequestWarning )
150-
151- self ._session : requests .Session = requests .session ()
152- # Don't redirect, because carrying auth accross redirects is unsafe.
153- self ._session .max_redirects = 0
154- self ._session .headers .update (self ._headers )
155- session_settings = self ._session .merge_environment_settings (
156- self ._base_url , {}, None , self ._verify_ssl , None
157- )
158- self ._session .verify = session_settings ["verify" ]
159- self ._session .proxies = session_settings ["proxies" ]
160-
161- if self ._auth_provider is not None and self ._auth_provider .can_complete_mutualTLS ():
162- cert , key = self ._auth_provider .tls_credentials ()
163- if key is not None :
164- self ._session .cert = (cert , key )
165- else :
166- self ._session .cert = cert
167-
168144 @property
169145 def base_url (self ) -> str :
170146 return self ._base_url
@@ -188,7 +164,10 @@ def ssl_context(self) -> t.Union[ssl.SSLContext, bool]:
188164 return _ssl_context
189165
190166 def load_api (self , refresh_cache : bool = False ) -> None :
191- # TODO: Find a way to invalidate caches on upstream change
167+ asyncio .run (self ._load_api (refresh_cache = refresh_cache ))
168+
169+ async def _load_api (self , refresh_cache : bool = False ) -> None :
170+ # TODO: Find a way to invalidate caches on upstream change.
192171 xdg_cache_home : str = os .environ .get ("XDG_CACHE_HOME" ) or "~/.cache"
193172 apidoc_cache : str = os .path .join (
194173 os .path .expanduser (xdg_cache_home ),
@@ -200,17 +179,17 @@ def load_api(self, refresh_cache: bool = False) -> None:
200179 if refresh_cache :
201180 # Fake that we did not find the cache.
202181 raise OSError ()
203- with open (apidoc_cache , "rb" ) as f :
204- data : bytes = f .read ()
182+ async with aiofiles . open (apidoc_cache , mode = "rb" ) as f :
183+ data : bytes = await f .read ()
205184 self ._parse_api (data )
206185 except Exception :
207- # Try again with a freshly downloaded version
208- data = self ._download_api ()
186+ # Try again with a freshly downloaded version.
187+ data = await self ._download_api ()
209188 self ._parse_api (data )
210- # Write to cache as it seems to be valid
211- os .makedirs (os .path .dirname (apidoc_cache ), exist_ok = True )
212- with open (apidoc_cache , "bw" ) as f :
213- f .write (data )
189+ # Write to cache as it seems to be valid.
190+ await aiofiles . os .makedirs (os .path .dirname (apidoc_cache ), exist_ok = True )
191+ async with aiofiles . open (apidoc_cache , mode = "bw" ) as f :
192+ await f .write (data )
214193
215194 def _parse_api (self , data : bytes ) -> None :
216195 self .api_spec : dict [str , t .Any ] = json .loads (data )
@@ -225,15 +204,18 @@ def _parse_api(self, data: bytes) -> None:
225204 if method in {"get" , "put" , "post" , "delete" , "options" , "head" , "patch" , "trace" }
226205 }
227206
228- def _download_api (self ) -> bytes :
229- try :
230- response : requests .Response = self ._session .get (urljoin (self ._base_url , self ._doc_path ))
231- except requests .RequestException as e :
232- raise OpenAPIError (str (e ))
233- response .raise_for_status ()
234- if "Correlation-Id" in response .headers :
235- self ._set_correlation_id (response .headers ["Correlation-Id" ])
236- return response .content
207+ async def _download_api (self ) -> bytes :
208+ response = await self ._send_request (
209+ _Request (
210+ operation_id = "" ,
211+ method = "get" ,
212+ url = urljoin (self ._base_url , self ._doc_path ),
213+ headers = self ._headers ,
214+ )
215+ )
216+ if response .status_code != 200 :
217+ raise OpenAPIError (_ ("Failed to find api docs." ))
218+ return response .body
237219
238220 def _set_correlation_id (self , correlation_id : str ) -> None :
239221 if "Correlation-Id" in self ._headers :
@@ -245,8 +227,6 @@ def _set_correlation_id(self, correlation_id: str) -> None:
245227 )
246228 else :
247229 self ._headers ["Correlation-Id" ] = correlation_id
248- # Do it for requests too...
249- self ._session .headers ["Correlation-Id" ] = correlation_id
250230
251231 def param_spec (
252232 self , operation_id : str , param_type : str , required : bool = False
@@ -463,7 +443,7 @@ def _render_request(
463443 security = security ,
464444 )
465445
466- def _log_request (self , request : _Request ) -> None :
446+ async def _log_request (self , request : _Request ) -> None :
467447 if request .params :
468448 qs = urlencode (request .params )
469449 self ._debug_callback (1 , f"{ request .operation_id } : { request .method } { request .url } ?{ qs } " )
@@ -489,7 +469,6 @@ def _select_proposal(
489469 if (
490470 request .security
491471 and "Authorization" not in request .headers
492- and "Authorization" not in self ._session .headers
493472 and self ._auth_provider is not None
494473 ):
495474 security_schemes : dict [str , dict [str , t .Any ]] = self .api_spec ["components" ][
@@ -561,7 +540,7 @@ async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool:
561540 headers = {"Authorization" : f"Basic { secret .decode ()} " },
562541 data = data ,
563542 )
564- response = self ._send_request (request )
543+ response = await self ._send_request (request )
565544 if response .status_code < 200 or response .status_code >= 300 :
566545 raise OpenAPIError ("Failed to fetch OAuth2 token" )
567546 result = json .loads (response .body )
@@ -570,38 +549,55 @@ async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool:
570549 new_token = True
571550 return new_token
572551
573- def _send_request (
552+ async def _send_request (
574553 self ,
575554 request : _Request ,
576555 ) -> _Response :
577- # This function uses requests to translate the _Request into a _Response.
556+ # This function uses aiohttp to translate the _Request into a _Response.
557+ data : aiohttp .FormData | dict [str , t .Any ] | str | None
558+ if request .files :
559+ assert isinstance (request .data , dict )
560+ # Maybe assert on the content type header.
561+ data = aiohttp .FormData (default_to_multipart = True )
562+ for key , value in request .data .items ():
563+ data .add_field (key , encode_param (value ))
564+ for key , (name , value , content_type ) in request .files .items ():
565+ data .add_field (key , value , filename = name , content_type = content_type )
566+ else :
567+ data = request .data
578568 try :
579- r = self ._session .request (
580- request .method ,
581- request .url ,
582- params = request .params ,
583- headers = request .headers ,
584- data = request .data ,
585- files = request .files ,
586- )
587- response = _Response (status_code = r .status_code , headers = r .headers , body = r .content )
588- except requests .TooManyRedirects as e :
589- assert e .response is not None
569+ async with aiohttp .ClientSession () as session :
570+ async with session .request (
571+ request .method ,
572+ request .url ,
573+ params = request .params ,
574+ headers = request .headers ,
575+ data = data ,
576+ ssl = self .ssl_context ,
577+ max_redirects = 0 ,
578+ ) as r :
579+ response_body = await r .read ()
580+ response = _Response (
581+ status_code = r .status , headers = r .headers , body = response_body
582+ )
583+ except aiohttp .TooManyRedirects as e :
584+ # We could handle that in the middleware...
585+ assert e .history [- 1 ] is not None
590586 raise OpenAPIError (
591587 _ (
592588 "Received redirect to '{new_url} from {old_url}'."
593589 " Please check your configuration."
594590 ).format (
595- new_url = e .response .headers ["location" ],
591+ new_url = e .history [ - 1 ] .headers ["location" ],
596592 old_url = request .url ,
597593 )
598594 )
599- except requests . RequestException as e :
595+ except aiohttp . ClientResponseError as e :
600596 raise OpenAPIError (str (e ))
601597
602598 return response
603599
604- def _log_response (self , response : _Response ) -> None :
600+ async def _log_response (self , response : _Response ) -> None :
605601 self ._debug_callback (
606602 1 , _ ("Response: {status_code}" ).format (status_code = response .status_code )
607603 )
@@ -648,6 +644,22 @@ def call(
648644 parameters : dict [str , t .Any ] | None = None ,
649645 body : dict [str , t .Any ] | None = None ,
650646 validate_body : bool = True ,
647+ ) -> t .Any :
648+ return asyncio .run (
649+ self .async_call (
650+ operation_id = operation_id ,
651+ parameters = parameters ,
652+ body = body ,
653+ validate_body = validate_body ,
654+ )
655+ )
656+
657+ async def async_call (
658+ self ,
659+ operation_id : str ,
660+ parameters : dict [str , t .Any ] | None = None ,
661+ body : dict [str , t .Any ] | None = None ,
662+ validate_body : bool = True ,
651663 ) -> t .Any :
652664 """
653665 Make a call to the server.
@@ -702,37 +714,33 @@ def call(
702714 body ,
703715 validate_body = validate_body ,
704716 )
705- self ._log_request (request )
717+ await self ._log_request (request )
706718
707719 if self ._dry_run and request .method .upper () not in SAFE_METHODS :
708720 raise UnsafeCallError (_ ("Call aborted due to safe mode" ))
709721
710722 may_retry = False
711723 if proposal := self ._select_proposal (request ):
712724 assert len (proposal ) == 1 , "More complex security proposals are not implemented."
713- may_retry = asyncio . run ( self ._authenticate_request (request , proposal ) )
725+ may_retry = await self ._authenticate_request (request , proposal )
714726
715- response = self ._send_request (request )
727+ response = await self ._send_request (request )
716728
717729 if proposal is not None :
718730 assert self ._auth_provider is not None
719731 if may_retry and response .status_code == 401 :
720732 self ._oauth2_token = None
721- asyncio . run ( self ._authenticate_request (request , proposal ) )
722- response = self ._send_request (request )
733+ await self ._authenticate_request (request , proposal )
734+ response = await self ._send_request (request )
723735
724736 if response .status_code >= 200 and response .status_code < 300 :
725- asyncio .run (
726- self ._auth_provider .auth_success_hook (
727- proposal , self .api_spec ["components" ]["securitySchemes" ]
728- )
737+ await self ._auth_provider .auth_success_hook (
738+ proposal , self .api_spec ["components" ]["securitySchemes" ]
729739 )
730740 elif response .status_code == 401 :
731- asyncio .run (
732- self ._auth_provider .auth_failure_hook (
733- proposal , self .api_spec ["components" ]["securitySchemes" ]
734- )
741+ await self ._auth_provider .auth_failure_hook (
742+ proposal , self .api_spec ["components" ]["securitySchemes" ]
735743 )
736744
737- self ._log_response (response )
745+ await self ._log_response (response )
738746 return self ._parse_response (method_spec , response )
0 commit comments