Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 117 additions & 9 deletions agentrun/tool/api/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
extracts operations as ToolInfo list, and makes HTTP calls via Server URL.
"""

from copy import deepcopy
import json
from typing import Any, Dict, Generator, List, Optional
from typing import Any, Dict, Generator, List, Optional, Tuple
from urllib.parse import urlparse, urlunparse

import httpx
Expand Down Expand Up @@ -204,6 +205,85 @@ def _resolve_schema(

return result

@staticmethod
def _collect_parameters(
path_item: Dict[str, Any], operation: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""合并 path 和 operation 参数 / Merge path and operation parameters."""
merged: Dict[Tuple[str, str], Dict[str, Any]] = {}
for source in (
path_item.get("parameters", []),
operation.get("parameters", []),
):
if not isinstance(source, list):
continue
for param in source:
if not isinstance(param, dict):
continue
key = (str(param.get("name", "")), str(param.get("in", "")))
merged[key] = param
return list(merged.values())

@staticmethod
def _fixed_header_value_from_schema(
schema: Optional[Dict[str, Any]],
) -> Optional[str]:
"""从 schema.const 或单值 enum 提取固定 header 值。"""
if not schema or not isinstance(schema, dict):
return None

if "const" in schema:
value = schema.get("const")
return value if isinstance(value, str) else None

enum_values = schema.get("enum")
if isinstance(enum_values, list) and len(enum_values) == 1:
value = enum_values[0]
return value if isinstance(value, str) else None

return None

@staticmethod
def _merge_headers(
base_headers: Dict[str, str], fixed_headers: Dict[str, str]
) -> Dict[str, str]:
"""合并请求头,固定 header 按大小写不敏感规则覆盖已有值。"""
merged = dict(base_headers)
for fixed_key, fixed_value in fixed_headers.items():
fixed_key_lower = fixed_key.lower()
for existing_key in list(merged.keys()):
if existing_key.lower() == fixed_key_lower:
del merged[existing_key]
merged[fixed_key] = fixed_value
return merged

@staticmethod
def _remove_fixed_header_arguments(
arguments: Optional[Dict[str, Any]], fixed_headers: Dict[str, str]
) -> Dict[str, Any]:
"""从调用参数中移除固定 header,避免误传到 query 或 body。"""
if not arguments:
return {}

cleaned = dict(arguments)
fixed_names = {name.lower() for name in fixed_headers}
for name in list(cleaned.keys()):
if name.lower() in fixed_names:
del cleaned[name]
return cleaned

def _prepare_request_inputs(
self,
arguments: Optional[Dict[str, Any]],
fixed_headers: Dict[str, str],
) -> Tuple[Dict[str, str], Dict[str, Any]]:
"""准备请求头和参数 / Prepare request headers and arguments."""
request_headers = self._merge_headers(self.headers, fixed_headers)
request_arguments = self._remove_fixed_header_arguments(
arguments, fixed_headers
)
return request_headers, request_arguments

def _parse_operations(self) -> List[Dict[str, Any]]:
"""解析 OpenAPI Schema 中的所有 operations / Parse all operations from OpenAPI Schema"""
if self._operations is not None:
Expand Down Expand Up @@ -235,15 +315,34 @@ def _parse_operations(self) -> List[Dict[str, Any]]:
request_body_schema = self._resolve_schema(raw_schema)

parameters_schema = None
parameters = operation.get("parameters", [])
parameters = self._collect_parameters(path_item, operation)
fixed_headers: Dict[str, str] = {}
if parameters and isinstance(parameters, list):
props = {}
required_params = []
for param in parameters:
if not isinstance(param, dict):
continue
param_name = param.get("name", "")
param_schema = param.get("schema", {"type": "string"})
if not param_name:
continue

raw_param_schema = param.get(
"schema", {"type": "string"}
)
param_schema = self._resolve_schema(raw_param_schema)
if not isinstance(param_schema, dict):
param_schema = {"type": "string"}

if param.get("in") == "header":
fixed_value = self._fixed_header_value_from_schema(
param_schema
)
if fixed_value is not None:
fixed_headers[str(param_name)] = fixed_value
continue

param_schema = deepcopy(param_schema)
param_schema["description"] = param.get(
"description", ""
)
Expand All @@ -267,6 +366,7 @@ def _parse_operations(self) -> List[Dict[str, Any]]:
"method": method.upper(),
"path": path,
"input_schema": input_schema,
"fixed_headers": fixed_headers,
})

return self._operations
Expand Down Expand Up @@ -337,6 +437,10 @@ def call_tool(

url = f"{base_url.rstrip('/')}{target_operation['path']}"
method = target_operation["method"]
fixed_headers = target_operation.get("fixed_headers", {})
request_headers, request_arguments = self._prepare_request_inputs(
arguments, fixed_headers
)

# 应用 RAM 签名
url, auth = self._build_ram_auth(url)
Expand All @@ -346,12 +450,12 @@ def call_tool(
)

with httpx.Client(
headers=self.headers, timeout=30.0, auth=auth
headers=request_headers, timeout=30.0, auth=auth
) as client:
if method in ("POST", "PUT", "PATCH"):
response = client.request(method, url, json=arguments or {})
response = client.request(method, url, json=request_arguments)
else:
response = client.request(method, url, params=arguments or {})
response = client.request(method, url, params=request_arguments)

response.raise_for_status()

Expand Down Expand Up @@ -396,6 +500,10 @@ async def call_tool_async(

url = f"{base_url.rstrip('/')}{target_operation['path']}"
method = target_operation["method"]
fixed_headers = target_operation.get("fixed_headers", {})
request_headers, request_arguments = self._prepare_request_inputs(
arguments, fixed_headers
)

# 应用 RAM 签名
url, auth = self._build_ram_auth(url)
Expand All @@ -406,15 +514,15 @@ async def call_tool_async(
)

async with httpx.AsyncClient(
headers=self.headers, timeout=30.0, auth=auth
headers=request_headers, timeout=30.0, auth=auth
) as client:
if method in ("POST", "PUT", "PATCH"):
response = await client.request(
method, url, json=arguments or {}
method, url, json=request_arguments
)
else:
response = await client.request(
method, url, params=arguments or {}
method, url, params=request_arguments
)

response.raise_for_status()
Expand Down
Loading