Skip to content

Commit 83056e3

Browse files
Magd BayoumiMagd Bayoumi
authored andcommitted
fixup! fixup! fixup! fixup! fixup! fixup! fixup! fix: handle http2 goaway race conditions
1 parent 2fe76e7 commit 83056e3

4 files changed

Lines changed: 62 additions & 23 deletions

File tree

httpcore/_async/http2.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,10 @@ async def handle_async_request(self, request: Request) -> Response:
195195
if self._connection_terminated: # pragma: nocover
196196
phase = self._stream_requests.get(
197197
stream_id,
198-
{"headers_sent": False, "body_sent": False},
198+
{
199+
"headers_sent": False,
200+
"body_sent": False,
201+
},
199202
)
200203
raise ConnectionGoingAway(
201204
self._connection_terminated, # type: ignore[arg-type]
@@ -213,7 +216,10 @@ async def handle_async_request(self, request: Request) -> Response:
213216
):
214217
phase = self._stream_requests.get(
215218
stream_id,
216-
{"headers_sent": False, "body_sent": False},
219+
{
220+
"headers_sent": False,
221+
"body_sent": False,
222+
},
217223
)
218224
msg = f"Connection closed: {exc}"
219225
raise ConnectionGoingAway(
@@ -396,7 +402,11 @@ async def _receive_events(
396402
last_stream_id = self._connection_terminated.last_stream_id
397403
if stream_id is not None:
398404
phase = self._stream_requests.get(
399-
stream_id, {"headers_sent": False, "body_sent": False}
405+
stream_id,
406+
{
407+
"headers_sent": False,
408+
"body_sent": False,
409+
},
400410
)
401411
if last_stream_id is not None and stream_id > last_stream_id:
402412
# stream_id > last_stream_id: guaranteed unprocessed, safe to retry
@@ -523,7 +533,11 @@ async def _read_incoming_data(
523533
# Server disconnected. Check if this is related to GOAWAY.
524534
if stream_id is not None:
525535
phase = self._stream_requests.get(
526-
stream_id, {"headers_sent": False, "body_sent": False}
536+
stream_id,
537+
{
538+
"headers_sent": False,
539+
"body_sent": False,
540+
},
527541
)
528542
# If we have a GOAWAY recorded, this disconnect is GOAWAY-related
529543
if self._connection_terminated is not None:

httpcore/_sync/connection_pool.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,7 @@ def handle_request(self, request: Request) -> Response:
263263
else:
264264
# Request may have been processed. Propagate error with context so application can decide
265265
# whether to retry.
266-
msg = (
267-
"GOAWAY recieved: request may have been processed"
268-
)
266+
msg = "GOAWAY recieved: request may have been processed"
269267
# QUESTION: What is the best way to propagate the context for the applications?
270268
raise RemoteProtocolError(msg) from exc
271269
else:

httpcore/_sync/http2.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def __init__(
8989
# TODO: Consider shifting this to a dataclass or typeddict
9090
self._stream_requests: dict[int, dict[str, bool]] = {}
9191

92-
9392
def handle_request(self, request: Request) -> Response:
9493
if not self.can_handle_request(request.url.origin):
9594
# This cannot occur in normal operation, since the connection pool
@@ -142,7 +141,10 @@ def handle_request(self, request: Request) -> Response:
142141
stream_id = self._h2_state.get_next_available_stream_id()
143142
self._events[stream_id] = []
144143
# Initialize phase tracking for this stream
145-
self._stream_requests[stream_id] = {"headers_sent": False, "body_sent": False}
144+
self._stream_requests[stream_id] = {
145+
"headers_sent": False,
146+
"body_sent": False,
147+
}
146148
except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover
147149
self._used_all_stream_ids = True
148150
self._request_count -= 1
@@ -192,7 +194,11 @@ def handle_request(self, request: Request) -> Response:
192194
# it as a ConnectionGoingAway if applicable, or RemoteProtocolError.
193195
if self._connection_terminated: # pragma: nocover
194196
phase = self._stream_requests.get(
195-
stream_id, {"headers_sent": False, "body_sent": False},
197+
stream_id,
198+
{
199+
"headers_sent": False,
200+
"body_sent": False,
201+
},
196202
)
197203
raise ConnectionGoingAway(
198204
self._connection_terminated, # type: ignore[arg-type]
@@ -204,15 +210,22 @@ def handle_request(self, request: Request) -> Response:
204210
)
205211
# Check if h2 is in CLOSED state due to GOAWAY. This can happen when
206212
# GOAWAY was recieved but we haven't processed the event yet (race condition).
207-
if self._h2_state.state_machine.state == h2.connection.ConnectionState.CLOSED:
213+
if (
214+
self._h2_state.state_machine.state
215+
== h2.connection.ConnectionState.CLOSED
216+
):
208217
phase = self._stream_requests.get(
209-
stream_id, {"headers_sent": False, "body_sent": False},
218+
stream_id,
219+
{
220+
"headers_sent": False,
221+
"body_sent": False,
222+
},
210223
)
211224
msg = f"Connection closed: {exc}"
212225
raise ConnectionGoingAway(
213226
msg,
214-
last_stream_id=stream_id, # Conservative: assume this stream may have been processed
215-
error_code=0, # Assume graceful shutdown
227+
last_stream_id=stream_id, # Conservative: assume this stream may have been processed
228+
error_code=0, # Assume graceful shutdown
216229
request_stream_id=stream_id,
217230
headers_sent=phase["headers_sent"],
218231
body_sent=phase["body_sent"],
@@ -389,7 +402,11 @@ def _receive_events(
389402
last_stream_id = self._connection_terminated.last_stream_id
390403
if stream_id is not None:
391404
phase = self._stream_requests.get(
392-
stream_id, {"headers_sent": False, "body_sent": False}
405+
stream_id,
406+
{
407+
"headers_sent": False,
408+
"body_sent": False,
409+
},
393410
)
394411
if last_stream_id is not None and stream_id > last_stream_id:
395412
# stream_id > last_stream_id: guaranteed unprocessed, safe to retry
@@ -402,11 +419,13 @@ def _receive_events(
402419
headers_sent=phase["headers_sent"],
403420
body_sent=phase["body_sent"],
404421
)
405-
# stream_id <= last_stream_id: may have been processed
406422
if self._state != HTTPConnectionState.DRAINING:
423+
# stream_id <= last_stream_id: may have been processed
407424
raise ConnectionGoingAway(
408425
f"GOAWAY received: stream {stream_id} <= last_stream_id {last_stream_id}",
409-
last_stream_id=last_stream_id if last_stream_id is not None else 0,
426+
last_stream_id=last_stream_id
427+
if last_stream_id is not None
428+
else 0,
410429
error_code=self._connection_terminated.error_code, # type: ignore[arg-type]
411430
request_stream_id=stream_id,
412431
headers_sent=phase["headers_sent"],
@@ -477,7 +496,7 @@ def _receive_remote_settings_change(
477496
def _response_closed(self, stream_id: int) -> None:
478497
self._max_streams_semaphore.release()
479498
del self._events[stream_id]
480-
self._stream_requests.pop(stream_id, None) # Clean up phase tracking
499+
self._stream_requests.pop(stream_id, None) # Clean up phase tracking
481500
with self._state_lock:
482501
if self._connection_terminated and not self._events:
483502
self.close()
@@ -514,21 +533,28 @@ def _read_incoming_data(
514533
# Server disconnected. Check if this is related to GOAWAY.
515534
if stream_id is not None:
516535
phase = self._stream_requests.get(
517-
stream_id, {"headers_sent": False, "body_sent": False}
536+
stream_id,
537+
{
538+
"headers_sent": False,
539+
"body_sent": False,
540+
},
518541
)
519542
# If we have a GOAWAY recorded, this disconnect is GOAWAY-related
520543
if self._connection_terminated is not None:
521544
last_stream_id = self._connection_terminated.last_stream_id
522545
raise ConnectionGoingAway(
523546
"Server disconnected after GOAWAY",
524547
last_stream_id=last_stream_id if last_stream_id else 0,
525-
error_code=self._connection_terminated.error_code, # type: ignore[arg-type]
548+
error_code=self._connection_terminated.error_code, # type: ignore[arg-type]
526549
request_stream_id=stream_id,
527550
headers_sent=phase["headers_sent"],
528551
body_sent=phase["body_sent"],
529552
)
530553
# Check if h2 is in CLOSED state (GOAWAY received but not processed)
531-
if self._h2_state.state_machine.state == h2.connection.ConnectionState.CLOSED:
554+
if (
555+
self._h2_state.state_machine.state
556+
== h2.connection.ConnectionState.CLOSED
557+
):
532558
raise ConnectionGoingAway(
533559
"Server disconnected (connection closed)",
534560
last_stream_id=stream_id, # Conservative
@@ -607,7 +633,8 @@ def can_handle_request(self, origin: Origin) -> bool:
607633

608634
def is_available(self) -> bool:
609635
return (
610-
self._state not in (HTTPConnectionState.DRAINING, HTTPConnectionState.CLOSED)
636+
self._state
637+
not in (HTTPConnectionState.DRAINING, HTTPConnectionState.CLOSED)
611638
and not self._connection_error
612639
and not self._used_all_stream_ids
613640
and not (

tests/_sync/test_http_proxy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def test_proxy_tunneling_with_403():
224224
"""
225225
network_backend = MockBackend(
226226
[
227-
b"HTTP/1.1 403 Permission Denied\r\n" b"\r\n",
227+
b"HTTP/1.1 403 Permission Denied\r\n\r\n",
228228
]
229229
)
230230

0 commit comments

Comments
 (0)