Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -13614,6 +13614,34 @@
"type": "array",
"title": "High-level inference providers",
"description": "Unified-mode synthesis input (Decision S5): a high-level, backend-agnostic list of inference providers the synthesizer expands into Llama Stack provider entries. Lives at the configuration root so it survives a future backend change. A non-empty list signals unified mode. Empty (the default) leaves legacy/remote modes unaffected. The sibling default_model / default_provider keep their query-time routing meaning and are independent of this list."
},
"max_infer_iters": {
"anyOf": [
{
"type": "integer",
"exclusiveMinimum": 0.0
},
{
"type": "null"
}
],
"title": "Default max inference iterations",
"description": "Server-side default for the maximum number of inference iterations a model can perform in a single request. Prevents small models from looping indefinitely on tool calls. Per-request values take precedence over this default. Set to None to disable the limit.",
"default": 10
},
"max_tool_calls": {
"anyOf": [
{
"type": "integer",
"exclusiveMinimum": 0.0
},
{
"type": "null"
}
],
"title": "Default max tool calls",
"description": "Server-side default for the maximum number of tool calls allowed in a single response. Prevents small models from exhausting the context window with repeated tool calls. Per-request values take precedence over this default. Set to None to disable the limit.",
"default": 30
}
},
"additionalProperties": false,
Expand Down
5 changes: 5 additions & 0 deletions src/app/endpoints/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,11 @@ async def responses_endpoint_handler(
original_request.input, inline_rag_context.context_text
)

if "max_infer_iters" not in original_request.model_fields_set:
updated_request.max_infer_iters = configuration.inference.max_infer_iters
if "max_tool_calls" not in original_request.model_fields_set:
updated_request.max_tool_calls = configuration.inference.max_tool_calls

api_params = ResponsesApiParams.model_validate(updated_request.model_dump())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a nit, but could you please follow the same pattern as other parameter overrides do. Namely, override attributes of updated_request (if not explicitly set) above the model_validate command.

# Compact the conversation if it is approaching the context window limit.
Expand Down
20 changes: 20 additions & 0 deletions src/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,6 +1727,26 @@ class InferenceConfiguration(ConfigurationBase):
"meaning and are independent of this list.",
)

max_infer_iters: Optional[PositiveInt] = Field(
default=10,
title="Default max inference iterations",
description="Server-side default for the maximum number of inference "
"iterations a model can perform in a single request. Prevents small "
"models from looping indefinitely on tool calls. "
"Per-request values take precedence over this default. "
"Set to None to disable the limit.",
)

max_tool_calls: Optional[PositiveInt] = Field(
default=30,
title="Default max tool calls",
description="Server-side default for the maximum number of tool calls "
"allowed in a single response. Prevents small models from exhausting "
"the context window with repeated tool calls. "
"Per-request values take precedence over this default. "
"Set to None to disable the limit.",
)

@model_validator(mode="after")
def check_default_model_and_provider(self) -> Self:
"""
Expand Down
2 changes: 2 additions & 0 deletions src/utils/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,8 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma
stream=stream,
store=store,
extra_headers=extra_headers,
max_infer_iters=configuration.inference.max_infer_iters,
max_tool_calls=configuration.inference.max_tool_calls,
)


Expand Down
68 changes: 68 additions & 0 deletions tests/unit/models/config/test_inference_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,71 @@ def test_context_windows_rejects_negative_size() -> None:
InferenceConfiguration(
context_windows={"openai/gpt-4o-mini": -1},
) # pyright: ignore[reportCallIssue]


def test_max_infer_iters_default() -> None:
"""Test that max_infer_iters defaults to 10."""
config = InferenceConfiguration() # pyright: ignore[reportCallIssue]
assert config.max_infer_iters == 10


def test_max_tool_calls_default() -> None:
"""Test that max_tool_calls defaults to 30."""
config = InferenceConfiguration() # pyright: ignore[reportCallIssue]
assert config.max_tool_calls == 30


def test_max_infer_iters_accepts_positive_int() -> None:
"""Test that max_infer_iters accepts a positive integer."""
config = InferenceConfiguration(
max_infer_iters=5
) # pyright: ignore[reportCallIssue]
assert config.max_infer_iters == 5


def test_max_tool_calls_accepts_positive_int() -> None:
"""Test that max_tool_calls accepts a positive integer."""
config = InferenceConfiguration(
max_tool_calls=20
) # pyright: ignore[reportCallIssue]
assert config.max_tool_calls == 20


def test_max_infer_iters_rejects_zero() -> None:
"""Test that max_infer_iters rejects zero."""
with pytest.raises(ValueError):
InferenceConfiguration(max_infer_iters=0) # pyright: ignore[reportCallIssue]


def test_max_infer_iters_rejects_negative() -> None:
"""Test that max_infer_iters rejects a negative value."""
with pytest.raises(ValueError):
InferenceConfiguration(max_infer_iters=-1) # pyright: ignore[reportCallIssue]


def test_max_tool_calls_rejects_zero() -> None:
"""Test that max_tool_calls rejects zero."""
with pytest.raises(ValueError):
InferenceConfiguration(max_tool_calls=0) # pyright: ignore[reportCallIssue]


def test_max_tool_calls_rejects_negative() -> None:
"""Test that max_tool_calls rejects a negative value."""
with pytest.raises(ValueError):
InferenceConfiguration(max_tool_calls=-1) # pyright: ignore[reportCallIssue]


def test_max_infer_iters_accepts_none() -> None:
"""Test that max_infer_iters accepts None to disable the limit."""
config = InferenceConfiguration(
max_infer_iters=None
) # pyright: ignore[reportCallIssue]
assert config.max_infer_iters is None


def test_max_tool_calls_accepts_none() -> None:
"""Test that max_tool_calls accepts None to disable the limit."""
config = InferenceConfiguration(
max_tool_calls=None
) # pyright: ignore[reportCallIssue]
assert config.max_tool_calls is None
23 changes: 14 additions & 9 deletions tests/unit/utils/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@
import constants
from models.api.requests import QueryRequest
from models.common.responses.types import InputTool, InputToolMCP
from models.config import ApprovalFilter, ByokRag, ModelContextProtocolServer
from models.config import (
ApprovalFilter,
ByokRag,
InferenceConfiguration,
ModelContextProtocolServer,
)
from utils.query import normalize_vertex_ai_model_id
from utils.responses import (
_build_chunk_attributes,
Expand Down Expand Up @@ -1976,7 +1981,7 @@ async def test_prepare_responses_params_with_conversation_id(
) # pyright: ignore[reportCallIssue]

mock_config = mocker.Mock()
mock_config.inference = None
mock_config.inference = InferenceConfiguration()
mocker.patch("utils.responses.configuration", mock_config)
mocker.patch("utils.responses.get_system_prompt", return_value="System prompt")
mocker.patch("utils.responses.prepare_tools", return_value=None)
Expand Down Expand Up @@ -2012,7 +2017,7 @@ async def test_prepare_responses_params_create_conversation(
query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]

mock_config = mocker.Mock()
mock_config.inference = None
mock_config.inference = InferenceConfiguration()
mocker.patch("utils.responses.configuration", mock_config)
mocker.patch("utils.responses.get_system_prompt", return_value="System prompt")
mocker.patch("utils.responses.prepare_tools", return_value=None)
Expand All @@ -2038,7 +2043,7 @@ async def test_prepare_responses_params_connection_error_on_models(

query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]
mock_config = mocker.Mock()
mock_config.inference = None
mock_config.inference = InferenceConfiguration()
mocker.patch("utils.responses.configuration", mock_config)

with pytest.raises(HTTPException) as exc_info:
Expand All @@ -2064,7 +2069,7 @@ async def test_prepare_responses_params_connection_error_on_conversation(
query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]

mock_config = mocker.Mock()
mock_config.inference = None
mock_config.inference = InferenceConfiguration()
mocker.patch("utils.responses.configuration", mock_config)
mocker.patch("utils.responses.get_system_prompt", return_value="System prompt")
mocker.patch("utils.responses.prepare_tools", return_value=None)
Expand All @@ -2088,7 +2093,7 @@ async def test_prepare_responses_params_api_status_error_on_models(

query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]
mock_config = mocker.Mock()
mock_config.inference = None
mock_config.inference = InferenceConfiguration()
mocker.patch("utils.responses.configuration", mock_config)

with pytest.raises(HTTPException) as exc_info:
Expand Down Expand Up @@ -2131,7 +2136,7 @@ async def test_prepare_responses_params_includes_mcp_provider_data_headers(
]

mock_config = mocker.Mock()
mock_config.inference = None
mock_config.inference = InferenceConfiguration()
mocker.patch("utils.responses.configuration", mock_config)
mocker.patch("utils.responses.get_system_prompt", return_value="System prompt")
mocker.patch(
Expand Down Expand Up @@ -2179,7 +2184,7 @@ async def test_prepare_responses_params_no_extra_headers_without_mcp_tools(
query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]

mock_config = mocker.Mock()
mock_config.inference = None
mock_config.inference = InferenceConfiguration()
mocker.patch("utils.responses.configuration", mock_config)
mocker.patch("utils.responses.get_system_prompt", return_value="System prompt")
mocker.patch("utils.responses.prepare_tools", return_value=None)
Expand Down Expand Up @@ -2211,7 +2216,7 @@ async def test_prepare_responses_params_api_status_error_on_conversation(
query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]

mock_config = mocker.Mock()
mock_config.inference = None
mock_config.inference = InferenceConfiguration()
mocker.patch("utils.responses.configuration", mock_config)
mocker.patch("utils.responses.get_system_prompt", return_value="System prompt")
mocker.patch("utils.responses.prepare_tools", return_value=None)
Expand Down
Loading