diff --git a/src/mlpa/core/completions.py b/src/mlpa/core/completions.py index 38b0c4d..9074f2d 100644 --- a/src/mlpa/core/completions.py +++ b/src/mlpa/core/completions.py @@ -255,6 +255,14 @@ async def _read_next_chunk( snapshot=litellm_routing_snapshot, ) result = PrometheusResult.SUCCESS + except (GeneratorExit, asyncio.CancelledError): + # Client went away mid-stream: Starlette tears the generator down by + # throwing GeneratorExit (or cancelling the task) at the paused + # `yield chunk`. This often beats the disconnect poller, so classify it + # as an abort here rather than letting the initial ERROR stand. + result = PrometheusResult.ABORT + logger.info(_client_disconnected_msg) + raise except httpx.ReadError as e: if disconnect_event.is_set() or await request.is_disconnected(): disconnect_event.set() diff --git a/src/tests/unit/test_completions.py b/src/tests/unit/test_completions.py index 083d0b0..c2e819d 100644 --- a/src/tests/unit/test_completions.py +++ b/src/tests/unit/test_completions.py @@ -1246,6 +1246,38 @@ async def test_stream_sends_error_sse_on_empty_200_response( _assert_error_latency(metrics_spy) +async def test_stream_completion_client_disconnect_records_abort( + mocker, mock_request, metrics_spy +): + """ + Client disconnect mid-stream tears the generator down via GeneratorExit at + the paused `yield chunk` (this is what Starlette does when the client goes + away). Even when the disconnect poller has not fired yet — `is_disconnected` + still returns False, so `disconnect_event` is unset — this must be recorded + as ABORT, not ERROR. Otherwise normal client cancellations pollute the error + rate. + """ + role_chunk = ( + b'data: {"choices":[{"delta":{"role":"assistant","content":null}}]}\n\n' + ) + + async def _aiter_bytes(): + yield role_chunk + yield b'data: {"choices":[{"delta":{"content":"hi"}}]}\n\n' + + _patch_mock_stream_client(mocker, _aiter_bytes) + + gen = stream_completion(SAMPLE_REQUEST, mock_request) + first = await gen.__anext__() + assert first == role_chunk + + # Client goes away: the response generator is closed mid-stream. + await gen.aclose() + + assert _latency_count(metrics_spy, PrometheusResult.ABORT) == 1 + assert _latency_count(metrics_spy, PrometheusResult.ERROR) == 0 + + async def test_stream_uses_httpx_timeout_object_preserving_pool_timeout( mocker, mock_request, metrics_spy ):