Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 89 additions & 10 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,53 @@ class _ToolExecutionInterrupted(Exception):
ToolExecutorResultT = T.TypeVar("ToolExecutorResultT")


def _normalize_text(text: str) -> str:
"""标准化文本用于比对:strip + 合并连续空白。"""
return " ".join(text.strip().split())


def _extract_text_from_chain(chain) -> str | None:
"""从 MessageChain 中提取纯文本内容。"""
if not chain or not hasattr(chain, "chain") or not chain.chain:
return None
texts = []
for comp in chain.chain:
if hasattr(comp, "text"):
text = comp.text.strip()
if text:
texts.append(text)
return " ".join(texts) if texts else None


def _extract_send_message_texts(
tools_call_name: list[str] | None,
tools_call_args: list[dict] | None,
) -> list[str]:
"""从所有 send_message_to_user 调用中提取纯文本。

用于比对 completion_text / result_chain 与工具 payload
是否一致,判断是否为重复发送。
仅提取 type="plain" 的文本部分。
"""
if not tools_call_name or not tools_call_args:
return []
results = []
for name, args in zip(tools_call_name, tools_call_args):
if name == "send_message_to_user" and isinstance(args, dict):
messages = args.get("messages")
if not isinstance(messages, list):
continue
texts = []
for msg in messages:
if isinstance(msg, dict) and msg.get("type") == "plain":
text = msg.get("text", "").strip()
if text:
texts.append(text)
if texts:
results.append(" ".join(texts))
return results


class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
TOOL_RESULT_MAX_ESTIMATED_TOKENS = 27_500
TOOL_RESULT_PREVIEW_MAX_ESTIMATED_TOKENS = 7000
Expand Down Expand Up @@ -792,6 +839,36 @@ async def step(self):
await self._complete_with_assistant_response(llm_resp)

# 返回 LLM 结果
# 当 send_message_to_user 的 payload 与 completion_text
# 或 result_chain 内容一致时,抑制 yield,
# 避免 respond 阶段重复发送。
# 仅在内容匹配时抑制,不影响发到其他会话或内容不同的场景。
_should_suppress_text = False
if (
llm_resp.tools_call_name
and "send_message_to_user" in llm_resp.tools_call_name
):
_tool_texts = _extract_send_message_texts(
llm_resp.tools_call_name, llm_resp.tools_call_args
)
if _tool_texts:
_completion = _normalize_text(
llm_resp.completion_text or ""
)
_chain_text = _normalize_text(
_extract_text_from_chain(llm_resp.result_chain) or ""
)
for _tt in _tool_texts:
_nt = _normalize_text(_tt)
if (_nt and _completion and _nt == _completion) or (
_nt and _chain_text and _nt == _chain_text
):
_should_suppress_text = True
logger.info(
"send_message_to_user payload 与响应文本一致,"
f"抑制以避免重复发送。 text={_tt[:50]!r}"
)
break
if llm_resp.reasoning_content:
yield AgentResponse(
type="llm_result",
Expand All @@ -802,17 +879,19 @@ async def step(self):
),
)
if llm_resp.result_chain:
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=llm_resp.result_chain),
)
if not _should_suppress_text:
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=llm_resp.result_chain),
)
elif llm_resp.completion_text:
yield AgentResponse(
type="llm_result",
data=AgentResponseData(
chain=MessageChain().message(llm_resp.completion_text),
),
)
if not _should_suppress_text:
yield AgentResponse(
type="llm_result",
data=AgentResponseData(
chain=MessageChain().message(llm_resp.completion_text),
),
)

# 如果有工具调用,还需处理工具调用
if llm_resp.tools_call_name:
Expand Down
25 changes: 25 additions & 0 deletions astrbot/core/tools/message_tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import json
import os
import shlex
Expand Down Expand Up @@ -27,6 +28,7 @@
)



def _file_send_allowed_roots(umo: str | None) -> tuple[Path, ...]:
roots = []
if umo:
Expand Down Expand Up @@ -316,6 +318,29 @@ async def call(
else:
return f"error: invalid session: {session}"

# 去重:按事件作用域记录已发送的消息指纹,
# 拦截同一 agent run 内的重复调用。
# 作用域限定在当前事件,不影响其他事件的合法重复发送。
dedup_key = str(target_session) + json.dumps(
messages, ensure_ascii=False, sort_keys=True
)
fingerprint = hashlib.md5(dedup_key.encode()).hexdigest()
existing = context.context.event.get_extra(
"_send_message_fingerprints"
)
sent_fingerprints = set(existing) if existing else set()
if fingerprint in sent_fingerprints:
logger.info(
f"[send_message_to_user] 当前事件内重复发送,已跳过。"
f" session={session}, target_session={target_session},"
f" fingerprint={fingerprint[:8]}"
)
return f"Message skipped (duplicate), session={target_session}"
sent_fingerprints.add(fingerprint)
context.context.event.set_extra(
"_send_message_fingerprints", sent_fingerprints
)

await context.context.context.send_message(
target_session,
MessageChain(chain=components),
Expand Down
24 changes: 19 additions & 5 deletions tests/unit/test_message_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@
from astrbot.core.tools.message_tools import SendMessageToUserTool


class _MockEvent:
"""Minimal event mock with get_extra/set_extra support."""

def __init__(self, unified_msg_origin, role):
self.unified_msg_origin = unified_msg_origin
self.role = role
self._extras: dict = {}

def get_sender_id(self):
return "user-1"

def get_extra(self, key, default=None):
return self._extras.get(key, default)

def set_extra(self, key, value):
self._extras[key] = value


def _make_context(
current_session="feishu:GroupMessage:oc_xxx",
role="admin",
Expand All @@ -23,11 +41,7 @@ def _make_context(
}
return SimpleNamespace(
context=SimpleNamespace(
event=SimpleNamespace(
unified_msg_origin=current_session,
role=role,
get_sender_id=lambda: "user-1",
),
event=_MockEvent(current_session, role),
context=SimpleNamespace(
get_config=lambda umo: cfg,
send_message=AsyncMock(),
Expand Down
Loading