From 3a25b3c0b33d82d08f3b8734e403a980c9f84a01 Mon Sep 17 00:00:00 2001 From: Rat0323 <261020116+Rat0323@users.noreply.github.com> Date: Sat, 13 Jun 2026 07:50:29 +0800 Subject: [PATCH 1/5] feat(provider): support reset/default subcommands to clear session provider override --- .../builtin_commands/commands/provider.py | 87 ++++++++++++++----- .../builtin_stars/builtin_commands/main.py | 2 +- tests/unit/test_provider_commands.py | 56 ++++++++++++ 3 files changed, 124 insertions(+), 21 deletions(-) create mode 100644 tests/unit/test_provider_commands.py diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 971d6ca8a0..edcd4419ef 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -3,7 +3,7 @@ import asyncio from astrbot import logger -from astrbot.api import star +from astrbot.api import sp, star from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.core.provider.entities import ProviderType from astrbot.core.utils.error_redaction import safe_error @@ -112,7 +112,7 @@ async def provider( self, event: AstrMessageEvent, idx: str | int | None = None, - idx2: int | None = None, + idx2: str | int | None = None, ) -> None: """查看或者切换 LLM Provider""" umo = event.unified_msg_origin @@ -192,12 +192,28 @@ async def provider( MessageEventResult().message("Please enter the index.") ) return - if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1: + if idx2 in ("default", "reset", "clear"): + await sp.session_remove( + umo, + f"provider_perf_{ProviderType.TEXT_TO_SPEECH.value}", + ) + event.set_result( + MessageEventResult().message("✅ Successfully reset TTS provider to global default.") + ) + return + try: + idx2_int = int(idx2) + except ValueError: + event.set_result( + MessageEventResult().message("❌ Invalid provider index.") + ) + return + if idx2_int > len(self.context.get_all_tts_providers()) or idx2_int < 1: event.set_result( MessageEventResult().message("❌ Invalid provider index.") ) return - provider = self.context.get_all_tts_providers()[idx2 - 1] + provider = self.context.get_all_tts_providers()[idx2_int - 1] id_ = provider.meta().id await self.context.provider_manager.set_provider( provider_id=id_, @@ -213,36 +229,67 @@ async def provider( MessageEventResult().message("Please enter the index.") ) return - if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1: + if idx2 in ("default", "reset", "clear"): + await sp.session_remove( + umo, + f"provider_perf_{ProviderType.SPEECH_TO_TEXT.value}", + ) + event.set_result( + MessageEventResult().message("✅ Successfully reset STT provider to global default.") + ) + return + try: + idx2_int = int(idx2) + except ValueError: event.set_result( MessageEventResult().message("❌ Invalid provider index.") ) return - provider = self.context.get_all_stt_providers()[idx2 - 1] - id_ = provider.meta().id - await self.context.provider_manager.set_provider( - provider_id=id_, - provider_type=ProviderType.SPEECH_TO_TEXT, - umo=umo, - ) - event.set_result( - MessageEventResult().message(f"✅ Successfully switched to {id_}.") - ) - elif isinstance(idx, int): - if idx > len(self.context.get_all_providers()) or idx < 1: + if idx2_int > len(self.context.get_all_stt_providers()) or idx2_int < 1: event.set_result( MessageEventResult().message("❌ Invalid provider index.") ) return - provider = self.context.get_all_providers()[idx - 1] + provider = self.context.get_all_stt_providers()[idx2_int - 1] id_ = provider.meta().id await self.context.provider_manager.set_provider( provider_id=id_, - provider_type=ProviderType.CHAT_COMPLETION, + provider_type=ProviderType.SPEECH_TO_TEXT, umo=umo, ) event.set_result( MessageEventResult().message(f"✅ Successfully switched to {id_}.") ) + elif idx in ("default", "reset", "clear"): + await sp.session_remove( + umo, + f"provider_perf_{ProviderType.CHAT_COMPLETION.value}", + ) + event.set_result( + MessageEventResult().message("✅ Successfully reset Chat Completion provider to global default.") + ) else: - event.set_result(MessageEventResult().message("❌ Invalid parameter.")) + try: + idx_int = int(idx) + is_int = True + except (ValueError, TypeError): + is_int = False + + if is_int: + if idx_int > len(self.context.get_all_providers()) or idx_int < 1: + event.set_result( + MessageEventResult().message("❌ Invalid provider index.") + ) + return + provider = self.context.get_all_providers()[idx_int - 1] + id_ = provider.meta().id + await self.context.provider_manager.set_provider( + provider_id=id_, + provider_type=ProviderType.CHAT_COMPLETION, + umo=umo, + ) + event.set_result( + MessageEventResult().message(f"✅ Successfully switched to {id_}.") + ) + else: + event.set_result(MessageEventResult().message("❌ Invalid parameter.")) diff --git a/astrbot/builtin_stars/builtin_commands/main.py b/astrbot/builtin_stars/builtin_commands/main.py index 4c5ce3f8ca..e69bbfb280 100644 --- a/astrbot/builtin_stars/builtin_commands/main.py +++ b/astrbot/builtin_stars/builtin_commands/main.py @@ -67,7 +67,7 @@ async def provider( self, event: AstrMessageEvent, idx: str | int | None = None, - idx2: int | None = None, + idx2: str | int | None = None, ) -> None: """View or switch LLM Provider""" await self.provider_c.provider(event, idx, idx2) diff --git a/tests/unit/test_provider_commands.py b/tests/unit/test_provider_commands.py new file mode 100644 index 0000000000..db8c0009df --- /dev/null +++ b/tests/unit/test_provider_commands.py @@ -0,0 +1,56 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from astrbot.builtin_stars.builtin_commands.commands.provider import ProviderCommands +from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.core.provider.entities import ProviderType + +@pytest.mark.asyncio +async def test_provider_reset_chat_completion(): + context = MagicMock() + cmd = ProviderCommands(context) + event = MagicMock(spec=AstrMessageEvent) + event.unified_msg_origin = "session-123" + + with patch("astrbot.builtin_stars.builtin_commands.commands.provider.sp.session_remove", new_callable=AsyncMock) as mock_remove: + await cmd.provider(event, idx="reset") + mock_remove.assert_called_once_with( + "session-123", + f"provider_perf_{ProviderType.CHAT_COMPLETION.value}", + ) + event.set_result.assert_called_once() + res = event.set_result.call_args[0][0] + assert "reset Chat Completion provider" in res.get_plain_text() + +@pytest.mark.asyncio +async def test_provider_reset_tts(): + context = MagicMock() + cmd = ProviderCommands(context) + event = MagicMock(spec=AstrMessageEvent) + event.unified_msg_origin = "session-123" + + with patch("astrbot.builtin_stars.builtin_commands.commands.provider.sp.session_remove", new_callable=AsyncMock) as mock_remove: + await cmd.provider(event, idx="tts", idx2="reset") + mock_remove.assert_called_once_with( + "session-123", + f"provider_perf_{ProviderType.TEXT_TO_SPEECH.value}", + ) + event.set_result.assert_called_once() + res = event.set_result.call_args[0][0] + assert "reset TTS provider" in res.get_plain_text() + +@pytest.mark.asyncio +async def test_provider_reset_stt(): + context = MagicMock() + cmd = ProviderCommands(context) + event = MagicMock(spec=AstrMessageEvent) + event.unified_msg_origin = "session-123" + + with patch("astrbot.builtin_stars.builtin_commands.commands.provider.sp.session_remove", new_callable=AsyncMock) as mock_remove: + await cmd.provider(event, idx="stt", idx2="reset") + mock_remove.assert_called_once_with( + "session-123", + f"provider_perf_{ProviderType.SPEECH_TO_TEXT.value}", + ) + event.set_result.assert_called_once() + res = event.set_result.call_args[0][0] + assert "reset STT provider" in res.get_plain_text() From 9030f8893e9a5b235cec3ee4ec0e3769c5225174 Mon Sep 17 00:00:00 2001 From: Rat0323 <261020116+Rat0323@users.noreply.github.com> Date: Sat, 13 Jun 2026 07:55:38 +0800 Subject: [PATCH 2/5] refactor(provider): address code review recommendations and format code --- .../builtin_commands/commands/provider.py | 27 ++++++++++++------- tests/unit/test_provider_commands.py | 18 ++++++++++--- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index edcd4419ef..e9bcb47f73 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -198,7 +198,9 @@ async def provider( f"provider_perf_{ProviderType.TEXT_TO_SPEECH.value}", ) event.set_result( - MessageEventResult().message("✅ Successfully reset TTS provider to global default.") + MessageEventResult().message( + "✅ Successfully reset TTS provider to global default." + ) ) return try: @@ -208,12 +210,13 @@ async def provider( MessageEventResult().message("❌ Invalid provider index.") ) return - if idx2_int > len(self.context.get_all_tts_providers()) or idx2_int < 1: + providers = list(self.context.get_all_tts_providers()) + if idx2_int > len(providers) or idx2_int < 1: event.set_result( MessageEventResult().message("❌ Invalid provider index.") ) return - provider = self.context.get_all_tts_providers()[idx2_int - 1] + provider = providers[idx2_int - 1] id_ = provider.meta().id await self.context.provider_manager.set_provider( provider_id=id_, @@ -235,7 +238,9 @@ async def provider( f"provider_perf_{ProviderType.SPEECH_TO_TEXT.value}", ) event.set_result( - MessageEventResult().message("✅ Successfully reset STT provider to global default.") + MessageEventResult().message( + "✅ Successfully reset STT provider to global default." + ) ) return try: @@ -245,12 +250,13 @@ async def provider( MessageEventResult().message("❌ Invalid provider index.") ) return - if idx2_int > len(self.context.get_all_stt_providers()) or idx2_int < 1: + providers = list(self.context.get_all_stt_providers()) + if idx2_int > len(providers) or idx2_int < 1: event.set_result( MessageEventResult().message("❌ Invalid provider index.") ) return - provider = self.context.get_all_stt_providers()[idx2_int - 1] + provider = providers[idx2_int - 1] id_ = provider.meta().id await self.context.provider_manager.set_provider( provider_id=id_, @@ -266,7 +272,9 @@ async def provider( f"provider_perf_{ProviderType.CHAT_COMPLETION.value}", ) event.set_result( - MessageEventResult().message("✅ Successfully reset Chat Completion provider to global default.") + MessageEventResult().message( + "✅ Successfully reset Chat Completion provider to global default." + ) ) else: try: @@ -276,12 +284,13 @@ async def provider( is_int = False if is_int: - if idx_int > len(self.context.get_all_providers()) or idx_int < 1: + providers = list(self.context.get_all_providers()) + if idx_int > len(providers) or idx_int < 1: event.set_result( MessageEventResult().message("❌ Invalid provider index.") ) return - provider = self.context.get_all_providers()[idx_int - 1] + provider = providers[idx_int - 1] id_ = provider.meta().id await self.context.provider_manager.set_provider( provider_id=id_, diff --git a/tests/unit/test_provider_commands.py b/tests/unit/test_provider_commands.py index db8c0009df..de996c1410 100644 --- a/tests/unit/test_provider_commands.py +++ b/tests/unit/test_provider_commands.py @@ -4,6 +4,7 @@ from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.core.provider.entities import ProviderType + @pytest.mark.asyncio async def test_provider_reset_chat_completion(): context = MagicMock() @@ -11,7 +12,10 @@ async def test_provider_reset_chat_completion(): event = MagicMock(spec=AstrMessageEvent) event.unified_msg_origin = "session-123" - with patch("astrbot.builtin_stars.builtin_commands.commands.provider.sp.session_remove", new_callable=AsyncMock) as mock_remove: + with patch( + "astrbot.builtin_stars.builtin_commands.commands.provider.sp.session_remove", + new_callable=AsyncMock, + ) as mock_remove: await cmd.provider(event, idx="reset") mock_remove.assert_called_once_with( "session-123", @@ -21,6 +25,7 @@ async def test_provider_reset_chat_completion(): res = event.set_result.call_args[0][0] assert "reset Chat Completion provider" in res.get_plain_text() + @pytest.mark.asyncio async def test_provider_reset_tts(): context = MagicMock() @@ -28,7 +33,10 @@ async def test_provider_reset_tts(): event = MagicMock(spec=AstrMessageEvent) event.unified_msg_origin = "session-123" - with patch("astrbot.builtin_stars.builtin_commands.commands.provider.sp.session_remove", new_callable=AsyncMock) as mock_remove: + with patch( + "astrbot.builtin_stars.builtin_commands.commands.provider.sp.session_remove", + new_callable=AsyncMock, + ) as mock_remove: await cmd.provider(event, idx="tts", idx2="reset") mock_remove.assert_called_once_with( "session-123", @@ -38,6 +46,7 @@ async def test_provider_reset_tts(): res = event.set_result.call_args[0][0] assert "reset TTS provider" in res.get_plain_text() + @pytest.mark.asyncio async def test_provider_reset_stt(): context = MagicMock() @@ -45,7 +54,10 @@ async def test_provider_reset_stt(): event = MagicMock(spec=AstrMessageEvent) event.unified_msg_origin = "session-123" - with patch("astrbot.builtin_stars.builtin_commands.commands.provider.sp.session_remove", new_callable=AsyncMock) as mock_remove: + with patch( + "astrbot.builtin_stars.builtin_commands.commands.provider.sp.session_remove", + new_callable=AsyncMock, + ) as mock_remove: await cmd.provider(event, idx="stt", idx2="reset") mock_remove.assert_called_once_with( "session-123", From d0ff433f4f973a3f6fbf966b032c8d762b381b46 Mon Sep 17 00:00:00 2001 From: Rat0323 <261020116+Rat0323@users.noreply.github.com> Date: Sat, 13 Jun 2026 07:57:47 +0800 Subject: [PATCH 3/5] refactor(provider): extract _reset_provider_override helper method to reduce code duplication --- .../builtin_commands/commands/provider.py | 47 +++++++++---------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index e9bcb47f73..f7c892fcd2 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -13,6 +13,23 @@ class ProviderCommands: def __init__(self, context: star.Context) -> None: self.context = context + async def _reset_provider_override( + self, + event: AstrMessageEvent, + umo: str, + provider_type: ProviderType, + display_name: str, + ) -> None: + await sp.session_remove( + umo, + f"provider_perf_{provider_type.value}", + ) + event.set_result( + MessageEventResult().message( + f"✅ Successfully reset {display_name} provider to global default." + ) + ) + def _log_reachability_failure( self, provider, @@ -193,14 +210,8 @@ async def provider( ) return if idx2 in ("default", "reset", "clear"): - await sp.session_remove( - umo, - f"provider_perf_{ProviderType.TEXT_TO_SPEECH.value}", - ) - event.set_result( - MessageEventResult().message( - "✅ Successfully reset TTS provider to global default." - ) + await self._reset_provider_override( + event, umo, ProviderType.TEXT_TO_SPEECH, "TTS" ) return try: @@ -233,14 +244,8 @@ async def provider( ) return if idx2 in ("default", "reset", "clear"): - await sp.session_remove( - umo, - f"provider_perf_{ProviderType.SPEECH_TO_TEXT.value}", - ) - event.set_result( - MessageEventResult().message( - "✅ Successfully reset STT provider to global default." - ) + await self._reset_provider_override( + event, umo, ProviderType.SPEECH_TO_TEXT, "STT" ) return try: @@ -267,14 +272,8 @@ async def provider( MessageEventResult().message(f"✅ Successfully switched to {id_}.") ) elif idx in ("default", "reset", "clear"): - await sp.session_remove( - umo, - f"provider_perf_{ProviderType.CHAT_COMPLETION.value}", - ) - event.set_result( - MessageEventResult().message( - "✅ Successfully reset Chat Completion provider to global default." - ) + await self._reset_provider_override( + event, umo, ProviderType.CHAT_COMPLETION, "Chat Completion" ) else: try: From 0763878fbca263afc8f8c15ce73d8c28586726bb Mon Sep 17 00:00:00 2001 From: Rat0323 <261020116+Rat0323@users.noreply.github.com> Date: Sat, 13 Jun 2026 08:22:50 +0800 Subject: [PATCH 4/5] fix(provider): add reset command hint to help text --- astrbot/builtin_stars/builtin_commands/commands/provider.py | 1 + 1 file changed, 1 insertion(+) diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index f7c892fcd2..8128c1d2cf 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -201,6 +201,7 @@ async def provider( ret += "\nUse /provider tts to switch TTS providers." if stts: ret += "\nUse /provider stt to switch STT providers." + ret += "\nUse /provider reset to clear session override." event.set_result(MessageEventResult().message(ret)) elif idx == "tts": From 0f8ab794bbfe527ed78167a8e7aef4460ff64b9a Mon Sep 17 00:00:00 2001 From: Rat0323 <261020116+Rat0323@users.noreply.github.com> Date: Mon, 15 Jun 2026 07:28:50 +0800 Subject: [PATCH 5/5] test(provider): clean provider command test imports --- tests/unit/test_provider_commands.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_provider_commands.py b/tests/unit/test_provider_commands.py index de996c1410..020a3277e6 100644 --- a/tests/unit/test_provider_commands.py +++ b/tests/unit/test_provider_commands.py @@ -1,7 +1,9 @@ -import pytest from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from astrbot.api.event import AstrMessageEvent from astrbot.builtin_stars.builtin_commands.commands.provider import ProviderCommands -from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.core.provider.entities import ProviderType