diff --git a/python/packages/jumpstarter/jumpstarter/client/lease.py b/python/packages/jumpstarter/jumpstarter/client/lease.py index 6c9a86391..839f08b80 100644 --- a/python/packages/jumpstarter/jumpstarter/client/lease.py +++ b/python/packages/jumpstarter/jumpstarter/client/lease.py @@ -22,6 +22,7 @@ fail_after, sleep, ) +from anyio.abc import SocketStream from anyio.from_thread import BlockingPortal from grpc.aio import AioRpcError, Channel from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc @@ -312,33 +313,70 @@ def __contextmanager__(self) -> Generator[Self]: with self.portal.wrap_async_context_manager(self) as value: yield value - async def handle_async(self, stream): + # DEADLINE_EXCEEDED and CANCELLED are excluded: they indicate client-side + # timeout or cancellation, not server/network transients worth retrying. + _TRANSIENT_GRPC_CODES = frozenset( + { + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.RESOURCE_EXHAUSTED, + grpc.StatusCode.ABORTED, + grpc.StatusCode.INTERNAL, + } + ) + + # UNKNOWN error messages that indicate transient tunnel teardowns. + # We don't blanket-retry all UNKNOWN errors (they could be permanent + # server bugs), but specific messages like "watch channel closed" are + # known to occur during tunnel reconnection. + _TRANSIENT_UNKNOWN_MESSAGES = ("watch channel closed",) + + @staticmethod + def _retry_delay(attempt: int, remaining: float, base: float = 0.3, cap: float = 5.0) -> float: + """Compute exponential-backoff delay, capped by *cap* and *remaining* time.""" + return min(base * (2**attempt), cap, remaining) + + async def _dial_and_connect(self, stream: SocketStream, channel_ready_timeout: float = 10.0) -> None: + """Single attempt; raises on failure for caller-driven retry.""" + response = await self.controller.Dial(jumpstarter_pb2.DialRequest(lease_name=self.name)) + async with connect_router_stream( + response.router_endpoint, + response.router_token, + stream, + self.tls_config, + self.grpc_options, + channel_ready_timeout=channel_ready_timeout, + ): + pass + + async def handle_async(self, stream: SocketStream) -> None: logger.debug("Connecting to Lease with name %s", self.name) - # Retry Dial with exponential backoff for transient "exporter not ready" errors. - # This handles the race condition where the client acquires a lease before - # the exporter has transitioned to LEASE_READY status. - # Uses time-based retry bounded by dial_timeout instead of fixed retry count. - base_delay = 0.3 - max_delay = 2.0 + # Retry Dial + router connection with exponential backoff. + # Handles FAILED_PRECONDITION (exporter not yet ready), transient + # network errors (tunnel drops), and OSError (unreachable endpoint). + # All error paths return instead of raising because handle_async runs + # inside TemporaryUnixListener.serve's task group -- an unhandled + # exception would crash the listener and terminate sibling connections. deadline = time.monotonic() + self.dial_timeout attempt = 0 while True: + remaining = deadline - time.monotonic() + channel_ready_timeout = max(min(10.0, remaining), 0) try: - response = await self.controller.Dial(jumpstarter_pb2.DialRequest(lease_name=self.name)) - break + await self._dial_and_connect(stream, channel_ready_timeout=channel_ready_timeout) + return except AioRpcError as e: + remaining = deadline - time.monotonic() if e.code() == grpc.StatusCode.FAILED_PRECONDITION and "not ready" in str(e.details()): - remaining = deadline - time.monotonic() if remaining <= 0: - logger.debug( + logger.warning( "Exporter not ready and dial timeout (%.1fs) exceeded after %d attempts", self.dial_timeout, attempt + 1, ) - raise - delay = min(base_delay * (2**attempt), max_delay, remaining) + return + delay = self._retry_delay(attempt, remaining) logger.debug( - "Exporter not ready, retrying Dial in %.1fs (attempt %d, %.1fs remaining)", + "Exporter not ready, retrying in %.1fs (attempt %d, %.1fs remaining)", delay, attempt + 1, remaining, @@ -346,27 +384,32 @@ async def handle_async(self, stream): await sleep(delay) attempt += 1 continue - if e.code() == grpc.StatusCode.UNAVAILABLE: - remaining = deadline - time.monotonic() + is_transient = e.code() in self._TRANSIENT_GRPC_CODES or ( + e.code() == grpc.StatusCode.UNKNOWN + and any(msg in str(e.details()).lower() for msg in self._TRANSIENT_UNKNOWN_MESSAGES) + ) + if is_transient: if remaining <= 0: logger.warning( - "Exporter unavailable and dial timeout (%.1fs) exceeded after %d attempts", - self.dial_timeout, + "Connection failed with transient error after %d attempts (%.1fs elapsed): %s", attempt + 1, + self.dial_timeout, + e.details(), ) - raise - delay = min(base_delay * (2**attempt), max_delay, remaining) - logger.warning( - "Exporter unavailable, retrying Dial in %.1fs (attempt %d, %.1fs remaining)", + return + delay = self._retry_delay(attempt, remaining) + logger.info( + "Connection failed with %s, retrying in %.1fs (attempt %d, %.1fs remaining): %s", + e.code().name, delay, attempt + 1, remaining, + e.details(), ) await sleep(delay) attempt += 1 continue - # Exporter went offline or lease ended - log and exit gracefully - if "permission denied" in str(e.details()).lower(): + if e.code() == grpc.StatusCode.PERMISSION_DENIED: self.lease_transferred = True logger.warning( "Lease %s has been transferred to another client. Your session is no longer valid.", @@ -375,10 +418,22 @@ async def handle_async(self, stream): else: logger.warning("Connection to exporter lost: %s", e.details()) return - async with connect_router_stream( - response.router_endpoint, response.router_token, stream, self.tls_config, self.grpc_options - ): - pass + except OSError as e: + remaining = deadline - time.monotonic() + if remaining > 0: + delay = self._retry_delay(attempt, remaining) + logger.info( + "Connection failed with OSError, retrying in %.1fs (attempt %d, %.1fs remaining): %s", + delay, + attempt + 1, + remaining, + e, + ) + await sleep(delay) + attempt += 1 + continue + logger.warning("Connection failed: %s", e) + return @asynccontextmanager async def serve_unix_async(self): diff --git a/python/packages/jumpstarter/jumpstarter/client/lease_test.py b/python/packages/jumpstarter/jumpstarter/client/lease_test.py index dffd6a288..52a5941c9 100644 --- a/python/packages/jumpstarter/jumpstarter/client/lease_test.py +++ b/python/packages/jumpstarter/jumpstarter/client/lease_test.py @@ -1,6 +1,7 @@ import asyncio import logging import sys +from contextlib import asynccontextmanager from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, Mock, patch @@ -572,61 +573,505 @@ async def get_then_fail(): assert remain_arg == timedelta(0) -class TestHandleAsyncUnavailableRetry: - """Tests for Lease.handle_async UNAVAILABLE retry behavior.""" +def _make_aio_rpc_error(code, details="error"): + """Helper to construct an AioRpcError.""" + return AioRpcError( + code=code, + initial_metadata=grpc.aio.Metadata(), + trailing_metadata=grpc.aio.Metadata(), + details=details, + debug_error_string=None, + ) - def _make_lease_for_handle(self): - lease = object.__new__(Lease) - lease.name = "test-lease" - lease.dial_timeout = 5.0 - lease.lease_transferred = False - lease.tls_config = Mock() - lease.grpc_options = {} - lease.controller = Mock() - return lease + +def _make_lease_for_handle(): + """Create a minimal Lease for testing handle_async.""" + lease = object.__new__(Lease) + lease.name = "test-lease" + lease.dial_timeout = 5.0 + lease.tls_config = Mock() + lease.grpc_options = {} + lease.controller = Mock() + lease.lease_transferred = False + return lease + + +class TestHandleAsyncTransientRetry: + """Tests for transient gRPC error retry in handle_async (unified Dial + router loop).""" + + @pytest.mark.anyio + async def test_retries_on_dial_unavailable_then_succeeds(self): + """Should retry on UNAVAILABLE from Dial and succeed on the next attempt.""" + lease = _make_lease_for_handle() + + dial_response = Mock(router_endpoint="ep", router_token="tok") + call_count = 0 + + async def dial_side_effect(req): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise _make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE, "tunnel dropped") + return dial_response + + lease.controller.Dial = AsyncMock(side_effect=dial_side_effect) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + with patch("jumpstarter.client.lease.connect_router_stream") as mock_router: + + @asynccontextmanager + async def fake_router(*args, **kwargs): + yield + + mock_router.side_effect = fake_router + await lease.handle_async(Mock()) + + assert call_count == 2 + + @pytest.mark.anyio + async def test_transient_error_returns_after_timeout(self): + """Should give up and return when dial_timeout is exceeded.""" + lease = _make_lease_for_handle() + lease.dial_timeout = 0.0 # already expired + + lease.controller.Dial = AsyncMock( + side_effect=_make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE, "tunnel dropped"), + ) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + await lease.handle_async(Mock()) + + # Should return without raising + lease.controller.Dial.assert_called_once() + + @pytest.mark.anyio + @pytest.mark.parametrize( + "status_code", + [ + grpc.StatusCode.RESOURCE_EXHAUSTED, + grpc.StatusCode.ABORTED, + grpc.StatusCode.INTERNAL, + ], + ids=["RESOURCE_EXHAUSTED", "ABORTED", "INTERNAL"], + ) + async def test_retries_multiple_transient_codes(self, status_code): + """Should retry on RESOURCE_EXHAUSTED, ABORTED, INTERNAL.""" + lease = _make_lease_for_handle() + dial_response = Mock(router_endpoint="ep", router_token="tok") + call_count = 0 + + async def dial_side_effect(req): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise _make_aio_rpc_error(status_code, "transient") + return dial_response + + lease.controller.Dial = AsyncMock(side_effect=dial_side_effect) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + with patch("jumpstarter.client.lease.connect_router_stream") as mock_router: + + @asynccontextmanager + async def fake_router(*args, **kwargs): + yield + + mock_router.side_effect = fake_router + await lease.handle_async(Mock()) + + assert call_count == 2, f"Expected 2 calls for {status_code}, got {call_count}" + + @pytest.mark.anyio + async def test_retries_unknown_with_watch_channel_closed(self): + """Should retry UNKNOWN only when details contain 'watch channel closed'.""" + lease = _make_lease_for_handle() + dial_response = Mock(router_endpoint="ep", router_token="tok") + call_count = 0 + + async def dial_side_effect(req): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise _make_aio_rpc_error(grpc.StatusCode.UNKNOWN, "watch channel closed") + return dial_response + + lease.controller.Dial = AsyncMock(side_effect=dial_side_effect) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + with patch("jumpstarter.client.lease.connect_router_stream") as mock_router: + + @asynccontextmanager + async def fake_router(*args, **kwargs): + yield + + mock_router.side_effect = fake_router + await lease.handle_async(Mock()) + + assert call_count == 2 + + @pytest.mark.anyio + async def test_unknown_without_known_message_not_retried(self): + """UNKNOWN with an unrecognized message should NOT be retried.""" + lease = _make_lease_for_handle() + + lease.controller.Dial = AsyncMock( + side_effect=_make_aio_rpc_error(grpc.StatusCode.UNKNOWN, "some unexpected server bug"), + ) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + await lease.handle_async(Mock()) + + # Should return after just one attempt (no retry) + lease.controller.Dial.assert_called_once() + + @pytest.mark.anyio + async def test_router_transient_error_retries_full_dial_and_connect(self): + """Router transient error should retry the full Dial + connect cycle.""" + lease = _make_lease_for_handle() + dial_response = Mock(router_endpoint="ep", router_token="tok") + lease.controller.Dial = AsyncMock(return_value=dial_response) + + connect_count = 0 + + @asynccontextmanager + async def fake_router(*args, **kwargs): + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + raise _make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE, "router unreachable") + yield + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + with patch("jumpstarter.client.lease.connect_router_stream", side_effect=fake_router): + await lease.handle_async(Mock()) + + assert connect_count == 2 + # Dial is called fresh each attempt (unified loop) + assert lease.controller.Dial.call_count == 2 + + @pytest.mark.anyio + async def test_non_transient_error_returns_immediately(self): + """Non-transient errors should not be retried.""" + lease = _make_lease_for_handle() + dial_response = Mock(router_endpoint="ep", router_token="tok") + lease.controller.Dial = AsyncMock(return_value=dial_response) + + @asynccontextmanager + async def fail_router(*args, **kwargs): + raise _make_aio_rpc_error(grpc.StatusCode.NOT_FOUND, "not found") + yield # pragma: no cover + + with patch("jumpstarter.client.lease.connect_router_stream", side_effect=fail_router): + await lease.handle_async(Mock()) + + # Only one Dial attempt, no retry + assert lease.controller.Dial.call_count == 1 + + @pytest.mark.anyio + async def test_transient_router_error_returns_after_timeout(self): + """Should give up when dial_timeout is exceeded during router retries.""" + lease = _make_lease_for_handle() + lease.dial_timeout = 0.0 # already expired + dial_response = Mock(router_endpoint="ep", router_token="tok") + lease.controller.Dial = AsyncMock(return_value=dial_response) + + @asynccontextmanager + async def fail_router(*args, **kwargs): + raise _make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE, "unreachable") + yield # pragma: no cover + + with patch("jumpstarter.client.lease.connect_router_stream", side_effect=fail_router): + await lease.handle_async(Mock()) + + # Only one Dial (initial), no retry + assert lease.controller.Dial.call_count == 1 + + @pytest.mark.anyio + async def test_dial_failure_on_retry_is_retried_again(self): + """When Dial fails with a transient error during retry, it should keep retrying.""" + lease = _make_lease_for_handle() + dial_response = Mock(router_endpoint="ep", router_token="tok") + + dial_count = 0 + + async def dial_side_effect(req): + nonlocal dial_count + dial_count += 1 + if dial_count == 1: + return dial_response # first Dial succeeds, router will fail + if dial_count == 2: + raise _make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE, "re-dial failed") + return dial_response # third Dial succeeds + + lease.controller.Dial = AsyncMock(side_effect=dial_side_effect) + + connect_count = 0 + + @asynccontextmanager + async def fake_router(*args, **kwargs): + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + raise _make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE, "router fail") + yield + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + with patch("jumpstarter.client.lease.connect_router_stream", side_effect=fake_router): + await lease.handle_async(Mock()) + + # Attempt 1: Dial OK -> router fails (UNAVAILABLE) + # Attempt 2: Dial fails (UNAVAILABLE) -> retried + # Attempt 3: Dial OK -> router OK + assert dial_count == 3 + assert connect_count == 2 + + @pytest.mark.anyio + async def test_oserror_retries_then_succeeds(self): + """OSError from router should retry the full Dial + connect cycle.""" + lease = _make_lease_for_handle() + dial_response = Mock(router_endpoint="ep", router_token="tok") + lease.controller.Dial = AsyncMock(return_value=dial_response) + + connect_count = 0 + + @asynccontextmanager + async def fake_router(*args, **kwargs): + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + raise OSError("Connection refused") + yield + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + with patch("jumpstarter.client.lease.connect_router_stream", side_effect=fake_router): + await lease.handle_async(Mock()) + + assert connect_count == 2 + # Dial called fresh each attempt + assert lease.controller.Dial.call_count == 2 + + @pytest.mark.anyio + async def test_oserror_returns_after_timeout(self): + """Should give up on OSError when dial_timeout is exceeded.""" + lease = _make_lease_for_handle() + lease.dial_timeout = 0.0 # already expired + dial_response = Mock(router_endpoint="ep", router_token="tok") + lease.controller.Dial = AsyncMock(return_value=dial_response) + + @asynccontextmanager + async def fail_router(*args, **kwargs): + raise OSError("Connection refused") + yield # pragma: no cover + + with patch("jumpstarter.client.lease.connect_router_stream", side_effect=fail_router): + await lease.handle_async(Mock()) + + # Only the initial Dial, no retry + assert lease.controller.Dial.call_count == 1 + + @pytest.mark.anyio + async def test_exponential_backoff_delay_values(self): + """Verify that sleep delays follow exponential backoff: 0.3, 0.6, 1.2, 2.4, 4.8, capped at 5.0.""" + lease = _make_lease_for_handle() + lease.dial_timeout = 60.0 # large timeout so remaining doesn't cap delays + + # Fail 6 times then succeed on the 7th attempt + total_failures = 6 + call_count = 0 + dial_response = Mock(router_endpoint="ep", router_token="tok") + + async def dial_side_effect(req): + nonlocal call_count + call_count += 1 + if call_count <= total_failures: + raise _make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE, "tunnel dropped") + return dial_response + + lease.controller.Dial = AsyncMock(side_effect=dial_side_effect) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock) as mock_sleep: + with patch("jumpstarter.client.lease.connect_router_stream") as mock_router: + + @asynccontextmanager + async def fake_router(*args, **kwargs): + yield + + mock_router.side_effect = fake_router + await lease.handle_async(Mock()) + + assert call_count == total_failures + 1 + + # Verify exponential backoff: base_delay=0.3, max_delay=5.0 + # attempt 0: 0.3 * 2^0 = 0.3 + # attempt 1: 0.3 * 2^1 = 0.6 + # attempt 2: 0.3 * 2^2 = 1.2 + # attempt 3: 0.3 * 2^3 = 2.4 + # attempt 4: 0.3 * 2^4 = 4.8 + # attempt 5: min(0.3 * 2^5, 5.0) = min(9.6, 5.0) = 5.0 + expected_delays = [0.3, 0.6, 1.2, 2.4, 4.8, 5.0] + actual_delays = [call.args[0] for call in mock_sleep.call_args_list] + assert len(actual_delays) == len(expected_delays) + for actual, expected in zip(actual_delays, expected_delays, strict=True): + assert actual == pytest.approx(expected), f"Expected delay {expected}, got {actual}" + + @pytest.mark.anyio + async def test_failed_precondition_not_ready_retries_then_succeeds(self): + """FAILED_PRECONDITION 'not ready' should retry and succeed on next attempt.""" + lease = _make_lease_for_handle() + dial_response = Mock(router_endpoint="ep", router_token="tok") + call_count = 0 + + async def dial_side_effect(req): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise _make_aio_rpc_error(grpc.StatusCode.FAILED_PRECONDITION, "exporter not ready") + return dial_response + + lease.controller.Dial = AsyncMock(side_effect=dial_side_effect) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + with patch("jumpstarter.client.lease.connect_router_stream") as mock_router: + + @asynccontextmanager + async def fake_router(*args, **kwargs): + yield + + mock_router.side_effect = fake_router + await lease.handle_async(Mock()) + + assert call_count == 2 + + @pytest.mark.anyio + async def test_failed_precondition_returns_after_timeout(self): + """FAILED_PRECONDITION should return (not raise) when dial_timeout is exceeded.""" + lease = _make_lease_for_handle() + lease.dial_timeout = 0.0 # already expired + + lease.controller.Dial = AsyncMock( + side_effect=_make_aio_rpc_error(grpc.StatusCode.FAILED_PRECONDITION, "exporter not ready"), + ) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + # Should return without raising + await lease.handle_async(Mock()) + + lease.controller.Dial.assert_called_once() @pytest.mark.anyio - async def test_handle_async_retries_unavailable_then_succeeds(self): - """Dial returns UNAVAILABLE once then succeeds on retry.""" - lease = self._make_lease_for_handle() - dial_call_count = 0 + async def test_permission_denied_sets_lease_transferred(self): + """PERMISSION_DENIED should set lease_transferred = True.""" + lease = _make_lease_for_handle() + assert lease.lease_transferred is False - async def mock_dial(request): - nonlocal dial_call_count - dial_call_count += 1 - if dial_call_count == 1: - raise MockAioRpcError(grpc.StatusCode.UNAVAILABLE, "temporarily unavailable") - return Mock(router_endpoint="endpoint", router_token="token") + lease.controller.Dial = AsyncMock( + side_effect=_make_aio_rpc_error(grpc.StatusCode.PERMISSION_DENIED, "permission denied"), + ) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + await lease.handle_async(Mock()) + + assert lease.lease_transferred is True + + @pytest.mark.anyio + async def test_permission_denied_with_custom_details_still_detected(self): + """PERMISSION_DENIED with non-standard detail text should still set lease_transferred.""" + lease = _make_lease_for_handle() + assert lease.lease_transferred is False + + lease.controller.Dial = AsyncMock( + side_effect=_make_aio_rpc_error(grpc.StatusCode.PERMISSION_DENIED, "lease reassigned to another client"), + ) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + await lease.handle_async(Mock()) - lease.controller.Dial = mock_dial + assert lease.lease_transferred is True - with patch("jumpstarter.client.lease.connect_router_stream") as mock_connect: - mock_connect.return_value.__aenter__ = AsyncMock() - mock_connect.return_value.__aexit__ = AsyncMock(return_value=False) - stream = Mock() + @pytest.mark.anyio + async def test_unauthenticated_with_permission_text_does_not_set_transferred(self): + """UNAUTHENTICATED with 'permission denied' in details should NOT set lease_transferred.""" + lease = _make_lease_for_handle() + assert lease.lease_transferred is False + + lease.controller.Dial = AsyncMock( + side_effect=_make_aio_rpc_error( + grpc.StatusCode.UNAUTHENTICATED, + "permission denied: token expired", + ), + ) - await lease.handle_async(stream) + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + await lease.handle_async(Mock()) - assert dial_call_count == 2 - mock_connect.assert_called_once_with("endpoint", "token", stream, lease.tls_config, lease.grpc_options) + assert lease.lease_transferred is False @pytest.mark.anyio - async def test_handle_async_unavailable_exceeds_dial_timeout(self): - """Dial returns UNAVAILABLE until dial_timeout is exceeded, then raises.""" - lease = self._make_lease_for_handle() - lease.dial_timeout = 0.5 - dial_call_count = 0 + async def test_channel_ready_timeout_bounded_by_remaining(self): + """channel_ready_timeout should decrease as the dial deadline approaches.""" + lease = _make_lease_for_handle() + lease.dial_timeout = 3.0 + + call_count = 0 + captured_timeouts = [] + + async def tracking_dial_and_connect(self_inner, stream, channel_ready_timeout=10.0): + nonlocal call_count + call_count += 1 + captured_timeouts.append(channel_ready_timeout) + if call_count <= 3: + raise _make_aio_rpc_error(grpc.StatusCode.FAILED_PRECONDITION, "exporter not ready") + # Succeed on 4th attempt (won't normally reach here with 3s timeout) + + with ( + patch.object(type(lease), "_dial_and_connect", tracking_dial_and_connect), + patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock), + ): + await lease.handle_async(Mock()) + + # With a 3s dial_timeout, the first call should have channel_ready_timeout <= 3.0 + # and subsequent calls should have progressively smaller values + assert len(captured_timeouts) >= 2 + assert all(t <= 10.0 for t in captured_timeouts), f"All timeouts should be <= 10.0, got {captured_timeouts}" + # The first timeout should be bounded by remaining (~3.0), not the default 10.0 + assert captured_timeouts[0] <= 3.1, ( + f"First timeout should be bounded by dial_timeout (~3.0), got {captured_timeouts[0]}" + ) + + +class TestRetryDelay: + """Tests for the _retry_delay static method.""" + + def test_basic_exponential(self): + assert Lease._retry_delay(0, 60.0) == pytest.approx(0.3) + assert Lease._retry_delay(1, 60.0) == pytest.approx(0.6) + assert Lease._retry_delay(2, 60.0) == pytest.approx(1.2) + + def test_capped_by_max(self): + assert Lease._retry_delay(10, 60.0) == pytest.approx(5.0) + + def test_capped_by_remaining(self): + assert Lease._retry_delay(0, 0.1) == pytest.approx(0.1) + + +class TestTransientGrpcCodes: + """Tests for the _TRANSIENT_GRPC_CODES class attribute.""" - async def mock_dial(request): - nonlocal dial_call_count - dial_call_count += 1 - raise MockAioRpcError(grpc.StatusCode.UNAVAILABLE, "permanently unavailable") + def test_contains_expected_codes(self): + assert grpc.StatusCode.UNAVAILABLE in Lease._TRANSIENT_GRPC_CODES + assert grpc.StatusCode.RESOURCE_EXHAUSTED in Lease._TRANSIENT_GRPC_CODES + assert grpc.StatusCode.ABORTED in Lease._TRANSIENT_GRPC_CODES + assert grpc.StatusCode.INTERNAL in Lease._TRANSIENT_GRPC_CODES - lease.controller.Dial = mock_dial - stream = Mock() + def test_unknown_not_in_blanket_transient_codes(self): + """UNKNOWN is handled separately via _TRANSIENT_UNKNOWN_MESSAGES.""" + assert grpc.StatusCode.UNKNOWN not in Lease._TRANSIENT_GRPC_CODES - with pytest.raises(AioRpcError) as exc_info: - await lease.handle_async(stream) + def test_does_not_contain_non_transient_codes(self): + assert grpc.StatusCode.PERMISSION_DENIED not in Lease._TRANSIENT_GRPC_CODES + assert grpc.StatusCode.NOT_FOUND not in Lease._TRANSIENT_GRPC_CODES + assert grpc.StatusCode.FAILED_PRECONDITION not in Lease._TRANSIENT_GRPC_CODES - assert exc_info.value.code() == grpc.StatusCode.UNAVAILABLE - assert dial_call_count >= 2 + def test_transient_unknown_messages(self): + """Should contain the known tunnel teardown messages.""" + assert "watch channel closed" in Lease._TRANSIENT_UNKNOWN_MESSAGES diff --git a/python/packages/jumpstarter/jumpstarter/common/streams.py b/python/packages/jumpstarter/jumpstarter/common/streams.py index 8cdc02330..19c0280ca 100644 --- a/python/packages/jumpstarter/jumpstarter/common/streams.py +++ b/python/packages/jumpstarter/jumpstarter/common/streams.py @@ -1,3 +1,4 @@ +import asyncio from contextlib import asynccontextmanager from typing import Annotated, Literal, Union from uuid import UUID @@ -34,13 +35,27 @@ class StreamRequestMetadata(BaseModel): @asynccontextmanager -async def connect_router_stream(endpoint, token, stream, tls_config, grpc_options): +async def connect_router_stream(endpoint, token, stream, tls_config, grpc_options, channel_ready_timeout: float = 10): credentials = grpc.composite_channel_credentials( await ssl_channel_credentials(endpoint, tls_config), grpc.access_token_call_credentials(token), ) async with aio_secure_channel(endpoint, credentials, grpc_options) as channel: + # Wait for the channel to be ready before starting the stream. + # Without this, a broken router connection would cause the gRPC + # stream to hang indefinitely waiting for the HTTP/2 SETTINGS frame, + # which manifests as a timeout for the j command on the Unix socket. + try: + await asyncio.wait_for(channel.channel_ready(), timeout=channel_ready_timeout) + except asyncio.TimeoutError: + raise grpc.aio.AioRpcError( + code=grpc.StatusCode.UNAVAILABLE, + initial_metadata=grpc.aio.Metadata(), + trailing_metadata=grpc.aio.Metadata(), + details=f"Timed out waiting for router channel to become ready ({channel_ready_timeout}s)", + debug_error_string=None, + ) from None router = router_pb2_grpc.RouterServiceStub(channel) context = router.Stream(metadata=()) async with RouterStream(context=context) as s: diff --git a/python/packages/jumpstarter/jumpstarter/common/streams_test.py b/python/packages/jumpstarter/jumpstarter/common/streams_test.py new file mode 100644 index 000000000..be89d5ee8 --- /dev/null +++ b/python/packages/jumpstarter/jumpstarter/common/streams_test.py @@ -0,0 +1,88 @@ +import asyncio +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, Mock, patch + +import grpc +import pytest +from grpc.aio import AioRpcError + +from jumpstarter.common.streams import connect_router_stream + + +class TestConnectRouterStreamChannelReady: + """Tests for the channel_ready timeout logic in connect_router_stream.""" + + @pytest.mark.anyio + async def test_raises_unavailable_on_channel_ready_timeout(self): + """When channel_ready() times out, an AioRpcError with UNAVAILABLE should be raised.""" + mock_channel = Mock() + + # Make channel_ready() return a coroutine that never completes + async def hang_forever(): + await asyncio.sleep(999) + + mock_channel.channel_ready = Mock(return_value=hang_forever()) + + @asynccontextmanager + async def fake_secure_channel(*args, **kwargs): + yield mock_channel + + with ( + patch("jumpstarter.common.streams.ssl_channel_credentials", new_callable=AsyncMock), + patch("jumpstarter.common.streams.aio_secure_channel", side_effect=fake_secure_channel), + patch("grpc.composite_channel_credentials", return_value=Mock()), + patch("grpc.access_token_call_credentials", return_value=Mock()), + ): + with pytest.raises(AioRpcError) as exc_info: + async with connect_router_stream( + "endpoint:443", "token", Mock(), Mock(), {}, channel_ready_timeout=0.01 + ): + pass # pragma: no cover + + assert exc_info.value.code() == grpc.StatusCode.UNAVAILABLE + assert "Timed out" in str(exc_info.value.details()) + + @pytest.mark.anyio + async def test_proceeds_when_channel_ready_succeeds(self): + """When channel_ready() succeeds quickly, the stream should be set up normally.""" + mock_channel = Mock() + + # channel_ready() resolves immediately + async def ready_immediately(): + pass + + mock_channel.channel_ready = Mock(return_value=ready_immediately()) + + mock_context = Mock() + + @asynccontextmanager + async def fake_secure_channel(*args, **kwargs): + yield mock_channel + + @asynccontextmanager + async def fake_router_stream(*args, **kwargs): + yield Mock() + + @asynccontextmanager + async def fake_forward(*args, **kwargs): + yield + + with ( + patch("jumpstarter.common.streams.ssl_channel_credentials", new_callable=AsyncMock), + patch("jumpstarter.common.streams.aio_secure_channel", side_effect=fake_secure_channel), + patch("grpc.composite_channel_credentials", return_value=Mock()), + patch("grpc.access_token_call_credentials", return_value=Mock()), + patch("jumpstarter.common.streams.router_pb2_grpc.RouterServiceStub") as mock_stub_cls, + patch("jumpstarter.common.streams.RouterStream", side_effect=fake_router_stream), + patch("jumpstarter.common.streams.forward_stream", side_effect=fake_forward), + ): + mock_stub = Mock() + mock_stub.Stream.return_value = mock_context + mock_stub_cls.return_value = mock_stub + + async with connect_router_stream( + "endpoint:443", "token", Mock(), Mock(), {}, channel_ready_timeout=5 + ): + pass # Successfully entered the context + + mock_channel.channel_ready.assert_called_once()