diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index eaa278e..c61ea82 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,8 +54,25 @@ jobs: id: changes run: | echo "Checking if agentrun directory has changes..." - # 获取最近两次提交之间的差异;如果没有父提交,则将所有跟踪文件视为已更改 - if git rev-parse HEAD^ >/dev/null 2>&1; then + # 与默认分支(main)的分叉点比较,取整个分支引入的全部改动, + # 而非仅最近一次提交(避免多 commit 分支漏检早先提交的改动)。 + # CI 以 fetch-depth: 0 检出,origin/main 可用;优先用它,退回本地 main。 + BASE_REF="" + for ref in origin/main main; do + if git rev-parse --verify "$ref" >/dev/null 2>&1; then + BASE_REF="$ref"; break + fi + done + MERGE_BASE="" + if [ -n "$BASE_REF" ]; then + MERGE_BASE=$(git merge-base "$BASE_REF" HEAD 2>/dev/null || echo "") + fi + if [ -n "$MERGE_BASE" ] && [ "$MERGE_BASE" != "$(git rev-parse HEAD)" ]; then + echo "Diffing against base ($BASE_REF, merge-base ${MERGE_BASE})" + git diff --name-only "$MERGE_BASE" HEAD > changed_files.txt + elif git rev-parse HEAD^ >/dev/null 2>&1; then + # 回退(如直接 push 到 main,与 base 无分叉):比较最近一次提交 + echo "No divergence from base; falling back to HEAD^..HEAD" git diff --name-only HEAD^ HEAD > changed_files.txt else echo "No parent commit; treating all tracked files as changed." diff --git a/AGENTS.md b/AGENTS.md index 65cf54f..23daf81 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -471,3 +471,61 @@ A: 建议: 3. 执行相关模块的 ut 测试,确保可以正确执行 4. 进行修改内容的总结汇报 5. 根据汇报内容进行检查,重新检查底层 SDK 和 AgentRun SDK 的定义 + +## 凭证注入与 STS 静默刷新约定(强制) + +STS 临时凭证(ak/sk/security_token)会过期。部署在函数计算(FC)时,最新轮转 +后的 STS 通过**每次请求的 HTTP 头**下发,而非进程级环境变量。为让所有 client 在 +凭证过期后静默刷新,本仓库采用统一机制: + +1. **请求级 overlay**:`agentrun/server/sts_middleware.py` 解析 FC 头 + (默认 `x-fc-access-key-id` / `x-fc-access-key-secret` / `x-fc-security-token`, + 可经构造参数或 `AGENTRUN_STS_HEADER_*` 环境变量覆盖),写入 + `agentrun/utils/credential_context.py` 的 `contextvars` overlay。中间件本身 + 只是 `use_sts_from_headers` 的薄封装(加 FC 门控),二者共用同一套解析逻辑。 + + **非 agentrun server 场景**(自有 FastAPI / Flask / Django、或非 HTTP 任务): + 中间件不会运行,需用户手动注入。SDK 顶层导出两个上下文管理器: + - `agentrun.use_sts_credentials(ak, sk, sts)` —— 显式传值; + - `agentrun.use_sts_from_headers(headers)` —— 从任意请求头映射解析(同 `x-fc-*`)。 + + ```python + from agentrun import use_sts_from_headers + with use_sts_from_headers(request.headers): + ... # 块内所有 SDK 调用使用最新 STS,退出自动复位 + ``` + +2. **Config 懒解析**:`Config` 的三个凭证 getter 按 + **显式传入 > 请求级 overlay(仅当三者均未显式传入)> 环境变量** 解析。 + 切勿在 `Config.__init__` 里把凭证快照成固定字符串。 + +3. **client 一律用 credential provider,禁止传静态 ak/sk/sts**: + + - **alibabacloud OpenAPI**(控制面 / Bailian / GPDB / Devs 等): + 构造 `open_api_util_models.Config` 时传 + `credential=build_openapi_credential(cfg)` + (见 `agentrun/utils/credential_providers.py`),**不要**再传 + `access_key_id` / `access_key_secret` / `security_token`。 + 注意 tea_openapi 的优先级是「静态 ak/sk 优先于 credential」,传了静态值 + 会让 provider 失效。 + + - **TableStore `OTSClient` / `AsyncOTSClient`**: + 构造时传 `credentials_provider=build_ots_credentials_provider(cfg)` + (见 `agentrun/conversation_service/utils.py`),**不要**再传 + `access_key_id` / `access_key_secret` / `sts_token`。 + + 原因:直接传静态凭证会在 client 构造时把凭证冻结,长生命周期 client(如 + server 启动时仅创建一次的 OTSClient)在 STS 过期后所有请求都会失败。 + provider 会在**每次请求**被底层 SDK 调用,从而拿到最新 STS。 + +4. **数据面手写签名**(`agentrun/utils/data_api.py` 的 RAM 签名)无需 provider: + 它本就每次请求调用 `cfg.get_*()`,已随 Config 懒解析自动刷新。 + +5. **自定义 httpx 签名器**(如 `_AgentrunRamAuth`,用于 MCP SSE / OpenAPI 工具) + 必须**持有 `Config`**、在 `auth_flow` 内调用 `cfg.get_*()` 取证, + **不要**在 `__init__` 把 ak/sk/sts 快照成字段。否则长连接(SSE 一次建连、 + 多请求复用)会冻结建连时的 STS。 + +新增任何与阿里云 / TableStore 交互的 client 时,必须遵循第 3 条;新增单测应覆盖 +「overlay 生效」与「显式凭证不被 overlay 覆盖」两种情况 +(参考 `tests/unittests/test_sts_refresh.py`)。 diff --git a/agentrun/__init__.py b/agentrun/__init__.py index 59aa57c..6f9b622 100644 --- a/agentrun/__init__.py +++ b/agentrun/__init__.py @@ -135,6 +135,11 @@ # ToolSet from agentrun.toolset import ToolSet, ToolSetClient from agentrun.utils.config import Config +from agentrun.utils.credential_context import ( + StsCredential, + use_sts_credentials, + use_sts_from_headers, +) from agentrun.utils.exception import ( ResourceAlreadyExistError, ResourceNotExistError, @@ -335,6 +340,10 @@ "ResourceNotExistError", "ResourceAlreadyExistError", "Config", + ######## STS 凭证刷新(非 server 场景手动注入) ######## + "StsCredential", + "use_sts_credentials", + "use_sts_from_headers", ] # Memory Collection 模块的所有导出(延迟加载) diff --git a/agentrun/conversation_service/__session_store_async_template.py b/agentrun/conversation_service/__session_store_async_template.py index 667c38c..83a6248 100644 --- a/agentrun/conversation_service/__session_store_async_template.py +++ b/agentrun/conversation_service/__session_store_async_template.py @@ -880,7 +880,7 @@ async def from_memory_collection_async( "vector_store_config.instance_name 为空。" ) - # 3. 获取凭证 + # 3. 校验凭证存在(fail-fast;运行时由 CredentialsProvider 动态取证) effective_config = config if isinstance(config, Config) else Config() access_key_id = effective_config.get_access_key_id() access_key_secret = effective_config.get_access_key_secret() @@ -891,17 +891,13 @@ async def from_memory_collection_async( "AGENTRUN_ACCESS_KEY_ID / AGENTRUN_ACCESS_KEY_SECRET。" ) - security_token = effective_config.get_security_token() - sts_token = security_token if security_token else None - # 4. 构建 OTSClient + AsyncOTSClient 和 OTSBackend # 使用 utils.build_ots_clients 避免 codegen 替换 AsyncOTSClient + # 传入 config:OTS 经 CredentialsProvider 每次请求动态取最新 STS。 ots_client, async_ots_client = build_ots_clients( endpoint, - access_key_id, - access_key_secret, instance_name, - sts_token=sts_token, + config=effective_config, ) backend = OTSBackend( diff --git a/agentrun/conversation_service/session_store.py b/agentrun/conversation_service/session_store.py index 2360c02..6cd3d8c 100644 --- a/agentrun/conversation_service/session_store.py +++ b/agentrun/conversation_service/session_store.py @@ -1645,7 +1645,7 @@ async def from_memory_collection_async( "vector_store_config.instance_name 为空。" ) - # 3. 获取凭证 + # 3. 校验凭证存在(fail-fast;运行时由 CredentialsProvider 动态取证) effective_config = config if isinstance(config, Config) else Config() access_key_id = effective_config.get_access_key_id() access_key_secret = effective_config.get_access_key_secret() @@ -1656,17 +1656,13 @@ async def from_memory_collection_async( "AGENTRUN_ACCESS_KEY_ID / AGENTRUN_ACCESS_KEY_SECRET。" ) - security_token = effective_config.get_security_token() - sts_token = security_token if security_token else None - # 4. 构建 OTSClient + AsyncOTSClient 和 OTSBackend # 使用 utils.build_ots_clients 避免 codegen 替换 AsyncOTSClient + # 传入 config:OTS 经 CredentialsProvider 每次请求动态取最新 STS。 ots_client, async_ots_client = build_ots_clients( endpoint, - access_key_id, - access_key_secret, instance_name, - sts_token=sts_token, + config=effective_config, ) backend = OTSBackend( @@ -1749,7 +1745,7 @@ def from_memory_collection( "vector_store_config.instance_name 为空。" ) - # 3. 获取凭证 + # 3. 校验凭证存在(fail-fast;运行时由 CredentialsProvider 动态取证) effective_config = config if isinstance(config, Config) else Config() access_key_id = effective_config.get_access_key_id() access_key_secret = effective_config.get_access_key_secret() @@ -1760,17 +1756,13 @@ def from_memory_collection( "AGENTRUN_ACCESS_KEY_ID / AGENTRUN_ACCESS_KEY_SECRET。" ) - security_token = effective_config.get_security_token() - sts_token = security_token if security_token else None - # 4. 构建 OTSClient + OTSClient 和 OTSBackend # 使用 utils.build_ots_clients 避免 codegen 替换 OTSClient + # 传入 config:OTS 经 CredentialsProvider 每次请求动态取最新 STS。 ots_client, async_ots_client = build_ots_clients( endpoint, - access_key_id, - access_key_secret, instance_name, - sts_token=sts_token, + config=effective_config, ) backend = OTSBackend( diff --git a/agentrun/conversation_service/utils.py b/agentrun/conversation_service/utils.py index 89fdab7..4f9b732 100644 --- a/agentrun/conversation_service/utils.py +++ b/agentrun/conversation_service/utils.py @@ -12,6 +12,9 @@ if TYPE_CHECKING: from tablestore import AsyncOTSClient # type: ignore[import-untyped] from tablestore import OTSClient + from tablestore.credentials import CredentialsProvider + + from agentrun.utils.config import Config # OTS 单个属性列值上限为 2MB,留 0.5MB 余量(按字符数计) MAX_COLUMN_SIZE: int = 1_500_000 # 1.5M 字符 @@ -103,38 +106,73 @@ def from_chunks(chunks: list[str]) -> str: return "".join(chunks) +def build_ots_credentials_provider(config: "Config") -> "CredentialsProvider": + """构建 TableStore CredentialsProvider,每次请求从 Config 实时取最新 STS。 + + TableStore client 在**每个请求**调用 ``credentials_provider.get_credentials()`` + (见 tablestore client 的 ``_request_helper``),因此长生命周期的 OTSClient + 也能在每次操作时拿到请求级 overlay 注入的最新 STS(再回退环境变量)。 + + Args: + config: agentrun Config 对象,凭证经其 getter 解析(overlay 优先)。 + + Returns: + TableStore ``CredentialsProvider`` 实例。 + """ + from tablestore.credentials import Credentials, CredentialsProvider + + class _AgentrunOtsCredentialsProvider(CredentialsProvider): + def __init__(self, cfg: "Config") -> None: + self._cfg = cfg + + def get_credentials(self) -> Credentials: + cfg = self._cfg + return Credentials( + access_key_id=cfg.get_access_key_id(), + access_key_secret=cfg.get_access_key_secret(), + security_token=cfg.get_security_token() or None, + ) + + return _AgentrunOtsCredentialsProvider(config) + + def build_ots_clients( endpoint: str, - access_key_id: str, - access_key_secret: str, instance_name: str, *, - sts_token: str | None = None, + config: "Config", ) -> tuple[OTSClient, AsyncOTSClient]: """构建 OTSClient 和 AsyncOTSClient 实例。 独立于 codegen 模板,避免 AsyncOTSClient 被替换为 OTSClient。 + 凭证统一通过 :func:`build_ots_credentials_provider` 注入:client 每次请求 + 动态从 ``config`` 取最新 STS(请求级 overlay 优先),STS 过期可静默刷新。 + 遵循 AGENTS.md 约定——不接受静态 ak/sk/sts。 + + Args: + endpoint: OTS endpoint。 + instance_name: OTS 实例名。 + config: agentrun Config 对象,凭证经其 getter 动态解析。 + Returns: (ots_client, async_ots_client) 二元组。 """ from tablestore import AsyncOTSClient # type: ignore[import-untyped] from tablestore import OTSClient, WriteRetryPolicy + # 同一个 provider 可被 sync / async client 共享:无状态,按请求读 overlay。 + provider = build_ots_credentials_provider(config) ots_client = OTSClient( endpoint, - access_key_id, - access_key_secret, - instance_name, - sts_token=sts_token, + instance_name=instance_name, + credentials_provider=provider, retry_policy=WriteRetryPolicy(), ) async_ots_client = AsyncOTSClient( endpoint, - access_key_id, - access_key_secret, - instance_name, - sts_token=sts_token, + instance_name=instance_name, + credentials_provider=provider, retry_policy=WriteRetryPolicy(), ) return ots_client, async_ots_client diff --git a/agentrun/memory_collection/memory_conversation.py b/agentrun/memory_collection/memory_conversation.py index 7ee16ee..3f3ac30 100644 --- a/agentrun/memory_collection/memory_conversation.py +++ b/agentrun/memory_collection/memory_conversation.py @@ -189,19 +189,18 @@ async def _init_memory_store(self): ots_config = await self._get_ots_config_from_memory_collection() # 创建 AsyncOTSClient - # 支持使用 STS 临时凭证访问 TableStore - client_kwargs = { - "end_point": ots_config["endpoint"], - "access_key_id": ots_config["access_key_id"], - "access_key_secret": ots_config["access_key_secret"], - "instance_name": ots_config["instance_name"], - } - - # 如果提供了 security_token,则添加到参数中(支持 STS 临时凭证) - if ots_config.get("security_token"): - client_kwargs["sts_token"] = ots_config["security_token"] + # 通过 CredentialsProvider 注入凭证:长生命周期的 client(server 启动时 + # 仅创建一次)也能在每次请求动态取到请求级 overlay 注入的最新 STS + # (再回退环境变量),无需重建连接。 + from agentrun.conversation_service.utils import ( + build_ots_credentials_provider, + ) - self._ots_client = tablestore.AsyncOTSClient(**client_kwargs) + self._ots_client = tablestore.AsyncOTSClient( + end_point=ots_config["endpoint"], + instance_name=ots_config["instance_name"], + credentials_provider=build_ots_credentials_provider(self.config), + ) # 配置会话表的二级索引元数据字段 # agent_id 字段用于标识会话所属的 Agent @@ -259,10 +258,10 @@ async def _get_ots_config_from_memory_collection(self) -> Dict[str, Any]: Returns: Dict[str, Any]: OTS 配置字典,包含: - endpoint: OTS endpoint - - access_key_id: 访问密钥 ID - - access_key_secret: 访问密钥 Secret - - security_token: STS 安全令牌(可选,用于临时凭证) - instance_name: OTS 实例名称 + + 凭证(ak/sk/sts)不在此返回:OTSClient 通过 CredentialsProvider + 每次请求动态从 Config 取最新 STS,无需在此快照。 """ from agentrun.memory_collection import MemoryCollection @@ -307,15 +306,10 @@ async def _get_ots_config_from_memory_collection(self) -> Dict[str, Any]: f" {original_endpoint} -> {endpoint}" ) - # 构建 OTS 配置 + # 构建 OTS 配置(仅连接信息;凭证由 CredentialsProvider 动态注入) ots_config = { "endpoint": endpoint, "instance_name": vs_config.instance_name or "", - "access_key_id": self.config.get_access_key_id(), - "access_key_secret": self.config.get_access_key_secret(), - "security_token": ( - self.config.get_security_token() - ), # 支持 STS 临时凭证 } return ots_config diff --git a/agentrun/server/invoker.py b/agentrun/server/invoker.py index d08b65b..763e6a0 100644 --- a/agentrun/server/invoker.py +++ b/agentrun/server/invoker.py @@ -10,6 +10,7 @@ """ import asyncio +import contextvars import inspect from typing import ( Any, @@ -201,7 +202,13 @@ async def _call_handler(self, request: AgentRequest) -> Any: else: sync_handler = cast(SyncInvokeAgentHandler, self.invoke_agent) loop = asyncio.get_running_loop() - result = await loop.run_in_executor(None, sync_handler, request) + # 拷贝当前 contextvars(含请求级 STS overlay)进 executor 线程: + # loop.run_in_executor 默认不传播调用方 context,否则同步 handler + # 内的凭证 getter 读不到请求注入的最新 STS(会回退到陈旧 env)。 + ctx = contextvars.copy_context() + result = await loop.run_in_executor( + None, lambda: ctx.run(sync_handler, request) + ) return result @@ -299,6 +306,9 @@ async def _iterate_async( else: loop = asyncio.get_running_loop() iterator = iter(content) + # 拷贝当前 contextvars(含请求级 STS overlay)进 executor 线程, + # 使同步生成器每次 next() 内的凭证 getter 都能取到最新 STS。 + ctx = contextvars.copy_context() _STOP = object() @@ -309,7 +319,9 @@ def _safe_next() -> Any: return _STOP while True: - chunk = await loop.run_in_executor(None, _safe_next) + chunk = await loop.run_in_executor( + None, lambda: ctx.run(_safe_next) + ) if chunk is _STOP: break yield chunk diff --git a/agentrun/server/server.py b/agentrun/server/server.py index 5799349..8a5f4d7 100644 --- a/agentrun/server/server.py +++ b/agentrun/server/server.py @@ -18,6 +18,7 @@ from .model import ServerConfig from .openai_protocol import OpenAIProtocolHandler from .protocol import InvokeAgentHandler, ProtocolHandler +from .sts_middleware import StsRefreshMiddleware class AgentRunServer: @@ -123,6 +124,12 @@ def __init__( """ self.app = FastAPI(title="AgentRun Server") + # 注入 STS 刷新中间件:从每次请求的 x-fc-* 头解析最新 STS 临时凭证,写入 + # 请求级 overlay,使本次请求内所有 Config/client 静默使用最新凭证。 + # 默认启用;未携带相关头时不产生任何副作用。如需关闭设环境变量 + # AGENTRUN_STS_REFRESH_ENABLED=false。 + self.app.add_middleware(StsRefreshMiddleware) + # 如果启用了 memory,包装 invoke_agent if memory_collection_name: invoke_agent = self._wrap_with_memory( diff --git a/agentrun/server/sts_middleware.py b/agentrun/server/sts_middleware.py new file mode 100644 index 0000000..0f133c2 --- /dev/null +++ b/agentrun/server/sts_middleware.py @@ -0,0 +1,99 @@ +"""STS 刷新中间件 / STS refresh middleware. + +部署在函数计算(FC)时,每次请求的 HTTP 头会携带最新轮转的 STS 临时凭证。 +此中间件在请求进入时解析这些头,写入请求级 overlay +(:mod:`agentrun.utils.credential_context`),使本次请求内所有 ``Config`` / +client 的取证都拿到最新 STS;请求结束时复位。 + +为何用纯 ASGI 中间件 / Why a plain ASGI middleware: + 本中间件只做一件事——设置 / 复位一个 contextvar,故用**纯 ASGI** 实现 + (``__call__`` 包裹 ``await self.app(scope, receive, send)``),而非 + ``BaseHTTPMiddleware``。优点:overlay 与请求**同任务、同生命周期**——对 + endpoint、``StreamingResponse`` 的 body、``run_in_threadpool`` 的同步处理器、 + 以及**响应后的 background task** 全程可见,并在 app 完全结束后于 ``finally`` + 复位;同时避免 ``BaseHTTPMiddleware`` 的额外 task/stream 包装及其在流式 / + 断连 / 异常传播上的已知坑。 + +注入时机与有效期 / Injection lifetime: + STS 在请求入口注入一份、整条请求固定不变(头只到达一次)。流式响应全程使用 + 这份入口 STS;仅当**单条请求持续时间超过 STS 有效期**时才会中途过期——属按 + 请求头注入模型的固有上限,正常请求 / 流远短于有效期,不受影响。 + +头名可配置 / Configurable header names: + 构造参数 > 环境变量 > 默认值(``x-fc-*``)。头名大小写不敏感。 + +启用开关 / Enable switch: + 中间件**默认启用**。仅在特定情况下关闭:构造参数 ``enabled=False``,或 + 环境变量 ``AGENTRUN_STS_REFRESH_ENABLED`` 设为假值(``0`` / ``false`` / + ``no`` / ``off``)。 + +信任边界 / Trust boundary: + overlay 仅在 ``x-fc-*`` 头**齐全**时注入(覆盖运维方 env 凭证),否则透传。 + 函数计算(FC)拥有该头命名空间并会剥离客户端伪造的同名头,FC 内安全。 + **注意**:若部署在非 FC 环境(裸 uvicorn / 自有网关)且服务可被不可信客户端 + 直达,攻击者可注入 ``x-fc-*`` 头冒用身份——此类场景请前置鉴权 / 由网关剥离 + 这些头,或按上面的开关关闭本中间件。 +""" + +from __future__ import annotations + +import os +from typing import Optional + +from starlette.datastructures import Headers +from starlette.types import ASGIApp, Receive, Scope, Send + +from agentrun.utils.credential_context import use_sts_from_headers + + +def _detect_enabled() -> bool: + """是否启用 overlay:**默认启用**,仅环境变量显式设为假值时关闭。 + + ``AGENTRUN_STS_REFRESH_ENABLED`` 未设置 -> 启用;设为 + ``0`` / ``false`` / ``no`` / ``off`` -> 关闭;其余真值 -> 启用。 + """ + flag = os.getenv("AGENTRUN_STS_REFRESH_ENABLED") + if flag is None: + return True + return flag.strip().lower() in ("1", "true", "yes", "on") + + +class StsRefreshMiddleware: + """纯 ASGI 中间件:从请求头解析最新 STS 并注入请求级 overlay。""" + + def __init__( + self, + app: ASGIApp, + *, + enabled: Optional[bool] = None, + access_key_id_header: Optional[str] = None, + access_key_secret_header: Optional[str] = None, + security_token_header: Optional[str] = None, + ) -> None: + self.app = app + # enabled=None 时按环境变量决定(默认启用, + # AGENTRUN_STS_REFRESH_ENABLED 设为假值时关闭)。 + self._enabled = _detect_enabled() if enabled is None else enabled + # 头名解析(参数 > 环境变量 > 默认)交由 sts_from_headers 处理,这里只存原值。 + self._ak_header = access_key_id_header + self._sk_header = access_key_secret_header + self._sts_header = security_token_header + + async def __call__( + self, scope: Scope, receive: Receive, send: Send + ) -> None: + if scope["type"] != "http" or not self._enabled: + await self.app(scope, receive, send) + return + + # 复用公开上下文管理器:解析请求头 -> 注入 overlay -> app 整体跑完后复位。 + # 三元组不齐全时 use_sts_from_headers 不覆盖(透传),与手动注入完全一致。 + # 纯 ASGI:overlay 在同一任务内对 endpoint / 流式 body / 同步处理器 / + # 响应后的 background task 全程可见,``with`` 在 app 结束后才退出复位。 + with use_sts_from_headers( + Headers(scope=scope), + access_key_id_header=self._ak_header, + access_key_secret_header=self._sk_header, + security_token_header=self._sts_header, + ): + await self.app(scope, receive, send) diff --git a/agentrun/tool/api/mcp.py b/agentrun/tool/api/mcp.py index dd7702d..ab2cf8d 100644 --- a/agentrun/tool/api/mcp.py +++ b/agentrun/tool/api/mcp.py @@ -5,15 +5,13 @@ """ import asyncio -from typing import Any, Dict, Generator, List, Optional +from typing import Any, Dict, List, Optional from urllib.parse import urlparse, urlunparse -import httpx - from agentrun.tool.model import ToolInfo from agentrun.utils.config import Config from agentrun.utils.log import logger -from agentrun.utils.ram_signature import get_agentrun_signed_headers +from agentrun.utils.ram_signature.auth import AgentrunRamAuth _MCP_METADATA_TIMEOUT_SECONDS = 30.0 @@ -33,63 +31,6 @@ def _get_or_create_event_loop() -> asyncio.AbstractEventLoop: return loop -class _AgentrunRamAuth(httpx.Auth): - """httpx Auth handler:为每次请求动态生成 RAM 签名。 - - SSE 场景下同一个 httpx.AsyncClient 会发出 GET(SSE 连接)和 - POST(消息发送)请求,URL / method / body 各不相同,因此必须 - per-request 计算签名,不能在 client 初始化时一次性设置 headers。 - """ - - def __init__( - self, - access_key_id: str, - access_key_secret: str, - region: str, - security_token: Optional[str] = None, - ): - self._ak = access_key_id - self._sk = access_key_secret - self._region = region - self._security_token = security_token - - def auth_flow( - self, request: httpx.Request - ) -> Generator[httpx.Request, httpx.Response, None]: - url = str(request.url) - method = request.method - - body: Optional[bytes] = None - if request.content: - body = request.content - - content_type: Optional[str] = request.headers.get("content-type") - - try: - signed = get_agentrun_signed_headers( - url=url, - method=method, - access_key_id=self._ak, - access_key_secret=self._sk, - security_token=self._security_token, - region=self._region, - product="agentrun", - body=body, - content_type=content_type, - ) - for k, v in signed.items(): - request.headers[k] = v - logger.debug( - "applied RAM signature for MCP %s request to %s", - method, - url[:80] + ("..." if len(url) > 80 else ""), - ) - except ValueError as e: - logger.warning("RAM signing skipped for MCP request: %s", e) - - yield request - - def _rewrite_to_ram_url(url: str) -> str: """将 agentrun-data 域名改写为 -ram 端点。""" parsed = urlparse(url) @@ -218,12 +159,7 @@ def _build_ram_auth(self, url: str) -> tuple: url = _rewrite_to_ram_url(url) - auth = _AgentrunRamAuth( - access_key_id=ak, - access_key_secret=sk, - region=cfg.get_region_id(), - security_token=cfg.get_security_token() or None, - ) + auth = AgentrunRamAuth(config=cfg) return url, auth async def list_tools_async(self) -> List[ToolInfo]: diff --git a/agentrun/tool/api/openapi.py b/agentrun/tool/api/openapi.py index 3874fbb..c5780cb 100644 --- a/agentrun/tool/api/openapi.py +++ b/agentrun/tool/api/openapi.py @@ -8,7 +8,7 @@ from copy import deepcopy import json -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse, urlunparse import httpx @@ -16,59 +16,7 @@ from agentrun.tool.model import ToolInfo, ToolSchema from agentrun.utils.config import Config from agentrun.utils.log import logger -from agentrun.utils.ram_signature import get_agentrun_signed_headers - - -class _AgentrunRamAuth(httpx.Auth): - """httpx Auth handler:为每次请求动态生成 RAM 签名。""" - - def __init__( - self, - access_key_id: str, - access_key_secret: str, - region: str, - security_token: Optional[str] = None, - ): - self._ak = access_key_id - self._sk = access_key_secret - self._region = region - self._security_token = security_token - - def auth_flow( - self, request: httpx.Request - ) -> Generator[httpx.Request, httpx.Response, None]: - url = str(request.url) - method = request.method - - body: Optional[bytes] = None - if request.content: - body = request.content - - content_type: Optional[str] = request.headers.get("content-type") - - try: - signed = get_agentrun_signed_headers( - url=url, - method=method, - access_key_id=self._ak, - access_key_secret=self._sk, - security_token=self._security_token, - region=self._region, - product="agentrun", - body=body, - content_type=content_type, - ) - for k, v in signed.items(): - request.headers[k] = v - logger.debug( - "applied RAM signature for OpenAPI %s request to %s", - method, - url[:80] + ("..." if len(url) > 80 else ""), - ) - except ValueError as e: - logger.warning("RAM signing skipped for OpenAPI request: %s", e) - - yield request +from agentrun.utils.ram_signature.auth import AgentrunRamAuth def _rewrite_to_ram_url(url: str) -> str: @@ -558,10 +506,5 @@ def _build_ram_auth(self, url: str) -> tuple: url = _rewrite_to_ram_url(url) - auth = _AgentrunRamAuth( - access_key_id=ak, - access_key_secret=sk, - region=cfg.get_region_id(), - security_token=cfg.get_security_token() or None, - ) + auth = AgentrunRamAuth(config=cfg) return url, auth diff --git a/agentrun/toolset/api/mcp.py b/agentrun/toolset/api/mcp.py index ab0d3a1..4d84ba8 100644 --- a/agentrun/toolset/api/mcp.py +++ b/agentrun/toolset/api/mcp.py @@ -4,71 +4,12 @@ Handles tool invocations for MCP (Model Context Protocol). """ -from typing import Any, Dict, Generator, Optional +from typing import Any, Dict, Optional from urllib.parse import urlparse, urlunparse -import httpx - from agentrun.utils.config import Config from agentrun.utils.log import logger -from agentrun.utils.ram_signature import get_agentrun_signed_headers - - -class _AgentrunRamAuth(httpx.Auth): - """httpx Auth handler:为每次请求动态生成 RAM 签名。 - - SSE 场景下同一个 httpx.AsyncClient 会发出 GET(SSE 连接)和 - POST(消息发送)请求,URL / method / body 各不相同,因此必须 - per-request 计算签名,不能在 client 初始化时一次性设置 headers。 - """ - - def __init__( - self, - access_key_id: str, - access_key_secret: str, - region: str, - security_token: Optional[str] = None, - ): - self._ak = access_key_id - self._sk = access_key_secret - self._region = region - self._security_token = security_token - - def auth_flow( - self, request: httpx.Request - ) -> Generator[httpx.Request, httpx.Response, None]: - url = str(request.url) - method = request.method - - body: Optional[bytes] = None - if request.content: - body = request.content - - content_type: Optional[str] = request.headers.get("content-type") - - try: - signed = get_agentrun_signed_headers( - url=url, - method=method, - access_key_id=self._ak, - access_key_secret=self._sk, - security_token=self._security_token, - region=self._region, - product="agentrun", - body=body, - content_type=content_type, - ) - for k, v in signed.items(): - request.headers[k] = v - logger.debug( - "applied RAM signature for MCP %s request to %s", - method, - url[:80] + ("..." if len(url) > 80 else ""), - ) - except ValueError as e: - logger.warning("RAM signing skipped for MCP request: %s", e) - - yield request +from agentrun.utils.ram_signature.auth import AgentrunRamAuth def _rewrite_to_ram_url(url: str) -> str: @@ -112,12 +53,7 @@ def _build_ram_auth(self, url: str) -> tuple: url = _rewrite_to_ram_url(url) - auth = _AgentrunRamAuth( - access_key_id=ak, - access_key_secret=sk, - region=cfg.get_region_id(), - security_token=cfg.get_security_token() or None, - ) + auth = AgentrunRamAuth(config=cfg) return url, auth async def __aenter__(self): diff --git a/agentrun/utils/config.py b/agentrun/utils/config.py index 0c54abb..3a9fe4c 100644 --- a/agentrun/utils/config.py +++ b/agentrun/utils/config.py @@ -9,6 +9,8 @@ from dotenv import load_dotenv +from agentrun.utils.credential_context import get_request_sts + load_dotenv() @@ -130,20 +132,11 @@ def __init__( headers: 自定义请求头,可选 / Custom request headers, optional """ - if access_key_id is None: - access_key_id = get_env_with_default( - "", "AGENTRUN_ACCESS_KEY_ID", "ALIBABA_CLOUD_ACCESS_KEY_ID" - ) - if access_key_secret is None: - access_key_secret = get_env_with_default( - "", - "AGENTRUN_ACCESS_KEY_SECRET", - "ALIBABA_CLOUD_ACCESS_KEY_SECRET", - ) - if security_token is None: - security_token = get_env_with_default( - "", "AGENTRUN_SECURITY_TOKEN", "ALIBABA_CLOUD_SECURITY_TOKEN" - ) + # 凭证(ak/sk/sts)不在此处快照环境变量,改为在 getter 中**懒解析**: + # 显式传入 -> 请求级 overlay(仅当三者均未显式传入时)-> 环境变量。 + # 这样在 server 场景下,每次取值都能拿到 FC 头注入的最新 STS;非 server + # 场景(overlay 为空)行为与历史一致(getter 实时回退环境变量)。 + # 显式传入的字段保留为非 None,作为"用户显式凭证"的标记,overlay 不覆盖。 if account_id is None: account_id = get_env_with_default( "", "AGENTRUN_ACCOUNT_ID", "FC_ACCOUNT_ID" @@ -222,17 +215,55 @@ def __repr__(self) -> str: ]) ) + def _credentials_ambient(self) -> bool: + """三个凭证字段是否均未被显式传入。 + + 仅当全部为 ambient(None)时,才允许请求级 overlay 接管,避免把 + overlay 的新 sts 与用户显式传入的 ak/sk 混用。 + """ + return ( + self._access_key_id is None + and self._access_key_secret is None + and self._security_token is None + ) + def get_access_key_id(self) -> str: - """获取 Access Key ID""" - return self._access_key_id + """获取 Access Key ID(显式 -> 请求级 overlay -> 环境变量)""" + if self._access_key_id is not None: + return self._access_key_id + if self._credentials_ambient(): + overlay = get_request_sts() + if overlay is not None and overlay.access_key_id: + return overlay.access_key_id + return get_env_with_default( + "", "AGENTRUN_ACCESS_KEY_ID", "ALIBABA_CLOUD_ACCESS_KEY_ID" + ) def get_access_key_secret(self) -> str: - """获取 Access Key Secret""" - return self._access_key_secret + """获取 Access Key Secret(显式 -> 请求级 overlay -> 环境变量)""" + if self._access_key_secret is not None: + return self._access_key_secret + if self._credentials_ambient(): + overlay = get_request_sts() + if overlay is not None and overlay.access_key_secret: + return overlay.access_key_secret + return get_env_with_default( + "", + "AGENTRUN_ACCESS_KEY_SECRET", + "ALIBABA_CLOUD_ACCESS_KEY_SECRET", + ) def get_security_token(self) -> str: - """获取安全令牌""" - return self._security_token + """获取安全令牌(显式 -> 请求级 overlay -> 环境变量)""" + if self._security_token is not None: + return self._security_token + if self._credentials_ambient(): + overlay = get_request_sts() + if overlay is not None and overlay.security_token: + return overlay.security_token + return get_env_with_default( + "", "AGENTRUN_SECURITY_TOKEN", "ALIBABA_CLOUD_SECURITY_TOKEN" + ) def get_account_id(self) -> str: """获取账号 ID""" diff --git a/agentrun/utils/control_api.py b/agentrun/utils/control_api.py index 3e6728a..e13e395 100644 --- a/agentrun/utils/control_api.py +++ b/agentrun/utils/control_api.py @@ -11,6 +11,7 @@ from alibabacloud_tea_openapi import utils_models as open_api_util_models from agentrun.utils.config import Config +from agentrun.utils.credential_providers import build_openapi_credential # 延迟导入:BailianClient 和 GPDBClient 仅在 knowledgebase 模块使用, # 不在顶层导入以减少非 knowledgebase 场景的依赖加载。 @@ -52,9 +53,7 @@ def _get_client(self, config: Optional[Config] = None) -> "AgentRunClient": endpoint = endpoint.split("://", 1)[1] return AgentRunClient( open_api_util_models.Config( - access_key_id=cfg.get_access_key_id(), - access_key_secret=cfg.get_access_key_secret(), - security_token=cfg.get_security_token(), + credential=build_openapi_credential(cfg), region_id=cfg.get_region_id(), endpoint=endpoint, connect_timeout=cfg.get_timeout(), # type: ignore @@ -75,9 +74,7 @@ def _get_devs_client(self, config: Optional[Config] = None) -> "DevsClient": endpoint = endpoint.split("://", 1)[1] return DevsClient( open_api_util_models.Config( - access_key_id=cfg.get_access_key_id(), - access_key_secret=cfg.get_access_key_secret(), - security_token=cfg.get_security_token(), + credential=build_openapi_credential(cfg), region_id=cfg.get_region_id(), endpoint=endpoint, connect_timeout=cfg.get_timeout(), # type: ignore @@ -102,9 +99,7 @@ def _get_bailian_client( endpoint = endpoint.split("://", 1)[1] return BailianClient( open_api_util_models.Config( - access_key_id=cfg.get_access_key_id(), - access_key_secret=cfg.get_access_key_secret(), - security_token=cfg.get_security_token(), + credential=build_openapi_credential(cfg), region_id=cfg.get_region_id(), endpoint=endpoint, connect_timeout=cfg.get_timeout(), # type: ignore @@ -130,9 +125,7 @@ def _get_gpdb_client(self, config: Optional[Config] = None) -> "GPDBClient": return GPDBClient( open_api_util_models.Config( - access_key_id=cfg.get_access_key_id(), - access_key_secret=cfg.get_access_key_secret(), - security_token=cfg.get_security_token(), + credential=build_openapi_credential(cfg), region_id=cfg.get_region_id(), endpoint=endpoint, connect_timeout=cfg.get_timeout(), # type: ignore diff --git a/agentrun/utils/credential_context.py b/agentrun/utils/credential_context.py new file mode 100644 index 0000000..e016699 --- /dev/null +++ b/agentrun/utils/credential_context.py @@ -0,0 +1,193 @@ +"""请求级凭证上下文 / Per-request credential overlay. + +此模块提供一个进程级、按请求隔离的"最新 STS 凭证"覆盖层(overlay)。 + +背景 / Background: + 所有凭证(ak/sk/sts)默认来自环境变量,在 ``Config`` 构造时被读取。但 STS + 临时凭证会过期;部署在函数计算(FC)时,最新轮转后的 STS 通过**每次请求的 + HTTP 头**下发,而非进程级环境变量。因此需要一个按请求设置、所有 ``Config``/ + client 都能优先读取的"当前凭证"覆盖层。 + + The overlay is backed by a :class:`contextvars.ContextVar`, so it is: + + - **任务隔离 / task-isolated**: 并发请求各自拥有独立的副本,互不串号; + - **线程安全 / thread-safe**: ``run_in_threadpool`` 启动的同步处理器会拷贝 + 当前 context,因此也能读到; + - **流式安全 / streaming-safe**: ``StreamingResponse`` 的 body 生成器在请求 + 协程的 context 中创建,整条 SSE 流可见同一份凭证。 + + 默认值为 ``None`` —— 未设置时(非 server 场景、本地调用)overlay 完全不参与, + 行为与历史完全一致。 +""" + +from __future__ import annotations + +import contextlib +import contextvars +import os +from dataclasses import dataclass +from typing import Iterator, Mapping, Optional + +# FC 注入 STS 的默认头名(可经构造参数或环境变量覆盖)。 +DEFAULT_ACCESS_KEY_ID_HEADER = "x-fc-access-key-id" +DEFAULT_ACCESS_KEY_SECRET_HEADER = "x-fc-access-key-secret" +DEFAULT_SECURITY_TOKEN_HEADER = "x-fc-security-token" + + +@dataclass(frozen=True) +class StsCredential: + """一组完整的 STS 临时凭证 / An atomic STS credential triple. + + STS 轮转时 ak/sk/sts 三者一起更新,必须作为整体提供,绝不能把新 sts 与旧 + ak/sk 混用。某字段为 ``None`` 表示该来源未提供,调用方应回退到下一来源。 + """ + + access_key_id: Optional[str] = None + access_key_secret: Optional[str] = None + security_token: Optional[str] = None + + def is_complete(self) -> bool: + """三个字段是否齐全。 + + STS 轮转时 ak/sk/sts 同时更新,必须作为完整三元组才可作为 overlay, + 避免把新 sts 与陈旧/环境变量里的 ak/sk 混用。 + """ + return bool( + self.access_key_id + and self.access_key_secret + and self.security_token + ) + + +# 默认 None:未在 server 场景注入时 overlay 不参与,getter 回退到 env 快照。 +_current_sts: contextvars.ContextVar[Optional[StsCredential]] = ( + contextvars.ContextVar("agentrun_current_sts", default=None) +) + + +def set_request_sts(cred: Optional[StsCredential]) -> contextvars.Token: + """设置当前请求的 STS 覆盖层,返回用于复位的 token。 + + Args: + cred: 本次请求的最新 STS 三元组;传 ``None`` 表示清除覆盖。 + + Returns: + contextvars.Token: 传给 :func:`reset_request_sts` 以恢复上一状态。 + """ + return _current_sts.set(cred) + + +def reset_request_sts(token: contextvars.Token) -> None: + """恢复 :func:`set_request_sts` 之前的覆盖状态。""" + _current_sts.reset(token) + + +def get_request_sts() -> Optional[StsCredential]: + """获取当前请求的 STS 覆盖层;未设置时返回 ``None``。""" + return _current_sts.get() + + +def _resolve_header_name( + explicit: Optional[str], env_key: str, default: str +) -> str: + """解析头名:构造参数 > 环境变量 > 默认值;统一转小写。""" + return (explicit or os.getenv(env_key) or default).lower() + + +def sts_from_headers( + headers: Mapping[str, str], + *, + access_key_id_header: Optional[str] = None, + access_key_secret_header: Optional[str] = None, + security_token_header: Optional[str] = None, +) -> Optional[StsCredential]: + """从请求头映射解析 STS 三元组;不齐全则返回 ``None``。 + + 仅当 ak/sk/sts 三者齐全才视为有效刷新(避免把新 sts 与陈旧/环境变量里的 + ak/sk 混用)。``headers`` 可为任意 Mapping(如 ``dict`` 或 Starlette + ``Headers``),按头名**大小写不敏感**查找。头名优先级:参数 > 环境变量 + (``AGENTRUN_STS_HEADER_*``)> 默认(``x-fc-*``)。 + """ + ak_name = _resolve_header_name( + access_key_id_header, + "AGENTRUN_STS_HEADER_ACCESS_KEY_ID", + DEFAULT_ACCESS_KEY_ID_HEADER, + ) + sk_name = _resolve_header_name( + access_key_secret_header, + "AGENTRUN_STS_HEADER_ACCESS_KEY_SECRET", + DEFAULT_ACCESS_KEY_SECRET_HEADER, + ) + sts_name = _resolve_header_name( + security_token_header, + "AGENTRUN_STS_HEADER_SECURITY_TOKEN", + DEFAULT_SECURITY_TOKEN_HEADER, + ) + + lower = {str(k).lower(): v for k, v in headers.items()} + cred = StsCredential( + access_key_id=lower.get(ak_name), + access_key_secret=lower.get(sk_name), + security_token=lower.get(sts_name), + ) + return cred if cred.is_complete() else None + + +@contextlib.contextmanager +def use_sts_credentials( + access_key_id: Optional[str] = None, + access_key_secret: Optional[str] = None, + security_token: Optional[str] = None, +) -> Iterator[StsCredential]: + """在 ``with`` 块内临时使用给定 STS 临时凭证(请求级 overlay),退出自动复位。 + + 适用于**不经过 agentrun server** 的场景:自有 FastAPI / Flask / Django,或 + 非 HTTP 的任务里,从上游 / 请求头拿到最新 STS 后注入——块内所有 SDK 调用 + (以及其内创建的 asyncio 任务)即使用这组凭证。 + + Examples: + >>> with use_sts_credentials(ak, sk, sts): + ... knowledgebase.retrieve(...) # 使用最新 STS + + Note: + 基于 ``contextvars``,按当前任务/线程隔离;用户自行 ``threading.Thread`` + 起的裸线程不会继承(``asyncio.create_task`` 会)。 + """ + cred = StsCredential(access_key_id, access_key_secret, security_token) + token = set_request_sts(cred) + try: + yield cred + finally: + reset_request_sts(token) + + +@contextlib.contextmanager +def use_sts_from_headers( + headers: Mapping[str, str], + *, + access_key_id_header: Optional[str] = None, + access_key_secret_header: Optional[str] = None, + security_token_header: Optional[str] = None, +) -> Iterator[Optional[StsCredential]]: + """从请求头映射解析 STS 并在 ``with`` 块内生效;三元组不齐全则不覆盖(透传)。 + + 与 :class:`agentrun.server.sts_middleware.StsRefreshMiddleware` 共用同一套 + 解析逻辑。适用于在自有 Web 框架里手动接入: + + >>> with use_sts_from_headers(request.headers): + ... await invoke_agent(request) + """ + cred = sts_from_headers( + headers, + access_key_id_header=access_key_id_header, + access_key_secret_header=access_key_secret_header, + security_token_header=security_token_header, + ) + if cred is None: + yield None + return + token = set_request_sts(cred) + try: + yield cred + finally: + reset_request_sts(token) diff --git a/agentrun/utils/credential_providers.py b/agentrun/utils/credential_providers.py new file mode 100644 index 0000000..8eae1d7 --- /dev/null +++ b/agentrun/utils/credential_providers.py @@ -0,0 +1,74 @@ +"""动态凭证 Provider / Dynamic credential providers. + +把 :class:`agentrun.utils.config.Config` 接入阿里云 SDK 的"每次请求动态取证" +机制,使长生命周期 client 也能在每次请求时拿到最新 STS(请求级 overlay 优先, +再回退环境变量)。 + +- :class:`OpenApiCredentialsProvider` 适配 ``alibabacloud_credentials`` 的 + ``ICredentialsProvider``,供 ``alibabacloud_tea_openapi`` 控制面 client 使用。 +- TableStore 的 ``CredentialsProvider`` 适配见 + :func:`agentrun.conversation_service.utils.build_ots_credentials_provider` + (延迟导入 ``tablestore``,避免在非会话场景引入该可选依赖)。 + +约定 / Convention: + 所有阿里云 / TableStore client 一律通过 provider 注入凭证,**不要**再传静态 + ak/sk/sts —— 否则凭证会在 client 构造时被冻结,STS 过期后请求全部失败。 +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from alibabacloud_credentials.models import CredentialModel +from alibabacloud_credentials_api import ICredentialsProvider + +if TYPE_CHECKING: + from agentrun.utils.config import Config + +PROVIDER_NAME = "agentrun_context" + + +class OpenApiCredentialsProvider(ICredentialsProvider): + """从 ``Config`` 实时解析凭证的 alibabacloud OpenAPI provider。 + + ``alibabacloud_tea_openapi`` 的 client 在**每个请求方法内部**调用 + ``credential.get_credential()``,因此这里返回的凭证总是当前最新值(请求级 + overlay 优先,再回退环境变量)。 + """ + + def __init__(self, config: "Config"): + self._config = config + + def get_provider_name(self) -> str: + return PROVIDER_NAME + + def get_credentials(self) -> CredentialModel: + cfg = self._config + security_token = cfg.get_security_token() or None + return CredentialModel( + access_key_id=cfg.get_access_key_id(), + access_key_secret=cfg.get_access_key_secret(), + security_token=security_token, + # 语义化取值,仅在直接使用本 provider 时有意义。注意:经 + # build_openapi_credential 包进 alibabacloud Client 后, + # client.get_credential().type 实际报告为 provider_name + # ("agentrun_context")——任意非 bearer/id_token 的 type 都会进入 + # AK/STS 签名分支并自动附带 security_token,故此处取值不影响签名。 + type="sts" if security_token else "access_key", + provider_name=PROVIDER_NAME, + ) + + async def get_credentials_async(self) -> CredentialModel: + return self.get_credentials() + + +def build_openapi_credential(config: "Config"): + """构造可直接传给 ``open_api_util_models.Config(credential=...)`` 的凭证 client。 + + Returns: + ``alibabacloud_credentials.client.Client`` 实例,内部包裹 + :class:`OpenApiCredentialsProvider`,每次请求动态取证。 + """ + from alibabacloud_credentials.client import Client as CredentialsClient + + return CredentialsClient(provider=OpenApiCredentialsProvider(config)) diff --git a/agentrun/utils/ram_signature/auth.py b/agentrun/utils/ram_signature/auth.py new file mode 100644 index 0000000..ea4573c --- /dev/null +++ b/agentrun/utils/ram_signature/auth.py @@ -0,0 +1,65 @@ +"""共享的 httpx RAM 签名 Auth handler / Shared httpx RAM-signing auth handler.""" + +from __future__ import annotations + +from typing import Generator, Optional, TYPE_CHECKING + +import httpx + +from agentrun.utils.log import logger +from agentrun.utils.ram_signature.signer import get_agentrun_signed_headers + +if TYPE_CHECKING: + from agentrun.utils.config import Config + + +class AgentrunRamAuth(httpx.Auth): + """httpx Auth handler:为每次请求动态生成 RAM 签名。 + + SSE 场景下同一个 ``httpx.AsyncClient`` 会发出 GET(SSE 连接)和 POST + (消息发送)等不同请求,URL / method / body 各异,因此必须 per-request + 计算签名,不能在 client 初始化时一次性设置 headers。 + + 持有 ``Config`` 而非快照凭证:``auth_flow`` 每次请求实时取 ak/sk/sts, + 使长连接(一次建连、多请求复用)也能拿到请求级 overlay 注入的最新 STS。 + """ + + def __init__(self, config: "Config"): + self._config = config + + def auth_flow( + self, request: httpx.Request + ) -> Generator[httpx.Request, httpx.Response, None]: + url = str(request.url) + method = request.method + + body: Optional[bytes] = None + if request.content: + body = request.content + + content_type: Optional[str] = request.headers.get("content-type") + + cfg = self._config + try: + signed = get_agentrun_signed_headers( + url=url, + method=method, + access_key_id=cfg.get_access_key_id(), + access_key_secret=cfg.get_access_key_secret(), + security_token=cfg.get_security_token() or None, + region=cfg.get_region_id(), + product="agentrun", + body=body, + content_type=content_type, + ) + for k, v in signed.items(): + request.headers[k] = v + logger.debug( + "applied RAM signature for %s request to %s", + method, + url[:80] + ("..." if len(url) > 80 else ""), + ) + except ValueError as e: + logger.warning("RAM signing skipped: %s", e) + + yield request diff --git a/tests/unittests/conversation_service/test_session_store.py b/tests/unittests/conversation_service/test_session_store.py index 47862b9..27ccc6d 100644 --- a/tests/unittests/conversation_service/test_session_store.py +++ b/tests/unittests/conversation_service/test_session_store.py @@ -785,9 +785,15 @@ def test_with_sts_token( ): store = SessionStore.from_memory_collection("test-mc") - assert isinstance(store, SessionStore) - ots_kwargs = mock_ots_cls.call_args.kwargs - assert ots_kwargs.get("sts_token") == "sts_token" + # OTS 现通过 credentials_provider 动态注入凭证(不再传静态 + # sts_token);在 env 仍生效的上下文内解析 provider 校验 STS。 + assert isinstance(store, SessionStore) + ots_kwargs = mock_ots_cls.call_args.kwargs + provider = ots_kwargs.get("credentials_provider") + assert provider is not None + creds = provider.get_credentials() + assert creds.get_access_key_id() == "ak_id" + assert creds.get_security_token() == "sts_token" @patch("tablestore.WriteRetryPolicy") @patch("tablestore.AsyncOTSClient") diff --git a/tests/unittests/knowledgebase/api/test_data.py b/tests/unittests/knowledgebase/api/test_data.py index 6716cdc..1c1a58c 100644 --- a/tests/unittests/knowledgebase/api/test_data.py +++ b/tests/unittests/knowledgebase/api/test_data.py @@ -532,6 +532,45 @@ def test_init_minimal(self): assert api.provider_settings is None +class TestBailianDataAPINormalizeSearchFilters: + """测试 BailianDataAPI._normalize_search_filters(SearchFilters 规范化)""" + + def test_none_returns_none(self): + """None 直接返回 None。""" + assert BailianDataAPI._normalize_search_filters(None) is None + + def test_empty_list_returns_empty(self): + """空列表返回空列表。""" + assert BailianDataAPI._normalize_search_filters([]) == [] + + def test_list_and_dict_values_json_serialized(self): + """list / dict 值被 JSON 序列化为字符串,标量值转 str。""" + result = BailianDataAPI._normalize_search_filters([ + { + "tags": ["0216", "0217"], + "meta": {"k": "v"}, + "name": "kb", + "count": 5, + } + ]) + assert result == [ + { + "tags": '["0216", "0217"]', + "meta": '{"k": "v"}', + "name": "kb", + "count": "5", + } + ] + + def test_multiple_filter_items(self): + """多个 filter item 各自独立规范化。""" + result = BailianDataAPI._normalize_search_filters([ + {"a": ["x"]}, + {"b": "y"}, + ]) + assert result == [{"a": '["x"]'}, {"b": "y"}] + + class TestBailianDataAPIRetrieve: """测试 BailianDataAPI.retrieve 方法""" diff --git a/tests/unittests/test_sts_refresh.py b/tests/unittests/test_sts_refresh.py new file mode 100644 index 0000000..ee7952e --- /dev/null +++ b/tests/unittests/test_sts_refresh.py @@ -0,0 +1,543 @@ +"""STS 静默刷新机制单元测试 / Unit tests for silent STS refresh. + +覆盖: +- 请求级凭证 overlay(contextvars) +- Config getter 的解析优先级(显式 > overlay(仅 ambient) > 环境变量) +- alibabacloud OpenAPI 动态 credential provider +- TableStore 动态 credentials_provider +- 服务端 StsRefreshMiddleware 端到端(async / sync / streaming / 隔离) +""" + +from __future__ import annotations + +import pytest + +from agentrun.utils.config import Config +from agentrun.utils.credential_context import ( + StsCredential, + get_request_sts, + reset_request_sts, + set_request_sts, + use_sts_credentials, + use_sts_from_headers, +) + + +@pytest.fixture +def overlay(): + """在 with 块内设置 overlay,退出自动复位。""" + from contextlib import contextmanager + + @contextmanager + def _set(ak=None, sk=None, sts=None): + token = set_request_sts( + StsCredential( + access_key_id=ak, access_key_secret=sk, security_token=sts + ) + ) + try: + yield + finally: + reset_request_sts(token) + + return _set + + +# --------------------------------------------------------------------------- # +# overlay 基础行为 +# --------------------------------------------------------------------------- # +def test_overlay_default_is_none(): + assert get_request_sts() is None + + +def test_overlay_set_reset(): + token = set_request_sts(StsCredential("a", "b", "c")) + try: + cur = get_request_sts() + assert cur is not None and cur.access_key_id == "a" + finally: + reset_request_sts(token) + assert get_request_sts() is None + + +# --------------------------------------------------------------------------- # +# Config getter 解析优先级 +# --------------------------------------------------------------------------- # +def test_ambient_falls_back_to_env(monkeypatch): + monkeypatch.setenv("AGENTRUN_ACCESS_KEY_ID", "ENV_AK") + monkeypatch.setenv("AGENTRUN_ACCESS_KEY_SECRET", "ENV_SK") + cfg = Config() + assert cfg.get_access_key_id() == "ENV_AK" + assert cfg.get_access_key_secret() == "ENV_SK" + assert cfg.get_security_token() == "" + + +def test_ambient_prefers_overlay(overlay): + cfg = Config() + with overlay(ak="OV_AK", sk="OV_SK", sts="OV_STS"): + assert cfg.get_access_key_id() == "OV_AK" + assert cfg.get_access_key_secret() == "OV_SK" + assert cfg.get_security_token() == "OV_STS" + + +def test_explicit_not_overridden_by_overlay(overlay): + cfg = Config(access_key_id="USER_AK", access_key_secret="USER_SK") + with overlay(ak="OV_AK", sk="OV_SK", sts="OV_STS"): + assert cfg.get_access_key_id() == "USER_AK" + assert cfg.get_access_key_secret() == "USER_SK" + # 显式设置了 ak/sk -> 非 ambient -> sts 不得从 overlay 取(避免混用) + assert cfg.get_security_token() == "" + + +def test_overlay_dropped_after_reset(monkeypatch, overlay): + monkeypatch.setenv("AGENTRUN_ACCESS_KEY_ID", "ENV_AK") + cfg = Config() + with overlay(ak="OV_AK"): + assert cfg.get_access_key_id() == "OV_AK" + assert cfg.get_access_key_id() == "ENV_AK" + + +def test_with_configs_stays_overlay_aware(overlay): + merged = Config.with_configs(Config(), Config()) + with overlay(ak="OV_AK", sts="OV_STS"): + assert merged.get_access_key_id() == "OV_AK" + assert merged.get_security_token() == "OV_STS" + + +def test_with_configs_preserves_explicit(overlay): + user = Config(access_key_id="USER_AK", access_key_secret="USER_SK") + merged = Config.with_configs(Config(), user) + with overlay(ak="OV_AK"): + assert merged.get_access_key_id() == "USER_AK" + + +# --------------------------------------------------------------------------- # +# OpenAPI provider(alibabacloud) +# --------------------------------------------------------------------------- # +def test_openapi_provider_per_request_fresh(overlay): + from agentrun.utils.credential_providers import build_openapi_credential + + cred_client = build_openapi_credential(Config()) + with overlay(ak="OV_AK", sk="OV_SK", sts="OV_STS"): + model = cred_client.get_credential() + assert model.access_key_id == "OV_AK" + assert model.access_key_secret == "OV_SK" + assert model.security_token == "OV_STS" + + +def test_openapi_provider_reflects_change_between_calls(overlay): + from agentrun.utils.credential_providers import build_openapi_credential + + cred_client = build_openapi_credential(Config()) + with overlay(ak="AK1", sk="SK1", sts="STS1"): + assert cred_client.get_credential().access_key_id == "AK1" + with overlay(ak="AK2", sk="SK2", sts="STS2"): + assert cred_client.get_credential().access_key_id == "AK2" + + +# --------------------------------------------------------------------------- # +# TableStore provider +# --------------------------------------------------------------------------- # +def test_ots_provider_per_request_fresh(overlay): + from agentrun.conversation_service.utils import ( + build_ots_credentials_provider, + ) + + provider = build_ots_credentials_provider(Config()) + with overlay(ak="OV_AK", sk="OV_SK", sts="OV_STS"): + creds = provider.get_credentials() + assert creds.get_access_key_id() == "OV_AK" + assert creds.get_access_key_secret() == "OV_SK" + assert creds.get_security_token() == "OV_STS" + + +def test_ots_provider_empty_sts_is_none(monkeypatch): + monkeypatch.setenv("AGENTRUN_ACCESS_KEY_ID", "ENV_AK") + monkeypatch.setenv("AGENTRUN_ACCESS_KEY_SECRET", "ENV_SK") + from agentrun.conversation_service.utils import ( + build_ots_credentials_provider, + ) + + creds = build_ots_credentials_provider(Config()).get_credentials() + assert creds.get_access_key_id() == "ENV_AK" + assert creds.get_security_token() is None + + +# --------------------------------------------------------------------------- # +# 服务端中间件端到端 +# --------------------------------------------------------------------------- # +def _build_app(): + from fastapi import FastAPI + from fastapi.responses import StreamingResponse + + from agentrun.server.sts_middleware import StsRefreshMiddleware + + app = FastAPI() + # 显式启用,避免受 dev 环境 AGENTRUN_STS_REFRESH_ENABLED 影响。 + app.add_middleware(StsRefreshMiddleware, enabled=True) + + @app.get("/async") + async def _async(): + return { + "ak": Config().get_access_key_id(), + "sts": Config().get_security_token(), + } + + @app.get("/sync") + def _sync(): # 在 threadpool 中执行 + return {"ak": Config().get_access_key_id()} + + @app.get("/stream") + async def _stream(): + cfg = Config() + + async def gen(): + yield ( + f"ak={cfg.get_access_key_id()};" + f"sts={cfg.get_security_token()}" + ).encode() + + return StreamingResponse(gen()) + + return app + + +_HEADERS = { + "x-fc-access-key-id": "H_AK", + "x-fc-access-key-secret": "H_SK", + "x-fc-security-token": "H_STS", +} + + +def test_middleware_async_overlay(): + from fastapi.testclient import TestClient + + client = TestClient(_build_app()) + assert client.get("/async", headers=_HEADERS).json() == { + "ak": "H_AK", + "sts": "H_STS", + } + + +def test_middleware_sync_threadpool_overlay(): + from fastapi.testclient import TestClient + + client = TestClient(_build_app()) + assert client.get("/sync", headers=_HEADERS).json() == {"ak": "H_AK"} + + +def test_middleware_streaming_overlay(): + from fastapi.testclient import TestClient + + client = TestClient(_build_app()) + assert client.get("/stream", headers=_HEADERS).text == "ak=H_AK;sts=H_STS" + + +def test_middleware_no_header_falls_back_to_env(monkeypatch): + from fastapi.testclient import TestClient + + monkeypatch.setenv("AGENTRUN_ACCESS_KEY_ID", "ENV_AK") + client = TestClient(_build_app()) + assert client.get("/async").json() == {"ak": "ENV_AK", "sts": ""} + + +def test_middleware_resets_overlay_after_request(): + from fastapi.testclient import TestClient + + client = TestClient(_build_app()) + client.get("/async", headers=_HEADERS) + assert get_request_sts() is None + + +@pytest.mark.parametrize( + "import_path", + [ + "agentrun.toolset.api.mcp", + "agentrun.tool.api.mcp", + "agentrun.tool.api.openapi", + ], +) +def test_ram_auth_helper_signs_with_live_sts(import_path, overlay): + """长生命周期的 _AgentrunRamAuth 实例应在每次 auth_flow 读到最新 STS。""" + import importlib + + import httpx + + module = importlib.import_module(import_path) + # 构造一次(模拟长连接 SSE 会话只在打开时建一次 auth) + # 三个模块均复用共享的 AgentrunRamAuth。 + auth = module.AgentrunRamAuth(config=Config()) + + def signed_token() -> str: + request = httpx.Request( + "POST", + "https://x.agentrun-data.cn-hangzhou.aliyuncs.com/m", + content=b"{}", + ) + flow = auth.auth_flow(request) + next(flow) # 触发签名,写入 request.headers + return request.headers.get("x-acs-security-token") + + with overlay(ak="AK1", sk="SK1", sts="STS1"): + assert signed_token() == "STS1" + # 同一个 auth 实例:overlay 切换后应签出新的 sts + with overlay(ak="AK2", sk="SK2", sts="STS2"): + assert signed_token() == "STS2" + + +def test_middleware_custom_header_names(monkeypatch): + from fastapi import FastAPI + from fastapi.testclient import TestClient + + from agentrun.server.sts_middleware import StsRefreshMiddleware + + app = FastAPI() + app.add_middleware( + StsRefreshMiddleware, + enabled=True, + access_key_id_header="x-custom-ak", + access_key_secret_header="x-custom-sk", + security_token_header="x-custom-sts", + ) + + @app.get("/c") + async def _c(): + return { + "ak": Config().get_access_key_id(), + "sts": Config().get_security_token(), + } + + client = TestClient(app) + resp = client.get( + "/c", + headers={ + "x-custom-ak": "C_AK", + "x-custom-sk": "C_SK", + "x-custom-sts": "C_STS", + }, + ).json() + assert resp == {"ak": "C_AK", "sts": "C_STS"} + + +def test_middleware_partial_headers_ignored(monkeypatch): + """H1:部分头集合(缺一)不应设置 overlay,避免与 env 混用。""" + from fastapi.testclient import TestClient + + monkeypatch.setenv("AGENTRUN_ACCESS_KEY_ID", "ENV_AK") + monkeypatch.setenv("AGENTRUN_ACCESS_KEY_SECRET", "ENV_SK") + client = TestClient(_build_app()) + # 只发 ak + sts(缺 sk)→ 不构成完整三元组 → 回退 env,不混用。 + resp = client.get( + "/async", + headers={"x-fc-access-key-id": "H_AK", "x-fc-security-token": "H_STS"}, + ).json() + assert resp == {"ak": "ENV_AK", "sts": ""} + + +def test_middleware_enabled_by_default(monkeypatch): + """默认启用(不看 FC 环境,无需任何开关)。""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + from agentrun.server.sts_middleware import StsRefreshMiddleware + + monkeypatch.delenv("FC_REGION", raising=False) + monkeypatch.delenv("AGENTRUN_STS_REFRESH_ENABLED", raising=False) + + app = FastAPI() + app.add_middleware(StsRefreshMiddleware) # enabled=None → 默认启用 + + @app.get("/x") + async def _x(): + return {"ak": Config().get_access_key_id()} + + client = TestClient(app) + # 默认开:携带完整 x-fc-* 头即生效。 + assert client.get("/x", headers=_HEADERS).json() == {"ak": "H_AK"} + + +def test_middleware_disabled_via_env(monkeypatch): + """仅在 AGENTRUN_STS_REFRESH_ENABLED 设为假值时关闭。""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + from agentrun.server.sts_middleware import StsRefreshMiddleware + + monkeypatch.setenv("AGENTRUN_STS_REFRESH_ENABLED", "false") + monkeypatch.setenv("AGENTRUN_ACCESS_KEY_ID", "ENV_AK") + + app = FastAPI() + app.add_middleware(StsRefreshMiddleware) # enabled=None → 读环境变量 -> 关闭 + + @app.get("/x") + async def _x(): + return {"ak": Config().get_access_key_id()} + + client = TestClient(app) + # 已关闭:即便携带完整 x-fc-* 头也不覆盖 env 凭证。 + assert client.get("/x", headers=_HEADERS).json() == {"ak": "ENV_AK"} + + +def test_middleware_disabled_via_constructor(monkeypatch): + """构造参数 enabled=False 显式关闭。""" + from fastapi import FastAPI + from fastapi.testclient import TestClient + + from agentrun.server.sts_middleware import StsRefreshMiddleware + + monkeypatch.setenv("AGENTRUN_ACCESS_KEY_ID", "ENV_AK") + + app = FastAPI() + app.add_middleware(StsRefreshMiddleware, enabled=False) + + @app.get("/x") + async def _x(): + return {"ak": Config().get_access_key_id()} + + client = TestClient(app) + assert client.get("/x", headers=_HEADERS).json() == {"ak": "ENV_AK"} + + +def test_invoker_sync_path_sees_overlay(): + """C2:AgentInvoker 的同步 handler + 同步生成器路径应读到请求级 overlay。 + + 生产路径为 async 端点 -> AgentInvoker.invoke_stream -> run_in_executor, + 必须拷贝 context,否则 sync 路径取不到最新 STS(回退陈旧 env)。 + """ + import asyncio + + from agentrun.server.invoker import AgentInvoker + from agentrun.server.model import AgentRequest, EventType + + def sync_handler(_request): + # 同步 handler 返回同步生成器:两段都经 run_in_executor。 + def gen(): + cfg = Config() + yield ( + f"ak={cfg.get_access_key_id()};" + f"sts={cfg.get_security_token()}" + ) + + return gen() + + invoker = AgentInvoker(sync_handler) + + async def run(): + token = set_request_sts(StsCredential("OV_AK", "OV_SK", "OV_STS")) + try: + return [ + event + async for event in invoker.invoke_stream(AgentRequest()) + ] + finally: + reset_request_sts(token) + + events = asyncio.run(run()) + text = "".join( + e.data.get("delta", "") + for e in events + if e.event == EventType.TEXT + ) + assert text == "ak=OV_AK;sts=OV_STS", text + + +# --------------------------------------------------------------------------- # +# 公开 API:非 server 场景手动注入 +# --------------------------------------------------------------------------- # +def test_use_sts_credentials_context_manager(monkeypatch): + monkeypatch.setenv("AGENTRUN_ACCESS_KEY_ID", "ENV_AK") + cfg = Config() + assert cfg.get_access_key_id() == "ENV_AK" + with use_sts_credentials("OV_AK", "OV_SK", "OV_STS"): + assert cfg.get_access_key_id() == "OV_AK" + assert cfg.get_access_key_secret() == "OV_SK" + assert cfg.get_security_token() == "OV_STS" + # 退出自动复位 + assert cfg.get_access_key_id() == "ENV_AK" + assert get_request_sts() is None + + +def test_use_sts_from_headers_complete_case_insensitive(): + cfg = Config() + headers = { + "X-Fc-Access-Key-Id": "H_AK", # 大小写不敏感 + "x-fc-access-key-secret": "H_SK", + "X-FC-SECURITY-TOKEN": "H_STS", + } + with use_sts_from_headers(headers) as cred: + assert cred is not None + assert cfg.get_access_key_id() == "H_AK" + assert cfg.get_security_token() == "H_STS" + assert get_request_sts() is None + + +def test_use_sts_from_headers_partial_no_override(monkeypatch): + monkeypatch.setenv("AGENTRUN_ACCESS_KEY_ID", "ENV_AK") + cfg = Config() + # 只有 sts、缺 ak/sk -> 不构成完整三元组 -> 不覆盖 + with use_sts_from_headers({"x-fc-security-token": "H_STS"}) as cred: + assert cred is None + assert cfg.get_access_key_id() == "ENV_AK" + assert cfg.get_security_token() == "" + assert get_request_sts() is None + + +def test_sts_from_headers_helper(): + from agentrun.utils.credential_context import sts_from_headers + + assert sts_from_headers({"x-fc-access-key-id": "a"}) is None # 不齐全 + cred = sts_from_headers({ + "x-fc-access-key-id": "a", + "x-fc-access-key-secret": "b", + "x-fc-security-token": "c", + }) + assert cred is not None + assert ( + cred.access_key_id, + cred.access_key_secret, + cred.security_token, + ) == ("a", "b", "c") + + +def test_public_exports_available(): + import agentrun + + for name in ( + "StsCredential", + "use_sts_credentials", + "use_sts_from_headers", + ): + assert hasattr(agentrun, name), f"{name} not exported" + assert name in agentrun.__all__, f"{name} missing from __all__" + + +def test_middleware_background_task_sees_overlay(): + """纯 ASGI:响应后的 background task 仍能读到请求级 overlay。 + + BaseHTTPMiddleware 下 background task 可能在 overlay 复位后才运行;纯 ASGI + 中间件持有 overlay 直到 app(含 background)整体结束,故覆盖到位。 + """ + from fastapi import FastAPI + from fastapi.responses import JSONResponse + from fastapi.testclient import TestClient + from starlette.background import BackgroundTask + + from agentrun.server.sts_middleware import StsRefreshMiddleware + + captured: dict = {} + + def _bg(): + cfg = Config() + captured["ak"] = cfg.get_access_key_id() + captured["sts"] = cfg.get_security_token() + + app = FastAPI() + app.add_middleware(StsRefreshMiddleware, enabled=True) + + @app.get("/bg") + async def _ep(): + return JSONResponse({"ok": True}, background=BackgroundTask(_bg)) + + client = TestClient(app) + client.get("/bg", headers=_HEADERS) + assert captured == {"ak": "H_AK", "sts": "H_STS"}, captured diff --git a/tests/unittests/utils/test_config.py b/tests/unittests/utils/test_config.py index 013e4ff..029293c 100644 --- a/tests/unittests/utils/test_config.py +++ b/tests/unittests/utils/test_config.py @@ -16,8 +16,12 @@ def test_init_without_parameters(self): }, ): config = Config() - assert config._access_key_id == "mock-access-key-id" - assert config._access_key_secret == "mock-access-key-secret" + # 凭证改为懒解析:未显式传入时私有字段保持 None(ambient), + # 实际值经 getter(overlay 优先 -> 环境变量)解析。 + assert config._access_key_id is None + assert config._access_key_secret is None + assert config.get_access_key_id() == "mock-access-key-id" + assert config.get_access_key_secret() == "mock-access-key-secret" assert config._account_id == "mock-account-id" assert config._use_vpc_endpoint is False diff --git a/tests/unittests/utils/test_control_api.py b/tests/unittests/utils/test_control_api.py index 5757413..cbd2881 100644 --- a/tests/unittests/utils/test_control_api.py +++ b/tests/unittests/utils/test_control_api.py @@ -46,8 +46,10 @@ def test_get_client_basic(self, mock_client_class): assert mock_client_class.called call_args = mock_client_class.call_args config_arg = call_args[0][0] - assert config_arg.access_key_id == "ak" - assert config_arg.access_key_secret == "sk" + # 凭证经 credential provider 注入(不再传静态 ak/sk/sts)。 + creds = config_arg.credential.get_credential() + assert creds.access_key_id == "ak" + assert creds.access_key_secret == "sk" assert config_arg.region_id == "cn-hangzhou" @patch("agentrun.utils.control_api.AgentRunClient") @@ -109,7 +111,8 @@ def test_get_client_with_override_config(self, mock_client_class): call_args = mock_client_class.call_args config_arg = call_args[0][0] - assert config_arg.access_key_id == "override-ak" + creds = config_arg.credential.get_credential() + assert creds.access_key_id == "override-ak" assert config_arg.region_id == "cn-shanghai" @patch("agentrun.utils.control_api.AgentRunClient") @@ -149,7 +152,8 @@ def test_get_client_with_security_token(self, mock_client_class): call_args = mock_client_class.call_args config_arg = call_args[0][0] - assert config_arg.security_token == "sts-token" + creds = config_arg.credential.get_credential() + assert creds.security_token == "sts-token" class TestControlAPIGetDevsClient: @@ -174,8 +178,10 @@ def test_get_devs_client_basic(self, mock_client_class): assert mock_client_class.called call_args = mock_client_class.call_args config_arg = call_args[0][0] - assert config_arg.access_key_id == "ak" - assert config_arg.access_key_secret == "sk" + # 凭证经 credential provider 注入(不再传静态 ak/sk/sts)。 + creds = config_arg.credential.get_credential() + assert creds.access_key_id == "ak" + assert creds.access_key_secret == "sk" assert config_arg.region_id == "cn-hangzhou" @patch("agentrun.utils.control_api.DevsClient") @@ -235,7 +241,8 @@ def test_get_devs_client_with_override_config(self, mock_client_class): call_args = mock_client_class.call_args config_arg = call_args[0][0] - assert config_arg.access_key_id == "override-ak" + creds = config_arg.credential.get_credential() + assert creds.access_key_id == "override-ak" @patch("agentrun.utils.control_api.DevsClient") def test_get_devs_client_without_protocol_prefix(self, mock_client_class): @@ -300,8 +307,10 @@ def test_get_bailian_client_basic(self, mock_client_class): assert mock_client_class.called call_args = mock_client_class.call_args config_arg = call_args[0][0] - assert config_arg.access_key_id == "ak" - assert config_arg.access_key_secret == "sk" + # 凭证经 credential provider 注入(不再传静态 ak/sk/sts)。 + creds = config_arg.credential.get_credential() + assert creds.access_key_id == "ak" + assert creds.access_key_secret == "sk" assert config_arg.region_id == "cn-hangzhou" @patch("alibabacloud_bailian20231229.client.Client")