diff --git a/astrbot/builtin_stars/builtin_commands/commands/provider.py b/astrbot/builtin_stars/builtin_commands/commands/provider.py index 971d6ca8a0..8128c1d2cf 100644 --- a/astrbot/builtin_stars/builtin_commands/commands/provider.py +++ b/astrbot/builtin_stars/builtin_commands/commands/provider.py @@ -3,7 +3,7 @@ import asyncio from astrbot import logger -from astrbot.api import star +from astrbot.api import sp, star from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.core.provider.entities import ProviderType from astrbot.core.utils.error_redaction import safe_error @@ -13,6 +13,23 @@ class ProviderCommands: def __init__(self, context: star.Context) -> None: self.context = context + async def _reset_provider_override( + self, + event: AstrMessageEvent, + umo: str, + provider_type: ProviderType, + display_name: str, + ) -> None: + await sp.session_remove( + umo, + f"provider_perf_{provider_type.value}", + ) + event.set_result( + MessageEventResult().message( + f"✅ Successfully reset {display_name} provider to global default." + ) + ) + def _log_reachability_failure( self, provider, @@ -112,7 +129,7 @@ async def provider( self, event: AstrMessageEvent, idx: str | int | None = None, - idx2: int | None = None, + idx2: str | int | None = None, ) -> None: """查看或者切换 LLM Provider""" umo = event.unified_msg_origin @@ -184,6 +201,7 @@ async def provider( ret += "\nUse /provider tts to switch TTS providers." if stts: ret += "\nUse /provider stt to switch STT providers." + ret += "\nUse /provider reset to clear session override." event.set_result(MessageEventResult().message(ret)) elif idx == "tts": @@ -192,12 +210,25 @@ async def provider( MessageEventResult().message("Please enter the index.") ) return - if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1: + if idx2 in ("default", "reset", "clear"): + await self._reset_provider_override( + event, umo, ProviderType.TEXT_TO_SPEECH, "TTS" + ) + return + try: + idx2_int = int(idx2) + except ValueError: + event.set_result( + MessageEventResult().message("❌ Invalid provider index.") + ) + return + providers = list(self.context.get_all_tts_providers()) + if idx2_int > len(providers) or idx2_int < 1: event.set_result( MessageEventResult().message("❌ Invalid provider index.") ) return - provider = self.context.get_all_tts_providers()[idx2 - 1] + provider = providers[idx2_int - 1] id_ = provider.meta().id await self.context.provider_manager.set_provider( provider_id=id_, @@ -213,36 +244,61 @@ async def provider( MessageEventResult().message("Please enter the index.") ) return - if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1: + if idx2 in ("default", "reset", "clear"): + await self._reset_provider_override( + event, umo, ProviderType.SPEECH_TO_TEXT, "STT" + ) + return + try: + idx2_int = int(idx2) + except ValueError: event.set_result( MessageEventResult().message("❌ Invalid provider index.") ) return - provider = self.context.get_all_stt_providers()[idx2 - 1] - id_ = provider.meta().id - await self.context.provider_manager.set_provider( - provider_id=id_, - provider_type=ProviderType.SPEECH_TO_TEXT, - umo=umo, - ) - event.set_result( - MessageEventResult().message(f"✅ Successfully switched to {id_}.") - ) - elif isinstance(idx, int): - if idx > len(self.context.get_all_providers()) or idx < 1: + providers = list(self.context.get_all_stt_providers()) + if idx2_int > len(providers) or idx2_int < 1: event.set_result( MessageEventResult().message("❌ Invalid provider index.") ) return - provider = self.context.get_all_providers()[idx - 1] + provider = providers[idx2_int - 1] id_ = provider.meta().id await self.context.provider_manager.set_provider( provider_id=id_, - provider_type=ProviderType.CHAT_COMPLETION, + provider_type=ProviderType.SPEECH_TO_TEXT, umo=umo, ) event.set_result( MessageEventResult().message(f"✅ Successfully switched to {id_}.") ) + elif idx in ("default", "reset", "clear"): + await self._reset_provider_override( + event, umo, ProviderType.CHAT_COMPLETION, "Chat Completion" + ) else: - event.set_result(MessageEventResult().message("❌ Invalid parameter.")) + try: + idx_int = int(idx) + is_int = True + except (ValueError, TypeError): + is_int = False + + if is_int: + providers = list(self.context.get_all_providers()) + if idx_int > len(providers) or idx_int < 1: + event.set_result( + MessageEventResult().message("❌ Invalid provider index.") + ) + return + provider = providers[idx_int - 1] + id_ = provider.meta().id + await self.context.provider_manager.set_provider( + provider_id=id_, + provider_type=ProviderType.CHAT_COMPLETION, + umo=umo, + ) + event.set_result( + MessageEventResult().message(f"✅ Successfully switched to {id_}.") + ) + else: + event.set_result(MessageEventResult().message("❌ Invalid parameter.")) diff --git a/astrbot/builtin_stars/builtin_commands/main.py b/astrbot/builtin_stars/builtin_commands/main.py index 4c5ce3f8ca..e69bbfb280 100644 --- a/astrbot/builtin_stars/builtin_commands/main.py +++ b/astrbot/builtin_stars/builtin_commands/main.py @@ -67,7 +67,7 @@ async def provider( self, event: AstrMessageEvent, idx: str | int | None = None, - idx2: int | None = None, + idx2: str | int | None = None, ) -> None: """View or switch LLM Provider""" await self.provider_c.provider(event, idx, idx2) diff --git a/tests/unit/test_provider_commands.py b/tests/unit/test_provider_commands.py new file mode 100644 index 0000000000..020a3277e6 --- /dev/null +++ b/tests/unit/test_provider_commands.py @@ -0,0 +1,70 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from astrbot.api.event import AstrMessageEvent +from astrbot.builtin_stars.builtin_commands.commands.provider import ProviderCommands +from astrbot.core.provider.entities import ProviderType + + +@pytest.mark.asyncio +async def test_provider_reset_chat_completion(): + context = MagicMock() + cmd = ProviderCommands(context) + event = MagicMock(spec=AstrMessageEvent) + event.unified_msg_origin = "session-123" + + with patch( + "astrbot.builtin_stars.builtin_commands.commands.provider.sp.session_remove", + new_callable=AsyncMock, + ) as mock_remove: + await cmd.provider(event, idx="reset") + mock_remove.assert_called_once_with( + "session-123", + f"provider_perf_{ProviderType.CHAT_COMPLETION.value}", + ) + event.set_result.assert_called_once() + res = event.set_result.call_args[0][0] + assert "reset Chat Completion provider" in res.get_plain_text() + + +@pytest.mark.asyncio +async def test_provider_reset_tts(): + context = MagicMock() + cmd = ProviderCommands(context) + event = MagicMock(spec=AstrMessageEvent) + event.unified_msg_origin = "session-123" + + with patch( + "astrbot.builtin_stars.builtin_commands.commands.provider.sp.session_remove", + new_callable=AsyncMock, + ) as mock_remove: + await cmd.provider(event, idx="tts", idx2="reset") + mock_remove.assert_called_once_with( + "session-123", + f"provider_perf_{ProviderType.TEXT_TO_SPEECH.value}", + ) + event.set_result.assert_called_once() + res = event.set_result.call_args[0][0] + assert "reset TTS provider" in res.get_plain_text() + + +@pytest.mark.asyncio +async def test_provider_reset_stt(): + context = MagicMock() + cmd = ProviderCommands(context) + event = MagicMock(spec=AstrMessageEvent) + event.unified_msg_origin = "session-123" + + with patch( + "astrbot.builtin_stars.builtin_commands.commands.provider.sp.session_remove", + new_callable=AsyncMock, + ) as mock_remove: + await cmd.provider(event, idx="stt", idx2="reset") + mock_remove.assert_called_once_with( + "session-123", + f"provider_perf_{ProviderType.SPEECH_TO_TEXT.value}", + ) + event.set_result.assert_called_once() + res = event.set_result.call_args[0][0] + assert "reset STT provider" in res.get_plain_text()