Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions astrbot/core/astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,13 @@ def _build_handoff_toolset(
# "all tools", including runtime computer-use tools.
if tools is None:
toolset = ToolSet()
for registered_tool in llm_tools.func_list:
if isinstance(registered_tool, HandoffTool):
handoff_names = {
tool.name
for tool in tool_mgr.func_list
if isinstance(tool, HandoffTool)
}
for registered_tool in tool_mgr.get_full_tool_set():
if registered_tool.name in handoff_names:
continue
if registered_tool.active:
toolset.add_tool(registered_tool)
Expand Down
24 changes: 19 additions & 5 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,10 +456,10 @@ async def _ensure_persona_and_skills(
cfg: dict,
plugin_context: Context,
event: AstrMessageEvent,
) -> None:
) -> set[str] | None:
"""Ensure persona and skills are applied to the request's system prompt or user prompt."""
if not req.conversation:
return
return None

(
persona_id,
Expand Down Expand Up @@ -514,11 +514,13 @@ async def _ensure_persona_and_skills(

# inject toolset in the persona
if (persona and persona.get("tools") is None) or not persona:
persona_allowed_tools = None
persona_toolset = tmgr.get_full_tool_set()
for tool in list(persona_toolset):
if not tool.active:
persona_toolset.remove_tool(tool.name)
else:
persona_allowed_tools = {str(tool_name) for tool_name in persona["tools"]}
persona_toolset = ToolSet()
if persona["tools"]:
for tool_name in persona["tools"]:
Expand Down Expand Up @@ -599,6 +601,7 @@ async def _ensure_persona_and_skills(
)
except Exception:
pass
return persona_allowed_tools


async def _request_img_caption(
Expand Down Expand Up @@ -931,12 +934,13 @@ async def _decorate_llm_request(
plugin_context: Context,
config: MainAgentBuildConfig,
provider: Provider | None = None,
) -> None:
) -> set[str] | None:
cfg = config.provider_settings or plugin_context.get_config(
umo=event.unified_msg_origin
).get("provider_settings", {})

_apply_prompt_prefix(req, cfg)
persona_allowed_tools = None

main_provider_supports_image = provider is not None and _provider_supports_modality(
provider, "image"
Expand All @@ -945,7 +949,9 @@ async def _decorate_llm_request(
quote_images_already_captioned = False

if req.conversation:
await _ensure_persona_and_skills(req, cfg, plugin_context, event)
persona_allowed_tools = await _ensure_persona_and_skills(
req, cfg, plugin_context, event
)

if img_cap_prov_id and req.image_urls and not main_provider_supports_image:
await _ensure_img_caption(
Expand Down Expand Up @@ -974,6 +980,7 @@ async def _decorate_llm_request(
tz = plugin_context.get_config().get("timezone")
_append_system_reminders(event, req, cfg, tz)
_apply_workspace_extra_prompt(event, req)
return persona_allowed_tools


def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
Expand Down Expand Up @@ -1502,7 +1509,9 @@ async def build_main_agent(
else:
return None

await _decorate_llm_request(event, req, plugin_context, config, provider=provider)
persona_allowed_tools = await _decorate_llm_request(
event, req, plugin_context, config, provider=provider
)

await _apply_kb(event, req, plugin_context, config)

Expand Down Expand Up @@ -1538,6 +1547,11 @@ async def build_main_agent(
)
)

if persona_allowed_tools is not None and req.func_tool:
for tool in list(req.func_tool):
if tool.name not in persona_allowed_tools:
req.func_tool.remove_tool(tool.name)
Comment on lines +1550 to +1553

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of iterating over the tools and calling remove_tool (which performs a list comprehension internally on every call, leading to $O(N^2)$ complexity), you can filter the tools list in-place in a single $O(N)$ pass.

Suggested change
if persona_allowed_tools is not None and req.func_tool:
for tool in list(req.func_tool):
if tool.name not in persona_allowed_tools:
req.func_tool.remove_tool(tool.name)
if persona_allowed_tools is not None and req.func_tool:
req.func_tool.tools = [
tool for tool in req.func_tool.tools
if tool.name in persona_allowed_tools
]


fallback_providers = _get_fallback_chat_providers(
provider, plugin_context, config.provider_settings
)
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/test_astr_agent_tool_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
import mcp
import pytest

from astrbot.core.agent.agent import Agent
from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
from astrbot.core.message.components import Image
from astrbot.core.provider.func_tool_manager import (
FunctionToolManager,
_PermissionGuardedTool,
)


class _DummyEvent:
Expand All @@ -29,6 +36,32 @@ def _build_run_context(message_components: list[object] | None = None):
return ContextWrapper(context=ctx)


def test_build_handoff_toolset_keeps_permission_guards_for_default_tools():
mgr = FunctionToolManager()
plugin_tool = FunctionTool(
name="admin_only_mcp",
description="admin tool",
parameters={"type": "object", "properties": {}},
)
handoff = HandoffTool(Agent(name="child"))
mgr.func_list = [plugin_tool, handoff]

event = _DummyEvent()
context = SimpleNamespace(
get_config=lambda **_kwargs: {
"provider_settings": {"computer_use_runtime": "none"}
},
get_llm_tool_manager=lambda: mgr,
)
run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context))

toolset = FunctionToolExecutor._build_handoff_toolset(run_context, tools=None)

assert toolset is not None
assert isinstance(toolset.get_tool("admin_only_mcp"), _PermissionGuardedTool)
assert toolset.get_tool("transfer_to_child") is None


@pytest.mark.asyncio
async def test_collect_handoff_image_urls_normalizes_filters_and_appends_event_image(
monkeypatch: pytest.MonkeyPatch,
Expand Down
51 changes: 51 additions & 0 deletions tests/unit/test_astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,57 @@ async def test_ensure_tools_from_persona(self, mock_event, mock_context):

assert req.func_tool is not None

@pytest.mark.asyncio
async def test_persona_empty_tools_filters_late_builtin_tools(
self, mock_event, mock_context, mock_provider
):
module = ama
persona = {"name": "locked", "prompt": "No tools.", "tools": []}
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
return_value=("locked", persona, None, False)
)
mock_context.get_config.return_value = {
"provider_settings": {
"web_search": True,
"websearch_provider": "baidu_ai_search",
}
}
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
provider_settings={
"web_search": True,
"websearch_provider": "baidu_ai_search",
},
computer_use_runtime="none",
)
req = ProviderRequest(prompt="hello")
req.conversation = MagicMock(persona_id="locked", history="[]")

with (
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
):
mock_runner = MagicMock()
mock_runner.reset = AsyncMock()
mock_runner_cls.return_value = mock_runner

result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=config,
provider=mock_provider,
req=req,
apply_reset=False,
)
assert result is not None
try:
assert result.provider_request.func_tool is None or (
result.provider_request.func_tool.empty()
)
finally:
if result.reset_coro:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If build_main_agent returns None (e.g., due to a configuration or provider resolution failure), result will be None. In the finally block, attempting to access result.reset_coro will raise an AttributeError, which masks the actual test assertion failure (assert result is not None). Adding a null check for result prevents this masking.

Suggested change
if result.reset_coro:
if result and result.reset_coro:

result.reset_coro.close()

@pytest.mark.asyncio
async def test_subagent_dedupe_uses_default_persona_tools(
self, mock_event, mock_context
Expand Down
Loading