Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions astrbot/core/provider/sources/gemini_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,19 +377,18 @@ def append_or_extend(
append_or_extend(gemini_contents, parts, types.UserContent)

elif role == "assistant":
if isinstance(content, str):
parts = [types.Part.from_text(text=content)]
append_or_extend(gemini_contents, parts, types.ModelContent)
parts = []
if isinstance(content, str) and content:
parts.append(types.Part.from_text(text=content))
elif isinstance(content, list):
parts = []
thinking_signature = None
text = ""
for part in content:
# for most cases, assistant content only contains two parts: think and text
if part.get("type") == "think":
thinking_signature = part.get("encrypted") or None
else:
text += str(part.get("text"))
text += str(part.get("text") or "")

if thinking_signature and isinstance(thinking_signature, str):
try:
Expand All @@ -406,11 +405,10 @@ def append_or_extend(
thought_signature=thinking_signature,
)
)
append_or_extend(gemini_contents, parts, types.ModelContent)

elif not native_tool_enabled and "tool_calls" in message:
parts = []
for tool in message["tool_calls"]:
tool_calls = message.get("tool_calls") or []
if not native_tool_enabled and tool_calls:
for tool in tool_calls:
part = types.Part.from_function_call(
name=tool["function"]["name"],
args=json.loads(tool["function"]["arguments"]),
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Expand All @@ -427,15 +425,21 @@ def append_or_extend(
if ts_bs64:
part.thought_signature = base64.b64decode(ts_bs64)
parts.append(part)
append_or_extend(gemini_contents, parts, types.ModelContent)
else:

parts = [
part
for part in parts
if part.text or part.thought_signature or part.function_call
]
if not parts:
Comment thread
Rat0323 marked this conversation as resolved.
logger.warning("assistant 角色的消息内容为空,已添加空格占位")
if native_tool_enabled and "tool_calls" in message:
if native_tool_enabled and tool_calls:
logger.warning(
"检测到启用Gemini原生工具,且上下文中存在函数调用,建议使用 /reset 重置上下文",
)
parts = [types.Part.from_text(text=" ")]
append_or_extend(gemini_contents, parts, types.ModelContent)

append_or_extend(gemini_contents, parts, types.ModelContent)

elif role == "tool" and not native_tool_enabled:
func_name = message.get("name", message["tool_call_id"])
Expand Down
130 changes: 130 additions & 0 deletions tests/test_gemini_source.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import base64

import pytest
from google.genai import types

from astrbot.core.exceptions import EmptyModelOutputError
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI


class _ConversationOnlyGeminiProvider(ProviderGoogleGenAI):
def _init_client(self) -> None:
self.client = None


def test_gemini_empty_output_raises_empty_model_output_error():
llm_response = LLMResponse(role="assistant")

Expand All @@ -27,3 +35,125 @@ def test_gemini_reasoning_only_output_is_allowed():
response_id="resp_reasoning",
finish_reason="STOP",
)


def _make_gemini_provider_for_conversation():
return _ConversationOnlyGeminiProvider(
{
"key": ["test-key"],
"model": "gemini-test",
"gm_native_coderunner": False,
"gm_native_search": False,
},
{},
)


def _assistant_tool_call_message(content):
return {
"role": "assistant",
"content": content,
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_pull_request_files",
"arguments": '{"owner":"AstrBotDevs","repo":"AstrBot","pull_number":8742}',
},
},
],
}


def _first_model_parts(gemini_contents):
model_content = next(
content
for content in gemini_contents
if isinstance(content, types.ModelContent)
)
return model_content.parts or []


def test_prepare_conversation_keeps_assistant_text_and_tool_calls():
provider = _make_gemini_provider_for_conversation()
payloads = {
"messages": [
{"role": "user", "content": "summarize this PR"},
_assistant_tool_call_message("I will inspect the changed files first."),
]
}

parts = _first_model_parts(provider._prepare_conversation(payloads))

assert any(part.text == "I will inspect the changed files first." for part in parts)
assert [
part.function_call.name
for part in parts
if getattr(part, "function_call", None)
] == ["get_pull_request_files"]
Comment thread
Rat0323 marked this conversation as resolved.


def test_prepare_conversation_keeps_assistant_only_tool_calls_without_placeholder():
provider = _make_gemini_provider_for_conversation()
payloads = {
"messages": [
{"role": "user", "content": "summarize this PR"},
_assistant_tool_call_message(None),
]
}

parts = _first_model_parts(provider._prepare_conversation(payloads))

assert not any(part.text for part in parts)
assert [
part.function_call.name
for part in parts
if getattr(part, "function_call", None)
] == ["get_pull_request_files"]


def test_prepare_conversation_keeps_assistant_list_content_and_tool_calls():
provider = _make_gemini_provider_for_conversation()
payloads = {
"messages": [
{"role": "user", "content": "summarize this PR"},
_assistant_tool_call_message(
[
{
"type": "think",
"encrypted": base64.b64encode(b"signature").decode("utf-8"),
},
{"type": "text", "text": "I will inspect the changed files first."},
]
),
]
}

parts = _first_model_parts(provider._prepare_conversation(payloads))

assert any(part.text == "I will inspect the changed files first." for part in parts)
assert [
part.function_call.name
for part in parts
if getattr(part, "function_call", None)
] == ["get_pull_request_files"]


def test_prepare_conversation_ignores_null_tool_calls():
provider = _make_gemini_provider_for_conversation()
payloads = {
"messages": [
{"role": "user", "content": "hello"},
{
"role": "assistant",
"content": "hello back",
"tool_calls": None,
},
]
}

parts = _first_model_parts(provider._prepare_conversation(payloads))

assert [part.text for part in parts] == ["hello back"]
assert not any(getattr(part, "function_call", None) for part in parts)
Loading