Skip to content

Commit c1018e0

Browse files
committed
WIP: Async support in pulp-glue
Replaces requests with aiohttp and changes the api.
1 parent 27f983f commit c1018e0

9 files changed

Lines changed: 118 additions & 98 deletions

File tree

CHANGES/pulp-glue/+aiohttp.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
WIP: Added async api to Pulp glue.

CHANGES/pulp-glue/+aiohttp.removal

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Replaced requests with aiohttp.
2+
Breaking change: Reworked the contract around the `AuthProvider` to allow authentication to be coded independently of the underlying library.

lint_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ mypy==1.19.1
44
shellcheck-py==0.11.0.1
55

66
# Type annotation stubs
7+
types-aiofiles
78
types-pygments
89
types-PyYAML
9-
types-requests
1010
types-setuptools
1111
types-toml
1212

lower_bounds_constraints.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
aiofiles==25.1.0
2+
aiohttp==3.12.0
13
click==8.0.0
24
packaging==20.0
35
PyYAML==5.3

pulp-glue/pulp_glue/common/openapi.py

Lines changed: 88 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from io import BufferedReader
1313
from urllib.parse import urlencode, urljoin
1414

15-
import requests
16-
import urllib3
15+
import aiofiles
16+
import aiofiles.os
17+
import aiohttp
1718
from multidict import CIMultiDict, CIMultiDictProxy, MutableMultiMapping
1819

1920
from pulp_glue.common import __version__
@@ -134,37 +135,12 @@ def __init__(
134135
if cid:
135136
self._headers["Correlation-Id"] = cid
136137

137-
self._setup_session()
138-
139138
self._oauth2_lock = asyncio.Lock()
140139
self._oauth2_token: str | None = None
141140
self._oauth2_expires: datetime = datetime.now()
142141

143142
self.load_api(refresh_cache=refresh_cache)
144143

145-
def _setup_session(self) -> None:
146-
# This is specific requests library.
147-
148-
if self._verify_ssl is False:
149-
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
150-
151-
self._session: requests.Session = requests.session()
152-
# Don't redirect, because carrying auth accross redirects is unsafe.
153-
self._session.max_redirects = 0
154-
self._session.headers.update(self._headers)
155-
session_settings = self._session.merge_environment_settings(
156-
self._base_url, {}, None, self._verify_ssl, None
157-
)
158-
self._session.verify = session_settings["verify"]
159-
self._session.proxies = session_settings["proxies"]
160-
161-
if self._auth_provider is not None and self._auth_provider.can_complete_mutualTLS():
162-
cert, key = self._auth_provider.tls_credentials()
163-
if key is not None:
164-
self._session.cert = (cert, key)
165-
else:
166-
self._session.cert = cert
167-
168144
@property
169145
def base_url(self) -> str:
170146
return self._base_url
@@ -188,7 +164,10 @@ def ssl_context(self) -> t.Union[ssl.SSLContext, bool]:
188164
return _ssl_context
189165

190166
def load_api(self, refresh_cache: bool = False) -> None:
191-
# TODO: Find a way to invalidate caches on upstream change
167+
asyncio.run(self._load_api(refresh_cache=refresh_cache))
168+
169+
async def _load_api(self, refresh_cache: bool = False) -> None:
170+
# TODO: Find a way to invalidate caches on upstream change.
192171
xdg_cache_home: str = os.environ.get("XDG_CACHE_HOME") or "~/.cache"
193172
apidoc_cache: str = os.path.join(
194173
os.path.expanduser(xdg_cache_home),
@@ -200,17 +179,17 @@ def load_api(self, refresh_cache: bool = False) -> None:
200179
if refresh_cache:
201180
# Fake that we did not find the cache.
202181
raise OSError()
203-
with open(apidoc_cache, "rb") as f:
204-
data: bytes = f.read()
182+
async with aiofiles.open(apidoc_cache, mode="rb") as f:
183+
data: bytes = await f.read()
205184
self._parse_api(data)
206185
except Exception:
207-
# Try again with a freshly downloaded version
208-
data = self._download_api()
186+
# Try again with a freshly downloaded version.
187+
data = await self._download_api()
209188
self._parse_api(data)
210-
# Write to cache as it seems to be valid
211-
os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True)
212-
with open(apidoc_cache, "bw") as f:
213-
f.write(data)
189+
# Write to cache as it seems to be valid.
190+
await aiofiles.os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True)
191+
async with aiofiles.open(apidoc_cache, mode="bw") as f:
192+
await f.write(data)
214193

215194
def _parse_api(self, data: bytes) -> None:
216195
self.api_spec: dict[str, t.Any] = json.loads(data)
@@ -225,15 +204,18 @@ def _parse_api(self, data: bytes) -> None:
225204
if method in {"get", "put", "post", "delete", "options", "head", "patch", "trace"}
226205
}
227206

228-
def _download_api(self) -> bytes:
229-
try:
230-
response: requests.Response = self._session.get(urljoin(self._base_url, self._doc_path))
231-
except requests.RequestException as e:
232-
raise OpenAPIError(str(e))
233-
response.raise_for_status()
234-
if "Correlation-Id" in response.headers:
235-
self._set_correlation_id(response.headers["Correlation-Id"])
236-
return response.content
207+
async def _download_api(self) -> bytes:
208+
response = await self._send_request(
209+
_Request(
210+
operation_id="",
211+
method="get",
212+
url=urljoin(self._base_url, self._doc_path),
213+
headers=self._headers,
214+
)
215+
)
216+
if response.status_code != 200:
217+
raise OpenAPIError(_("Failed to find api docs."))
218+
return response.body
237219

238220
def _set_correlation_id(self, correlation_id: str) -> None:
239221
if "Correlation-Id" in self._headers:
@@ -245,8 +227,6 @@ def _set_correlation_id(self, correlation_id: str) -> None:
245227
)
246228
else:
247229
self._headers["Correlation-Id"] = correlation_id
248-
# Do it for requests too...
249-
self._session.headers["Correlation-Id"] = correlation_id
250230

251231
def param_spec(
252232
self, operation_id: str, param_type: str, required: bool = False
@@ -463,7 +443,7 @@ def _render_request(
463443
security=security,
464444
)
465445

466-
def _log_request(self, request: _Request) -> None:
446+
async def _log_request(self, request: _Request) -> None:
467447
if request.params:
468448
qs = urlencode(request.params)
469449
self._debug_callback(1, f"{request.operation_id} : {request.method} {request.url}?{qs}")
@@ -489,7 +469,6 @@ def _select_proposal(
489469
if (
490470
request.security
491471
and "Authorization" not in request.headers
492-
and "Authorization" not in self._session.headers
493472
and self._auth_provider is not None
494473
):
495474
security_schemes: dict[str, dict[str, t.Any]] = self.api_spec["components"][
@@ -561,7 +540,7 @@ async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool:
561540
headers={"Authorization": f"Basic {secret.decode()}"},
562541
data=data,
563542
)
564-
response = self._send_request(request)
543+
response = await self._send_request(request)
565544
if response.status_code < 200 or response.status_code >= 300:
566545
raise OpenAPIError("Failed to fetch OAuth2 token")
567546
result = json.loads(response.body)
@@ -570,38 +549,55 @@ async def _fetch_oauth2_token(self, flow: dict[str, t.Any]) -> bool:
570549
new_token = True
571550
return new_token
572551

573-
def _send_request(
552+
async def _send_request(
574553
self,
575554
request: _Request,
576555
) -> _Response:
577-
# This function uses requests to translate the _Request into a _Response.
556+
# This function uses aiohttp to translate the _Request into a _Response.
557+
data: aiohttp.FormData | dict[str, t.Any] | str | None
558+
if request.files:
559+
assert isinstance(request.data, dict)
560+
# Maybe assert on the content type header.
561+
data = aiohttp.FormData(default_to_multipart=True)
562+
for key, value in request.data.items():
563+
data.add_field(key, encode_param(value))
564+
for key, (name, value, content_type) in request.files.items():
565+
data.add_field(key, value, filename=name, content_type=content_type)
566+
else:
567+
data = request.data
578568
try:
579-
r = self._session.request(
580-
request.method,
581-
request.url,
582-
params=request.params,
583-
headers=request.headers,
584-
data=request.data,
585-
files=request.files,
586-
)
587-
response = _Response(status_code=r.status_code, headers=r.headers, body=r.content)
588-
except requests.TooManyRedirects as e:
589-
assert e.response is not None
569+
async with aiohttp.ClientSession() as session:
570+
async with session.request(
571+
request.method,
572+
request.url,
573+
params=request.params,
574+
headers=request.headers,
575+
data=data,
576+
ssl=self.ssl_context,
577+
max_redirects=0,
578+
) as r:
579+
response_body = await r.read()
580+
response = _Response(
581+
status_code=r.status, headers=r.headers, body=response_body
582+
)
583+
except aiohttp.TooManyRedirects as e:
584+
# We could handle that in the middleware...
585+
assert e.history[-1] is not None
590586
raise OpenAPIError(
591587
_(
592588
"Received redirect to '{new_url} from {old_url}'."
593589
" Please check your configuration."
594590
).format(
595-
new_url=e.response.headers["location"],
591+
new_url=e.history[-1].headers["location"],
596592
old_url=request.url,
597593
)
598594
)
599-
except requests.RequestException as e:
595+
except aiohttp.ClientResponseError as e:
600596
raise OpenAPIError(str(e))
601597

602598
return response
603599

604-
def _log_response(self, response: _Response) -> None:
600+
async def _log_response(self, response: _Response) -> None:
605601
self._debug_callback(
606602
1, _("Response: {status_code}").format(status_code=response.status_code)
607603
)
@@ -648,6 +644,22 @@ def call(
648644
parameters: dict[str, t.Any] | None = None,
649645
body: dict[str, t.Any] | None = None,
650646
validate_body: bool = True,
647+
) -> t.Any:
648+
return asyncio.run(
649+
self.async_call(
650+
operation_id=operation_id,
651+
parameters=parameters,
652+
body=body,
653+
validate_body=validate_body,
654+
)
655+
)
656+
657+
async def async_call(
658+
self,
659+
operation_id: str,
660+
parameters: dict[str, t.Any] | None = None,
661+
body: dict[str, t.Any] | None = None,
662+
validate_body: bool = True,
651663
) -> t.Any:
652664
"""
653665
Make a call to the server.
@@ -702,37 +714,33 @@ def call(
702714
body,
703715
validate_body=validate_body,
704716
)
705-
self._log_request(request)
717+
await self._log_request(request)
706718

707719
if self._dry_run and request.method.upper() not in SAFE_METHODS:
708720
raise UnsafeCallError(_("Call aborted due to safe mode"))
709721

710722
may_retry = False
711723
if proposal := self._select_proposal(request):
712724
assert len(proposal) == 1, "More complex security proposals are not implemented."
713-
may_retry = asyncio.run(self._authenticate_request(request, proposal))
725+
may_retry = await self._authenticate_request(request, proposal)
714726

715-
response = self._send_request(request)
727+
response = await self._send_request(request)
716728

717729
if proposal is not None:
718730
assert self._auth_provider is not None
719731
if may_retry and response.status_code == 401:
720732
self._oauth2_token = None
721-
asyncio.run(self._authenticate_request(request, proposal))
722-
response = self._send_request(request)
733+
await self._authenticate_request(request, proposal)
734+
response = await self._send_request(request)
723735

724736
if response.status_code >= 200 and response.status_code < 300:
725-
asyncio.run(
726-
self._auth_provider.auth_success_hook(
727-
proposal, self.api_spec["components"]["securitySchemes"]
728-
)
737+
await self._auth_provider.auth_success_hook(
738+
proposal, self.api_spec["components"]["securitySchemes"]
729739
)
730740
elif response.status_code == 401:
731-
asyncio.run(
732-
self._auth_provider.auth_failure_hook(
733-
proposal, self.api_spec["components"]["securitySchemes"]
734-
)
741+
await self._auth_provider.auth_failure_hook(
742+
proposal, self.api_spec["components"]["securitySchemes"]
735743
)
736744

737-
self._log_response(response)
745+
await self._log_response(response)
738746
return self._parse_response(method_spec, response)

pulp-glue/pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ classifiers = [
2323
"Typing :: Typed",
2424
]
2525
dependencies = [
26+
"aiofiles>=25.1.0,<25.2",
27+
"aiohttp>=3.12.0,<3.14",
2628
"multidict>=6.0.5,<6.8",
2729
"packaging>=20.0,<=26.0", # CalVer
28-
"requests>=2.24.0,<2.33",
2930
"tomli>=2.0.0,<2.1;python_version<'3.11'",
3031
]
3132

pulp-glue/tests/test_auth_provider.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,7 @@ def test_can_complete_basic(self, provider: AuthProviderBase) -> None:
6363
assert provider.can_complete_http_basic()
6464

6565
def test_provides_username_and_password(self, provider: AuthProviderBase) -> None:
66-
assert asyncio.run(provider.http_basic_credentials()) == (
67-
b"user1",
68-
b"password1",
69-
)
66+
assert asyncio.run(provider.http_basic_credentials()) == (b"user1", b"password1")
7067

7168
def test_cannot_complete_mutualTLS(self, provider: AuthProviderBase) -> None:
7269
assert not provider.can_complete_mutualTLS()
@@ -104,10 +101,7 @@ def test_client_id_needs_client_secret(self) -> None:
104101
def test_can_complete_oauth2_client_credentials_and_provide_them(self) -> None:
105102
provider = GlueAuthProvider(client_id="client1", client_secret="secret1")
106103
assert provider.can_complete_oauth2_client_credentials([]) is True
107-
assert asyncio.run(provider.oauth2_client_credentials()) == (
108-
b"client1",
109-
b"secret1",
110-
)
104+
assert asyncio.run(provider.oauth2_client_credentials()) == (b"client1", b"secret1")
111105

112106
def test_can_complete_mutualTLS_and_provide_cert(self) -> None:
113107
provider = GlueAuthProvider(cert="FAKECERTIFICATE")

pulp-glue/tests/test_openapi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
).encode()
9595

9696

97-
def mock_send_request(request: _Request) -> _Response:
97+
async def mock_send_request(request: _Request) -> _Response:
9898
if request.url.endswith("oauth/token"):
9999
assert request.method.lower() == "post"
100100
# $ echo -n "client1:secret1" | base64

0 commit comments

Comments
 (0)