From 9e7b5c6627e70e2b501f2c74dd17c9a547f150ba Mon Sep 17 00:00:00 2001 From: claude <6687499+pike00@users.noreply.github.com> Date: Tue, 9 Jun 2026 07:46:17 -0500 Subject: [PATCH] fix: make V1Channel re-subscribable after a failed subscribe --- roborock/devices/device.py | 5 +- roborock/devices/rpc/v1_channel.py | 95 +++++++++++++++----------- tests/devices/rpc/test_v1_channel.py | 60 +++++++++++++++++ tests/devices/test_v1_device.py | 99 +++++++++++++++++++++++++++- 4 files changed, 218 insertions(+), 41 deletions(-) diff --git a/roborock/devices/device.py b/roborock/devices/device.py index bf020814..50cda7be 100644 --- a/roborock/devices/device.py +++ b/roborock/devices/device.py @@ -202,7 +202,10 @@ async def connect(self) -> None: await self.v1_properties.start() elif self.b01_q10_properties is not None: await self.b01_q10_properties.start() - except RoborockException: + except BaseException: + # Any failure in start() must unsubscribe before propagating, so a + # retry by connect_loop() gets a clean channel. Broader than + # RoborockException so non-Roborock errors also release the channel. unsub() raise self._logger.info("Connected to device") diff --git a/roborock/devices/rpc/v1_channel.py b/roborock/devices/rpc/v1_channel.py index 6d310543..51250017 100644 --- a/roborock/devices/rpc/v1_channel.py +++ b/roborock/devices/rpc/v1_channel.py @@ -295,45 +295,66 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab if self._callback is not None: raise ValueError("Only one subscription allowed at a time") - # Make an initial, optimistic attempt to connect to local with the - # cache. The cache information will be refreshed by the background task. - try: - await self._local_connect(prefer_cache=True) - except RoborockException as err: - self._logger.debug("First local connection attempt failed, will retry: %s", err) - - # Start a background task to manage the local connection health. This - # happens independent of whether we were able to connect locally now. - if self._reconnect_task is None: - loop = asyncio.get_running_loop() - self._reconnect_task = loop.create_task(self._background_reconnect()) - - # We maintain an active MQTT subscription even when connected locally to receive - # unsolicited status updates (DPS push messages) directly from the cloud. + # Claim the subscription up front. Any failure in the setup below routes + # through _teardown(), which clears this again so the channel is left in + # a clean, re-subscribable state. Without this, a partially-completed + # subscribe (e.g. a transient failure later in connect()) would leave a + # stale callback and the next subscribe() would raise the guard above. + self._callback = callback try: - self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) - except RoborockException as err: - if not self.is_local_connected: - # Propagate error if both local and MQTT failed - self._logger.debug("MQTT connection also failed: %s", err) - raise - self._logger.debug("MQTT subscription failed, continuing with local-only connection: %s", err) - - def unsub() -> None: - """Unsubscribe from all messages.""" - if self._reconnect_task: - self._reconnect_task.cancel() - self._reconnect_task = None - if self._mqtt_unsub: - self._mqtt_unsub() - self._mqtt_unsub = None - if self._local_unsub: - self._local_unsub() - self._local_unsub = None - self._logger.debug("Unsubscribed from device") + # Make an initial, optimistic attempt to connect to local with the + # cache. The cache information will be refreshed by the background task. + try: + await self._local_connect(prefer_cache=True) + except RoborockException as err: + self._logger.debug("First local connection attempt failed, will retry: %s", err) - self._callback = callback - return unsub + # Start a background task to manage the local connection health. This + # happens independent of whether we were able to connect locally now. + if self._reconnect_task is None: + loop = asyncio.get_running_loop() + self._reconnect_task = loop.create_task(self._background_reconnect()) + + # We maintain an active MQTT subscription even when connected locally to receive + # unsolicited status updates (DPS push messages) directly from the cloud. + try: + self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message) + except RoborockException as err: + if not self.is_local_connected: + # Propagate error if both local and MQTT failed + self._logger.debug("MQTT connection also failed: %s", err) + raise + self._logger.debug("MQTT subscription failed, continuing with local-only connection: %s", err) + except BaseException: + # Any failure during setup must leave the channel re-subscribable: + # cancel the reconnect task, drop subscriptions, and clear _callback. + self._teardown() + raise + + self._logger.debug("Subscribed to device") + return self._teardown + + def _teardown(self) -> None: + """Tear down all subscriptions and reset the channel to a re-subscribable state. + + Returned from subscribe() as the unsubscribe function and also invoked on + any failure partway through subscribe(). Idempotent: each resource is + guarded so repeat calls are no-ops. + """ + if self._reconnect_task: + self._reconnect_task.cancel() + self._reconnect_task = None + if self._mqtt_unsub: + self._mqtt_unsub() + self._mqtt_unsub = None + if self._local_unsub: + self._local_unsub() + self._local_unsub = None + if self._local_channel: + self._local_channel.close() + self._local_channel = None + self._callback = None + self._logger.debug("Unsubscribed from device") def add_dps_listener(self, listener: Callable[[dict[RoborockDataProtocol, Any]], None]) -> Callable[[], None]: """Add a listener for DPS updates. diff --git a/tests/devices/rpc/test_v1_channel.py b/tests/devices/rpc/test_v1_channel.py index fdd4eda9..f77efa02 100644 --- a/tests/devices/rpc/test_v1_channel.py +++ b/tests/devices/rpc/test_v1_channel.py @@ -642,3 +642,63 @@ async def test_v1_channel_dps_listener_raises_exception( unsub_dps1() unsub_dps2() + + +async def test_v1_channel_resubscribe_after_unsub( + v1_channel: V1Channel, + mock_mqtt_channel: FakeChannel, +) -> None: + """A subscribe -> unsub -> subscribe cycle must not raise. + + Regression: unsub() previously failed to clear ``_callback``, so the second + subscribe() tripped the "Only one subscription allowed at a time" guard. + This is the exact failure that bricked a second vacuum sharing an account. + """ + mock_mqtt_channel.response_queue.append(TEST_NETWORK_INFO_RESPONSE) + unsub = await v1_channel.subscribe(Mock()) + assert v1_channel._callback is not None + + unsub() + # The whole point of the fix: tearing down clears the callback. + assert v1_channel._callback is None + assert v1_channel._reconnect_task is None + assert not mock_mqtt_channel.subscribers + + # Re-subscribing must succeed (network info is now cached, no MQTT needed). + unsub2 = await v1_channel.subscribe(Mock()) + assert v1_channel._callback is not None + assert mock_mqtt_channel.subscribers + unsub2() + + +async def test_v1_channel_subscribe_failure_is_atomic( + v1_channel: V1Channel, + mock_mqtt_channel: FakeChannel, + mock_local_channel: FakeChannel, +) -> None: + """A failure partway through subscribe() leaves the channel re-subscribable. + + Regression: a failed subscribe() previously leaked the background reconnect + task and a partial subscription, so the next attempt could neither reuse nor + cleanly recreate the channel. + """ + # Both transports down: local connect fails and the MQTT subscribe fails. + mock_local_channel.connect.side_effect = RoborockException("local down") + mock_mqtt_channel.subscribe.side_effect = RoborockException("mqtt down") + + with pytest.raises(RoborockException): + await v1_channel.subscribe(Mock()) + + # No leaked task, no stale callback, no dangling subscription. + assert v1_channel._callback is None + assert v1_channel._reconnect_task is None + assert v1_channel._mqtt_unsub is None + assert not mock_mqtt_channel.subscribers + + # And the channel is re-subscribable once the transports recover. + mock_local_channel.connect.side_effect = None + mock_mqtt_channel.subscribe.side_effect = mock_mqtt_channel._subscribe + mock_mqtt_channel.response_queue.append(TEST_NETWORK_INFO_RESPONSE) + unsub = await v1_channel.subscribe(Mock()) + assert v1_channel._callback is not None + unsub() diff --git a/tests/devices/test_v1_device.py b/tests/devices/test_v1_device.py index 8afc62cd..44ab8466 100644 --- a/tests/devices/test_v1_device.py +++ b/tests/devices/test_v1_device.py @@ -8,14 +8,18 @@ import pytest from syrupy import SnapshotAssertion -from roborock.data import HomeData, S7MaxVStatus, UserData -from roborock.devices.cache import DeviceCache, NoCache +from roborock.data import HomeData, NetworkInfo, S7MaxVStatus, UserData +from roborock.devices.cache import DeviceCache, DeviceCacheData, InMemoryCache, NoCache from roborock.devices.device import RoborockDevice +from roborock.devices.rpc.v1_channel import V1Channel from roborock.devices.traits import v1 from roborock.devices.traits.v1.common import V1TraitMixin -from roborock.protocols.v1_protocol import decode_rpc_response +from roborock.devices.transport.local_channel import LocalSession +from roborock.exceptions import RoborockException +from roborock.protocols.v1_protocol import SecurityData, decode_rpc_response from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol from tests import mock_data +from tests.fixtures.channel_fixtures import FakeChannel USER_DATA = UserData.from_dict(mock_data.USER_DATA) HOME_DATA = HomeData.from_dict(mock_data.HOME_DATA_RAW) @@ -181,3 +185,92 @@ async def test_device_trait_command_parsing( assert device.v1_properties device_dict = device.diagnostic_data() assert device_dict == snapshot + + +@pytest.mark.parametrize( + "start_error", + [RoborockException("transient status fetch failed"), ValueError("unexpected")], + ids=["roborock-exception", "non-roborock-exception"], +) +async def test_connect_unsubscribes_when_start_fails( + device: RoborockDevice, + channel: AsyncMock, + start_error: Exception, +) -> None: + """connect() must release the channel when start() fails, for any exception. + + Regression: the cleanup was scoped to ``except RoborockException``, so a + non-Roborock failure in start() propagated without unsubscribing, leaving the + channel subscribed and the next attempt unable to re-subscribe. + """ + unsub = Mock() + channel.subscribe = AsyncMock(return_value=unsub) + device.v1_properties.start = AsyncMock(side_effect=start_error) + + with pytest.raises(type(start_error)): + await device.connect() + + channel.subscribe.assert_awaited_once() + unsub.assert_called_once() # channel released before propagating + assert device._unsub is None # not marked connected; safe for connect_loop to retry + + +async def test_connect_retries_after_transient_start_failure() -> None: + """End-to-end regression for the Q5 multi-vacuum bug. + + A device backed by a real V1Channel: the first connect() subscribes, then + start() fails transiently. The retry must re-subscribe cleanly rather than + raising "Only one subscription allowed at a time", and the device must end + up connected. + """ + duid = HOME_DATA.devices[0].duid + + mqtt_channel = FakeChannel() + await mqtt_channel.connect() + local_channel = FakeChannel() + local_session = Mock(spec=LocalSession, return_value=local_channel) + + # Cache the network info so local connect doesn't need an MQTT round-trip. + cache = InMemoryCache() + device_cache = DeviceCache(duid, cache) + await device_cache.set(DeviceCacheData(network_info=NetworkInfo.from_dict(mock_data.NETWORK_INFO))) + + v1_channel = V1Channel( + device_uid=duid, + security_data=SecurityData(endpoint="test_endpoint", nonce=b"test_nonce_16byte"), + mqtt_channel=mqtt_channel, + local_session=local_session, + device_cache=device_cache, + ) + device = RoborockDevice( + device_info=HOME_DATA.devices[0], + product=HOME_DATA.products[0], + channel=v1_channel, + trait=v1.create( + duid, + HOME_DATA.products[0], + HOME_DATA, + AsyncMock(), + AsyncMock(), + AsyncMock(), + Mock(), + AsyncMock(), + device_cache=device_cache, + region=USER_DATA.region, + ), + ) + + # First connect() subscribes successfully, then start() fails transiently; + # the second succeeds. + device.v1_properties.start = AsyncMock(side_effect=[RoborockException("transient"), None]) + + with pytest.raises(RoborockException): + await device.connect() + assert device._unsub is None # channel released after the transient failure + + # The retry must NOT raise "Only one subscription allowed at a time". + await device.connect() + assert device._unsub is not None + assert device.is_connected + + await device.close()