diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 3f74f0ec9b..8f48c9c20c 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -107,6 +107,53 @@ class _ToolExecutionInterrupted(Exception): ToolExecutorResultT = T.TypeVar("ToolExecutorResultT") +def _normalize_text(text: str) -> str: + """标准化文本用于比对:strip + 合并连续空白。""" + return " ".join(text.strip().split()) + + +def _extract_text_from_chain(chain) -> str | None: + """从 MessageChain 中提取纯文本内容。""" + if not chain or not hasattr(chain, "chain") or not chain.chain: + return None + texts = [] + for comp in chain.chain: + if hasattr(comp, "text"): + text = comp.text.strip() + if text: + texts.append(text) + return " ".join(texts) if texts else None + + +def _extract_send_message_texts( + tools_call_name: list[str] | None, + tools_call_args: list[dict] | None, +) -> list[str]: + """从所有 send_message_to_user 调用中提取纯文本。 + + 用于比对 completion_text / result_chain 与工具 payload + 是否一致,判断是否为重复发送。 + 仅提取 type="plain" 的文本部分。 + """ + if not tools_call_name or not tools_call_args: + return [] + results = [] + for name, args in zip(tools_call_name, tools_call_args): + if name == "send_message_to_user" and isinstance(args, dict): + messages = args.get("messages") + if not isinstance(messages, list): + continue + texts = [] + for msg in messages: + if isinstance(msg, dict) and msg.get("type") == "plain": + text = msg.get("text", "").strip() + if text: + texts.append(text) + if texts: + results.append(" ".join(texts)) + return results + + class ToolLoopAgentRunner(BaseAgentRunner[TContext]): TOOL_RESULT_MAX_ESTIMATED_TOKENS = 27_500 TOOL_RESULT_PREVIEW_MAX_ESTIMATED_TOKENS = 7000 @@ -792,6 +839,36 @@ async def step(self): await self._complete_with_assistant_response(llm_resp) # 返回 LLM 结果 + # 当 send_message_to_user 的 payload 与 completion_text + # 或 result_chain 内容一致时,抑制 yield, + # 避免 respond 阶段重复发送。 + # 仅在内容匹配时抑制,不影响发到其他会话或内容不同的场景。 + _should_suppress_text = False + if ( + llm_resp.tools_call_name + and "send_message_to_user" in llm_resp.tools_call_name + ): + _tool_texts = _extract_send_message_texts( + llm_resp.tools_call_name, llm_resp.tools_call_args + ) + if _tool_texts: + _completion = _normalize_text( + llm_resp.completion_text or "" + ) + _chain_text = _normalize_text( + _extract_text_from_chain(llm_resp.result_chain) or "" + ) + for _tt in _tool_texts: + _nt = _normalize_text(_tt) + if (_nt and _completion and _nt == _completion) or ( + _nt and _chain_text and _nt == _chain_text + ): + _should_suppress_text = True + logger.info( + "send_message_to_user payload 与响应文本一致," + f"抑制以避免重复发送。 text={_tt[:50]!r}" + ) + break if llm_resp.reasoning_content: yield AgentResponse( type="llm_result", @@ -802,17 +879,19 @@ async def step(self): ), ) if llm_resp.result_chain: - yield AgentResponse( - type="llm_result", - data=AgentResponseData(chain=llm_resp.result_chain), - ) + if not _should_suppress_text: + yield AgentResponse( + type="llm_result", + data=AgentResponseData(chain=llm_resp.result_chain), + ) elif llm_resp.completion_text: - yield AgentResponse( - type="llm_result", - data=AgentResponseData( - chain=MessageChain().message(llm_resp.completion_text), - ), - ) + if not _should_suppress_text: + yield AgentResponse( + type="llm_result", + data=AgentResponseData( + chain=MessageChain().message(llm_resp.completion_text), + ), + ) # 如果有工具调用,还需处理工具调用 if llm_resp.tools_call_name: diff --git a/astrbot/core/tools/message_tools.py b/astrbot/core/tools/message_tools.py index 40516d5297..6e65452bf8 100644 --- a/astrbot/core/tools/message_tools.py +++ b/astrbot/core/tools/message_tools.py @@ -1,3 +1,4 @@ +import hashlib import json import os import shlex @@ -27,6 +28,7 @@ ) + def _file_send_allowed_roots(umo: str | None) -> tuple[Path, ...]: roots = [] if umo: @@ -316,6 +318,29 @@ async def call( else: return f"error: invalid session: {session}" + # 去重:按事件作用域记录已发送的消息指纹, + # 拦截同一 agent run 内的重复调用。 + # 作用域限定在当前事件,不影响其他事件的合法重复发送。 + dedup_key = str(target_session) + json.dumps( + messages, ensure_ascii=False, sort_keys=True + ) + fingerprint = hashlib.md5(dedup_key.encode()).hexdigest() + existing = context.context.event.get_extra( + "_send_message_fingerprints" + ) + sent_fingerprints = set(existing) if existing else set() + if fingerprint in sent_fingerprints: + logger.info( + f"[send_message_to_user] 当前事件内重复发送,已跳过。" + f" session={session}, target_session={target_session}," + f" fingerprint={fingerprint[:8]}" + ) + return f"Message skipped (duplicate), session={target_session}" + sent_fingerprints.add(fingerprint) + context.context.event.set_extra( + "_send_message_fingerprints", sent_fingerprints + ) + await context.context.context.send_message( target_session, MessageChain(chain=components), diff --git a/tests/unit/test_message_tools.py b/tests/unit/test_message_tools.py index 2b4659f5b2..1dec7681a8 100644 --- a/tests/unit/test_message_tools.py +++ b/tests/unit/test_message_tools.py @@ -8,6 +8,24 @@ from astrbot.core.tools.message_tools import SendMessageToUserTool +class _MockEvent: + """Minimal event mock with get_extra/set_extra support.""" + + def __init__(self, unified_msg_origin, role): + self.unified_msg_origin = unified_msg_origin + self.role = role + self._extras: dict = {} + + def get_sender_id(self): + return "user-1" + + def get_extra(self, key, default=None): + return self._extras.get(key, default) + + def set_extra(self, key, value): + self._extras[key] = value + + def _make_context( current_session="feishu:GroupMessage:oc_xxx", role="admin", @@ -23,11 +41,7 @@ def _make_context( } return SimpleNamespace( context=SimpleNamespace( - event=SimpleNamespace( - unified_msg_origin=current_session, - role=role, - get_sender_id=lambda: "user-1", - ), + event=_MockEvent(current_session, role), context=SimpleNamespace( get_config=lambda umo: cfg, send_message=AsyncMock(),