From 36d917b58327b2ab3e07cf274109ab4af6e39841 Mon Sep 17 00:00:00 2001 From: Rat0323 <261020116+Rat0323@users.noreply.github.com> Date: Mon, 15 Jun 2026 06:55:42 +0800 Subject: [PATCH] fix(gemini): preserve assistant tool calls during history conversion --- .../core/provider/sources/gemini_source.py | 30 ++-- tests/test_gemini_source.py | 130 ++++++++++++++++++ 2 files changed, 147 insertions(+), 13 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 0c58174897..ff5f664dcb 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -377,11 +377,10 @@ def append_or_extend( append_or_extend(gemini_contents, parts, types.UserContent) elif role == "assistant": - if isinstance(content, str): - parts = [types.Part.from_text(text=content)] - append_or_extend(gemini_contents, parts, types.ModelContent) + parts = [] + if isinstance(content, str) and content: + parts.append(types.Part.from_text(text=content)) elif isinstance(content, list): - parts = [] thinking_signature = None text = "" for part in content: @@ -389,7 +388,7 @@ def append_or_extend( if part.get("type") == "think": thinking_signature = part.get("encrypted") or None else: - text += str(part.get("text")) + text += str(part.get("text") or "") if thinking_signature and isinstance(thinking_signature, str): try: @@ -406,11 +405,10 @@ def append_or_extend( thought_signature=thinking_signature, ) ) - append_or_extend(gemini_contents, parts, types.ModelContent) - elif not native_tool_enabled and "tool_calls" in message: - parts = [] - for tool in message["tool_calls"]: + tool_calls = message.get("tool_calls") or [] + if not native_tool_enabled and tool_calls: + for tool in tool_calls: part = types.Part.from_function_call( name=tool["function"]["name"], args=json.loads(tool["function"]["arguments"]), @@ -427,15 +425,21 @@ def append_or_extend( if ts_bs64: part.thought_signature = base64.b64decode(ts_bs64) parts.append(part) - append_or_extend(gemini_contents, parts, types.ModelContent) - else: + + parts = [ + part + for part in parts + if part.text or part.thought_signature or part.function_call + ] + if not parts: logger.warning("assistant 角色的消息内容为空,已添加空格占位") - if native_tool_enabled and "tool_calls" in message: + if native_tool_enabled and tool_calls: logger.warning( "检测到启用Gemini原生工具,且上下文中存在函数调用,建议使用 /reset 重置上下文", ) parts = [types.Part.from_text(text=" ")] - append_or_extend(gemini_contents, parts, types.ModelContent) + + append_or_extend(gemini_contents, parts, types.ModelContent) elif role == "tool" and not native_tool_enabled: func_name = message.get("name", message["tool_call_id"]) diff --git a/tests/test_gemini_source.py b/tests/test_gemini_source.py index 4db8e92bfe..fa19e58319 100644 --- a/tests/test_gemini_source.py +++ b/tests/test_gemini_source.py @@ -1,10 +1,18 @@ +import base64 + import pytest +from google.genai import types from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.provider.entities import LLMResponse from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI +class _ConversationOnlyGeminiProvider(ProviderGoogleGenAI): + def _init_client(self) -> None: + self.client = None + + def test_gemini_empty_output_raises_empty_model_output_error(): llm_response = LLMResponse(role="assistant") @@ -27,3 +35,125 @@ def test_gemini_reasoning_only_output_is_allowed(): response_id="resp_reasoning", finish_reason="STOP", ) + + +def _make_gemini_provider_for_conversation(): + return _ConversationOnlyGeminiProvider( + { + "key": ["test-key"], + "model": "gemini-test", + "gm_native_coderunner": False, + "gm_native_search": False, + }, + {}, + ) + + +def _assistant_tool_call_message(content): + return { + "role": "assistant", + "content": content, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "get_pull_request_files", + "arguments": '{"owner":"AstrBotDevs","repo":"AstrBot","pull_number":8742}', + }, + }, + ], + } + + +def _first_model_parts(gemini_contents): + model_content = next( + content + for content in gemini_contents + if isinstance(content, types.ModelContent) + ) + return model_content.parts or [] + + +def test_prepare_conversation_keeps_assistant_text_and_tool_calls(): + provider = _make_gemini_provider_for_conversation() + payloads = { + "messages": [ + {"role": "user", "content": "summarize this PR"}, + _assistant_tool_call_message("I will inspect the changed files first."), + ] + } + + parts = _first_model_parts(provider._prepare_conversation(payloads)) + + assert any(part.text == "I will inspect the changed files first." for part in parts) + assert [ + part.function_call.name + for part in parts + if getattr(part, "function_call", None) + ] == ["get_pull_request_files"] + + +def test_prepare_conversation_keeps_assistant_only_tool_calls_without_placeholder(): + provider = _make_gemini_provider_for_conversation() + payloads = { + "messages": [ + {"role": "user", "content": "summarize this PR"}, + _assistant_tool_call_message(None), + ] + } + + parts = _first_model_parts(provider._prepare_conversation(payloads)) + + assert not any(part.text for part in parts) + assert [ + part.function_call.name + for part in parts + if getattr(part, "function_call", None) + ] == ["get_pull_request_files"] + + +def test_prepare_conversation_keeps_assistant_list_content_and_tool_calls(): + provider = _make_gemini_provider_for_conversation() + payloads = { + "messages": [ + {"role": "user", "content": "summarize this PR"}, + _assistant_tool_call_message( + [ + { + "type": "think", + "encrypted": base64.b64encode(b"signature").decode("utf-8"), + }, + {"type": "text", "text": "I will inspect the changed files first."}, + ] + ), + ] + } + + parts = _first_model_parts(provider._prepare_conversation(payloads)) + + assert any(part.text == "I will inspect the changed files first." for part in parts) + assert [ + part.function_call.name + for part in parts + if getattr(part, "function_call", None) + ] == ["get_pull_request_files"] + + +def test_prepare_conversation_ignores_null_tool_calls(): + provider = _make_gemini_provider_for_conversation() + payloads = { + "messages": [ + {"role": "user", "content": "hello"}, + { + "role": "assistant", + "content": "hello back", + "tool_calls": None, + }, + ] + } + + parts = _first_model_parts(provider._prepare_conversation(payloads)) + + assert [part.text for part in parts] == ["hello back"] + assert not any(getattr(part, "function_call", None) for part in parts)