diff --git a/python/packages/jumpstarter/jumpstarter/client/lease.py b/python/packages/jumpstarter/jumpstarter/client/lease.py index f0678556b..25a9bad4e 100644 --- a/python/packages/jumpstarter/jumpstarter/client/lease.py +++ b/python/packages/jumpstarter/jumpstarter/client/lease.py @@ -22,6 +22,7 @@ fail_after, sleep, ) +from anyio.abc import SocketStream from anyio.from_thread import BlockingPortal from grpc.aio import AioRpcError, Channel from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc @@ -312,33 +313,70 @@ def __contextmanager__(self) -> Generator[Self]: with self.portal.wrap_async_context_manager(self) as value: yield value - async def handle_async(self, stream): + # DEADLINE_EXCEEDED and CANCELLED are excluded: they indicate client-side + # timeout or cancellation, not server/network transients worth retrying. + _TRANSIENT_GRPC_CODES = frozenset({ + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.RESOURCE_EXHAUSTED, + grpc.StatusCode.ABORTED, + grpc.StatusCode.INTERNAL, + }) + + # UNKNOWN error messages that indicate transient tunnel teardowns. + # We don't blanket-retry all UNKNOWN errors (they could be permanent + # server bugs), but specific messages like "watch channel closed" are + # known to occur during tunnel reconnection. + _TRANSIENT_UNKNOWN_MESSAGES = ("watch channel closed",) + + @staticmethod + def _retry_delay(attempt: int, remaining: float, base: float = 0.3, cap: float = 5.0) -> float: + """Compute exponential-backoff delay, capped by *cap* and *remaining* time.""" + return min(base * (2**attempt), cap, remaining) + + async def _dial_and_connect( + self, stream: SocketStream, channel_ready_timeout: float = 10.0 + ) -> None: + """Single attempt; raises on failure for caller-driven retry.""" + response = await self.controller.Dial(jumpstarter_pb2.DialRequest(lease_name=self.name)) + async with connect_router_stream( + response.router_endpoint, + response.router_token, + stream, + self.tls_config, + self.grpc_options, + channel_ready_timeout=channel_ready_timeout, + ): + pass + + async def handle_async(self, stream: SocketStream) -> None: logger.debug("Connecting to Lease with name %s", self.name) - # Retry Dial with exponential backoff for transient "exporter not ready" errors. - # This handles the race condition where the client acquires a lease before - # the exporter has transitioned to LEASE_READY status. - # Uses time-based retry bounded by dial_timeout instead of fixed retry count. - base_delay = 0.3 - max_delay = 2.0 + # Retry Dial + router connection with exponential backoff. + # Handles FAILED_PRECONDITION (exporter not yet ready), transient + # network errors (tunnel drops), and OSError (unreachable endpoint). + # All error paths return instead of raising because handle_async runs + # inside TemporaryUnixListener.serve's task group -- an unhandled + # exception would crash the listener and terminate sibling connections. deadline = time.monotonic() + self.dial_timeout attempt = 0 while True: + remaining = deadline - time.monotonic() + channel_ready_timeout = max(min(10.0, remaining), 0.5) try: - response = await self.controller.Dial(jumpstarter_pb2.DialRequest(lease_name=self.name)) - break + await self._dial_and_connect(stream, channel_ready_timeout=channel_ready_timeout) + return except AioRpcError as e: + remaining = deadline - time.monotonic() if e.code() == grpc.StatusCode.FAILED_PRECONDITION and "not ready" in str(e.details()): - remaining = deadline - time.monotonic() if remaining <= 0: - logger.debug( + logger.warning( "Exporter not ready and dial timeout (%.1fs) exceeded after %d attempts", self.dial_timeout, attempt + 1, ) - raise - delay = min(base_delay * (2**attempt), max_delay, remaining) + return + delay = self._retry_delay(attempt, remaining) logger.debug( - "Exporter not ready, retrying Dial in %.1fs (attempt %d, %.1fs remaining)", + "Exporter not ready, retrying in %.1fs (attempt %d, %.1fs remaining)", delay, attempt + 1, remaining, @@ -346,7 +384,31 @@ async def handle_async(self, stream): await sleep(delay) attempt += 1 continue - # Exporter went offline or lease ended - log and exit gracefully + is_transient = e.code() in self._TRANSIENT_GRPC_CODES or ( + e.code() == grpc.StatusCode.UNKNOWN + and any(msg in str(e.details()).lower() for msg in self._TRANSIENT_UNKNOWN_MESSAGES) + ) + if is_transient: + if remaining <= 0: + logger.warning( + "Connection failed with transient error after %d attempts (%.1fs elapsed): %s", + attempt + 1, + self.dial_timeout, + e.details(), + ) + return + delay = self._retry_delay(attempt, remaining) + logger.info( + "Connection failed with %s, retrying in %.1fs (attempt %d, %.1fs remaining): %s", + e.code().name, + delay, + attempt + 1, + remaining, + e.details(), + ) + await sleep(delay) + attempt += 1 + continue if "permission denied" in str(e.details()).lower(): self.lease_transferred = True logger.warning( @@ -356,10 +418,22 @@ async def handle_async(self, stream): else: logger.warning("Connection to exporter lost: %s", e.details()) return - async with connect_router_stream( - response.router_endpoint, response.router_token, stream, self.tls_config, self.grpc_options - ): - pass + except OSError as e: + remaining = deadline - time.monotonic() + if remaining > 0: + delay = self._retry_delay(attempt, remaining) + logger.info( + "Connection failed with OSError, retrying in %.1fs (attempt %d, %.1fs remaining): %s", + delay, + attempt + 1, + remaining, + e, + ) + await sleep(delay) + attempt += 1 + continue + logger.warning("Connection failed: %s", e) + return @asynccontextmanager async def serve_unix_async(self): diff --git a/python/packages/jumpstarter/jumpstarter/client/lease_test.py b/python/packages/jumpstarter/jumpstarter/client/lease_test.py index 87a3f16be..c02052a07 100644 --- a/python/packages/jumpstarter/jumpstarter/client/lease_test.py +++ b/python/packages/jumpstarter/jumpstarter/client/lease_test.py @@ -1,10 +1,13 @@ import asyncio import logging import sys +from contextlib import asynccontextmanager from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, Mock, patch +import grpc import pytest +from grpc.aio import AioRpcError from rich.console import Console from jumpstarter.client.exceptions import LeaseError @@ -554,3 +557,457 @@ async def get_then_fail(): callback.assert_called() _, remain_arg = callback.call_args[0] assert remain_arg == timedelta(0) + + +def _make_aio_rpc_error(code, details="error"): + """Helper to construct an AioRpcError.""" + return AioRpcError( + code=code, + initial_metadata=grpc.aio.Metadata(), + trailing_metadata=grpc.aio.Metadata(), + details=details, + debug_error_string=None, + ) + + +def _make_lease_for_handle(): + """Create a minimal Lease for testing handle_async.""" + lease = object.__new__(Lease) + lease.name = "test-lease" + lease.dial_timeout = 5.0 + lease.tls_config = Mock() + lease.grpc_options = {} + lease.controller = Mock() + lease.lease_transferred = False + return lease + + +class TestHandleAsyncTransientRetry: + """Tests for transient gRPC error retry in handle_async (unified Dial + router loop).""" + + @pytest.mark.anyio + async def test_retries_on_dial_unavailable_then_succeeds(self): + """Should retry on UNAVAILABLE from Dial and succeed on the next attempt.""" + lease = _make_lease_for_handle() + + dial_response = Mock(router_endpoint="ep", router_token="tok") + call_count = 0 + + async def dial_side_effect(req): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise _make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE, "tunnel dropped") + return dial_response + + lease.controller.Dial = AsyncMock(side_effect=dial_side_effect) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + with patch("jumpstarter.client.lease.connect_router_stream") as mock_router: + + @asynccontextmanager + async def fake_router(*args, **kwargs): + yield + + mock_router.side_effect = fake_router + await lease.handle_async(Mock()) + + assert call_count == 2 + + @pytest.mark.anyio + async def test_transient_error_returns_after_timeout(self): + """Should give up and return when dial_timeout is exceeded.""" + lease = _make_lease_for_handle() + lease.dial_timeout = 0.0 # already expired + + lease.controller.Dial = AsyncMock( + side_effect=_make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE, "tunnel dropped"), + ) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + await lease.handle_async(Mock()) + + # Should return without raising + lease.controller.Dial.assert_called_once() + + @pytest.mark.anyio + @pytest.mark.parametrize( + "status_code", + [ + grpc.StatusCode.RESOURCE_EXHAUSTED, + grpc.StatusCode.ABORTED, + grpc.StatusCode.INTERNAL, + ], + ids=["RESOURCE_EXHAUSTED", "ABORTED", "INTERNAL"], + ) + async def test_retries_multiple_transient_codes(self, status_code): + """Should retry on RESOURCE_EXHAUSTED, ABORTED, INTERNAL.""" + lease = _make_lease_for_handle() + dial_response = Mock(router_endpoint="ep", router_token="tok") + call_count = 0 + + async def dial_side_effect(req): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise _make_aio_rpc_error(status_code, "transient") + return dial_response + + lease.controller.Dial = AsyncMock(side_effect=dial_side_effect) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + with patch("jumpstarter.client.lease.connect_router_stream") as mock_router: + + @asynccontextmanager + async def fake_router(*args, **kwargs): + yield + + mock_router.side_effect = fake_router + await lease.handle_async(Mock()) + + assert call_count == 2, f"Expected 2 calls for {status_code}, got {call_count}" + + @pytest.mark.anyio + async def test_retries_unknown_with_watch_channel_closed(self): + """Should retry UNKNOWN only when details contain 'watch channel closed'.""" + lease = _make_lease_for_handle() + dial_response = Mock(router_endpoint="ep", router_token="tok") + call_count = 0 + + async def dial_side_effect(req): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise _make_aio_rpc_error( + grpc.StatusCode.UNKNOWN, "watch channel closed" + ) + return dial_response + + lease.controller.Dial = AsyncMock(side_effect=dial_side_effect) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + with patch("jumpstarter.client.lease.connect_router_stream") as mock_router: + + @asynccontextmanager + async def fake_router(*args, **kwargs): + yield + + mock_router.side_effect = fake_router + await lease.handle_async(Mock()) + + assert call_count == 2 + + @pytest.mark.anyio + async def test_unknown_without_known_message_not_retried(self): + """UNKNOWN with an unrecognized message should NOT be retried.""" + lease = _make_lease_for_handle() + + lease.controller.Dial = AsyncMock( + side_effect=_make_aio_rpc_error( + grpc.StatusCode.UNKNOWN, "some unexpected server bug" + ), + ) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + await lease.handle_async(Mock()) + + # Should return after just one attempt (no retry) + lease.controller.Dial.assert_called_once() + + @pytest.mark.anyio + async def test_router_transient_error_retries_full_dial_and_connect(self): + """Router transient error should retry the full Dial + connect cycle.""" + lease = _make_lease_for_handle() + dial_response = Mock(router_endpoint="ep", router_token="tok") + lease.controller.Dial = AsyncMock(return_value=dial_response) + + connect_count = 0 + + @asynccontextmanager + async def fake_router(*args, **kwargs): + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + raise _make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE, "router unreachable") + yield + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + with patch("jumpstarter.client.lease.connect_router_stream", side_effect=fake_router): + await lease.handle_async(Mock()) + + assert connect_count == 2 + # Dial is called fresh each attempt (unified loop) + assert lease.controller.Dial.call_count == 2 + + @pytest.mark.anyio + async def test_non_transient_error_returns_immediately(self): + """Non-transient errors should not be retried.""" + lease = _make_lease_for_handle() + dial_response = Mock(router_endpoint="ep", router_token="tok") + lease.controller.Dial = AsyncMock(return_value=dial_response) + + @asynccontextmanager + async def fail_router(*args, **kwargs): + raise _make_aio_rpc_error(grpc.StatusCode.NOT_FOUND, "not found") + yield # pragma: no cover + + with patch("jumpstarter.client.lease.connect_router_stream", side_effect=fail_router): + await lease.handle_async(Mock()) + + # Only one Dial attempt, no retry + assert lease.controller.Dial.call_count == 1 + + @pytest.mark.anyio + async def test_transient_router_error_returns_after_timeout(self): + """Should give up when dial_timeout is exceeded during router retries.""" + lease = _make_lease_for_handle() + lease.dial_timeout = 0.0 # already expired + dial_response = Mock(router_endpoint="ep", router_token="tok") + lease.controller.Dial = AsyncMock(return_value=dial_response) + + @asynccontextmanager + async def fail_router(*args, **kwargs): + raise _make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE, "unreachable") + yield # pragma: no cover + + with patch("jumpstarter.client.lease.connect_router_stream", side_effect=fail_router): + await lease.handle_async(Mock()) + + # Only one Dial (initial), no retry + assert lease.controller.Dial.call_count == 1 + + @pytest.mark.anyio + async def test_dial_failure_on_retry_is_retried_again(self): + """When Dial fails with a transient error during retry, it should keep retrying.""" + lease = _make_lease_for_handle() + dial_response = Mock(router_endpoint="ep", router_token="tok") + + dial_count = 0 + + async def dial_side_effect(req): + nonlocal dial_count + dial_count += 1 + if dial_count == 1: + return dial_response # first Dial succeeds, router will fail + if dial_count == 2: + raise _make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE, "re-dial failed") + return dial_response # third Dial succeeds + + lease.controller.Dial = AsyncMock(side_effect=dial_side_effect) + + connect_count = 0 + + @asynccontextmanager + async def fake_router(*args, **kwargs): + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + raise _make_aio_rpc_error(grpc.StatusCode.UNAVAILABLE, "router fail") + yield + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + with patch("jumpstarter.client.lease.connect_router_stream", side_effect=fake_router): + await lease.handle_async(Mock()) + + # Attempt 1: Dial OK -> router fails (UNAVAILABLE) + # Attempt 2: Dial fails (UNAVAILABLE) -> retried + # Attempt 3: Dial OK -> router OK + assert dial_count == 3 + assert connect_count == 2 + + @pytest.mark.anyio + async def test_oserror_retries_then_succeeds(self): + """OSError from router should retry the full Dial + connect cycle.""" + lease = _make_lease_for_handle() + dial_response = Mock(router_endpoint="ep", router_token="tok") + lease.controller.Dial = AsyncMock(return_value=dial_response) + + connect_count = 0 + + @asynccontextmanager + async def fake_router(*args, **kwargs): + nonlocal connect_count + connect_count += 1 + if connect_count == 1: + raise OSError("Connection refused") + yield + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + with patch("jumpstarter.client.lease.connect_router_stream", side_effect=fake_router): + await lease.handle_async(Mock()) + + assert connect_count == 2 + # Dial called fresh each attempt + assert lease.controller.Dial.call_count == 2 + + @pytest.mark.anyio + async def test_oserror_returns_after_timeout(self): + """Should give up on OSError when dial_timeout is exceeded.""" + lease = _make_lease_for_handle() + lease.dial_timeout = 0.0 # already expired + dial_response = Mock(router_endpoint="ep", router_token="tok") + lease.controller.Dial = AsyncMock(return_value=dial_response) + + @asynccontextmanager + async def fail_router(*args, **kwargs): + raise OSError("Connection refused") + yield # pragma: no cover + + with patch("jumpstarter.client.lease.connect_router_stream", side_effect=fail_router): + await lease.handle_async(Mock()) + + # Only the initial Dial, no retry + assert lease.controller.Dial.call_count == 1 + + @pytest.mark.anyio + async def test_exponential_backoff_delay_values(self): + """Verify that sleep delays follow exponential backoff: 0.3, 0.6, 1.2, 2.4, 4.8, capped at 5.0.""" + lease = _make_lease_for_handle() + lease.dial_timeout = 60.0 # large timeout so remaining doesn't cap delays + + # Fail 6 times then succeed on the 7th attempt + total_failures = 6 + call_count = 0 + dial_response = Mock(router_endpoint="ep", router_token="tok") + + async def dial_side_effect(req): + nonlocal call_count + call_count += 1 + if call_count <= total_failures: + raise _make_aio_rpc_error( + grpc.StatusCode.UNAVAILABLE, "tunnel dropped" + ) + return dial_response + + lease.controller.Dial = AsyncMock(side_effect=dial_side_effect) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock) as mock_sleep: + with patch("jumpstarter.client.lease.connect_router_stream") as mock_router: + + @asynccontextmanager + async def fake_router(*args, **kwargs): + yield + + mock_router.side_effect = fake_router + await lease.handle_async(Mock()) + + assert call_count == total_failures + 1 + + # Verify exponential backoff: base_delay=0.3, max_delay=5.0 + # attempt 0: 0.3 * 2^0 = 0.3 + # attempt 1: 0.3 * 2^1 = 0.6 + # attempt 2: 0.3 * 2^2 = 1.2 + # attempt 3: 0.3 * 2^3 = 2.4 + # attempt 4: 0.3 * 2^4 = 4.8 + # attempt 5: min(0.3 * 2^5, 5.0) = min(9.6, 5.0) = 5.0 + expected_delays = [0.3, 0.6, 1.2, 2.4, 4.8, 5.0] + actual_delays = [call.args[0] for call in mock_sleep.call_args_list] + assert len(actual_delays) == len(expected_delays) + for actual, expected in zip(actual_delays, expected_delays, strict=True): + assert actual == pytest.approx(expected), ( + f"Expected delay {expected}, got {actual}" + ) + + + @pytest.mark.anyio + async def test_failed_precondition_not_ready_retries_then_succeeds(self): + """FAILED_PRECONDITION 'not ready' should retry and succeed on next attempt.""" + lease = _make_lease_for_handle() + dial_response = Mock(router_endpoint="ep", router_token="tok") + call_count = 0 + + async def dial_side_effect(req): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise _make_aio_rpc_error( + grpc.StatusCode.FAILED_PRECONDITION, "exporter not ready" + ) + return dial_response + + lease.controller.Dial = AsyncMock(side_effect=dial_side_effect) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + with patch("jumpstarter.client.lease.connect_router_stream") as mock_router: + + @asynccontextmanager + async def fake_router(*args, **kwargs): + yield + + mock_router.side_effect = fake_router + await lease.handle_async(Mock()) + + assert call_count == 2 + + @pytest.mark.anyio + async def test_failed_precondition_returns_after_timeout(self): + """FAILED_PRECONDITION should return (not raise) when dial_timeout is exceeded.""" + lease = _make_lease_for_handle() + lease.dial_timeout = 0.0 # already expired + + lease.controller.Dial = AsyncMock( + side_effect=_make_aio_rpc_error( + grpc.StatusCode.FAILED_PRECONDITION, "exporter not ready" + ), + ) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + # Should return without raising + await lease.handle_async(Mock()) + + lease.controller.Dial.assert_called_once() + + @pytest.mark.anyio + async def test_permission_denied_sets_lease_transferred(self): + """PERMISSION_DENIED should set lease_transferred = True.""" + lease = _make_lease_for_handle() + assert lease.lease_transferred is False + + lease.controller.Dial = AsyncMock( + side_effect=_make_aio_rpc_error( + grpc.StatusCode.PERMISSION_DENIED, "permission denied" + ), + ) + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + await lease.handle_async(Mock()) + + assert lease.lease_transferred is True + + +class TestRetryDelay: + """Tests for the _retry_delay static method.""" + + def test_basic_exponential(self): + assert Lease._retry_delay(0, 60.0) == pytest.approx(0.3) + assert Lease._retry_delay(1, 60.0) == pytest.approx(0.6) + assert Lease._retry_delay(2, 60.0) == pytest.approx(1.2) + + def test_capped_by_max(self): + assert Lease._retry_delay(10, 60.0) == pytest.approx(5.0) + + def test_capped_by_remaining(self): + assert Lease._retry_delay(0, 0.1) == pytest.approx(0.1) + + +class TestTransientGrpcCodes: + """Tests for the _TRANSIENT_GRPC_CODES class attribute.""" + + def test_contains_expected_codes(self): + assert grpc.StatusCode.UNAVAILABLE in Lease._TRANSIENT_GRPC_CODES + assert grpc.StatusCode.RESOURCE_EXHAUSTED in Lease._TRANSIENT_GRPC_CODES + assert grpc.StatusCode.ABORTED in Lease._TRANSIENT_GRPC_CODES + assert grpc.StatusCode.INTERNAL in Lease._TRANSIENT_GRPC_CODES + + def test_unknown_not_in_blanket_transient_codes(self): + """UNKNOWN is handled separately via _TRANSIENT_UNKNOWN_MESSAGES.""" + assert grpc.StatusCode.UNKNOWN not in Lease._TRANSIENT_GRPC_CODES + + def test_does_not_contain_non_transient_codes(self): + assert grpc.StatusCode.PERMISSION_DENIED not in Lease._TRANSIENT_GRPC_CODES + assert grpc.StatusCode.NOT_FOUND not in Lease._TRANSIENT_GRPC_CODES + assert grpc.StatusCode.FAILED_PRECONDITION not in Lease._TRANSIENT_GRPC_CODES + + def test_transient_unknown_messages(self): + """Should contain the known tunnel teardown messages.""" + assert "watch channel closed" in Lease._TRANSIENT_UNKNOWN_MESSAGES diff --git a/python/packages/jumpstarter/jumpstarter/common/streams.py b/python/packages/jumpstarter/jumpstarter/common/streams.py index 8cdc02330..19c0280ca 100644 --- a/python/packages/jumpstarter/jumpstarter/common/streams.py +++ b/python/packages/jumpstarter/jumpstarter/common/streams.py @@ -1,3 +1,4 @@ +import asyncio from contextlib import asynccontextmanager from typing import Annotated, Literal, Union from uuid import UUID @@ -34,13 +35,27 @@ class StreamRequestMetadata(BaseModel): @asynccontextmanager -async def connect_router_stream(endpoint, token, stream, tls_config, grpc_options): +async def connect_router_stream(endpoint, token, stream, tls_config, grpc_options, channel_ready_timeout: float = 10): credentials = grpc.composite_channel_credentials( await ssl_channel_credentials(endpoint, tls_config), grpc.access_token_call_credentials(token), ) async with aio_secure_channel(endpoint, credentials, grpc_options) as channel: + # Wait for the channel to be ready before starting the stream. + # Without this, a broken router connection would cause the gRPC + # stream to hang indefinitely waiting for the HTTP/2 SETTINGS frame, + # which manifests as a timeout for the j command on the Unix socket. + try: + await asyncio.wait_for(channel.channel_ready(), timeout=channel_ready_timeout) + except asyncio.TimeoutError: + raise grpc.aio.AioRpcError( + code=grpc.StatusCode.UNAVAILABLE, + initial_metadata=grpc.aio.Metadata(), + trailing_metadata=grpc.aio.Metadata(), + details=f"Timed out waiting for router channel to become ready ({channel_ready_timeout}s)", + debug_error_string=None, + ) from None router = router_pb2_grpc.RouterServiceStub(channel) context = router.Stream(metadata=()) async with RouterStream(context=context) as s: diff --git a/python/packages/jumpstarter/jumpstarter/common/streams_test.py b/python/packages/jumpstarter/jumpstarter/common/streams_test.py new file mode 100644 index 000000000..be89d5ee8 --- /dev/null +++ b/python/packages/jumpstarter/jumpstarter/common/streams_test.py @@ -0,0 +1,88 @@ +import asyncio +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, Mock, patch + +import grpc +import pytest +from grpc.aio import AioRpcError + +from jumpstarter.common.streams import connect_router_stream + + +class TestConnectRouterStreamChannelReady: + """Tests for the channel_ready timeout logic in connect_router_stream.""" + + @pytest.mark.anyio + async def test_raises_unavailable_on_channel_ready_timeout(self): + """When channel_ready() times out, an AioRpcError with UNAVAILABLE should be raised.""" + mock_channel = Mock() + + # Make channel_ready() return a coroutine that never completes + async def hang_forever(): + await asyncio.sleep(999) + + mock_channel.channel_ready = Mock(return_value=hang_forever()) + + @asynccontextmanager + async def fake_secure_channel(*args, **kwargs): + yield mock_channel + + with ( + patch("jumpstarter.common.streams.ssl_channel_credentials", new_callable=AsyncMock), + patch("jumpstarter.common.streams.aio_secure_channel", side_effect=fake_secure_channel), + patch("grpc.composite_channel_credentials", return_value=Mock()), + patch("grpc.access_token_call_credentials", return_value=Mock()), + ): + with pytest.raises(AioRpcError) as exc_info: + async with connect_router_stream( + "endpoint:443", "token", Mock(), Mock(), {}, channel_ready_timeout=0.01 + ): + pass # pragma: no cover + + assert exc_info.value.code() == grpc.StatusCode.UNAVAILABLE + assert "Timed out" in str(exc_info.value.details()) + + @pytest.mark.anyio + async def test_proceeds_when_channel_ready_succeeds(self): + """When channel_ready() succeeds quickly, the stream should be set up normally.""" + mock_channel = Mock() + + # channel_ready() resolves immediately + async def ready_immediately(): + pass + + mock_channel.channel_ready = Mock(return_value=ready_immediately()) + + mock_context = Mock() + + @asynccontextmanager + async def fake_secure_channel(*args, **kwargs): + yield mock_channel + + @asynccontextmanager + async def fake_router_stream(*args, **kwargs): + yield Mock() + + @asynccontextmanager + async def fake_forward(*args, **kwargs): + yield + + with ( + patch("jumpstarter.common.streams.ssl_channel_credentials", new_callable=AsyncMock), + patch("jumpstarter.common.streams.aio_secure_channel", side_effect=fake_secure_channel), + patch("grpc.composite_channel_credentials", return_value=Mock()), + patch("grpc.access_token_call_credentials", return_value=Mock()), + patch("jumpstarter.common.streams.router_pb2_grpc.RouterServiceStub") as mock_stub_cls, + patch("jumpstarter.common.streams.RouterStream", side_effect=fake_router_stream), + patch("jumpstarter.common.streams.forward_stream", side_effect=fake_forward), + ): + mock_stub = Mock() + mock_stub.Stream.return_value = mock_context + mock_stub_cls.return_value = mock_stub + + async with connect_router_stream( + "endpoint:443", "token", Mock(), Mock(), {}, channel_ready_timeout=5 + ): + pass # Successfully entered the context + + mock_channel.channel_ready.assert_called_once()