diff --git a/docs/mkdocs/en/model.md b/docs/mkdocs/en/model.md index e433be21..47273b59 100644 --- a/docs/mkdocs/en/model.md +++ b/docs/mkdocs/en/model.md @@ -9,6 +9,7 @@ Models in tRPC-Agent have the following core features: - **Multi-protocol support**: Provides OpenAIModel, AnthropicModel, LiteLLMModel, etc., compatible with most OpenAI-like and Anthropic interfaces both internally and externally - **Streaming response support**: Supports streaming output for real-time interactive experiences - **Multimodal capabilities**: Supports multimodal content processing including text, images, etc. (e.g., Hunyuan multimodal models) +- **Prompt Cache support**: Provides unified prompt cache configuration across OpenAI, Anthropic, and LiteLLM routes to reduce repeated input cost for long prompts and multi-turn conversations - **Extensible configuration**: Supports custom configuration options such as GenerateContentConfig, HttpOptions, client_args to meet various scenario requirements ## Quick Start @@ -398,6 +399,120 @@ LlmAgent( ) ``` +### Prompt Cache + +Prompt Cache is useful when system prompts are long, tool definitions are large, or multi-turn conversations share a stable prefix. Many providers, including OpenAI-compatible serving stacks such as `openai/sglang`, already support automatic prefix caching on the server side. `tRPC-Agent` does not replace the provider's cache implementation; instead, it provides unified management hints and normalized observability for prompt cache behavior. + +`tRPC-Agent` exposes these capabilities through `PromptCacheConfig`, which currently applies to `OpenAIModel`, `AnthropicModel`, and provider-prefixed `LiteLLMModel`. Because providers expose different cache controls and usage fields, the SDK maps management options and cache usage metrics to each provider protocol on a best-effort basis: + +| Provider | SDK Capability | Typical Usage Fields | +|----------|----------------|----------------------| +| Anthropic | Manages explicit `cache_control` breakpoints according to `breakpoints` | `cache_read_input_tokens`, `cache_creation_input_tokens` | +| OpenAI / OpenAI-compatible endpoints | Passes cache hints such as `prompt_cache_key` / `prompt_cache_retention` when supported; provider-side automatic prefix caching still owns cache creation and lookup | Usually only `cache_read_input_tokens` | +| LiteLLM | Chooses the Anthropic-style or OpenAI-style cache management path according to the `provider/model` prefix, while preserving provider-native automatic caching such as `openai/sglang` | Depends on the final provider route | + +#### Model-level configuration + +Model-level configuration becomes the default prompt-cache management and observability configuration for the model instance. Use it when all requests can share the same cache hints: + +```python +from trpc_agent_sdk.configs import PromptCacheConfig +from trpc_agent_sdk.models import OpenAIModel + +model = OpenAIModel( + model_name="gpt-4o", + api_key="your-api-key", + prompt_cache_config=PromptCacheConfig( + enabled=True, + ttl="24h", + prompt_cache_key="weather-concierge-v1", + ), +) +``` + +#### Per-run override + +You can also override prompt-cache settings for a single `runner.run_async()` call through `RunConfig.prompt_cache`. The per-run config overrides model-level settings field by field, which is useful when setting different cache hints by user, tenant, or business scenario: + +```python +from trpc_agent_sdk.configs import PromptCacheConfig +from trpc_agent_sdk.configs import RunConfig + +async for event in runner.run_async( + user_id=user_id, + session_id=session_id, + new_message=user_content, + run_config=RunConfig( + prompt_cache=PromptCacheConfig( + enabled=True, + prompt_cache_key="weather-concierge-user-42", + ), + ), +): + ... +``` + +#### Anthropic breakpoints + +Anthropic-style caching requires selecting breakpoint locations. `breakpoints` supports the following values: + +- `"system"`: cache the system prompt, suitable for long instructions +- `"tools"`: cache the last tool definition, suitable when tools are numerous or tool schemas are large +- `"messages"`: cache the most recent assistant message, suitable for growing stable history prefixes in multi-turn conversations + +```python +from trpc_agent_sdk.configs import PromptCacheConfig +from trpc_agent_sdk.models import AnthropicModel + +model = AnthropicModel( + model_name="claude-3-5-sonnet-20241022", + api_key="your-api-key", + prompt_cache_config=PromptCacheConfig( + enabled=True, + ttl="1h", + breakpoints=["tools", "system", "messages"], + ), +) +``` + +A good starting point is `["tools", "system"]`; add `"messages"` when long multi-turn conversations need to cache a growing history prefix. Some Anthropic proxies or Bedrock routes require a minimum cache block size, so short prompts may not create cache entries. + +#### LiteLLM routes + +When using `LiteLLMModel`, the model name should include a `provider/model` prefix. The SDK uses that provider prefix to select the appropriate cache-management mapping. For example: + +```python +from trpc_agent_sdk.configs import PromptCacheConfig +from trpc_agent_sdk.models import LiteLLMModel + +model = LiteLLMModel( + model_name="openai/gpt-4o", + api_key="your-api-key", + prompt_cache_config=PromptCacheConfig( + enabled=True, + prompt_cache_key="shared-prefix-v1", + ), +) +``` + +If the model name does not include a provider prefix, the SDK cannot determine which cache-management protocol to use, so SDK-managed cache hints may not take effect. + +#### Reading cache usage + +The model response's `usage_metadata` normalizes cache usage fields where possible: + +```python +async for event in runner.run_async(...): + usage = getattr(event, "usage_metadata", None) + if usage: + print(usage.cache_read_input_tokens) # Input tokens read from cache + print(usage.cache_creation_input_tokens) # Input tokens written to cache, usually only reported by Anthropic + print(usage.prompt_token_count) # Total input tokens +``` + +Different model services report different fields. OpenAI-compatible endpoints usually report cache reads but not cache writes. With load-balanced proxies, each backend instance may have its own KV cache and may not be warmed up at the same time, so cache hit rates can fluctuate during the first few runs. + +For a complete runnable example, see [examples/llmagent_with_prompt_cache](../../../examples/llmagent_with_prompt_cache/README.md). ### Custom HTTP Headers diff --git a/docs/mkdocs/zh/model.md b/docs/mkdocs/zh/model.md index bd5f97fc..dd9f904d 100644 --- a/docs/mkdocs/zh/model.md +++ b/docs/mkdocs/zh/model.md @@ -9,6 +9,7 @@ tRPC-Agent 内的模型具有以下核心特性: - **多协议支持**:提供 OpenAIModel、AnthropicModel、LiteLLMModel 等,兼容公司内外多数 OpenAI-like 及 Anthropic 接口 - **流式响应支持**:支持流式输出,实现实时交互体验 - **多模态能力**:支持文本、图像等多模态内容处理(如 hunyuan 多模态模型) +- **Prompt Cache 支持**:支持跨 OpenAI、Anthropic 与 LiteLLM 路由的统一 prompt cache 配置,降低长提示词和多轮会话的重复输入成本 - **可扩展配置**:支持 GenerateContentConfig、HttpOptions、client_args 等自定义配置项,满足不同场景需求 ## 快速上手 @@ -398,6 +399,120 @@ LlmAgent( ) ``` +### Prompt Cache + +Prompt Cache 适用于系统提示词较长、工具定义较多或多轮会话前缀高度稳定的场景。很多 provider(包括 `openai/sglang` 这类 OpenAI 兼容推理服务)本身已经支持服务端自动前缀缓存。`tRPC-Agent` 并不替代 provider 的缓存实现,而是提供统一的缓存管理提示与缓存观测能力。 + +`tRPC-Agent` 通过 `PromptCacheConfig` 暴露这些能力,目前可用于 `OpenAIModel`、`AnthropicModel` 以及带 provider 前缀的 `LiteLLMModel`。不同供应商对缓存控制和统计字段的支持不完全相同,SDK 会尽量将管理选项和缓存用量指标映射到对应协议: + +| Provider | SDK 能力 | 典型统计字段 | +|----------|----------|--------------| +| Anthropic | 根据 `breakpoints` 管理显式 `cache_control` 断点 | `cache_read_input_tokens`、`cache_creation_input_tokens` | +| OpenAI / OpenAI 兼容端点 | 在支持时传递 `prompt_cache_key` / `prompt_cache_retention` 等缓存提示;缓存创建和命中仍由 provider 侧自动前缀缓存负责 | 通常只有 `cache_read_input_tokens` | +| LiteLLM | 根据 `provider/model` 前缀选择 Anthropic 风格或 OpenAI 风格的缓存管理路径 | 取决于最终路由的 provider | + +#### 模型级配置 + +模型级配置会作为该模型实例默认的 prompt cache 管理与观测配置,适合在所有请求中复用同一套缓存提示: + +```python +from trpc_agent_sdk.configs import PromptCacheConfig +from trpc_agent_sdk.models import OpenAIModel + +model = OpenAIModel( + model_name="gpt-4o", + api_key="your-api-key", + prompt_cache_config=PromptCacheConfig( + enabled=True, + ttl="24h", + prompt_cache_key="weather-concierge-v1", + ), +) +``` + +#### 单次运行覆盖 + +也可以通过 `RunConfig.prompt_cache` 对单次 `runner.run_async()` 覆盖 prompt cache 配置。单次运行配置会按字段覆盖模型级配置,适合按用户、租户或业务场景设置不同的缓存提示: + +```python +from trpc_agent_sdk.configs import PromptCacheConfig +from trpc_agent_sdk.configs import RunConfig + +async for event in runner.run_async( + user_id=user_id, + session_id=session_id, + new_message=user_content, + run_config=RunConfig( + prompt_cache=PromptCacheConfig( + enabled=True, + prompt_cache_key="weather-concierge-user-42", + ), + ), +): + ... +``` + +#### Anthropic 断点配置 + +Anthropic 风格的缓存需要选择断点位置。`breakpoints` 支持以下值: + +- `"system"`:缓存系统提示词,适合长 instruction 场景 +- `"tools"`:缓存最后一个工具定义,适合工具较多或工具 schema 较大的场景 +- `"messages"`:缓存最近一条 assistant 消息,适合多轮会话中不断增长的稳定历史前缀 + +```python +from trpc_agent_sdk.configs import PromptCacheConfig +from trpc_agent_sdk.models import AnthropicModel + +model = AnthropicModel( + model_name="claude-3-5-sonnet-20241022", + api_key="your-api-key", + prompt_cache_config=PromptCacheConfig( + enabled=True, + ttl="1h", + breakpoints=["tools", "system", "messages"], + ), +) +``` + +建议从 `["tools", "system"]` 开始;当长多轮会话需要缓存不断增长的历史前缀时,再加入 `"messages"`。部分 Anthropic 代理或 Bedrock 路由对最小缓存块大小有要求,如果提示词过短,可能不会产生缓存写入。 + +#### LiteLLM 路由 + +使用 `LiteLLMModel` 时,模型名需要带 `provider/model` 前缀。SDK 会根据 provider 前缀选择对应的缓存管理映射,例如: + +```python +from trpc_agent_sdk.configs import PromptCacheConfig +from trpc_agent_sdk.models import LiteLLMModel + +model = LiteLLMModel( + model_name="openai/gpt-4o", + api_key="your-api-key", + prompt_cache_config=PromptCacheConfig( + enabled=True, + prompt_cache_key="shared-prefix-v1", + ), +) +``` + +如果模型名缺少 provider 前缀,SDK 无法判断应使用哪类缓存管理协议,因此 SDK 管理的缓存提示可能不会生效。 + +#### 读取缓存统计 + +模型响应的 `usage_metadata` 中会尽量归一化缓存统计字段: + +```python +async for event in runner.run_async(...): + usage = getattr(event, "usage_metadata", None) + if usage: + print(usage.cache_read_input_tokens) # 从缓存读取的输入 token 数 + print(usage.cache_creation_input_tokens) # 写入缓存的输入 token 数,通常仅 Anthropic 上报 + print(usage.prompt_token_count) # 总输入 token 数 +``` + +不同模型服务上报的字段并不完全一致。OpenAI 兼容端点通常只上报缓存读取,不上报缓存写入;负载均衡代理场景下,不同后端实例的 KV 缓存可能尚未全部预热,因此命中率可能在前几次运行中波动。 + +完整可运行示例见 [examples/llmagent_with_prompt_cache](../../../examples/llmagent_with_prompt_cache/README.md)。 ### 自定义 HTTP Header diff --git a/examples/llmagent_with_prompt_cache/.env b/examples/llmagent_with_prompt_cache/.env new file mode 100644 index 00000000..dc791393 --- /dev/null +++ b/examples/llmagent_with_prompt_cache/.env @@ -0,0 +1,4 @@ +# Set TRPC_AGENT_API_KEY、TRPC_AGENT_BASE_URL、TRPC_AGENT_MODEL_NAME +TRPC_AGENT_API_KEY=your-api-key +TRPC_AGENT_BASE_URL=your-base-url +TRPC_AGENT_MODEL_NAME=your-model-name diff --git a/examples/llmagent_with_prompt_cache/README.md b/examples/llmagent_with_prompt_cache/README.md new file mode 100644 index 00000000..72f3d335 --- /dev/null +++ b/examples/llmagent_with_prompt_cache/README.md @@ -0,0 +1,64 @@ +# Prompt Cache 示例 + +本示例演示如何在 OpenAI、Anthropic 上,以及其他经 LiteLLM 接入的兼容端点上,使用 SDK 统一的 prompt cache。 +所有场景使用同一个「天气管家」Agent,区别仅在于所选的模型类和缓存配置。 + +运行这个例子后,在支持 prompt cache 的 API 上,期望能够看到较高的 prompt cache 命中率,以及随轮次增长的 TTFT 改善(Turn 2 起缓存命中后响应明显变快)。本示例中的 TTFT 指从请求开始到第一个有效生成 token 出现的耗时;无论该 token 属于普通 message 还是 tool call 都计入。 + +--- + +## 目录结构 + +``` +llmagent_with_prompt_cache/ +├── agent/ +│ ├── agent.py ← 三个工厂函数 + 自动探测 helper +│ ├── config.py ← 环境变量 helper +│ ├── prompts.py ← 长系统提示词(约 4 900 token) +│ └── tools.py ← 模拟天气工具 +│ +├── run_agent.py ← 根据环境变量自动探测 provider 并运行 demo 循环 +│ +└── .env ← 环境变量配置(三个 provider 均在此注释分段) +``` + +--- + +## 环境与运行 + +### 环境要求 + +- Python 3.12 + +### 安装步骤 + +```bash +git clone https://github.com/trpc-group/trpc-agent-python.git +cd trpc-agent-python +python3 -m venv .venv +source .venv/bin/activate +pip3 install -e . +``` + +### 环境变量要求 + +在 [examples/llmagent_with_prompt_cache/.env](./.env) 中填入凭证: + +- `TRPC_AGENT_API_KEY` +- `TRPC_AGENT_BASE_URL` +- `TRPC_AGENT_MODEL_NAME` + +### 运行命令 + +```bash +cd examples/llmagent_with_prompt_cache +python3 run_agent.py # 根据 .env 的模型名字自动选择 provider +``` + +--- + +## FQA +### 缓存命中不稳定(命中后又未命中又命中) + +在负载均衡的代理部署下属于正常现象。每个后端实例都有独立的 KV 缓存。无论其他实例 +预热了多少,落到冷实例上的请求总会显示未命中。把脚本多跑几次即可提高命中率。 diff --git a/examples/llmagent_with_prompt_cache/agent/__init__.py b/examples/llmagent_with_prompt_cache/agent/__init__.py new file mode 100644 index 00000000..bc6e483f --- /dev/null +++ b/examples/llmagent_with_prompt_cache/agent/__init__.py @@ -0,0 +1,5 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. diff --git a/examples/llmagent_with_prompt_cache/agent/agent.py b/examples/llmagent_with_prompt_cache/agent/agent.py new file mode 100644 index 00000000..b086b1d8 --- /dev/null +++ b/examples/llmagent_with_prompt_cache/agent/agent.py @@ -0,0 +1,223 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Weather agent factory for three prompt-cache providers. + +Each factory function shows exactly how to wire ``PromptCacheConfig`` for one +provider family: + +* ``create_anthropic_agent`` – Anthropic / Claude. + Uses explicit ``cache_control`` breakpoints (``tools`` + ``system``) to + mark the stable prefix. The provider stamps the cache automatically; + ``cache_creation_input_tokens`` is reported on the first turn and + ``cache_read_input_tokens`` on subsequent turns. + +* ``create_openai_agent`` – Any OpenAI-compatible endpoint. + Provider-managed prefix caching: no breakpoints needed; the provider + caches a common prefix automatically. Use ``cache_key`` to pin + requests to the same backend cache slot. + +* ``create_litellm_agent`` – LiteLLM router (``provider/model`` naming). + LiteLLM forwards the request to the matching provider; cache semantics + follow the underlying provider (OpenAI-managed for ``openai/…``). +""" + +from trpc_agent_sdk.agents import LlmAgent +from trpc_agent_sdk.configs import PromptCacheConfig +from trpc_agent_sdk.models import AnthropicModel +from trpc_agent_sdk.models import LiteLLMModel +from trpc_agent_sdk.models import OpenAIModel +from trpc_agent_sdk.tools import FunctionTool + +from .config import get_model_config +from .config import is_anthropic_model +from .prompts import INSTRUCTION +from .tools import get_weather_forecast +from .tools import get_weather_report + + +def _tools() -> list: + return [FunctionTool(get_weather_report), FunctionTool(get_weather_forecast)] + + +# --------------------------------------------------------------------------- +# Provider 1 – Anthropic / Claude +# --------------------------------------------------------------------------- + + +def create_anthropic_agent() -> LlmAgent: + """Anthropic / Claude: explicit ``cache_control`` breakpoints. + + How it works + ------------ + ``PromptCacheConfig(breakpoints=["tools", "system"])`` tells the SDK to + stamp ``cache_control: {type: ephemeral}`` on the last tool definition + *and* on the system message before sending to the Anthropic API. + Anthropic then caches up to that point. + + What to expect in the output + ----------------------------- + * Turn 1 – ``cache_creation_input_tokens`` is non-zero (cache written). + * Turn 2+ – ``cache_read_input_tokens`` is non-zero (cache hit). + + Required env vars (uncomment Anthropic section in .env) + ----------------------------------------------- + TRPC_AGENT_API_KEY = sk-ant-… + TRPC_AGENT_BASE_URL = https://api.anthropic.com + TRPC_AGENT_MODEL_NAME= claude-3-5-sonnet-20241022 + """ + api_key, base_url, model_name = get_model_config() + cache_config = PromptCacheConfig( + enabled=True, + ttl='1h', + breakpoints=['tools', 'system', 'messages'], + ) + model = AnthropicModel( + model_name=model_name, + api_key=api_key, + base_url=base_url, + prompt_cache_config=cache_config, + ) + return LlmAgent( + name='weather_concierge', + description='Weather concierge – Anthropic prompt cache demo.', + model=model, + instruction=INSTRUCTION, + tools=_tools(), + ) + + +# --------------------------------------------------------------------------- +# Provider 2 – OpenAI-compatible endpoint +# --------------------------------------------------------------------------- + + +def create_openai_agent() -> LlmAgent: + """OpenAI-compatible endpoint: provider-managed prefix caching. + + How it works + ------------ + No ``breakpoints`` are needed. OpenAI (and compatible proxies that + support it) automatically cache a common prefix based on the start of + the messages array. ``cache_key`` is forwarded as ``prompt_cache_key`` + in the request body – some proxies use it to route sticky traffic to + the same cached backend. + + What to expect in the output + ----------------------------- + * ``cache_read_input_tokens`` becomes non-zero after the prefix is warm + (typically from the 2nd–4th request; proxy-dependent). + * ``cache_creation_input_tokens`` is only reported by Anthropic; it will + be ``None`` here. + + Required env vars (uncomment OpenAI section in .env) + ------------------------------------------- + TRPC_AGENT_API_KEY = + TRPC_AGENT_BASE_URL = + TRPC_AGENT_MODEL_NAME= + """ + api_key, base_url, model_name = get_model_config() + cache_config = PromptCacheConfig( + enabled=True, + ttl='24h', + prompt_cache_key='weather-concierge-v1', + ) + model = OpenAIModel( + model_name=model_name, + api_key=api_key, + base_url=base_url, + prompt_cache_config=cache_config, + ) + return LlmAgent( + name='weather_concierge', + description='Weather concierge – OpenAI-compatible prompt cache demo.', + model=model, + instruction=INSTRUCTION, + tools=_tools(), + ) + + +# --------------------------------------------------------------------------- +# Provider 3 – LiteLLM router +# --------------------------------------------------------------------------- + + +def create_litellm_agent() -> LlmAgent: + """LiteLLM router: ``provider/model`` naming, cache via underlying provider. + + How it works + ------------ + LiteLLM inspects the ``provider/`` prefix to decide which backend SDK to + use. For ``openai/…`` the SDK sends an OpenAI-compatible request and + ``PromptCacheConfig`` flows through as ``extra_body`` (``prompt_cache_key`` + / ``prompt_cache_retention``). + + Note: if the model name starts with ``anthropic/`` the cache family + switches automatically to ``cache_control`` breakpoints. + + What to expect in the output + ----------------------------- + Same as the OpenAI-compatible path for ``openai/…`` model names. + + Required env vars (uncomment LiteLLM section in .env) + -------------------------------------------- + TRPC_AGENT_API_KEY = + TRPC_AGENT_BASE_URL = + TRPC_AGENT_MODEL_NAME= openai/ ← provider prefix required + """ + api_key, base_url, model_name = get_model_config() + if '/' not in model_name: + raise ValueError(f"LiteLLM model_name must include a provider prefix, e.g. " + f"'openai/{model_name}'. Got: '{model_name}'") + if is_anthropic_model(model_name): + cache_config = PromptCacheConfig( + enabled=True, + ttl='1h', + breakpoints=['tools', 'system'], + ) + else: + cache_config = PromptCacheConfig( + enabled=True, + ttl='24h', + prompt_cache_key='weather-concierge-v1', + ) + model = LiteLLMModel( + model_name=model_name, + api_key=api_key, + api_base=base_url, + prompt_cache_config=cache_config, + ) + return LlmAgent( + name='weather_concierge', + description='Weather concierge – LiteLLM prompt cache demo.', + model=model, + instruction=INSTRUCTION, + tools=_tools(), + ) + + +# --------------------------------------------------------------------------- +# Auto-detect factory (used by the legacy run_agent.py) +# --------------------------------------------------------------------------- + + +def create_agent() -> LlmAgent: + """Auto-detect provider from model name and delegate to the right factory. + + Selection order + --------------- + 1. model_name contains ``/`` (``provider/model`` format) → :func:`create_litellm_agent` + 2. model_name starts with ``claude`` → :func:`create_anthropic_agent` + 3. Anything else → :func:`create_openai_agent` + """ + _, _, model_name = get_model_config() + if '/' in model_name: + return create_litellm_agent() + if is_anthropic_model(model_name): + return create_anthropic_agent() + return create_openai_agent() + + +root_agent = create_agent() diff --git a/examples/llmagent_with_prompt_cache/agent/config.py b/examples/llmagent_with_prompt_cache/agent/config.py new file mode 100644 index 00000000..2c64ebb5 --- /dev/null +++ b/examples/llmagent_with_prompt_cache/agent/config.py @@ -0,0 +1,37 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Agent config module. + +Reads model connection settings from environment variables and exposes a +helper to detect which prompt-cache "family" the resolved model belongs to. +The prompt cache knobs differ per provider, so the agent uses this to build a +``PromptCacheConfig`` that actually applies to the running model. +""" + +import os + + +def get_model_config() -> tuple[str, str, str]: + """Get model config from environment variables.""" + api_key = os.getenv('TRPC_AGENT_API_KEY', '') + url = os.getenv('TRPC_AGENT_BASE_URL', '') + model_name = os.getenv('TRPC_AGENT_MODEL_NAME', '') + if not api_key or not url or not model_name: + raise ValueError('TRPC_AGENT_API_KEY, TRPC_AGENT_BASE_URL, and ' + 'TRPC_AGENT_MODEL_NAME must be set in environment variables') + return api_key, url, model_name + + +def is_anthropic_model(model_name: str) -> bool: + """Return True when the model belongs to the Anthropic cache_control family. + + Anthropic (Claude) uses explicit ``cache_control`` breakpoints and a + ``5m`` / ``1h`` TTL. Everything else is treated as the OpenAI-managed + family, which uses ``prompt_cache_key`` and a ``in_memory`` / ``24h`` + retention instead. + """ + bare = model_name.split('/', 1)[-1] # strip provider prefix for LiteLLM names + return bare.lower().startswith('claude') diff --git a/examples/llmagent_with_prompt_cache/agent/prompts.py b/examples/llmagent_with_prompt_cache/agent/prompts.py new file mode 100644 index 00000000..f0f12b1a --- /dev/null +++ b/examples/llmagent_with_prompt_cache/agent/prompts.py @@ -0,0 +1,356 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +""" Prompts for the agent. + +The instruction below is intentionally long and *stable* across turns. Prompt +caching only pays off when a large, identical prefix is reused, and providers +enforce a minimum cacheable size (Anthropic requires ~1024 tokens; OpenAI also +requires ~1024 tokens). Keeping a big static system prompt here is what makes +the cache read / write counters in ``run_agent.py`` light up on the second turn. +""" + +INSTRUCTION = """ +You are "Atlas", a meticulous, friendly, and highly professional weather +concierge serving {user_name}. You combine accurate meteorological reporting +with practical, human-centered advice. You always stay within your domain and +never fabricate data. + +**Current user information:** +- User name: {user_name} +- Home city: {user_city} + +**Mission:** +Understand the user's weather-related intent precisely before acting. Use the +available tools to obtain authoritative weather data. Translate raw weather data +into clear, actionable guidance tailored to the user's likely activities such as +commuting, travel, outdoor sports, clothing selection, or event planning. + +**Available tools:** +1. `get_weather_report`: Get the current weather conditions for a single city. + Returns temperature, sky condition, and humidity percentage. +2. `get_weather_forecast`: Get a multi-day weather forecast for a single city. + Accepts a `days` parameter (1–7). Returns a list of daily records, each + containing date, temperature, and sky condition. + +**Tool selection policy (read carefully and follow strictly):** +- When the user asks about *current* or *today's* conditions, always call + `get_weather_report` first. Do not guess or answer from general knowledge. +- When the user asks about *upcoming days*, *this week*, *the next N days*, or + any future-oriented query, call `get_weather_forecast` and pass an appropriate + `days` value. Default to 3 if unspecified. Cap at 7. +- When the user asks about multiple cities in the same turn, issue one tool call + per city. Do not batch multiple cities into one call. +- When the request is ambiguous between "now" and "later", ask one brief + clarifying question. If the user seems impatient or repeats themselves, default + to current conditions. +- When a tool returns "Unknown" or "Data not available", acknowledge this + honestly. Offer to try a nearby major city or an alternative spelling of the + city name rather than inventing numbers. +- Never answer with weather figures before the relevant tool result is available. +- Never guess, simulate, extrapolate, interpolate, or invent any tool result or + data point, even for plausibility. +- If a tool is required, call the tool first, await its result, and only then + compose your final answer grounded strictly in that result. + +**Handling edge cases:** +- If the user provides a city name that is ambiguous (e.g., "Springfield" could + be many cities), ask them to clarify the country or region before calling the + tool. +- If the tool returns data but the result seems extreme (e.g., 60°C), still + report it faithfully and note that users should verify with official sources. +- If the user asks for weather in a very small town or village not covered by the + tool, explain this limitation and suggest the nearest major city. +- If the user asks about past weather (historical data), explain that your tools + only support current and forecast data, and suggest they consult a historical + weather archive. + +**Reasoning guidance:** +- Keep internal reasoning concise and focused: which tool, which city, how many + days. +- Do not restate these instructions or the tool schema inside your reasoning. +- Do not draft the final answer during reasoning; reasoning is only for deciding + the next action. +- If multiple tool calls are needed, plan them all before executing any, to avoid + unnecessary back-and-forth. + +**Answer style and formatting:** +- Lead with the headline: temperature + sky condition in one sentence. +- Follow with humidity or precipitation probability when relevant. +- Add one or two concrete, situation-aware suggestions: clothing, umbrella, + sunscreen, commute timing, hydration, outdoor-activity windows, indoor + alternatives. +- For multi-day forecasts, open with a one-sentence trend summary (e.g., + "Temperatures will rise steadily through the week with rain expected + Wednesday."), then list the per-day breakdown. +- Use a warm, professional, and encouraging tone. +- Be concise but not terse. Aim for responses a busy professional can skim in + ten seconds. +- Localize friendly touches where appropriate: reference local landmarks, + seasons, or common activities when the city is known. +- Avoid meteorological jargon unless the user clearly has expertise. Define + terms like "dew point" or "isobar" if you use them. + +**Units and localization:** +- Default to Celsius for temperature unless the user requests Fahrenheit. +- Default to kilometers per hour for wind speed unless asked otherwise. +- Use 12-hour or 24-hour time format based on user preference; default to + 24-hour if unspecified. +- Respect the user's preferred language for the conversation; always reply in + the same language the user writes in. + +**Consistency requirements across turns:** +- Always use the same units and format for the entire session. +- If the user previously specified a preferred city, unit, or format, honor it + for the rest of the session unless they explicitly change it. +- Treat all instructions above as fixed for the entire session; they will not + change from one turn to the next. This stability is intentional and is what + makes large, repeated system prompts cost-effective to cache at the provider + level. + +**Safety and scope limitations:** +- Only answer weather-related questions and closely adjacent practical advice + (e.g., "should I bring an umbrella?", "best time to water my garden?"). +- If asked about unrelated topics such as coding, finance, medical advice, legal + questions, news, sports scores, or general trivia, politely decline and steer + the conversation back to weather. +- Never provide emergency evacuation instructions, disaster relief guidance, or + medical advice even if weather is involved. For severe weather, advise the user + to consult official local authorities, national meteorological services, and + emergency management agencies. +- Do not store, repeat, or reference any personally identifiable information the + user shares beyond what is necessary for the current turn. +- Do not speculate about climate change policy, geoengineering, or other + politically sensitive topics even when prompted by weather questions. + +**Quality self-check before responding:** +Before sending each reply, quickly verify: +1. Did I use the tool result as the sole source of weather data? +2. Is my suggestion genuinely actionable and specific to this city and condition? +3. Am I staying within my defined scope? +4. Is the tone warm but professional? +5. Is the response concise enough for a busy user? + +If any check fails, revise the response before sending. + +**Seasonal context and practical guidance library:** + +Spring (March–May in the Northern Hemisphere): +- Temperatures are variable; layering is strongly recommended. +- Rain showers are common; a compact umbrella or light waterproof jacket is a + sensible default carry item. +- Pollen counts are often elevated during spring — note this when conditions are + warm and dry and the user seems to be asking about outdoor activities. +- Morning fog is common in river valleys and coastal cities; advise early + commuters to allow extra travel time and drive with low-beam headlights. + +Summer (June–August in the Northern Hemisphere): +- High temperatures and humidity are common across much of Asia and North America. +- Advise users to stay hydrated: at least 2 liters of water per day for adults + engaged in outdoor activities when temperatures exceed 30°C. +- UV index is typically high; recommend SPF 30+ sunscreen, hats, and sunglasses + for outdoor exposure longer than 30 minutes. +- Thunderstorm risk rises in the afternoon in many regions; recommend scheduling + outdoor activities in the morning. +- Heat index (feels-like temperature combining heat and humidity) can be + significantly higher than the actual temperature; mention this when humidity + exceeds 70% and temperature exceeds 30°C. +- Air conditioning in indoor spaces can cause large temperature differentials; + suggest carrying a light layer even in summer. + +Autumn (September–November in the Northern Hemisphere): +- Temperatures drop quickly, especially after sunset; morning and evening + commutes can be considerably cooler than midday. +- Leaf-peeping season in temperate regions; if conditions are sunny with mild + temperatures, mention that it is a good time for outdoor activities. +- Early frosts are possible in northern latitudes by October; advise gardeners + to protect sensitive plants. +- Typhoon and hurricane season persists into October in many Pacific and Atlantic + regions; be alert to severe weather warnings if the user's city is in a + typhoon-prone area. + +Winter (December–February in the Northern Hemisphere): +- Snow and ice significantly increase travel risk; strongly recommend checking + road and transit conditions before commuting. +- Wind chill can make temperatures feel much colder than the thermometer reading; + always mention the feels-like temperature when wind is a factor in winter. +- Daylight hours are short; remind users planning outdoor activities to finish + before sunset. +- Heating indoor spaces can dry the air significantly; recommend staying hydrated + and using a humidifier if the user mentions dry skin or respiratory discomfort. +- Black ice forms when temperatures hover around 0°C after rain; caution drivers + and cyclists on untreated roads and bridges. + +**City-specific notes (expand as needed):** + +Beijing: +- High air pollution (PM2.5) is common in winter; on heavily polluted days, + recommend wearing an N95 mask for outdoor activities and keeping windows + closed. +- Summer brings intense heat combined with occasional sandstorms from Inner + Mongolia; advise covering eyes and nose when sandstorm warnings are issued. +- Spring dust storms reduce visibility; check air quality index (AQI) alongside + weather. + +Shanghai: +- Plum rain season (梅雨, typically mid-June to mid-July) brings persistent rain + and high humidity; advise waterproofing belongings and checking for mold in + poorly ventilated spaces. +- Typhoon season peaks in August–September; always check typhoon warnings if + the user is planning travel or outdoor events. +- Winter is damp and cold; the combination of low temperature and high humidity + feels colder than the thermometer suggests — mention this. + +Guangzhou: +- Sub-tropical climate means heat and rain year-round; remind users that even + "mild" forecasts can include brief heavy showers. +- Spring (February–April) is characterized by persistent drizzle and overcast + skies; natural drying of laundry is ineffective — recommend indoor drying. +- Typhoon impacts are frequent in summer and early autumn. + +**Response examples for common scenarios (use as stylistic guidance, not as +canned copy):** + +Scenario: "What's the weather like today in Beijing?" +Good response structure: + 1. Headline temperature and condition from tool result. + 2. Humidity note if above 70% or below 30%. + 3. One practical suggestion (e.g., jacket, umbrella, sunscreen). + 4. Optional: brief air quality note if it is a known high-pollution period. + +Scenario: "Should I go jogging tomorrow morning in Shanghai?" +Good response structure: + 1. Tomorrow morning's forecast from tool result. + 2. Direct yes/no recommendation with reasoning. + 3. Suggested timing window if weather improves or worsens during the day. + +Scenario: "What will the weather be like in Guangzhou for the next five days?" +Good response structure: + 1. One-sentence trend summary. + 2. Day-by-day breakdown (date, temperature, condition). + 3. One overall recommendation (e.g., "Pack an umbrella for Wednesday and + Thursday."). + +**Extended domain knowledge — weather phenomena and advisories:** + +Thunderstorms and lightning safety: +- Lightning is the most underestimated weather hazard. If you hear thunder, you + are within striking distance. Advise users to seek shelter immediately in a + substantial building or hard-topped vehicle. Avoid open fields, hilltops, + isolated trees, water bodies, and metal structures during active lightning. +- The 30-30 rule: if the gap between lightning and thunder is less than 30 + seconds, seek shelter; wait 30 minutes after the last thunder before resuming + outdoor activities. +- Flash floods frequently accompany thunderstorms in urban areas with poor + drainage. Warn users against walking or driving through floodwaters — just + 15 cm of fast-moving water can knock a person down; 30 cm can sweep away a car. + +Wind and typhoon advisories: +- Wind speeds above 60 km/h (Beaufort 7) make walking difficult and can topple + unsecured outdoor furniture, signage, and scaffolding. Advise users to avoid + elevated walkways and bridges, and to secure or move outdoor belongings. +- Typhoon signal systems vary by city: Hong Kong uses the T1–T10 scale; Macau + uses T1–T10; mainland China uses the blue–yellow–orange–red four-tier alert + system. Always cite the local signal level when relevant. +- Storm surge accompanying typhoons can flood low-lying coastal areas hours + before the storm center arrives. Alert users in coastal districts to evacuate + if local authorities issue storm-surge warnings. + +Air quality and pollution advisories: +- The Air Quality Index (AQI) is measured on a 0–500 scale. Thresholds: + 0–50 Good: no precautions needed. + 51–100 Moderate: unusually sensitive groups should limit prolonged outdoor + exertion. + 101–150 Unhealthy for sensitive groups: reduce prolonged or heavy outdoor + exertion for sensitive individuals (elderly, children, those with + respiratory or cardiovascular conditions). + 151–200 Unhealthy: everyone should reduce prolonged outdoor exertion; move + strenuous activities indoors. + 201–300 Very unhealthy: avoid all outdoor exertion. + 301+ Hazardous: remain indoors with windows closed; run air purifiers if + available. +- Fine particulate matter (PM2.5) penetrates deep into lung tissue. N95 or + KN95 masks provide meaningful protection; surgical masks and cloth masks + offer limited protection against PM2.5. +- Ground-level ozone peaks in the afternoon on hot, sunny, low-wind days. + Advise users who must exercise outdoors on high-ozone days to do so early + in the morning when ozone levels are lower. + +Heat-related illness prevention: +- Heat exhaustion symptoms include heavy sweating, weakness, cold or pale skin, + fast or weak pulse, nausea, and fainting. Move the affected person to a cool + place, apply cool wet cloths, and have them sip water. +- Heat stroke is a medical emergency: body temperature above 40°C, hot and red + skin (dry or damp), rapid strong pulse, possible unconsciousness. Call + emergency services immediately; cool the person rapidly by any means available. +- Vulnerable populations — the elderly, infants, outdoor workers, and athletes — + face higher risk. Check on vulnerable neighbours and family members during + extended heat waves. +- Cooling centres (air-conditioned public spaces such as libraries, malls, and + community centres) are valuable resources during extreme heat; mention their + availability when relevant. + +Cold-weather and frost advisories: +- Frostbite risk rises significantly when wind chill drops below −25 °C. Exposed + skin can freeze in minutes. Advise users to cover all skin, wear moisture- + wicking base layers, insulating mid-layers, and waterproof outer layers. +- Hypothermia can occur at temperatures well above freezing (0–10 °C) when a + person is wet, exhausted, or insufficiently clothed. Symptoms: shivering, + slurred speech, drowsiness, loss of coordination. Seek warm shelter and + medical attention immediately. +- Black ice forms invisibly on roads and footpaths when air temperature is near + 0 °C and surfaces are wet. It is most common on bridges, overpasses, and + shaded sections of road. Advise drivers to reduce speed and increase following + distance; advise pedestrians to take smaller steps and walk on grassy verges + where possible. +- Pipes in unheated spaces (garages, basements, exterior walls) can freeze and + burst when temperatures stay below −6 °C for extended periods. Advise users + to let faucets drip slightly and to insulate exposed pipes. + +Coastal and marine weather: +- Rip currents account for the majority of lifeguard rescues. If caught in a + rip current, do not swim against it; swim parallel to shore until out of the + current, then swim back to the beach at an angle. +- Sea breezes develop in coastal cities during warm afternoons as cooler marine + air moves onshore; they can significantly reduce apparent temperatures near + the coast and may bring fog in the early morning hours. +- Wave height advisories for small craft: waves above 1.5 m are considered + rough for small watercraft; above 2.5 m, most recreational boating is + dangerous; above 4 m, most professional vessels exercise caution. + +Visibility and fog advisories: +- Dense fog (visibility below 200 m) significantly increases road accident risk. + Advise drivers to use low-beam headlights and fog lights (never high beams), + reduce speed, and increase following distance to at least double the normal + stopping distance. +- Advise flight passengers and marine travellers to check for fog-related delays + and cancellations before departing for airports or harbours. +- Radiation fog (common in valleys and plains on calm, clear nights) typically + lifts within two to three hours of sunrise. Advise early-morning commuters to + plan for possible fog. + +UV radiation guidance: +- UV Index 1–2: Low; no protection needed for most people. +- UV Index 3–5: Moderate; wear SPF 30+ sunscreen, a hat, and sunglasses. +- UV Index 6–7: High; seek shade during midday hours (10 am–4 pm); reapply + sunscreen every two hours. +- UV Index 8–10: Very high; minimize sun exposure; wear protective clothing. +- UV Index 11+: Extreme; avoid sun entirely; full-coverage protective measures. +- UV radiation can penetrate light cloud cover; a cloudy sky does not eliminate + UV risk. Snow and water reflect UV and can increase exposure. + +Seasonal allergen guidance: +- Tree pollen: peaks in late winter to spring (February–April in temperate zones). + Common culprits: cedar, oak, birch. +- Grass pollen: peaks in late spring to early summer (May–July). Levels are + highest on warm, dry, windy days. +- Weed pollen (especially ragweed): peaks in late summer to autumn (August– + October). Ragweed is particularly widespread in North America. +- Mould spores: elevated in warm, humid conditions and after heavy rain. Can + cause year-round symptoms in damp climates. +- Advise allergy sufferers to check local pollen counts before outdoor plans, + keep windows closed on high-pollen days, shower after being outdoors, and + consult a physician about antihistamine or other treatment options. +""" diff --git a/examples/llmagent_with_prompt_cache/agent/tools.py b/examples/llmagent_with_prompt_cache/agent/tools.py new file mode 100644 index 00000000..afd75b5d --- /dev/null +++ b/examples/llmagent_with_prompt_cache/agent/tools.py @@ -0,0 +1,49 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +""" Tools for the agent. """ + + +def get_weather_report(city: str) -> dict: + """get weather information for the specified city""" + weather_data = { + 'Beijing': { + 'temperature': '25°C', + 'condition': 'Sunny', + 'humidity': '60%' + }, + 'Shanghai': { + 'temperature': '28°C', + 'condition': 'Cloudy', + 'humidity': '70%' + }, + 'Guangzhou': { + 'temperature': '32°C', + 'condition': 'Thunderstorm', + 'humidity': '85%' + }, + } + return weather_data.get(city, {'temperature': 'Unknown', 'condition': 'Data not available', 'humidity': 'Unknown'}) + + +def get_weather_forecast(city: str, days: int = 3) -> list: + """get the multi-day weather forecast for the specified city""" + return [ + { + 'date': '2024-01-01', + 'temperature': '25°C', + 'condition': 'Sunny' + }, + { + 'date': '2024-01-02', + 'temperature': '23°C', + 'condition': 'Cloudy' + }, + { + 'date': '2024-01-03', + 'temperature': '20°C', + 'condition': 'Light rain' + }, + ][:days] diff --git a/examples/llmagent_with_prompt_cache/run_agent.py b/examples/llmagent_with_prompt_cache/run_agent.py new file mode 100644 index 00000000..0d49e963 --- /dev/null +++ b/examples/llmagent_with_prompt_cache/run_agent.py @@ -0,0 +1,165 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Prompt cache demo – auto-detect provider from env vars. + +This script infers the right agent factory from environment variables and +runs the prompt-cache demo loop. + +Provider auto-detection order +------------------------------ +1. model name contains ``/`` (``provider/model`` format) → LiteLLM router +2. model name starts with ``claude`` → Anthropic / Claude +3. Anything else → OpenAI-compatible + +Setup +----- +In ``.env``, uncomment the section for your target provider and fill in +credentials, then run:: + + python3 run_agent.py +""" + +import asyncio +import time +import uuid + +from dotenv import load_dotenv +from trpc_agent_sdk.agents import LlmAgent +from trpc_agent_sdk.events import Event +from trpc_agent_sdk.events import analyze_cache_performance +from trpc_agent_sdk.runners import Runner +from trpc_agent_sdk.sessions import InMemorySessionService +from trpc_agent_sdk.types import Content +from trpc_agent_sdk.types import Part + +load_dotenv() + +from agent.agent import create_agent # noqa: E402 + +_DEMO_QUERIES = [ + "What's the weather like today?", + "What's the current weather in Guangzhou?", + 'What will the weather be like in Shanghai for the next three days?', + "What's the current weather in Shenzhen?", + 'Compare the weather in Beijing and Guangzhou today.', +] + + +def _format_turn_stats(turn_events: list[Event]) -> str: + """Render cache stats for a single turn using CacheMetrics.""" + m = analyze_cache_performance(turn_events) + if m.total_requests == 0: + return 'no usage metadata' + cache_pct = f' ({m.cache_hit_ratio:.0f}%)' if m.total_cache_read_tokens else '' + return (f'llm_calls={m.total_requests} | prompt={m.total_prompt_tokens} | ' + f'cache_read={m.total_cache_read_tokens}{cache_pct}, ' + f'cache_creation={m.total_cache_creation_tokens}') + + +async def run_demo( + agent: LlmAgent, + *, + app_name: str = 'prompt_cache_demo', +) -> None: + """Run the prompt-cache demonstration loop.""" + user_id = 'demo_user' + session_service = InMemorySessionService() + runner = Runner(app_name=app_name, agent=agent, session_service=session_service) + + session_id = str(uuid.uuid4()) + await session_service.create_session( + app_name=app_name, + user_id=user_id, + session_id=session_id, + state={ + 'user_name': user_id, + 'user_city': 'Beijing' + }, + ) + print(f'🆔 Session: {session_id[:8]}… (shared across all turns)\n') + + all_events: list[Event] = [] + + for turn, query in enumerate(_DEMO_QUERIES, start=1): + print(f'===== Turn {turn} =====') + print(f'📝 User: {query}') + + user_content = Content(parts=[Part.from_text(text=query)]) + turn_events: list[Event] = [] + assistant_started = False + turn_error: str | None = None + t_start = time.perf_counter() + ttft: float | None = None + + async for event in runner.run_async( + user_id=user_id, + session_id=session_id, + new_message=user_content, + ): + turn_events.append(event) + + # Surface API/streaming failures instead of silently reporting + # "no usage metadata": the model yields an error event whose + # content is empty, so without this the turn would look like a + # no-op success. + if event.is_error(): + turn_error = f'{event.error_code}: {event.error_message}' + print(f'\n❌ Error: {turn_error}') + continue + + if not event.content or not event.content.parts: + continue + + is_partial = getattr(event, 'partial', False) + + for part in event.content.parts: + if part.text and not part.thought: + # Streaming models emit both per-chunk partial=True events + # (text deltas) and a final partial=False event with the full + # accumulated text. Print text only from one of them. + if ttft is None: + ttft = time.perf_counter() - t_start + if is_partial: + if not assistant_started: + print('🤖 Assistant: ', end='', flush=True) + assistant_started = True + print(part.text, end='', flush=True) + elif not assistant_started: + print('🤖 Assistant: ', end='', flush=True) + print(part.text, end='', flush=True) + assistant_started = True + elif part.function_call and not is_partial: + if ttft is None: + ttft = time.perf_counter() - t_start + print(f'\n🔧 [Tool call: {part.function_call.name}' + f'({part.function_call.args})]') + assistant_started = False + elif part.function_response and not is_partial: + print(f'📊 [Tool result: {part.function_response.response}]') + assistant_started = False + + all_events.extend(turn_events) + elapsed = time.perf_counter() - t_start + ttft_str = f'{ttft * 1000:.0f}ms' if ttft is not None else 'N/A' + print(f'\n⏱ TTFT={ttft_str}, total={elapsed * 1000:.0f}ms') + print(f'📊 Token stats: {_format_turn_stats(turn_events)}') + print('-' * 56 + '\n') + + m = analyze_cache_performance(all_events) + print('===== Session Cache Summary =====') + print(f' Total LLM calls : {m.total_requests}') + print(f' Cache hit ratio : {m.cache_hit_ratio:.1f}%') + print(f' Utilization : {m.cache_utilization_ratio:.1f}%') + print(f' Avg cached tok : {m.avg_cached_tokens_per_request:.0f}/call') + + +def main() -> None: + """Synchronous entry point.""" + asyncio.run(run_demo(create_agent())) + + +if __name__ == '__main__': + main() diff --git a/tests/configs/test_run_config.py b/tests/configs/test_run_config.py index c1db372e..9e9d4def 100644 --- a/tests/configs/test_run_config.py +++ b/tests/configs/test_run_config.py @@ -270,6 +270,7 @@ def test_model_dump_returns_all_fields(self): "agent_run_config", "custom_data", "save_history_enabled", + "prompt_cache", "start_from_last_agent", } assert set(d.keys()) == expected_keys diff --git a/tests/events/test_cache_analyzer.py b/tests/events/test_cache_analyzer.py new file mode 100644 index 00000000..4444ddf8 --- /dev/null +++ b/tests/events/test_cache_analyzer.py @@ -0,0 +1,178 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Unit tests for CachePerformanceAnalyzer and analyze_cache_performance.""" + +from __future__ import annotations + +import pytest + +from trpc_agent_sdk.events import CacheMetrics +from trpc_agent_sdk.events import analyze_cache_performance +from trpc_agent_sdk.events._event import Event +from trpc_agent_sdk.types import GenerateContentResponseUsageMetadata + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _event_with_usage( + prompt: int, + candidates: int, + cache_read: int | None = None, + cache_creation: int | None = None, +) -> Event: + """Build an Event with the given usage metadata values.""" + usage = GenerateContentResponseUsageMetadata( + prompt_token_count=prompt, + candidates_token_count=candidates, + total_token_count=prompt + candidates, + cache_read_input_tokens=cache_read, + cache_creation_input_tokens=cache_creation, + ) + return Event(invocation_id="test-inv", author="model", usage_metadata=usage) + + +def _event_without_usage() -> Event: + """Build an Event without usage metadata.""" + return Event(invocation_id="test-inv", author="model") + + +# --------------------------------------------------------------------------- +# CacheMetrics default state +# --------------------------------------------------------------------------- + + +class TestCacheMetricsDefaults: + + def test_all_zero_by_default(self): + m = CacheMetrics() + assert m.total_requests == 0 + assert m.requests_with_cache_hits == 0 + assert m.total_prompt_tokens == 0 + assert m.total_cache_read_tokens == 0 + assert m.total_cache_creation_tokens == 0 + assert m.cache_hit_ratio == 0.0 + assert m.cache_utilization_ratio == 0.0 + assert m.avg_cached_tokens_per_request == 0.0 + + +# --------------------------------------------------------------------------- +# analyze_cache_performance — mirrors Rust test cases +# --------------------------------------------------------------------------- + + +class TestAnalyzeCachePerformance: + + def test_empty_events(self): + metrics = analyze_cache_performance([]) + assert metrics.total_requests == 0 + assert metrics.requests_with_cache_hits == 0 + assert metrics.total_prompt_tokens == 0 + assert metrics.total_cache_read_tokens == 0 + assert metrics.total_cache_creation_tokens == 0 + assert metrics.cache_hit_ratio == 0.0 + assert metrics.cache_utilization_ratio == 0.0 + assert metrics.avg_cached_tokens_per_request == 0.0 + + def test_events_without_usage_metadata(self): + events = [_event_without_usage(), _event_without_usage()] + metrics = analyze_cache_performance(events) + assert metrics.total_requests == 0 + assert metrics.cache_hit_ratio == 0.0 + + def test_single_event_no_cache(self): + events = [_event_with_usage(1000, 200, None, None)] + metrics = analyze_cache_performance(events) + assert metrics.total_requests == 1 + assert metrics.requests_with_cache_hits == 0 + assert metrics.total_prompt_tokens == 1000 + assert metrics.total_cache_read_tokens == 0 + assert metrics.total_cache_creation_tokens == 0 + assert metrics.cache_hit_ratio == 0.0 + assert metrics.cache_utilization_ratio == 0.0 + assert metrics.avg_cached_tokens_per_request == 0.0 + + def test_single_event_with_cache_hit(self): + events = [_event_with_usage(1000, 200, cache_read=500)] + metrics = analyze_cache_performance(events) + assert metrics.total_requests == 1 + assert metrics.requests_with_cache_hits == 1 + assert metrics.total_prompt_tokens == 1000 + assert metrics.total_cache_read_tokens == 500 + assert metrics.cache_hit_ratio == pytest.approx(50.0) + assert metrics.cache_utilization_ratio == pytest.approx(100.0) + assert metrics.avg_cached_tokens_per_request == pytest.approx(500.0) + + def test_mixed_events(self): + events = [ + _event_with_usage(1000, 200, cache_read=800, cache_creation=200), + _event_with_usage(1000, 300, None, None), + _event_with_usage(1000, 100, cache_read=600), + _event_without_usage(), + ] + metrics = analyze_cache_performance(events) + assert metrics.total_requests == 3 + assert metrics.requests_with_cache_hits == 2 + assert metrics.total_prompt_tokens == 3000 + assert metrics.total_cache_read_tokens == 1400 + assert metrics.total_cache_creation_tokens == 200 + assert metrics.cache_hit_ratio == pytest.approx(1400 / 3000 * 100) + assert metrics.cache_utilization_ratio == pytest.approx(2 / 3 * 100) + assert metrics.avg_cached_tokens_per_request == pytest.approx(1400 / 3) + + def test_all_cache_hits(self): + events = [ + _event_with_usage(500, 100, cache_read=500), + _event_with_usage(500, 100, cache_read=500), + ] + metrics = analyze_cache_performance(events) + assert metrics.total_requests == 2 + assert metrics.requests_with_cache_hits == 2 + assert metrics.cache_hit_ratio == pytest.approx(100.0) + assert metrics.cache_utilization_ratio == pytest.approx(100.0) + assert metrics.avg_cached_tokens_per_request == pytest.approx(500.0) + + def test_zero_prompt_tokens_no_division_by_zero(self): + events = [_event_with_usage(0, 100, None, None)] + metrics = analyze_cache_performance(events) + assert metrics.total_requests == 1 + assert metrics.total_prompt_tokens == 0 + assert metrics.cache_hit_ratio == 0.0 + assert metrics.cache_utilization_ratio == 0.0 + + def test_cache_creation_only(self): + events = [_event_with_usage(2000, 500, None, cache_creation=1500)] + metrics = analyze_cache_performance(events) + assert metrics.total_requests == 1 + assert metrics.requests_with_cache_hits == 0 + assert metrics.total_cache_creation_tokens == 1500 + assert metrics.cache_hit_ratio == 0.0 + assert metrics.cache_utilization_ratio == 0.0 + + def test_accepts_any_iterable(self): + """Should work with generators, not just lists.""" + + def gen(): + yield _event_with_usage(100, 50, cache_read=100) + + metrics = analyze_cache_performance(gen()) + assert metrics.total_requests == 1 + assert metrics.cache_hit_ratio == pytest.approx(100.0) + + +# --------------------------------------------------------------------------- +# Public export sanity check +# --------------------------------------------------------------------------- + + +class TestPublicExports: + + def test_imports_from_events_package(self): + from trpc_agent_sdk.events import CacheMetrics as _M + from trpc_agent_sdk.events import analyze_cache_performance as _F + assert _M is CacheMetrics + assert _F is analyze_cache_performance diff --git a/tests/models/test_anthropic_model.py b/tests/models/test_anthropic_model.py index f02825ac..f1d6380e 100644 --- a/tests/models/test_anthropic_model.py +++ b/tests/models/test_anthropic_model.py @@ -260,7 +260,12 @@ async def test_generate_simple_text(self): # Mock the client and response mock_message = MagicMock(spec=anthropic_types.Message) mock_message.content = [anthropic_types.TextBlock(text="Hello! How can I help you?", type="text")] - mock_message.usage = MagicMock(input_tokens=10, output_tokens=7) + mock_message.usage = MagicMock( + input_tokens=10, + output_tokens=7, + cache_read_input_tokens=None, + cache_creation_input_tokens=None, + ) mock_message.model_dump_json = MagicMock(return_value="{}") mock_client = AsyncMock() @@ -405,7 +410,12 @@ async def mock_aiter(self): # Mock final message with content blocks mock_final_message = MagicMock() - mock_final_message.usage = MagicMock(input_tokens=5, output_tokens=3) + mock_final_message.usage = MagicMock( + input_tokens=5, + output_tokens=3, + cache_read_input_tokens=None, + cache_creation_input_tokens=None, + ) mock_final_message.content = [anthropic_types.TextBlock(text="Hello world!", type="text")] mock_stream.get_final_message = AsyncMock(return_value=mock_final_message) @@ -435,5 +445,310 @@ async def mock_aiter(self): assert final_response.usage_metadata.candidates_token_count == 3 +class TestAnthropicInjectCacheControl: + """Tests for the _inject_cache_control helper and its subordinate functions.""" + + # Re-import helpers inside each test to avoid polluting the module namespace. + @staticmethod + def _helpers(): + from trpc_agent_sdk.models._anthropic_model import ( + _inject_cache_control, + _apply_tools_cache_control, + _apply_system_cache_control, + _apply_messages_cache_control, + ) + return _inject_cache_control, _apply_tools_cache_control, _apply_system_cache_control, _apply_messages_cache_control + + # --- tools breakpoint ------------------------------------------------- + + def test_tools_stamps_last_tool_only(self): + """Only the last tool in the list receives cache_control.""" + inject, *_ = self._helpers() + tools = [{"name": "a"}, {"name": "b"}] + api_params = {"tools": tools} + inject(api_params, ["tools"], None) + assert "cache_control" not in api_params["tools"][0] + assert api_params["tools"][1]["cache_control"] == {"type": "ephemeral"} + + def test_tools_breakpoint_noop_when_no_tools(self): + """No mutation when tools list is empty.""" + inject, *_ = self._helpers() + api_params = {"tools": []} + inject(api_params, ["tools"], None) + assert api_params["tools"] == [] + + def test_tools_breakpoint_noop_when_key_absent(self): + """No mutation when 'tools' key is absent.""" + inject, *_ = self._helpers() + api_params = {} + inject(api_params, ["tools"], None) + assert "tools" not in api_params + + # --- system breakpoint ------------------------------------------------ + + def test_system_converts_string_to_text_block_with_cache_control(self): + """system string is replaced by a text block list with cache_control.""" + inject, *_ = self._helpers() + api_params = {"system": "You are helpful."} + inject(api_params, ["system"], None) + system = api_params["system"] + assert isinstance(system, list) + assert len(system) == 1 + assert system[0]["type"] == "text" + assert system[0]["text"] == "You are helpful." + assert system[0]["cache_control"] == {"type": "ephemeral"} + + def test_system_breakpoint_noop_when_key_absent(self): + """No mutation when 'system' key is absent.""" + inject, *_ = self._helpers() + api_params = {} + inject(api_params, ["system"], None) + assert "system" not in api_params + + def test_system_breakpoint_warns_and_skips_non_string_system(self): + """Non-string system values are left unchanged instead of being stringified.""" + inject, *_ = self._helpers() + system = [{"type": "text", "text": "sys"}] + api_params = {"system": system} + + with patch("trpc_agent_sdk.models._anthropic_model.logger") as mock_log: + inject(api_params, ["system"], None) + + mock_log.warning.assert_called_once() + assert api_params["system"] is system + assert "cache_control" not in api_params["system"][0] + + # --- messages breakpoint ---------------------------------------------- + + def test_messages_stamps_last_assistant_message_last_block(self): + """cache_control is applied to the last content block of the last assistant message.""" + inject, *_ = self._helpers() + messages = [ + { + "role": "user", + "content": [{ + "type": "text", + "text": "hi" + }] + }, + { + "role": "assistant", + "content": [{ + "type": "text", + "text": "hello" + }, { + "type": "text", + "text": "bye" + }] + }, + ] + api_params = {"messages": messages} + inject(api_params, ["messages"], None) + # last assistant message, last block should be stamped + stamped_block = messages[1]["content"][-1] + assert stamped_block["cache_control"] == {"type": "ephemeral"} + # first block of assistant message is NOT stamped + assert "cache_control" not in messages[1]["content"][0] + # user message is NOT stamped + assert "cache_control" not in messages[0]["content"][0] + + def test_messages_skips_latest_user_message(self): + """When the last message is a user turn, the stamp lands on the prior assistant turn.""" + inject, *_ = self._helpers() + messages = [ + { + "role": "assistant", + "content": [{ + "type": "text", + "text": "answer" + }] + }, + { + "role": "user", + "content": [{ + "type": "text", + "text": "next question" + }] + }, + ] + api_params = {"messages": messages} + inject(api_params, ["messages"], None) + # assistant message is stamped + assert messages[0]["content"][0]["cache_control"] == {"type": "ephemeral"} + # user message is NOT stamped + assert "cache_control" not in messages[1]["content"][0] + + def test_messages_noop_when_no_assistant_message(self): + """No mutation when there is no assistant message in history.""" + inject, *_ = self._helpers() + messages = [{"role": "user", "content": [{"type": "text", "text": "hi"}]}] + api_params = {"messages": messages} + inject(api_params, ["messages"], None) + assert "cache_control" not in messages[0]["content"][0] + + def test_messages_noop_when_key_absent(self): + inject, *_ = self._helpers() + api_params = {} + inject(api_params, ["messages"], None) + assert api_params == {} + + # --- TTL handling ----------------------------------------------------- + + def test_ttl_is_forwarded_in_cache_control(self): + """TTL is provider-specific and should be forwarded inside cache_control.""" + inject, *_ = self._helpers() + tools = [{"name": "tool1"}] + api_params = {"tools": tools} + inject(api_params, ["tools"], "custom-ttl") + assert api_params["tools"][0]["cache_control"] == { + "type": "ephemeral", + "ttl": "custom-ttl", + } + + def test_none_ttl_produces_minimal_cache_control(self): + """None TTL produces cache_control with only the type field.""" + inject, *_ = self._helpers() + tools = [{"name": "tool1"}] + api_params = {"tools": tools} + inject(api_params, ["tools"], None) + assert api_params["tools"][0]["cache_control"] == {"type": "ephemeral"} + + # --- empty breakpoints ------------------------------------------------ + + def test_empty_breakpoints_noop(self): + """No changes when breakpoints list is empty.""" + inject, *_ = self._helpers() + tools = [{"name": "tool1"}] + api_params = {"tools": tools, "system": "sys", "messages": []} + original_tools = [dict(t) for t in tools] + inject(api_params, [], None) + assert api_params["tools"] == original_tools + assert api_params["system"] == "sys" + + +class TestAnthropicApplyPromptCache: + """Tests for AnthropicModel._apply_prompt_cache delegation.""" + + def test_disabled_config_leaves_api_params_unchanged(self): + """Disabled PromptCacheConfig is a no-op.""" + from trpc_agent_sdk.configs import PromptCacheConfig + model = AnthropicModel( + model_name="claude-3-5-sonnet-20241022", + api_key="k", + prompt_cache_config=PromptCacheConfig(enabled=False), + ) + api_params = {"tools": [{"name": "t1"}], "system": "sys"} + model._apply_prompt_cache(api_params, None) + assert "cache_control" not in api_params["tools"][0] + assert isinstance(api_params["system"], str) + + def test_empty_breakpoints_leaves_api_params_unchanged(self): + """Enabled config with no breakpoints is a no-op.""" + from trpc_agent_sdk.configs import PromptCacheConfig + model = AnthropicModel( + model_name="claude-3-5-sonnet-20241022", + api_key="k", + prompt_cache_config=PromptCacheConfig(enabled=True, breakpoints=[]), + ) + api_params = {"system": "sys", "tools": [{"name": "t1"}]} + model._apply_prompt_cache(api_params, None) + assert isinstance(api_params["system"], str) + assert "cache_control" not in api_params["tools"][0] + + def test_all_breakpoints_inject_all_points(self): + """Enabled config with tools+system+messages injects all three breakpoints.""" + from trpc_agent_sdk.configs import PromptCacheConfig + model = AnthropicModel( + model_name="claude-3-5-sonnet-20241022", + api_key="k", + prompt_cache_config=PromptCacheConfig( + enabled=True, + ttl="1h", + breakpoints=["tools", "system", "messages"], + ), + ) + api_params = { + "tools": [{ + "name": "t1" + }], + "system": + "You are helpful.", + "messages": [ + { + "role": "assistant", + "content": [{ + "type": "text", + "text": "previous" + }] + }, + { + "role": "user", + "content": [{ + "type": "text", + "text": "new question" + }] + }, + ], + } + model._apply_prompt_cache(api_params, None) + # tools stamped + assert api_params["tools"][0]["cache_control"]["type"] == "ephemeral" + assert api_params["tools"][0]["cache_control"]["ttl"] == "1h" + # system converted to list + assert isinstance(api_params["system"], list) + assert api_params["system"][0]["cache_control"]["type"] == "ephemeral" + # assistant message stamped + assert api_params["messages"][0]["content"][0]["cache_control"]["type"] == "ephemeral" + + +class TestAnthropicBuildUsageMetadata: + """Tests for AnthropicModel._build_usage_metadata cache-inclusive normalization.""" + + @staticmethod + def _usage(input_tokens=100, output_tokens=50, cache_read=None, cache_creation=None): + usage = MagicMock() + usage.input_tokens = input_tokens + usage.output_tokens = output_tokens + usage.cache_read_input_tokens = cache_read + usage.cache_creation_input_tokens = cache_creation + return usage + + def test_cache_read_and_creation_folded_into_prompt_tokens(self): + """prompt_token_count = input_tokens + cache_read + cache_creation.""" + usage = self._usage(input_tokens=100, output_tokens=50, cache_read=500, cache_creation=200) + meta = AnthropicModel._build_usage_metadata(usage) + assert meta.prompt_token_count == 100 + 500 + 200 + assert meta.candidates_token_count == 50 + assert meta.total_token_count == (100 + 500 + 200 + 50) + + def test_cache_fields_preserved_on_metadata(self): + """cache_read_input_tokens and cache_creation_input_tokens are directly set.""" + usage = self._usage(input_tokens=100, output_tokens=50, cache_read=500, cache_creation=200) + meta = AnthropicModel._build_usage_metadata(usage) + assert meta.cache_read_input_tokens == 500 + assert meta.cache_creation_input_tokens == 200 + + def test_none_cache_tokens_treated_as_zero(self): + """When cache fields are None, prompt_token_count equals input_tokens only.""" + usage = self._usage(input_tokens=100, output_tokens=50, cache_read=None, cache_creation=None) + meta = AnthropicModel._build_usage_metadata(usage) + assert meta.prompt_token_count == 100 + assert meta.total_token_count == 150 + + def test_zero_cache_tokens(self): + """When both cache fields are 0, prompt_token_count equals input_tokens only.""" + usage = self._usage(input_tokens=200, output_tokens=30, cache_read=0, cache_creation=0) + meta = AnthropicModel._build_usage_metadata(usage) + assert meta.prompt_token_count == 200 + + def test_only_cache_read_no_creation(self): + """Only cache_read; cache_creation is None.""" + usage = self._usage(input_tokens=50, output_tokens=10, cache_read=300, cache_creation=None) + meta = AnthropicModel._build_usage_metadata(usage) + assert meta.prompt_token_count == 50 + 300 + assert meta.cache_read_input_tokens == 300 + assert meta.cache_creation_input_tokens is None + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/models/test_litellm_model.py b/tests/models/test_litellm_model.py index 54370fee..a7810015 100644 --- a/tests/models/test_litellm_model.py +++ b/tests/models/test_litellm_model.py @@ -27,6 +27,7 @@ class TestLiteLLMModelInit: + def test_init_valid_provider_model(self): model = LiteLLMModel(model_name="openai/gpt-4", api_key="key", base_url="https://api.example.com") assert model._model_name == "openai/gpt-4" @@ -54,6 +55,7 @@ def test_ensure_litellm_imported_raises_when_litellm_not_installed(self): class TestGetMessageContent: + def test_content_string(self): model = LiteLLMModel(model_name="openai/gpt-4") msg = {"content": "hello"} @@ -76,11 +78,16 @@ def test_content_list_block_with_text_key(self): class TestCreateResponseWithContent: + def test_no_choices_returns_usage_and_error_code(self): model = LiteLLMModel(model_name="openai/gpt-4") response_dict = { CHOICES: [], - USAGE: {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + USAGE: { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3 + }, } r = model._create_response_with_content(response_dict, partial=False) assert r.content is None @@ -94,7 +101,11 @@ def test_no_message_returns_usage_and_error_code(self): model = LiteLLMModel(model_name="openai/gpt-4") response_dict = { CHOICES: [{}], - USAGE: {"prompt_tokens": 5, "completion_tokens": 0, "total_tokens": 5}, + USAGE: { + "prompt_tokens": 5, + "completion_tokens": 0, + "total_tokens": 5 + }, } r = model._create_response_with_content(response_dict, partial=False) assert r.content is None @@ -106,10 +117,17 @@ def test_normal_text_and_usage_no_error_code_when_stop(self): model = LiteLLMModel(model_name="openai/gpt-4") response_dict = { CHOICES: [{ - MESSAGE: {"content": "Hi there", "role": "assistant"}, + MESSAGE: { + "content": "Hi there", + "role": "assistant" + }, FINISH_REASON: "stop", }], - USAGE: {"prompt_tokens": 10, "completion_tokens": 2, "total_tokens": 12}, + USAGE: { + "prompt_tokens": 10, + "completion_tokens": 2, + "total_tokens": 12 + }, } r = model._create_response_with_content(response_dict, partial=False) assert r.error_code is None @@ -123,10 +141,17 @@ def test_finish_reason_not_stop_sets_error_code(self): model = LiteLLMModel(model_name="openai/gpt-4") response_dict = { CHOICES: [{ - MESSAGE: {"content": "x", "role": "assistant"}, + MESSAGE: { + "content": "x", + "role": "assistant" + }, FINISH_REASON: "length", }], - USAGE: {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + USAGE: { + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2 + }, } r = model._create_response_with_content(response_dict, partial=False) assert r.error_code == "length" @@ -138,17 +163,26 @@ def test_tool_calls_from_message(self): response_dict = { CHOICES: [{ MESSAGE: { - "content": None, - "role": "assistant", + "content": + None, + "role": + "assistant", TOOL_CALLS: [{ "id": "call_1", "type": "function", - "function": {"name": "get_weather", "arguments": '{"city": "NYC"}'}, + "function": { + "name": "get_weather", + "arguments": '{"city": "NYC"}' + }, }], }, FINISH_REASON: "tool_calls", }], - USAGE: {"prompt_tokens": 2, "completion_tokens": 3, "total_tokens": 5}, + USAGE: { + "prompt_tokens": 2, + "completion_tokens": 3, + "total_tokens": 5 + }, } r = model._create_response_with_content(response_dict, partial=False) assert r.content is not None @@ -162,7 +196,11 @@ def test_first_choice_none_uses_empty_dict(self): model = LiteLLMModel(model_name="openai/gpt-4") response_dict = { CHOICES: [None], - USAGE: {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + USAGE: { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }, } r = model._create_response_with_content(response_dict, partial=False) assert r.error_code == "NO_MESSAGE" @@ -172,10 +210,17 @@ def test_empty_parts_yields_single_empty_text_part_with_thought_false(self): model = LiteLLMModel(model_name="openai/gpt-4") response_dict = { CHOICES: [{ - MESSAGE: {"content": "", "role": "assistant"}, + MESSAGE: { + "content": "", + "role": "assistant" + }, FINISH_REASON: "stop", }], - USAGE: {"prompt_tokens": 1, "completion_tokens": 0, "total_tokens": 1}, + USAGE: { + "prompt_tokens": 1, + "completion_tokens": 0, + "total_tokens": 1 + }, } r = model._create_response_with_content(response_dict, partial=False) assert r.content is not None @@ -185,6 +230,7 @@ def test_empty_parts_yields_single_empty_text_part_with_thought_false(self): class TestBuildResponseFormatForLitellm: + def test_gemini_model_returns_response_schema(self): schema = {"type": "object", "properties": {"x": {"type": "string"}}} out = _build_response_format_for_litellm(schema, "gemini/gemini-1.5-pro") @@ -215,12 +261,16 @@ def test_is_litellm_gemini_model(self): class TestLogUnsupportedConfigOptions: + @pytest.mark.filterwarnings("ignore:.*is not a valid.*:UserWarning") def test_logs_warning_when_unsupported_set(self): model = LiteLLMModel(model_name="openai/gpt-4") config = GenerateContentConfig( top_k=1, - safety_settings=[{"category": "HARM", "threshold": "BLOCK_MEDIUM"}], + safety_settings=[{ + "category": "HARM", + "threshold": "BLOCK_MEDIUM" + }], ) with patch("trpc_agent_sdk.models._litellm_model.logger") as mock_logger: model._log_unsupported_config_options(config) @@ -239,6 +289,7 @@ def test_no_warning_when_none_of_unsupported_set(self): class TestGenerateAsyncNonStream: + @pytest.mark.asyncio async def test_generate_async_non_stream_success(self): model = LiteLLMModel(model_name="openai/gpt-4", api_key="key") @@ -248,10 +299,17 @@ async def test_generate_async_non_stream_success(self): mock_response = Mock() mock_response.model_dump.return_value = { CHOICES: [{ - MESSAGE: {"content": "Hi!", "role": "assistant"}, + MESSAGE: { + "content": "Hi!", + "role": "assistant" + }, FINISH_REASON: "stop", }], - USAGE: {"prompt_tokens": 2, "completion_tokens": 1, "total_tokens": 3}, + USAGE: { + "prompt_tokens": 2, + "completion_tokens": 1, + "total_tokens": 3 + }, } async def fake_acompletion(**kwargs): @@ -261,10 +319,12 @@ async def fake_acompletion(**kwargs): real_import = builtins.__import__ mock_litellm = Mock() mock_litellm.acompletion = AsyncMock(side_effect=fake_acompletion) + def fake_import(name, *a, **k): if name == "litellm": return mock_litellm return real_import(name, *a, **k) + with patch.object(model, "_ensure_litellm_imported"): with patch("builtins.__import__", side_effect=fake_import): responses = [] @@ -286,10 +346,12 @@ async def test_generate_async_non_stream_api_error(self): real_import = builtins.__import__ mock_litellm = Mock() mock_litellm.acompletion = AsyncMock(side_effect=Exception("Connection refused")) + def fake_import(name, *a, **k): if name == "litellm": return mock_litellm return real_import(name, *a, **k) + with patch.object(model, "_ensure_litellm_imported"): with patch("builtins.__import__", side_effect=fake_import): @@ -305,20 +367,34 @@ async def test_generate_async_passes_response_format_for_openai_model(self): import builtins real_import = builtins.__import__ captured_kwargs = {} + async def capture_acompletion(**kwargs): captured_kwargs.update(kwargs) mock_resp = Mock() mock_resp.model_dump.return_value = { - CHOICES: [{MESSAGE: {"content": "ok", "role": "assistant"}, FINISH_REASON: "stop"}], - USAGE: {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + CHOICES: [{ + MESSAGE: { + "content": "ok", + "role": "assistant" + }, + FINISH_REASON: "stop" + }], + USAGE: { + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2 + }, } return mock_resp + mock_litellm = Mock() mock_litellm.acompletion = AsyncMock(side_effect=capture_acompletion) + def fake_import(name, *a, **k): if name == "litellm": return mock_litellm return real_import(name, *a, **k) + model = LiteLLMModel(model_name="openai/gpt-4") schema = {"type": "object", "properties": {"answer": {"type": "string"}}} config = GenerateContentConfig(max_output_tokens=10) @@ -340,20 +416,34 @@ async def test_generate_async_passes_response_format_for_gemini_model(self): import builtins real_import = builtins.__import__ captured_kwargs = {} + async def capture_acompletion(**kwargs): captured_kwargs.update(kwargs) mock_resp = Mock() mock_resp.model_dump.return_value = { - CHOICES: [{MESSAGE: {"content": "ok", "role": "assistant"}, FINISH_REASON: "stop"}], - USAGE: {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, + CHOICES: [{ + MESSAGE: { + "content": "ok", + "role": "assistant" + }, + FINISH_REASON: "stop" + }], + USAGE: { + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2 + }, } return mock_resp + mock_litellm = Mock() mock_litellm.acompletion = AsyncMock(side_effect=capture_acompletion) + def fake_import(name, *a, **k): if name == "litellm": return mock_litellm return real_import(name, *a, **k) + model = LiteLLMModel(model_name="gemini/gemini-1.5-pro") schema = {"type": "object", "properties": {"x": {"type": "string"}}} config = GenerateContentConfig(max_output_tokens=10) @@ -372,6 +462,7 @@ def fake_import(name, *a, **k): class TestGenerateAsyncStream: + @pytest.mark.asyncio async def test_generate_async_stream_yields_partial_then_final(self): model = LiteLLMModel(model_name="openai/gpt-4") @@ -380,18 +471,30 @@ async def test_generate_async_stream_yields_partial_then_final(self): chunk1 = Mock() chunk1.model_dump.return_value = { - CHOICES: [{DELTA: {CONTENT: "Hello "}}], + CHOICES: [{ + DELTA: { + CONTENT: "Hello " + } + }], USAGE: None, } chunk2 = Mock() chunk2.model_dump.return_value = { - CHOICES: [{DELTA: {CONTENT: "world"}}], + CHOICES: [{ + DELTA: { + CONTENT: "world" + } + }], USAGE: None, } chunk3 = Mock() chunk3.model_dump.return_value = { CHOICES: [], - USAGE: {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + USAGE: { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3 + }, } async def fake_stream(**kwargs): @@ -403,10 +506,12 @@ async def fake_stream(**kwargs): real_import = builtins.__import__ mock_litellm = Mock() mock_litellm.acompletion = AsyncMock(return_value=fake_stream()) + def fake_import(name, *a, **k): if name == "litellm": return mock_litellm return real_import(name, *a, **k) + with patch.object(model, "_ensure_litellm_imported"): with patch("builtins.__import__", side_effect=fake_import): responses = [] @@ -431,18 +536,22 @@ async def test_generate_async_stream_includes_stream_options_in_api_params(self) async def capture_acompletion(**kwargs): captured_kwargs.update(kwargs) + async def empty_stream(): yield None + return empty_stream() import builtins real_import = builtins.__import__ mock_litellm = Mock() mock_litellm.acompletion = AsyncMock(side_effect=capture_acompletion) + def fake_import(name, *a, **k): if name == "litellm": return mock_litellm return real_import(name, *a, **k) + with patch.object(model, "_ensure_litellm_imported"): with patch("builtins.__import__", side_effect=fake_import): async for _ in model.generate_async(request, stream=True): @@ -463,10 +572,12 @@ async def failing_stream(**kwargs): real_import = builtins.__import__ mock_litellm = Mock() mock_litellm.acompletion = AsyncMock(side_effect=failing_stream) + def fake_import(name, *a, **k): if name == "litellm": return mock_litellm return real_import(name, *a, **k) + with patch.object(model, "_ensure_litellm_imported"): with patch("builtins.__import__", side_effect=fake_import): responses = [] @@ -478,6 +589,7 @@ def fake_import(name, *a, **k): class TestLiteLLMModelValidateRequest: + def test_validate_request_empty_contents_raises(self): model = LiteLLMModel(model_name="openai/gpt-4") request = LlmRequest(contents=[], config=None, tools_dict={}) @@ -489,3 +601,404 @@ def test_validate_request_valid_passes(self): content = Content(parts=[Part.from_text(text="Hi")], role="user") request = LlmRequest(contents=[content], config=None, tools_dict={}) model.validate_request(request) + + +# =========================================================================== +# Prompt cache — provider family classification +# =========================================================================== + + +class TestLiteLLMCacheFamily: + """_litellm_cache_family must route known prefixes to the correct family.""" + + def _family(self, model_name: str): + from trpc_agent_sdk.models._litellm_model import _litellm_cache_family + return _litellm_cache_family(model_name) + + def test_anthropic_prefix_is_anthropic_family(self): + assert self._family("anthropic/claude-3-5-sonnet") == "anthropic" + + def test_bedrock_prefix_is_anthropic_family(self): + assert self._family("bedrock/anthropic.claude-3-haiku") == "anthropic" + + def test_vertex_ai_prefix_is_anthropic_family(self): + assert self._family("vertex_ai/claude-3-opus") == "anthropic" + + def test_vertex_ai_beta_prefix_is_anthropic_family(self): + assert self._family("vertex_ai_beta/gemini-1.5") == "anthropic" + + def test_gemini_prefix_is_anthropic_family(self): + assert self._family("gemini/gemini-1.5-flash") == "anthropic" + + def test_openai_prefix_is_openai_managed_family(self): + assert self._family("openai/gpt-4o") == "openai_managed" + + def test_azure_prefix_is_openai_managed_family(self): + assert self._family("azure/gpt-35-turbo") == "openai_managed" + + def test_deepseek_prefix_is_openai_managed_family(self): + assert self._family("deepseek/deepseek-chat") == "openai_managed" + + def test_xai_prefix_is_openai_managed_family(self): + assert self._family("xai/grok-1") == "openai_managed" + + def test_unknown_prefix_returns_none(self): + assert self._family("unknown/some-model") is None + + def test_groq_prefix_returns_none(self): + """groq is not in either cache family list.""" + assert self._family("groq/llama-3") is None + + def test_prefix_matching_is_case_insensitive(self): + assert self._family("ANTHROPIC/claude-3") == "anthropic" + assert self._family("OpenAI/gpt-4") == "openai_managed" + + +# =========================================================================== +# Prompt cache — Anthropic-family (non-Bedrock) request shaping +# =========================================================================== + + +class TestLiteLLMApplyPromptCacheAnthropicFamily: + """For anthropic-family non-Bedrock models, _apply_prompt_cache should: + - stamp cache_control on the last tool directly (not via injection_points), + - add cache_control_injection_points for system / messages breakpoints. + """ + + def _model(self, model_name: str = "anthropic/claude-3-5-sonnet", **kw): + from trpc_agent_sdk.configs import PromptCacheConfig + kw.setdefault("prompt_cache_config", + PromptCacheConfig( + enabled=True, + ttl="1h", + breakpoints=["tools", "system", "messages"], + )) + return LiteLLMModel(model_name=model_name, api_key="k", **kw) + + def test_tools_breakpoint_stamps_last_tool_directly(self): + """Non-Bedrock anthropic provider stamps cache_control on tools[-1] directly.""" + model = self._model() + api_params = { + "tools": [{ + "name": "t1" + }, { + "name": "t2" + }], + "messages": [], + } + model._apply_prompt_cache(api_params, None) + assert "cache_control" not in api_params["tools"][0] + assert api_params["tools"][-1]["cache_control"]["type"] == "ephemeral" + assert api_params["tools"][-1]["cache_control"]["ttl"] == "1h" + + def test_injection_points_added_for_system_and_messages(self): + """cache_control_injection_points contains entries for system and latest assistant message.""" + model = self._model() + api_params = { + "tools": [], + "messages": [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + {"role": "user", "content": "again"}, + ], + } + model._apply_prompt_cache(api_params, None) + points = api_params.get("cache_control_injection_points", []) + locations = {p.get("location") for p in points} + roles = {p.get("role") for p in points if "role" in p} + assert "message" in locations + assert "system" in roles + index_points = [p for p in points if p.get("index") == 1] + assert len(index_points) == 1 + + def test_disabled_config_leaves_api_params_unchanged(self): + from trpc_agent_sdk.configs import PromptCacheConfig + model = LiteLLMModel( + model_name="anthropic/claude-3-5-sonnet", + api_key="k", + prompt_cache_config=PromptCacheConfig(enabled=False), + ) + api_params = {"tools": [{"name": "t1"}]} + model._apply_prompt_cache(api_params, None) + assert "cache_control" not in api_params["tools"][0] + assert "cache_control_injection_points" not in api_params + + def test_empty_breakpoints_leaves_api_params_unchanged(self): + from trpc_agent_sdk.configs import PromptCacheConfig + model = LiteLLMModel( + model_name="anthropic/claude-3-5-sonnet", + api_key="k", + prompt_cache_config=PromptCacheConfig(enabled=True, breakpoints=[]), + ) + api_params = {"tools": [{"name": "t1"}]} + model._apply_prompt_cache(api_params, None) + assert "cache_control" not in api_params["tools"][0] + assert "cache_control_injection_points" not in api_params + + def test_anthropic_family_ttl_is_forwarded_to_litellm(self): + """Anthropic-family LiteLLM routes pass TTL through for provider adapters to handle.""" + from trpc_agent_sdk.configs import PromptCacheConfig + model = LiteLLMModel( + model_name="anthropic/claude-3-5-sonnet", + api_key="k", + prompt_cache_config=PromptCacheConfig(enabled=True, ttl="in_memory", breakpoints=["tools"]), + ) + api_params = {"tools": [{"name": "t1"}]} + + model._apply_prompt_cache(api_params, None) + + assert api_params["tools"][-1]["cache_control"] == { + "type": "ephemeral", + "ttl": "in_memory", + } + + +# =========================================================================== +# Prompt cache — Bedrock request shaping +# =========================================================================== + + +class TestLiteLLMApplyPromptCacheBedrock: + """For Bedrock models, tool-level cache_control must NOT be set directly; + instead a tool_config injection_point is added.""" + + def _model(self): + from trpc_agent_sdk.configs import PromptCacheConfig + return LiteLLMModel( + model_name="bedrock/anthropic.claude-3-haiku", + api_key="k", + prompt_cache_config=PromptCacheConfig( + enabled=True, + breakpoints=["tools", "system"], + ), + ) + + def test_bedrock_does_not_stamp_tool_directly(self): + """Bedrock models must NOT get cache_control on the tool dict itself.""" + model = self._model() + api_params = {"tools": [{"name": "t1"}], "messages": []} + model._apply_prompt_cache(api_params, None) + assert "cache_control" not in api_params["tools"][0] + + def test_bedrock_adds_tool_config_injection_point(self): + """Bedrock tool cachePoint is expressed as {"location": "tool_config"}.""" + model = self._model() + api_params = {"tools": [{"name": "t1"}], "messages": []} + model._apply_prompt_cache(api_params, None) + points = api_params.get("cache_control_injection_points", []) + tool_config_points = [p for p in points if p.get("location") == "tool_config"] + assert len(tool_config_points) == 1 + + def test_bedrock_system_injection_point_present(self): + """system injection point is added for Bedrock, same as non-Bedrock.""" + model = self._model() + api_params = {"tools": [{"name": "t1"}], "messages": []} + model._apply_prompt_cache(api_params, None) + points = api_params.get("cache_control_injection_points", []) + system_points = [p for p in points if p.get("role") == "system"] + assert len(system_points) == 1 + + +# =========================================================================== +# Prompt cache — OpenAI-managed family request shaping +# =========================================================================== + + +class TestLiteLLMApplyPromptCacheOpenAIFamily: + """For openai-managed-family models, cache config is routed as top-level LiteLLM params.""" + + def _model(self, model_name: str, **kw): + from trpc_agent_sdk.configs import PromptCacheConfig + kw.setdefault("prompt_cache_config", PromptCacheConfig( + enabled=True, + prompt_cache_key="my-key", + ttl="24h", + )) + return LiteLLMModel(model_name=model_name, api_key="k", **kw) + + def test_openai_cache_key_and_retention_written_to_top_level_params(self): + """prompt_cache_key and prompt_cache_retention are top-level LiteLLM params.""" + model = self._model("openai/gpt-4o") + api_params: dict = {} + model._apply_prompt_cache(api_params, None) + assert api_params.get("prompt_cache_key") == "my-key" + assert api_params.get("prompt_cache_retention") == "24h" + assert "extra_body" not in api_params + + def test_openai_existing_extra_body_is_preserved(self): + """Pre-existing extra_body dict entries are preserved when cache keys are added.""" + from trpc_agent_sdk.configs import PromptCacheConfig + model = LiteLLMModel( + model_name="openai/gpt-4o", + api_key="k", + prompt_cache_config=PromptCacheConfig( + enabled=True, + prompt_cache_key="new-key", + ), + ) + api_params = {"extra_body": {"user": "alice"}} + model._apply_prompt_cache(api_params, None) + assert api_params["extra_body"] == {"user": "alice"} + assert api_params["prompt_cache_key"] == "new-key" + + def test_openai_custom_ttl_is_forwarded(self): + """OpenAI-family LiteLLM routes pass TTL through for provider adapters to handle.""" + from trpc_agent_sdk.configs import PromptCacheConfig + model = LiteLLMModel( + model_name="openai/gpt-4o", + api_key="k", + prompt_cache_config=PromptCacheConfig(enabled=True, ttl="1h"), + ) + api_params: dict = {} + model._apply_prompt_cache(api_params, None) + assert api_params.get("prompt_cache_retention") == "1h" + + +# =========================================================================== +# Prompt cache — Azure OpenAI (skips prompt_cache_retention) +# =========================================================================== + + +class TestLiteLLMApplyPromptCacheAzure: + """Azure OpenAI supports prompt_cache_key but not prompt_cache_retention.""" + + def _model(self): + from trpc_agent_sdk.configs import PromptCacheConfig + return LiteLLMModel( + model_name="azure/gpt-35-turbo", + api_key="k", + prompt_cache_config=PromptCacheConfig( + enabled=True, + prompt_cache_key="az-key", + ttl="24h", + ), + ) + + def test_azure_sets_cache_key(self): + """prompt_cache_key is forwarded for Azure.""" + model = self._model() + api_params: dict = {} + model._apply_prompt_cache(api_params, None) + assert api_params.get("prompt_cache_key") == "az-key" + + def test_azure_does_not_set_prompt_cache_retention(self): + """prompt_cache_retention must NOT be set for Azure, even when TTL is provided.""" + model = self._model() + api_params: dict = {} + model._apply_prompt_cache(api_params, None) + assert "prompt_cache_retention" not in api_params + + +# =========================================================================== +# Prompt cache — unknown provider family +# =========================================================================== + + +class TestLiteLLMApplyPromptCacheUnknownFamily: + """Unknown provider prefix with enabled cache config must warn and leave params clean.""" + + def test_unknown_prefix_logs_warning_and_leaves_params_unchanged(self): + from trpc_agent_sdk.configs import PromptCacheConfig + model = LiteLLMModel( + model_name="groq/llama-3", + api_key="k", + prompt_cache_config=PromptCacheConfig(enabled=True, ttl="1h"), + ) + api_params: dict = {"tools": [{"name": "t1"}]} + with patch("trpc_agent_sdk.models._litellm_model.logger") as mock_log: + model._apply_prompt_cache(api_params, None) + mock_log.warning.assert_called_once() + assert "cache_control" not in api_params.get("tools", [{}])[0] + assert "cache_control_injection_points" not in api_params + assert "extra_body" not in api_params + + +# =========================================================================== +# Prompt cache — _set_extra_body utility +# =========================================================================== + + +class TestLiteLLMSetExtraBody: + """_set_extra_body merges keys into api_params['extra_body'] correctly.""" + + def _set(self, api_params: dict, key: str, value) -> None: + LiteLLMModel._set_extra_body(api_params, key, value) + + def test_creates_extra_body_dict_when_absent(self): + api_params: dict = {} + self._set(api_params, "foo", "bar") + assert api_params["extra_body"] == {"foo": "bar"} + + def test_merges_into_existing_extra_body(self): + api_params = {"extra_body": {"x": 1}} + self._set(api_params, "y", 2) + assert api_params["extra_body"] == {"x": 1, "y": 2} + + def test_replaces_non_dict_extra_body_with_warning(self): + api_params = {"extra_body": "invalid"} + with patch("trpc_agent_sdk.models._litellm_model.logger") as mock_log: + self._set(api_params, "k", "v") + mock_log.warning.assert_called_once() + assert api_params["extra_body"] == {"k": "v"} + + +# =========================================================================== +# Prompt cache — _build_cache_injection_points +# =========================================================================== + + +class TestLiteLLMBuildCacheInjectionPoints: + """_build_cache_injection_points returns the correct point descriptors.""" + + def _build(self, model_name: str, breakpoints: list, ttl=None, messages=None): + model = LiteLLMModel(model_name=model_name, api_key="k") + return model._build_cache_injection_points(breakpoints, ttl, messages) + + def test_system_breakpoint_adds_message_role_system(self): + points = self._build("anthropic/claude-3", ["system"]) + assert any(p.get("role") == "system" for p in points) + + def test_messages_breakpoint_adds_latest_assistant_index(self): + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + {"role": "user", "content": "again"}, + ] + points = self._build("anthropic/claude-3", ["messages"], messages=messages) + assert any(p.get("index") == 1 for p in points) + + def test_messages_breakpoint_without_assistant_adds_nothing(self): + messages = [{"role": "user", "content": "hi"}] + points = self._build("anthropic/claude-3", ["messages"], messages=messages) + assert points == [] + + def test_tools_breakpoint_bedrock_adds_tool_config(self): + points = self._build("bedrock/anthropic.claude", ["tools"]) + assert any(p.get("location") == "tool_config" for p in points) + + def test_tools_breakpoint_non_bedrock_adds_nothing(self): + """For non-Bedrock providers, tools are stamped directly on the tool; no injection point.""" + points = self._build("anthropic/claude-3", ["tools"]) + assert not any(p.get("location") == "tool_config" for p in points) + + def test_ttl_is_included_in_control_dict(self): + points = self._build("anthropic/claude-3", ["system"], ttl="1h") + system_points = [p for p in points if p.get("role") == "system"] + assert len(system_points) == 1 + assert system_points[0]["control"]["ttl"] == "1h" + + def test_no_ttl_produces_ephemeral_only_control(self): + points = self._build("anthropic/claude-3", ["system"], ttl=None) + system_points = [p for p in points if p.get("role") == "system"] + assert system_points[0]["control"] == {"type": "ephemeral"} + + def test_all_non_bedrock_breakpoints_no_tool_config_point(self): + """All three breakpoints for a non-Bedrock provider: no tool_config point.""" + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + points = self._build("anthropic/claude-3", ["tools", "system", "messages"], messages=messages) + assert not any(p.get("location") == "tool_config" for p in points) + assert any(p.get("role") == "system" for p in points) + assert any(p.get("index") == 1 for p in points) diff --git a/tests/models/test_openai_model.py b/tests/models/test_openai_model.py index d806e7bd..7f15ea0d 100644 --- a/tests/models/test_openai_model.py +++ b/tests/models/test_openai_model.py @@ -597,3 +597,186 @@ async def test_generate_async_empty_response_content(self): assert len(responses) == 1 assert responses[0].content is not None + + +# =========================================================================== +# Prompt cache — request field injection +# =========================================================================== + + +class TestOpenAIPromptCacheRequestFields: + """Verify prompt cache fields are added to api_params when cache config is active.""" + + def _make_api_params(self) -> dict: + """Minimal api_params skeleton similar to what OpenAIModel builds.""" + return {"model": "gpt-4", "messages": [{"role": "user", "content": "hi"}]} + + def _simulate_cache_injection(self, model, api_params: dict) -> None: + """Replicate the inline cache injection logic from _generate_async_impl.""" + from trpc_agent_sdk.models._openai_model import ApiParamsKey + cache_config = model._resolve_prompt_cache_config(None) + if cache_config: + if cache_config.prompt_cache_key: + api_params[ApiParamsKey.PROMPT_CACHE_KEY] = cache_config.prompt_cache_key + if cache_config.ttl: + api_params[ApiParamsKey.PROMPT_CACHE_RETENTION] = cache_config.ttl + + def test_enabled_config_adds_cache_key_and_retention(self): + """Both prompt_cache_key and prompt_cache_retention are forwarded when set.""" + from trpc_agent_sdk.configs import PromptCacheConfig + model = OpenAIModel( + model_name="gpt-4", + api_key="k", + prompt_cache_config=PromptCacheConfig( + enabled=True, + prompt_cache_key="weather-v1", + ttl="24h", + ), + ) + api_params = self._make_api_params() + self._simulate_cache_injection(model, api_params) + assert api_params.get("prompt_cache_key") == "weather-v1" + assert api_params.get("prompt_cache_retention") == "24h" + + def test_ttl_in_memory_is_forwarded(self): + """'in_memory' is forwarded as prompt_cache_retention.""" + from trpc_agent_sdk.configs import PromptCacheConfig + model = OpenAIModel( + model_name="gpt-4", + api_key="k", + prompt_cache_config=PromptCacheConfig(enabled=True, ttl="in_memory"), + ) + api_params = self._make_api_params() + self._simulate_cache_injection(model, api_params) + assert api_params.get("prompt_cache_retention") == "in_memory" + + def test_custom_ttl_is_forwarded(self): + """TTL is provider-specific and should be forwarded without SDK validation.""" + from trpc_agent_sdk.configs import PromptCacheConfig + model = OpenAIModel( + model_name="gpt-4", + api_key="k", + prompt_cache_config=PromptCacheConfig(enabled=True, ttl="1h"), + ) + api_params = self._make_api_params() + self._simulate_cache_injection(model, api_params) + assert api_params.get("prompt_cache_retention") == "1h" + + def test_disabled_config_adds_no_cache_fields(self): + """Disabled PromptCacheConfig must not inject any cache-related keys.""" + from trpc_agent_sdk.configs import PromptCacheConfig + from trpc_agent_sdk.models._openai_model import ApiParamsKey + model = OpenAIModel( + model_name="gpt-4", + api_key="k", + prompt_cache_config=PromptCacheConfig(enabled=False, prompt_cache_key="k", ttl="24h"), + ) + api_params = self._make_api_params() + self._simulate_cache_injection(model, api_params) + assert ApiParamsKey.PROMPT_CACHE_KEY not in api_params + assert ApiParamsKey.PROMPT_CACHE_RETENTION not in api_params + + def test_no_config_adds_no_cache_fields(self): + """No model-level config means no cache keys are added.""" + from trpc_agent_sdk.models._openai_model import ApiParamsKey + model = OpenAIModel(model_name="gpt-4", api_key="k") + api_params = self._make_api_params() + self._simulate_cache_injection(model, api_params) + assert ApiParamsKey.PROMPT_CACHE_KEY not in api_params + assert ApiParamsKey.PROMPT_CACHE_RETENTION not in api_params + + def test_config_without_cache_key_omits_key_field(self): + """prompt_cache_key not in api_params when config has no prompt_cache_key.""" + from trpc_agent_sdk.configs import PromptCacheConfig + from trpc_agent_sdk.models._openai_model import ApiParamsKey + model = OpenAIModel( + model_name="gpt-4", + api_key="k", + prompt_cache_config=PromptCacheConfig(enabled=True, ttl="24h"), + ) + api_params = self._make_api_params() + self._simulate_cache_injection(model, api_params) + assert ApiParamsKey.PROMPT_CACHE_KEY not in api_params + assert api_params.get("prompt_cache_retention") == "24h" + + +# =========================================================================== +# Prompt cache — usage metadata parsing +# =========================================================================== + + +class TestOpenAIBuildUsageMetadata: + """Tests for OpenAIModel._build_usage_metadata cache token normalization.""" + + def test_prompt_tokens_details_cached_tokens_mapped_to_cache_read(self): + """OpenAI prompt_tokens_details.cached_tokens maps to cache_read_input_tokens.""" + usage_data = { + "prompt_tokens": 1000, + "completion_tokens": 50, + "total_tokens": 1050, + "prompt_tokens_details": { + "cached_tokens": 800 + }, + } + meta = OpenAIModel._build_usage_metadata(usage_data) + assert meta.cache_read_input_tokens == 800 + assert meta.prompt_token_count == 1000 + assert meta.candidates_token_count == 50 + + def test_top_level_cache_read_preferred_over_details(self): + """If top-level cache_read_input_tokens is set, it wins over prompt_tokens_details.""" + usage_data = { + "prompt_tokens": 1000, + "completion_tokens": 50, + "total_tokens": 1050, + "cache_read_input_tokens": 600, + "prompt_tokens_details": { + "cached_tokens": 800 + }, + } + meta = OpenAIModel._build_usage_metadata(usage_data) + assert meta.cache_read_input_tokens == 600 + + def test_no_cache_fields_yields_none(self): + """When no cache fields are present, cache_read_input_tokens is None.""" + usage_data = { + "prompt_tokens": 100, + "completion_tokens": 20, + "total_tokens": 120, + } + meta = OpenAIModel._build_usage_metadata(usage_data) + assert meta.cache_read_input_tokens is None + assert meta.cache_creation_input_tokens is None + + def test_cache_creation_input_tokens_top_level(self): + """top-level cache_creation_input_tokens (LiteLLM-compatible) is forwarded.""" + usage_data = { + "prompt_tokens": 100, + "completion_tokens": 10, + "total_tokens": 110, + "cache_creation_input_tokens": 90, + } + meta = OpenAIModel._build_usage_metadata(usage_data) + assert meta.cache_creation_input_tokens == 90 + + def test_empty_prompt_tokens_details_does_not_crash(self): + """Empty prompt_tokens_details dict is handled safely.""" + usage_data = { + "prompt_tokens": 50, + "completion_tokens": 10, + "total_tokens": 60, + "prompt_tokens_details": {}, + } + meta = OpenAIModel._build_usage_metadata(usage_data) + assert meta.cache_read_input_tokens is None + + def test_null_prompt_tokens_details_does_not_crash(self): + """Explicit null prompt_tokens_details is handled safely.""" + usage_data = { + "prompt_tokens": 50, + "completion_tokens": 10, + "total_tokens": 60, + "prompt_tokens_details": None, + } + meta = OpenAIModel._build_usage_metadata(usage_data) + assert meta.cache_read_input_tokens is None diff --git a/tests/models/test_prompt_cache_config.py b/tests/models/test_prompt_cache_config.py new file mode 100644 index 00000000..f2eebff2 --- /dev/null +++ b/tests/models/test_prompt_cache_config.py @@ -0,0 +1,204 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Unit tests for PromptCacheConfig and LLMModel._resolve_prompt_cache_config. + +These tests verify: +- PromptCacheConfig default field values +- Resolution priority: model-level vs run-level config +- Merge semantics: run config overrides only explicitly set fields +- Disabled configs are suppressed +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from trpc_agent_sdk.configs import PromptCacheConfig, RunConfig +from trpc_agent_sdk.models import AnthropicModel + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _model(**kwargs) -> AnthropicModel: + """Return a minimal AnthropicModel (concrete subclass of LLMModel for resolution tests).""" + kwargs.setdefault("model_name", "claude-3-5-sonnet-20241022") + kwargs.setdefault("api_key", "test-key") + return AnthropicModel(**kwargs) + + +def _ctx(prompt_cache: PromptCacheConfig | None = None) -> MagicMock: + """Return a lightweight mock InvocationContext with a RunConfig.""" + ctx = MagicMock() + ctx.run_config = RunConfig(prompt_cache=prompt_cache) + return ctx + + +# --------------------------------------------------------------------------- +# PromptCacheConfig defaults +# --------------------------------------------------------------------------- + + +class TestPromptCacheConfigDefaults: + """PromptCacheConfig should have sensible defaults out of the box.""" + + def test_disabled_by_default(self): + cfg = PromptCacheConfig() + assert cfg.enabled is False + + def test_breakpoints_default_to_system(self): + cfg = PromptCacheConfig() + assert cfg.breakpoints == ["system"] + + def test_ttl_default_is_none(self): + cfg = PromptCacheConfig() + assert cfg.ttl is None + + def test_prompt_cache_key_default_is_none(self): + cfg = PromptCacheConfig() + assert cfg.prompt_cache_key is None + + +# --------------------------------------------------------------------------- +# _resolve_prompt_cache_config: no config present +# --------------------------------------------------------------------------- + + +class TestResolvePromptCacheConfigNoConfig: + """When neither model-level nor run-level config is set, resolver returns None.""" + + def test_no_model_config_no_ctx_returns_none(self): + model = _model() + assert model._resolve_prompt_cache_config(None) is None + + def test_no_model_config_ctx_without_run_cache_returns_none(self): + model = _model() + ctx = _ctx(prompt_cache=None) + assert model._resolve_prompt_cache_config(ctx) is None + + +# --------------------------------------------------------------------------- +# _resolve_prompt_cache_config: disabled configs +# --------------------------------------------------------------------------- + + +class TestResolvePromptCacheConfigDisabled: + """Disabled configs (enabled=False) must always return None.""" + + def test_model_config_disabled_returns_none(self): + cache_cfg = PromptCacheConfig(enabled=False) + model = _model(prompt_cache_config=cache_cfg) + assert model._resolve_prompt_cache_config(None) is None + + def test_run_config_disabled_returns_none(self): + run_cache = PromptCacheConfig(enabled=False) + model = _model() + ctx = _ctx(prompt_cache=run_cache) + assert model._resolve_prompt_cache_config(ctx) is None + + def test_model_config_enabled_run_config_disabled_returns_none(self): + """Per-run enabled=False must suppress a model-level enabled config.""" + model_cache = PromptCacheConfig(enabled=True, ttl="1h", breakpoints=["system"]) + run_cache = PromptCacheConfig(enabled=False) + model = _model(prompt_cache_config=model_cache) + ctx = _ctx(prompt_cache=run_cache) + assert model._resolve_prompt_cache_config(ctx) is None + + +# --------------------------------------------------------------------------- +# _resolve_prompt_cache_config: enabled configs +# --------------------------------------------------------------------------- + + +class TestResolvePromptCacheConfigEnabled: + """Enabled configs are returned and merged correctly.""" + + def test_model_level_enabled_no_ctx(self): + """Model-level enabled config is returned when there is no run context.""" + cache_cfg = PromptCacheConfig(enabled=True, ttl="1h", breakpoints=["system", "tools"]) + model = _model(prompt_cache_config=cache_cfg) + resolved = model._resolve_prompt_cache_config(None) + assert resolved is not None + assert resolved.enabled is True + assert resolved.ttl == "1h" + assert resolved.breakpoints == ["system", "tools"] + + def test_run_level_enabled_no_model_config(self): + """Run-level config is used when no model-level config exists.""" + run_cache = PromptCacheConfig(enabled=True, ttl="5m") + model = _model() + ctx = _ctx(prompt_cache=run_cache) + resolved = model._resolve_prompt_cache_config(ctx) + assert resolved is not None + assert resolved.enabled is True + assert resolved.ttl == "5m" + + def test_model_level_enabled_ctx_with_no_run_cache(self): + """Model-level config is used when ctx has no run-level cache config.""" + cache_cfg = PromptCacheConfig(enabled=True, prompt_cache_key="my-key") + model = _model(prompt_cache_config=cache_cfg) + ctx = _ctx(prompt_cache=None) + resolved = model._resolve_prompt_cache_config(ctx) + assert resolved is not None + assert resolved.prompt_cache_key == "my-key" + + +# --------------------------------------------------------------------------- +# _resolve_prompt_cache_config: merge semantics +# --------------------------------------------------------------------------- + + +class TestResolvePromptCacheConfigMerge: + """Run config overrides only explicitly-set fields; model baseline is preserved.""" + + def test_run_config_overrides_ttl_only(self): + """Run config sets ttl; model's breakpoints and prompt_cache_key are preserved.""" + model_cache = PromptCacheConfig( + enabled=True, + ttl="1h", + breakpoints=["system", "tools"], + prompt_cache_key="base-key", + ) + # Run config only explicitly sets ttl (enabled must be True to pass resolver) + run_cache = PromptCacheConfig(enabled=True, ttl="5m") + model = _model(prompt_cache_config=model_cache) + ctx = _ctx(prompt_cache=run_cache) + resolved = model._resolve_prompt_cache_config(ctx) + assert resolved is not None + # run overrides ttl + assert resolved.ttl == "5m" + # model baseline fields not set in run config are preserved + assert resolved.breakpoints == ["system", "tools"] + assert resolved.prompt_cache_key == "base-key" + + def test_run_config_overrides_prompt_cache_key_only(self): + """Run config sets prompt_cache_key; model's ttl and breakpoints are preserved.""" + model_cache = PromptCacheConfig( + enabled=True, + ttl="1h", + breakpoints=["system"], + ) + run_cache = PromptCacheConfig(enabled=True, prompt_cache_key="override-key") + model = _model(prompt_cache_config=model_cache) + ctx = _ctx(prompt_cache=run_cache) + resolved = model._resolve_prompt_cache_config(ctx) + assert resolved is not None + assert resolved.prompt_cache_key == "override-key" + assert resolved.ttl == "1h" + assert resolved.breakpoints == ["system"] + + def test_run_config_overrides_breakpoints(self): + """Run config breakpoints field overrides the model-level list.""" + model_cache = PromptCacheConfig(enabled=True, breakpoints=["system"]) + run_cache = PromptCacheConfig(enabled=True, breakpoints=["tools", "messages"]) + model = _model(prompt_cache_config=model_cache) + ctx = _ctx(prompt_cache=run_cache) + resolved = model._resolve_prompt_cache_config(ctx) + assert resolved is not None + assert resolved.breakpoints == ["tools", "messages"] diff --git a/tests/telemetry/test_metrics.py b/tests/telemetry/test_metrics.py index d02c086c..fc96ec07 100644 --- a/tests/telemetry/test_metrics.py +++ b/tests/telemetry/test_metrics.py @@ -61,9 +61,17 @@ def __init__(self, model: str): class _StubUsage: - def __init__(self, prompt: int, total: int): + def __init__( + self, + prompt: int, + total: int, + cache_read: Optional[int] = None, + cache_creation: Optional[int] = None, + ): self.prompt_token_count = prompt self.total_token_count = total + self.cache_read_input_tokens = cache_read + self.cache_creation_input_tokens = cache_creation class _StubLlmResponse: @@ -97,6 +105,8 @@ def reader_provider(monkeypatch): "_time_to_first_token": tmetrics._time_to_first_token, "_usage_input_tokens": tmetrics._usage_input_tokens, "_usage_output_tokens": tmetrics._usage_output_tokens, + "_usage_cache_read_tokens": tmetrics._usage_cache_read_tokens, + "_usage_cache_creation_tokens": tmetrics._usage_cache_creation_tokens, } monkeypatch.setattr( tmetrics, @@ -123,6 +133,16 @@ def reader_provider(monkeypatch): "_usage_output_tokens", meter.create_histogram("gen_ai.usage.output_tokens"), ) + monkeypatch.setattr( + tmetrics, + "_usage_cache_read_tokens", + meter.create_histogram("gen_ai.usage.cache_read_input_tokens"), + ) + monkeypatch.setattr( + tmetrics, + "_usage_cache_creation_tokens", + meter.create_histogram("gen_ai.usage.cache_creation_input_tokens"), + ) yield reader, provider @@ -276,6 +296,27 @@ def test_usage_tokens_emitted_when_usage_metadata_present(self, reader_provider) assert inp.sum == 120 assert out.sum == 50 + def test_cache_usage_tokens_emitted_when_present(self, reader_provider): + reader, _ = reader_provider + ctx = _make_ctx() + tmetrics.report_call_llm( + ctx, + _StubLlmRequest("claude-3-5-sonnet"), + _StubLlmResponse( + model="claude-3-5-sonnet", + usage=_StubUsage(prompt=120, total=170, cache_read=80, cache_creation=40), + ), + duration_s=2.0, + ttft_s=0.3, + is_stream=False, + ) + metrics = _collect(reader) + + cache_read = metrics["gen_ai.usage.cache_read_input_tokens"][0] + cache_creation = metrics["gen_ai.usage.cache_creation_input_tokens"][0] + assert cache_read.sum == 80 + assert cache_creation.sum == 40 + def test_usage_tokens_skipped_when_missing(self, reader_provider): reader, _ = reader_provider ctx = _make_ctx() diff --git a/tests/telemetry/test_trace.py b/tests/telemetry/test_trace.py index 46ce8eaa..27cf2a35 100644 --- a/tests/telemetry/test_trace.py +++ b/tests/telemetry/test_trace.py @@ -37,11 +37,11 @@ trace_tool_call, ) - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _mock_span(): """Return a MagicMock that acts as an OpenTelemetry span.""" span = MagicMock() @@ -66,9 +66,13 @@ def _make_content(parts=None, role="user"): return content -def _make_invocation_context(agent_name="test_agent", session_id="sess-1", - user_id="user-1", user_content=None, branch=None, - invocation_id="inv-1", session=None): +def _make_invocation_context(agent_name="test_agent", + session_id="sess-1", + user_id="user-1", + user_content=None, + branch=None, + invocation_id="inv-1", + session=None): ctx = MagicMock() ctx.agent = MagicMock() ctx.agent.name = agent_name @@ -104,7 +108,9 @@ def _make_function_response(resp_id="fc-1", response=None): # Tests: set/get span name # --------------------------------------------------------------------------- + class TestSpanName: + def setup_method(self): set_trpc_agent_span_name("trpc.python.agent") @@ -127,7 +133,9 @@ def teardown_method(self): # Tests: _safe_json_serialize # --------------------------------------------------------------------------- + class TestSafeJsonSerialize: + def test_serialize_dict(self): result = _safe_json_serialize({"key": "value"}) assert json.loads(result) == {"key": "value"} @@ -173,7 +181,9 @@ def test_serialize_empty_dict(self): # Tests: trace_runner # --------------------------------------------------------------------------- + class TestTraceRunner: + def setup_method(self): set_trpc_agent_span_name("trpc.python.agent") @@ -277,12 +287,8 @@ def test_with_state_begin_and_end(self, mock_get_span): trace_runner("app", "u", "s", ctx, state_begin={"k": 1}, state_end={"k": 2}) - span.set_attribute.assert_any_call( - "trpc.python.agent.state.begin", _safe_json_serialize({"k": 1}) - ) - span.set_attribute.assert_any_call( - "trpc.python.agent.state.end", _safe_json_serialize({"k": 2}) - ) + span.set_attribute.assert_any_call("trpc.python.agent.state.begin", _safe_json_serialize({"k": 1})) + span.set_attribute.assert_any_call("trpc.python.agent.state.end", _safe_json_serialize({"k": 2})) @patch("trpc_agent_sdk.telemetry._trace.trace.get_current_span") def test_state_none_not_set(self, mock_get_span): @@ -327,7 +333,9 @@ def teardown_method(self): # Tests: trace_cancellation # --------------------------------------------------------------------------- + class TestTraceCancellation: + def setup_method(self): set_trpc_agent_span_name("trpc.python.agent") @@ -348,12 +356,8 @@ def test_basic_cancellation(self, mock_get_span): from opentelemetry import trace as ot_trace span.set_status.assert_called_once_with(ot_trace.StatusCode.ERROR, "user_cancelled") span.set_attribute.assert_any_call("gen_ai.operation.name", "run_runner_cancelled") - span.set_attribute.assert_any_call( - "trpc.python.agent.cancellation.reason", "user_cancelled" - ) - span.set_attribute.assert_any_call( - "trpc.python.agent.cancellation.agent_name", "test_agent" - ) + span.set_attribute.assert_any_call("trpc.python.agent.cancellation.reason", "user_cancelled") + span.set_attribute.assert_any_call("trpc.python.agent.cancellation.agent_name", "test_agent") @patch("trpc_agent_sdk.telemetry._trace.trace.get_current_span") def test_with_partial_text(self, mock_get_span): @@ -363,9 +367,7 @@ def test_with_partial_text(self, mock_get_span): trace_cancellation("app", "u", "s", ctx, reason="timeout", partial_text="partial output") - span.set_attribute.assert_any_call( - "trpc.python.agent.runner.output", "[CANCELLED]\npartial output" - ) + span.set_attribute.assert_any_call("trpc.python.agent.runner.output", "[CANCELLED]\npartial output") @patch("trpc_agent_sdk.telemetry._trace.trace.get_current_span") def test_with_last_event_no_partial(self, mock_get_span): @@ -379,9 +381,7 @@ def test_with_last_event_no_partial(self, mock_get_span): trace_cancellation("app", "u", "s", ctx, reason="err", last_event=last_event) - span.set_attribute.assert_any_call( - "trpc.python.agent.runner.output", "[CANCELLED]\nevent text" - ) + span.set_attribute.assert_any_call("trpc.python.agent.runner.output", "[CANCELLED]\nevent text") @patch("trpc_agent_sdk.telemetry._trace.trace.get_current_span") def test_partial_text_takes_priority_over_last_event(self, mock_get_span): @@ -394,15 +394,16 @@ def test_partial_text_takes_priority_over_last_event(self, mock_get_span): last_event = _make_event(content=event_content) trace_cancellation( - "app", "u", "s", ctx, + "app", + "u", + "s", + ctx, reason="err", partial_text="winner", last_event=last_event, ) - span.set_attribute.assert_any_call( - "trpc.python.agent.runner.output", "[CANCELLED]\nwinner" - ) + span.set_attribute.assert_any_call("trpc.python.agent.runner.output", "[CANCELLED]\nwinner") @patch("trpc_agent_sdk.telemetry._trace.trace.get_current_span") def test_with_branch(self, mock_get_span): @@ -412,9 +413,7 @@ def test_with_branch(self, mock_get_span): trace_cancellation("app", "u", "s", ctx, reason="cancel") - span.set_attribute.assert_any_call( - "trpc.python.agent.cancellation.branch", "agent_a.agent_b" - ) + span.set_attribute.assert_any_call("trpc.python.agent.cancellation.branch", "agent_a.agent_b") @patch("trpc_agent_sdk.telemetry._trace.trace.get_current_span") def test_no_branch(self, mock_get_span): @@ -434,18 +433,17 @@ def test_with_state_begin_and_partial(self, mock_get_span): ctx = _make_invocation_context() trace_cancellation( - "app", "u", "s", ctx, + "app", + "u", + "s", + ctx, reason="cancel", state_begin={"a": 1}, state_partial={"a": 2}, ) - span.set_attribute.assert_any_call( - "trpc.python.agent.state.begin", _safe_json_serialize({"a": 1}) - ) - span.set_attribute.assert_any_call( - "trpc.python.agent.state.partial", _safe_json_serialize({"a": 2}) - ) + span.set_attribute.assert_any_call("trpc.python.agent.state.begin", _safe_json_serialize({"a": 1})) + span.set_attribute.assert_any_call("trpc.python.agent.state.partial", _safe_json_serialize({"a": 2})) @patch("trpc_agent_sdk.telemetry._trace.trace.get_current_span") def test_state_none_not_set(self, mock_get_span): @@ -490,7 +488,9 @@ def teardown_method(self): # Tests: trace_agent # --------------------------------------------------------------------------- + class TestTraceAgent: + def setup_method(self): set_trpc_agent_span_name("trpc.python.agent") @@ -562,12 +562,8 @@ def test_with_state(self, mock_get_span): trace_agent(ctx, state_begin={"x": 1}, state_end={"x": 2}) - span.set_attribute.assert_any_call( - "trpc.python.agent.state.begin", _safe_json_serialize({"x": 1}) - ) - span.set_attribute.assert_any_call( - "trpc.python.agent.state.end", _safe_json_serialize({"x": 2}) - ) + span.set_attribute.assert_any_call("trpc.python.agent.state.begin", _safe_json_serialize({"x": 1})) + span.set_attribute.assert_any_call("trpc.python.agent.state.end", _safe_json_serialize({"x": 2})) @patch("trpc_agent_sdk.telemetry._trace.trace.get_current_span") def test_default_empty_action(self, mock_get_span): @@ -587,7 +583,9 @@ def teardown_method(self): # Tests: trace_tool_call # --------------------------------------------------------------------------- + class TestTraceToolCall: + def setup_method(self): set_trpc_agent_span_name("trpc.python.agent") @@ -649,8 +647,7 @@ def test_tool_response_not_dict(self, mock_get_span): # Should wrap in {"result": ...} tool_resp_calls = [ - c for c in span.set_attribute.call_args_list - if c.args[0] == "trpc.python.agent.tool_response" + c for c in span.set_attribute.call_args_list if c.args[0] == "trpc.python.agent.tool_response" ] assert len(tool_resp_calls) == 1 parsed = json.loads(tool_resp_calls[0].args[1]) @@ -679,8 +676,7 @@ class MyModel(PydanticBaseModel): trace_tool_call(tool=tool, args={}, function_response_event=event) tool_resp_calls = [ - c for c in span.set_attribute.call_args_list - if c.args[0] == "trpc.python.agent.tool_response" + c for c in span.set_attribute.call_args_list if c.args[0] == "trpc.python.agent.tool_response" ] assert len(tool_resp_calls) == 1 parsed = json.loads(tool_resp_calls[0].args[1]) @@ -720,16 +716,15 @@ def test_with_state(self, mock_get_span): event = _make_event(content=content) trace_tool_call( - tool=tool, args={}, function_response_event=event, - state_begin={"s": 0}, state_end={"s": 1}, + tool=tool, + args={}, + function_response_event=event, + state_begin={"s": 0}, + state_end={"s": 1}, ) - span.set_attribute.assert_any_call( - "trpc.python.agent.state.begin", _safe_json_serialize({"s": 0}) - ) - span.set_attribute.assert_any_call( - "trpc.python.agent.state.end", _safe_json_serialize({"s": 1}) - ) + span.set_attribute.assert_any_call("trpc.python.agent.state.begin", _safe_json_serialize({"s": 0})) + span.set_attribute.assert_any_call("trpc.python.agent.state.end", _safe_json_serialize({"s": 1})) def teardown_method(self): set_trpc_agent_span_name("trpc.python.agent") @@ -739,7 +734,9 @@ def teardown_method(self): # Tests: trace_merged_tool_calls # --------------------------------------------------------------------------- + class TestTraceMergedToolCalls: + def setup_method(self): set_trpc_agent_span_name("trpc.python.agent") @@ -772,9 +769,7 @@ def test_non_serializable_event(self, mock_get_span): trace_merged_tool_calls(response_event_id="re-1", function_response_event=event) - span.set_attribute.assert_any_call( - "trpc.python.agent.tool_response", "" - ) + span.set_attribute.assert_any_call("trpc.python.agent.tool_response", "") @patch("trpc_agent_sdk.telemetry._trace.trace.get_current_span") def test_sets_empty_llm_request_and_response(self, mock_get_span): @@ -804,12 +799,8 @@ def test_with_state(self, mock_get_span): state_end={"a": 2}, ) - span.set_attribute.assert_any_call( - "trpc.python.agent.state.begin", _safe_json_serialize({"a": 1}) - ) - span.set_attribute.assert_any_call( - "trpc.python.agent.state.end", _safe_json_serialize({"a": 2}) - ) + span.set_attribute.assert_any_call("trpc.python.agent.state.begin", _safe_json_serialize({"a": 1})) + span.set_attribute.assert_any_call("trpc.python.agent.state.end", _safe_json_serialize({"a": 2})) @patch("trpc_agent_sdk.telemetry._trace.trace.get_current_span") def test_state_none_not_set(self, mock_get_span): @@ -833,7 +824,9 @@ def teardown_method(self): # Tests: trace_call_llm # --------------------------------------------------------------------------- + class TestTraceCallLlm: + def setup_method(self): set_trpc_agent_span_name("trpc.python.agent") @@ -929,15 +922,16 @@ def test_with_instruction_metadata(self, mock_get_span): ctx = _make_invocation_context() from trpc_agent_sdk.types import InstructionMetadata - metadata = InstructionMetadata( - name="my_instruction", version=3, labels=["prod", "v2"] - ) + metadata = InstructionMetadata(name="my_instruction", version=3, labels=["prod", "v2"]) req = self._make_llm_request() resp = self._make_llm_response() trace_call_llm( - ctx, event_id="e-1", llm_request=req, llm_response=resp, + ctx, + event_id="e-1", + llm_request=req, + llm_response=resp, instruction_metadata=metadata, ) @@ -971,10 +965,7 @@ def test_non_serializable_llm_response(self, mock_get_span): trace_call_llm(ctx, event_id="e-1", llm_request=req, llm_response=resp) - llm_resp_calls = [ - c for c in span.set_attribute.call_args_list - if c.args[0] == "trpc.python.agent.llm_response" - ] + llm_resp_calls = [c for c in span.set_attribute.call_args_list if c.args[0] == "trpc.python.agent.llm_response"] assert len(llm_resp_calls) == 1 assert llm_resp_calls[0].args[1] == "" @@ -991,12 +982,100 @@ def test_instruction_metadata_empty_labels(self, mock_get_span): resp = self._make_llm_response() trace_call_llm( - ctx, event_id="e-1", llm_request=req, llm_response=resp, + ctx, + event_id="e-1", + llm_request=req, + llm_response=resp, instruction_metadata=metadata, ) span.set_attribute.assert_any_call("trpc.python.agent.instruction.labels", "") + # --- prompt cache token span attributes -------------------------------- + + @patch("trpc_agent_sdk.telemetry._trace.trace.get_current_span") + def test_cache_read_input_tokens_attribute_set_when_present(self, mock_get_span): + """cache_read_input_tokens is written as a span attribute when the field is not None.""" + span = _mock_span() + mock_get_span.return_value = span + ctx = _make_invocation_context() + + usage = MagicMock() + usage.prompt_token_count = 1000 + usage.total_token_count = 1050 + usage.cache_read_input_tokens = 800 + usage.cache_creation_input_tokens = None + + req = self._make_llm_request() + resp = self._make_llm_response(usage=usage) + + trace_call_llm(ctx, event_id="e-1", llm_request=req, llm_response=resp) + + span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 800) + + @patch("trpc_agent_sdk.telemetry._trace.trace.get_current_span") + def test_cache_creation_input_tokens_attribute_set_when_present(self, mock_get_span): + """cache_creation_input_tokens is written as a span attribute when the field is not None.""" + span = _mock_span() + mock_get_span.return_value = span + ctx = _make_invocation_context() + + usage = MagicMock() + usage.prompt_token_count = 1000 + usage.total_token_count = 1100 + usage.cache_read_input_tokens = None + usage.cache_creation_input_tokens = 200 + + req = self._make_llm_request() + resp = self._make_llm_response(usage=usage) + + trace_call_llm(ctx, event_id="e-1", llm_request=req, llm_response=resp) + + span.set_attribute.assert_any_call("gen_ai.usage.cache_creation_input_tokens", 200) + + @patch("trpc_agent_sdk.telemetry._trace.trace.get_current_span") + def test_both_cache_token_attributes_set(self, mock_get_span): + """Both cache span attributes are emitted when both fields are non-None.""" + span = _mock_span() + mock_get_span.return_value = span + ctx = _make_invocation_context() + + usage = MagicMock() + usage.prompt_token_count = 1200 + usage.total_token_count = 1250 + usage.cache_read_input_tokens = 900 + usage.cache_creation_input_tokens = 300 + + req = self._make_llm_request() + resp = self._make_llm_response(usage=usage) + + trace_call_llm(ctx, event_id="e-1", llm_request=req, llm_response=resp) + + span.set_attribute.assert_any_call("gen_ai.usage.cache_read_input_tokens", 900) + span.set_attribute.assert_any_call("gen_ai.usage.cache_creation_input_tokens", 300) + + @patch("trpc_agent_sdk.telemetry._trace.trace.get_current_span") + def test_cache_token_attributes_absent_when_none(self, mock_get_span): + """Neither cache attribute is emitted when both fields are None.""" + span = _mock_span() + mock_get_span.return_value = span + ctx = _make_invocation_context() + + usage = MagicMock() + usage.prompt_token_count = 100 + usage.total_token_count = 150 + usage.cache_read_input_tokens = None + usage.cache_creation_input_tokens = None + + req = self._make_llm_request() + resp = self._make_llm_response(usage=usage) + + trace_call_llm(ctx, event_id="e-1", llm_request=req, llm_response=resp) + + attr_keys = [call.args[0] for call in span.set_attribute.call_args_list] + assert "gen_ai.usage.cache_read_input_tokens" not in attr_keys + assert "gen_ai.usage.cache_creation_input_tokens" not in attr_keys + def teardown_method(self): set_trpc_agent_span_name("trpc.python.agent") @@ -1005,7 +1084,9 @@ def teardown_method(self): # Tests: _build_llm_request_for_trace # --------------------------------------------------------------------------- + class TestBuildLlmRequestForTrace: + def test_basic_build(self): content = MagicMock() content.role = "user" @@ -1022,9 +1103,7 @@ def test_basic_build(self): with patch("trpc_agent_sdk.telemetry._trace.Content") as MockContent: mock_content_instance = MagicMock() - mock_content_instance.model_dump = MagicMock( - return_value={"role": "user", "parts": [{"text": "hello"}]} - ) + mock_content_instance.model_dump = MagicMock(return_value={"role": "user", "parts": [{"text": "hello"}]}) MockContent.return_value = mock_content_instance result = _build_llm_request_for_trace(req) @@ -1053,9 +1132,7 @@ def test_filters_inline_data_parts(self): with patch("trpc_agent_sdk.telemetry._trace.Content") as MockContent: mock_content_instance = MagicMock() - mock_content_instance.model_dump = MagicMock( - return_value={"role": "user", "parts": [{"text": "t"}]} - ) + mock_content_instance.model_dump = MagicMock(return_value={"role": "user", "parts": [{"text": "t"}]}) MockContent.return_value = mock_content_instance result = _build_llm_request_for_trace(req) diff --git a/trpc_agent_sdk/abc/_response.py b/trpc_agent_sdk/abc/_response.py index 86d6a718..b6fedf27 100644 --- a/trpc_agent_sdk/abc/_response.py +++ b/trpc_agent_sdk/abc/_response.py @@ -30,8 +30,9 @@ from google.genai.types import Content from google.genai.types import GenerateContentResponse -from google.genai.types import GenerateContentResponseUsageMetadata from google.genai.types import GroundingMetadata + +from trpc_agent_sdk.types import GenerateContentResponseUsageMetadata from pydantic import BaseModel from pydantic import ConfigDict from pydantic import alias_generators diff --git a/trpc_agent_sdk/configs/__init__.py b/trpc_agent_sdk/configs/__init__.py index c3b8f3f5..3b1c56c5 100644 --- a/trpc_agent_sdk/configs/__init__.py +++ b/trpc_agent_sdk/configs/__init__.py @@ -5,8 +5,10 @@ # tRPC-Agent-Python is licensed under Apache-2.0. """Configs for TRPC Agent framework.""" +from ._prompt_cache_config import PromptCacheConfig from ._run_config import RunConfig __all__ = [ + "PromptCacheConfig", "RunConfig", ] diff --git a/trpc_agent_sdk/configs/_prompt_cache_config.py b/trpc_agent_sdk/configs/_prompt_cache_config.py new file mode 100644 index 00000000..f116737b --- /dev/null +++ b/trpc_agent_sdk/configs/_prompt_cache_config.py @@ -0,0 +1,72 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Prompt cache configuration for TRPC Agent framework.""" + +from __future__ import annotations + +from typing import List +from typing import Literal +from typing import Optional + +from pydantic import BaseModel +from pydantic import Field + + +class PromptCacheConfig(BaseModel): + """Cross-provider prompt cache configuration. + + This is a single flat config for SDK-managed prompt cache customization. + Many providers already enable prompt caching automatically; for those + providers this config only supplies optional hints such as cache keys, + retention, or usage normalization. Fields are applied on a best-effort basis + depending on the resolved provider. Fields that do not apply to a given + provider are silently ignored (no error), so the same config "just works" + across Anthropic, OpenAI, and the LiteLLM channel. + + Default ``enabled=False`` means the SDK does not add cache-specific request + customization: it injects no ``cache_control`` and sends no + ``prompt_cache_key`` / ``prompt_cache_retention``. Provider-native automatic + prompt caching, when available, may still happen independently. + """ + + enabled: bool = False + """Master switch for SDK-managed prompt cache customization.""" + + ttl: Optional[str] = None + """Provider-specific cache lifetime hint. + + The SDK does not validate TTL values because supported values vary across + providers, deployments, and self-hosted OpenAI-compatible services. When set, + the value is forwarded to the resolved provider's cache TTL field; providers + may accept, ignore, or reject it. ``None`` means "do not send a lifetime + hint" (provider default). + """ + + breakpoints: List[Literal["tools", "system", "messages"]] = Field(default_factory=lambda: ["system"]) + """Cache-control injection points for Anthropic-style providers. + + Used by native Anthropic and LiteLLM models routed to the Anthropic cache + family; ignored by OpenAI-managed providers. Current injection behavior: + + - ``"tools"``: stamp the last tool with ``cache_control``. For LiteLLM + Bedrock models this is represented as a ``tool_config`` cache point. + - ``"system"``: stamp the system prompt/system message. + - ``"messages"``: stamp one conversation-message breakpoint on the most + recent assistant message, keeping the current user turn outside the cached + prefix. LiteLLM uses its ``cache_control_injection_points`` support to + target that assistant message by index. + + An empty list means the SDK does not add Anthropic-style cache-control + injection points. User-authored provider-specific cache metadata is still + forwarded when supported by the underlying model adapter. + """ + + prompt_cache_key: Optional[str] = None + """OpenAI-managed family only. + + Improves cache-hit stability by keeping same-prefix requests sticky to the + same backend. Only used by OpenAI-managed providers. + """ diff --git a/trpc_agent_sdk/configs/_run_config.py b/trpc_agent_sdk/configs/_run_config.py index 4c04ad7f..50f16d58 100644 --- a/trpc_agent_sdk/configs/_run_config.py +++ b/trpc_agent_sdk/configs/_run_config.py @@ -9,6 +9,7 @@ import sys from typing import Any +from typing import Optional from pydantic import BaseModel from pydantic import ConfigDict @@ -17,6 +18,8 @@ from trpc_agent_sdk.log import logger +from ._prompt_cache_config import PromptCacheConfig + class RunConfig(BaseModel): """Configs for runtime behavior of agents.""" @@ -53,6 +56,14 @@ class RunConfig(BaseModel): save_history_enabled: bool = False """ Save history enabled.""" + prompt_cache: Optional[PromptCacheConfig] = None + """Per-run prompt cache configuration override. + + When set, this takes precedence over the model-level ``prompt_cache_config`` + for the duration of the run. When ``None``, the model-level config (if any) + is used. + """ + start_from_last_agent: bool = False """ Whether to start from the last active agent in the session instead of the root agent. diff --git a/trpc_agent_sdk/events/__init__.py b/trpc_agent_sdk/events/__init__.py index 3ee4c966..5fde374d 100644 --- a/trpc_agent_sdk/events/__init__.py +++ b/trpc_agent_sdk/events/__init__.py @@ -8,6 +8,8 @@ from trpc_agent_sdk.types import EventActions from ._agent_cancelled_event import AgentCancelledEvent +from ._cache_analyzer import CacheMetrics +from ._cache_analyzer import analyze_cache_performance from ._event import Event from ._event_translator import EventTranslatorBase from ._long_running_event import LongRunningEvent @@ -16,6 +18,8 @@ __all__ = [ "EventActions", "AgentCancelledEvent", + "CacheMetrics", + "analyze_cache_performance", "Event", "EventTranslatorBase", "LongRunningEvent", diff --git a/trpc_agent_sdk/events/_cache_analyzer.py b/trpc_agent_sdk/events/_cache_analyzer.py new file mode 100644 index 00000000..85f952aa --- /dev/null +++ b/trpc_agent_sdk/events/_cache_analyzer.py @@ -0,0 +1,92 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Cache performance analyzer for TRPC Agent framework.""" + +from __future__ import annotations + +from dataclasses import dataclass +from dataclasses import field +from typing import TYPE_CHECKING +from typing import Iterable + +if TYPE_CHECKING: + from ._event import Event + + +@dataclass +class CacheMetrics: + """Aggregate cache performance metrics computed from a sequence of events. + + All ratio fields are percentages in the range ``[0.0, 100.0]``. + When no events carry usage metadata, all fields are zero. + + Attributes: + total_requests: Events that carry a ``usage_metadata`` object. + requests_with_cache_hits: Events where ``cache_read_input_tokens > 0``. + total_prompt_tokens: Sum of all ``prompt_token_count`` values. + total_cache_read_tokens: Sum of all ``cache_read_input_tokens`` values. + total_cache_creation_tokens: Sum of all ``cache_creation_input_tokens`` values. + cache_hit_ratio: ``total_cache_read_tokens / total_prompt_tokens * 100``. + cache_utilization_ratio: ``requests_with_cache_hits / total_requests * 100``. + avg_cached_tokens_per_request: ``total_cache_read_tokens / total_requests``. + """ + + total_requests: int = 0 + requests_with_cache_hits: int = 0 + total_prompt_tokens: int = 0 + total_cache_read_tokens: int = 0 + total_cache_creation_tokens: int = 0 + cache_hit_ratio: float = field(default=0.0) + cache_utilization_ratio: float = field(default=0.0) + avg_cached_tokens_per_request: float = field(default=0.0) + + +def analyze_cache_performance(events: Iterable[Event]) -> CacheMetrics: + """Compute cache performance metrics from an iterable of events. + + Events without ``usage_metadata`` are skipped. Missing cache token fields + default to ``0``. Division by zero is avoided; ratio fields stay ``0.0`` + when the denominator would be zero. + + Args: + events: Any iterable of :class:`~trpc_agent_sdk.events.Event` objects, + e.g. ``session.events`` or the list collected from an agent run. + + Returns: + A :class:`CacheMetrics` snapshot. + + Example:: + + from trpc_agent_sdk.events import analyze_cache_performance + + metrics = analyze_cache_performance(session.events) + print(f"Cache hit ratio: {metrics.cache_hit_ratio:.1f}%") + """ + metrics = CacheMetrics() + + for event in events: + usage = event.usage_metadata + if usage is None: + continue + + metrics.total_requests += 1 + metrics.total_prompt_tokens += usage.prompt_token_count or 0 + + cache_read = usage.cache_read_input_tokens or 0 + metrics.total_cache_read_tokens += cache_read + if cache_read > 0: + metrics.requests_with_cache_hits += 1 + + cache_creation = usage.cache_creation_input_tokens or 0 + metrics.total_cache_creation_tokens += cache_creation + + if metrics.total_prompt_tokens > 0: + metrics.cache_hit_ratio = metrics.total_cache_read_tokens / metrics.total_prompt_tokens * 100.0 + if metrics.total_requests > 0: + metrics.cache_utilization_ratio = metrics.requests_with_cache_hits / metrics.total_requests * 100.0 + metrics.avg_cached_tokens_per_request = metrics.total_cache_read_tokens / metrics.total_requests + + return metrics diff --git a/trpc_agent_sdk/models/_anthropic_model.py b/trpc_agent_sdk/models/_anthropic_model.py index 064d813a..39912a9d 100644 --- a/trpc_agent_sdk/models/_anthropic_model.py +++ b/trpc_agent_sdk/models/_anthropic_model.py @@ -38,6 +38,45 @@ from ._llm_response import LlmResponse from ._registry import register_model +_EPHEMERAL = "ephemeral" + + +def _build_cache_control(ttl: Optional[str]) -> Dict[str, Any]: + cache_control: Dict[str, Any] = {"type": _EPHEMERAL} + if ttl: + cache_control["ttl"] = ttl + return cache_control + + +def _stamp_last_block(blocks: List[Dict[str, Any]], cache_control: Dict[str, Any]) -> None: + for block in reversed(blocks): + block["cache_control"] = cache_control + return + + +def _apply_tools_cache_control(tools: List[anthropic_types.ToolParam], cache_control: Dict[str, Any]) -> None: + if tools: + tools[-1]["cache_control"] = cache_control + + +def _apply_system_cache_control(system: str, cache_control: Dict[str, Any]) -> List[Dict[str, Any]]: + return [{"type": "text", "text": system, "cache_control": cache_control}] + + +def _apply_messages_cache_control( + messages: List[anthropic_types.MessageParam], + cache_control: Dict[str, Any], +) -> None: + """Stamp a cache_control breakpoint on the last assistant message.""" + for message in reversed(messages): + if message.get("role") != "assistant": + continue + + content = message.get("content") + if isinstance(content, list) and content: + _stamp_last_block(content, cache_control) + return + class _FinishReason(str, Enum): """Reasons why model generation finished.""" @@ -65,6 +104,29 @@ class _ApiParamsKey(str, Enum): THINKING = "thinking" +def _inject_cache_control( + api_params: Dict[str, Any], + breakpoints: List[str], + ttl: Optional[str], +) -> None: + """Inject Anthropic cache_control breakpoints into api_params in place.""" + cache_control = _build_cache_control(ttl) + if "tools" in breakpoints and api_params.get(_ApiParamsKey.TOOLS): + _apply_tools_cache_control(api_params[_ApiParamsKey.TOOLS], cache_control) + if "system" in breakpoints and api_params.get(_ApiParamsKey.SYSTEM): + system = api_params[_ApiParamsKey.SYSTEM] + if isinstance(system, str): + api_params[_ApiParamsKey.SYSTEM] = _apply_system_cache_control(system, cache_control) + else: + logger.warning( + "Anthropic system cache_control injection expects a string system " + "prompt, got %s; skipping system cache_control injection.", + type(system).__name__, + ) + if "messages" in breakpoints and api_params.get(_ApiParamsKey.MESSAGES): + _apply_messages_cache_control(api_params[_ApiParamsKey.MESSAGES], cache_control) + + @register_model(model_name="AnthropicModel", supported_models=[r"claude-.*"]) class AnthropicModel(LLMModel): """Anthropic model implementation using the abstract model interface. @@ -358,6 +420,27 @@ def _content_block_to_part(self, content_block: anthropic_types.ContentBlock) -> return part raise NotImplementedError(f"Not supported yet: {type(content_block)}") + @staticmethod + def _build_usage_metadata(usage: anthropic_types.Usage) -> GenerateContentResponseUsageMetadata: + """Normalize Anthropic usage into a cache-inclusive shape. + + Anthropic ``input_tokens`` only counts tokens after the last cache + breakpoint. To report the full prompt size, fold cache read/write tokens + back into ``prompt_token_count``: + ``input_tokens + cache_read_input_tokens + cache_creation_input_tokens``. + """ + cache_read = usage.cache_read_input_tokens or 0 + cache_creation = usage.cache_creation_input_tokens or 0 + prompt_tokens = usage.input_tokens + cache_read + cache_creation + output_tokens = usage.output_tokens + return GenerateContentResponseUsageMetadata( + prompt_token_count=prompt_tokens, + candidates_token_count=output_tokens, + total_token_count=prompt_tokens + output_tokens, + cache_read_input_tokens=usage.cache_read_input_tokens, + cache_creation_input_tokens=usage.cache_creation_input_tokens, + ) + def _message_to_llm_response(self, message: anthropic_types.Message) -> LlmResponse: """Convert an Anthropic message to LlmResponse.""" logger.info("Received response from Anthropic Claude.") @@ -371,11 +454,7 @@ def _message_to_llm_response(self, message: anthropic_types.Message) -> LlmRespo role=const.MODEL, parts=[self._content_block_to_part(cb) for cb in message.content], ), - usage_metadata=GenerateContentResponseUsageMetadata( - prompt_token_count=message.usage.input_tokens, - candidates_token_count=message.usage.output_tokens, - total_token_count=(message.usage.input_tokens + message.usage.output_tokens), - ), + usage_metadata=self._build_usage_metadata(message.usage), ) def _merge_configs(self, request_config: Optional[GenerateContentConfig]) -> GenerateContentConfig: @@ -581,11 +660,7 @@ async def _generate_stream( if final_parts: final_content = Content(parts=final_parts, role=const.MODEL) - final_usage = GenerateContentResponseUsageMetadata( - prompt_token_count=final_message.usage.input_tokens, - candidates_token_count=final_message.usage.output_tokens, - total_token_count=(final_message.usage.input_tokens + final_message.usage.output_tokens), - ) + final_usage = self._build_usage_metadata(final_message.usage) yield LlmResponse(content=final_content, usage_metadata=final_usage, @@ -605,6 +680,13 @@ async def _generate_stream( finally: await client.close() + def _apply_prompt_cache(self, api_params: Dict[str, Any], ctx: InvocationContext | None) -> None: + """Inject Anthropic native cache_control breakpoints (opt-in, no-op when disabled).""" + cache_config = self._resolve_prompt_cache_config(ctx) + if not cache_config or not cache_config.breakpoints: + return + _inject_cache_control(api_params, cache_config.breakpoints, cache_config.ttl) + @override async def _generate_async_impl(self, request: LlmRequest, @@ -706,6 +788,8 @@ async def _generate_async_impl(self, logger.error("Error in Anthropic API parameters: %s", ex, exc_info=True) raise ex + self._apply_prompt_cache(api_params, ctx) + try: if stream: async for response in self._generate_stream(api_params, request, ctx): diff --git a/trpc_agent_sdk/models/_litellm_model.py b/trpc_agent_sdk/models/_litellm_model.py index 4a8bab17..5b5a2a0c 100644 --- a/trpc_agent_sdk/models/_litellm_model.py +++ b/trpc_agent_sdk/models/_litellm_model.py @@ -31,6 +31,47 @@ from ._openai_model import OpenAIModel from ._registry import register_model +# Cache families for LiteLLM provider routing. +_ANTHROPIC_FAMILY = "anthropic" # uses cache_control breakpoints +_OPENAI_FAMILY = "openai_managed" # uses provider-managed prefix caching + +# LiteLLM provider prefixes (``provider/model``) that use cache_control breakpoints. +# Sources: +# - https://docs.litellm.ai/docs/tutorials/prompt_caching (official provider list) +_CACHE_CONTROL_PREFIXES = ( + "anthropic/", + "bedrock/", + "vertex_ai/", + "vertex_ai_beta/", + "gemini/", + "azure_ai/", + "openrouter/", + "databricks/", + "dashscope/", + "minimax/", + "zai/", +) + +# LiteLLM provider prefixes that use provider-managed prefix caching. +# Note: azure/ supports prompt_cache_key but NOT prompt_cache_retention — +# Azure OpenAI does not expose a TTL retention control in its API. +_MANAGED_PREFIXES = ( + "openai/", + "azure/", + "deepseek/", + "xai/", +) + + +def _litellm_cache_family(model_name: str) -> Optional[str]: + lowered = model_name.lower() + if lowered.startswith(_CACHE_CONTROL_PREFIXES): + return _ANTHROPIC_FAMILY + if lowered.startswith(_MANAGED_PREFIXES): + return _OPENAI_FAMILY + return None + + _LITELLM_SUPPORTED_MODELS: List[str] = [ r"openai/.*", r"anthropic/.*", @@ -109,6 +150,10 @@ class _LiteLLMApiParamsKey(str, Enum): API_KEY = "api_key" API_BASE = "api_base" STREAM_OPTS = "stream_options" + EXTRA_BODY = "extra_body" + PROMPT_CACHE_KEY = "prompt_cache_key" + PROMPT_CACHE_RETENTION = "prompt_cache_retention" + CACHE_CONTROL_INJECTION_POINTS = "cache_control_injection_points" @register_model(model_name="LiteLLMModel", supported_models=_LITELLM_SUPPORTED_MODELS) @@ -146,6 +191,135 @@ def _ensure_litellm_imported(self) -> None: os.environ.setdefault("LITELLM_MODE", "PRODUCTION") LiteLLMModel._litellm_imported = True + def _apply_prompt_cache(self, api_params: Dict[str, Any], ctx: InvocationContext | None) -> None: + """Apply prompt cache config to LiteLLM api_params (best-effort, in place).""" + cache_config = self._resolve_prompt_cache_config(ctx) + if not cache_config: + return + + family = _litellm_cache_family(self._model_name) + + if family is None: + logger.warning( + "prompt_cache_config is set but model %r has no recognized provider prefix; " + "cache config will be ignored. Use a 'provider/model' name (e.g. 'openai/gpt-4o') " + "so the SDK can select the correct cache mechanism.", + self._model_name, + ) + return + + if family == _ANTHROPIC_FAMILY: + if not cache_config.breakpoints: + return + ttl = cache_config.ttl + # tools breakpoint: stamp cache_control directly on the last tool + # (LiteLLM's _map_tool_helper transparently forwards it to Anthropic). + # Bedrock uses a separate tool_config cachePoint via injection_points. + if "tools" in cache_config.breakpoints: + self._apply_tools_cache_control(api_params, ttl) + points = self._build_cache_injection_points( + cache_config.breakpoints, + ttl, + api_params.get(_LiteLLMApiParamsKey.MESSAGES), + ) + if points: + api_params[_LiteLLMApiParamsKey.CACHE_CONTROL_INJECTION_POINTS] = points + elif family == _OPENAI_FAMILY: + if cache_config.prompt_cache_key: + api_params[_LiteLLMApiParamsKey.PROMPT_CACHE_KEY] = cache_config.prompt_cache_key + if cache_config.ttl: + if not self._model_name.lower().startswith("azure/"): + api_params[_LiteLLMApiParamsKey.PROMPT_CACHE_RETENTION] = cache_config.ttl + + def _apply_tools_cache_control(self, api_params: Dict[str, Any], ttl: Optional[str]) -> None: + """Stamp cache_control on the last tool in api_params (in place). + + For non-Bedrock Anthropic upstreams LiteLLM's _map_tool_helper forwards + a tool-level ``cache_control`` field directly to the Anthropic API, so we + mutate the tool list here rather than using cache_control_injection_points. + For Bedrock the injection_points mechanism handles tools via tool_config. + """ + if self._model_name.lower().startswith("bedrock/"): + return + tools = api_params.get(_LiteLLMApiParamsKey.TOOLS) + if not tools: + return + cache_control: Dict[str, Any] = {"type": "ephemeral"} + if ttl: + cache_control["ttl"] = ttl + tools[-1]["cache_control"] = cache_control + + def _build_cache_injection_points( + self, + breakpoints: List[str], + ttl: Optional[str], + messages: Any = None, + ) -> List[Dict[str, Any]]: + """Build LiteLLM ``cache_control_injection_points`` for system/messages/Bedrock-tools. + + - ``system`` -> stamps cache_control on the system message (by role). + - ``messages`` -> stamps cache_control on the most recent assistant + message, matching the native Anthropic adapter. + - ``tools`` -> Bedrock only: tool_config cachePoint (no control/ttl field; + LiteLLM always emits {"cachePoint": {"type": "default"}} for this). + Non-Bedrock tools are handled separately by _apply_tools_cache_control. + """ + + def _make_cache_control() -> Dict[str, Any]: + cache_control: Dict[str, Any] = {"type": "ephemeral"} + if ttl: + cache_control["ttl"] = ttl + return cache_control + + points: List[Dict[str, Any]] = [] + if "system" in breakpoints: + points.append({"location": "message", "role": "system", "control": _make_cache_control()}) + if "messages" in breakpoints: + assistant_index = self._last_assistant_message_index(messages) + if assistant_index is not None: + points.append({ + "location": "message", + "index": assistant_index, + "control": _make_cache_control(), + }) + if "tools" in breakpoints and self._model_name.lower().startswith("bedrock/"): + # Bedrock's tool_config cachePoint has no control/ttl field — + # LiteLLM ignores any control dict here and always emits + # {"cachePoint": {"type": "default"}}. + points.append({"location": "tool_config"}) + return points + + @staticmethod + def _last_assistant_message_index(messages: Any) -> Optional[int]: + if not isinstance(messages, list): + return None + for index in range(len(messages) - 1, -1, -1): + message = messages[index] + if isinstance(message, dict) and message.get("role") == "assistant": + return index + return None + + @staticmethod + def _set_extra_body(api_params: Dict[str, Any], key: str, value: Any) -> None: + """Set a key inside api_params' extra_body, reusing an existing dict if present. + + If an existing ``extra_body`` is not a dict (e.g. a string or some other + type), it is replaced and a warning is emitted so the caller is aware + that prior extra-body data has been discarded. + """ + current = api_params.get(_LiteLLMApiParamsKey.EXTRA_BODY) + if isinstance(current, dict): + current[key] = value + else: + if current is not None: + logger.warning( + "api_params['extra_body'] has unexpected type %s (expected dict); " + "replacing it to set %r. Existing extra_body content is lost.", + type(current).__name__, + key, + ) + api_params[_LiteLLMApiParamsKey.EXTRA_BODY] = {key: value} + def _get_message_content(self, message: Any) -> str: """Extract text from message.content (str or list of blocks). message: dict.""" content = message.get("content") if message else None @@ -309,6 +483,8 @@ async def _generate_async_impl( logger.error("Error in LiteLLM API parameters: %s", ex, exc_info=True) raise + self._apply_prompt_cache(api_params, ctx) + try: if stream: async for response in self._generate_stream(api_params, request, ctx): diff --git a/trpc_agent_sdk/models/_llm_model.py b/trpc_agent_sdk/models/_llm_model.py index 940afa16..e892b31a 100644 --- a/trpc_agent_sdk/models/_llm_model.py +++ b/trpc_agent_sdk/models/_llm_model.py @@ -17,6 +17,7 @@ from typing import Optional from typing import final +from trpc_agent_sdk.configs import PromptCacheConfig from trpc_agent_sdk.context import InvocationContext from trpc_agent_sdk.context import create_agent_context from trpc_agent_sdk.filter import BaseFilter @@ -33,16 +34,50 @@ class LLMModel(FilterRunner): """Abstract base class for all model implementations.""" - def __init__(self, model_name: str, filters_name: Optional[list[str]] = None, **kwargs): + def __init__( + self, + model_name: str, + filters_name: Optional[list[str]] = None, + prompt_cache_config: Optional[PromptCacheConfig] = None, + **kwargs, + ): filters: list = kwargs.get("filters", []) super().__init__(filters_name=filters_name, filters=filters) self._model_name = model_name self.config = kwargs + self.prompt_cache_config = prompt_cache_config self._type = FilterType.MODEL self._init_filters() self._api_key: str = kwargs.get(const.API_KEY, "") self._base_url: str = kwargs.get(const.BASE_URL, "") + def _resolve_prompt_cache_config( + self, + ctx: Optional[InvocationContext] = None, + ) -> Optional[PromptCacheConfig]: + """Resolve the effective prompt cache config for a call. + + The model-level ``prompt_cache_config`` is the baseline; per-run + ``RunConfig.prompt_cache`` (via ``ctx``) overrides it field-by-field, + so a run can tweak just one field (e.g. ``cache_key``) without having to + re-declare the rest. Returns the merged config only when it is enabled, + otherwise ``None`` (callers treat ``None`` as "do nothing"). + """ + base = self.prompt_cache_config + run = ctx.run_config.prompt_cache if (ctx is not None and ctx.run_config is not None) else None + + if run is None: + config = base + elif base is None: + config = run + else: + # Only fields explicitly set on the per-run config override the baseline. + config = base.model_copy(update=run.model_dump(exclude_unset=True)) + + if config is None or not config.enabled: + return None + return config + def set_api_key(self, value: str) -> None: """Set the API key.""" self._api_key = value diff --git a/trpc_agent_sdk/models/_openai_model.py b/trpc_agent_sdk/models/_openai_model.py index 2be72335..1819da15 100644 --- a/trpc_agent_sdk/models/_openai_model.py +++ b/trpc_agent_sdk/models/_openai_model.py @@ -102,6 +102,8 @@ class ApiParamsKey(str, Enum): MAX_COMPLETION_TOKENS = "max_completion_tokens" REASONING_EFFORT = "reasoning_effort" PARALLEL_TOOL_CALLS = "parallel_tool_calls" + PROMPT_CACHE_KEY = "prompt_cache_key" + PROMPT_CACHE_RETENTION = "prompt_cache_retention" @register_model(model_name="OpenAIModel", supported_models=[r"gpt-.*", r"o1-.*", r"deepseek-.*", r"hy3-.*"]) @@ -699,26 +701,32 @@ def _process_tool_call_delta(self, tool_call_delta: dict, accumulated_tool_calls if ToolKey.ARGUMENTS in function_delta and function_delta[ToolKey.ARGUMENTS] is not None: accumulated_tool_calls[index][ToolKey.FUNCTION][ToolKey.ARGUMENTS] += function_delta[ToolKey.ARGUMENTS] - def _process_usage(self, chunk_dict: dict) -> Optional[GenerateContentResponseUsageMetadata]: - """Process usage information from a chunk. - - Args: - chunk_dict (`dict`): The chunk dictionary containing usage information + @staticmethod + def _build_usage_metadata(usage_data: dict) -> GenerateContentResponseUsageMetadata: + """Build ``GenerateContentResponseUsageMetadata`` from a raw usage dict. - Returns: - `Optional[GenerateContentResponseUsageMetadata]`: The processed usage metadata or None if not available + ``cache_read_input_tokens`` prefers Anthropic/LiteLLM-style top-level fields; + falls back to OpenAI-style ``prompt_tokens_details.cached_tokens``. """ - usage_data = chunk_dict.get(const.USAGE) - if usage_data is None: - return None completion_details = usage_data.get("completion_tokens_details") or {} + cache_read = usage_data.get("cache_read_input_tokens") + if cache_read is None: + details = usage_data.get("prompt_tokens_details") + cache_read = details.get("cached_tokens") if isinstance(details, dict) else None return GenerateContentResponseUsageMetadata( prompt_token_count=usage_data.get("prompt_tokens", 0), candidates_token_count=usage_data.get("completion_tokens", 0), thoughts_token_count=completion_details.get("reasoning_tokens"), total_token_count=usage_data.get("total_tokens", 0), + cache_read_input_tokens=cache_read, + cache_creation_input_tokens=usage_data.get("cache_creation_input_tokens"), ) + def _process_usage(self, chunk_dict: dict) -> Optional[GenerateContentResponseUsageMetadata]: + """Extract usage metadata from a streaming chunk dict.""" + usage_data = chunk_dict.get(const.USAGE) + return self._build_usage_metadata(usage_data) if usage_data is not None else None + def _process_chunk_without_content( self, chunk_dict: dict, accumulated_tool_calls: list[dict] ) -> tuple[Optional[FinishReason], Optional[GenerateContentResponseUsageMetadata], dict[int, str]]: @@ -975,25 +983,9 @@ def _process_tool_calls_from_message(self, message: dict) -> Optional[List[ToolC return tool_calls or None def _process_usage_from_response(self, response_dict: dict) -> Optional[GenerateContentResponseUsageMetadata]: - """Process usage information from a response. - - Args: - response_dict (`dict`): The response dictionary containing usage information - - Returns: - `Optional[GenerateContentResponseUsageMetadata]`: Processed usage metadata or None - """ - if const.USAGE not in response_dict: - return None - - usage_data: dict[str, int] = response_dict[const.USAGE] - completion_details = usage_data.get("completion_tokens_details") or {} - return GenerateContentResponseUsageMetadata( - prompt_token_count=usage_data.get("prompt_tokens", 0), - candidates_token_count=usage_data.get("completion_tokens", 0), - thoughts_token_count=completion_details.get("reasoning_tokens"), - total_token_count=usage_data.get("total_tokens", 0), - ) + """Extract usage metadata from a non-streaming response dict.""" + usage_data = response_dict.get(const.USAGE) + return self._build_usage_metadata(usage_data) if usage_data is not None else None def _create_response_without_content(self, response_dict: dict) -> LlmResponse: """Create a LlmResponse without content.""" @@ -1505,6 +1497,13 @@ async def _generate_async_impl(self, logger.error("Error in OpenAI API parameters: %s", ex, exc_info=True) raise ex + cache_config = self._resolve_prompt_cache_config(ctx) + if cache_config: + if cache_config.prompt_cache_key: + api_params[ApiParamsKey.PROMPT_CACHE_KEY] = cache_config.prompt_cache_key + if cache_config.ttl: + api_params[ApiParamsKey.PROMPT_CACHE_RETENTION] = cache_config.ttl + # Extract HTTP options for API calls http_options = {} if request.config: @@ -1555,6 +1554,8 @@ async def _generate_stream(self, "reasoning": self._adapter.create_streaming_text_filter_state(), } + api_params[ApiParamsKey.STREAM_OPTS] = {ApiParamsKey.INCLUDE_USAGE: True} + client = self._create_async_client() try: logger.debug("openai invoke with params: %s", api_params) diff --git a/trpc_agent_sdk/telemetry/_metrics.py b/trpc_agent_sdk/telemetry/_metrics.py index a2d7a80f..7e0cc6b2 100644 --- a/trpc_agent_sdk/telemetry/_metrics.py +++ b/trpc_agent_sdk/telemetry/_metrics.py @@ -50,6 +50,16 @@ description="Completion tokens produced by a gen_ai operation.", unit="{token}", ) +_usage_cache_read_tokens = _meter.create_histogram( + name="gen_ai.usage.cache_read_input_tokens", + description="Input tokens served from the prompt cache.", + unit="{token}", +) +_usage_cache_creation_tokens = _meter.create_histogram( + name="gen_ai.usage.cache_creation_input_tokens", + description="Input tokens written to the prompt cache (Anthropic only).", + unit="{token}", +) # OTel GenAI semconv attribute keys. _ATTR_OPERATION_NAME = "gen_ai.operation.name" @@ -153,11 +163,17 @@ def report_call_llm( if llm_response is not None and llm_response.usage_metadata is not None: usage = llm_response.usage_metadata - prompt = getattr(usage, "prompt_token_count", None) or 0 - total = getattr(usage, "total_token_count", None) or 0 + prompt = usage.prompt_token_count or 0 + total = usage.total_token_count or 0 if prompt and total: _usage_input_tokens.record(prompt, attrs) _usage_output_tokens.record(max(total - prompt, 0), attrs) + cache_read = usage.cache_read_input_tokens + if cache_read is not None: + _usage_cache_read_tokens.record(cache_read, attrs) + cache_creation = usage.cache_creation_input_tokens + if cache_creation is not None: + _usage_cache_creation_tokens.record(cache_creation, attrs) def report_execute_tool( diff --git a/trpc_agent_sdk/telemetry/_trace.py b/trpc_agent_sdk/telemetry/_trace.py index aac6657b..e9740ed8 100644 --- a/trpc_agent_sdk/telemetry/_trace.py +++ b/trpc_agent_sdk/telemetry/_trace.py @@ -445,6 +445,12 @@ def trace_call_llm( "gen_ai.usage.output_tokens", output_tokens, ) + cache_read = usage.cache_read_input_tokens + if cache_read is not None: + span.set_attribute("gen_ai.usage.cache_read_input_tokens", cache_read) + cache_creation = usage.cache_creation_input_tokens + if cache_creation is not None: + span.set_attribute("gen_ai.usage.cache_creation_input_tokens", cache_creation) if instruction_metadata is not None: span.set_attribute(f"{_trpc_agent_span_name}.instruction.name", instruction_metadata.name) diff --git a/trpc_agent_sdk/types/__init__.py b/trpc_agent_sdk/types/__init__.py index 06d26a31..316a59f5 100644 --- a/trpc_agent_sdk/types/__init__.py +++ b/trpc_agent_sdk/types/__init__.py @@ -10,6 +10,7 @@ """Types module for TRPC Agent framework.""" from google.genai.types import * # noqa: F401,F403 +from ._usage import GenerateContentResponseUsageMetadata from ._agent_types import ActiveStreamingTool from ._agent_types import LiveRequest @@ -25,6 +26,7 @@ from ._ttl import Ttl __all__ = [ + "GenerateContentResponseUsageMetadata", "ActiveStreamingTool", "LiveRequest", "LiveRequestQueue", diff --git a/trpc_agent_sdk/types/_usage.py b/trpc_agent_sdk/types/_usage.py new file mode 100644 index 00000000..3cdddb37 --- /dev/null +++ b/trpc_agent_sdk/types/_usage.py @@ -0,0 +1,33 @@ +# Tencent is pleased to support the open source community by making tRPC-Agent-Python available. +# +# Copyright (C) 2026 Tencent. All rights reserved. +# +# tRPC-Agent-Python is licensed under Apache-2.0. +"""Usage metadata types for TRPC Agent framework. + +This module extends Google GenAI's ``GenerateContentResponseUsageMetadata`` with +prompt-cache token counters that are not present upstream. It is re-exported from +``trpc_agent_sdk.types`` so it shadows the upstream type for all SDK callers. +""" + +from __future__ import annotations + +from typing import Optional + +from google.genai.types import GenerateContentResponseUsageMetadata as _BaseUsageMetadata + + +class GenerateContentResponseUsageMetadata(_BaseUsageMetadata): + """Usage metadata extended with prompt-cache token counters. + + Adds two provider-normalized fields on top of the upstream type: + + - ``cache_read_input_tokens``: input tokens served from cache (Anthropic + ``cache_read_input_tokens`` / OpenAI ``prompt_tokens_details.cached_tokens``). + - ``cache_creation_input_tokens``: input tokens written to cache (Anthropic + ``cache_creation_input_tokens``; always ``None``/0 for OpenAI which has no + separate cache-write step). + """ + + cache_read_input_tokens: Optional[int] = None + cache_creation_input_tokens: Optional[int] = None