From 21c16ce8b99d8a446b5d99189b63e2e48fd0e777 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=E2=82=82=E2=82=82H=E2=82=82=E2=82=85NO=E2=82=86?= Date: Mon, 15 Jun 2026 18:22:10 +0800 Subject: [PATCH 1/2] =?UTF-8?q?fix:=20=E6=8B=86=E5=88=86=E8=BF=87=E9=95=BF?= =?UTF-8?q?=E5=90=88=E5=B9=B6=E8=BD=AC=E5=8F=91=E8=8A=82=E7=82=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 20 + .../core/pipeline/result_decorate/stage.py | 162 +++++- .../en-US/features/config-metadata.json | 8 + .../ru-RU/features/config-metadata.json | 8 + .../zh-CN/features/config-metadata.json | 8 + tests/unit/test_result_decorate_forward.py | 515 ++++++++++++++++++ 6 files changed, 713 insertions(+), 8 deletions(-) create mode 100644 tests/unit/test_result_decorate_forward.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 9d6d8fe5d0..7d174337db 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -62,6 +62,8 @@ }, "reply_prefix": "", "forward_threshold": 1500, + "forward_node_max_length": 1000, + "forward_node_hard_limit": 1200, "enable_id_white_list": True, "id_whitelist": [], "id_whitelist_log": True, @@ -1075,6 +1077,14 @@ "type": "int", "hint": "超过一定字数后,机器人会将消息折叠成 QQ 群聊的 “转发消息”,以防止刷屏。目前仅 QQ 平台适配器适用。", }, + "forward_node_max_length": { + "type": "int", + "hint": "合并转发内单个节点期望的文本长度,达到该长度后会优先在句号、换行等自然断点附近切开。", + }, + "forward_node_hard_limit": { + "type": "int", + "hint": "合并转发内单个节点的文本硬上限,超过后一定会强制切开,用来避开 QQ/NapCat 对单条转发节点的隐藏限制。", + }, "enable_id_white_list": { "type": "bool", }, @@ -3887,6 +3897,16 @@ "description": "转发消息的字数阈值", "type": "int", }, + "platform_settings.forward_node_max_length": { + "description": "单个转发节点目标长度", + "type": "int", + "hint": "合并转发内单个节点期望容纳的文本长度,达到该长度后会优先寻找句号、换行等自然断点,避免一句话被切得太碎。", + }, + "platform_settings.forward_node_hard_limit": { + "description": "单个转发节点硬上限", + "type": "int", + "hint": "合并转发内单个节点的最大文本长度。超过后一定会强制切开,用来避开 QQ/NapCat 对单条转发节点的隐藏限制。", + }, "platform_settings.empty_mention_waiting": { "description": "只 @ 机器人是否触发等待", "type": "bool", diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 53f2ec49d1..e4df6b6c47 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -5,7 +5,16 @@ from collections.abc import AsyncGenerator from astrbot.core import file_token_service, html_renderer, logger -from astrbot.core.message.components import At, Image, Json, Node, Plain, Record, Reply +from astrbot.core.message.components import ( + At, + Image, + Json, + Node, + Nodes, + Plain, + Record, + Reply, +) from astrbot.core.message.message_event_result import ResultContentType from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -43,6 +52,36 @@ async def initialize(self, ctx: PipelineContext) -> None: "forward_threshold" ] + # Long-reply auto-forward node splitting settings + _default_forward_node_max_length = 1000 + _default_forward_node_hard_limit = 1200 + try: + self.forward_node_max_length = int( + ctx.astrbot_config["platform_settings"].get( + "forward_node_max_length", _default_forward_node_max_length + ) + ) + except (TypeError, ValueError): + self.forward_node_max_length = _default_forward_node_max_length + try: + self.forward_node_hard_limit = int( + ctx.astrbot_config["platform_settings"].get( + "forward_node_hard_limit", _default_forward_node_hard_limit + ) + ) + except (TypeError, ValueError): + self.forward_node_hard_limit = _default_forward_node_hard_limit + if self.forward_node_max_length <= 0: + self.forward_node_max_length = _default_forward_node_max_length + if self.forward_node_hard_limit <= 0: + self.forward_node_hard_limit = _default_forward_node_hard_limit + if self.forward_node_max_length > self.forward_node_hard_limit: + logger.warning( + "forward_node_max_length is greater than forward_node_hard_limit; " + "falling back to hard limit as target length." + ) + self.forward_node_max_length = self.forward_node_hard_limit + trigger_probability = ctx.astrbot_config["provider_tts_settings"].get( "trigger_probability", 1, @@ -87,6 +126,20 @@ async def initialize(self, ctx: PipelineContext) -> None: "segmented_reply" ]["content_cleanup_rule"] + # Natural breakpoints for forward node splitting: reuse segmented_reply.split_words plus newline + _forward_split_words = list(self.split_words) if self.split_words else [] + if "\n" not in _forward_split_words: + _forward_split_words.append("\n") + if _forward_split_words: + _escaped = sorted( + [re.escape(word) for word in _forward_split_words], + key=len, + reverse=True, + ) + self.forward_split_pattern = re.compile(f"(?:{'|'.join(_escaped)})+") + else: + self.forward_split_pattern = None + # exception self.content_safe_check_reply = ctx.astrbot_config["content_safety"][ "also_use_in_response" @@ -123,6 +176,92 @@ def _split_text_by_words(self, text: str) -> list[str]: result.append(seg) return result if result else [text] + @staticmethod + def _find_forward_split_pos( + text: str, + target_len: int, + hard_limit: int, + split_pattern: re.Pattern | None, + ) -> int: + """Find a split position for forward node plain text. + + Prefer natural breakpoints between target_len and hard_limit. + If none exists, fall back to the nearest breakpoint before target_len. + If still none, hard-cut at hard_limit. + """ + if len(text) <= target_len: + return len(text) + + search_end = min(hard_limit, len(text)) + if split_pattern is not None: + # Look for breakpoints after target_len but within hard_limit. + for match in split_pattern.finditer(text, target_len, search_end): + return match.end() + + # Fall back to the nearest breakpoint before target_len. + previous_end = 0 + for match in split_pattern.finditer(text, 0, min(target_len, len(text))): + if 0 < match.end() <= target_len: + previous_end = match.end() + if previous_end > 0: + return previous_end + + if len(text) > hard_limit: + return hard_limit + return search_end + + def _build_forward_nodes( + self, + chain: list, + uin: str, + name: str, + ) -> Nodes: + """Split a message chain into multiple forward nodes. + + Non-Plain components are kept in the node where they appear. + Each node's total plain text length will not exceed forward_node_hard_limit. + """ + nodes = Nodes([]) + current_content: list = [] + current_text_len = 0 + target_len = self.forward_node_max_length + hard_limit = self.forward_node_hard_limit + + def flush_current(): + nonlocal current_content, current_text_len + if current_content: + nodes.nodes.append(Node(uin=uin, name=name, content=current_content)) + current_content = [] + current_text_len = 0 + + for comp in chain: + if isinstance(comp, Plain): + rest = comp.text or "" + while rest: + if current_text_len >= target_len: + flush_current() + + remaining_target = max(1, target_len - current_text_len) + remaining_hard = max(1, hard_limit - current_text_len) + split_pos = self._find_forward_split_pos( + rest, + remaining_target, + remaining_hard, + self.forward_split_pattern, + ) + split_pos = max(1, min(split_pos, remaining_hard, len(rest))) + current_content.append(Plain(rest[:split_pos])) + current_text_len += split_pos + rest = rest[split_pos:] + + if rest: + flush_current() + else: + current_content.append(comp) + + flush_current() + return nodes + async def process( self, event: AstrMessageEvent, @@ -394,14 +533,21 @@ async def process( if isinstance(comp, Plain): word_cnt += len(comp.text) if word_cnt > self.forward_threshold: - node = Node( - uin=event.get_self_id(), - name="AstrBot", - content=[*result.chain], - ) - result.chain = [node] + # Skip if the chain already contains forward nodes. + if not any( + isinstance(comp, (Node, Nodes)) for comp in result.chain + ): + nodes = self._build_forward_nodes( + result.chain, + event.get_self_id(), + "AstrBot", + ) + result.chain = [nodes] - # at 回复 / 引用回复仅适用于纯文本或图文消息 + # at 回复 / 引用回复仅适用于纯文本或图文消息。 + # After forward conversion result.chain is [Nodes], so mention/quote + # decorations are not applied to forwarded messages. This matches the + # pre-existing single-Node behavior and keeps pipeline order stable. can_decorate = all( isinstance(item, (Plain, Image)) for item in result.chain ) diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 2e8a786d1c..5221c25a55 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -776,6 +776,14 @@ "forward_threshold": { "description": "Forward Message Word Count Threshold" }, + "forward_node_max_length": { + "description": "Target Length per Forward Node", + "hint": "The expected plain-text length for each forward node. When reached, AstrBot will prefer to split at natural breakpoints like periods or newlines to avoid breaking sentences." + }, + "forward_node_hard_limit": { + "description": "Hard Limit per Forward Node", + "hint": "The maximum plain-text length allowed in a single forward node. AstrBot will force-split beyond this limit to avoid QQ/NapCat hidden limits on a single node." + }, "empty_mention_waiting": { "description": "Trigger Waiting on Mention-only Messages" } diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index 6f79726b33..8b9b9d3c13 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -777,6 +777,14 @@ "forward_threshold": { "description": "Порог количества слов для пересылки" }, + "forward_node_max_length": { + "description": "Целевая длина одного узла пересылки", + "hint": "Ожидаемая длина текста в одном узле пересылки. При достижении этого значения AstrBot будет стараться разделить текст на естественных границах, таких как точка или перевод строки." + }, + "forward_node_hard_limit": { + "description": "Жёсткий лимит одного узла пересылки", + "hint": "Максимально допустимая длина текста в одном узле пересылки. При превышении этого значения AstrBot обязательно разделит текст, чтобы избежать скрытых ограничений QQ/NapCat." + }, "empty_mention_waiting": { "description": "Реагировать на пустое упоминание (@бота)" } diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index c73fcae2d7..3aa803dc4a 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -778,6 +778,14 @@ "forward_threshold": { "description": "转发消息的字数阈值" }, + "forward_node_max_length": { + "description": "单个转发节点目标长度", + "hint": "合并转发内单个节点期望容纳的文本长度,达到该长度后会优先寻找句号、换行等自然断点,避免一句话被切得太碎。" + }, + "forward_node_hard_limit": { + "description": "单个转发节点硬上限", + "hint": "合并转发内单个节点的最大文本长度。超过后一定会强制切开,用来避开 QQ/NapCat 对单条转发节点的隐藏限制。" + }, "empty_mention_waiting": { "description": "只 @ 机器人是否触发等待" } diff --git a/tests/unit/test_result_decorate_forward.py b/tests/unit/test_result_decorate_forward.py new file mode 100644 index 0000000000..36156137af --- /dev/null +++ b/tests/unit/test_result_decorate_forward.py @@ -0,0 +1,515 @@ +"""Tests for ResultDecorateStage long-reply auto-forward node splitting.""" + +import re +from unittest.mock import MagicMock, patch + +import pytest + +from astrbot.core.message.components import At, Image, Node, Nodes, Plain, Reply +from astrbot.core.message.message_event_result import MessageEventResult +from astrbot.core.pipeline.result_decorate.stage import ResultDecorateStage +from astrbot.core.platform.message_type import MessageType + +DEFAULT_SPLIT_WORDS = ["。", "?", "!", "~", "…"] + + +def _make_configured_stage( + forward_threshold: int = 100, + forward_node_max_length: int = 50, + forward_node_hard_limit: int = 70, + split_words: list[str] | None = None, +): + """Create a ResultDecorateStage with only the attributes needed for tests.""" + stage = ResultDecorateStage() + stage.forward_threshold = forward_threshold + stage.forward_node_max_length = forward_node_max_length + stage.forward_node_hard_limit = forward_node_hard_limit + + _split_words = ( + list(split_words) if split_words is not None else list(DEFAULT_SPLIT_WORDS) + ) + if "\n" not in _split_words: + _split_words.append("\n") + if _split_words: + _escaped = sorted( + [re.escape(word) for word in _split_words], key=len, reverse=True + ) + stage.forward_split_pattern = re.compile(f"(?:{'|'.join(_escaped)})+") + else: + stage.forward_split_pattern = None + + # Attributes used by process() + stage.reply_prefix = "" + stage.content_safe_check_reply = False + stage.content_safe_check_stage = None + stage.show_reasoning = False + stage.tts_trigger_probability = 0.0 + stage.reply_with_mention = False + stage.reply_with_quote = False + stage.t2i_word_threshold = 99999 + stage.t2i_use_network = False + stage.t2i_active_template = "base" + stage.enable_segmented_reply = False + stage.only_llm_result = True + stage.split_mode = "regex" + stage.regex = ".*?[。?!~…]+|.+$" + stage.split_words = list(DEFAULT_SPLIT_WORDS) + stage.split_words_pattern = re.compile( + f"(.*?({'|'.join(sorted([re.escape(w) for w in stage.split_words], key=len, reverse=True))})|.+$)", + re.DOTALL, + ) + stage.words_count_threshold = 150 + stage.content_cleanup_rule = "" + + stage.ctx = MagicMock() + stage.ctx.plugin_manager.context.get_using_tts_provider = MagicMock( + return_value=None + ) + stage.ctx.astrbot_config = { + "provider_tts_settings": {"enable": False}, + "t2i": False, + "t2i_use_file_service": False, + "callback_api_base": "", + } + return stage + + +def _make_event( + chain: list, + platform_name: str = "aiocqhttp", + message_type: MessageType = MessageType.FRIEND_MESSAGE, +): + """Create a minimal mock event for ResultDecorateStage.process().""" + event = MagicMock() + event.get_platform_name = MagicMock(return_value=platform_name) + event.get_self_id = MagicMock(return_value="123456") + event.get_sender_id = MagicMock(return_value="987654") + event.get_sender_name = MagicMock(return_value="User") + event.get_message_type = MagicMock(return_value=message_type) + event.message_obj.message_id = "msg-id" + event.get_extra = MagicMock(return_value=None) + event.is_stopped = MagicMock(return_value=False) + result = MessageEventResult() + result.chain = chain + result.use_t2i_ = None + result.result_content_type = MagicMock() + result.result_content_type.value = "TEXT_RESULT" + event.get_result = MagicMock(return_value=result) + return event + + +class TestFindForwardSplitPos: + """Tests for ResultDecorateStage._find_forward_split_pos.""" + + def test_no_split_needed(self): + pattern = re.compile(r"([。!?\n]+)") + assert ( + ResultDecorateStage._find_forward_split_pos("short", 10, 20, pattern) == 5 + ) + + def test_breakpoint_between_target_and_hard_limit(self): + pattern = re.compile(r"([。!?\n]+)") + text = "a" * 45 + "。" + "b" * 100 + # target=40, hard=70 -> should split at the period (position 46) + pos = ResultDecorateStage._find_forward_split_pos(text, 40, 70, pattern) + assert pos == 46 + + def test_fall_back_to_previous_breakpoint(self): + pattern = re.compile(r"([。!?\n]+)") + text = "a" * 30 + "。" + "b" * 100 + # target=40, hard=70 -> no breakpoint between 40 and 70, fall back to position 31 + pos = ResultDecorateStage._find_forward_split_pos(text, 40, 70, pattern) + assert pos == 31 + + def test_hard_cut_when_no_breakpoint(self): + pattern = re.compile(r"([。!?\n]+)") + text = "a" * 200 + pos = ResultDecorateStage._find_forward_split_pos(text, 40, 70, pattern) + assert pos == 70 + + def test_newline_is_a_breakpoint(self): + pattern = re.compile(r"([。!?\n]+)") + text = "a" * 45 + "\n" + "b" * 100 + pos = ResultDecorateStage._find_forward_split_pos(text, 40, 70, pattern) + assert pos == 46 + + +class TestBuildForwardNodes: + """Tests for ResultDecorateStage._build_forward_nodes.""" + + def test_short_text_single_node(self): + stage = _make_configured_stage() + nodes = stage._build_forward_nodes([Plain("hello world")], "123", "Bot") + assert len(nodes.nodes) == 1 + assert len(nodes.nodes[0].content) == 1 + assert nodes.nodes[0].content[0].text == "hello world" + + def test_long_plain_text_multiple_nodes(self): + stage = _make_configured_stage() + text = "x" * 200 + nodes = stage._build_forward_nodes([Plain(text)], "123", "Bot") + assert len(nodes.nodes) > 1 + total = 0 + for node in nodes.nodes: + plain_len = sum(len(c.text) for c in node.content if isinstance(c, Plain)) + assert plain_len <= stage.forward_node_hard_limit + total += plain_len + assert total == len(text) + + def test_breakpoint_before_target_is_used_when_no_later_breakpoint(self): + stage = _make_configured_stage() + # 40 chars, then a period, then more text. target=50, hard=70. + # There is no breakpoint after target, so it falls back to the period. + text = "x" * 40 + "。" + "y" * 100 + nodes = stage._build_forward_nodes([Plain(text)], "123", "Bot") + assert len(nodes.nodes) > 1 + first_plain = sum( + len(c.text) for c in nodes.nodes[0].content if isinstance(c, Plain) + ) + # Should split after the period (41 chars), not hard-cut at 50 or 70. + assert first_plain == 41 + + def test_fall_back_to_breakpoint_before_target(self): + stage = _make_configured_stage() + # Breakpoint at 30, then no breakpoint until past hard limit. + text = "x" * 30 + "。" + "y" * 200 + nodes = stage._build_forward_nodes([Plain(text)], "123", "Bot") + first_plain = sum( + len(c.text) for c in nodes.nodes[0].content if isinstance(c, Plain) + ) + # Should fall back to the breakpoint at 31. + assert first_plain == 31 + + def test_no_breakpoint_hard_cut(self): + stage = _make_configured_stage() + text = "x" * 200 + nodes = stage._build_forward_nodes([Plain(text)], "123", "Bot") + first_plain = sum( + len(c.text) for c in nodes.nodes[0].content if isinstance(c, Plain) + ) + assert first_plain == stage.forward_node_hard_limit + + def test_reply_at_only_in_first_node(self): + stage = _make_configured_stage() + text = "x" * 200 + nodes = stage._build_forward_nodes( + [Reply(id="r1"), At(qq="987"), Plain(text)], + "123", + "Bot", + ) + assert len(nodes.nodes) > 1 + assert any(isinstance(c, Reply) for c in nodes.nodes[0].content) + assert any(isinstance(c, At) for c in nodes.nodes[0].content) + for node in nodes.nodes[1:]: + assert not any(isinstance(c, Reply) for c in node.content) + assert not any(isinstance(c, At) for c in node.content) + + def test_later_reply_at_are_preserved_once(self): + stage = _make_configured_stage() + nodes = stage._build_forward_nodes( + [Plain("x" * 200), At(qq="987"), Reply(id="r1")], + "123", + "Bot", + ) + at_count = sum( + 1 for node in nodes.nodes for c in node.content if isinstance(c, At) + ) + reply_count = sum( + 1 for node in nodes.nodes for c in node.content if isinstance(c, Reply) + ) + assert at_count == 1 + assert reply_count == 1 + assert any(isinstance(c, At) for c in nodes.nodes[-1].content) + assert any(isinstance(c, Reply) for c in nodes.nodes[-1].content) + + def test_image_not_duplicated(self): + stage = _make_configured_stage() + text = "x" * 200 + image = Image(file="http://example.com/img.jpg") + nodes = stage._build_forward_nodes( + [Plain(text), image], + "123", + "Bot", + ) + image_count = sum( + 1 for node in nodes.nodes for c in node.content if isinstance(c, Image) + ) + assert image_count == 1 + # Image should be in the last node because it follows the Plain text. + assert any(isinstance(c, Image) for c in nodes.nodes[-1].content) + + def test_image_before_text_is_not_duplicated(self): + stage = _make_configured_stage() + text = "x" * 200 + image = Image(file="http://example.com/img.jpg") + nodes = stage._build_forward_nodes( + [image, Plain(text)], + "123", + "Bot", + ) + image_count = sum( + 1 for node in nodes.nodes for c in node.content if isinstance(c, Image) + ) + assert image_count == 1 + # Image should be in the first node because it precedes the Plain text. + assert any(isinstance(c, Image) for c in nodes.nodes[0].content) + # Text was split across nodes. + assert len(nodes.nodes) > 1 + + def test_chinese_punctuation_and_newline_breakpoints(self): + stage = _make_configured_stage() + # 30 chars, question mark, 30 chars, newline, 30 chars, exclamation, then more. + text = "x" * 30 + "?" + "y" * 30 + "\n" + "z" * 30 + "!" + "w" * 200 + nodes = stage._build_forward_nodes([Plain(text)], "123", "Bot") + assert len(nodes.nodes) > 1 + first_plain = sum( + len(c.text) for c in nodes.nodes[0].content if isinstance(c, Plain) + ) + # The question mark is before target=50, so the algorithm keeps searching + # and finds the newline at position 62 (within hard=70) first. + assert first_plain == 62 + + +class TestProcessForward: + """Tests for ResultDecorateStage.process() forward conversion behavior.""" + + @pytest.mark.asyncio + async def test_below_forward_threshold_no_conversion(self): + stage = _make_configured_stage(forward_threshold=100) + event = _make_event([Plain("short")]) + with patch( + "astrbot.core.pipeline.result_decorate.stage.SessionServiceManager.should_process_tts_request", + return_value=False, + ): + async for _ in stage.process(event): + pass + result = event.get_result() + assert len(result.chain) == 1 + assert isinstance(result.chain[0], Plain) + + @pytest.mark.asyncio + async def test_above_threshold_single_node_when_under_hard_limit(self): + stage = _make_configured_stage( + forward_threshold=10, + forward_node_max_length=200, + forward_node_hard_limit=250, + ) + event = _make_event([Plain("x" * 100)]) + with patch( + "astrbot.core.pipeline.result_decorate.stage.SessionServiceManager.should_process_tts_request", + return_value=False, + ): + async for _ in stage.process(event): + pass + result = event.get_result() + assert len(result.chain) == 1 + assert isinstance(result.chain[0], Nodes) + assert len(result.chain[0].nodes) == 1 + + @pytest.mark.asyncio + async def test_above_threshold_multiple_nodes(self): + stage = _make_configured_stage( + forward_threshold=10, + forward_node_max_length=50, + forward_node_hard_limit=70, + ) + event = _make_event([Plain("x" * 200)]) + with patch( + "astrbot.core.pipeline.result_decorate.stage.SessionServiceManager.should_process_tts_request", + return_value=False, + ): + async for _ in stage.process(event): + pass + result = event.get_result() + assert len(result.chain) == 1 + assert isinstance(result.chain[0], Nodes) + assert len(result.chain[0].nodes) > 1 + for node in result.chain[0].nodes: + plain_len = sum(len(c.text) for c in node.content if isinstance(c, Plain)) + assert plain_len <= stage.forward_node_hard_limit + + @pytest.mark.asyncio + async def test_non_aiocqhttp_platform_not_converted(self): + stage = _make_configured_stage( + forward_threshold=10, + forward_node_max_length=50, + forward_node_hard_limit=70, + ) + event = _make_event([Plain("x" * 200)], platform_name="telegram") + with patch( + "astrbot.core.pipeline.result_decorate.stage.SessionServiceManager.should_process_tts_request", + return_value=False, + ): + async for _ in stage.process(event): + pass + result = event.get_result() + assert len(result.chain) == 1 + assert isinstance(result.chain[0], Plain) + + @pytest.mark.asyncio + async def test_existing_nodes_are_skipped(self): + stage = _make_configured_stage( + forward_threshold=10, + forward_node_max_length=50, + forward_node_hard_limit=70, + ) + existing_node = Node(uin="123", name="Bot", content=[Plain("x" * 200)]) + event = _make_event([existing_node]) + with patch( + "astrbot.core.pipeline.result_decorate.stage.SessionServiceManager.should_process_tts_request", + return_value=False, + ): + async for _ in stage.process(event): + pass + result = event.get_result() + assert len(result.chain) == 1 + assert isinstance(result.chain[0], Node) + + @pytest.mark.asyncio + async def test_existing_nodes_object_is_skipped(self): + stage = _make_configured_stage( + forward_threshold=10, + forward_node_max_length=50, + forward_node_hard_limit=70, + ) + existing_nodes = Nodes( + [Node(uin="123", name="Bot", content=[Plain("x" * 200)])] + ) + event = _make_event([existing_nodes]) + with patch( + "astrbot.core.pipeline.result_decorate.stage.SessionServiceManager.should_process_tts_request", + return_value=False, + ): + async for _ in stage.process(event): + pass + result = event.get_result() + assert len(result.chain) == 1 + assert isinstance(result.chain[0], Nodes) + + @pytest.mark.asyncio + async def test_custom_config_affects_splitting(self): + stage = _make_configured_stage( + forward_threshold=10, + forward_node_max_length=30, + forward_node_hard_limit=40, + ) + event = _make_event([Plain("x" * 100)]) + with patch( + "astrbot.core.pipeline.result_decorate.stage.SessionServiceManager.should_process_tts_request", + return_value=False, + ): + async for _ in stage.process(event): + pass + result = event.get_result() + nodes = result.chain[0].nodes + assert len(nodes) >= 3 # 100 chars with max 40 per node + for node in nodes: + plain_len = sum(len(c.text) for c in node.content if isinstance(c, Plain)) + assert plain_len <= 40 + + +class TestConfigSanitization: + """Tests for invalid forward node config values.""" + + def test_max_length_greater_than_hard_limit_is_sanitized(self): + stage = ResultDecorateStage() + stage.forward_threshold = 100 + stage.forward_node_max_length = 200 + stage.forward_node_hard_limit = 100 + stage.forward_split_pattern = None + # Should not raise and should respect hard limit. + nodes = stage._build_forward_nodes([Plain("x" * 300)], "123", "Bot") + for node in nodes.nodes: + plain_len = sum(len(c.text) for c in node.content if isinstance(c, Plain)) + assert plain_len <= stage.forward_node_hard_limit + + def test_zero_or_negative_limits_do_not_crash(self): + stage = ResultDecorateStage() + stage.forward_threshold = 100 + stage.forward_node_max_length = 0 + stage.forward_node_hard_limit = 0 + stage.forward_split_pattern = None + nodes = stage._build_forward_nodes([Plain("x" * 300)], "123", "Bot") + assert len(nodes.nodes) >= 1 + + @pytest.mark.asyncio + async def test_initialize_sanitizes_invalid_config(self): + from astrbot.core.pipeline.context import PipelineContext + + cfg = { + "platform_settings": { + "reply_prefix": "", + "reply_with_mention": False, + "reply_with_quote": False, + "forward_threshold": 100, + "forward_node_max_length": -100, + "forward_node_hard_limit": 0, + "segmented_reply": { + "enable": False, + "only_llm_result": True, + "words_count_threshold": 150, + "split_mode": "regex", + "regex": ".*?[。?!~…]+|.+$", + "split_words": ["。", "?", "!", "~", "…"], + "content_cleanup_rule": "", + }, + }, + "provider_tts_settings": {"enable": False, "trigger_probability": 1}, + "provider_settings": {"display_reasoning_text": False}, + "content_safety": {"also_use_in_response": False}, + "t2i_word_threshold": 150, + "t2i_strategy": "remote", + "t2i": False, + "t2i_active_template": "base", + } + ctx = PipelineContext( + astrbot_config=cfg, + plugin_manager=MagicMock(), + astrbot_config_id="test", + ) + stage = ResultDecorateStage() + await stage.initialize(ctx) + assert stage.forward_node_max_length == 1000 + assert stage.forward_node_hard_limit == 1200 + + @pytest.mark.asyncio + async def test_initialize_converges_max_greater_than_hard(self): + from astrbot.core.pipeline.context import PipelineContext + + cfg = { + "platform_settings": { + "reply_prefix": "", + "reply_with_mention": False, + "reply_with_quote": False, + "forward_threshold": 100, + "forward_node_max_length": 5000, + "forward_node_hard_limit": 100, + "segmented_reply": { + "enable": False, + "only_llm_result": True, + "words_count_threshold": 150, + "split_mode": "regex", + "regex": ".*?[。?!~…]+|.+$", + "split_words": [","], + "content_cleanup_rule": "", + }, + }, + "provider_tts_settings": {"enable": False, "trigger_probability": 1}, + "provider_settings": {"display_reasoning_text": False}, + "content_safety": {"also_use_in_response": False}, + "t2i_word_threshold": 150, + "t2i_strategy": "remote", + "t2i": False, + "t2i_active_template": "base", + } + ctx = PipelineContext( + astrbot_config=cfg, + plugin_manager=MagicMock(), + astrbot_config_id="test", + ) + stage = ResultDecorateStage() + await stage.initialize(ctx) + assert stage.forward_node_max_length == 100 + assert stage.forward_node_hard_limit == 100 + # Custom split_words are used and newline is appended. + assert stage.forward_split_pattern is not None From adea2310efb526a2b135d4928bc8dbc5bdf13798 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=E2=82=82=E2=82=82H=E2=82=82=E2=82=85NO=E2=82=86?= Date: Mon, 15 Jun 2026 18:47:43 +0800 Subject: [PATCH 2/2] =?UTF-8?q?fix:=20=E5=AE=8C=E5=96=84=E5=90=88=E5=B9=B6?= =?UTF-8?q?=E8=BD=AC=E5=8F=91=E8=8A=82=E7=82=B9=E5=88=87=E5=88=86=E8=BE=B9?= =?UTF-8?q?=E7=95=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 6 ++-- .../core/pipeline/result_decorate/stage.py | 31 +++++++++---------- tests/unit/test_result_decorate_forward.py | 19 ++++++++++-- 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 7d174337db..a8cd09323d 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -7,6 +7,8 @@ VERSION = "4.26.0-beta.1" DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") +FORWARD_NODE_MAX_LENGTH_DEFAULT = 1000 +FORWARD_NODE_HARD_LIMIT_DEFAULT = 1200 PERSONAL_WECHAT_CONFIG_METADATA = { "weixin_oc_base_url": { "description": "Base URL", @@ -62,8 +64,8 @@ }, "reply_prefix": "", "forward_threshold": 1500, - "forward_node_max_length": 1000, - "forward_node_hard_limit": 1200, + "forward_node_max_length": FORWARD_NODE_MAX_LENGTH_DEFAULT, + "forward_node_hard_limit": FORWARD_NODE_HARD_LIMIT_DEFAULT, "enable_id_white_list": True, "id_whitelist": [], "id_whitelist_log": True, diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index e4df6b6c47..5a1e401db1 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -5,6 +5,10 @@ from collections.abc import AsyncGenerator from astrbot.core import file_token_service, html_renderer, logger +from astrbot.core.config.default import ( + FORWARD_NODE_HARD_LIMIT_DEFAULT, + FORWARD_NODE_MAX_LENGTH_DEFAULT, +) from astrbot.core.message.components import ( At, Image, @@ -53,28 +57,26 @@ async def initialize(self, ctx: PipelineContext) -> None: ] # Long-reply auto-forward node splitting settings - _default_forward_node_max_length = 1000 - _default_forward_node_hard_limit = 1200 try: self.forward_node_max_length = int( ctx.astrbot_config["platform_settings"].get( - "forward_node_max_length", _default_forward_node_max_length + "forward_node_max_length", FORWARD_NODE_MAX_LENGTH_DEFAULT ) ) except (TypeError, ValueError): - self.forward_node_max_length = _default_forward_node_max_length + self.forward_node_max_length = FORWARD_NODE_MAX_LENGTH_DEFAULT try: self.forward_node_hard_limit = int( ctx.astrbot_config["platform_settings"].get( - "forward_node_hard_limit", _default_forward_node_hard_limit + "forward_node_hard_limit", FORWARD_NODE_HARD_LIMIT_DEFAULT ) ) except (TypeError, ValueError): - self.forward_node_hard_limit = _default_forward_node_hard_limit + self.forward_node_hard_limit = FORWARD_NODE_HARD_LIMIT_DEFAULT if self.forward_node_max_length <= 0: - self.forward_node_max_length = _default_forward_node_max_length + self.forward_node_max_length = FORWARD_NODE_MAX_LENGTH_DEFAULT if self.forward_node_hard_limit <= 0: - self.forward_node_hard_limit = _default_forward_node_hard_limit + self.forward_node_hard_limit = FORWARD_NODE_HARD_LIMIT_DEFAULT if self.forward_node_max_length > self.forward_node_hard_limit: logger.warning( "forward_node_max_length is greater than forward_node_hard_limit; " @@ -189,19 +191,16 @@ def _find_forward_split_pos( If none exists, fall back to the nearest breakpoint before target_len. If still none, hard-cut at hard_limit. """ + search_end = min(hard_limit, len(text)) if len(text) <= target_len: return len(text) - search_end = min(hard_limit, len(text)) if split_pattern is not None: - # Look for breakpoints after target_len but within hard_limit. - for match in split_pattern.finditer(text, target_len, search_end): - return match.end() - - # Fall back to the nearest breakpoint before target_len. previous_end = 0 - for match in split_pattern.finditer(text, 0, min(target_len, len(text))): - if 0 < match.end() <= target_len: + for match in split_pattern.finditer(text, 0, search_end): + if match.end() >= target_len: + return match.end() + if match.end() > 0: previous_end = match.end() if previous_end > 0: return previous_end diff --git a/tests/unit/test_result_decorate_forward.py b/tests/unit/test_result_decorate_forward.py index 36156137af..30cdc406a5 100644 --- a/tests/unit/test_result_decorate_forward.py +++ b/tests/unit/test_result_decorate_forward.py @@ -5,6 +5,10 @@ import pytest +from astrbot.core.config.default import ( + FORWARD_NODE_HARD_LIMIT_DEFAULT, + FORWARD_NODE_MAX_LENGTH_DEFAULT, +) from astrbot.core.message.components import At, Image, Node, Nodes, Plain, Reply from astrbot.core.message.message_event_result import MessageEventResult from astrbot.core.pipeline.result_decorate.stage import ResultDecorateStage @@ -54,8 +58,11 @@ def _make_configured_stage( stage.split_mode = "regex" stage.regex = ".*?[。?!~…]+|.+$" stage.split_words = list(DEFAULT_SPLIT_WORDS) + escaped_words = sorted( + [re.escape(w) for w in stage.split_words], key=len, reverse=True + ) stage.split_words_pattern = re.compile( - f"(.*?({'|'.join(sorted([re.escape(w) for w in stage.split_words], key=len, reverse=True))})|.+$)", + f"(.*?({'|'.join(escaped_words)})|.+$)", re.DOTALL, ) stage.words_count_threshold = 150 @@ -133,6 +140,12 @@ def test_newline_is_a_breakpoint(self): pos = ResultDecorateStage._find_forward_split_pos(text, 40, 70, pattern) assert pos == 46 + def test_multichar_breakpoint_crossing_target_boundary(self): + pattern = re.compile(r"(?:END|\n)+") + text = "a" * 49 + "END" + "b" * 100 + pos = ResultDecorateStage._find_forward_split_pos(text, 50, 70, pattern) + assert pos == 52 + class TestBuildForwardNodes: """Tests for ResultDecorateStage._build_forward_nodes.""" @@ -469,8 +482,8 @@ async def test_initialize_sanitizes_invalid_config(self): ) stage = ResultDecorateStage() await stage.initialize(ctx) - assert stage.forward_node_max_length == 1000 - assert stage.forward_node_hard_limit == 1200 + assert stage.forward_node_max_length == FORWARD_NODE_MAX_LENGTH_DEFAULT + assert stage.forward_node_hard_limit == FORWARD_NODE_HARD_LIMIT_DEFAULT @pytest.mark.asyncio async def test_initialize_converges_max_greater_than_hard(self):