diff --git a/src/hooks/useMessageStream.test.ts b/src/hooks/useMessageStream.test.ts index 10370a8b..8fa3544e 100644 --- a/src/hooks/useMessageStream.test.ts +++ b/src/hooks/useMessageStream.test.ts @@ -234,6 +234,107 @@ describe("useMessageStream", () => { }); }); + it("resets connection state when conversationId is cleared", async () => { + const { result, rerender } = renderHook( + ({ conversationId }) => useMessageStream(conversationId), + { initialProps: { conversationId: "conv-123" as string | null } } + ); + + act(() => { + MockEventSource.instances[0].simulateOpen(); + }); + + await waitFor(() => { + expect(result.current.isConnected).toBe(true); + }); + + const firstInstance = MockEventSource.instances[0]; + + rerender({ conversationId: null }); + + await waitFor(() => { + expect(firstInstance.closed).toBe(true); + expect(result.current.isConnected).toBe(false); + expect(result.current.error).toBeNull(); + expect(MockEventSource.instances).toHaveLength(1); + }); + }); + + it("waits for onopen when the same conversation reconnects after being cleared", async () => { + const { result, rerender } = renderHook( + ({ conversationId }) => useMessageStream(conversationId), + { initialProps: { conversationId: "conv-123" as string | null } } + ); + + act(() => { + MockEventSource.instances[0].simulateOpen(); + }); + + await waitFor(() => { + expect(result.current.isConnected).toBe(true); + }); + + const firstInstance = MockEventSource.instances[0]; + + rerender({ conversationId: null }); + + await waitFor(() => { + expect(firstInstance.closed).toBe(true); + expect(result.current.isConnected).toBe(false); + }); + + rerender({ conversationId: "conv-123" }); + + expect(MockEventSource.instances).toHaveLength(2); + expect(result.current.isConnected).toBe(false); + expect(result.current.error).toBeNull(); + + act(() => { + MockEventSource.instances[1].simulateOpen(); + }); + + await waitFor(() => { + expect(result.current.isConnected).toBe(true); + }); + }); + + it("does not expose stale errors before same conversation reconnect opens", async () => { + const { result, rerender } = renderHook( + ({ conversationId }) => useMessageStream(conversationId), + { initialProps: { conversationId: "conv-123" as string | null } } + ); + + act(() => { + MockEventSource.instances[0].simulateError(); + }); + + await waitFor(() => { + expect(result.current.error).toBeTruthy(); + expect(result.current.isConnected).toBe(false); + }); + + rerender({ conversationId: null }); + + await waitFor(() => { + expect(result.current.error).toBeNull(); + }); + + rerender({ conversationId: "conv-123" }); + + expect(MockEventSource.instances).toHaveLength(2); + expect(result.current.error).toBeNull(); + expect(result.current.isConnected).toBe(false); + + act(() => { + MockEventSource.instances[1].simulateOpen(); + }); + + await waitFor(() => { + expect(result.current.error).toBeNull(); + expect(result.current.isConnected).toBe(true); + }); + }); + it("clears error on successful reconnect", async () => { const { result } = renderHook(() => useMessageStream("conv-123")); diff --git a/src/hooks/useMessageStream.ts b/src/hooks/useMessageStream.ts index 09808559..91f44fb6 100644 --- a/src/hooks/useMessageStream.ts +++ b/src/hooks/useMessageStream.ts @@ -20,12 +20,22 @@ export function useMessageStream( conversationId: string | null, options: UseMessageStreamOptions = {} ): UseMessageStreamReturn { - const [isConnected, setIsConnected] = useState(false); - const [error, setError] = useState(null); + const [connectedStreamId, setConnectedStreamId] = useState( + null + ); + const [connectionError, setConnectionError] = useState<{ + streamId: number; + conversationId: string; + error: Error; + } | null>(null); const [reconnectTrigger, setReconnectTrigger] = useState(0); - const [isOtherTyping, setIsOtherTyping] = useState(false); + const [otherTypingState, setOtherTypingState] = useState({ + conversationId: null as string | null, + streamId: null as number | null, + isTyping: false, + }); const eventSourceRef = useRef(null); - const typingTimeoutRef = useRef(null); + const streamIdRef = useRef(0); const lastTypingSentRef = useRef(0); const { onMessage, onTypingChange } = options; @@ -40,17 +50,39 @@ export function useMessageStream( onTypingChangeRef.current = onTypingChange; }, [onTypingChange]); + const isConnected = + conversationId !== null && connectedStreamId === streamIdRef.current; + const error = + connectionError?.conversationId === conversationId && + connectionError.streamId === streamIdRef.current + ? connectionError.error + : null; + const isOtherTyping = + conversationId !== null && + otherTypingState.conversationId === conversationId && + otherTypingState.streamId === streamIdRef.current && + otherTypingState.isTyping; + // Poll for typing status useEffect(() => { - if (!conversationId || !isConnected) return; + if (!conversationId || !isConnected || connectedStreamId === null) return; + + let cancelled = false; + const streamId = connectedStreamId; const pollTyping = async () => { try { const response = await fetch(`/api/conversations/${conversationId}/typing`); if (response.ok) { const data = await response.json(); - const hasTyping = data.typing && data.typing.length > 0; - setIsOtherTyping(hasTyping); + if (cancelled || streamIdRef.current !== streamId) return; + + const hasTyping = Boolean(data.typing?.length); + setOtherTypingState({ + conversationId, + streamId, + isTyping: hasTyping, + }); onTypingChangeRef.current?.(hasTyping, data.typing?.[0]); } } catch { @@ -62,15 +94,30 @@ export function useMessageStream( const interval = setInterval(pollTyping, 2000); pollTyping(); // Initial poll - return () => clearInterval(interval); - }, [conversationId, isConnected]); + return () => { + cancelled = true; + clearInterval(interval); + }; + }, [conversationId, connectedStreamId, isConnected]); useEffect(() => { - if (!conversationId) return; - - // Clean up existing connection if (eventSourceRef.current) { eventSourceRef.current.close(); + eventSourceRef.current = null; + } + + const streamId = streamIdRef.current + 1; + streamIdRef.current = streamId; + setConnectedStreamId(null); + setConnectionError(null); + setOtherTypingState({ + conversationId: null, + streamId: null, + isTyping: false, + }); + + if (!conversationId) { + return; } const eventSource = new EventSource( @@ -79,31 +126,48 @@ export function useMessageStream( eventSourceRef.current = eventSource; eventSource.onopen = () => { - setIsConnected(true); - setError(null); + if (streamIdRef.current !== streamId) return; + + setConnectedStreamId(streamId); + setConnectionError(null); }; eventSource.onmessage = (event) => { + if (streamIdRef.current !== streamId) return; + try { const message = JSON.parse(event.data) as MessageWithSender; onMessageRef.current?.(message); // Clear typing indicator when a message is received - setIsOtherTyping(false); + setOtherTypingState({ + conversationId, + streamId, + isTyping: false, + }); } catch (e) { console.error("Failed to parse message:", e); } }; eventSource.onerror = () => { - setIsConnected(false); - setError(new Error("Connection lost")); + if (streamIdRef.current !== streamId) return; + + setConnectedStreamId(null); + setConnectionError({ + streamId, + conversationId, + error: new Error("Connection lost"), + }); eventSource.close(); }; return () => { eventSource.close(); - if (typingTimeoutRef.current) { - clearTimeout(typingTimeoutRef.current); + if (eventSourceRef.current === eventSource) { + eventSourceRef.current = null; + } + if (streamIdRef.current === streamId) { + streamIdRef.current += 1; } }; }, [conversationId, reconnectTrigger]);