Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions astrbot/core/provider/sources/gemini_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,11 +377,10 @@ def append_or_extend(
append_or_extend(gemini_contents, parts, types.UserContent)

elif role == "assistant":
if isinstance(content, str):
parts = [types.Part.from_text(text=content)]
append_or_extend(gemini_contents, parts, types.ModelContent)
parts = []
if isinstance(content, str) and content:
parts.append(types.Part.from_text(text=content))
elif isinstance(content, list):
parts = []
thinking_signature = None
text = ""
for part in content:
Expand All @@ -406,11 +405,10 @@ def append_or_extend(
thought_signature=thinking_signature,
)
)
append_or_extend(gemini_contents, parts, types.ModelContent)

elif not native_tool_enabled and "tool_calls" in message:
parts = []
for tool in message["tool_calls"]:
tool_calls = message.get("tool_calls") or []
if not native_tool_enabled and tool_calls:
for tool in tool_calls:
part = types.Part.from_function_call(
name=tool["function"]["name"],
args=json.loads(tool["function"]["arguments"]),
Expand All @@ -427,15 +425,16 @@ def append_or_extend(
if ts_bs64:
part.thought_signature = base64.b64decode(ts_bs64)
parts.append(part)
append_or_extend(gemini_contents, parts, types.ModelContent)
else:

if not parts:
logger.warning("assistant 角色的消息内容为空,已添加空格占位")
if native_tool_enabled and "tool_calls" in message:
logger.warning(
"检测到启用Gemini原生工具,且上下文中存在函数调用,建议使用 /reset 重置上下文",
)
parts = [types.Part.from_text(text=" ")]
append_or_extend(gemini_contents, parts, types.ModelContent)

append_or_extend(gemini_contents, parts, types.ModelContent)

elif role == "tool" and not native_tool_enabled:
func_name = message.get("name", message["tool_call_id"])
Expand Down
102 changes: 102 additions & 0 deletions tests/test_gemini_source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import base64

import pytest

from astrbot.core.exceptions import EmptyModelOutputError
Expand Down Expand Up @@ -27,3 +29,103 @@ def test_gemini_reasoning_only_output_is_allowed():
response_id="resp_reasoning",
finish_reason="STOP",
)


def _make_gemini_provider_for_conversation():
provider = object.__new__(ProviderGoogleGenAI)
provider.provider_config = {
"gm_native_coderunner": False,
"gm_native_search": False,
}
return provider


def _assistant_tool_call_message(content):
return {
"role": "assistant",
"content": content,
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_pull_request_files",
"arguments": '{"owner":"AstrBotDevs","repo":"AstrBot","pull_number":8742}',
},
},
],
}


def _first_model_parts(gemini_contents):
model_content = next(
content
for content in gemini_contents
if content.__class__.__name__ == "ModelContent"
)
return model_content.parts or []


def test_prepare_conversation_keeps_assistant_text_and_tool_calls():
provider = _make_gemini_provider_for_conversation()
payloads = {
"messages": [
{"role": "user", "content": "summarize this PR"},
_assistant_tool_call_message("I will inspect the changed files first."),
]
}

parts = _first_model_parts(provider._prepare_conversation(payloads))

assert any(part.text == "I will inspect the changed files first." for part in parts)
assert [
part.function_call.name
for part in parts
if getattr(part, "function_call", None)
] == ["get_pull_request_files"]


def test_prepare_conversation_keeps_assistant_list_content_and_tool_calls():
provider = _make_gemini_provider_for_conversation()
payloads = {
"messages": [
{"role": "user", "content": "summarize this PR"},
_assistant_tool_call_message(
[
{
"type": "think",
"encrypted": base64.b64encode(b"signature").decode("utf-8"),
},
{"type": "text", "text": "I will inspect the changed files first."},
]
),
]
}

parts = _first_model_parts(provider._prepare_conversation(payloads))

assert any(part.text == "I will inspect the changed files first." for part in parts)
assert [
part.function_call.name
for part in parts
if getattr(part, "function_call", None)
] == ["get_pull_request_files"]


def test_prepare_conversation_ignores_null_tool_calls():
provider = _make_gemini_provider_for_conversation()
payloads = {
"messages": [
{"role": "user", "content": "hello"},
{
"role": "assistant",
"content": "hello back",
"tool_calls": None,
},
]
}

parts = _first_model_parts(provider._prepare_conversation(payloads))

assert [part.text for part in parts] == ["hello back"]
assert not any(getattr(part, "function_call", None) for part in parts)
Loading