diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 800b15374..2cdd5e326 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -233,7 +233,7 @@ async def query_endpoint_handler( responses_params, moderation_result, endpoint_path, - compaction.original_input if compaction.compacted else None, + compaction.original_input, ) if moderation_result.decision == "passed": diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index e6b269e4c..c266e80d2 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -338,6 +338,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals responses_params=responses_params, context=context, endpoint_path=endpoint_path, + original_input=None, ) # Combine inline RAG results (BYOK + Solr) with tool-based results @@ -353,6 +354,8 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals responses_params=responses_params, turn_summary=turn_summary, background_topic_summary_tasks=_background_topic_summary_tasks, + emit_start=True, + original_input=None, ), media_type=response_media_type, ) @@ -387,7 +390,6 @@ async def retrieve_response_generator( if context.moderation_result.decision == "blocked": turn_summary.llm_response = context.moderation_result.message turn_summary.id = context.moderation_result.moderation_id - turn_summary.output_items = [context.moderation_result.refusal_response] # In compacted mode the conversation parameter was omitted, so the # refusal turn (with the original input) is persisted by # generate_response; storing it here too would duplicate it. @@ -506,6 +508,7 @@ async def generate_response_with_compaction( responses_params=responses_params, context=context, endpoint_path=endpoint_path, + original_input=compacted_original_input, ) except HTTPException as e: yield http_exception_stream_event(e) @@ -699,7 +702,7 @@ async def generate_response( # pylint: disable=too-many-arguments,too-many-posi if original_input is not None else context.query_request.query ), - turn_summary.output_items, + [], # field was removed from TurnSummary ) except Exception: # pylint: disable=broad-except logger.exception( @@ -873,10 +876,6 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat getattr(chunk, "response"), # noqa: B009 ) turn_summary.llm_response = turn_summary.llm_response or "".join(text_parts) - # Capture structured output items for compacted-mode turn storage - # (LCORE-1572), so the persisted turn keeps non-text output items - # rather than being flattened to the response text. - turn_summary.output_items = list(latest_response_object.output or []) yield stream_event( { "id": chunk_id, @@ -893,9 +892,6 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat OpenAIResponseObject, getattr(chunk, "response"), # noqa: B009 ) - # Capture any partial output items so a compacted-mode turn is not - # persisted with empty output on these terminals (LCORE-1572). - turn_summary.output_items = list(latest_response_object.output or []) error_message = ( latest_response_object.error.message if latest_response_object.error diff --git a/src/models/common/turn_summary.py b/src/models/common/turn_summary.py index f09e24845..c3de9e572 100644 --- a/src/models/common/turn_summary.py +++ b/src/models/common/turn_summary.py @@ -5,7 +5,6 @@ from typing import Any, Optional -from llama_stack_api import OpenAIResponseOutput from pydantic import AnyUrl, BaseModel, Field from utils.token_counter import TokenCounter @@ -109,11 +108,6 @@ class TurnSummary(BaseModel): rag_chunks: list[RAGChunk] = Field(default_factory=list) referenced_documents: list[ReferencedDocument] = Field(default_factory=list) token_usage: TokenCounter = Field(default_factory=TokenCounter) - output_items: list[OpenAIResponseOutput] = Field( - default_factory=list, - description="Structured response output items, captured for compacted-mode " - "turn persistence (LCORE-1572). Empty on the non-compacted path.", - ) class ToolInfoSummary(BaseModel): diff --git a/src/pydantic_ai_lightspeed/llamastack/__init__.py b/src/pydantic_ai_lightspeed/llamastack/__init__.py index fac9ee826..80cb95193 100644 --- a/src/pydantic_ai_lightspeed/llamastack/__init__.py +++ b/src/pydantic_ai_lightspeed/llamastack/__init__.py @@ -1,6 +1,13 @@ """Pydantic AI provider for Llama Stack.""" -from pydantic_ai_lightspeed.llamastack._model import LlamaStackResponsesModel +from pydantic_ai_lightspeed.llamastack._model import ( + CompactionTurnContext, + LlamaStackResponsesModel, +) from pydantic_ai_lightspeed.llamastack._provider import LlamaStackProvider -__all__ = ["LlamaStackProvider", "LlamaStackResponsesModel"] +__all__ = [ + "CompactionTurnContext", + "LlamaStackProvider", + "LlamaStackResponsesModel", +] diff --git a/src/pydantic_ai_lightspeed/llamastack/_model.py b/src/pydantic_ai_lightspeed/llamastack/_model.py index 3df0b9640..588ec7978 100644 --- a/src/pydantic_ai_lightspeed/llamastack/_model.py +++ b/src/pydantic_ai_lightspeed/llamastack/_model.py @@ -1,87 +1,97 @@ -"""Custom OpenAI Responses model that works around Llama Stack streaming quirks. +"""Llama Stack Responses model adapter for pydantic-ai. -Llama Stack's Responses API emits ``ResponseFunctionCallArgumentsDeltaEvent`` for MCP -tool calls *before* the corresponding ``ResponseOutputItemAddedEvent``. pydantic_ai's -default handler creates an orphan ``ToolCallPartDelta`` for the unannounced item_id, -which later causes an IndexError in ``part_end_event``. - -Additionally, MCP tool calls arrive as ``McpCall`` items (not ``ResponseFunctionToolCall``), -and pydantic_ai registers them with a ``-call`` vendor_part_id suffix. The buffered -deltas must be replayed with the matching suffix so pydantic_ai can append the -streamed ``tool_args`` content to the correct part. - -This module provides ``LlamaStackResponsesModel`` which wraps the event stream to -buffer those early delta events and replay them correctly once the item is announced. +Patches client.responses.create to reorder streaming tool-call events and to +persist compacted conversation turns when the conversation parameter is omitted +from inference requests. """ from __future__ import annotations as _annotations from collections import defaultdict -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from typing import Any, cast +from collections.abc import AsyncIterator, Sequence +from dataclasses import dataclass +from typing import Any, Optional, Self, cast +from llama_stack_client import AsyncLlamaStackClient from openai import AsyncStream from openai.types import responses -from pydantic_ai import UnexpectedModelBehavior -from pydantic_ai._run_context import RunContext -from pydantic_ai._utils import PeekableAsyncStream, Unset, number_to_datetime -from pydantic_ai.messages import ModelMessage -from pydantic_ai.models import ( - ModelRequestParameters, - StreamedResponse, - check_allow_model_requests, -) -from pydantic_ai.models.openai import ( - OpenAIResponsesModel, - OpenAIResponsesModelSettings, - OpenAIResponsesStreamedResponse, - _map_api_errors, -) +from pydantic_ai.models.openai import OpenAIResponsesModel from pydantic_ai.settings import ModelSettings from log import get_logger +from models.common.responses.types import ResponseInput +from utils.conversations import append_turn_items_to_conversation logger = get_logger(__name__) -class _FilteredResponseStream: - """Wraps an OpenAI AsyncStream to reorder spurious events from Llama Stack. +@dataclass +class CompactionTurnContext: + """Mutable state for manually persisting compacted agent turns. + + latest_round_input is initialized to the real user query. The create patch + leaves it unchanged on the first LLM round, then records pydantic-ai input + for follow-up rounds after that turn is persisted. + + Attributes: + client: Llama Stack client used to append conversation items. + conversation_id: Conversation to store turns against. + latest_round_input: Input stored for the current or next inference round. + original_input_persisted: Whether the first compacted round was appended. + """ + + client: AsyncLlamaStackClient + conversation_id: str + latest_round_input: ResponseInput + original_input_persisted: bool = False + - Llama Stack emits ``ResponseFunctionCallArgumentsDeltaEvent`` for MCP tool calls - *before* the ``ResponseOutputItemAddedEvent`` that announces them. This wrapper - buffers those early deltas and replays them once the announcement arrives. +class _NormalizedLlamaStackStream: + """AsyncStream wrapper that normalizes Llama Stack response events. - For ``McpCall`` items specifically, pydantic_ai registers the part with a - ``-call`` vendor_part_id suffix. Buffered deltas are therefore replayed as a - single combined event with the suffixed ``item_id`` so they match the part, plus - a closing ``}`` to complete the outer JSON object that pydantic_ai opens. + Buffers early tool-call argument deltas and replays them after the matching + output item is announced. Optionally appends completed turns when compacted. """ - def __init__(self, source: AsyncStream[responses.ResponseStreamEvent]) -> None: - """Wrap an existing stream with reordering logic. + def __init__( + self, + source: AsyncStream[responses.ResponseStreamEvent], + compaction: Optional[CompactionTurnContext] = None, + ) -> None: + """Initialize the stream wrapper. Args: - source: The raw OpenAI AsyncStream to reorder. + source: Raw Responses API stream from the OpenAI SDK. + compaction: Compaction state for turn persistence, if active. """ self._source = source + self._compaction = compaction self._announced_item_ids: set[str] = set() self._buffered_deltas: dict[ str, list[responses.ResponseFunctionCallArgumentsDeltaEvent] ] = defaultdict(list) async def close(self) -> None: - """Close the underlying stream.""" + """Close the underlying SDK stream.""" await self._source.close() + async def __aenter__(self) -> Self: + """Enter the underlying stream context manager.""" + await self._source.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> None: + """Exit the underlying stream context manager.""" + await self._source.__aexit__(*args) + def __aiter__(self) -> AsyncIterator[responses.ResponseStreamEvent]: - """Return async iterator that reorders events.""" - return self._filtered_iter() + """Return an async iterator over normalized stream events.""" + return self._iter_normalized_events() - async def _filtered_iter( + async def _iter_normalized_events( self, ) -> AsyncIterator[responses.ResponseStreamEvent]: - """Yield events, buffering early argument deltas until their item is announced.""" + """Yield stream events in the order expected by pydantic-ai.""" async for event in self._source: if isinstance(event, responses.ResponseOutputItemAddedEvent): if ( @@ -91,7 +101,7 @@ async def _filtered_iter( item_id = event.item.id self._announced_item_ids.add(item_id) yield event - for delta in self._replay_buffered_deltas(item_id): + for delta in self._replay_function_tool_deltas(item_id): yield delta continue @@ -99,7 +109,7 @@ async def _filtered_iter( item_id = event.item.id self._announced_item_ids.add(item_id) yield event - for delta in self._replay_mcp_buffered_deltas(item_id): + for delta in self._replay_mcp_tool_deltas(item_id): yield delta continue @@ -112,18 +122,31 @@ async def _filtered_iter( self._buffered_deltas[event.item_id].append(event) continue + if ( + isinstance(event, responses.ResponseCompletedEvent) + and self._compaction is not None + ): + compaction = self._compaction + await append_turn_items_to_conversation( + compaction.client, + compaction.conversation_id, + compaction.latest_round_input, + cast(Sequence[Any], event.response.output), + ) + compaction.original_input_persisted = True + yield event - def _replay_buffered_deltas( + def _replay_function_tool_deltas( self, item_id: str ) -> list[responses.ResponseFunctionCallArgumentsDeltaEvent]: - """Return buffered deltas for a ``ResponseFunctionToolCall`` announcement. + """Return buffered deltas for a function tool-call item. Args: - item_id: The announced item ID. + item_id: Output item id from ResponseOutputItemAddedEvent. Returns: - List of buffered delta events to yield, unchanged. + Buffered argument delta events for the item, if any. """ buffered = self._buffered_deltas.pop(item_id, []) if buffered: @@ -134,28 +157,22 @@ def _replay_buffered_deltas( ) return buffered - def _replay_mcp_buffered_deltas( + def _replay_mcp_tool_deltas( self, item_id: str ) -> list[responses.ResponseFunctionCallArgumentsDeltaEvent]: - """Return buffered deltas for an ``McpCall`` announcement. - - pydantic_ai registers ``McpCall`` parts with ``vendor_part_id=f'{id}-call'`` - and seeds the args string with everything up to ``"tool_args":``. The - buffered deltas contain the actual ``tool_args`` content. We combine them - into a single delta with the suffixed ``item_id`` and append a closing ``}`` - to complete the outer JSON object that pydantic_ai opened. + """Return buffered MCP deltas as a single function-call delta event. Args: - item_id: The announced McpCall item ID. + item_id: MCP output item id from ResponseOutputItemAddedEvent. Returns: - List containing one synthetic delta event, or empty if nothing buffered. + A one-element list with a synthetic delta for pydantic-ai, or empty. """ buffered = self._buffered_deltas.pop(item_id, []) if not buffered: return [] - combined_args = "".join(d.delta for d in buffered) + "}" + combined_args = "".join(delta.delta for delta in buffered) + "}" logger.debug( "Replaying %d buffered MCP argument deltas as single event " "for item_id=%s-call", @@ -174,72 +191,87 @@ def _replay_mcp_buffered_deltas( class LlamaStackResponsesModel(OpenAIResponsesModel): - """OpenAI Responses model with Llama Stack streaming compatibility fixes. - - Overrides the streaming response processing to buffer and replay - ``ResponseFunctionCallArgumentsDeltaEvent`` events that Llama Stack emits - before the corresponding ``McpCall`` or ``ResponseFunctionToolCall`` item. - """ + """OpenAI Responses model with Llama Stack streaming and compaction support.""" - @asynccontextmanager - async def request_stream( + def __init__( # pylint: disable=too-many-arguments self, - messages: list[ModelMessage], - model_settings: ModelSettings | None, - model_request_parameters: ModelRequestParameters, - run_context: RunContext[Any] | None = None, - ) -> AsyncIterator[StreamedResponse]: - """Request a streaming response, filtering Llama Stack-specific event quirks. + model_name: str, + *, + provider: Any = "openai", + profile: Any = None, + settings: ModelSettings | None = None, + compaction: Optional[CompactionTurnContext] = None, + ) -> None: + """Initialize the model and patch client.responses.create. Args: - messages: Model messages for the request. - model_settings: Model-specific settings. - model_request_parameters: Request parameters for the model. - run_context: Optional run context from the agent. - - Yields: - A StreamedResponse with the filtered event stream. + model_name: Model identifier passed to pydantic-ai. + provider: Pydantic AI provider or provider name. + profile: Optional model profile override. + settings: Optional pydantic-ai model settings. + compaction: Compaction state when turns must be stored manually. """ - check_allow_model_requests() - model_settings, model_request_parameters = self.prepare_request( - model_settings, - model_request_parameters, - ) - model_settings_cast = cast(OpenAIResponsesModelSettings, model_settings or {}) - response = await self._responses_create( - messages, True, model_settings_cast, model_request_parameters + super().__init__( + model_name, + provider=provider, + profile=profile, + settings=settings, ) + self.compaction = compaction + self._patch_responses_create() + + def _patch_responses_create(self) -> None: + """Replace client.responses.create with a wrapper for the model lifetime. - filtered_stream = _FilteredResponseStream(response) + pydantic-ai calls responses.create for every inference round. The wrapper + runs before and after the real SDK method: - async with response: - peekable: PeekableAsyncStream[ - responses.ResponseStreamEvent, _FilteredResponseStream - ] = PeekableAsyncStream(filtered_stream) + Before (compacted mode only, after the first round is persisted): + Copy kwargs input into CompactionTurnContext.latest_round_input so + follow-up tool-loop rounds can be appended with the input pydantic-ai + actually sent. The first round is left unchanged so the real user query + is stored instead of the compacted explicit rewrite. - with _map_api_errors(self.model_name): - first_chunk = await peekable.peek() + After, depending on stream: - if isinstance(first_chunk, Unset): - raise UnexpectedModelBehavior( - "Streamed response ended without content or tool calls" + * stream=True — return _NormalizedLlamaStackStream around the SDK stream to + reorder early tool-call deltas and, when compacted, append the turn on + response.completed. + * stream=False — when compacted, append the completed turn immediately + using latest_round_input and mark the first round as persisted. + + Stream normalization is always applied; compaction hooks run only when + self.compaction is set. + """ + responses_api = self.client.responses + original_create = responses_api.create + + async def create(*args: Any, **kwargs: Any) -> Any: + if ( + self.compaction is not None + and "input" in kwargs + and self.compaction.original_input_persisted + ): + self.compaction.latest_round_input = cast( + ResponseInput, kwargs["input"] + ) + + result = await original_create(*args, **kwargs) + + if kwargs.get("stream"): + return _NormalizedLlamaStackStream( + cast(AsyncStream[responses.ResponseStreamEvent], result), + self.compaction, ) - if not isinstance(first_chunk, responses.ResponseCreatedEvent): - raise UnexpectedModelBehavior( - f"Expected ResponseCreatedEvent, got {type(first_chunk).__name__}" + if self.compaction is not None: + await append_turn_items_to_conversation( + self.compaction.client, + self.compaction.conversation_id, + self.compaction.latest_round_input, + cast(Sequence[Any], result.output), ) + self.compaction.original_input_persisted = True + return result - yield OpenAIResponsesStreamedResponse( - model_request_parameters=model_request_parameters, - _model_name=first_chunk.response.model, - _model_settings=model_settings_cast, - _response=peekable, # type: ignore[arg-type] - _provider_name=self._provider.name, - _provider_url=self._provider.base_url, - _provider_timestamp=( - number_to_datetime(first_chunk.response.created_at) - if first_chunk.response.created_at - else None - ), - ) + responses_api.create = create # type: ignore[method-assign] diff --git a/src/utils/agents/query.py b/src/utils/agents/query.py index c0f5ad958..4fe7f305b 100644 --- a/src/utils/agents/query.py +++ b/src/utils/agents/query.py @@ -282,7 +282,7 @@ async def retrieve_agent_response( responses_params: ResponsesApiParams, moderation_result: ShieldModerationResult, endpoint_path: str, - _original_input: Optional[ResponseInput] = None, + original_input: Optional[ResponseInput] = None, ) -> TurnSummary: """Retrieve a turn summary from a blocking agent run. @@ -293,7 +293,7 @@ async def retrieve_agent_response( responses_params: Prepared Responses API parameters. moderation_result: Shield moderation outcome for the turn. endpoint_path: Endpoint path used for metric labeling. - _original_input: Original user input before the explicit-input rewrite. + original_input: Original user input before the explicit-input rewrite. Returns: Turn summary for the completed agent run. @@ -305,7 +305,7 @@ async def retrieve_agent_response( await append_turn_items_to_conversation( client, responses_params.conversation, - responses_params.input, + original_input or responses_params.input, [moderation_result.refusal_response], ) return TurnSummary( @@ -313,7 +313,9 @@ async def retrieve_agent_response( llm_response=moderation_result.message, ) try: - agent = build_agent(client, responses_params, configuration.skills) + agent = build_agent( + client, responses_params, configuration.skills, original_input + ) logger.debug("Starting agent non-streaming response processing") run_result = await agent.run(cast(str, responses_params.input)) except (AgentRunError, APIStatusError, APIConnectionError, RuntimeError) as exc: diff --git a/src/utils/agents/streaming.py b/src/utils/agents/streaming.py index 138852bc2..7a6acd456 100644 --- a/src/utils/agents/streaming.py +++ b/src/utils/agents/streaming.py @@ -84,6 +84,7 @@ async def retrieve_agent_response_generator( responses_params: ResponsesApiParams, context: ResponseGeneratorContext, endpoint_path: str, + original_input: Optional[ResponseInput] = None, ) -> tuple[AsyncIterator[str], TurnSummary]: """Return the SSE generator and mutable turn summary for an agent run. @@ -91,6 +92,9 @@ async def retrieve_agent_response_generator( responses_params: Prepared Responses API parameters. context: Streaming request context and moderation result. endpoint_path: Endpoint path used for metric labeling. + original_input: In compacted mode, the original user input before the + explicit-input rewrite. Used to persist the completed turn with its + structured input (preserving attachments); ``None`` otherwise. Returns: Tuple of SSE async iterator and mutable turn summary. @@ -100,14 +104,12 @@ async def retrieve_agent_response_generator( if context.moderation_result.decision == "blocked": turn_summary.llm_response = context.moderation_result.message turn_summary.id = context.moderation_result.moderation_id - turn_summary.output_items = [context.moderation_result.refusal_response] - if not responses_params.omit_conversation: - await append_turn_items_to_conversation( - context.client, - responses_params.conversation, - responses_params.input, - [context.moderation_result.refusal_response], - ) + await append_turn_items_to_conversation( + context.client, + responses_params.conversation, + original_input or responses_params.input, + [context.moderation_result.refusal_response], + ) media_type = context.query_request.media_type or MEDIA_TYPE_JSON return ( shield_violation_generator( @@ -117,7 +119,9 @@ async def retrieve_agent_response_generator( turn_summary, ) - agent = build_agent(context.client, responses_params, configuration.skills) + agent = build_agent( + context.client, responses_params, configuration.skills, original_input + ) return ( agent_response_generator( diff --git a/src/utils/pydantic_ai.py b/src/utils/pydantic_ai.py index f4a1cf18c..f5235e9a8 100644 --- a/src/utils/pydantic_ai.py +++ b/src/utils/pydantic_ai.py @@ -10,9 +10,11 @@ from pydantic_ai.models.openai import OpenAIResponsesModelSettings from pydantic_ai_skills import SkillsCapability +from models.common.responses import ResponseInput from models.common.responses.responses_api_params import ResponsesApiParams from models.config import SkillsConfiguration from pydantic_ai_lightspeed.llamastack import ( + CompactionTurnContext, LlamaStackProvider, LlamaStackResponsesModel, ) @@ -115,6 +117,7 @@ def build_agent( client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient, responses_params: ResponsesApiParams, skills: Optional[SkillsConfiguration], + original_input: Optional[ResponseInput] = None, ) -> Agent[None, str]: """Build a Pydantic AI agent that mirrors ``responses_params`` on the Llama Stack backend. @@ -127,6 +130,7 @@ def build_agent( client: Initialized Llama Stack client from ``AsyncLlamaStackClientHolder().get_client()``. responses_params: Parameters produced by ``prepare_responses_params`` for this turn. skills: Agent skills configuration from LCS, or None when skills are disabled. + original_input: When set, enables compacted-turn persistence on the model. Returns: ``Agent`` configured for ``await agent.run(...)`` (or streaming) against the same @@ -135,10 +139,21 @@ def build_agent( provider = _llama_stack_provider_from_client(client) settings = _model_settings_from_responses_params(responses_params) + compaction = ( + CompactionTurnContext( + client=client, + conversation_id=responses_params.conversation, + latest_round_input=original_input, + ) + if original_input is not None + else None + ) + model = LlamaStackResponsesModel( responses_params.model, provider=provider, settings=settings, + compaction=compaction, ) return Agent( model, diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index dd5efd227..de2ca41c5 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -706,55 +706,9 @@ async def test_retrieve_response_generator_shield_blocked( assert isinstance(turn_summary, TurnSummary) assert turn_summary.llm_response == "Content blocked" - # Structured refusal captured for compacted-mode persistence (LCORE-1572). - assert turn_summary.output_items == [mock_moderation_result.refusal_response] # Non-compacted: the refusal turn is stored here. mock_append.assert_awaited_once() - @pytest.mark.asyncio - async def test_retrieve_response_generator_shield_blocked_compacted( - self, mocker: MockerFixture - ) -> None: - """In compacted mode the shield refusal is not stored here (no double-store). - - generate_response persists the compacted turn (with the original input), - so storing it again in the shield branch would duplicate it (LCORE-1572). - """ - mock_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) - - mock_responses_params = mocker.Mock(spec=ResponsesApiParams) - mock_responses_params.model = "provider1/model1" - mock_responses_params.input = "explicit input" - mock_responses_params.conversation = "conv_123" - mock_responses_params.omit_conversation = True # compacted - - mock_context = mocker.Mock(spec=ResponseGeneratorContext) - mock_context.client = mock_client - mock_context.vector_store_ids = [] - mock_context.rag_id_mapping = {} - mock_context.inline_rag_context = RAGContext() - mock_context.query_request = QueryRequest( - query="test", media_type=MEDIA_TYPE_TEXT - ) # pyright: ignore[reportCallIssue] - - mock_moderation_result = mocker.Mock() - mock_moderation_result.decision = "blocked" - mock_moderation_result.message = "Content blocked" - mock_moderation_result.moderation_id = "mod_123" - mock_moderation_result.refusal_response = mocker.Mock() - mock_context.moderation_result = mock_moderation_result - mock_append = mocker.patch( - "app.endpoints.streaming_query.append_turn_items_to_conversation", - new=mocker.AsyncMock(), - ) - - _generator, turn_summary = await retrieve_response_generator( - mock_responses_params, mock_context, endpoint_path="" - ) - - assert turn_summary.output_items == [mock_moderation_result.refusal_response] - mock_append.assert_not_awaited() # compacted: generate_response stores it - @pytest.mark.asyncio async def test_retrieve_response_generator_connection_error( self, mocker: MockerFixture @@ -1019,71 +973,6 @@ async def mock_generator() -> AsyncIterator[str]: assert any("start" in item for item in result) assert any("end" in item for item in result) - @pytest.mark.asyncio - async def test_generate_response_compacted_persists_structured_turn( - self, mocker: MockerFixture - ) -> None: - """Compacted mode persists the turn via store_compacted_turn with the - original input and structured output items, not flattened strings - (LCORE-1572).""" - - async def mock_generator() -> AsyncIterator[str]: - yield "data: token\n\n" - - conv_id = "123e4567-e89b-12d3-a456-426614174000" - mock_context = mocker.Mock(spec=ResponseGeneratorContext) - mock_context.conversation_id = conv_id - mock_context.user_id = "user_123" - mock_context.query_request = QueryRequest( - query="test", conversation_id=conv_id - ) # pyright: ignore[reportCallIssue] - mock_context.started_at = "2024-01-01T00:00:00Z" - mock_context.skip_userid_check = False - mock_context.request_id = "223e4567-e89b-12d3-a456-426614174000" - mock_context.client = mocker.AsyncMock(spec=AsyncLlamaStackClient) - - mock_responses_params = mocker.Mock(spec=ResponsesApiParams) - mock_responses_params.model = "provider1/model1" - mock_responses_params.conversation = conv_id - - turn_summary = TurnSummary() - turn_summary.token_usage = TokenCounter(input_tokens=10, output_tokens=5) - output_item = mocker.Mock() - turn_summary.output_items = [output_item] - - mock_config = mocker.Mock() - mock_config.quota_limiters = [] - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - mocker.patch("app.endpoints.streaming_query.consume_query_tokens") - mocker.patch( - "app.endpoints.streaming_query.get_available_quotas", return_value={} - ) - mocker.patch("app.endpoints.streaming_query.store_query_results") - store_mock = mocker.patch( - "app.endpoints.streaming_query.store_compacted_turn", - new_callable=mocker.AsyncMock, - ) - - result = [ - item - async for item in generate_response( - mock_generator(), - mock_context, - mock_responses_params, - turn_summary, - compacted=True, - original_input="the original query", - ) - ] - - assert any("end" in item for item in result) - store_mock.assert_awaited_once_with( - mock_context.client, - conv_id, - "the original query", - [output_item], - ) - @pytest.mark.asyncio async def test_generate_response_with_topic_summary( self, mocker: MockerFixture @@ -2659,45 +2548,3 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: # Should have both tool call and result (fallback behavior) assert len(mock_turn_summary.tool_calls) == 1 assert len(mock_turn_summary.tool_results) == 1 - - -@pytest.mark.asyncio -async def test_response_generator_failed_captures_output_items( - mocker: MockerFixture, -) -> None: - """A failed terminal captures output_items for compacted persistence (LCORE-1572).""" - out_item = mocker.Mock() - - async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: - chunk = mocker.Mock(spec=FailedChunk) - chunk.type = "response.failed" - mock_response = mocker.Mock() - mock_response.output = [out_item] - mock_response.error = mocker.Mock(message="boom") - chunk.response = mock_response - yield chunk - - mock_context = mocker.Mock(spec=ResponseGeneratorContext) - mock_context.query_request = QueryRequest( - query="test", media_type=MEDIA_TYPE_JSON - ) # pyright: ignore[reportCallIssue] - mock_context.model_id = "provider1/model1" - mock_context.vector_store_ids = [] - mock_context.rag_id_mapping = {} - mock_context.inline_rag_context = RAGContext() - - turn_summary = TurnSummary() - mocker.patch( - "app.endpoints.streaming_query.extract_token_usage", - return_value=TokenCounter(input_tokens=0, output_tokens=0), - ) - mocker.patch( - "app.endpoints.streaming_query.parse_referenced_documents", return_value=[] - ) - - async for _ in response_generator( - mock_turn_response(), mock_context, turn_summary, endpoint_path="" - ): - pass - - assert turn_summary.output_items == [out_item] diff --git a/tests/unit/utils/agents/test_streaming.py b/tests/unit/utils/agents/test_streaming.py index b93c314cc..11e494817 100644 --- a/tests/unit/utils/agents/test_streaming.py +++ b/tests/unit/utils/agents/test_streaming.py @@ -528,16 +528,20 @@ async def test_blocked_moderation_returns_shield_generator( assert turn_summary.id == blocked_moderation.moderation_id @pytest.mark.asyncio - async def test_blocked_moderation_skips_append_when_omit_conversation( + async def test_blocked_moderation_compacted_appends_with_original_input( self, mocker: MockerFixture, make_generator_context: Callable[..., ResponseGeneratorContext], make_responses_params: Callable[..., ResponsesApiParams], blocked_moderation: ShieldModerationBlocked, ) -> None: - """Test compacted mode does not append blocked turn to conversation.""" + """Test compacted blocked turns persist with original_input, not explicit input.""" context = make_generator_context(moderation_result=blocked_moderation) - responses_params = make_responses_params(omit_conversation=True) + responses_params = make_responses_params( + omit_conversation=True, + input_text="explicit summaries-plus-recent input", + ) + original_input = "the real user query" mocker.patch( "utils.agents.streaming.shield_violation_generator", return_value=_async_iter([]), @@ -551,9 +555,15 @@ async def test_blocked_moderation_skips_append_when_omit_conversation( responses_params, context, ENDPOINT_PATH_STREAMING_QUERY, + original_input=original_input, ) - mock_append.assert_not_awaited() + mock_append.assert_awaited_once_with( + context.client, + responses_params.conversation, + original_input, + [blocked_moderation.refusal_response], + ) @pytest.mark.asyncio async def test_success_returns_agent_generator(