diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 4de3d77bfb..cf07d3d1cb 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -87,6 +87,8 @@ BraveWebSearchTool, FirecrawlExtractWebPageTool, FirecrawlWebSearchTool, + KeenableExtractWebPageTool, + KeenableWebSearchTool, TavilyExtractWebPageTool, TavilyWebSearchTool, normalize_legacy_web_search_config, @@ -1194,6 +1196,9 @@ async def _apply_web_search_tools( req.func_tool.add_tool(tool_mgr.get_builtin_tool(FirecrawlExtractWebPageTool)) elif provider == "baidu_ai_search": req.func_tool.add_tool(tool_mgr.get_builtin_tool(BaiduWebSearchTool)) + elif provider == "keenable": + req.func_tool.add_tool(tool_mgr.get_builtin_tool(KeenableWebSearchTool)) + req.func_tool.add_tool(tool_mgr.get_builtin_tool(KeenableExtractWebPageTool)) def _apply_web_search_citation_prompt( diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 9d6d8fe5d0..b0bd4b8cc8 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -112,6 +112,7 @@ "websearch_brave_key": [], "websearch_baidu_app_builder_key": "", "websearch_firecrawl_key": [], + "websearch_keenable_key": [], "web_search_link": False, "display_reasoning_text": False, "identifier": False, @@ -3284,6 +3285,7 @@ "bocha", "brave", "firecrawl", + "keenable", ], "condition": { "provider_settings.web_search": True, @@ -3338,6 +3340,16 @@ "provider_settings.web_search": True, }, }, + "provider_settings.websearch_keenable_key": { + "description": "Keenable API Key", + "type": "list", + "items": {"type": "string"}, + "hint": "可添加多个 Key 进行轮询。获取地址:https://keenable.ai", + "condition": { + "provider_settings.websearch_provider": "keenable", + "provider_settings.web_search": True, + }, + }, "provider_settings.web_search_link": { "description": "显示来源引用", "type": "bool", diff --git a/astrbot/core/tools/web_search_tools.py b/astrbot/core/tools/web_search_tools.py index ebd13d0102..a21669dc0f 100644 --- a/astrbot/core/tools/web_search_tools.py +++ b/astrbot/core/tools/web_search_tools.py @@ -21,6 +21,8 @@ "web_search_brave", "web_search_firecrawl", "firecrawl_extract_web_page", + "web_search_keenable", + "keenable_extract_web_page", ] _TAVILY_WEB_SEARCH_TOOL_CONFIG = { "provider_settings.web_search": True, @@ -42,6 +44,10 @@ "provider_settings.web_search": True, "provider_settings.websearch_provider": "baidu_ai_search", } +_KEENABLE_WEB_SEARCH_TOOL_CONFIG = { + "provider_settings.web_search": True, + "provider_settings.websearch_provider": "keenable", +} @std_dataclass @@ -76,6 +82,7 @@ async def get(self, provider_settings: dict) -> str: _BOCHA_KEY_ROTATOR = _KeyRotator("websearch_bocha_key", "BoCha") _BRAVE_KEY_ROTATOR = _KeyRotator("websearch_brave_key", "Brave") _FIRECRAWL_KEY_ROTATOR = _KeyRotator("websearch_firecrawl_key", "Firecrawl") +_KEENABLE_KEY_ROTATOR = _KeyRotator("websearch_keenable_key", "Keenable") def normalize_legacy_web_search_config(cfg) -> None: @@ -99,6 +106,7 @@ def normalize_legacy_web_search_config(cfg) -> None: "websearch_bocha_key", "websearch_brave_key", "websearch_firecrawl_key", + "websearch_keenable_key", ): value = provider_settings.get(setting_name) if isinstance(value, str): @@ -370,6 +378,61 @@ async def _baidu_search( ] +async def _keenable_search( + provider_settings: dict, + payload: dict, +) -> list[SearchResult]: + api_key = await _KEENABLE_KEY_ROTATOR.get(provider_settings) + header = { + "X-API-Key": api_key, + "X-Keenable-Title": "astrbot", + "Content-Type": "application/json", + } + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + "https://api.keenable.ai/v1/search", + json=payload, + headers=header, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"Keenable web search failed: {reason}, status: {response.status}", + ) + data = await response.json() + return [ + SearchResult( + title=item.get("title", ""), + url=item.get("url", ""), + snippet=item.get("snippet") or item.get("description") or "", + ) + for item in (data.get("results") or []) + if item and item.get("url") + ] + + +async def _keenable_fetch(provider_settings: dict, params: dict) -> dict: + api_key = await _KEENABLE_KEY_ROTATOR.get(provider_settings) + header = {"X-API-Key": api_key, "X-Keenable-Title": "astrbot"} + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.get( + "https://api.keenable.ai/v1/fetch", + params=params, + headers=header, + ) as response: + if response.status != 200: + reason = await response.text() + raise Exception( + f"Keenable web fetch failed: {reason}, status: {response.status}", + ) + data = await response.json() + if not data.get("content"): + raise ValueError( + "Error: Keenable web fetcher does not return any results." + ) + return data + + @builtin_tool(config=_TAVILY_WEB_SEARCH_TOOL_CONFIG) @pydantic_dataclass class TavilyWebSearchTool(FunctionTool[AstrAgentContext]): @@ -803,10 +866,117 @@ async def call(self, context, **kwargs) -> ToolExecResult: return _search_result_payload(results) +@builtin_tool(config=_KEENABLE_WEB_SEARCH_TOOL_CONFIG) +@pydantic_dataclass +class KeenableWebSearchTool(FunctionTool[AstrAgentContext]): + name: str = "web_search_keenable" + description: str = ( + "A web search tool based on Keenable Search API, used to retrieve web " + "pages related to the user's query." + ) + parameters: dict = Field( + default_factory=lambda: { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Required. Search query."}, + "site": { + "type": "string", + "description": 'Optional. Restrict results to a specific site, for example "techcrunch.com".', + }, + "published_after": { + "type": "string", + "description": 'Optional. Only include pages published at or after this time. Accepts "YYYY-MM-DD", an ISO 8601 datetime, or a relative delta like "7d", "3mo".', + }, + "published_before": { + "type": "string", + "description": "Optional. Only include pages published at or before this time. Same formats as published_after.", + }, + "acquired_after": { + "type": "string", + "description": "Optional. Only include pages indexed at or after this time. Same formats as published_after.", + }, + "acquired_before": { + "type": "string", + "description": "Optional. Only include pages indexed at or before this time. Same formats as published_after.", + }, + }, + "required": ["query"], + } + ) + + async def call(self, context, **kwargs) -> ToolExecResult: + _, provider_settings, _ = _get_runtime(context) + if not provider_settings.get("websearch_keenable_key", []): + return "Error: Keenable API key is not configured in AstrBot." + + payload = {"query": kwargs["query"]} + for key in ( + "site", + "published_after", + "published_before", + "acquired_after", + "acquired_before", + ): + if kwargs.get(key): + payload[key] = kwargs[key] + + results = await _keenable_search(provider_settings, payload) + if not results: + return "Error: Keenable web searcher does not return any results." + return _search_result_payload(results) + + +@builtin_tool(config=_KEENABLE_WEB_SEARCH_TOOL_CONFIG) +@pydantic_dataclass +class KeenableExtractWebPageTool(FunctionTool[AstrAgentContext]): + name: str = "keenable_extract_web_page" + description: str = ( + "Extract the content of a web page using Keenable. " + "Only URLs indexed by Keenable are supported." + ) + parameters: dict = Field( + default_factory=lambda: { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "Required. A URL to extract content from.", + }, + "max_chars": { + "type": "integer", + "description": "Optional. Maximum number of characters of content to return. Default is 50000.", + }, + }, + "required": ["url"], + } + ) + + async def call(self, context, **kwargs) -> ToolExecResult: + _, provider_settings, _ = _get_runtime(context) + if not provider_settings.get("websearch_keenable_key", []): + return "Error: Keenable API key is not configured in AstrBot." + + url = str(kwargs.get("url", "")).strip() + if not url: + return "Error: url must be a non-empty string." + + params = {"url": url} + if kwargs.get("max_chars"): + params["max_chars"] = kwargs["max_chars"] + + result = await _keenable_fetch(provider_settings, params) + content = result.get("content", "") + result_url = result.get("url") or url + ret = f"URL: {result_url}\nContent: {content}" if content else "" + return ret or "Error: Keenable web fetcher does not return any results." + + __all__ = [ "BaiduWebSearchTool", "BochaWebSearchTool", "BraveWebSearchTool", + "KeenableExtractWebPageTool", + "KeenableWebSearchTool", "TavilyExtractWebPageTool", "TavilyWebSearchTool", "WEB_SEARCH_TOOL_NAMES", diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 2e8a786d1c..b3f239964a 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -133,6 +133,10 @@ "description": "Baidu Qianfan Smart Cloud APP Builder API Key", "hint": "Reference: [https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)" }, + "websearch_keenable_key": { + "description": "Keenable API Key", + "hint": "Multiple keys can be added for rotation. Get one at [https://keenable.ai](https://keenable.ai)." + }, "web_search_link": { "description": "Display Source Citations" } diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index 6f79726b33..1e4e09ba20 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -133,6 +133,10 @@ "description": "API-ключ Baidu Qianfan APP Builder", "hint": "Ссылка: [https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)" }, + "websearch_keenable_key": { + "description": "API-ключ Keenable", + "hint": "Можно добавить несколько ключей для ротации. Получить ключ: [https://keenable.ai](https://keenable.ai)." + }, "web_search_link": { "description": "Показывать ссылки на источники" } diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index c73fcae2d7..c5502fa7ed 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -135,6 +135,10 @@ "description": "百度千帆智能云 APP Builder API Key", "hint": "参考:[https://console.bce.baidu.com/iam/#/iam/apikey/list](https://console.bce.baidu.com/iam/#/iam/apikey/list)" }, + "websearch_keenable_key": { + "description": "Keenable API Key", + "hint": "可添加多个 Key 进行轮询。获取地址:[https://keenable.ai](https://keenable.ai)" + }, "web_search_link": { "description": "显示来源引用" } diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 6265cb077e..f2c59cb94c 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -510,6 +510,37 @@ async def test_apply_web_search_tools_adds_firecrawl_search_and_extract_tools( assert req.func_tool.get_tool("web_search_firecrawl") is search_tool assert req.func_tool.get_tool("firecrawl_extract_web_page") is extract_tool + @pytest.mark.asyncio + async def test_apply_web_search_tools_adds_keenable_search_and_extract_tools( + self, mock_event, mock_context + ): + """Test Keenable web search injects search and extract tools.""" + module = ama + req = ProviderRequest() + mock_context.get_config.return_value = { + "provider_settings": { + "web_search": True, + "websearch_provider": "keenable", + } + } + search_tool = MagicMock(spec=FunctionTool) + search_tool.name = "web_search_keenable" + extract_tool = MagicMock(spec=FunctionTool) + extract_tool.name = "keenable_extract_web_page" + tool_mgr = MagicMock() + tool_mgr.get_builtin_tool.side_effect = [search_tool, extract_tool] + mock_context.get_llm_tool_manager.return_value = tool_mgr + + await module._apply_web_search_tools(mock_event, req, mock_context) + + assert tool_mgr.get_builtin_tool.call_args_list == [ + ((module.KeenableWebSearchTool,),), + ((module.KeenableExtractWebPageTool,),), + ] + assert req.func_tool is not None + assert req.func_tool.get_tool("web_search_keenable") is search_tool + assert req.func_tool.get_tool("keenable_extract_web_page") is extract_tool + def test_apply_web_search_citation_prompt_for_webchat(self, mock_event): module = ama req = ProviderRequest(system_prompt="base") diff --git a/tests/unit/test_func_tool_manager.py b/tests/unit/test_func_tool_manager.py index d53ed3296f..e6e3aec2d5 100644 --- a/tests/unit/test_func_tool_manager.py +++ b/tests/unit/test_func_tool_manager.py @@ -9,6 +9,8 @@ from astrbot.core.tools.web_search_tools import ( FirecrawlExtractWebPageTool, FirecrawlWebSearchTool, + KeenableExtractWebPageTool, + KeenableWebSearchTool, ) @@ -345,3 +347,15 @@ def test_firecrawl_tools_are_registered_as_builtin_tools(): assert extract_tool.name == "firecrawl_extract_web_page" assert manager.is_builtin_tool("web_search_firecrawl") is True assert manager.is_builtin_tool("firecrawl_extract_web_page") is True + + +def test_keenable_tools_are_registered_as_builtin_tools(): + manager = FunctionToolManager() + + search_tool = manager.get_builtin_tool(KeenableWebSearchTool) + extract_tool = manager.get_builtin_tool(KeenableExtractWebPageTool) + + assert search_tool.name == "web_search_keenable" + assert extract_tool.name == "keenable_extract_web_page" + assert manager.is_builtin_tool("web_search_keenable") is True + assert manager.is_builtin_tool("keenable_extract_web_page") is True diff --git a/tests/unit/test_web_search_tools.py b/tests/unit/test_web_search_tools.py index c0ac3cf800..3df2f284a6 100644 --- a/tests/unit/test_web_search_tools.py +++ b/tests/unit/test_web_search_tools.py @@ -371,6 +371,270 @@ def post(self, url, json, headers): return self.response +class _FakeKeenableSession: + def __init__(self, response): + self.response = response + self.trust_env = None + self.entered = False + self.exited = False + self.posted = None + self.got = None + + async def __aenter__(self): + self.entered = True + return self + + async def __aexit__(self, exc_type, exc, tb): + self.exited = True + return None + + def post(self, url, json, headers): + self.posted = {"url": url, "json": json, "headers": headers} + return self.response + + def get(self, url, params, headers): + self.got = {"url": url, "params": params, "headers": headers} + return self.response + + +def test_normalize_legacy_web_search_config_migrates_keenable_key(): + config = _FakeConfig( + {"provider_settings": {"websearch_keenable_key": "keenable-key"}} + ) + + tools.normalize_legacy_web_search_config(config) + + assert config["provider_settings"]["websearch_keenable_key"] == ["keenable-key"] + assert config.saved is True + + +@pytest.mark.asyncio +async def test_keenable_search_maps_results(monkeypatch): + async def fake_keenable_search(provider_settings, payload): + assert provider_settings["websearch_keenable_key"] == ["keenable-key"] + assert payload == {"query": "AstrBot", "site": "example.com"} + return [ + tools.SearchResult( + title="AstrBot", + url="https://example.com", + snippet="Search result", + ) + ] + + monkeypatch.setattr(tools, "_keenable_search", fake_keenable_search) + tool = tools.KeenableWebSearchTool() + context = _context_with_provider_settings( + {"websearch_keenable_key": ["keenable-key"]} + ) + + result = await tool.call(context, query="AstrBot", site="example.com") + + assert json.loads(result)["results"] == [ + { + "title": "AstrBot", + "url": "https://example.com", + "snippet": "Search result", + "index": json.loads(result)["results"][0]["index"], + } + ] + + +@pytest.mark.asyncio +async def test_keenable_search_uses_api_key_header_and_falls_back_to_description( + monkeypatch, +): + session = _FakeKeenableSession( + _FakeFirecrawlResponse( + status=200, + json_data={ + "query": "AstrBot", + "results": [ + { + "title": "AstrBot", + "url": "https://example.com", + "description": "From description", + }, + {"title": "No URL", "description": "dropped"}, + ], + }, + ) + ) + + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + + results = await tools._keenable_search( + {"websearch_keenable_key": ["keenable-key"]}, + {"query": "AstrBot"}, + ) + + assert session.trust_env is True + assert session.entered is True + assert session.exited is True + assert session.posted == { + "url": "https://api.keenable.ai/v1/search", + "json": {"query": "AstrBot"}, + "headers": { + "X-API-Key": "keenable-key", + "X-Keenable-Title": "astrbot", + "Content-Type": "application/json", + }, + } + assert results == [ + tools.SearchResult( + title="AstrBot", url="https://example.com", snippet="From description" + ) + ] + + +@pytest.mark.asyncio +async def test_keenable_search_handles_null_results_and_items(monkeypatch): + session = _FakeKeenableSession( + _FakeFirecrawlResponse(status=200, json_data={"results": None}) + ) + + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + + assert ( + await tools._keenable_search( + {"websearch_keenable_key": ["keenable-key"]}, {"query": "AstrBot"} + ) + == [] + ) + + session.response = _FakeFirecrawlResponse( + status=200, + json_data={ + "results": [None, {"title": "AstrBot", "url": "https://example.com"}] + }, + ) + results = await tools._keenable_search( + {"websearch_keenable_key": ["keenable-key"]}, {"query": "AstrBot"} + ) + assert results == [ + tools.SearchResult(title="AstrBot", url="https://example.com", snippet="") + ] + + +@pytest.mark.asyncio +async def test_keenable_search_raises_error_for_http_errors(monkeypatch): + session = _FakeKeenableSession( + _FakeFirecrawlResponse(status=401, text_data="Unauthorized") + ) + + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + + with pytest.raises( + Exception, + match="Keenable web search failed: Unauthorized, status: 401", + ): + await tools._keenable_search( + {"websearch_keenable_key": ["keenable-key"]}, + {"query": "AstrBot"}, + ) + + +@pytest.mark.asyncio +async def test_keenable_extract_returns_fetched_markdown(monkeypatch): + async def fake_keenable_fetch(provider_settings, params): + assert provider_settings["websearch_keenable_key"] == ["keenable-key"] + assert params == {"url": "https://example.com", "max_chars": 1000} + return {"url": "https://example.com", "content": "# Example"} + + monkeypatch.setattr(tools, "_keenable_fetch", fake_keenable_fetch) + tool = tools.KeenableExtractWebPageTool() + context = _context_with_provider_settings( + {"websearch_keenable_key": ["keenable-key"]} + ) + + result = await tool.call(context, url="https://example.com", max_chars=1000) + + assert result == "URL: https://example.com\nContent: # Example" + + +@pytest.mark.asyncio +async def test_keenable_fetch_uses_get_with_api_key_header(monkeypatch): + session = _FakeKeenableSession( + _FakeFirecrawlResponse( + status=200, + json_data={"url": "https://example.com", "content": "# Example"}, + ) + ) + + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + + result = await tools._keenable_fetch( + {"websearch_keenable_key": ["keenable-key"]}, + {"url": "https://example.com"}, + ) + + assert result == {"url": "https://example.com", "content": "# Example"} + assert session.got == { + "url": "https://api.keenable.ai/v1/fetch", + "params": {"url": "https://example.com"}, + "headers": {"X-API-Key": "keenable-key", "X-Keenable-Title": "astrbot"}, + } + + +@pytest.mark.asyncio +async def test_keenable_fetch_raises_when_no_content(monkeypatch): + session = _FakeKeenableSession( + _FakeFirecrawlResponse(status=200, json_data={"url": "https://example.com"}) + ) + + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + + with pytest.raises( + ValueError, + match="Keenable web fetcher does not return any results.", + ): + await tools._keenable_fetch( + {"websearch_keenable_key": ["keenable-key"]}, + {"url": "https://example.com"}, + ) + + +@pytest.mark.asyncio +async def test_keenable_fetch_raises_error_for_http_errors(monkeypatch): + session = _FakeKeenableSession( + _FakeFirecrawlResponse(status=403, text_data="Forbidden") + ) + + def fake_client_session(*, trust_env): + session.trust_env = trust_env + return session + + monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session) + + with pytest.raises( + Exception, + match="Keenable web fetch failed: Forbidden, status: 403", + ): + await tools._keenable_fetch( + {"websearch_keenable_key": ["keenable-key"]}, + {"url": "https://example.com"}, + ) + + def _context_with_provider_settings(provider_settings): config = {"provider_settings": provider_settings} agent_context = SimpleNamespace(