diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index de5caad554..8c3ed661f9 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -266,8 +266,13 @@ def _build_handoff_toolset( # "all tools", including runtime computer-use tools. if tools is None: toolset = ToolSet() - for registered_tool in llm_tools.func_list: - if isinstance(registered_tool, HandoffTool): + handoff_names = { + tool.name + for tool in tool_mgr.func_list + if isinstance(tool, HandoffTool) + } + for registered_tool in tool_mgr.get_full_tool_set(): + if registered_tool.name in handoff_names: continue if registered_tool.active: toolset.add_tool(registered_tool) diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 4de3d77bfb..6aa0e523a8 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -456,10 +456,10 @@ async def _ensure_persona_and_skills( cfg: dict, plugin_context: Context, event: AstrMessageEvent, -) -> None: +) -> set[str] | None: """Ensure persona and skills are applied to the request's system prompt or user prompt.""" if not req.conversation: - return + return None ( persona_id, @@ -514,11 +514,13 @@ async def _ensure_persona_and_skills( # inject toolset in the persona if (persona and persona.get("tools") is None) or not persona: + persona_allowed_tools = None persona_toolset = tmgr.get_full_tool_set() for tool in list(persona_toolset): if not tool.active: persona_toolset.remove_tool(tool.name) else: + persona_allowed_tools = {str(tool_name) for tool_name in persona["tools"]} persona_toolset = ToolSet() if persona["tools"]: for tool_name in persona["tools"]: @@ -599,6 +601,7 @@ async def _ensure_persona_and_skills( ) except Exception: pass + return persona_allowed_tools async def _request_img_caption( @@ -931,12 +934,13 @@ async def _decorate_llm_request( plugin_context: Context, config: MainAgentBuildConfig, provider: Provider | None = None, -) -> None: +) -> set[str] | None: cfg = config.provider_settings or plugin_context.get_config( umo=event.unified_msg_origin ).get("provider_settings", {}) _apply_prompt_prefix(req, cfg) + persona_allowed_tools = None main_provider_supports_image = provider is not None and _provider_supports_modality( provider, "image" @@ -945,7 +949,9 @@ async def _decorate_llm_request( quote_images_already_captioned = False if req.conversation: - await _ensure_persona_and_skills(req, cfg, plugin_context, event) + persona_allowed_tools = await _ensure_persona_and_skills( + req, cfg, plugin_context, event + ) if img_cap_prov_id and req.image_urls and not main_provider_supports_image: await _ensure_img_caption( @@ -974,6 +980,7 @@ async def _decorate_llm_request( tz = plugin_context.get_config().get("timezone") _append_system_reminders(event, req, cfg, tz) _apply_workspace_extra_prompt(event, req) + return persona_allowed_tools def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None: @@ -1502,7 +1509,9 @@ async def build_main_agent( else: return None - await _decorate_llm_request(event, req, plugin_context, config, provider=provider) + persona_allowed_tools = await _decorate_llm_request( + event, req, plugin_context, config, provider=provider + ) await _apply_kb(event, req, plugin_context, config) @@ -1538,6 +1547,11 @@ async def build_main_agent( ) ) + if persona_allowed_tools is not None and req.func_tool: + for tool in list(req.func_tool): + if tool.name not in persona_allowed_tools: + req.func_tool.remove_tool(tool.name) + fallback_providers = _get_fallback_chat_providers( provider, plugin_context, config.provider_settings ) diff --git a/tests/unit/test_astr_agent_tool_exec.py b/tests/unit/test_astr_agent_tool_exec.py index 5fab9fe0a2..61fb4048c8 100644 --- a/tests/unit/test_astr_agent_tool_exec.py +++ b/tests/unit/test_astr_agent_tool_exec.py @@ -3,9 +3,16 @@ import mcp import pytest +from astrbot.core.agent.agent import Agent +from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import FunctionTool from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor from astrbot.core.message.components import Image +from astrbot.core.provider.func_tool_manager import ( + FunctionToolManager, + _PermissionGuardedTool, +) class _DummyEvent: @@ -29,6 +36,32 @@ def _build_run_context(message_components: list[object] | None = None): return ContextWrapper(context=ctx) +def test_build_handoff_toolset_keeps_permission_guards_for_default_tools(): + mgr = FunctionToolManager() + plugin_tool = FunctionTool( + name="admin_only_mcp", + description="admin tool", + parameters={"type": "object", "properties": {}}, + ) + handoff = HandoffTool(Agent(name="child")) + mgr.func_list = [plugin_tool, handoff] + + event = _DummyEvent() + context = SimpleNamespace( + get_config=lambda **_kwargs: { + "provider_settings": {"computer_use_runtime": "none"} + }, + get_llm_tool_manager=lambda: mgr, + ) + run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) + + toolset = FunctionToolExecutor._build_handoff_toolset(run_context, tools=None) + + assert toolset is not None + assert isinstance(toolset.get_tool("admin_only_mcp"), _PermissionGuardedTool) + assert toolset.get_tool("transfer_to_child") is None + + @pytest.mark.asyncio async def test_collect_handoff_image_urls_normalizes_filters_and_appends_event_image( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 6265cb077e..e49cc4c319 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -817,6 +817,57 @@ async def test_ensure_tools_from_persona(self, mock_event, mock_context): assert req.func_tool is not None + @pytest.mark.asyncio + async def test_persona_empty_tools_filters_late_builtin_tools( + self, mock_event, mock_context, mock_provider + ): + module = ama + persona = {"name": "locked", "prompt": "No tools.", "tools": []} + mock_context.persona_manager.resolve_selected_persona = AsyncMock( + return_value=("locked", persona, None, False) + ) + mock_context.get_config.return_value = { + "provider_settings": { + "web_search": True, + "websearch_provider": "baidu_ai_search", + } + } + config = module.MainAgentBuildConfig( + tool_call_timeout=60, + provider_settings={ + "web_search": True, + "websearch_provider": "baidu_ai_search", + }, + computer_use_runtime="none", + ) + req = ProviderRequest(prompt="hello") + req.conversation = MagicMock(persona_id="locked", history="[]") + + with ( + patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls, + patch("astrbot.core.astr_main_agent.AstrAgentContext"), + ): + mock_runner = MagicMock() + mock_runner.reset = AsyncMock() + mock_runner_cls.return_value = mock_runner + + result = await module.build_main_agent( + event=mock_event, + plugin_context=mock_context, + config=config, + provider=mock_provider, + req=req, + apply_reset=False, + ) + assert result is not None + try: + assert result.provider_request.func_tool is None or ( + result.provider_request.func_tool.empty() + ) + finally: + if result.reset_coro: + result.reset_coro.close() + @pytest.mark.asyncio async def test_subagent_dedupe_uses_default_persona_tools( self, mock_event, mock_context