Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 76 additions & 20 deletions astrbot/builtin_stars/builtin_commands/commands/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio

from astrbot import logger
from astrbot.api import star
from astrbot.api import sp, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.provider.entities import ProviderType
from astrbot.core.utils.error_redaction import safe_error
Expand All @@ -13,6 +13,23 @@ class ProviderCommands:
def __init__(self, context: star.Context) -> None:
self.context = context

async def _reset_provider_override(
self,
event: AstrMessageEvent,
umo: str,
provider_type: ProviderType,
display_name: str,
) -> None:
await sp.session_remove(
umo,
f"provider_perf_{provider_type.value}",
)
event.set_result(
MessageEventResult().message(
f"✅ Successfully reset {display_name} provider to global default."
)
)

def _log_reachability_failure(
self,
provider,
Expand Down Expand Up @@ -112,7 +129,7 @@ async def provider(
self,
event: AstrMessageEvent,
idx: str | int | None = None,
idx2: int | None = None,
idx2: str | int | None = None,
) -> None:
"""查看或者切换 LLM Provider"""
umo = event.unified_msg_origin
Expand Down Expand Up @@ -184,6 +201,7 @@ async def provider(
ret += "\nUse /provider tts <idx> to switch TTS providers."
if stts:
ret += "\nUse /provider stt <idx> to switch STT providers."
ret += "\nUse /provider reset to clear session override."

event.set_result(MessageEventResult().message(ret))
elif idx == "tts":
Expand All @@ -192,12 +210,25 @@ async def provider(
MessageEventResult().message("Please enter the index.")
)
return
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
if idx2 in ("default", "reset", "clear"):
await self._reset_provider_override(
event, umo, ProviderType.TEXT_TO_SPEECH, "TTS"
)
return
Comment thread
Rat0323 marked this conversation as resolved.
try:
idx2_int = int(idx2)
except ValueError:
event.set_result(
MessageEventResult().message("❌ Invalid provider index.")
)
return
providers = list(self.context.get_all_tts_providers())
if idx2_int > len(providers) or idx2_int < 1:
event.set_result(
MessageEventResult().message("❌ Invalid provider index.")
)
return
provider = self.context.get_all_tts_providers()[idx2 - 1]
provider = providers[idx2_int - 1]
id_ = provider.meta().id
await self.context.provider_manager.set_provider(
provider_id=id_,
Expand All @@ -213,36 +244,61 @@ async def provider(
MessageEventResult().message("Please enter the index.")
)
return
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
if idx2 in ("default", "reset", "clear"):
await self._reset_provider_override(
event, umo, ProviderType.SPEECH_TO_TEXT, "STT"
)
return
try:
idx2_int = int(idx2)
except ValueError:
event.set_result(
MessageEventResult().message("❌ Invalid provider index.")
)
return
provider = self.context.get_all_stt_providers()[idx2 - 1]
id_ = provider.meta().id
await self.context.provider_manager.set_provider(
provider_id=id_,
provider_type=ProviderType.SPEECH_TO_TEXT,
umo=umo,
)
event.set_result(
MessageEventResult().message(f"✅ Successfully switched to {id_}.")
)
elif isinstance(idx, int):
if idx > len(self.context.get_all_providers()) or idx < 1:
providers = list(self.context.get_all_stt_providers())
if idx2_int > len(providers) or idx2_int < 1:
event.set_result(
MessageEventResult().message("❌ Invalid provider index.")
)
return
provider = self.context.get_all_providers()[idx - 1]
provider = providers[idx2_int - 1]
id_ = provider.meta().id
await self.context.provider_manager.set_provider(
provider_id=id_,
provider_type=ProviderType.CHAT_COMPLETION,
provider_type=ProviderType.SPEECH_TO_TEXT,
umo=umo,
)
event.set_result(
MessageEventResult().message(f"✅ Successfully switched to {id_}.")
)
elif idx in ("default", "reset", "clear"):
await self._reset_provider_override(
event, umo, ProviderType.CHAT_COMPLETION, "Chat Completion"
)
else:
event.set_result(MessageEventResult().message("❌ Invalid parameter."))
try:
idx_int = int(idx)
is_int = True
except (ValueError, TypeError):
is_int = False

if is_int:
providers = list(self.context.get_all_providers())
if idx_int > len(providers) or idx_int < 1:
event.set_result(
MessageEventResult().message("❌ Invalid provider index.")
)
return
provider = providers[idx_int - 1]
id_ = provider.meta().id
await self.context.provider_manager.set_provider(
provider_id=id_,
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
event.set_result(
MessageEventResult().message(f"✅ Successfully switched to {id_}.")
)
else:
event.set_result(MessageEventResult().message("❌ Invalid parameter."))
2 changes: 1 addition & 1 deletion astrbot/builtin_stars/builtin_commands/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def provider(
self,
event: AstrMessageEvent,
idx: str | int | None = None,
idx2: int | None = None,
idx2: str | int | None = None,
) -> None:
"""View or switch LLM Provider"""
await self.provider_c.provider(event, idx, idx2)
Expand Down
70 changes: 70 additions & 0 deletions tests/unit/test_provider_commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from astrbot.api.event import AstrMessageEvent
from astrbot.builtin_stars.builtin_commands.commands.provider import ProviderCommands
from astrbot.core.provider.entities import ProviderType


@pytest.mark.asyncio
async def test_provider_reset_chat_completion():
context = MagicMock()
cmd = ProviderCommands(context)
event = MagicMock(spec=AstrMessageEvent)
event.unified_msg_origin = "session-123"

with patch(
"astrbot.builtin_stars.builtin_commands.commands.provider.sp.session_remove",
new_callable=AsyncMock,
) as mock_remove:
await cmd.provider(event, idx="reset")
mock_remove.assert_called_once_with(
"session-123",
f"provider_perf_{ProviderType.CHAT_COMPLETION.value}",
)
event.set_result.assert_called_once()
res = event.set_result.call_args[0][0]
assert "reset Chat Completion provider" in res.get_plain_text()


@pytest.mark.asyncio
async def test_provider_reset_tts():
context = MagicMock()
cmd = ProviderCommands(context)
event = MagicMock(spec=AstrMessageEvent)
event.unified_msg_origin = "session-123"

with patch(
"astrbot.builtin_stars.builtin_commands.commands.provider.sp.session_remove",
new_callable=AsyncMock,
) as mock_remove:
await cmd.provider(event, idx="tts", idx2="reset")
mock_remove.assert_called_once_with(
"session-123",
f"provider_perf_{ProviderType.TEXT_TO_SPEECH.value}",
)
event.set_result.assert_called_once()
res = event.set_result.call_args[0][0]
assert "reset TTS provider" in res.get_plain_text()


@pytest.mark.asyncio
async def test_provider_reset_stt():
context = MagicMock()
cmd = ProviderCommands(context)
event = MagicMock(spec=AstrMessageEvent)
event.unified_msg_origin = "session-123"

with patch(
"astrbot.builtin_stars.builtin_commands.commands.provider.sp.session_remove",
new_callable=AsyncMock,
) as mock_remove:
await cmd.provider(event, idx="stt", idx2="reset")
mock_remove.assert_called_once_with(
"session-123",
f"provider_perf_{ProviderType.SPEECH_TO_TEXT.value}",
)
event.set_result.assert_called_once()
res = event.set_result.call_args[0][0]
assert "reset STT provider" in res.get_plain_text()
Loading