From f7ffb370ba81995bec9adaf4957a407d57d8531f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=91=E6=9B=9C?= Date: Wed, 17 Jun 2026 00:31:50 +0800 Subject: [PATCH] =?UTF-8?q?feat(openapi):=20=E6=94=AF=E6=8C=81=E5=9B=BA?= =?UTF-8?q?=E5=AE=9A=E8=AF=B7=E6=B1=82=E5=A4=B4=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 用户要求:从 OpenAPI protocolSpec 中解析字符串类型的 header 固定值,让 remote OpenAPI 工具调用时自动携带这些 header。 实现思路: - 仅将 in: header 且 schema.const 为字符串的参数解析为 fixed_headers - 仅将 in: header 且 schema.enum 只有一个字符串值的参数解析为 fixed_headers - 调用工具时合并 fixed_headers,并过滤同名调用参数,避免误传到 query 或 body - 明确不支持 parameter $ref 和非字符串固定值,并补充对应测试 Signed-off-by: 黑曜 --- agentrun/tool/api/openapi.py | 126 ++++++++++- tests/unittests/tool/test_openapi.py | 325 +++++++++++++++++++++++++++ 2 files changed, 442 insertions(+), 9 deletions(-) diff --git a/agentrun/tool/api/openapi.py b/agentrun/tool/api/openapi.py index 5202dc9..3874fbb 100644 --- a/agentrun/tool/api/openapi.py +++ b/agentrun/tool/api/openapi.py @@ -6,8 +6,9 @@ extracts operations as ToolInfo list, and makes HTTP calls via Server URL. """ +from copy import deepcopy import json -from typing import Any, Dict, Generator, List, Optional +from typing import Any, Dict, Generator, List, Optional, Tuple from urllib.parse import urlparse, urlunparse import httpx @@ -204,6 +205,85 @@ def _resolve_schema( return result + @staticmethod + def _collect_parameters( + path_item: Dict[str, Any], operation: Dict[str, Any] + ) -> List[Dict[str, Any]]: + """合并 path 和 operation 参数 / Merge path and operation parameters.""" + merged: Dict[Tuple[str, str], Dict[str, Any]] = {} + for source in ( + path_item.get("parameters", []), + operation.get("parameters", []), + ): + if not isinstance(source, list): + continue + for param in source: + if not isinstance(param, dict): + continue + key = (str(param.get("name", "")), str(param.get("in", ""))) + merged[key] = param + return list(merged.values()) + + @staticmethod + def _fixed_header_value_from_schema( + schema: Optional[Dict[str, Any]], + ) -> Optional[str]: + """从 schema.const 或单值 enum 提取固定 header 值。""" + if not schema or not isinstance(schema, dict): + return None + + if "const" in schema: + value = schema.get("const") + return value if isinstance(value, str) else None + + enum_values = schema.get("enum") + if isinstance(enum_values, list) and len(enum_values) == 1: + value = enum_values[0] + return value if isinstance(value, str) else None + + return None + + @staticmethod + def _merge_headers( + base_headers: Dict[str, str], fixed_headers: Dict[str, str] + ) -> Dict[str, str]: + """合并请求头,固定 header 按大小写不敏感规则覆盖已有值。""" + merged = dict(base_headers) + for fixed_key, fixed_value in fixed_headers.items(): + fixed_key_lower = fixed_key.lower() + for existing_key in list(merged.keys()): + if existing_key.lower() == fixed_key_lower: + del merged[existing_key] + merged[fixed_key] = fixed_value + return merged + + @staticmethod + def _remove_fixed_header_arguments( + arguments: Optional[Dict[str, Any]], fixed_headers: Dict[str, str] + ) -> Dict[str, Any]: + """从调用参数中移除固定 header,避免误传到 query 或 body。""" + if not arguments: + return {} + + cleaned = dict(arguments) + fixed_names = {name.lower() for name in fixed_headers} + for name in list(cleaned.keys()): + if name.lower() in fixed_names: + del cleaned[name] + return cleaned + + def _prepare_request_inputs( + self, + arguments: Optional[Dict[str, Any]], + fixed_headers: Dict[str, str], + ) -> Tuple[Dict[str, str], Dict[str, Any]]: + """准备请求头和参数 / Prepare request headers and arguments.""" + request_headers = self._merge_headers(self.headers, fixed_headers) + request_arguments = self._remove_fixed_header_arguments( + arguments, fixed_headers + ) + return request_headers, request_arguments + def _parse_operations(self) -> List[Dict[str, Any]]: """解析 OpenAPI Schema 中的所有 operations / Parse all operations from OpenAPI Schema""" if self._operations is not None: @@ -235,7 +315,8 @@ def _parse_operations(self) -> List[Dict[str, Any]]: request_body_schema = self._resolve_schema(raw_schema) parameters_schema = None - parameters = operation.get("parameters", []) + parameters = self._collect_parameters(path_item, operation) + fixed_headers: Dict[str, str] = {} if parameters and isinstance(parameters, list): props = {} required_params = [] @@ -243,7 +324,25 @@ def _parse_operations(self) -> List[Dict[str, Any]]: if not isinstance(param, dict): continue param_name = param.get("name", "") - param_schema = param.get("schema", {"type": "string"}) + if not param_name: + continue + + raw_param_schema = param.get( + "schema", {"type": "string"} + ) + param_schema = self._resolve_schema(raw_param_schema) + if not isinstance(param_schema, dict): + param_schema = {"type": "string"} + + if param.get("in") == "header": + fixed_value = self._fixed_header_value_from_schema( + param_schema + ) + if fixed_value is not None: + fixed_headers[str(param_name)] = fixed_value + continue + + param_schema = deepcopy(param_schema) param_schema["description"] = param.get( "description", "" ) @@ -267,6 +366,7 @@ def _parse_operations(self) -> List[Dict[str, Any]]: "method": method.upper(), "path": path, "input_schema": input_schema, + "fixed_headers": fixed_headers, }) return self._operations @@ -337,6 +437,10 @@ def call_tool( url = f"{base_url.rstrip('/')}{target_operation['path']}" method = target_operation["method"] + fixed_headers = target_operation.get("fixed_headers", {}) + request_headers, request_arguments = self._prepare_request_inputs( + arguments, fixed_headers + ) # 应用 RAM 签名 url, auth = self._build_ram_auth(url) @@ -346,12 +450,12 @@ def call_tool( ) with httpx.Client( - headers=self.headers, timeout=30.0, auth=auth + headers=request_headers, timeout=30.0, auth=auth ) as client: if method in ("POST", "PUT", "PATCH"): - response = client.request(method, url, json=arguments or {}) + response = client.request(method, url, json=request_arguments) else: - response = client.request(method, url, params=arguments or {}) + response = client.request(method, url, params=request_arguments) response.raise_for_status() @@ -396,6 +500,10 @@ async def call_tool_async( url = f"{base_url.rstrip('/')}{target_operation['path']}" method = target_operation["method"] + fixed_headers = target_operation.get("fixed_headers", {}) + request_headers, request_arguments = self._prepare_request_inputs( + arguments, fixed_headers + ) # 应用 RAM 签名 url, auth = self._build_ram_auth(url) @@ -406,15 +514,15 @@ async def call_tool_async( ) async with httpx.AsyncClient( - headers=self.headers, timeout=30.0, auth=auth + headers=request_headers, timeout=30.0, auth=auth ) as client: if method in ("POST", "PUT", "PATCH"): response = await client.request( - method, url, json=arguments or {} + method, url, json=request_arguments ) else: response = await client.request( - method, url, params=arguments or {} + method, url, params=request_arguments ) response.raise_for_status() diff --git a/tests/unittests/tool/test_openapi.py b/tests/unittests/tool/test_openapi.py index 995c40d..1ce4041 100644 --- a/tests/unittests/tool/test_openapi.py +++ b/tests/unittests/tool/test_openapi.py @@ -686,6 +686,331 @@ def test_parse_operations_required_parameters(self): assert "id" in op["input_schema"]["properties"] assert "id" in op["input_schema"]["required"] + def test_parse_operations_fixed_header_const_not_exposed(self): + """测试 header schema.const 会转为固定 header 且不暴露给工具参数""" + spec = json.dumps({ + "openapi": "3.1.0", + "info": {"title": "Test API"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/reset": { + "post": { + "operationId": "resetWorkspaceData", + "parameters": [ + { + "name": "X-Custom-Auth", + "in": "header", + "schema": { + "type": "string", + "const": "fixed-token", + }, + }, + { + "name": "traceId", + "in": "query", + "schema": {"type": "string"}, + }, + ], + } + } + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + operations = client._parse_operations() + + assert len(operations) == 1 + op = operations[0] + assert op["fixed_headers"] == {"X-Custom-Auth": "fixed-token"} + assert "traceId" in op["input_schema"]["properties"] + assert "X-Custom-Auth" not in op["input_schema"]["properties"] + + def test_parse_operations_fixed_header_single_enum_not_exposed(self): + """测试单值 enum 会转为固定 header,不依赖 required 字段""" + spec = json.dumps({ + "openapi": "3.0.0", + "info": {"title": "Test API"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/reset": { + "parameters": [{ + "name": "X-Path-Auth", + "in": "header", + "schema": {"type": "string", "enum": ["path-token"]}, + }], + "post": { + "operationId": "resetWorkspaceData", + "parameters": [{ + "name": "X-Operation-Auth", + "in": "header", + "required": False, + "schema": {"enum": ["operation-token"]}, + }], + }, + } + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + operations = client._parse_operations() + + assert operations[0]["fixed_headers"] == { + "X-Path-Auth": "path-token", + "X-Operation-Auth": "operation-token", + } + assert operations[0]["input_schema"] is None + + def test_parse_operations_multi_enum_header_remains_parameter(self): + """测试多值 enum header 不是固定值,仍作为工具参数暴露""" + spec = json.dumps({ + "openapi": "3.0.0", + "info": {"title": "Test API"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/reset": { + "post": { + "operationId": "resetWorkspaceData", + "parameters": [{ + "name": "X-Custom-Auth", + "in": "header", + "schema": { + "type": "string", + "enum": ["a", "b"], + }, + }], + } + } + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + operations = client._parse_operations() + + op = operations[0] + assert op["fixed_headers"] == {} + assert "X-Custom-Auth" in op["input_schema"]["properties"] + + def test_parse_operations_non_string_fixed_header_remains_parameter(self): + """测试非字符串 const/单值 enum 不会转为固定 header""" + spec = json.dumps({ + "openapi": "3.1.0", + "info": {"title": "Test API"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/reset": { + "post": { + "operationId": "resetWorkspaceData", + "parameters": [ + { + "name": "X-Number-Auth", + "in": "header", + "schema": {"const": 123}, + }, + { + "name": "X-Bool-Auth", + "in": "header", + "schema": {"enum": [True]}, + }, + ], + } + } + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + operations = client._parse_operations() + + op = operations[0] + assert op["fixed_headers"] == {} + assert "X-Number-Auth" in op["input_schema"]["properties"] + assert "X-Bool-Auth" in op["input_schema"]["properties"] + + def test_parse_operations_fixed_header_with_request_body(self): + """测试 requestBody 存在时仍会解析固定 header""" + spec = json.dumps({ + "openapi": "3.1.0", + "info": {"title": "Test API"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/reset": { + "post": { + "operationId": "resetWorkspaceData", + "parameters": [{ + "name": "X-Custom-Auth", + "in": "header", + "schema": {"const": "fixed-token"}, + }], + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "code": {"type": "string"} + }, + "required": ["code"], + } + } + }, + }, + } + } + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + operations = client._parse_operations() + + op = operations[0] + assert op["fixed_headers"] == {"X-Custom-Auth": "fixed-token"} + assert op["input_schema"]["properties"] == {"code": {"type": "string"}} + + def test_parse_operations_parameter_ref_is_not_fixed_header(self): + """测试 parameter $ref 不作为固定 header 解析""" + spec = json.dumps({ + "openapi": "3.1.0", + "info": {"title": "Test API"}, + "servers": [{"url": "https://api.example.com"}], + "components": { + "parameters": { + "CustomAuth": { + "name": "X-Custom-Auth", + "in": "header", + "schema": {"const": "fixed-token"}, + } + } + }, + "paths": { + "/reset": { + "post": { + "operationId": "resetWorkspaceData", + "parameters": [ + {"$ref": "#/components/parameters/CustomAuth"} + ], + } + } + }, + }) + client = ToolOpenAPIClient(protocol_spec=spec) + operations = client._parse_operations() + + assert operations[0]["fixed_headers"] == {} + assert operations[0]["input_schema"] is None + + @patch("agentrun.tool.api.openapi.httpx.Client") + def test_call_tool_uses_fixed_headers_and_filters_arguments( + self, mock_client_class + ): + """测试调用时固定 header 覆盖默认 header,且不会进入 query""" + mock_response = Mock() + mock_response.json.return_value = {"ok": True} + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status = Mock() + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_client_instance.__enter__ = Mock(return_value=mock_client_instance) + mock_client_instance.__exit__ = Mock(return_value=False) + mock_client_class.return_value = mock_client_instance + + spec = json.dumps({ + "openapi": "3.1.0", + "info": {"title": "Test API"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/reset": { + "get": { + "operationId": "resetWorkspaceData", + "parameters": [ + { + "name": "X-Custom-Auth", + "in": "header", + "schema": {"const": "fixed-token"}, + }, + { + "name": "traceId", + "in": "query", + "schema": {"type": "string"}, + }, + ], + } + } + }, + }) + + client = ToolOpenAPIClient( + protocol_spec=spec, + headers={"x-custom-auth": "caller-token", "X-Trace": "base"}, + ) + result = client.call_tool( + "resetWorkspaceData", + {"X-Custom-Auth": "wrong-token", "traceId": "trace-1"}, + ) + + assert result == {"ok": True} + assert mock_client_class.call_args[1]["headers"] == { + "X-Trace": "base", + "X-Custom-Auth": "fixed-token", + } + request_kwargs = mock_client_instance.request.call_args[1] + assert request_kwargs["params"] == {"traceId": "trace-1"} + + @patch("agentrun.tool.api.openapi.httpx.AsyncClient") + async def test_call_tool_async_uses_fixed_headers_and_filters_arguments( + self, mock_async_client_class + ): + """测试异步调用时固定 header 覆盖默认 header,且不会进入 query""" + mock_response = Mock() + mock_response.json.return_value = {"ok": True} + mock_response.headers = {"content-type": "application/json"} + mock_response.raise_for_status = Mock() + + mock_client_instance = AsyncMock() + mock_client_instance.request = AsyncMock(return_value=mock_response) + mock_client_instance.__aenter__ = AsyncMock( + return_value=mock_client_instance + ) + mock_client_instance.__aexit__ = AsyncMock() + mock_async_client_class.return_value = mock_client_instance + + spec = json.dumps({ + "openapi": "3.1.0", + "info": {"title": "Test API"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/reset": { + "get": { + "operationId": "resetWorkspaceData", + "parameters": [ + { + "name": "X-Custom-Auth", + "in": "header", + "schema": {"enum": ["fixed-token"]}, + }, + { + "name": "traceId", + "in": "query", + "schema": {"type": "string"}, + }, + ], + } + } + }, + }) + + client = ToolOpenAPIClient( + protocol_spec=spec, + headers={"x-custom-auth": "caller-token", "X-Trace": "base"}, + ) + result = await client.call_tool_async( + "resetWorkspaceData", + {"X-Custom-Auth": "wrong-token", "traceId": "trace-1"}, + ) + + assert result == {"ok": True} + assert mock_async_client_class.call_args[1]["headers"] == { + "X-Trace": "base", + "X-Custom-Auth": "fixed-token", + } + request_kwargs = mock_client_instance.request.call_args[1] + assert request_kwargs["params"] == {"traceId": "trace-1"} + class AsyncMock(Mock): """Async mock helper"""