Skip to content

Commit cb4a446

Browse files
committed
fix: extract interaction_id from SSE events for Interactions API chaining
SSE streaming events carry the interaction ID in different attributes than the non-streaming Interaction response: - InteractionStartEvent/CompleteEvent: event.interaction.id - InteractionStatusUpdate: event.interaction_id The code only checked event.id which doesn't exist on SSE event types, so current_interaction_id was never set during streaming. This caused _find_previous_interaction_id() to fail, breaking function calling with StreamingMode.SSE + Interactions API. Fixes #5169
1 parent 114deef commit cb4a446

2 files changed

Lines changed: 111 additions & 3 deletions

File tree

src/google/adk/models/interactions_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,9 +1013,18 @@ async def generate_content_via_interactions(
10131013
# Log the streaming event
10141014
logger.debug(build_interactions_event_log(event))
10151015

1016-
# Extract interaction ID from event if available
1017-
if hasattr(event, 'id') and event.id:
1018-
current_interaction_id = event.id
1016+
# Extract interaction ID from event if available.
1017+
# SSE events carry the ID in different attributes depending on type:
1018+
# - InteractionStartEvent/CompleteEvent: event.interaction.id
1019+
# - InteractionStatusUpdate: event.interaction_id
1020+
if (
1021+
hasattr(event, 'interaction')
1022+
and hasattr(event.interaction, 'id')
1023+
and event.interaction.id
1024+
):
1025+
current_interaction_id = event.interaction.id
1026+
elif hasattr(event, 'interaction_id') and event.interaction_id:
1027+
current_interaction_id = event.interaction_id
10191028
llm_response = convert_interaction_event_to_llm_response(
10201029
event, aggregated_parts, current_interaction_id
10211030
)

tests/unittests/models/test_interactions_utils.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,3 +955,102 @@ def test_unknown_event_type_returns_none(self):
955955

956956
assert result is None
957957
assert not aggregated_parts
958+
959+
960+
class TestSSEInteractionIdExtraction:
961+
"""Tests for interaction_id extraction from SSE events.
962+
963+
SSE events carry the interaction ID in different attributes:
964+
- InteractionStartEvent/CompleteEvent: event.interaction.id
965+
- InteractionStatusUpdate: event.interaction_id
966+
- ContentDelta/ContentStop: no interaction ID
967+
968+
The call_interactions_api function must extract the ID from these
969+
locations so it can be propagated to LlmResponse objects and
970+
ultimately stored in session events for chaining.
971+
"""
972+
973+
def test_interaction_start_event_carries_id(self):
974+
"""InteractionStartEvent has interaction.id — should be found."""
975+
event = MagicMock()
976+
event.event_type = 'interaction.start'
977+
event.interaction = MagicMock()
978+
event.interaction.id = 'int_start_abc'
979+
# Should NOT have direct interaction_id
980+
del event.interaction_id
981+
982+
# Verify the extraction logic matches what call_interactions_api does
983+
current_id = None
984+
if (
985+
hasattr(event, 'interaction')
986+
and hasattr(event.interaction, 'id')
987+
and event.interaction.id
988+
):
989+
current_id = event.interaction.id
990+
elif hasattr(event, 'interaction_id') and event.interaction_id:
991+
current_id = event.interaction_id
992+
993+
assert current_id == 'int_start_abc'
994+
995+
def test_status_update_event_carries_interaction_id(self):
996+
"""InteractionStatusUpdate has interaction_id — should be found."""
997+
event = MagicMock(spec=['event_type', 'interaction_id', 'status'])
998+
event.event_type = 'interaction.status_update'
999+
event.interaction_id = 'int_status_xyz'
1000+
event.status = 'requires_action'
1001+
1002+
current_id = None
1003+
if (
1004+
hasattr(event, 'interaction')
1005+
and hasattr(event.interaction, 'id')
1006+
and event.interaction.id
1007+
):
1008+
current_id = event.interaction.id
1009+
elif hasattr(event, 'interaction_id') and event.interaction_id:
1010+
current_id = event.interaction_id
1011+
1012+
assert current_id == 'int_status_xyz'
1013+
1014+
def test_content_delta_has_no_interaction_id(self):
1015+
"""ContentDelta events don't carry interaction ID."""
1016+
event = MagicMock(spec=['event_type', 'delta', 'index', 'event_id'])
1017+
event.event_type = 'content.delta'
1018+
1019+
current_id = None
1020+
if (
1021+
hasattr(event, 'interaction')
1022+
and hasattr(event.interaction, 'id')
1023+
and event.interaction.id
1024+
):
1025+
current_id = event.interaction.id
1026+
elif hasattr(event, 'interaction_id') and event.interaction_id:
1027+
current_id = event.interaction_id
1028+
1029+
assert current_id is None
1030+
1031+
def test_interaction_id_propagated_to_status_update_response(self):
1032+
"""When interaction_id is extracted from earlier events, it should
1033+
be passed to convert_interaction_event_to_llm_response and appear
1034+
in the resulting LlmResponse for status_update events."""
1035+
event = MagicMock()
1036+
event.event_type = 'interaction.status_update'
1037+
event.status = 'requires_action'
1038+
1039+
# Function call was aggregated earlier
1040+
aggregated_parts = [
1041+
types.Part(
1042+
function_call=types.FunctionCall(
1043+
id='call_1', name='get_weather', args={'city': 'Tokyo'}
1044+
)
1045+
)
1046+
]
1047+
1048+
# The interaction_id should have been extracted from an earlier
1049+
# InteractionStartEvent and passed here
1050+
result = interactions_utils.convert_interaction_event_to_llm_response(
1051+
event, aggregated_parts, interaction_id='int_from_start'
1052+
)
1053+
1054+
assert result is not None
1055+
assert result.interaction_id == 'int_from_start'
1056+
assert result.turn_complete is True

0 commit comments

Comments
 (0)