diff --git a/.gitignore b/.gitignore index e0bac2b47..2666c56a0 100644 --- a/.gitignore +++ b/.gitignore @@ -66,4 +66,7 @@ sdk/benchmark/.env .pytest-tmp doc/mermaid -.claude/skills/python-import-triage \ No newline at end of file +.claude/skills/python-import-triage + +# Debug scripts preserved on feat/opt-agent-context-temp-scripts +sdk/nexent/core/agents/temp_scripts/ \ No newline at end of file diff --git a/sdk/nexent/__init__.py b/sdk/nexent/__init__.py index d0de150cb..643fd1b9c 100644 --- a/sdk/nexent/__init__.py +++ b/sdk/nexent/__init__.py @@ -3,6 +3,7 @@ from .memory import * from .storage import * from .vector_database import * +from .container import * from .skills import * diff --git a/sdk/nexent/core/agents/agent_context.py b/sdk/nexent/core/agents/agent_context.py deleted file mode 100644 index 0b40d325c..000000000 --- a/sdk/nexent/core/agents/agent_context.py +++ /dev/null @@ -1,1409 +0,0 @@ -"""Agent context management for memory compression and summarization. - -Provides ContextManager for token-aware compression of agent memory, -supporting incremental summarization with cache-based optimization. - -Also provides ContextManager as the single source of truth for: -- Context component registration and lifecycle -- System prompt assembly from components -- Strategy-based component selection -""" - -import hashlib -import json -import logging -import re -import threading -from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Tuple, Union - -if TYPE_CHECKING: - from .agent_model import ContextComponent, ContextStrategy - -from smolagents.memory import ActionStep, AgentMemory, MemoryStep, TaskStep -from smolagents.models import ChatMessage, MessageRole - -from .summary_cache import CompressionCallRecord, CurrentSummaryCache, PreviousSummaryCache -from .summary_config import ContextManagerConfig, StrategyType - -logger = logging.getLogger("agent_context") - -from ..utils.token_estimation import ( - _extract_text_from_messages, - estimate_tokens, - estimate_tokens_for_steps, - msg_char_count, - msg_token_count, - estimate_tokens_for_system_prompt -) - - -@dataclass -class SummaryTaskStep(TaskStep): - """TaskStep subclass that contains a compressed summary of earlier steps.""" - is_summary: bool = True - prefix: str = "Summary of earlier steps in this task:" # default prefix - - def to_messages(self, summary_mode: bool = False) -> list: - content = [{"type": "text", "text": f"{self.prefix}:\n{self.task}"}] - return [ChatMessage(role=MessageRole.USER, content=content)] - - -# ============================================================ -# Standalone utilities (no ContextManager state required) -# ============================================================ - -def format_summary_output(raw_output: str) -> Optional[str]: - """Clean and validate LLM summary output. - - Strips markdown code fences, attempts JSON parse for normalization, - falls back to plain text if not valid JSON. - """ - cleaned = raw_output.strip() - if cleaned.startswith("```"): - cleaned = re.sub(r"^```(?:json)?\s*\n?", "", cleaned) - cleaned = re.sub(r"\n?```\s*$", "", cleaned) - if not cleaned: - return None - try: - parsed = json.loads(cleaned) - return json.dumps(parsed, ensure_ascii=False, indent=2) - except json.JSONDecodeError: - logger.warning("Summary output is not valid JSON; using as plain text") - return cleaned - - -def _is_context_length_error(err: Exception) -> bool: - """Check if an exception indicates a context length / token limit error.""" - msg = str(err).lower() - return any(k in msg for k in ( - "context_length", "context length", "maximum context", "maximum context length", - "prompt is too long", "reduce the length", "too many tokens", - "token limit", "exceeds the maximum", "input is too long", - "input length", "exceeds context", "context window", - )) - - -def compress_history_offline( - pairs: List[Tuple[str, str]], - model, - config: Optional[ContextManagerConfig] = None, - previous_summary: Optional[str] = None, -) -> dict: - """Compress conversation history offline, without ContextManager or AgentMemory. - - This is a standalone function for **Static Compression Inspection** in - benchmarks. It takes plain-text (user, assistant) pairs and produces a - summary using the same prompts and schema as the in-agent compression path, - but without any stateful cache, offload store, or agent runtime. - - Args: - pairs: List of (user_text, assistant_text) tuples representing - conversation turns to compress. - model: An LLM model object compatible with smolagents' call interface. - config: ContextManagerConfig providing prompts, schema, and token budgets. - Defaults to a fresh ContextManagerConfig() if not provided. - previous_summary: Optional existing summary text for incremental - compression. If provided, uses the incremental prompt - to update rather than create from scratch. - - Returns: - dict with: - - "summary": the compressed summary text (str or None on failure) - - "is_incremental": whether incremental compression was used - - "is_fallback": whether the LLM failed and fallback truncation was used - - "input_text": the raw text that was fed to the LLM (for debugging) - - "input_chars": character count of the input text - """ - config = config or ContextManagerConfig() - # Same compensation as ContextManager.__init__: when max_summary_input_tokens - # is left at the default 0, derive it from token_threshold so that truncation - # logic doesn't accidentally chop all input. - if config.max_summary_input_tokens <= 0: - config.max_summary_input_tokens = int(config.token_threshold * 1.2) - if not pairs and not previous_summary: - return { - "summary": None, - "is_incremental": False, - "is_fallback": False, - "input_text": "", - "input_chars": 0, - } - - # Build input text from pairs - parts = [] - for user_text, assistant_text in pairs: - parts.append(f"user: {user_text}\nassistant: {assistant_text}") - pairs_text = "\n\n".join(parts) - - # Determine compression mode - is_incremental = previous_summary is not None - - if is_incremental: - input_text = ( - f"## Previous Summary\n{previous_summary}\n\n" - f"## New Conversations\n{pairs_text}" - ) - else: - input_text = pairs_text - - # Truncate if exceeds budget - from ..utils.token_estimation import estimate_tokens_text - input_tokens = estimate_tokens_text(input_text) - if input_tokens > config.max_summary_input_tokens: - # Simple tail-truncation for offline mode - approx_chars = int(config.max_summary_input_tokens * config.chars_per_token * 0.9) - input_text = "...[Earlier content truncated]...\n" + input_text[-approx_chars:] - - # Build prompt - schema_desc = json.dumps(config.summary_json_schema, ensure_ascii=False, indent=2) - if is_incremental: - system_prompt = config.incremental_summary_system_prompt - user_prompt = ( - f"Update the summary following this JSON structure:\n{schema_desc}\n\n" - f"{input_text}" - ) - else: - system_prompt = config.summary_system_prompt - user_prompt = ( - f"Create a structured checkpoint summary following this JSON structure:\n{schema_desc}\n\n" - f"TURNS TO SUMMARIZE:\n{input_text}" - ) - - messages = [ - ChatMessage(role=MessageRole.SYSTEM, - content=[{"type": "text", "text": system_prompt}]), - ChatMessage(role=MessageRole.USER, - content=[{"type": "text", "text": user_prompt}]), - ] - - # Call LLM with error handling - is_fallback = False - summary = None - - try: - response = model(messages, stop_sequences=[]) - raw_output = response.content - if isinstance(raw_output, list): - raw_output = " ".join( - block.get("text", "") - for block in raw_output - if isinstance(block, dict) and block.get("type") == "text" - ) - if not isinstance(raw_output, str): - raw_output = str(raw_output) - summary = format_summary_output(raw_output) - except Exception as e: - if _is_context_length_error(e): - logger.warning("Offline compression exceeds context limit; retrying with 2/3 budget") - approx_chars = int(config.max_summary_input_tokens * config.chars_per_token * 0.6) - truncated_input = input_text[-approx_chars:] if len(input_text) > approx_chars else input_text - if is_incremental: - user_prompt = ( - f"Update the summary following this JSON structure:\n{schema_desc}\n\n" - f"{truncated_input}" - ) - else: - user_prompt = ( - f"Create a structured checkpoint summary following this JSON structure:\n{schema_desc}\n\n" - f"TURNS TO SUMMARIZE:\n{truncated_input}" - ) - messages[-1] = ChatMessage( - role=MessageRole.USER, - content=[{"type": "text", "text": user_prompt}], - ) - try: - response = model(messages, stop_sequences=[]) - raw_output = response.content - if isinstance(raw_output, list): - raw_output = " ".join( - block.get("text", "") - for block in raw_output - if isinstance(block, dict) and block.get("type") == "text" - ) - if not isinstance(raw_output, str): - raw_output = str(raw_output) - summary = format_summary_output(raw_output) - except Exception as e2: - logger.error(f"Offline compression retry still failed: {e2}") - - if summary is None: - # L3 fallback: hard truncation - is_fallback = True - first_task = pairs[0][0][:200] if pairs else "" - reduced_chars = int(config.max_summary_reduce_tokens * config.chars_per_token) - reduced_text = pairs_text[-reduced_chars:] if len(pairs_text) > reduced_chars else pairs_text - summary = ( - "[CONTEXT COMPACTION — REFERENCE ONLY] Earlier steps were removed to free context space. " - "The removed content cannot be summarized. Continue based on the steps below.\n\n" - f"Original task: {first_task}\n\n" - f"Steps removed: {len(pairs)} of {len(pairs)}\n\n" - "Remaining compressed history:\n" - + reduced_text - ) - - return { - "summary": summary, - "is_incremental": is_incremental, - "is_fallback": is_fallback, - "input_text": input_text, - "input_chars": len(input_text), - } - - -class ContextManager: - def __init__(self, config: Optional[ContextManagerConfig] = None, max_steps: Optional[int] = None): - self.config = config or ContextManagerConfig() - self._previous_summary_cache: Optional[PreviousSummaryCache] = None - self._current_summary_cache: Optional[CurrentSummaryCache] = None - - self._last_run_start_idx: Optional[int] = None - - if max_steps is not None and self.config.keep_recent_steps >= max_steps: - self.config.keep_recent_steps = max_steps - - self.compression_calls_log: List[CompressionCallRecord] = [] - self._step_local_log: List[CompressionCallRecord] = [] - self._lock = threading.Lock() - - # Token accounting for benchmark instrumentation. - # Recorded by compress_if_needed at each return point so benchmarks - # can compute token_reduction = 1 - last_compressed / last_uncompressed. - self._last_uncompressed_token_count: Optional[int] = None - self._last_compressed_token_count: Optional[int] = None - - if self.config.max_summary_input_tokens <= 0: - self.config.max_summary_input_tokens = int(self.config.token_threshold * 1.2) - if self.config.max_summary_reduce_tokens <= 0: - self.config.max_summary_reduce_tokens = int(self.config.token_threshold * 0.2) - - self._components: List = [] - - # ============================================================ - # Cache validation - # ============================================================ - - def _is_prev_cache_valid(self, prev_pairs: List[tuple]) -> Tuple[bool, int]: - """Checks whether the previous cache covers a prefix of prev_pairs. - - Returns (is_valid, covered_idx). When is_valid is True, prev_pairs[0:covered_idx] - can be replaced by cache.summary_text, and prev_pairs[covered_idx:] represents - the uncovered incremental portion. - """ - cache = self._previous_summary_cache - if cache is None or not prev_pairs: - return False, 0 - if cache.covered_pairs == 0 or cache.covered_pairs > len(prev_pairs): - return False, 0 - anchor_t, anchor_a = prev_pairs[cache.covered_pairs - 1] - fp = self._pair_fingerprint(anchor_t.task or "", self._action_content(anchor_a)) - if fp != cache.anchor_fingerprint: - return False, 0 - return True, cache.covered_pairs - - def _is_curr_cache_valid(self, action_steps: List[ActionStep]) -> Tuple[bool, int]: - cache = self._current_summary_cache - if cache is None or not action_steps: - return False, 0 - if cache.end_steps == 0 or cache.end_steps > len(action_steps): - return False, 0 - anchor = action_steps[cache.end_steps - 1] - if self._action_fingerprint(anchor) != cache.anchor_fingerprint: - return False, 0 - return True, cache.end_steps - - # ============================================================ - # Effective token estimation - # ============================================================ - - def _effective_tokens(self, memory: AgentMemory, current_run_start_idx: int) -> int: - """Estimates the actual token burden of the upcoming _build_messages call. - Uses summary_text for the covered prefix when cache is valid; falls back to raw otherwise. - """ - system_prompt_tokens = estimate_tokens_for_system_prompt(memory) - prev_steps = memory.steps[:current_run_start_idx] - curr_steps = memory.steps[current_run_start_idx:] - return (system_prompt_tokens + self._effective_prev_tokens(prev_steps) - + self._effective_curr_tokens(curr_steps)) - - def _effective_prev_tokens(self, prev_steps: List[MemoryStep]) -> int: - if not prev_steps: - return 0 - prev_pairs = self._extract_pairs(prev_steps) - is_valid, covered_idx = self._is_prev_cache_valid(prev_pairs) - if not is_valid: - return self._estimate_tokens_for_steps(prev_steps) - uncovered = prev_pairs[covered_idx:] - uncovered_tokens = ( - self._estimate_text_tokens(self._pairs_to_text(uncovered)) - if uncovered else 0 - ) - return (self._estimate_text_tokens(self._previous_summary_cache.summary_text) - + uncovered_tokens) - - def _effective_curr_tokens(self, curr_steps: List[MemoryStep]) -> int: - if not curr_steps: - return 0 - curr_task = curr_steps[0] if isinstance(curr_steps[0], TaskStep) else None - action_steps = [s for s in curr_steps if isinstance(s, ActionStep)] - is_valid, covered_idx = self._is_curr_cache_valid(action_steps) - if not is_valid: - return self._estimate_tokens_for_steps(curr_steps) - task_tokens = ( - self._estimate_text_tokens(curr_task.task or "") if curr_task else 0 - ) - uncovered = action_steps[covered_idx:] - uncovered_tokens = ( - self._estimate_text_tokens(self._actions_to_text(uncovered)) - if uncovered else 0 - ) - return (task_tokens - + self._estimate_text_tokens(self._current_summary_cache.summary_text) - + uncovered_tokens) - - # ============================================================ - # Budget helpers - # ============================================================ - - def _estimate_text_tokens(self, text: str) -> int: - from ..utils.token_estimation import estimate_tokens_text - return estimate_tokens_text(text) - - def _trim_pairs_to_budget( - self, pairs: List[tuple], max_tokens: int, keep_first: bool = True, - ) -> List[tuple]: - if not pairs: - return [] - pair_tokens = [ - self._estimate_text_tokens(self._pairs_to_text([p])) for p in pairs - ] - sep = self._estimate_text_tokens("\n\n") - total = sum(pair_tokens) + sep * max(0, len(pairs) - 1) - if total <= max_tokens: - return list(pairs) - - if keep_first and len(pairs) > 1: - budget = max_tokens - pair_tokens[0] - sep - kept_tail = [] - for i in range(len(pairs) - 1, 0, -1): - cost = pair_tokens[i] + (sep if kept_tail else 0) - if cost > budget: - break - kept_tail.append(pairs[i]) - budget -= cost - return [pairs[0]] + list(reversed(kept_tail)) - - budget = max_tokens - kept = [] - for i in range(len(pairs) - 1, -1, -1): - cost = pair_tokens[i] + (sep if kept else 0) - if cost > budget: - break - kept.append(pairs[i]) - budget -= cost - return list(reversed(kept)) if kept else [pairs[-1]] - - - - def _is_observation_step(self, action: ActionStep) -> bool: - return action is not None and hasattr(action, 'observations') and action.observations is not None - - def _is_tool_call_step(self, action: ActionStep) -> bool: - return action is not None and hasattr(action, 'tool_calls') and action.tool_calls is not None - - def _trim_actions_to_budget( - self, actions: List[ActionStep], task_text: str, max_tokens: int, - ) -> List[ActionStep]: - if not actions: - return [] - - def _total_tokens(acts): - return self._estimate_text_tokens(task_text + self._actions_to_text(acts)) - - if _total_tokens(actions) <= max_tokens: - return list(actions) - - for drop in range(1, len(actions) + 1): - remaining = actions[drop:] - if not remaining: - break - if self._is_observation_step(remaining[0]) and self._is_tool_call_step(actions[drop - 1]): - continue - if _total_tokens(remaining) <= max_tokens: - return list(remaining) - - return self._fallback_trim_actions(actions) - - def _fallback_trim_actions(self, actions: List[ActionStep]) -> List[ActionStep]: - last_action = actions[-1] - if len(actions) >= 2 and self._is_observation_step(last_action): - prev_action = actions[-2] - if self._is_tool_call_step(prev_action): - logger.warning( - "Fallback limit triggered: Retaining the last complete ToolCall + Observation pair intact. " - "This may exceed the token budget, and downstream truncation will be relied upon." - ) - return [prev_action, last_action] - return [last_action] - - # ============================================================ - # Mainly Entry Point - # ============================================================ - - def compress_if_needed( - self, model, memory, original_messages: List[ChatMessage], current_run_start_idx, - ) -> List[ChatMessage]: - # G1 - if not self.config.enabled: - return original_messages - - if self._estimate_tokens(memory) <= self.config.token_threshold: - # No compression needed; record that compressed == uncompressed - # so benchmark token_reduction reads as zero rather than stale. - self._last_uncompressed_token_count = self._msg_token_count(original_messages) - self._last_compressed_token_count = self._last_uncompressed_token_count - return original_messages - - with self._lock: - # Run detection - if (self._last_run_start_idx is not None - and current_run_start_idx != self._last_run_start_idx): - self._current_summary_cache = None - self._last_run_start_idx = current_run_start_idx - - # Note: The memory here always consists of the unmodified, summary-task-step-free - # original previous_run + current_run. - # - previous_run: [(TaskStep, ActionStep), ...] - # - current_run: [TaskStep, ActionStep, ActionStep, ...] - if self._effective_tokens(memory, current_run_start_idx) <= self.config.token_threshold: - # Stable-phase bypass: No LLM call; construct compressed messages directly from existing cache. - self._step_local_log.clear() - - prev_steps = memory.steps[:current_run_start_idx] - curr_steps = memory.steps[current_run_start_idx:] - - prev_summary_step = None - prev_tail_steps = list(prev_steps) - prev_pairs = self._extract_pairs(prev_steps) - if prev_pairs: - is_valid, covered_idx = self._is_prev_cache_valid(prev_pairs) - if is_valid: - prev_summary_step = SummaryTaskStep( - task=self._previous_summary_cache.summary_text - ) - uncovered = prev_pairs[covered_idx:] - prev_tail_steps = self._pairs_to_steps(uncovered) - - curr_kept_steps = list(curr_steps) - if curr_steps: - curr_task = curr_steps[0] if isinstance(curr_steps[0], TaskStep) else None - curr_action_steps = [s for s in curr_steps if isinstance(s, ActionStep)] - if curr_action_steps: - is_valid, covered_idx = self._is_curr_cache_valid(curr_action_steps) - if is_valid: - uncovered = curr_action_steps[covered_idx:] - curr_kept_steps = ( - ([curr_task] if curr_task else []) - + [SummaryTaskStep(task=self._current_summary_cache.summary_text)] - + list(uncovered) - ) - - record = CompressionCallRecord( - call_type="stable_bypass", cache_hit=True, - details={"reason": "stable_period_effective_under_threshold"}, - ) - self.compression_calls_log.append(record) - self._step_local_log.append(record) - - compressed_msgs = self._build_messages( - memory, prev_summary_step, prev_tail_steps, curr_kept_steps - ) - self._last_uncompressed_token_count = self._msg_token_count(original_messages) - self._last_compressed_token_count = self._msg_token_count(compressed_msgs) - return compressed_msgs - - self._step_local_log.clear() - - self._last_uncompressed_token_count = self._msg_token_count(original_messages) - - prev_steps = memory.steps[:current_run_start_idx] - curr_steps = memory.steps[current_run_start_idx:] - - prev_tokens = self._effective_prev_tokens(prev_steps) - curr_tokens = self._effective_curr_tokens(curr_steps) - - compress_prev = prev_tokens > self.config.token_threshold * 0.6 - compress_curr = curr_tokens > self.config.token_threshold * 0.4 - - total_effective_tokens = prev_tokens + curr_tokens - if compress_prev or compress_curr: - logger.info( - f"Context compression triggered: total_tokens={total_effective_tokens}, " - f"threshold={self.config.token_threshold}, " - f"prev_tokens={prev_tokens} (compress={compress_prev}), " - f"curr_tokens={curr_tokens} (compress={compress_curr})" - ) - - # --------------- Previous phase --------------- - prev_summary_step: Optional[SummaryTaskStep] = None - prev_tail_steps: List[MemoryStep] = list(prev_steps) - prev_pairs = self._extract_pairs(prev_steps) - - if compress_prev and prev_pairs: - keep_n = min(self.config.keep_recent_pairs, len(prev_pairs)) - pairs_to_compress = prev_pairs[:-keep_n] if keep_n > 0 else prev_pairs - pairs_to_keep = prev_pairs[-keep_n:] if keep_n > 0 else [] - if pairs_to_compress: - summary_text = self._compress_previous_with_cache( - pairs_to_compress, model - ) - if summary_text: - if "[CONTEXT COMPACTION" in summary_text: - prev_summary_step = SummaryTaskStep(task=summary_text, prefix="Context fallback, Truncated raw history:") - else: - prev_summary_step = SummaryTaskStep(task=summary_text) - prev_tail_steps = self._pairs_to_steps(pairs_to_keep) - elif prev_pairs: - # if cache is valid, use cache + uncovered display - is_valid, covered_idx = self._is_prev_cache_valid(prev_pairs) - if is_valid: - prev_summary_step = SummaryTaskStep( - task=self._previous_summary_cache.summary_text - ) - uncovered = prev_pairs[covered_idx:] - prev_tail_steps = self._pairs_to_steps(uncovered) - - # --------------- Current phase --------------- - curr_kept_steps: List[MemoryStep] = list(curr_steps) - - if curr_steps: - curr_task = curr_steps[0] if isinstance(curr_steps[0], TaskStep) else None - curr_action_steps = [s for s in curr_steps if isinstance(s, ActionStep)] - - if compress_curr and curr_action_steps: - keep_n = min(self.config.keep_recent_steps, len(curr_action_steps)) - if keep_n > 0 and keep_n < len(curr_action_steps): - boundary = curr_action_steps[-keep_n] - prev_a = curr_action_steps[-keep_n - 1] - if (getattr(boundary, "observations", None) is not None - and getattr(prev_a, "tool_calls", None) is not None): - keep_n += 1 - - actions_to_compress = ( - curr_action_steps[:-keep_n] if keep_n > 0 else list(curr_action_steps) - ) - actions_to_keep = ( - curr_action_steps[-keep_n:] if keep_n > 0 else [] - ) - if actions_to_compress: - curr_summary_text = self._compress_current_with_cache( - curr_task, actions_to_compress, model - ) - if curr_summary_text: - if "[CONTEXT COMPACTION" in curr_summary_text: - curr_summary_step = SummaryTaskStep(task=curr_summary_text, prefix="Truncated recent action steps:") - else: - curr_summary_step = SummaryTaskStep(task=curr_summary_text) - curr_kept_steps = ( - ([curr_task] if curr_task else []) - + [curr_summary_step] - + list(actions_to_keep) - ) - elif curr_action_steps: - is_valid, covered_idx = self._is_curr_cache_valid(curr_action_steps) - if is_valid: - uncovered = curr_action_steps[covered_idx:] - curr_kept_steps = ( - ([curr_task] if curr_task else []) - + [SummaryTaskStep(task=self._current_summary_cache.summary_text)] - + list(uncovered) - ) - - final_messages = self._build_messages( - memory, prev_summary_step, prev_tail_steps, curr_kept_steps - ) - final_tokens = self._msg_token_count(final_messages) - self._last_compressed_token_count = final_tokens - # This situation is unlikely to occur unless the threshold itself is set unreasonably small - if final_tokens > int(self.config.token_threshold * 1.1): - logger.warning( - f"Still exceeds threshold after compression: {final_tokens} > {self.config.token_threshold}. " - f"Consider reducing keep_recent_pairs ({self.config.keep_recent_pairs}) " - f"or keep_recent_steps({self.config.keep_recent_steps})" - ) - return final_messages - - # ============================================================ - # Previous Compression - # ============================================================ - - def _extract_pairs(self, steps): - pairs = [] - i = 0 - while i < len(steps): - if isinstance(steps[i], TaskStep) and not isinstance(steps[i], SummaryTaskStep): - if i + 1 < len(steps) and isinstance(steps[i + 1], ActionStep): - pairs.append((steps[i], steps[i + 1])) - i += 2 - continue - i += 1 - return pairs - - def _compress_previous_with_cache( - self, pairs_to_compress: List[tuple], model, - ) -> Optional[str]: - if not pairs_to_compress: - return None - - cache = self._previous_summary_cache - if cache is not None and cache.covered_pairs == len(pairs_to_compress): - anchor_t, anchor_a = pairs_to_compress[-1] - fp = self._pair_fingerprint( - anchor_t.task or "", self._action_content(anchor_a) - ) - if fp == cache.anchor_fingerprint: - record = CompressionCallRecord( - call_type="previous_cache_hit", cache_hit=True, - details={"covered_pairs": cache.covered_pairs}, - ) - self.compression_calls_log.append(record) - self._step_local_log.append(record) - return cache.summary_text - - # ===== Incremental Compression Path ===== - if (cache is not None - and 0 < cache.covered_pairs < len(pairs_to_compress)): - anchor_t, anchor_a = pairs_to_compress[cache.covered_pairs - 1] - fp = self._pair_fingerprint( - anchor_t.task or "", self._action_content(anchor_a) - ) - if fp == cache.anchor_fingerprint: - old_summary = cache.summary_text - new_pairs = pairs_to_compress[cache.covered_pairs:] - incremental_input = ( - f"## Previous Summary\n{old_summary}\n\n" - f"## New Conversations\n{self._pairs_to_text(new_pairs)}" - ) - input_tokens = self._estimate_text_tokens(incremental_input) - if input_tokens <= self.config.max_summary_input_tokens: - summary_text = self._generate_summary( - incremental_input, model, - call_type="previous_incremental", - prompt_type="incremental", - ) - if summary_text: - last_t, last_a = pairs_to_compress[-1] - self._previous_summary_cache = PreviousSummaryCache( - summary_text=summary_text, - covered_pairs=len(pairs_to_compress), - anchor_fingerprint=self._pair_fingerprint( - last_t.task or "", self._action_content(last_a) - ), - ) - return summary_text - logger.info( - f"Incremental input {input_tokens} tokens exceeds budget " - f"({self.config.max_summary_input_tokens}), " - f"Falling back to full compression." - ) - - # Fresh compression - summary_text, is_cacheable = self._summarize_pairs(pairs_to_compress, model) - # summary_text is valid, not None - if summary_text and is_cacheable: - last_t, last_a = pairs_to_compress[-1] - self._previous_summary_cache = PreviousSummaryCache( - summary_text=summary_text, - covered_pairs=len(pairs_to_compress), - anchor_fingerprint=self._pair_fingerprint( - last_t.task or "", self._action_content(last_a) - ), - ) - # is_cacheable is False, PreviousSummaryCache keep as is - return summary_text - - def _action_content(self, action: ActionStep) -> str: - return action.action_output or getattr(action, "output", "") or "" - - def _pair_fingerprint(self, task_content: str, action_content: str) -> str: - raw = (task_content[-200:] + action_content[-200:]) - return hashlib.md5(raw.encode()).hexdigest() - - def _summarize_pairs( - self, pairs: List[tuple], model, - ) -> Tuple[Optional[str], bool]: - """Fresh compression entry point, returns (summary, is_cacheable). - - L1 full summary -> (text, True) - L2 trim summary -> (text, True) # discard long-lived pairs, then summarize - L3 trim origin -> (text, False) # LLM call failed, hard truncated, no summary returned - """ - if not pairs: - return None, False - - full_text = self._pairs_to_text(pairs) - if self._estimate_text_tokens(full_text) <= self.config.max_summary_input_tokens: - target_text = full_text - else: - trimmed_pairs = self._trim_pairs_to_budget( - pairs, self.config.max_summary_input_tokens, keep_first=False - ) - target_text = self._render_steps_with_truncation( - trimmed_pairs, fmt="pair", - max_tokens=self.config.max_summary_input_tokens, - task_budget_chars=800, action_budget_chars=1500 - ) - - summary_text = self._generate_summary(target_text, model, call_type="previous_summary") - if summary_text: - return summary_text, True - logger.warning("previous full/truncated history summary generation failed, triggering L3 fallback truncation") - - reduced_pairs = self._trim_pairs_to_budget(pairs, self.config.max_summary_reduce_tokens, False) - reduced_text = self._render_steps_with_truncation( - reduced_pairs, fmt="pair", max_tokens=self.config.max_summary_reduce_tokens - ) - first_task = pairs[0][0].task[:200] if pairs and pairs[0][0].task else "" - fallback_text = ( - "[CONTEXT COMPACTION — REFERENCE ONLY] Earlier steps were removed to free context space. " - "The removed content cannot be summarized. Continue based on the steps below.\n\n" - f"Original task: {first_task}\n\n" - f"Steps removed: {len(pairs) - len(reduced_pairs)} of {len(pairs)}\n\n" - "Remaining compressed history:\n" - + reduced_text - ) - return fallback_text, False - - - # ============================================================ - # Current compression - # ============================================================ - - def _compress_current_with_cache( - self, curr_task: Optional[TaskStep], actions_to_compress: List[ActionStep], model, - ) -> Optional[str]: - if not actions_to_compress: - return None - - current_last_fp = self._action_fingerprint(actions_to_compress[-1]) - task_text = f"Current Task: {curr_task.task}\n\n" if curr_task else "" - cache = self._current_summary_cache - # 1) Full cache hit - if cache is not None and cache.end_steps == len(actions_to_compress): - if cache.anchor_fingerprint == current_last_fp: - record = CompressionCallRecord( - call_type="current_cache_hit", cache_hit=True, - details={"end_steps": cache.end_steps}, - ) - self.compression_calls_log.append(record) - self._step_local_log.append(record) - return cache.summary_text - - # 2) Incremental compression - if cache is not None and 0 < cache.end_steps < len(actions_to_compress): - anchor_action = actions_to_compress[cache.end_steps - 1] - if self._action_fingerprint(anchor_action) == cache.anchor_fingerprint: - old_summary = cache.summary_text - new_actions = actions_to_compress[cache.end_steps:] - incremental_input = ( - f"## Previous Summary\n{old_summary}\n\n" - f"## New Steps\n{task_text}{self._actions_to_text(new_actions)}" - ) - input_tokens = self._estimate_text_tokens(incremental_input) - if input_tokens <= self.config.max_summary_input_tokens: - summary_text = self._generate_summary( - incremental_input, model, - call_type="current_incremental", - prompt_type="incremental", - ) - if summary_text: - self._current_summary_cache = CurrentSummaryCache( - summary_text=summary_text, - end_steps=len(actions_to_compress), - anchor_fingerprint=current_last_fp, - ) - return summary_text - logger.info( - f"current incremental input {input_tokens} tokens exceeds budget " - f"({self.config.max_summary_input_tokens}), fallback to full compression or trimmed actions" - ) - - - # 3) Fresh compression: no cache or no valid cache or incremental input exceeds max_summary_input_tokens - safe_actions = self._trim_actions_to_budget( - actions_to_compress, task_text, self.config.max_summary_input_tokens, - ) - is_full_coverage = (len(safe_actions) == len(actions_to_compress)) - if not is_full_coverage: - logger.info( - f"Current full summary trimmed {len(actions_to_compress) - len(safe_actions)} " - f"oldest actions, still using cache" - ) - - actions_budget = max(0, self.config.max_summary_input_tokens - self._estimate_text_tokens(task_text)) - full_text = task_text + self._render_steps_with_truncation( - safe_actions, fmt="action", max_tokens=actions_budget - ) - summary_text = self._generate_summary(full_text, model, call_type="current_summary") - if summary_text: - self._current_summary_cache = CurrentSummaryCache( - summary_text=summary_text, - end_steps=len(actions_to_compress), - anchor_fingerprint=current_last_fp, - ) - return summary_text - else: - reduced_actions = self._trim_actions_to_budget( - actions_to_compress, task_text, self.config.max_summary_reduce_tokens - ) - actions_text = self._render_steps_with_truncation( - reduced_actions, fmt="action", max_tokens=self.config.max_summary_reduce_tokens - ) - fallback_text = ( - "[CONTEXT COMPACTION — REFERENCE ONLY] Some recent action steps were removed to free context space. " - "Continue based on the remaining steps below.\n\n" - f"Steps removed: {len(actions_to_compress) - len(reduced_actions)} of {len(actions_to_compress)}\n\n" - "Remaining steps:\n" - + actions_text - ) - return fallback_text - - def _actions_to_text(self, actions: List[ActionStep]) -> str: - parts = [] - for i, step in enumerate(actions): - text = self._render_action_step(step) - parts.append(f"[Step {step.step_number or i+1}]\n{text}") - return "\n\n".join(parts) - - def _render_steps_with_truncation( - self, - steps: List, - fmt: str = "action", - max_tokens: int = None, - min_budget_chars: int = 80, - task_budget_chars: int = 800, - action_budget_chars: int = None, - ) -> str: - if max_tokens is None: - max_tokens = self.config.max_summary_input_tokens - if action_budget_chars is None: - action_budget_chars = self.config.max_memory_step_length - - entries = self._build_step_entries(steps, fmt) - raw_text = "\n\n".join(task + action for task, action in entries) - if self._estimate_text_tokens(raw_text) <= max_tokens: - return raw_text - - return self._truncate_entries_to_budget(entries, max_tokens, min_budget_chars, task_budget_chars, action_budget_chars) - - def _build_step_entries(self, steps: List, fmt: str) -> List[Tuple[str, str]]: - entries = [] - for step in steps: - if fmt == "action": - text = f"[Step {step.step_number or '?'}]\n{self._render_action_step(step)}" - entries.append(("", text)) - else: - task_step, action_step = step - task_str = f"user: {task_step.task or ''}\nassistant: " - action_str = self._render_action_step(action_step) - entries.append((task_str, action_str)) - return entries - - def _truncate_entries_to_budget( - self, entries: List[Tuple[str, str]], max_tokens: int, - min_budget_chars: int, task_budget_chars: int, action_budget_chars: int, - ) -> str: - t_budget = task_budget_chars - a_budget = action_budget_chars - all_text = "" - - while True: - parts = [self._truncate_entry(e, t_budget, a_budget) for e in entries] - all_text = "\n\n".join(parts) - - if self._estimate_text_tokens(all_text) <= max_tokens: - break - - t_budget, a_budget = self._reduce_budgets(t_budget, a_budget, min_budget_chars) - if t_budget == min_budget_chars and a_budget == min_budget_chars: - break - - return all_text - - def _truncate_entry(self, entry: Tuple[str, str], task_budget: int, action_budget: int) -> str: - task_str, action_str = entry - task_trunc = self._truncate_text(task_str, task_budget) if task_str else "" - action_trunc = self._truncate_text(action_str, action_budget) - return task_trunc + action_trunc - - def _truncate_text(self, text: str, max_len: int, mark: str = "...[Truncated]") -> str: - if len(text) <= max_len: - return text - return text[:max_len - len(mark)] + mark - - def _reduce_budgets(self, t_budget: int, a_budget: int, min_budget: int) -> Tuple[int, int]: - if a_budget > min_budget: - return t_budget, max(min_budget, int(a_budget * 0.8)) - if t_budget > min_budget: - return max(min_budget, int(t_budget * 0.8)), a_budget - return t_budget, a_budget - - def _actions_to_text_with_limit(self, actions: List[ActionStep], prefill_tokens: int = 0) -> str: - rendered_steps = [] - for i, step in enumerate(actions): - prefix = f"[Step {step.step_number or i+1}]\n" - content = self._render_action_step(step) - rendered_steps.append((prefix, content)) - budget_per_action = self.config.max_memory_step_length - - while True: - parts = [] - - for prefix, content in rendered_steps: - if len(content) > budget_per_action: - text = f"{prefix}{content[:budget_per_action]}\n\n[System Note: Step content too long, partially truncated]" - else: - text = f"{prefix}{content}" - parts.append(text) - - all_text = "\n\n".join(parts) - - if self._estimate_text_tokens(all_text) + prefill_tokens <= self.config.max_summary_input_tokens: - break - budget_per_action = int(budget_per_action * 0.9) - - if budget_per_action < 50: - logger.warning( - f"Per-step compression budget has reached minimum threshold " - f"(budget={budget_per_action}), possibly due to excessively long preset prompts. " - f"Forcing return of truncated result." - ) - break - return all_text - - @staticmethod - def _action_fingerprint(action: ActionStep) -> str: - raw = ( - str(action.step_number or "") - + (action.model_output or "")[-200:] - + ( - action.action_output if isinstance(action.action_output, str) - else str(action.action_output) if action.action_output else "" - )[-200:] - ) - return hashlib.md5(raw.encode()).hexdigest() - - # ============================================================ - # LLM call - # ============================================================ - - def _is_context_length_error(self, err: Exception) -> bool: - return _is_context_length_error(err) - - def _generate_summary(self, text: str, model, call_type: str = "summary", - prompt_type: str = "initial") -> Optional[str]: - try: - return self._do_generate_summary(text, model, call_type, prompt_type) - except Exception as e: - if self._is_context_length_error(e): - logger.warning(f"{call_type} exceeds context limit; retrying with 2/3 budget truncation") - shrunk = self._truncate_text_to_tokens( - text, int(self.config.max_summary_input_tokens * 0.66) - ) - try: - return self._do_generate_summary(shrunk, model, call_type + "_retry", prompt_type) - except Exception as e2: - self._record_failed_compression(call_type + "_retry_failed", str(e2)) - logger.error(f"Retry still failed: {e2}") - return None - self._record_failed_compression(call_type + "_failed", str(e)) - logger.error(f"Summary generation exception: {e}") - return None - - def _record_failed_compression(self, call_type: str, error_msg: str): - """Record a failed compression attempt so stats reflect actual compression triggers.""" - - record = CompressionCallRecord( - call_type=call_type, - input_tokens=0, - output_tokens=0, - input_chars=0, - output_chars=0, - cache_hit=False, - details={"error": error_msg}, - ) - self.compression_calls_log.append(record) - self._step_local_log.append(record) - - def _do_generate_summary(self, text: str, model, call_type: str = "summary", - prompt_type: str = "initial") -> Optional[str]: - # prompt_type selects which system prompt to render. For "incremental" - # we use the dedicated incremental_summary_system_prompt (with fallback - # to summary_system_prompt if it is empty) and a user prompt phrased - # as an update; "initial" keeps the original fresh-compaction phrasing. - if prompt_type == "incremental": - system_prompt = ( - self.config.incremental_summary_system_prompt - or self.config.summary_system_prompt - ) - else: - system_prompt = self.config.summary_system_prompt - - schema_desc = json.dumps( - self.config.summary_json_schema, ensure_ascii=False, indent=2 - ) - if prompt_type == "incremental": - # text already contains the "## Previous Summary" + "## New ..." - # sections; the prompt only needs to instruct the update. - user_prompt = ( - f"Update the summary following this JSON structure:\n{schema_desc}\n\n" - f"{text}" - ) - else: - user_prompt = ( - f"Output a summary following this JSON structure:\n{schema_desc}\n\n" - f"Conversation content to summarize:\n{text}" - ) - messages = [ - ChatMessage(role=MessageRole.SYSTEM, - content=[{"type": "text", "text": system_prompt}]), - ChatMessage(role=MessageRole.USER, - content=[{"type": "text", "text": user_prompt}]), - ] - response = model(messages, stop_sequences=[]) - - raw_output = response.content - if isinstance(raw_output, list): - raw_output = " ".join( - block.get("text", "") - for block in raw_output - if isinstance(block, dict) and block.get("type") == "text" - ) - if not isinstance(raw_output, str): - raw_output = str(raw_output) - - summary = self._format_summary(raw_output) - self._record_llm_call_token( - input_len=self._msg_char_count(messages), - output_len=len(raw_output), - response=response, call_type=call_type, - ) - return summary - - - def _record_llm_call_token(self, input_len, output_len, response, call_type): - record = CompressionCallRecord( - call_type=call_type, - input_tokens=getattr(getattr(response, "token_usage", None), "input_tokens", 0) or 0, - output_tokens=getattr(getattr(response, "token_usage", None), "output_tokens", 0) or 0, - input_chars=input_len, output_chars=output_len, - ) - self.compression_calls_log.append(record) - self._step_local_log.append(record) - - def _format_summary(self, raw_output: str) -> Optional[str]: - cleaned = raw_output.strip() - if cleaned.startswith("```"): - cleaned = re.sub(r"^```(?:json)?\s*\n?", "", cleaned) - cleaned = re.sub(r"\n?```\s*$", "", cleaned) - if not cleaned: - return None - try: - parsed = json.loads(cleaned) - return json.dumps(parsed, ensure_ascii=False, indent=2) - except json.JSONDecodeError: - logger.warning("Summary output is not valid JSON; using as plain text") - return cleaned - - def _render_action_step(self, action: ActionStep) -> str: - msgs = action.to_messages(summary_mode=False) - return _extract_text_from_messages(msgs) or "" - - def _truncate_text_to_tokens(self, text: str, max_tokens: int) -> str: - if max_tokens <= 0: - return "" - if self._estimate_text_tokens(text) <= max_tokens: - return text - units = text.split("\n\n") - kept, total = [], 0 - for u in reversed(units): - u_tokens = self._estimate_text_tokens(u) - if total + u_tokens > max_tokens and kept: - break - kept.append(u) - total += u_tokens - result = "...[Earlier content truncated]...\n\n" + "\n\n".join(reversed(kept)) - if self._estimate_text_tokens(result) > max_tokens: - approx_chars = int(max_tokens * self.config.chars_per_token * 0.9) - result = "...[Earlier content truncated]...\n" + result[:approx_chars] - return result - - def _pairs_to_text(self, pairs: List[tuple]) -> str: - parts = [] - for i, (task_step, action_step) in enumerate(pairs): - task_text = task_step.task or "" - action_text = self._render_action_step(action_step) - parts.append(f"user: {task_text}\nassistant: {action_text}") - return "\n\n".join(parts) - - def _pairs_to_steps(self, pairs: List[tuple]) -> List[MemoryStep]: - steps = [] - for task_step, action_step in pairs: - steps.append(task_step) - steps.append(action_step) - return steps - - def _build_messages( - self, memory: AgentMemory, - prev_summary_step: Optional[SummaryTaskStep], - prev_tail_steps: List[MemoryStep], - curr_kept_steps: List[MemoryStep], - ) -> List[ChatMessage]: - result = [] - if memory.system_prompt: - result.extend(memory.system_prompt.to_messages()) - if prev_summary_step: - result.extend(prev_summary_step.to_messages()) - for step in prev_tail_steps: - result.extend(step.to_messages()) - for step in curr_kept_steps: - result.extend(step.to_messages()) - return result - - # ============================================================ - # Token Estimation - # ============================================================ - - def _estimate_tokens_for_steps(self, steps): - return estimate_tokens_for_steps(steps, self.config.chars_per_token) - - def _estimate_tokens(self, memory: AgentMemory) -> int: - return estimate_tokens(memory, self.config.chars_per_token) - - def _msg_char_count(self, msg: Union[ChatMessage, List[ChatMessage]]) -> int: - return msg_char_count(msg) - - def _msg_token_count(self, msg): - return msg_token_count(msg, self.config.chars_per_token) - - def get_step_compression_stats(self) -> dict: - with self._lock: - if not self._step_local_log: - return {"calls": 0, "input_tokens": 0, "output_tokens": 0, "cache_hits": 0, "cache_types": []} - cache_types = [r.call_type for r in self._step_local_log if r.cache_hit] - return { - "calls": len([r for r in self._step_local_log if not r.cache_hit]), - "input_tokens": sum(r.input_tokens for r in self._step_local_log), - "output_tokens": sum(r.output_tokens for r in self._step_local_log), - "input_chars": sum(r.input_chars for r in self._step_local_log), - "output_chars": sum(r.output_chars for r in self._step_local_log), - "cache_hits": sum(1 for r in self._step_local_log if r.cache_hit), - "cache_types": cache_types, - } - - def get_all_compression_stats(self) -> dict: - with self._lock: - real_calls = [r for r in self.compression_calls_log if not r.cache_hit] - return { - "total_calls": len(real_calls), - "total_attempts": len(self.compression_calls_log), - "total_input_tokens": sum(r.input_tokens for r in real_calls), - "total_output_tokens": sum(r.output_tokens for r in real_calls), - "total_cache_hits": sum(1 for r in self.compression_calls_log if r.cache_hit), - } - - # ============================================================ - # Benchmark export APIs - # ============================================================ - - def build_compressed_snapshot( - self, model, memory: AgentMemory, current_run_start_idx: int, - ) -> Tuple[List[ChatMessage], dict]: - """Build a frozen compressed message snapshot for probe evaluation. - - Returns (compressed_messages, metadata) without modifying internal - cache state. This enables the Probe Evaluation pattern where each - probe runs independently against a frozen compressed snapshot. - - metadata contains: token counts, which caches were used, and summary export. - """ - saved_prev_cache = self._previous_summary_cache - saved_curr_cache = self._current_summary_cache - saved_step_log = list(self._step_local_log) - saved_calls_log = list(self.compression_calls_log) - - try: - original_messages = memory.system_prompt.to_messages() if memory.system_prompt else [] - for step in memory.steps: - original_messages.extend(step.to_messages()) - - compressed_messages = self.compress_if_needed( - model, memory, original_messages, current_run_start_idx - ) - - metadata = { - "token_counts": self.get_token_counts(), - "summary": self.export_summary(), - "compression_stats": self.get_step_compression_stats(), - } - return compressed_messages, metadata - finally: - self._previous_summary_cache = saved_prev_cache - self._current_summary_cache = saved_curr_cache - self._step_local_log = saved_step_log - self.compression_calls_log = saved_calls_log - - def get_token_counts(self) -> dict: - """Return token counts from the most recent compression pass. - - Returns a dict with ``last_uncompressed`` and ``last_compressed`` token - counts, enabling accurate ``token_reduction = 1 - compressed/uncompressed`` - measurement in benchmarks. Values are None before the first compress_if_needed - call on this instance. - """ - with self._lock: - return { - "last_uncompressed": self._last_uncompressed_token_count, - "last_compressed": self._last_compressed_token_count, - } - - def export_summary(self) -> dict: - """Export current compression summary state for benchmark inspection. - - Returns a dict with the cached summary texts, cache metadata, and a - compression_boundary block describing which pairs/steps fed the - summary versus which were retained verbatim. Benchmarks use the - boundary block to validate probe design: probes should only target - information that was actually compressed. - """ - with self._lock: - prev_cache = self._previous_summary_cache - curr_cache = self._current_summary_cache - return { - "previous_summary": prev_cache.summary_text if prev_cache else None, - "current_summary": curr_cache.summary_text if curr_cache else None, - "previous_cache_info": ( - { - "covered_pairs": prev_cache.covered_pairs, - "is_fallback": "[CONTEXT COMPACTION" in (prev_cache.summary_text or ""), - } - if prev_cache else None - ), - "current_cache_info": ( - { - "end_steps": curr_cache.end_steps, - "is_fallback": "[CONTEXT COMPACTION" in (curr_cache.summary_text or ""), - } - if curr_cache else None - ), - "compression_boundary": { - "config_keep_recent_pairs": self.config.keep_recent_pairs, - "config_keep_recent_steps": self.config.keep_recent_steps, - "previous_compressed_pairs": ( - prev_cache.covered_pairs if prev_cache else 0 - ), - "previous_retained_pairs": self.config.keep_recent_pairs, - "current_compressed_steps": ( - curr_cache.end_steps if curr_cache else 0 - ), - "current_retained_steps": self.config.keep_recent_steps, - }, - } - - # ============================================================ - # Context Component Management - # ============================================================ - - def register_component(self, component) -> None: - """Register a context component for system prompt assembly. - - Components are accumulated and used by build_system_prompt(). - - Args: - component: A ContextComponent instance (e.g., ToolsComponent, - MemoryComponent, KnowledgeBaseComponent). - """ - with self._lock: - if component.token_estimate == 0: - component.token_estimate = component.estimate_tokens( - self.config.chars_per_token - ) - self._components.append(component) - - def clear_components(self) -> None: - """Clear all registered context components. - - Typically called at the start of a new agent run. - """ - with self._lock: - self._components.clear() - - def get_registered_components(self) -> List: - """Return copy of registered components.""" - with self._lock: - return list(self._components) - - def _get_strategy(self): - """Factory method to get strategy instance based on config.""" - from .agent_model import ( - FullStrategy, TokenBudgetStrategy, BufferedStrategy, PriorityWeightedStrategy - ) - strategy_map = { - "full": FullStrategy, - "token_budget": TokenBudgetStrategy, - "buffered": BufferedStrategy, - "priority": PriorityWeightedStrategy, - } - strategy_class = strategy_map.get(self.config.strategy, TokenBudgetStrategy) - - if self.config.strategy == "buffered": - return strategy_class(buffer_size=self.config.buffer_size_per_component) - elif self.config.strategy == "priority": - return strategy_class(relevance_threshold=0.5) - return strategy_class() - - def build_system_prompt(self, token_budget: Optional[int] = None) -> List: - """Build system prompt messages from registered components. - - Uses configured strategy to select components within token budget, - then converts each to message format. - - Args: - token_budget: Maximum tokens for all components. Defaults to - config.component_budgets total minus conversation_history. - - Returns: - List of message dicts with 'role' and 'content' keys. - """ - if not self._components: - return [] - - from .agent_model import SystemPromptComponent - - budget = token_budget or self._calculate_component_budget() - strategy = self._get_strategy() - selected = strategy.select_components( - self._components, budget, self.config.component_budgets - ) - - messages = [] - for comp in selected: - comp_messages = comp.to_messages() - for msg in comp_messages: - if not self._message_already_present(messages, msg): - messages.append(msg) - - return messages - - def _calculate_component_budget(self) -> int: - """Calculate total token budget for components (excluding conversation_history).""" - budgets = self.config.component_budgets - excluded = ["conversation_history"] - return sum(v for k, v in budgets.items() if k not in excluded) - - def _message_already_present(self, messages: List, new_msg: dict) -> bool: - """Check if identical message already exists.""" - for existing in messages: - if existing.get("role") == new_msg.get("role") and existing.get("content") == new_msg.get("content"): - return True - return False \ No newline at end of file diff --git a/sdk/nexent/core/agents/agent_context/__init__.py b/sdk/nexent/core/agents/agent_context/__init__.py new file mode 100644 index 000000000..d5da97350 --- /dev/null +++ b/sdk/nexent/core/agents/agent_context/__init__.py @@ -0,0 +1,29 @@ +"""Agent context management for memory compression and summarization. + +Provides ContextManager for token-aware memory compression of agent memory, +supporting incremental summarization with cache-based optimization. +""" + +from .manager import ContextManager +from .offload_store import OffloadStore +from .summary_step import SummaryTaskStep +from .llm_summary import format_summary_output, _is_context_length_error +from .step_renderer import compress_history_offline + +# Re-export types from sibling modules so that +# ``from agent_context import ContextManagerConfig`` still works. +from ..summary_config import ContextManagerConfig +from ..summary_cache import CompressionCallRecord, PreviousSummaryCache, CurrentSummaryCache + +__all__ = [ + "ContextManager", + "OffloadStore", + "SummaryTaskStep", + "format_summary_output", + "_is_context_length_error", + "compress_history_offline", + "ContextManagerConfig", + "CompressionCallRecord", + "PreviousSummaryCache", + "CurrentSummaryCache", +] diff --git a/sdk/nexent/core/agents/agent_context/budget.py b/sdk/nexent/core/agents/agent_context/budget.py new file mode 100644 index 000000000..bb2d6d258 --- /dev/null +++ b/sdk/nexent/core/agents/agent_context/budget.py @@ -0,0 +1,215 @@ +"""Budget estimation, trimming, and pure data helpers for ContextManager.""" + +import hashlib +import logging +from typing import Callable, List, Optional, Tuple + +from smolagents.memory import ActionStep, MemoryStep, TaskStep + +from ...utils.token_estimation import estimate_tokens_text +from ..summary_cache import PreviousSummaryCache, CurrentSummaryCache + +logger = logging.getLogger("agent_context.budget") + + +# ============================================================ +# Pure data helpers (no dependencies beyond stdlib + types) +# ============================================================ + +def extract_pairs(steps: List[MemoryStep]) -> List[tuple]: + """Extract (TaskStep, ActionStep) pairs from a step list.""" + pairs = [] + i = 0 + from .summary_step import SummaryTaskStep + while i < len(steps): + if isinstance(steps[i], TaskStep) and not isinstance(steps[i], SummaryTaskStep): + if i + 1 < len(steps) and isinstance(steps[i + 1], ActionStep): + pairs.append((steps[i], steps[i + 1])) + i += 2 + continue + i += 1 + return pairs + + +def action_content(action: ActionStep) -> str: + """Extract the output text from an ActionStep.""" + return action.action_output or getattr(action, "output", "") or "" + + +def pair_fingerprint(task_content: str, action_content: str) -> str: + """Compute a fingerprint hash for a (task, action) pair.""" + raw = (task_content[-200:] + action_content[-200:]) + return hashlib.md5(raw.encode()).hexdigest() + + +def action_fingerprint(action: ActionStep) -> str: + """Compute a fingerprint hash for an ActionStep.""" + raw = ( + str(action.step_number or "") + + (action.model_output or "")[-200:] + + ( + action.action_output if isinstance(action.action_output, str) + else str(action.action_output) if action.action_output else "" + )[-200:] + ) + return hashlib.md5(raw.encode()).hexdigest() + + +def has_invoked_tools(action: ActionStep) -> bool: + """Check whether an ActionStep invokes any registered tool. + + Unlike ``is_tool_call_step()`` which only checks for the generic + ``tool_calls is not None`` (always True for CodeAgent steps), this + function checks the ``invoked_tools`` list for actual tool names. + + Returns True only when the step's code called at least one tool + that is registered in the agent's ``self.tools`` dict. + """ + invoked = getattr(action, "invoked_tools", None) + return bool(invoked) + + +def is_observation_step(action: ActionStep) -> bool: + """Check if an ActionStep is an observation step.""" + return action is not None and hasattr(action, 'observations') and action.observations is not None + + +def is_tool_call_step(action: ActionStep) -> bool: + """Check if an ActionStep is a tool call step.""" + return action is not None and hasattr(action, 'tool_calls') and action.tool_calls is not None + + +# ============================================================ +# Cache validation (depends on fingerprint functions, pure) +# ============================================================ + +def is_prev_cache_valid( + prev_pairs: List[tuple], cache: Optional[PreviousSummaryCache], +) -> Tuple[bool, int]: + """Check whether the previous cache covers a prefix of prev_pairs. + + Returns (is_valid, covered_idx). When is_valid is True, + prev_pairs[0:covered_idx] can be replaced by cache.summary_text, + and prev_pairs[covered_idx:] represents the uncovered incremental portion. + """ + if cache is None or not prev_pairs: + return False, 0 + if cache.covered_pairs == 0 or cache.covered_pairs > len(prev_pairs): + return False, 0 + anchor_t, anchor_a = prev_pairs[cache.covered_pairs - 1] + fp = pair_fingerprint(anchor_t.task or "", action_content(anchor_a)) + if fp != cache.anchor_fingerprint: + return False, 0 + return True, cache.covered_pairs + + +def is_curr_cache_valid( + action_steps: List[ActionStep], cache: Optional[CurrentSummaryCache], +) -> Tuple[bool, int]: + """Check whether the current cache covers a prefix of action_steps.""" + if cache is None or not action_steps: + return False, 0 + if cache.end_steps == 0 or cache.end_steps > len(action_steps): + return False, 0 + anchor = action_steps[cache.end_steps - 1] + if action_fingerprint(anchor) != cache.anchor_fingerprint: + return False, 0 + return True, cache.end_steps + + +# ============================================================ +# Budget trimming (depends on render_fn for text conversion) +# ============================================================ + +def trim_pairs_to_budget( + pairs: List[tuple], max_tokens: int, + render_fn: Callable[[List[tuple]], str], + keep_first: bool = True, +) -> List[tuple]: + """Trim pairs to fit within a token budget. + + Args: + pairs: List of (TaskStep, ActionStep) tuples. + max_tokens: Maximum token budget. + render_fn: Function to convert pairs to text (e.g. renderer.pairs_to_text). + keep_first: If True, always keep the first pair. + """ + if not pairs: + return [] + pair_tokens = [ + estimate_tokens_text(render_fn([p])) for p in pairs + ] + sep = estimate_tokens_text("\n\n") + total = sum(pair_tokens) + sep * max(0, len(pairs) - 1) + if total <= max_tokens: + return list(pairs) + + if keep_first and len(pairs) > 1: + budget = max_tokens - pair_tokens[0] - sep + kept_tail = [] + for i in range(len(pairs) - 1, 0, -1): + cost = pair_tokens[i] + (sep if kept_tail else 0) + if cost > budget: + break + kept_tail.append(pairs[i]) + budget -= cost + return [pairs[0]] + list(reversed(kept_tail)) + + budget = max_tokens + kept = [] + for i in range(len(pairs) - 1, -1, -1): + cost = pair_tokens[i] + (sep if kept else 0) + if cost > budget: + break + kept.append(pairs[i]) + budget -= cost + return list(reversed(kept)) if kept else [pairs[-1]] + + +def trim_actions_to_budget( + actions: List[ActionStep], task_text: str, max_tokens: int, + render_fn: Callable[[List[ActionStep]], str], +) -> List[ActionStep]: + """Trim actions to fit within a token budget. + + Args: + actions: List of ActionStep instances. + task_text: Task description text. + max_tokens: Maximum token budget. + render_fn: Function to convert actions to text (e.g. renderer.actions_to_text). + """ + if not actions: + return [] + + def _total_tokens(acts): + return estimate_tokens_text(task_text + render_fn(acts)) + + if _total_tokens(actions) <= max_tokens: + return list(actions) + + for drop in range(1, len(actions) + 1): + remaining = actions[drop:] + if not remaining: + break + if is_observation_step(remaining[0]) and is_tool_call_step(actions[drop - 1]): + continue + if _total_tokens(remaining) <= max_tokens: + return list(remaining) + + return _fallback_trim_actions(actions) + + +def _fallback_trim_actions(actions: List[ActionStep]) -> List[ActionStep]: + """Fallback trimming that preserves the last complete tool call pair.""" + if not actions: + return [] + last_action = actions[-1] + if len(actions) >= 2 and is_observation_step(last_action): + prev_action = actions[-2] + if is_tool_call_step(prev_action): + logger.warning( + "Fallback limit triggered: Retaining the last complete ToolCall + Observation pair intact. " + "This may exceed the token budget, and downstream truncation will be relied upon." + ) + return [prev_action, last_action] + return [last_action] diff --git a/sdk/nexent/core/agents/agent_context/current_compression.py b/sdk/nexent/core/agents/agent_context/current_compression.py new file mode 100644 index 000000000..f1593575d --- /dev/null +++ b/sdk/nexent/core/agents/agent_context/current_compression.py @@ -0,0 +1,156 @@ +"""Current run compression for ContextManager.""" + +import logging +from dataclasses import dataclass, field +from typing import List, Optional + +from smolagents.memory import ActionStep, TaskStep + +from ..summary_cache import CompressionCallRecord, CurrentSummaryCache +from ...utils.token_estimation import estimate_tokens_text +from .budget import action_fingerprint, trim_actions_to_budget + +logger = logging.getLogger("agent_context.current_compression") + + +@dataclass +class CurrentCompressResult: + """Result of a current-run compression operation.""" + summary_text: Optional[str] + new_cache: Optional[CurrentSummaryCache] = None # None means no update + records: List[CompressionCallRecord] = field(default_factory=list) + + +class CurrentCompressor: + """Current-run compression logic. + + Owns config, renderer, and llm references. Returns CurrentCompressResult + with summary text, optional new cache, and records instead of mutating + shared state. + """ + + def __init__(self, config, renderer, llm): + self._config = config + self._renderer = renderer + self._llm = llm + + def compress( + self, curr_task: Optional[TaskStep], actions_to_compress: List[ActionStep], + cache: Optional[CurrentSummaryCache], model, + ) -> CurrentCompressResult: + """Compress current-run actions with cache-based optimization. + + Args: + curr_task: Current run's TaskStep, or None. + actions_to_compress: List of ActionStep instances to compress. + cache: Current current-run summary cache, or None. + model: LLM model for summary generation. + + Returns: + CurrentCompressResult with summary_text, optional new_cache, and records. + """ + if not actions_to_compress: + return CurrentCompressResult(summary_text=None) + + current_last_fp = action_fingerprint(actions_to_compress[-1]) + task_text = f"Current Task: {curr_task.task}\n\n" if curr_task else "" + + # 1) Full cache hit + if cache is not None and cache.end_steps == len(actions_to_compress): + if cache.anchor_fingerprint == current_last_fp: + record = CompressionCallRecord( + call_type="current_cache_hit", cache_hit=True, + details={"end_steps": cache.end_steps}, + ) + return CurrentCompressResult( + summary_text=cache.summary_text, + new_cache=None, + records=[record], + ) + + # 2) Incremental compression + if cache is not None and 0 < cache.end_steps < len(actions_to_compress): + anchor_action = actions_to_compress[cache.end_steps - 1] + if action_fingerprint(anchor_action) == cache.anchor_fingerprint: + old_summary = cache.summary_text + new_actions = actions_to_compress[cache.end_steps:] + incremental_input = ( + f"## Previous Summary\n{old_summary}\n\n" + f"## New Steps\n{task_text}{self._renderer.actions_to_text(new_actions, offload_store=self._renderer._offload_store)}" + ) + input_tokens = estimate_tokens_text(incremental_input) + if input_tokens <= self._config.max_summary_input_tokens: + llm_result = self._llm.generate_summary( + incremental_input, model, + call_type="current_incremental", prompt_type="incremental" + ) + if llm_result.summary_text: + new_cache = CurrentSummaryCache( + summary_text=llm_result.summary_text, + end_steps=len(actions_to_compress), + anchor_fingerprint=current_last_fp, + ) + return CurrentCompressResult( + summary_text=llm_result.summary_text, + new_cache=new_cache, + records=llm_result.records, + ) + logger.info( + f"current incremental input {input_tokens} tokens exceeds budget " + f"({self._config.max_summary_input_tokens}), fallback to full compression or trimmed actions" + ) + + # 3) Fresh compression: no cache or no valid cache or incremental input exceeds max_summary_input_tokens + records: List[CompressionCallRecord] = [] + + safe_actions = trim_actions_to_budget( + actions_to_compress, task_text, self._config.max_summary_input_tokens, + render_fn=self._renderer.actions_to_text, + ) + is_full_coverage = (len(safe_actions) == len(actions_to_compress)) + if not is_full_coverage: + logger.info( + f"Current full summary trimmed {len(actions_to_compress) - len(safe_actions)} " + f"oldest actions, still using cache" + ) + + actions_budget = max(0, self._config.max_summary_input_tokens - estimate_tokens_text(task_text)) + full_text = task_text + self._renderer.render_steps_with_truncation( + safe_actions, fmt="action", max_tokens=actions_budget, + offload_store=self._renderer._offload_store, + ) + llm_result = self._llm.generate_summary(full_text, model, call_type="current_summary", prompt_type="initial") + records.extend(llm_result.records) + + if llm_result.summary_text: + new_cache = CurrentSummaryCache( + summary_text=llm_result.summary_text, + end_steps=len(actions_to_compress), + anchor_fingerprint=current_last_fp, + ) + return CurrentCompressResult( + summary_text=llm_result.summary_text, + new_cache=new_cache, + records=records, + ) + else: + reduced_actions = trim_actions_to_budget( + actions_to_compress, task_text, self._config.max_summary_reduce_tokens, + render_fn=self._renderer.actions_to_text, + ) + actions_text = self._renderer.render_steps_with_truncation( + reduced_actions, fmt="action", max_tokens=self._config.max_summary_reduce_tokens, + offload_store=self._renderer._offload_store, + ) + fallback_text = ( + "[CONTEXT COMPACTION \u2014 REFERENCE ONLY] Some recent action steps were removed to free context space. " + "Continue based on the remaining steps below.\n\n" + f"Steps removed: {len(actions_to_compress) - len(reduced_actions)} of {len(actions_to_compress)}\n\n" + "Remaining steps:\n" + + actions_text + ) + return CurrentCompressResult( + summary_text=fallback_text, + new_cache=None, + records=records, + ) \ No newline at end of file diff --git a/sdk/nexent/core/agents/agent_context/llm_summary.py b/sdk/nexent/core/agents/agent_context/llm_summary.py new file mode 100644 index 000000000..1744cb85a --- /dev/null +++ b/sdk/nexent/core/agents/agent_context/llm_summary.py @@ -0,0 +1,173 @@ +"""LLM summary generation utilities for ContextManager.""" + +import json +import logging +import re +from dataclasses import dataclass, field +from typing import List, Optional + +from smolagents.models import ChatMessage, MessageRole + +from ..summary_cache import CompressionCallRecord +from ...utils.token_estimation import msg_char_count + +logger = logging.getLogger("agent_context.llm_summary") + + +# ============================================================ +# Standalone utilities (no ContextManager state required) +# ============================================================ + +def format_summary_output(raw_output: str) -> Optional[str]: + """Clean and validate LLM summary output. + + Strips markdown code fences, attempts JSON parse for normalization, + falls back to plain text if not valid JSON. + """ + cleaned = raw_output.strip() + if cleaned.startswith("```"): + cleaned = re.sub(r"^```(?:json)?\s*\n?", "", cleaned) + cleaned = re.sub(r"\n?```\s*$", "", cleaned) + if not cleaned: + return None + try: + parsed = json.loads(cleaned) + return json.dumps(parsed, ensure_ascii=False, indent=2) + except json.JSONDecodeError: + logger.warning("Summary output is not valid JSON; using as plain text") + return cleaned + + +def _is_context_length_error(err: Exception) -> bool: + """Check if an exception indicates a context length / token limit error.""" + msg = str(err).lower() + return any(k in msg for k in ( + "context_length", "context length", "maximum context", "maximum context length", + "prompt is too long", "reduce the length", "too many tokens", + "token limit", "exceeds the maximum", "input is too long", + "input length", "exceeds context", "context window", + )) + + +# ============================================================ +# SummaryResult +# ============================================================ + +@dataclass +class SummaryResult: + """Result of an LLM summary generation call.""" + summary_text: Optional[str] + records: List[CompressionCallRecord] = field(default_factory=list) + + +# ============================================================ +# LLMSummary (standalone class, owns config + renderer) +# ============================================================ + +class LLMSummary: + """LLM summary generation. + + Owns config and renderer reference. Returns SummaryResult + with records instead of mutating shared log state. + """ + + def __init__(self, config, renderer): + self._config = config + self._renderer = renderer + + def generate_summary( + self, text: str, model, call_type: str = "summary", prompt_type: str = "initial", + ) -> SummaryResult: + """Generate a summary with retry and error handling. + + Returns SummaryResult containing the summary text (or None on failure) + and a list of CompressionCallRecord entries for the call(s) made. + """ + try: + return self._do_generate_summary(text, model, call_type, prompt_type) + except Exception as e: + if _is_context_length_error(e): + logger.warning(f"{call_type} exceeds context limit; retrying with 2/3 budget truncation") + shrunk = self._renderer.truncate_text_to_tokens( + text, int(self._config.max_summary_input_tokens * 0.66) + ) + try: + return self._do_generate_summary(shrunk, model, call_type + "_retry", prompt_type) + except Exception as e2: + logger.error(f"Retry still failed: {e2}") + record = self._record_failed_compression(call_type + "_retry_failed", str(e2)) + return SummaryResult(summary_text=None, records=[record]) + logger.error(f"Summary generation exception: {e}") + record = self._record_failed_compression(call_type + "_failed", str(e)) + return SummaryResult(summary_text=None, records=[record]) + + def _do_generate_summary( + self, text: str, model, call_type: str = "summary", prompt_type: str = "initial", + ) -> SummaryResult: + if prompt_type == "incremental": + system_prompt = ( + self._config.incremental_summary_system_prompt + or self._config.summary_system_prompt + ) + else: + system_prompt = self._config.summary_system_prompt + + schema_desc = json.dumps( + self._config.summary_json_schema, ensure_ascii=False, indent=2 + ) + if prompt_type == "incremental": + user_prompt = ( + f"Update the summary following this JSON structure:\n{schema_desc}\n\n" + f"{text}" + ) + else: + user_prompt = ( + f"Output a summary following this JSON structure:\n{schema_desc}\n\n" + f"Conversation content to summarize:\n{text}" + ) + messages = [ + ChatMessage(role=MessageRole.SYSTEM, + content=[{"type": "text", "text": system_prompt}]), + ChatMessage(role=MessageRole.USER, + content=[{"type": "text", "text": user_prompt}]), + ] + response = model(messages, stop_sequences=[]) + + raw_output = response.content + if isinstance(raw_output, list): + raw_output = " ".join( + block.get("text", "") + for block in raw_output + if isinstance(block, dict) and block.get("type") == "text" + ) + if not isinstance(raw_output, str): + raw_output = str(raw_output) + + summary = format_summary_output(raw_output) + record = self._record_llm_call_token( + input_len=msg_char_count(messages), + output_len=len(raw_output), + response=response, call_type=call_type, + ) + return SummaryResult(summary_text=summary, records=[record]) + + def _record_llm_call_token(self, input_len, output_len, response, call_type) -> CompressionCallRecord: + """Record a successful LLM call's token usage. Returns the record.""" + return CompressionCallRecord( + call_type=call_type, + input_tokens=getattr(getattr(response, "token_usage", None), "input_tokens", 0) or 0, + output_tokens=getattr(getattr(response, "token_usage", None), "output_tokens", 0) or 0, + input_chars=input_len, output_chars=output_len, + ) + + def _record_failed_compression(self, call_type: str, error_msg: str) -> CompressionCallRecord: + """Record a failed compression attempt. Returns the record.""" + return CompressionCallRecord( + call_type=call_type, + input_tokens=0, + output_tokens=0, + input_chars=0, + output_chars=0, + cache_hit=False, + details={"error": error_msg}, + ) diff --git a/sdk/nexent/core/agents/agent_context/manager.py b/sdk/nexent/core/agents/agent_context/manager.py new file mode 100644 index 000000000..04ebc6b90 --- /dev/null +++ b/sdk/nexent/core/agents/agent_context/manager.py @@ -0,0 +1,494 @@ +"""ContextManager: agent memory context management and compression. + +Composes sub-components (StepRenderer, LLMSummary, PreviousCompressor, +CurrentCompressor) and pure functions (budget, stats_export) to provide +full compression functionality. The main entry point is +``compress_if_needed()`` which orchestrates previous/current compression +phases with cache-based optimization. +""" + +import logging +import threading +from typing import List, Optional + +from smolagents.memory import ActionStep, AgentMemory, MemoryStep, TaskStep +from smolagents.models import ChatMessage + +from ..summary_cache import CompressionCallRecord, CurrentSummaryCache, PreviousSummaryCache +from ..summary_config import ContextManagerConfig, StrategyType +from .budget import ( + extract_pairs, is_prev_cache_valid, is_curr_cache_valid, + trim_pairs_to_budget, trim_actions_to_budget, +) +from .current_compression import CurrentCompressor +from .llm_summary import LLMSummary +from .offload_store import OffloadStore +from .previous_compression import PreviousCompressor +from .stats_export import ( + get_step_compression_stats, get_all_compression_stats, + export_summary, get_token_counts, +) +from .step_renderer import StepRenderer +from .summary_step import SummaryTaskStep +from ...utils.token_estimation import estimate_tokens, estimate_tokens_for_system_prompt, estimate_tokens_for_steps, estimate_tokens_text, msg_token_count + +logger = logging.getLogger("agent_context") + + +class ContextManager: + """Agent memory context management and compression. + + Orchestrates token-aware compression of agent memory, supporting + incremental summarization with cache-based optimization. + + Composes (not inherits) sub-components. Owns all state and + delegates computation to pure functions and sub-component methods. + """ + + def __init__(self, config: Optional[ContextManagerConfig] = None, + max_steps: Optional[int] = None, offload_store: Optional[OffloadStore] = None): + self.config = config or ContextManagerConfig() + + # --- Owned state --- + self._previous_summary_cache: Optional[PreviousSummaryCache] = None + self._current_summary_cache: Optional[CurrentSummaryCache] = None + + # Run boundary self-detection. The current cache fingerprint is only reused + # within the current run and must be explicitly cleared at the start of a new run. + # The previous cache is managed and updated across runs. + self._last_run_start_idx: Optional[int] = None + + if max_steps is not None and self.config.keep_recent_steps >= max_steps: + self.config.keep_recent_steps = max_steps + + self.compression_calls_log: List[CompressionCallRecord] = [] + self._step_local_log: List[CompressionCallRecord] = [] + self._lock = threading.Lock() + self._offload_store = offload_store or OffloadStore( + max_entries=self.config.max_offload_entries, + max_entry_chars=self.config.max_offload_entry_chars, + max_total_chars=self.config.max_offload_total_chars, + ) + self._last_uncompressed_token_count: Optional[int] = None + self._last_compressed_token_count: Optional[int] = None + + if self.config.max_summary_input_tokens <= 0: + self.config.max_summary_input_tokens = int(self.config.token_threshold * 1.2) + if self.config.max_summary_reduce_tokens <= 0: + self.config.max_summary_reduce_tokens = int(self.config.token_threshold * 0.2) + + # --- Composed sub-components --- + self._renderer = StepRenderer(self.config, self._offload_store) + self._llm = LLMSummary(self.config, self._renderer) + self._prev_compressor = PreviousCompressor(self.config, self._renderer, self._llm) + self._curr_compressor = CurrentCompressor(self.config, self._renderer, self._llm) + + self._components: List = [] + + @property + def offload_store(self) -> OffloadStore: + return self._offload_store + + # ============================================================ + # Effective token estimation (orchestration-level, uses state) + # ============================================================ + + def _effective_tokens(self, memory: AgentMemory, current_run_start_idx: int) -> int: + """Estimates the actual token burden of the upcoming build_messages call.""" + system_prompt_tokens = estimate_tokens_for_system_prompt(memory) + prev_steps = memory.steps[:current_run_start_idx] + curr_steps = memory.steps[current_run_start_idx:] + return (system_prompt_tokens + self._effective_prev_tokens(prev_steps) + + self._effective_curr_tokens(curr_steps)) + + def _effective_prev_tokens(self, prev_steps: List[MemoryStep]) -> int: + if not prev_steps: + return 0 + prev_pairs = extract_pairs(prev_steps) + is_valid, covered_idx = is_prev_cache_valid(prev_pairs, self._previous_summary_cache) + if not is_valid: + return estimate_tokens_for_steps(prev_steps, self.config.chars_per_token) + uncovered = prev_pairs[covered_idx:] + uncovered_tokens = ( + estimate_tokens_text(self._renderer.pairs_to_text(uncovered)) + if uncovered else 0 + ) + return (estimate_tokens_text(self._previous_summary_cache.summary_text) + + uncovered_tokens) + + def _effective_curr_tokens(self, curr_steps: List[MemoryStep]) -> int: + if not curr_steps: + return 0 + curr_task = curr_steps[0] if isinstance(curr_steps[0], TaskStep) else None + action_steps = [s for s in curr_steps if isinstance(s, ActionStep)] + is_valid, covered_idx = is_curr_cache_valid(action_steps, self._current_summary_cache) + if not is_valid: + return estimate_tokens_for_steps(curr_steps, self.config.chars_per_token) + task_tokens = ( + estimate_tokens_text(curr_task.task or "") if curr_task else 0 + ) + uncovered = action_steps[covered_idx:] + uncovered_tokens = ( + estimate_tokens_text(self._renderer.actions_to_text(uncovered)) + if uncovered else 0 + ) + return (task_tokens + + estimate_tokens_text(self._current_summary_cache.summary_text) + + uncovered_tokens) + + # ============================================================ + # Main Entry Point + # ============================================================ + + def compress_if_needed( + self, model, memory, original_messages: List[ChatMessage], current_run_start_idx, + ) -> List[ChatMessage]: + if not self.config.enabled: + return original_messages + + if estimate_tokens(memory, self.config.chars_per_token) <= self.config.token_threshold: + self._last_uncompressed_token_count = msg_token_count(original_messages, self.config.chars_per_token) + self._last_compressed_token_count = self._last_uncompressed_token_count + return original_messages + + with self._lock: + # Run detection + if (self._last_run_start_idx is not None + and current_run_start_idx != self._last_run_start_idx): + # Only the per-run compression cache is run-scoped and reset here. + # The offload store is intentionally NOT cleared: it is now + # session-scoped and owned externally (injected via the + # ``offload_store`` parameter), so archived content survives across + # runs within the same session and can be re-listed + reloaded. + self._current_summary_cache = None + self._last_run_start_idx = current_run_start_idx + + # Note: The memory here always consists of the unmodified, summary-task-step-free + # original previous_run + current_run. + # - previous_run: [(TaskStep, ActionStep), ...] + # - current_run: [TaskStep, ActionStep, ActionStep, ...] + if self._effective_tokens(memory, current_run_start_idx) <= self.config.token_threshold: + # Stable-phase bypass: No LLM call; construct compressed messages directly from existing cache. + self._step_local_log.clear() + + prev_steps = memory.steps[:current_run_start_idx] + curr_steps = memory.steps[current_run_start_idx:] + + prev_summary_step = None + prev_tail_steps = list(prev_steps) + prev_pairs = extract_pairs(prev_steps) + if prev_pairs: + is_valid, covered_idx = is_prev_cache_valid(prev_pairs, self._previous_summary_cache) + if is_valid: + prev_summary_step = SummaryTaskStep( + task=self._previous_summary_cache.summary_text, + prefix="Summary of earlier steps in this task:", + ) + uncovered = prev_pairs[covered_idx:] + prev_tail_steps = self._renderer.pairs_to_steps(uncovered) + + curr_kept_steps = list(curr_steps) + if curr_steps: + curr_task = curr_steps[0] if isinstance(curr_steps[0], TaskStep) else None + curr_action_steps = [s for s in curr_steps if isinstance(s, ActionStep)] + if curr_action_steps: + is_valid, covered_idx = is_curr_cache_valid(curr_action_steps, self._current_summary_cache) + if is_valid: + uncovered = curr_action_steps[covered_idx:] + curr_kept_steps = ( + ([curr_task] if curr_task else []) + + [SummaryTaskStep(task=self._current_summary_cache.summary_text, prefix="Summary of earlier steps in this task:")] + + list(uncovered) + ) + + record = CompressionCallRecord( + call_type="stable_bypass", cache_hit=True, + details={"reason": "stable_period_effective_under_threshold"}, + ) + self.compression_calls_log.append(record) + self._step_local_log.append(record) + + compressed_msgs = self._renderer.build_messages( + memory, prev_summary_step, prev_tail_steps, curr_kept_steps + ) + self._last_compressed_token_count = msg_token_count(compressed_msgs, self.config.chars_per_token) + return compressed_msgs + + self._step_local_log.clear() + + self._last_uncompressed_token_count = msg_token_count(original_messages, self.config.chars_per_token) + + prev_steps = memory.steps[:current_run_start_idx] + curr_steps = memory.steps[current_run_start_idx:] + + prev_tokens = self._effective_prev_tokens(prev_steps) + curr_tokens = self._effective_curr_tokens(curr_steps) + + compress_prev = prev_tokens > self.config.token_threshold * 0.6 + compress_curr = curr_tokens > self.config.token_threshold * 0.4 + + total_effective_tokens = prev_tokens + curr_tokens + if compress_prev or compress_curr: + logger.info( + f"Context compression triggered: total_tokens={total_effective_tokens}, " + f"threshold={self.config.token_threshold}, " + f"prev_tokens={prev_tokens} (compress={compress_prev}), " + f"curr_tokens={curr_tokens} (compress={compress_curr})" + ) + + # --------------- Previous phase --------------- + prev_summary_step: Optional[SummaryTaskStep] = None + prev_tail_steps: List[MemoryStep] = list(prev_steps) + prev_pairs = extract_pairs(prev_steps) + + if compress_prev and prev_pairs: + keep_n = min(self.config.keep_recent_pairs, len(prev_pairs)) + pairs_to_compress = prev_pairs[:-keep_n] if keep_n > 0 else prev_pairs + pairs_to_keep = prev_pairs[-keep_n:] if keep_n > 0 else [] + if pairs_to_compress: + prev_result = self._prev_compressor.compress( + pairs_to_compress, self._previous_summary_cache, model, + ) + self.compression_calls_log.extend(prev_result.records) + self._step_local_log.extend(prev_result.records) + if prev_result.new_cache is not None: + self._previous_summary_cache = prev_result.new_cache + if prev_result.summary_text: + is_fallback = "[CONTEXT COMPACTION" in prev_result.summary_text + prev_summary_step = SummaryTaskStep( + task=prev_result.summary_text, + prefix="Context fallback, Truncated raw history:" if is_fallback else "Summary of earlier steps in this task:" + ) + prev_tail_steps = self._renderer.pairs_to_steps(pairs_to_keep) + elif prev_pairs: + # if cache is valid, use cache + uncovered display + is_valid, covered_idx = is_prev_cache_valid(prev_pairs, self._previous_summary_cache) + if is_valid: + prev_summary_step = SummaryTaskStep( + task=self._previous_summary_cache.summary_text, + prefix="Summary of earlier steps in this task:", + ) + uncovered = prev_pairs[covered_idx:] + prev_tail_steps = self._renderer.pairs_to_steps(uncovered) + + # --------------- Current phase --------------- + curr_kept_steps: List[MemoryStep] = list(curr_steps) + + if curr_steps: + curr_task = curr_steps[0] if isinstance(curr_steps[0], TaskStep) else None + curr_action_steps = [s for s in curr_steps if isinstance(s, ActionStep)] + + if compress_curr and curr_action_steps: + keep_n = min(self.config.keep_recent_steps, len(curr_action_steps)) + # Note: No cross-step pair detection needed here. Each ActionStep + # is self-contained — tool_calls and observations always belong to + # the same step (set in _step_stream), so there is no risk of + # splitting a call-observation pair across the compression boundary. + + actions_to_compress = ( + curr_action_steps[:-keep_n] if keep_n > 0 else list(curr_action_steps) + ) + + actions_to_compress = ( + curr_action_steps[:-keep_n] if keep_n > 0 else list(curr_action_steps) + ) + actions_to_keep = ( + curr_action_steps[-keep_n:] if keep_n > 0 else [] + ) + if actions_to_compress: + curr_result = self._curr_compressor.compress( + curr_task, actions_to_compress, self._current_summary_cache, model, + ) + self.compression_calls_log.extend(curr_result.records) + self._step_local_log.extend(curr_result.records) + if curr_result.new_cache is not None: + self._current_summary_cache = curr_result.new_cache + if curr_result.summary_text: + is_fallback = "[CONTEXT COMPACTION" in curr_result.summary_text + curr_summary_step = SummaryTaskStep( + task=curr_result.summary_text, + prefix="Truncated recent action steps:" if is_fallback else "Summary of earlier steps in this task:" + ) + curr_kept_steps = ( + ([curr_task] if curr_task else []) + + [curr_summary_step] + + list(actions_to_keep) + ) + elif curr_action_steps: + is_valid, covered_idx = is_curr_cache_valid(curr_action_steps, self._current_summary_cache) + if is_valid: + uncovered = curr_action_steps[covered_idx:] + curr_kept_steps = ( + ([curr_task] if curr_task else []) + + [SummaryTaskStep(task=self._current_summary_cache.summary_text, prefix="Summary of earlier steps in this task:")] + + list(uncovered) + ) + + final_messages = self._renderer.build_messages( + memory, prev_summary_step, prev_tail_steps, curr_kept_steps + ) + final_tokens = msg_token_count(final_messages, self.config.chars_per_token) + self._last_compressed_token_count = final_tokens + # This situation is unlikely to occur unless the threshold itself is set unreasonably small + if final_tokens > int(self.config.token_threshold * 1.1): + logger.warning( + f"Still exceeds threshold after compression: {final_tokens} > {self.config.token_threshold}. " + f"Consider reducing keep_recent_pairs ({self.config.keep_recent_pairs}) " + f"or keep_recent_steps({self.config.keep_recent_steps})" + ) + return final_messages + + # ============================================================ + # Stats and Export (delegate to pure functions) + # ============================================================ + + def get_step_compression_stats(self) -> dict: + with self._lock: + return get_step_compression_stats(self._step_local_log) + + def get_all_compression_stats(self) -> dict: + with self._lock: + return get_all_compression_stats(self.compression_calls_log) + + def export_summary(self) -> dict: + with self._lock: + return export_summary(self._previous_summary_cache, self._current_summary_cache, self.config) + + def get_token_counts(self) -> dict: + with self._lock: + return get_token_counts(self._last_uncompressed_token_count, self._last_compressed_token_count) + + def build_compressed_snapshot( + self, model, memory: AgentMemory, current_run_start_idx: int, + ) -> tuple: + """Build a frozen compressed message snapshot for probe evaluation. + + Returns (compressed_messages, metadata) without modifying internal + cache state. + """ + # Save current state before compression (no lock -- compress_if_needed + # acquires its own lock, so we must not hold one here) + saved_prev_cache = self._previous_summary_cache + saved_curr_cache = self._current_summary_cache + saved_step_log = list(self._step_local_log) + saved_calls_log = list(self.compression_calls_log) + + try: + original_messages = memory.system_prompt.to_messages() if memory.system_prompt else [] + for step in memory.steps: + original_messages.extend(step.to_messages()) + + compressed_messages = self.compress_if_needed( + model, memory, original_messages, current_run_start_idx + ) + + metadata = { + "token_counts": self.get_token_counts(), + "summary": self.export_summary(), + "compression_stats": self.get_step_compression_stats(), + } + return compressed_messages, metadata + finally: + # Restore original state -- snapshot must not mutate cache + self._previous_summary_cache = saved_prev_cache + self._current_summary_cache = saved_curr_cache + self._step_local_log = saved_step_log + self.compression_calls_log = saved_calls_log + + # ============================================================ + # Context Component Management + # ============================================================ + + def register_component(self, component) -> None: + """Register a context component for system prompt assembly. + + Components are accumulated and used by build_system_prompt(). + + Args: + component: A ContextComponent instance (e.g., ToolsComponent, + MemoryComponent, KnowledgeBaseComponent). + """ + with self._lock: + if component.token_estimate == 0: + component.token_estimate = component.estimate_tokens( + self.config.chars_per_token + ) + self._components.append(component) + + def clear_components(self) -> None: + """Clear all registered context components. + + Typically called at the start of a new agent run. + """ + with self._lock: + self._components.clear() + + def get_registered_components(self) -> List: + """Return copy of registered components.""" + with self._lock: + return list(self._components) + + def _get_strategy(self): + """Factory method to get strategy instance based on config.""" + from ..agent_model import ( + FullStrategy, TokenBudgetStrategy, BufferedStrategy, PriorityWeightedStrategy + ) + strategy_map = { + "full": FullStrategy, + "token_budget": TokenBudgetStrategy, + "buffered": BufferedStrategy, + "priority": PriorityWeightedStrategy, + } + strategy_class = strategy_map.get(self.config.strategy, TokenBudgetStrategy) + + if self.config.strategy == "buffered": + return strategy_class(buffer_size=self.config.buffer_size_per_component) + elif self.config.strategy == "priority": + return strategy_class(relevance_threshold=0.5) + return strategy_class() + + def build_system_prompt(self, token_budget: Optional[int] = None) -> List: + """Build system prompt messages from registered components. + + Uses configured strategy to select components within token budget, + then converts each to message format. + + Args: + token_budget: Maximum tokens for all components. Defaults to + config.component_budgets total minus conversation_history. + + Returns: + List of message dicts with 'role' and 'content' keys. + """ + if not self._components: + return [] + + from ..agent_model import SystemPromptComponent + + budget = token_budget or self._calculate_component_budget() + strategy = self._get_strategy() + selected = strategy.select_components( + self._components, budget, self.config.component_budgets + ) + + messages = [] + for comp in selected: + comp_messages = comp.to_messages() + for msg in comp_messages: + if not self._message_already_present(messages, msg): + messages.append(msg) + + return messages + + def _calculate_component_budget(self) -> int: + """Calculate total token budget for components (excluding conversation_history).""" + budgets = self.config.component_budgets + excluded = ["conversation_history"] + return sum(v for k, v in budgets.items() if k not in excluded) + + def _message_already_present(self, messages: List, new_msg: dict) -> bool: + """Check if identical message already exists.""" + for existing in messages: + if existing.get("role") == new_msg.get("role") and existing.get("content") == new_msg.get("content"): + return True + return False diff --git a/sdk/nexent/core/agents/agent_context/offload_store.py b/sdk/nexent/core/agents/agent_context/offload_store.py new file mode 100644 index 000000000..a747b821c --- /dev/null +++ b/sdk/nexent/core/agents/agent_context/offload_store.py @@ -0,0 +1,283 @@ +"""In-memory store for offloaded step content, keyed by UUID handle.""" + +import re +import uuid +import logging +import threading +from typing import Dict, List, Optional, Tuple + +logger = logging.getLogger("agent_context.offload_store") + + +class _Entry: + """A single offloaded entry: original content plus a short description. + + The description is a human/LLM-readable hint about what was archived + (e.g. the observation's first line). It lets the model judge, from the + active-handle listing alone, whether a given archived item is relevant + to the current question and therefore worth reloading. + """ + + __slots__ = ("content", "description", "tokens") + + def __init__(self, content: str, description: str = "", tokens: set = None): + self.content = content + self.description = description + # Pre-computed token set for keyword-overlap scoring. + # Computed once at store() time to avoid re-tokenizing on every + # build_reload_inventory call. None means "not yet computed". + self.tokens = tokens or set() + + +class OffloadStore: + """In-memory store for offloaded step content, keyed by UUID handle. + + Each entry keeps the original ``content`` and a short ``description``. + The store is the single source of truth for "what can currently be + reloaded": ``list_active()`` returns exactly the (handle, description) + pairs for which ``reload(handle)`` is guaranteed to succeed right now + (i.e. evicted entries never appear). This is what lets a fresh run + re-list reloadable archives without relying on handles surviving + inside compressed/summarized conversation history. + """ + + def __init__(self, max_entries: int = 200, max_total_chars: int = 2_000_000, max_entry_chars: int = 30000): + self._store: Dict[str, _Entry] = {} + self._max_entries = max_entries + self._max_total_chars = max_total_chars + self._max_entry_chars = max_entry_chars + self._current_total = 0 + self._lock = threading.Lock() + # Diagnostics: count successful reloads so tests/metrics can verify + # the reload path was actually exercised (not inferred from streamed text). + self._reload_hits = 0 + self._reload_misses = 0 + + def store(self, content: str, description: str = "") -> Optional[str]: + """Store content (+ optional description) and return a UUID handle. + + Returns None if the content exceeds ``max_entry_chars`` and cannot + be stored. + """ + if len(content) > self._max_entry_chars: + logger.warning( + f"Content exceeds max_entry_chars ({self._max_entry_chars}), " + f"skipping offload for {len(content)} chars" + ) + return None + + handle = uuid.uuid4().hex + with self._lock: + # Evict oldest entries if total chars would exceed budget + while (self._current_total + len(content) > self._max_total_chars + and self._store): + oldest = next(iter(self._store)) + self._current_total -= len(self._store[oldest].content) + del self._store[oldest] + + # Evict oldest entry if count budget exceeded + if len(self._store) >= self._max_entries: + oldest = next(iter(self._store)) + self._current_total -= len(self._store[oldest].content) + del self._store[oldest] + + # Pre-compute tokens at store time so build_reload_inventory + # never re-tokenizes descriptions during scoring. + entry_tokens = OffloadStore._tokenize(description) + self._store[handle] = _Entry(content, description, tokens=entry_tokens) + self._current_total += len(content) + return handle + + def reload(self, handle: str) -> Optional[str]: + """Retrieve offloaded content by handle. Returns None if not found.""" + with self._lock: + entry = self._store.get(handle) + if entry is None: + self._reload_misses += 1 + return None + self._reload_hits += 1 + return entry.content + + def list_active(self) -> List[Tuple[str, str]]: + """Return (handle, description) for every entry currently reloadable. + + The returned set is exactly the handles for which ``reload`` would + succeed right now: evicted entries are absent. Callers can render + this into a per-run, ephemeral inventory of reloadable archives. + """ + with self._lock: + return [(h, e.description) for h, e in self._store.items()] + + # Common English stop words filtered during tokenization to reduce + # spurious partial matches from high-frequency short tokens like "in", + # "is", "all", etc. + _STOP_WORDS: set = { + "the", "a", "an", "in", "on", "at", "to", "for", "of", "is", "are", + "was", "were", "be", "been", "and", "or", "not", "it", "its", "this", + "that", "with", "from", "all", "has", "have", "do", "does", "did", + "will", "would", "can", "could", "should", "may", "might", "shall", + "but", "if", "then", "else", "when", "up", "so", "no", + } + + # CJK Unified Ideographs range (U+4E00–U+9FFF). Matched via ord() range + # rather than a regex Unicode-escape so the logic is self-documenting. + _CJK_LO = 0x4E00 + _CJK_HI = 0x9FFF + + @staticmethod + def _tokenize(text: str) -> set: + """Tokenize text for keyword-overlap scoring, supporting CJK + Latin. + + CJK text (Chinese / Japanese / Korean hanzi) is split into overlapping + character bigrams so that multi-character words like "数据库" become + {"数据", "据库"} and can match against query bigrams. + + Latin text is lowercased, punctuation-stripped, split on whitespace, + and filtered for stop-words, single-character tokens, and pure-digits. + """ + tokens: set = set() + text_lower = text.lower() + + # ── CJK character bigrams ────────────────────────────── + cjk_run: list = [] + for ch in text_lower: + if OffloadStore._CJK_LO <= ord(ch) <= OffloadStore._CJK_HI: + cjk_run.append(ch) + else: + if len(cjk_run) >= 2: + for i in range(len(cjk_run) - 1): + tokens.add(cjk_run[i] + cjk_run[i + 1]) + cjk_run.clear() + if len(cjk_run) >= 2: + for i in range(len(cjk_run) - 1): + tokens.add(cjk_run[i] + cjk_run[i + 1]) + + # ── Latin word tokens ────────────────────────────────── + latin_part = re.sub(r'[^\w\s]', ' ', text_lower) + for t in latin_part.split(): + if (len(t) >= 2 + and not t.isdigit() + and t not in OffloadStore._STOP_WORDS): + tokens.add(t) + + return tokens + + def _score_description(self, desc_tokens: set, query_tokens: set) -> float: + """Score pre-computed *desc_tokens* against *query_tokens*. + + Exact token matches count 1.0; substring containment (e.g. "db" in + "database", "download" in "downloaded") counts 0.5. The overlap is + squared before dividing by the capped denominator so that entries + with multiple matches are amplified relative to single-match noise + (2 matches → 4× weight, 3 → 9×). + """ + if not desc_tokens: + return 0.0 + + # Exact matches + overlap = float(len(desc_tokens & query_tokens)) + + # Partial matches: one token contains the other (min 2 chars) + remaining_desc = desc_tokens - query_tokens + remaining_query = query_tokens - desc_tokens + for dt in remaining_desc: + for qt in remaining_query: + if len(dt) >= 2 and len(qt) >= 2 and (qt in dt or dt in qt): + overlap += 0.5 + break # count each desc token at most once + + # Square the overlap to amplify multi-match entries vs single-match + # noise (common in CJK bigram matching). + return (overlap * overlap) / min(len(desc_tokens), 8) + + def build_reload_inventory( + self, + enable_reload: bool, + query: Optional[str] = None, + max_items: int = 10, + ) -> Optional[str]: + """Build a per-run inventory listing reloadable archives. + + When ``query`` is provided, entries are scored by keyword overlap + and sorted by relevance (highest first), capped at ``max_items``. + Entries with zero overlap are dropped so the LLM only sees items + with at least some lexical connection to the query. + + When ``query`` is None the most recent entries are used (FIFO tail). + + Args: + enable_reload: If False, returns None immediately. + query: Optional user query for relevance scoring. + max_items: Maximum entries to include in the inventory. + + Returns: + Inventory text, or None if nothing to list or reload is disabled. + """ + if not enable_reload: + return None + active = self.list_active() + if not active: + return None + + # When a query is provided, score entries by keyword overlap and + # keep only the top max_items with non-zero scores. Fall back to + # recency (FIFO tail) when the query is empty or nothing matched. + if query: + query_tokens = self._tokenize(query) + if query_tokens: + scored = [ + (handle, desc, self._score_description( + self._store[handle].tokens, query_tokens)) + for handle, desc in active + ] + scored.sort(key=lambda x: x[2], reverse=True) + matching = [(h, d) for h, d, s in scored if s > 0] + if matching: + active = matching[:max_items] + else: + active = active[-max_items:] + + if len(active) > max_items: + active = active[-max_items:] + + lines = [ + f"- handle={handle}: {description}" + for handle, description in active + ] + return ( + "[System Notice - Not User Input] The following content was archived " + "(offloaded) earlier in this session. You can retrieve the full " + "original text by calling reload_original_context_messages with the " + "corresponding handle. If any of these are relevant to answering the " + "user's question below, decide whether to reload them; do not guess " + "based on truncated display text.\n" + + "\n".join(lines) + ) + + @property + def reload_hits(self) -> int: + """Number of successful reload() calls (diagnostics).""" + with self._lock: + return self._reload_hits + + @property + def reload_misses(self) -> int: + """Number of reload() calls that missed (evicted/unknown handle).""" + with self._lock: + return self._reload_misses + + def __len__(self) -> int: + """Return the number of stored entries. Thread-safe.""" + with self._lock: + return len(self._store) + + def items(self): + """Return a thread-safe snapshot of all (handle, content) pairs.""" + with self._lock: + return [(h, e.content) for h, e in self._store.items()] + + def clear(self) -> None: + """Clear all offloaded content.""" + with self._lock: + self._store.clear() + self._current_total = 0 diff --git a/sdk/nexent/core/agents/agent_context/previous_compression.py b/sdk/nexent/core/agents/agent_context/previous_compression.py new file mode 100644 index 000000000..6bdfcb624 --- /dev/null +++ b/sdk/nexent/core/agents/agent_context/previous_compression.py @@ -0,0 +1,188 @@ +"""Previous run compression for ContextManager.""" + +import logging +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + +from smolagents.memory import ActionStep + +from ..summary_cache import CompressionCallRecord, PreviousSummaryCache +from ...utils.token_estimation import estimate_tokens_text +from .budget import extract_pairs, action_content, pair_fingerprint, trim_pairs_to_budget +from .summary_step import SummaryTaskStep + +logger = logging.getLogger("agent_context.previous_compression") + + +@dataclass +class PreviousCompressResult: + """Result of a previous-run compression operation.""" + summary_text: Optional[str] + new_cache: Optional[PreviousSummaryCache] = None # None means no update + records: List[CompressionCallRecord] = field(default_factory=list) + + +class PreviousCompressor: + """Previous-run compression logic. + + Owns config, renderer, and llm references. Returns PreviousCompressResult + with summary text, optional new cache, and records instead of mutating + shared state. + """ + + def __init__(self, config, renderer, llm): + self._config = config + self._renderer = renderer + self._llm = llm + + def compress( + self, pairs_to_compress: List[tuple], cache: Optional[PreviousSummaryCache], model, + ) -> PreviousCompressResult: + """Compress previous-run pairs with cache-based optimization. + + Args: + pairs_to_compress: List of (TaskStep, ActionStep) tuples to compress. + cache: Current previous summary cache, or None. + model: LLM model for summary generation. + + Returns: + PreviousCompressResult with summary_text, optional new_cache, and records. + """ + if not pairs_to_compress: + return PreviousCompressResult(summary_text=None) + + # 1) Full cache hit + if cache is not None and cache.covered_pairs == len(pairs_to_compress): + anchor_t, anchor_a = pairs_to_compress[-1] + fp = pair_fingerprint( + anchor_t.task or "", action_content(anchor_a) + ) + if fp == cache.anchor_fingerprint: + record = CompressionCallRecord( + call_type="previous_cache_hit", cache_hit=True, + details={"covered_pairs": cache.covered_pairs}, + ) + return PreviousCompressResult( + summary_text=cache.summary_text, + new_cache=None, + records=[record], + ) + + # 2) Incremental compression + if (cache is not None + and 0 < cache.covered_pairs < len(pairs_to_compress)): + anchor_t, anchor_a = pairs_to_compress[cache.covered_pairs - 1] + fp = pair_fingerprint( + anchor_t.task or "", action_content(anchor_a) + ) + if fp == cache.anchor_fingerprint: + old_summary = cache.summary_text + new_pairs = pairs_to_compress[cache.covered_pairs:] + incremental_input = ( + f"## Previous Summary\n{old_summary}\n\n" + f"## New Conversations\n{self._renderer.pairs_to_text(new_pairs, offload_store=self._renderer._offload_store)}" + ) + input_tokens = estimate_tokens_text(incremental_input) + if input_tokens <= self._config.max_summary_input_tokens: + llm_result = self._llm.generate_summary( + incremental_input, model, + call_type="previous_incremental", prompt_type="incremental" + ) + if llm_result.summary_text: + last_t, last_a = pairs_to_compress[-1] + new_cache = PreviousSummaryCache( + summary_text=llm_result.summary_text, + covered_pairs=len(pairs_to_compress), + anchor_fingerprint=pair_fingerprint( + last_t.task or "", action_content(last_a) + ), + ) + return PreviousCompressResult( + summary_text=llm_result.summary_text, + new_cache=new_cache, + records=llm_result.records, + ) + logger.info( + f"Incremental input {input_tokens} tokens exceeds budget " + f"({self._config.max_summary_input_tokens}), " + f"Falling back to full compression." + ) + + # 3) Fresh compression + summary_text, is_cacheable, records = self._summarize_pairs(pairs_to_compress, model) + new_cache = None + if summary_text and is_cacheable: + last_t, last_a = pairs_to_compress[-1] + new_cache = PreviousSummaryCache( + summary_text=summary_text, + covered_pairs=len(pairs_to_compress), + anchor_fingerprint=pair_fingerprint( + last_t.task or "", action_content(last_a) + ), + ) + # is_cacheable is False: PreviousSummaryCache kept as-is to avoid + # incremental compression reusing a fallback cache that doesn't + # represent a true summary. + return PreviousCompressResult( + summary_text=summary_text, + new_cache=new_cache, + records=records, + ) + + def _summarize_pairs( + self, pairs: List[tuple], model, + ) -> Tuple[Optional[str], bool, List[CompressionCallRecord]]: + """Fresh compression entry point. + + Returns (summary_text, is_cacheable, records). + + L1 full summary -> (text, True, records) + L2 trim summary -> (text, True, records) # discard long-lived pairs, then summarize + L3 trim origin -> (text, False, records) # LLM call failed, hard truncated + """ + records: List[CompressionCallRecord] = [] + + if not pairs: + return None, False, records + + full_text = self._renderer.pairs_to_text(pairs, offload_store=self._renderer._offload_store) + if estimate_tokens_text(full_text) <= self._config.max_summary_input_tokens: + target_text = full_text + else: + trimmed_pairs = trim_pairs_to_budget( + pairs, self._config.max_summary_input_tokens, + render_fn=self._renderer.pairs_to_text, keep_first=False, + ) + target_text = self._renderer.render_steps_with_truncation( + trimmed_pairs, fmt="pair", + max_tokens=self._config.max_summary_input_tokens, + task_budget_chars=800, action_budget_chars=1500, + offload_store=self._renderer._offload_store, + ) + + llm_result = self._llm.generate_summary( + target_text, model, call_type="previous_summary", prompt_type="initial" + ) + records.extend(llm_result.records) + if llm_result.summary_text: + return llm_result.summary_text, True, records + logger.warning("previous full/truncated history summary generation failed, triggering L3 fallback truncation") + + reduced_pairs = trim_pairs_to_budget( + pairs, self._config.max_summary_reduce_tokens, + render_fn=self._renderer.pairs_to_text, keep_first=False, + ) + reduced_text = self._renderer.render_steps_with_truncation( + reduced_pairs, fmt="pair", max_tokens=self._config.max_summary_reduce_tokens, + offload_store=self._renderer._offload_store, + ) + first_task = pairs[0][0].task[:200] if pairs and pairs[0][0].task else "" + fallback_text = ( + "[CONTEXT COMPACTION \u2014 REFERENCE ONLY] Earlier steps were removed to free context space. " + "The removed content cannot be summarized. Continue based on the steps below.\n\n" + f"Original task: {first_task}\n\n" + f"Steps removed: {len(pairs) - len(reduced_pairs)} of {len(pairs)}\n\n" + "Remaining compressed history:\n" + + reduced_text + ) + return fallback_text, False, records \ No newline at end of file diff --git a/sdk/nexent/core/agents/agent_context/stats_export.py b/sdk/nexent/core/agents/agent_context/stats_export.py new file mode 100644 index 000000000..8c2ea101d --- /dev/null +++ b/sdk/nexent/core/agents/agent_context/stats_export.py @@ -0,0 +1,89 @@ +"""Statistics and export pure functions for ContextManager.""" + +import logging +from typing import List, Optional + +from ..summary_cache import CompressionCallRecord + +logger = logging.getLogger("agent_context.stats_export") + + +def get_step_compression_stats(step_local_log: List[CompressionCallRecord]) -> dict: + """Compute compression statistics for the current step.""" + if not step_local_log: + return {"calls": 0, "input_tokens": 0, "output_tokens": 0, "cache_hits": 0, "cache_types": []} + cache_types = [r.call_type for r in step_local_log if r.cache_hit] + return { + "calls": len([r for r in step_local_log if not r.cache_hit]), + "input_tokens": sum(r.input_tokens for r in step_local_log), + "output_tokens": sum(r.output_tokens for r in step_local_log), + "input_chars": sum(r.input_chars for r in step_local_log), + "output_chars": sum(r.output_chars for r in step_local_log), + "cache_hits": sum(1 for r in step_local_log if r.cache_hit), + "cache_types": cache_types, + } + + +def get_all_compression_stats(calls_log: List[CompressionCallRecord]) -> dict: + """Compute cumulative compression statistics across all calls.""" + real_calls = [r for r in calls_log if not r.cache_hit] + return { + "total_calls": len(real_calls), + "total_attempts": len(calls_log), + "total_input_tokens": sum(r.input_tokens for r in real_calls), + "total_output_tokens": sum(r.output_tokens for r in real_calls), + "total_cache_hits": sum(1 for r in calls_log if r.cache_hit), + } + + +def export_summary(prev_cache, curr_cache, config) -> dict: + """Export current compression summary state for benchmark inspection. + + Returns a dict with the compressed summary texts, cache metadata, + and compression boundary information. The boundary info tells the + benchmark author which pairs/steps were compressed into the summary + vs. which were retained verbatim -- this is critical for validating + probe design (probes should only target compressed content). + """ + return { + "previous_summary": prev_cache.summary_text if prev_cache else None, + "current_summary": curr_cache.summary_text if curr_cache else None, + "previous_cache_info": ( + { + "covered_pairs": prev_cache.covered_pairs, + "is_fallback": "[CONTEXT COMPACTION" in (prev_cache.summary_text or ""), + } + if prev_cache else None + ), + "current_cache_info": ( + { + "end_steps": curr_cache.end_steps, + "is_fallback": "[CONTEXT COMPACTION" in (curr_cache.summary_text or ""), + } + if curr_cache else None + ), + "compression_boundary": { + "config_keep_recent_pairs": config.keep_recent_pairs, + "config_keep_recent_steps": config.keep_recent_steps, + "previous_compressed_pairs": ( + prev_cache.covered_pairs if prev_cache else 0 + ), + "previous_retained_pairs": config.keep_recent_pairs, + "current_compressed_steps": ( + curr_cache.end_steps if curr_cache else 0 + ), + "current_retained_steps": config.keep_recent_steps, + }, + } + + +def get_token_counts(last_uncompressed: Optional[int], last_compressed: Optional[int]) -> dict: + """Return token counts from the most recent compression pass. + + Returns dict with last_uncompressed and last_compressed token counts, + enabling accurate token_reduction measurement in benchmarks. + """ + return { + "last_uncompressed": last_uncompressed, + "last_compressed": last_compressed, + } diff --git a/sdk/nexent/core/agents/agent_context/step_renderer.py b/sdk/nexent/core/agents/agent_context/step_renderer.py new file mode 100644 index 000000000..5cccd5f24 --- /dev/null +++ b/sdk/nexent/core/agents/agent_context/step_renderer.py @@ -0,0 +1,436 @@ +"""Step rendering and text transformation for ContextManager.""" + +import logging +from typing import Dict, List, Optional, Tuple + +from smolagents.memory import ActionStep, AgentMemory, MemoryStep, TaskStep +from smolagents.models import ChatMessage + +from ..summary_config import ContextManagerConfig +from ...utils.token_estimation import ( + _extract_text_from_messages, + estimate_tokens_text, +) +from .offload_store import OffloadStore +from .summary_step import SummaryTaskStep + +logger = logging.getLogger("agent_context.step_renderer") + + +# ============================================================ +# Standalone offline compression (no ContextManager state) +# ============================================================ + +def compress_history_offline( + pairs: List[Tuple[str, str]], + model, + config: Optional[ContextManagerConfig] = None, + previous_summary: Optional[str] = None, +) -> dict: + """Compress conversation history offline, without ContextManager or AgentMemory. + + This is a standalone function for **Static compression Inspection** in + benchmarks. It takes plain-text (user, assistant) pairs and produces a + summary using the same prompts and schema as the in-agent compression path, + but without any stateful cache, offload store, or agent runtime. + + Args: + pairs: List of (user_text, assistant_text) tuples representing + conversation turns to compress. + model: An LLM model object compatible with smolagents' call interface. + config: ContextManagerConfig providing prompts, schema, and token budgets. + Defaults to a fresh ContextManagerConfig() if not provided. + previous_summary: Optional existing summary text for incremental + compression. If provided, uses the incremental prompt + to update rather than create from scratch. + + Returns: + dict with: + - "summary": the compressed summary text (str or None on failure) + - "is_incremental": whether incremental compression was used + - "is_fallback": whether the LLM failed and fallback truncation was used + - "input_text": the raw text that was fed to the LLM (for debugging) + - "input_chars": character count of the input text + """ + import json + + from smolagents.models import MessageRole + + from .llm_summary import format_summary_output, _is_context_length_error + + config = config or ContextManagerConfig() + # Same compensation as ContextManager.__init__: when max_summary_input_tokens + # is left at the default 0, derive it from token_threshold so that truncation + # logic doesn't accidentally chop all input. + if config.max_summary_input_tokens <= 0: + config.max_summary_input_tokens = int(config.token_threshold * 1.2) + if not pairs and not previous_summary: + return { + "summary": None, + "is_incremental": False, + "is_fallback": False, + "input_text": "", + "input_chars": 0, + } + + # Build input text from pairs + parts = [] + for user_text, assistant_text in pairs: + parts.append(f"user: {user_text}\nassistant: {assistant_text}") + pairs_text = "\n\n".join(parts) + + # Determine compression mode + is_incremental = previous_summary is not None + + if is_incremental: + input_text = ( + f"## Previous Summary\n{previous_summary}\n\n" + f"## New Conversations\n{pairs_text}" + ) + else: + input_text = pairs_text + + # Truncate if exceeds budget (tail-truncation: preserve newest content) + input_tokens = estimate_tokens_text(input_text) + if input_tokens > config.max_summary_input_tokens: + approx_chars = int(config.max_summary_input_tokens * config.chars_per_token * 0.9) + input_text = "...[Earlier content truncated]...\n" + input_text[-approx_chars:] + + # Build prompt + schema_desc = json.dumps(config.summary_json_schema, ensure_ascii=False, indent=2) + if is_incremental: + system_prompt = config.incremental_summary_system_prompt + user_prompt = ( + f"Update the summary following this JSON structure:\n{schema_desc}\n\n" + f"{input_text}" + ) + else: + system_prompt = config.summary_system_prompt + user_prompt = ( + f"Create a structured checkpoint summary following this JSON structure:\n{schema_desc}\n\n" + f"TURNS TO SUMMARIZE:\n{input_text}" + ) + + messages = [ + ChatMessage(role=MessageRole.SYSTEM, + content=[{"type": "text", "text": system_prompt}]), + ChatMessage(role=MessageRole.USER, + content=[{"type": "text", "text": user_prompt}]), + ] + + # Call LLM with error handling + is_fallback = False + summary = None + + try: + response = model(messages, stop_sequences=[]) + raw_output = response.content + if isinstance(raw_output, list): + raw_output = " ".join( + block.get("text", "") + for block in raw_output + if isinstance(block, dict) and block.get("type") == "text" + ) + if not isinstance(raw_output, str): + raw_output = str(raw_output) + summary = format_summary_output(raw_output) + except Exception as e: + if _is_context_length_error(e): + logger.warning("Offline compression exceeds context limit; retrying with 2/3 budget") + approx_chars = int(config.max_summary_input_tokens * config.chars_per_token * 0.6) + truncated_input = input_text[-approx_chars:] if len(input_text) > approx_chars else input_text + if is_incremental: + user_prompt = ( + f"Update the summary following this JSON structure:\n{schema_desc}\n\n" + f"{truncated_input}" + ) + else: + user_prompt = ( + f"Create a structured checkpoint summary following this JSON structure:\n{schema_desc}\n\n" + f"TURNS TO SUMMARIZE:\n{truncated_input}" + ) + messages[-1] = ChatMessage( + role=MessageRole.USER, + content=[{"type": "text", "text": user_prompt}], + ) + try: + response = model(messages, stop_sequences=[]) + raw_output = response.content + if isinstance(raw_output, list): + raw_output = " ".join( + block.get("text", "") + for block in raw_output + if isinstance(block, dict) and block.get("type") == "text" + ) + if not isinstance(raw_output, str): + raw_output = str(raw_output) + summary = format_summary_output(raw_output) + except Exception as e2: + logger.error(f"Offline compression retry still failed: {e2}") + + if summary is None: + # L3 fallback: hard truncation + is_fallback = True + first_task = pairs[0][0][:200] if pairs else "" + reduced_chars = int(config.max_summary_reduce_tokens * config.chars_per_token) + reduced_text = pairs_text[-reduced_chars:] if len(pairs_text) > reduced_chars else pairs_text + summary = ( + "[CONTEXT COMPACTION \u2014 REFERENCE ONLY] Earlier steps were removed to free context space. " + "The removed content cannot be summarized. Continue based on the steps below.\n\n" + f"Original task: {first_task}\n\n" + f"Steps removed: {len(pairs)} of {len(pairs)}\n\n" + "Remaining compressed history:\n" + + reduced_text + ) + + return { + "summary": summary, + "is_incremental": is_incremental, + "is_fallback": is_fallback, + "input_text": input_text, + "input_chars": len(input_text), + } + + +# ============================================================ +# StepRenderer (standalone class, owns config + offload_store) +# ============================================================ + +class StepRenderer: + """Step rendering and text transformation. + + Owns its config and offload_store references, with no cross-mixin + dependencies. All methods are pure computation on the provided inputs. + """ + + def __init__(self, config: ContextManagerConfig, offload_store: OffloadStore): + self._config = config + self._offload_store = offload_store + + def render_action_step(self, action: ActionStep, offload_store: Optional[OffloadStore] = None) -> str: + """Render an ActionStep to text, with per-segment offload. + + Each message segment (model_output, tool_call, observation) is independently + checked against the offload threshold. Only oversized segments are offloaded; + short segments remain intact, giving the compression LLM sufficient context + to produce a high-quality summary. + """ + msgs = action.to_messages(summary_mode=False) + + # Fast path: no offload configured — simple concatenation + if offload_store is None or self._config.per_step_render_limit <= 0: + return _extract_text_from_messages(msgs) or "" + + # Per-segment rendering with offload + parts = [] + for msg in msgs: + text = _extract_text_from_messages([msg]) or "" + if text.startswith("Calling tools:"): + # Tool call is always short — keep verbatim + parts.append(text) + else: + # Per-segment offload: observation or model_output + parts.append(self._render_segment(text, action, offload_store)) + return "\n".join(parts) + + def _render_segment(self, text: str, action: ActionStep, offload_store: Optional[OffloadStore] = None) -> str: + """Render a single message segment, offloading if oversized. + + When the segment exceeds ``per_step_render_limit``, the full text is archived + in ``offload_store`` and replaced with a self-describing marker so the + compression LLM knows what was offloaded and how to retrieve it. + + If the action has a ``_raw_observation`` attribute (preserved before + ``max_observation_length`` truncation), the original text is used for + offload so that ``reload`` can retrieve the truly original content. + + Args: + text: The display segment text (possibly already truncated by + ``max_observation_length``). + action: The ActionStep being rendered; may carry ``_raw_observation``. + offload_store: OffloadStore for archiving oversized segments. + + Returns: + Rendered segment — either the full text, or a truncated version ending + with a self-describing offload marker. + """ + limit = self._config.per_step_render_limit + if offload_store is None or limit <= 0: + return text + + # Determine the source text for offload decisions and archiving. + # For observations, prefer the pre-truncation raw content if available. + source_text = text + if text.startswith("Observation:") and hasattr(action, '_raw_observation'): + source_text = action._raw_observation + + # Skip offload for reloaded content: the output of + # ReloadOriginalContextTool already originates from the offload + # store — archiving it again creates a duplicate. + # The observation is prefixed with "Execution logs:\n" but the + # JSON body always contains "offload_handle" near the start. + if '"offload_handle"' in source_text[:300]: + return text + + # If the source (original) content is within limit, no offload needed. + if len(source_text) <= limit: + return text + + # Build the human/LLM-readable description first, so the same hint + # is stored alongside the content (for list_active inventory) + # and embedded in the inline marker. Preview is based on *display* + # text (``text``) because that is what the LLM already sees. + char_count = len(source_text) + if text.startswith("Observation:"): + first_line = text.split("\n")[0] if "\n" in text else text[:100] + description = f"{first_line[:80]} ({char_count} chars)" + else: + preview = text[:80].replace("\n", " ").strip() + description = f"{preview}... ({char_count} chars)" + + # Offload triggered — archive the original (or raw) content together + # with its description. + handle = offload_store.store(source_text, description=description) + if handle is None: + # Content too large even for offload store; fall back to plain truncation. + return text[:limit] + "\n...[CONTENT_TOO_LARGE_TO_OFFLOAD]" + + # Build a self-describing marker so the LLM understands what was offloaded. + if text.startswith("Observation:"): + marker = f"\n...[[OBS_OFFLOAD: {description}, handle={handle}]]" + else: + marker = f"\n...[[CONTENT_OFFLOAD: {description}, handle={handle}]]" + + return text[:limit] + marker + + def truncate_text_to_tokens(self, text: str, max_tokens: int) -> str: + if max_tokens <= 0: + return "" + if estimate_tokens_text(text) <= max_tokens: + return text + units = text.split("\n\n") + kept, total = [], 0 + for u in reversed(units): + u_tokens = estimate_tokens_text(u) + if total + u_tokens > max_tokens and kept: + break + kept.append(u) + total += u_tokens + result = "...[Earlier content truncated]...\n\n" + "\n\n".join(reversed(kept)) + if estimate_tokens_text(result) > max_tokens: + approx_chars = int(max_tokens * self._config.chars_per_token * 0.9) + result = "...[Earlier content truncated]...\n" + result[:approx_chars] + return result + + def pairs_to_text(self, pairs: List[tuple], offload_store: Optional[OffloadStore] = None) -> str: + parts = [] + for i, (task_step, action_step) in enumerate(pairs): + task_text = task_step.task or "" + action_text = self.render_action_step(action_step, offload_store=offload_store) + parts.append(f"user: {task_text}\nassistant: {action_text}") + return "\n\n".join(parts) + + def pairs_to_steps(self, pairs: List[tuple]) -> List[MemoryStep]: + steps = [] + for task_step, action_step in pairs: + steps.append(task_step) + steps.append(action_step) + return steps + + def build_messages( + self, memory: AgentMemory, + prev_summary_step: Optional[SummaryTaskStep], + prev_tail_steps: List[MemoryStep], + curr_kept_steps: List[MemoryStep], + ) -> List[ChatMessage]: + result = [] + if memory.system_prompt: + result.extend(memory.system_prompt.to_messages()) + if prev_summary_step: + result.extend(prev_summary_step.to_messages()) + for step in prev_tail_steps: + result.extend(step.to_messages()) + for step in curr_kept_steps: + result.extend(step.to_messages()) + return result + + def actions_to_text(self, actions: List[ActionStep], offload_store: Optional[OffloadStore] = None) -> str: + parts = [] + for i, step in enumerate(actions): + text = self.render_action_step(step, offload_store=offload_store) + parts.append(f"[Step {step.step_number or i+1}]\n{text}") + return "\n\n".join(parts) + + def render_steps_with_truncation( + self, + steps: List, + fmt: str = "action", + max_tokens: int = None, + min_budget_chars: int = 80, + task_budget_chars: int = 800, + action_budget_chars: int = None, + offload_store: Optional[OffloadStore] = None, + ) -> str: + if max_tokens is None: + max_tokens = self._config.max_summary_input_tokens + if action_budget_chars is None: + action_budget_chars = self._config.max_memory_step_length + + entries = self._build_step_entries(steps, fmt, offload_store=offload_store) + raw_text = "\n\n".join(task + action for task, action in entries) + if estimate_tokens_text(raw_text) <= max_tokens: + return raw_text + + return self._truncate_entries_to_budget(entries, max_tokens, min_budget_chars, task_budget_chars, action_budget_chars) + + def _build_step_entries(self, steps: List, fmt: str, offload_store: Optional[OffloadStore] = None) -> List[Tuple[str, str]]: + entries = [] + for step in steps: + if fmt == "action": + text = f"[Step {step.step_number or '?'}]\n{self.render_action_step(step, offload_store=offload_store)}" + entries.append(("", text)) + else: + task_step, action_step = step + task_str = f"user: {task_step.task or ''}\nassistant: " + action_str = self.render_action_step(action_step, offload_store=offload_store) + entries.append((task_str, action_str)) + return entries + + def _truncate_entries_to_budget( + self, entries: List[Tuple[str, str]], max_tokens: int, + min_budget_chars: int, task_budget_chars: int, action_budget_chars: int, + ) -> str: + t_budget = task_budget_chars + a_budget = action_budget_chars + all_text = "" + + while True: + parts = [self._truncate_entry(e, t_budget, a_budget) for e in entries] + all_text = "\n\n".join(parts) + + if estimate_tokens_text(all_text) <= max_tokens: + break + + t_budget, a_budget = self._reduce_budgets(t_budget, a_budget, min_budget_chars) + if t_budget == min_budget_chars and a_budget == min_budget_chars: + break + + return all_text + + def _truncate_entry(self, entry: Tuple[str, str], task_budget: int, action_budget: int) -> str: + task_str, action_str = entry + task_trunc = self._truncate_text(task_str, task_budget) if task_str else "" + action_trunc = self._truncate_text(action_str, action_budget) + return task_trunc + action_trunc + + def _truncate_text(self, text: str, max_len: int, mark: str = "...[Truncated]") -> str: + if len(text) <= max_len: + return text + return text[:max_len - len(mark)] + mark + + def _reduce_budgets(self, t_budget: int, a_budget: int, min_budget: int) -> Tuple[int, int]: + if a_budget > min_budget: + return t_budget, max(min_budget, int(a_budget * 0.8)) + if t_budget > min_budget: + return max(min_budget, int(t_budget * 0.8)), a_budget + return t_budget, a_budget \ No newline at end of file diff --git a/sdk/nexent/core/agents/agent_context/summary_step.py b/sdk/nexent/core/agents/agent_context/summary_step.py new file mode 100644 index 000000000..8bf32d844 --- /dev/null +++ b/sdk/nexent/core/agents/agent_context/summary_step.py @@ -0,0 +1,22 @@ +"""Summary step types for context compression.""" + +from dataclasses import dataclass +from smolagents.memory import TaskStep +from smolagents.models import ChatMessage, MessageRole + + +@dataclass +class SummaryTaskStep(TaskStep): + """TaskStep subclass that contains a compressed summary of earlier steps.""" + is_summary: bool = True + is_fallback: bool = False + prefix: str = ( + "[HISTORICAL_MEMORY_BLOCK]\n" + "This is a compressed summary of earlier steps, not a new user instruction. " + "If it conflicts with the most recent user message, follow the recent message. " + "The summary may be lossy; use the reload tool to retrieve original content if needed." + ) + + def to_messages(self, summary_mode: bool = False) -> list: + content = [{"type": "text", "text": f"{self.prefix}:\n{self.task}"}] + return [ChatMessage(role=MessageRole.USER, content=content)] diff --git a/sdk/nexent/core/agents/agent_model.py b/sdk/nexent/core/agents/agent_model.py index 62e75cb59..a1ac43363 100644 --- a/sdk/nexent/core/agents/agent_model.py +++ b/sdk/nexent/core/agents/agent_model.py @@ -141,7 +141,7 @@ class AgentConfig(BaseModel): description: str = Field(description="Agent description") prompt_templates: Optional[Dict[str, Any]] = Field(description="Prompt templates", default=None) tools: List[ToolConfig] = Field(description="List of tool information") - max_steps: int = Field(description="Maximum number of steps for current Agent", default=15, ge=1, le=30) + max_steps: int = Field(description="Maximum number of steps for current Agent", default=5) model_name: str = Field(description="Model alias from ModelConfig") provide_run_summary: Optional[bool] = Field(description="Whether to provide run summary to upper-level Agent", default=False) instructions: Optional[str] = Field(description="Additional instructions to prepend to system prompt", default=None) @@ -376,11 +376,12 @@ def to_messages(self) -> List[Dict[str, str]]: return [{"role": "system", "content": self.formatted_description}] return [] - def add_skill(self, name: str, description: str) -> None: + def add_skill(self, name: str, description: str, examples: List[str] = None) -> None: """Add a skill definition.""" self.skills.append({ "name": name, - "description": description + "description": description, + "examples": examples or [] }) diff --git a/sdk/nexent/core/agents/core_agent.py b/sdk/nexent/core/agents/core_agent.py index 9397b2bfa..4491aa3da 100644 --- a/sdk/nexent/core/agents/core_agent.py +++ b/sdk/nexent/core/agents/core_agent.py @@ -31,6 +31,12 @@ from .agent_model import AgentVerificationConfig from .verification import VerificationController, VerificationResult from ..utils.token_estimation import msg_token_count +from ..utils.code_analysis import extract_invoked_tools, extract_invoked_tool_signatures + +if not hasattr(ActionStep, "invoked_tools"): + ActionStep.invoked_tools = None +if not hasattr(ActionStep, "invoked_tool_signatures"): + ActionStep.invoked_tool_signatures = None def parse_code_blobs(text: str) -> str: """Extract code blocks from the LLM's output for execution. @@ -236,6 +242,15 @@ def __init__( self.stop_event = threading.Event() self._history_step_count = 0 # For ContextManager, record boundary for compression self.context_manager: ContextManager = None + self._ephemeral_system_messages: Optional[List[ChatMessage]] = None + """Per-run system messages injected before the current user query. + + Set via ``set_ephemeral_messages()`` before ``run()`` and automatically + prepended to ``input_messages`` in ``_step_stream()`` right before the + last USER message (the current query). These messages are NOT stored in + ``memory.steps`` and therefore do not persist into conversation history + or compression. Cleared via ``clear_ephemeral_messages()``. + """ self.step_metrics: List[dict] = [] # Quantitative metrics per step self._last_uncompressed_est = 0 # Override smolagent default to prevent extracting ```python blocks from KB content. @@ -244,6 +259,18 @@ def __init__( # identifiers; omitting "python" and "py" ensures ```python blocks are not extracted. self.code_block_tags = ["", ""] + def set_ephemeral_messages(self, messages: List[ChatMessage]) -> None: + """Set per-run system messages injected before the current user query. + + These are prepended to ``input_messages`` in ``_step_stream()`` right + before the last USER message and are NOT stored in ``memory.steps``. + """ + self._ephemeral_system_messages = messages + + def clear_ephemeral_messages(self) -> None: + """Clear ephemeral system messages after the run completes.""" + self._ephemeral_system_messages = None + def _verification_tool_names(self) -> List[str]: names = set() for container in (getattr(self, "tools", {}) or {}, getattr(self, "managed_agents", {}) or {}): @@ -391,12 +418,28 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: ) input_messages = memory_messages.copy() - # import pdb; pdb.set_trace() + # Trigger context compression if needed before building messages if self.context_manager and self.context_manager.config.enabled: input_messages = self.context_manager.compress_if_needed( self.model, self.memory, input_messages, self._history_step_count ) + + # Inject ephemeral system messages before the last USER message + # (the current query), maximizing prompt cache prefix reuse: the + # system prompt and all history messages stay at the same positions. + if self._ephemeral_system_messages: + insert_at = len(input_messages) + for i in range(len(input_messages) - 1, -1, -1): + if input_messages[i].role == "user": + insert_at = i + break + input_messages = ( + input_messages[:insert_at] + + self._ephemeral_system_messages + + input_messages[insert_at:] + ) + # Add new step in logs memory_step.model_input_messages = input_messages stop_sequences = ["Observation:", "Calling tools:"] @@ -464,8 +507,32 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: arguments=code_action, id=f"call_{len(self.memory.steps)}", ) + # memory_step.invoked_tools = extract_invoked_tools(code_action, self.tools) if self.tools else [] + memory_step.invoked_tool_signatures = ( + extract_invoked_tool_signatures(code_action, self.tools) if self.tools else [] + ) + # When context_manager is enabled, store a COMPACT call-signature in + # tool_call.arguments instead of the full code. to_messages() renders + # arguments into the TOOL_CALL message, while the same code already + # lives verbatim in model_output's block. Compacting here makes + # every downstream path emit a bounded tool-call message. The full code + # is preserved untouched in memory_step.code_action and model_output. + if self.context_manager and self.context_manager.config.enabled: + compact_arguments = ( + "\n".join(memory_step.invoked_tool_signatures) + if memory_step.invoked_tool_signatures + else truncate_content(code_action, max_length=100) + ) + else: + compact_arguments = code_action + tool_call = ToolCall( + name="python_interpreter", + arguments=compact_arguments, + id=f"call_{len(self.memory.steps)}", + ) memory_step.tool_calls = [tool_call] + # Execute self.logger.log_code(title="Executing parsed code:", content=code_action, level=LogLevel.INFO) @@ -535,6 +602,28 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: observation += "Last output from code snippet:\n" + truncated_output memory_step.observations = observation + # --- Save raw observation for offload when needed --- + # Only preserve the truly original content when: + # 1. ContextManager is enabled with reload + # 2. per_step_render_limit is active (offload mechanism on) + # 3. max_observation_length will truncate the observation + # This avoids unconditional double-storage for short observations. + ctx_cfg = self.context_manager.config if self.context_manager else None + needs_raw = ( + ctx_cfg + and ctx_cfg.enable_reload + and ctx_cfg.per_step_render_limit > 0 + and ctx_cfg.max_observation_length > 0 + and len(observation) > ctx_cfg.max_observation_length + ) + if needs_raw: + raw_limit = getattr(ctx_cfg, 'max_offload_entry_chars', 30000) + if len(observation) > raw_limit: + memory_step._raw_observation = observation[:raw_limit] + "\n...[RAW_TRUNCATED]" + else: + memory_step._raw_observation = observation + # --- end raw observation save --- + verification_controller = getattr(self, "verification_controller", None) if verification_controller: postcheck = verification_controller.verify_after_tool_call( @@ -560,11 +649,14 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[Any]: max_obs = self.context_manager.config.max_observation_length if max_obs > 0 and memory_step.observations and len(memory_step.observations) > max_obs: obs_text = memory_step.observations - half = max_obs // 2 truncation_marker = ( f"\n...[Output truncated to {max_obs} characters. " f"Use search or read tools to find specific results.]\n" ) + # Reserve space for the marker itself so the total stays + # within max_obs (half + marker + half ≤ max_obs). + content_budget = max(0, max_obs - len(truncation_marker)) + half = content_budget // 2 memory_step.observations = obs_text[:half] + truncation_marker + obs_text[-half:] if not code_output.is_final_answer and truncated_output is not None: @@ -1008,4 +1100,4 @@ def _handle_max_steps_reached(self, task: str) -> Any: self._finalize_step(final_memory_step) self.memory.steps.append(final_memory_step) - return model_output + return model_output \ No newline at end of file diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index a9a31a94b..e20db27ee 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -10,6 +10,7 @@ from smolagents import ActionStep, AgentText, TaskStep, Timing from smolagents.tools import Tool +from smolagents.models import ChatMessage, MessageRole from ...monitor import AgentRunMetadata, get_agent_monitoring_context, get_monitoring_manager @@ -413,6 +414,23 @@ def create_single_agent(self, agent_config: AgentConfig): except Exception as e: raise ValueError(f"Error in creating external A2A agent wrapper: {e}") + # Build ContextManager before agent so the reload tool can be + # included in tool_list and properly sandboxed by the Python executor. + ctx_config = getattr(agent_config, 'context_manager_config', None) + ctx_manager = None + if ctx_config: + ctx_manager = ContextManager( + config=ctx_config, + max_steps=agent_config.max_steps, + ) + if ctx_config.enable_reload: + from ..tools import ReloadOriginalContextTool + tool_list.append( + ReloadOriginalContextTool( + offload_store=ctx_manager.offload_store, + ) + ) + # Create the agent agent = CoreAgent( observer=self.observer, @@ -430,17 +448,15 @@ def create_single_agent(self, agent_config: AgentConfig): ) agent.stop_event = self.stop_event - # Mount context manager if config provided - ctx_config = getattr(agent_config, 'context_manager_config', None) - if ctx_config: - agent.context_manager = ContextManager( - config=ctx_config, - max_steps=agent_config.max_steps - ) - context_components = getattr(agent_config, 'context_components', None) - if context_components: - for component in context_components: - agent.context_manager.register_component(component) + # Mount context manager if configured + if ctx_manager: + agent.context_manager = ctx_manager + + # Register context components if provided + context_components = getattr(agent_config, 'context_components', None) + if context_components and agent.context_manager: + for component in context_components: + agent.context_manager.register_component(component) return agent except Exception as e: @@ -474,6 +490,37 @@ def add_history_to_agent(self, history: List[AgentHistory]): action_output=msg.content, model_output=msg.content)) self.agent._history_step_count = len(self.agent.memory.steps) + def _build_reloadable_archives_messages(self, query: str = "") -> List[ChatMessage]: + """Build ephemeral ChatMessages listing reloadable offload archives. + + Uses ``OffloadStore.build_reload_inventory()`` and wraps the result + in a SYSTEM-role ChatMessage. Returns an empty list when there is nothing + to list or reload is disabled. + + When ``query`` is non-empty, entries are scored by keyword overlap + with the query so only relevant handles are shown. + + The returned messages are injected as ephemeral system messages on the + agent: they are prepended right before the current user query in + ``_step_stream`` and do NOT persist in ``memory.steps``, so they never + become part of the conversation history or compression. + """ + + ctx_mgr = getattr(self.agent, "context_manager", None) + if ctx_mgr is None: + return [] + store = getattr(ctx_mgr, "offload_store", None) + if store is None: + return [] + enable_reload = getattr(ctx_mgr.config, "enable_reload", False) + text = store.build_reload_inventory(enable_reload, query=query) + if not text: + return [] + return [ChatMessage( + role=MessageRole.SYSTEM, + content=[{"type": "text", "text": text}], + )] + def agent_run_with_observer(self, query: str, reset=True): if not isinstance(self.agent, CoreAgent): raise TypeError(f"agent must be a CoreAgent object, not {type(self.agent)}") @@ -488,6 +535,12 @@ def agent_run_with_observer(self, query: str, reset=True): observer = self.agent.observer total_output_tokens = 0 final_answer_for_trace = None + # Set ephemeral system messages for reloadable archive inventory. + # These are injected right before the current user query in + # _step_stream and do NOT persist in memory.steps. + ephemeral_msgs = self._build_reloadable_archives_messages(query=query) + if ephemeral_msgs: + self.agent.set_ephemeral_messages(ephemeral_msgs) with monitoring_manager.start_agent_run(metadata): with monitoring_manager.trace_agent_step( "agent.run.loop", @@ -567,6 +620,8 @@ def agent_run_with_observer(self, query: str, reset=True): raise ValueError(f"Error in interaction: {str(e)}") finally: + if ephemeral_msgs: + self.agent.clear_ephemeral_messages() self._log_step_metrics() if final_answer_for_trace is not None: @@ -662,7 +717,7 @@ def _val_width(vals, extra_val=None): lines.append( "-----" ) - logger.debug("\n".join(lines)) + print("\n".join(lines)) # Optional: write to local file with open("nexent_context_metrics.log", "a", encoding="utf-8") as f: diff --git a/sdk/nexent/core/agents/run_agent.py b/sdk/nexent/core/agents/run_agent.py index 243ca099e..98c1f97ed 100644 --- a/sdk/nexent/core/agents/run_agent.py +++ b/sdk/nexent/core/agents/run_agent.py @@ -88,6 +88,11 @@ def agent_run_thread(agent_run_info: AgentRunInfo): if getattr(agent_run_info, 'context_manager', None) is not None: agent.context_manager = agent_run_info.context_manager + # Sync reload tool to the swapped store (otherwise it still + # points to the internal store created by create_single_agent). + if 'reload_original_context_messages' in agent.tools: + agent.tools['reload_original_context_messages']._offload_store = \ + agent.context_manager.offload_store nexent.add_history_to_agent(agent_run_info.history) nexent.agent_run_with_observer( @@ -109,6 +114,9 @@ def agent_run_thread(agent_run_info: AgentRunInfo): if getattr(agent_run_info, 'context_manager', None) is not None: agent.context_manager = agent_run_info.context_manager + if 'reload_original_context_messages' in agent.tools: + agent.tools['reload_original_context_messages']._offload_store = \ + agent.context_manager.offload_store nexent.add_history_to_agent(agent_run_info.history) nexent.agent_run_with_observer( diff --git a/sdk/nexent/core/agents/summary_config.py b/sdk/nexent/core/agents/summary_config.py index e271ddd34..69aed829d 100644 --- a/sdk/nexent/core/agents/summary_config.py +++ b/sdk/nexent/core/agents/summary_config.py @@ -10,7 +10,7 @@ @dataclass class ContextManagerConfig: """Configuration for ContextManager - handles ALL context building. - + Extends existing compression config with: - Strategy selection for component selection algorithms - Injection flags to enable/disable individual context components @@ -24,6 +24,43 @@ class ContextManagerConfig: max_chunk_count: int = 0 max_memory_step_length: int = 2000 + # === Offload Settings === + # Archives oversized step-render segments to an in-memory OffloadStore + # so the LLM still sees compact context. Requires **both** enable_reload + # AND per_step_render_limit > 0. The agent retrieves archived content + # via the ``reload_original_context_messages`` tool. + + enable_reload: bool = False + """Create an :class:`OffloadStore` and inject the reload tool into the agent. + + Offload is only *triggered* when ``per_step_render_limit > 0``. + """ + + per_step_render_limit: int = 0 + """Character threshold triggering offload **during compression**. + + Only applies to old steps outside the ``keep_recent`` window — recent + steps are never offloaded. When a step's rendered text exceeds this + limit, the full content is archived and replaced with an + ``[[OFFLOAD:handle:desc]]`` marker. Unlike ``max_observation_length`` + this is **reversible**: the agent can reload archived content on demand. + + Set to 0 to disable (the default). Suggested: 3000–10000. + """ + + max_offload_entries: int = 200 + """Max entries in the :class:`OffloadStore`. Oldest evicted (FIFO) when full.""" + + max_offload_entry_chars: int = 30000 + """Max characters per offload entry. Oversized content is rejected by the + store. Safety cap against a single giant observation dominating memory. + """ + + max_offload_total_chars: int = 2_000_000 + """Cumulative character budget across all entries. Oldest evicted (FIFO) + when exceeded. Together with ``max_offload_entries`` bounds total memory. + """ + summary_system_prompt: str = ( "You are a conversation summarization assistant. Compress the following " "conversation history into a structured summary, preserving all key information: " @@ -59,17 +96,17 @@ class ContextManagerConfig: estimated_chunk_summary_tokens: int = 400 chars_per_token: float = 1.5 - # Pre-truncate single observations (model/tool outputs) longer than this - # character limit at execute_action time, before they reach memory. - # 0 = disabled (production default). Only takes effect when ``enabled`` - # is True, so production callers that do not opt in see no behaviour - # change. + # Pre-truncate observations at source (before memory), keeping head+tail + # around a truncation marker. This is per-step, irreversible sanitation — + # not a compression mechanism. For reversible archiving of large content, + # use offload (``per_step_render_limit``) instead. + # 0 = disabled (default). Takes effect only when ``enabled`` is True. max_observation_length: int = 0 # === NEW: Strategy Selection === strategy: StrategyType = "token_budget" """Context component selection strategy. - + Options: - 'full': Keep all components (for unlimited context models) - 'token_budget': Select components within token budget by priority @@ -80,22 +117,22 @@ class ContextManagerConfig: # === NEW: Component Injection Flags === inject_system_prompt: bool = True """Whether to inject system prompt into context.""" - + inject_tools: bool = True """Whether to inject tool descriptions into system prompt.""" - + inject_skills: bool = True """Whether to inject skill summaries into system prompt.""" - + inject_memory: bool = True """Whether to search and inject long-term memory (mem0) into system prompt.""" - + inject_knowledge_base: bool = True """Whether to inject knowledge base summaries into system prompt.""" - + inject_agent_definitions: bool = True """Whether to inject sub-agent (managed_agents + external_a2a_agents) definitions.""" - + inject_app_context: bool = True """Whether to inject APP_NAME, APP_DESCRIPTION, time, user_id.""" @@ -111,11 +148,11 @@ class ContextManagerConfig: "conversation_history": 4000, # Reserved for conversation compression }) """Token budget for each context component type. - + Used by token_budget strategy to allocate tokens across components. Total of all budgets should not exceed token_threshold. """ # === NEW: Buffered Strategy Settings === buffer_size_per_component: int = 10 - """Number of items to keep per component type for 'buffered' strategy.""" \ No newline at end of file + """Number of items to keep per component type for 'buffered' strategy.""" diff --git a/sdk/nexent/core/tools/__init__.py b/sdk/nexent/core/tools/__init__.py index c35991f6e..97d3e0ba1 100644 --- a/sdk/nexent/core/tools/__init__.py +++ b/sdk/nexent/core/tools/__init__.py @@ -24,8 +24,7 @@ from .run_skill_script_tool import run_skill_script from .read_skill_md_tool import read_skill_md from .read_skill_config_tool import read_skill_config -from .store_memory_tool import StoreMemoryTool -from .search_memory_tool import SearchMemoryTool +from .reload_original_context_tool import ReloadOriginalContextTool __all__ = [ "MySqlTool", @@ -56,6 +55,5 @@ "run_skill_script", "read_skill_md", "read_skill_config", - "StoreMemoryTool", - "SearchMemoryTool", + "ReloadOriginalContextTool" ] diff --git a/sdk/nexent/core/tools/reload_original_context_tool.py b/sdk/nexent/core/tools/reload_original_context_tool.py new file mode 100644 index 000000000..c6c552f37 --- /dev/null +++ b/sdk/nexent/core/tools/reload_original_context_tool.py @@ -0,0 +1,55 @@ +import json +import logging + +from smolagents.tools import Tool + +logger = logging.getLogger("reload_original_context_tool") + + +class ReloadOriginalContextTool(Tool): + """Tool for reloading offloaded original context content. + + When the context manager compresses conversation history, long step content + is offloaded to an in-memory store and replaced with [[OFFLOAD:handle=...]] + markers. The agent can call this tool to recover the full original content + when detailed information from earlier steps is needed. + """ + name = "reload_original_context_messages" + description = ( + "Reload the original full content of an offloaded / archived context step. " + "At the start of each conversation turn, a system notice lists available " + "archived handles (e.g. 'handle=abc123: description'). " + "Use this tool with the handle value from that notice when you need to " + "review the detailed original content that was removed to save context space." + ) + + inputs = { + "offload_handle": { + "type": "string", + "description": "The handle value from the system notice inventory (e.g. 'handle=abc123')" + } + } + + output_type = "string" + + def __init__(self, offload_store=None, **kwargs): + super().__init__(**kwargs) + self._offload_store = offload_store + + def forward(self, offload_handle: str) -> str: + if self._offload_store is None: + return json.dumps({"error": "Offload store is not available. Context reload is not enabled."}) + + content = self._offload_store.reload(offload_handle) + if content is None: + return json.dumps({ + "error": f"No offloaded content found for handle '{offload_handle}'. " + f"The content may have been evicted from the store." + }) + + return json.dumps({ + "offload_handle": offload_handle, + "content": content, + "content_length": len(content), + "message": "Original context content retrieved successfully." + }, ensure_ascii=False) \ No newline at end of file diff --git a/sdk/nexent/core/utils/__init__.py b/sdk/nexent/core/utils/__init__.py index 67513717d..ac48a8b39 100644 --- a/sdk/nexent/core/utils/__init__.py +++ b/sdk/nexent/core/utils/__init__.py @@ -1,3 +1,5 @@ + from .observer import MessageObserver, ProcessType +from .code_analysis import extract_invoked_tools -__all__ = ["MessageObserver", "ProcessType"] \ No newline at end of file +__all__ = ["MessageObserver", "ProcessType", "extract_invoked_tools"] diff --git a/sdk/nexent/core/utils/code_analysis.py b/sdk/nexent/core/utils/code_analysis.py new file mode 100644 index 000000000..ba46c0c0a --- /dev/null +++ b/sdk/nexent/core/utils/code_analysis.py @@ -0,0 +1,151 @@ +"""Code analysis utilities for extracting tool usage information from agent-generated code.""" + +import ast +import logging +from typing import List + +logger = logging.getLogger("code_analysis") + + +def extract_invoked_tools(code_action: str, registered_tools: dict) -> List[str]: + """Extract registered tool names called in code_action via AST analysis. + + Walks the AST to find all ``ast.Call`` nodes whose func is an ``ast.Name``, + then intersects with the keys of *registered_tools* (typically + ``self.tools`` on the agent). Returns a **sorted** list of matched tool + names (duplicates removed). + + Known limitations (acceptable gaps): + - Variable aliases (``fn = tool; fn()``) — resolved at runtime, not in AST. + - Dynamic dispatch (``globals()[name]()``) — same reason. + These patterns are exceedingly rare in LLM-generated CodeAgent code. + + Args: + code_action: The Python code string from ``action_step.code_action``. + registered_tools: Dict mapping tool name -> tool object (e.g. ``self.tools``). + Only the keys are used; values are ignored. + + Returns: + Sorted list of tool names that are both called in *code_action* and + present in *registered_tools*. Empty list when no tools are called + or *code_action* has syntax errors. + """ + if not code_action: + return [] + try: + tree = ast.parse(code_action) + except SyntaxError: + logger.warning("Failed to parse code_action for invoked_tools extraction") + return [] + called_names = set() + for node in ast.walk(tree): + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): + called_names.add(node.func.id) + return sorted(name for name in called_names if name in registered_tools) + + +def _render_arg_value(node: ast.AST, max_value_len: int) -> str: + """Render one call-argument value. + + The compression rule is deliberately **semantic-agnostic**: it never looks + at the parameter *name* (we cannot know which tool's params matter). It + decides purely by the *rendered size* of the value: + + - Numbers / bools / None -> always kept verbatim (tiny, often pivotal). + - Strings <= ``max_value_len`` -> kept verbatim (quoted via ``repr``). + - Strings > ``max_value_len`` -> ```` placeholder. + - Any other expression (variable, nested call, list/dict literal): + unparsed to source, kept if short, else ````. + + This means short scalars like ``file_path="a.txt"`` survive while big + payloads like a 5KB ``content=...`` collapse to a self-describing hint, + with no per-tool configuration. + """ + if isinstance(node, ast.Constant): + v = node.value + if isinstance(v, str): + return repr(v) if len(v) <= max_value_len else f"" + # int / float / bool / None: always compact and frequently the + # decision-critical args (ids, flags, thresholds) -> keep verbatim. + return repr(v) + try: + src = ast.unparse(node) # Python 3.9+; project requires >=3.10 + except Exception: + return "" + if len(src) <= max_value_len: + return src + return f"<{type(node).__name__.lower()}:{len(src)} chars>" + + +def _render_call_signature(node: ast.Call, max_value_len: int, max_sig_len: int) -> str: + """Render a single ``tool(...)`` call into a compact, length-bounded signature.""" + parts: List[str] = [] + for arg in node.args: # positional + parts.append(_render_arg_value(arg, max_value_len)) + for kw in node.keywords: # keyword (kw.arg is None for ``**kwargs``) + name = kw.arg if kw.arg is not None else "**" + parts.append(f"{name}={_render_arg_value(kw.value, max_value_len)}") + sig = f"{node.func.id}(" + ", ".join(parts) + ")" + if len(sig) > max_sig_len: + sig = sig[:max_sig_len] + f"...(+{len(parts)} args)" + return sig + + +def extract_invoked_tool_signatures( + code_action: str, + registered_tools: dict, + max_value_len: int = 60, + max_sig_len: int = 200, +) -> List[str]: + """Extract compact call *signatures* for registered tools invoked in code. + + Unlike :func:`extract_invoked_tools` (which returns bare names), this keeps + the call shape -- e.g. ``write_file(file_path='ana.txt', content=analysis)`` + -- so that a compacted TOOL_CALL message preserves causal information + (which file, which id) even if the assistant's ```` text is later + truncated, while large argument values are replaced by size-describing + placeholders to avoid duplicating the payload already present verbatim in + ``model_output``. + + Nested calls that appear *as arguments* to another call are not emitted as + separate top-level entries (they already show up inside the parent + signature). Order of first appearance is preserved. + + Falls back to the bare tool name for any individual call whose signature + rendering raises, and returns ``[]`` on syntax errors -- it must never + raise into the rendering path. + """ + if not code_action: + return [] + try: + tree = ast.parse(code_action) + except SyntaxError: + logger.warning("Failed to parse code_action for invoked_tool signatures") + return [] + + # Collect Call nodes that are themselves arguments of another Call, so we + # can skip emitting them as standalone top-level signatures. + nested_arg_calls = set() + for node in ast.walk(tree): + if isinstance(node, ast.Call): + for sub in list(node.args) + [kw.value for kw in node.keywords]: + if isinstance(sub, ast.Call): + nested_arg_calls.add(id(sub)) + + signatures: List[str] = [] + seen = set() + for node in ast.walk(tree): + if not (isinstance(node, ast.Call) and isinstance(node.func, ast.Name)): + continue + if node.func.id not in registered_tools: + continue + if id(node) in nested_arg_calls: + continue + try: + sig = _render_call_signature(node, max_value_len, max_sig_len) + except Exception: + sig = node.func.id + if sig not in seen: + seen.add(sig) + signatures.append(sig) + return signatures \ No newline at end of file diff --git a/test/sdk/core/agents/test_agent_context/loader.py b/test/sdk/core/agents/test_agent_context/loader.py index 3d41c07a0..fee93d9ce 100644 --- a/test/sdk/core/agents/test_agent_context/loader.py +++ b/test/sdk/core/agents/test_agent_context/loader.py @@ -1,16 +1,18 @@ """ loader.py ───────── -Loads sdk/nexent/core/agents/agent_context.py in isolation via importlib, +Loads the agent_context package in isolation via importlib, bypassing __init__.py chains that drag in unrelated heavy dependencies. Also injects a fully-functional token_estimation stub so that the module under test executes its real estimation logic without any external imports. +Since agent_context/ is now a package (directory), we load each submodule +manually via importlib in dependency order, then wire them together as +``sdk.nexent.core.agents.agent_context.*`` in ``sys.modules``. + Public names re-exported from this module are the same names that test files used to import at the top of the original monolithic test file. - - """ import importlib.util @@ -146,163 +148,244 @@ def estimate_tokens(memory, chars_per_token=1.5): return stub -# ── 3. Register stub package hierarchy ─────────────────────── - -def _register_stub_packages(): - """Create empty parent ModuleType entries so the dotted import chain resolves.""" - for name in [ - "sdk", - "sdk.nexent", - "sdk.nexent.core", - "sdk.nexent.core.agents", - "sdk.nexent.core.utils", - "sdk.nexent.core.utils.observer", - "sdk.nexent.core.agents.a2a_agent_proxy", - ]: - if name not in sys.modules: - m = ModuleType(name) - if name == "sdk.nexent.core.utils.observer": - m.MessageObserver = type("MessageObserver", (), {}) - if name == "sdk.nexent.core.agents.a2a_agent_proxy": - m.A2AAgentInfo = type("A2AAgentInfo", (), { - "__init__": lambda self, **kwargs: None - }) - sys.modules[name] = m - - token_est_key = "sdk.nexent.core.utils.token_estimation" - if token_est_key not in sys.modules: - sys.modules[token_est_key] = _build_token_estimation_stub() - - -_register_stub_packages() - - -# ── 3.5. Load summary_cache and summary_config modules ──────────────────── +# ── 3. Path helpers ───────────────────────────────────────────── -def _locate_module(module_name: str) -> str: - """Resolve the absolute path to a module in sdk/nexent/core/agents.""" - here = os.path.dirname(os.path.abspath(__file__)) - repo = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(here))))) - filename = module_name + ".py" - target = os.path.join(repo, "sdk", "nexent", "core", "agents", filename) - if not os.path.exists(target): - raise FileNotFoundError(f"Cannot locate {filename}. Expected: {target}") - return target +_HERE_DIR = os.path.dirname(os.path.abspath(__file__)) +_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(_HERE_DIR))))) +_SDK_ROOT = os.path.join(_REPO_ROOT, "sdk") +_AGENTS_DIR = os.path.join(_SDK_ROOT, "nexent", "core", "agents") +_AC_DIR = os.path.join(_AGENTS_DIR, "agent_context") -def _load_summary_modules(): - """Load summary_cache.py and summary_config.py before agent_context.py.""" - for module_name in ["summary_cache", "summary_config"]: - full_name = f"sdk.nexent.core.agents.{module_name}" - if full_name in sys.modules: - continue - target = _locate_module(module_name) - spec = importlib.util.spec_from_file_location(full_name, target) - module = importlib.util.module_from_spec(spec) - module.__package__ = "sdk.nexent.core.agents" - sys.modules[full_name] = module - spec.loader.exec_module(module) +def _agents_file(filename: str) -> str: + """Absolute path to a file in sdk/nexent/core/agents/.""" + return os.path.join(_AGENTS_DIR, filename) -_load_summary_modules() +def _ac_file(filename: str) -> str: + """Absolute path to a file in sdk/nexent/core/agents/agent_context/.""" + return os.path.join(_AC_DIR, filename) -# ── 4. Load agent_context.py via importlib ──────────────────── +# ── 4. Register stub package hierarchy ─────────────────────── -def _locate_agent_context() -> str: - """ - Resolve the absolute path to agent_context.py. +def _register_stub_packages(): + """Create parent ModuleType entries with __path__ so sub-package imports work. - Directory layout assumed: - / - sdk/nexent/core/agents/agent_context.py - tests/sdk/core/agents/ ← this file lives here + If a package is already registered (e.g. pytest discovers ``test/sdk/`` + as a real ``sdk`` package), we prepend the real SDK directory to its + ``__path__`` so that ``sdk.nexent.*`` resolves through the real source tree. """ - here = os.path.dirname(os.path.abspath(__file__)) - # tests/sdk/core/agents → tests/sdk/core → tests/sdk → tests → repo_root - repo = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(here))))) - target = os.path.join(repo, "sdk", "nexent", "core", "agents", "agent_context.py") - if not os.path.exists(target): - raise FileNotFoundError( - f"Cannot locate agent_context.py.\n" - f"Expected: {target}\n" - f"Check the number of os.path.dirname levels in loader.py." - ) - return target - + for name, pkg_dir in [ + ("sdk", _SDK_ROOT), + ("sdk.nexent", os.path.join(_SDK_ROOT, "nexent")), + ("sdk.nexent.core", os.path.join(_SDK_ROOT, "nexent", "core")), + ("sdk.nexent.core.agents", _AGENTS_DIR), + ("sdk.nexent.core.utils", os.path.join(_SDK_ROOT, "nexent", "core", "utils")), + ]: + if name in sys.modules: + mod = sys.modules[name] + # Prepend so the real SDK directory is searched FIRST + if hasattr(mod, "__path__") and pkg_dir not in mod.__path__: + mod.__path__.insert(0, pkg_dir) + else: + mod = ModuleType(name) + mod.__path__ = [pkg_dir] + mod.__package__ = name + sys.modules[name] = mod + + # Stub for agent_model classes used by manager.py (lazy imports). + # These classes are referenced inside manager._get_strategy() and + # manager.build_system_prompt(). Create minimal stubs so that the + # test harness can exercise those code paths. + _agent_model_stub = ModuleType("sdk.nexent.core.agents.agent_model") + + class _BaseStrategy: + """Minimal strategy base that passes through all components.""" + def select_components(self, components, budget, component_budgets): + return components + + class _FullStrategy(_BaseStrategy): + @staticmethod + def get_strategy_name(): + return "full" + + class _TokenBudgetStrategy(_BaseStrategy): + @staticmethod + def get_strategy_name(): + return "token_budget" + + class _BufferedStrategy(_BaseStrategy): + def __init__(self, buffer_size=5): + self.buffer_size = buffer_size + + @staticmethod + def get_strategy_name(): + return "buffered" + + class _PriorityWeightedStrategy(_BaseStrategy): + def __init__(self, relevance_threshold=0.5): + self.relevance_threshold = relevance_threshold + + @staticmethod + def get_strategy_name(): + return "priority" + + _agent_model_stub.FullStrategy = _FullStrategy + _agent_model_stub.TokenBudgetStrategy = _TokenBudgetStrategy + _agent_model_stub.BufferedStrategy = _BufferedStrategy + _agent_model_stub.PriorityWeightedStrategy = _PriorityWeightedStrategy + _agent_model_stub.ContextStrategy = _BaseStrategy + + # ContextComponent stubs used by build_system_prompt / tests + class _ContextComponent: + component_type: str = "" + priority: int = 10 + token_estimate: int = 0 + _content: str = "" + metadata: dict = {} + + class _SystemPromptComponent(_ContextComponent): + pass + + _agent_model_stub.ContextComponent = _ContextComponent + _agent_model_stub.SystemPromptComponent = _SystemPromptComponent + _agent_model_stub.ToolsComponent = _ContextComponent + _agent_model_stub.SkillsComponent = _ContextComponent + _agent_model_stub.MemoryComponent = _ContextComponent + _agent_model_stub.KnowledgeBaseComponent = _ContextComponent + _agent_model_stub.ManagedAgentsComponent = _ContextComponent + _agent_model_stub.ExternalAgentsComponent = _ContextComponent + + sys.modules["sdk.nexent.core.agents.agent_model"] = _agent_model_stub -def _load_agent_context(): - module_name = "sdk.nexent.core.agents.agent_context" - if module_name in sys.modules: - return sys.modules[module_name] + token_est_key = "sdk.nexent.core.utils.token_estimation" + if token_est_key not in sys.modules: + sys.modules[token_est_key] = _build_token_estimation_stub() - target = _locate_agent_context() - spec = importlib.util.spec_from_file_location(module_name, target) - module = importlib.util.module_from_spec(spec) - module.__package__ = "sdk.nexent.core.agents" - sys.modules[module_name] = module - spec.loader.exec_module(module) - return module +_register_stub_packages() -_ctx_mod = _load_agent_context() -# ── 5. Load agent_model.py for ContextComponent classes ────────────────── +# ── 5. Load summary_cache and summary_config modules ──────────────────── -def _load_agent_model(): - """Load agent_model.py containing ContextComponent and ContextStrategy classes.""" - module_name = "sdk.nexent.core.agents.agent_model" - if module_name in sys.modules: - return sys.modules[module_name] - - target = _locate_module("agent_model") - spec = importlib.util.spec_from_file_location(module_name, target) +def _load_file_module(full_name: str, filepath: str, package: str): + """Load a single .py file as a module via importlib.""" + if full_name in sys.modules: + return sys.modules[full_name] + spec = importlib.util.spec_from_file_location(full_name, filepath) module = importlib.util.module_from_spec(spec) - module.__package__ = "sdk.nexent.core.agents" - sys.modules[module_name] = module + module.__package__ = package + sys.modules[full_name] = module spec.loader.exec_module(module) return module -_agent_model_mod = _load_agent_model() - -# Restore real smolagents in sys.modules so sibling test trees (e.g. -# test/backend/utils/test_context_utils.py) that import the real -# nexent.core.agents path can do "from smolagents.memory import AgentMemory" -# without picking up our mock. The mock classes captured above as -# module-level attributes on _ctx_mod / _agent_model_mod stay valid for our -# own unit tests, which never touch sys.modules['smolagents.*'] at runtime. -restore_real_smolagents() - -# ── 6. Re-export public names (mirrors original monolithic imports) ── +_load_file_module("sdk.nexent.core.agents.summary_cache", + _agents_file("summary_cache.py"), "sdk.nexent.core.agents") +_load_file_module("sdk.nexent.core.agents.summary_config", + _agents_file("summary_config.py"), "sdk.nexent.core.agents") + + +# ── 6. Load agent_context submodules in dependency order ──────────── + +# Submodules must be loaded in topological (dependency) order. +# Each module uses relative imports (from .X import ...) that we +# satisfy by pre-loading dependencies into sys.modules first. +# +# Dependency graph (no circular deps): +# offload_store, summary_step (leaf, no intra-package deps) +# budget (imports from summary_step) +# step_renderer (imports from budget, offload_store) +# llm_summary (imports from step_renderer) +# previous_compression (imports from budget, step_renderer, llm_summary) +# current_compression (imports from budget, step_renderer, llm_summary) +# stats_export (no intra-package deps beyond summary_cache) +# manager (imports from all above) + +_AC_PREFIX = "sdk.nexent.core.agents.agent_context" +_AC_PKG = _AC_PREFIX + +# Leaf modules (no intra-package dependencies) +_load_file_module(f"{_AC_PREFIX}.offload_store", _ac_file("offload_store.py"), _AC_PKG) +_load_file_module(f"{_AC_PREFIX}.summary_step", _ac_file("summary_step.py"), _AC_PKG) + +# Core modules (depend on leaf modules) +_load_file_module(f"{_AC_PREFIX}.budget", _ac_file("budget.py"), _AC_PKG) +_load_file_module(f"{_AC_PREFIX}.step_renderer", _ac_file("step_renderer.py"), _AC_PKG) +_load_file_module(f"{_AC_PREFIX}.llm_summary", _ac_file("llm_summary.py"), _AC_PKG) +_load_file_module(f"{_AC_PREFIX}.previous_compression", _ac_file("previous_compression.py"), _AC_PKG) +_load_file_module(f"{_AC_PREFIX}.current_compression", _ac_file("current_compression.py"), _AC_PKG) +_load_file_module(f"{_AC_PREFIX}.stats_export", _ac_file("stats_export.py"), _AC_PKG) + +# Manager depends on all above +_load_file_module(f"{_AC_PREFIX}.manager", _ac_file("manager.py"), _AC_PKG) + +# Create the agent_context package module that re-exports public names +_ctx_mod = ModuleType(_AC_PREFIX) +_ctx_mod.__package__ = _AC_PKG +_ctx_mod.__path__ = [_AC_DIR] +# Re-export key names from manager so that ``agent_context.ContextManager`` works +_manager_mod = sys.modules[f"{_AC_PREFIX}.manager"] +_ctx_mod.ContextManager = _manager_mod.ContextManager +_ctx_mod.ContextManagerConfig = _manager_mod.ContextManagerConfig +_ctx_mod.CompressionCallRecord = sys.modules["sdk.nexent.core.agents.summary_cache"].CompressionCallRecord +_ctx_mod.PreviousSummaryCache = sys.modules["sdk.nexent.core.agents.summary_cache"].PreviousSummaryCache +_ctx_mod.CurrentSummaryCache = sys.modules["sdk.nexent.core.agents.summary_cache"].CurrentSummaryCache +_ctx_mod.SummaryTaskStep = sys.modules[f"{_AC_PREFIX}.summary_step"].SummaryTaskStep +_ctx_mod.format_summary_output = sys.modules[f"{_AC_PREFIX}.llm_summary"].format_summary_output +_ctx_mod._is_context_length_error = sys.modules[f"{_AC_PREFIX}.llm_summary"]._is_context_length_error +_ctx_mod.compress_history_offline = sys.modules[f"{_AC_PREFIX}.step_renderer"].compress_history_offline +_ctx_mod.OffloadStore = sys.modules[f"{_AC_PREFIX}.offload_store"].OffloadStore +sys.modules[_AC_PREFIX] = _ctx_mod + + +# ── 7. Re-export public names (mirrors original monolithic imports) ── ContextManager = _ctx_mod.ContextManager ContextManagerConfig = _ctx_mod.ContextManagerConfig PreviousSummaryCache = _ctx_mod.PreviousSummaryCache CurrentSummaryCache = _ctx_mod.CurrentSummaryCache SummaryTaskStep = _ctx_mod.SummaryTaskStep -TaskStep = _ctx_mod.TaskStep -ActionStep = _ctx_mod.ActionStep -AgentMemory = _ctx_mod.AgentMemory -ChatMessage = _ctx_mod.ChatMessage -MessageRole = _ctx_mod.MessageRole +TaskStep = _manager_mod.TaskStep +ActionStep = _manager_mod.ActionStep +AgentMemory = _manager_mod.AgentMemory +ChatMessage = _manager_mod.ChatMessage +MessageRole = sys.modules["smolagents.models"].MessageRole CompressionCallRecord = _ctx_mod.CompressionCallRecord -# Export ContextComponent classes -ContextComponent = _agent_model_mod.ContextComponent -SystemPromptComponent = _agent_model_mod.SystemPromptComponent -ToolsComponent = _agent_model_mod.ToolsComponent -SkillsComponent = _agent_model_mod.SkillsComponent -MemoryComponent = _agent_model_mod.MemoryComponent -KnowledgeBaseComponent = _agent_model_mod.KnowledgeBaseComponent -ManagedAgentsComponent = _agent_model_mod.ManagedAgentsComponent -ExternalAgentsComponent = _agent_model_mod.ExternalAgentsComponent - -# Export ContextStrategy classes -ContextStrategy = _agent_model_mod.ContextStrategy -FullStrategy = _agent_model_mod.FullStrategy -TokenBudgetStrategy = _agent_model_mod.TokenBudgetStrategy -BufferedStrategy = _agent_model_mod.BufferedStrategy -PriorityWeightedStrategy = _agent_model_mod.PriorityWeightedStrategy - -from stubs import _SystemPromptStep as SystemPromptStep \ No newline at end of file +from stubs import _SystemPromptStep as SystemPromptStep + +# ── 7a. Re-export OffloadStore ────────────────────────────────── +OffloadStore = _ctx_mod.OffloadStore + +# ── 8. Re-export new standalone functions and classes ────────────── + +from sdk.nexent.core.agents.agent_context.budget import ( + extract_pairs, action_content, pair_fingerprint, action_fingerprint, + has_invoked_tools, is_observation_step, is_tool_call_step, + is_prev_cache_valid, is_curr_cache_valid, + trim_pairs_to_budget, trim_actions_to_budget, +) +from sdk.nexent.core.utils.token_estimation import ( + estimate_tokens, estimate_tokens_text, estimate_tokens_for_steps, + estimate_tokens_for_system_prompt, msg_token_count, msg_char_count, +) +from sdk.nexent.core.agents.agent_context.step_renderer import StepRenderer, compress_history_offline +from sdk.nexent.core.agents.agent_context.llm_summary import LLMSummary, SummaryResult, format_summary_output +from sdk.nexent.core.agents.agent_context.previous_compression import PreviousCompressor, PreviousCompressResult +from sdk.nexent.core.agents.agent_context.current_compression import CurrentCompressor, CurrentCompressResult +from sdk.nexent.core.agents.agent_context.stats_export import ( + get_step_compression_stats, get_all_compression_stats, + export_summary as export_summary_fn, get_token_counts, +) + + +# ── 9. Restore real smolagents for sibling test trees ─────────── +# Restore real smolagents in sys.modules so sibling test trees (e.g. +# test/backend/utils/test_context_utils.py) that import the real +# nexent.core.agents path can do "from smolagents.memory import AgentMemory" +# without picking up our mock. The mock classes captured above as +# module-level attributes stay valid for our own unit tests, which +# never touch sys.modules['smolagents.*'] at runtime. +restore_real_smolagents() diff --git a/test/sdk/core/agents/test_agent_context/stubs.py b/test/sdk/core/agents/test_agent_context/stubs.py index 41eb1917c..89e9a4b31 100644 --- a/test/sdk/core/agents/test_agent_context/stubs.py +++ b/test/sdk/core/agents/test_agent_context/stubs.py @@ -59,6 +59,7 @@ class _ActionStep(_MemoryStep): action_output: Optional[Any] = None observations: Optional[str] = None tool_calls: Optional[list] = None + invoked_tools: Optional[list] = None error: Optional[str] = None token_usage: Optional[Any] = None diff --git a/test/sdk/core/agents/test_agent_context/unit/test_budget_trim.py b/test/sdk/core/agents/test_agent_context/unit/test_budget_trim.py index ebc5fd9f8..a290f9689 100644 --- a/test/sdk/core/agents/test_agent_context/unit/test_budget_trim.py +++ b/test/sdk/core/agents/test_agent_context/unit/test_budget_trim.py @@ -1,130 +1,131 @@ -from factories import make_cm, make_pair -from loader import ActionStep - - -class TestBudgetTrimming: - - def test_trim_pairs_within_budget_returns_all(self): - cm = make_cm() - pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(3)] - result = cm._trim_pairs_to_budget(pairs, max_tokens=99999) - assert len(result) == 3 - - def test_trim_pairs_empty_input(self): - cm = make_cm() - assert cm._trim_pairs_to_budget([], max_tokens=1000) == [] - - def test_trim_pairs_keeps_at_least_last_when_all_overflow(self): - """Even with minimal budget, at least keep the last pair.""" - cm = make_cm() - pairs = [make_pair("very long task description" * 50, "very long response content" * 50, i) for i in range(3)] - result = cm._trim_pairs_to_budget(pairs, max_tokens=1, keep_first=False) - assert len(result) == 1 - - def test_trim_pairs_keep_first_true_keeps_first_pair(self): - """keep_first=True, first pair must be retained.""" - cm = make_cm() - pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(5)] - first_pair_tokens = cm._estimate_text_tokens(cm._pairs_to_text([pairs[0]])) - result = cm._trim_pairs_to_budget(pairs, max_tokens=first_pair_tokens + 5, keep_first=True) - assert result[0] == pairs[0] - - def test_trim_actions_within_budget_returns_all(self): - cm = make_cm() - actions = [ActionStep(step_number=i, model_output=f"output{i}") for i in range(3)] - result = cm._trim_actions_to_budget(actions, task_text="", max_tokens=99999) - assert len(result) == 3 - - def test_trim_actions_empty_returns_empty(self): - cm = make_cm() - assert cm._trim_actions_to_budget([], task_text="", max_tokens=1000) == [] - - def test_trim_actions_keeps_last_when_overflow(self): - """Minimal budget, at least keep the last action.""" - cm = make_cm() - actions = [ - ActionStep(step_number=i, model_output="X" * 500, action_output="Y" * 500) - for i in range(4) - ] - result = cm._trim_actions_to_budget(actions, task_text="", max_tokens=1) - assert len(result) >= 1 - assert result[-1] is actions[-1] - - def test_trim_actions_skips_drop_that_splits_tool_call_and_observation(self): - """When truncation point would split tool_calls and observations, skip that truncation point.""" - cm = make_cm() - actions = [ - ActionStep(step_number=0, model_output="A" * 400, tool_calls=[{"name": "tool1"}]), - ActionStep(step_number=1, model_output="B" * 400, observations="obs1"), - ActionStep(step_number=2, model_output="C" * 400), - ] - two_act_tokens = cm._estimate_text_tokens(cm._actions_to_text(actions[1:])) - three_act_tokens = cm._estimate_text_tokens(cm._actions_to_text(actions)) - max_tokens = two_act_tokens + (three_act_tokens - two_act_tokens) // 2 - - result = cm._trim_actions_to_budget(actions, task_text="", max_tokens=max_tokens) - assert result == [actions[2]] - - def test_trim_actions_allows_drop_when_no_tool_call_before_observation(self): - """remaining[0] has observations, but previous action has no tool_calls, should allow truncation.""" - cm = make_cm() - actions = [ - ActionStep(step_number=0, model_output="A" * 400), - ActionStep(step_number=1, model_output="B" * 400, observations="obs1"), - ActionStep(step_number=2, model_output="C" * 400), - ] - two_act_tokens = cm._estimate_text_tokens(cm._actions_to_text(actions[1:])) - three_act_tokens = cm._estimate_text_tokens(cm._actions_to_text(actions)) - max_tokens = two_act_tokens + (three_act_tokens - two_act_tokens) // 2 - - result = cm._trim_actions_to_budget(actions, task_text="", max_tokens=max_tokens) - assert result == [actions[1], actions[2]] - - def test_trim_actions_allows_drop_when_no_observation_after_tool_call(self): - """actions[drop-1] has tool_calls, but remaining[0] has no observations, should allow truncation.""" - cm = make_cm() - actions = [ - ActionStep(step_number=0, model_output="A" * 400, tool_calls=[{"name": "tool1"}]), - ActionStep(step_number=1, model_output="B" * 400), - ActionStep(step_number=2, model_output="C" * 400), - ] - two_act_tokens = cm._estimate_text_tokens(cm._actions_to_text(actions[1:])) - three_act_tokens = cm._estimate_text_tokens(cm._actions_to_text(actions)) - max_tokens = two_act_tokens + (three_act_tokens - two_act_tokens) // 2 - - result = cm._trim_actions_to_budget(actions, task_text="", max_tokens=max_tokens) - assert result == [actions[1], actions[2]] - - def test_trim_actions_chain_pairs_fallback_returns_complete_pair(self): - """Continuous pairing causes all suffix truncation points invalid or over budget, fallback returns last complete tool_call+observation pair.""" - cm = make_cm() - actions = [ - ActionStep(step_number=0, model_output="A" * 400, tool_calls=[{"name": "t1"}]), - ActionStep(step_number=1, model_output="B" * 400, observations="obs1"), - ActionStep(step_number=2, model_output="C" * 400, tool_calls=[{"name": "t2"}]), - ActionStep(step_number=3, model_output="D" * 400, observations="obs2"), - ] - result = cm._trim_actions_to_budget(actions, task_text="", max_tokens=1) - assert result == [actions[2], actions[3]] - - def test_trim_actions_fallback_returns_pair_when_last_is_observation(self): - """Fallback when last action is observation and previous has tool_calls, return complete pair.""" - cm = make_cm() - actions = [ - ActionStep(step_number=0, model_output="A" * 400), - ActionStep(step_number=1, model_output="B" * 400, tool_calls=[{"name": "t1"}]), - ActionStep(step_number=2, model_output="C" * 400, observations="obs1"), - ] - result = cm._trim_actions_to_budget(actions, task_text="", max_tokens=1) - assert result == [actions[1], actions[2]] - - def test_trim_actions_fallback_returns_single_when_last_has_no_observation(self): - """Fallback when last action has no observations, return single last one.""" - cm = make_cm() - actions = [ - ActionStep(step_number=0, model_output="A" * 400), - ActionStep(step_number=1, model_output="B" * 400), - ActionStep(step_number=2, model_output="C" * 400), - ] - result = cm._trim_actions_to_budget(actions, task_text="", max_tokens=1) - assert result == [actions[-1]] \ No newline at end of file +from factories import make_cm, make_pair +from loader import ActionStep +from loader import trim_pairs_to_budget, trim_actions_to_budget, estimate_tokens_text + + +class TestBudgetTrimming: + + def test_trim_pairs_within_budget_returns_all(self): + cm = make_cm() + pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(3)] + result = trim_pairs_to_budget(pairs, max_tokens=99999, render_fn=cm._renderer.pairs_to_text) + assert len(result) == 3 + + def test_trim_pairs_empty_input(self): + cm = make_cm() + assert trim_pairs_to_budget([], max_tokens=1000, render_fn=cm._renderer.pairs_to_text) == [] + + def test_trim_pairs_keeps_at_least_last_when_all_overflow(self): + """Even with minimal budget, at least keep the last pair.""" + cm = make_cm() + pairs = [make_pair("very long task description" * 50, "very long response content" * 50, i) for i in range(3)] + result = trim_pairs_to_budget(pairs, max_tokens=1, render_fn=cm._renderer.pairs_to_text, keep_first=False) + assert len(result) == 1 + + def test_trim_pairs_keep_first_true_keeps_first_pair(self): + """keep_first=True, first pair must be retained.""" + cm = make_cm() + pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(5)] + first_pair_tokens = estimate_tokens_text(cm._renderer.pairs_to_text([pairs[0]])) + result = trim_pairs_to_budget(pairs, max_tokens=first_pair_tokens + 5, render_fn=cm._renderer.pairs_to_text, keep_first=True) + assert result[0] == pairs[0] + + def test_trim_actions_within_budget_returns_all(self): + cm = make_cm() + actions = [ActionStep(step_number=i, model_output=f"output{i}") for i in range(3)] + result = trim_actions_to_budget(actions, task_text="", max_tokens=99999, render_fn=cm._renderer.actions_to_text) + assert len(result) == 3 + + def test_trim_actions_empty_returns_empty(self): + cm = make_cm() + assert trim_actions_to_budget([], task_text="", max_tokens=1000, render_fn=cm._renderer.actions_to_text) == [] + + def test_trim_actions_keeps_last_when_overflow(self): + """Minimal budget, at least keep the last action.""" + cm = make_cm() + actions = [ + ActionStep(step_number=i, model_output="X" * 500, action_output="Y" * 500) + for i in range(4) + ] + result = trim_actions_to_budget(actions, task_text="", max_tokens=1, render_fn=cm._renderer.actions_to_text) + assert len(result) >= 1 + assert result[-1] is actions[-1] + + def test_trim_actions_skips_drop_that_splits_tool_call_and_observation(self): + """When truncation point would split tool_calls and observations, skip that truncation point.""" + cm = make_cm() + actions = [ + ActionStep(step_number=0, model_output="A" * 400, tool_calls=[{"name": "tool1"}]), + ActionStep(step_number=1, model_output="B" * 400, observations="obs1"), + ActionStep(step_number=2, model_output="C" * 400), + ] + two_act_tokens = estimate_tokens_text(cm._renderer.actions_to_text(actions[1:])) + three_act_tokens = estimate_tokens_text(cm._renderer.actions_to_text(actions)) + max_tokens = two_act_tokens + (three_act_tokens - two_act_tokens) // 2 + + result = trim_actions_to_budget(actions, task_text="", max_tokens=max_tokens, render_fn=cm._renderer.actions_to_text) + assert result == [actions[2]] + + def test_trim_actions_allows_drop_when_no_tool_call_before_observation(self): + """remaining[0] has observations, but previous action has no tool_calls, should allow truncation.""" + cm = make_cm() + actions = [ + ActionStep(step_number=0, model_output="A" * 400), + ActionStep(step_number=1, model_output="B" * 400, observations="obs1"), + ActionStep(step_number=2, model_output="C" * 400), + ] + two_act_tokens = estimate_tokens_text(cm._renderer.actions_to_text(actions[1:])) + three_act_tokens = estimate_tokens_text(cm._renderer.actions_to_text(actions)) + max_tokens = two_act_tokens + (three_act_tokens - two_act_tokens) // 2 + + result = trim_actions_to_budget(actions, task_text="", max_tokens=max_tokens, render_fn=cm._renderer.actions_to_text) + assert result == [actions[1], actions[2]] + + def test_trim_actions_allows_drop_when_no_observation_after_tool_call(self): + """actions[drop-1] has tool_calls, but remaining[0] has no observations, should allow truncation.""" + cm = make_cm() + actions = [ + ActionStep(step_number=0, model_output="A" * 400, tool_calls=[{"name": "tool1"}]), + ActionStep(step_number=1, model_output="B" * 400), + ActionStep(step_number=2, model_output="C" * 400), + ] + two_act_tokens = estimate_tokens_text(cm._renderer.actions_to_text(actions[1:])) + three_act_tokens = estimate_tokens_text(cm._renderer.actions_to_text(actions)) + max_tokens = two_act_tokens + (three_act_tokens - two_act_tokens) // 2 + + result = trim_actions_to_budget(actions, task_text="", max_tokens=max_tokens, render_fn=cm._renderer.actions_to_text) + assert result == [actions[1], actions[2]] + + def test_trim_actions_chain_pairs_fallback_returns_complete_pair(self): + """Continuous pairing causes all suffix truncation points invalid or over budget, fallback returns last complete tool_call+observation pair.""" + cm = make_cm() + actions = [ + ActionStep(step_number=0, model_output="A" * 400, tool_calls=[{"name": "t1"}]), + ActionStep(step_number=1, model_output="B" * 400, observations="obs1"), + ActionStep(step_number=2, model_output="C" * 400, tool_calls=[{"name": "t2"}]), + ActionStep(step_number=3, model_output="D" * 400, observations="obs2"), + ] + result = trim_actions_to_budget(actions, task_text="", max_tokens=1, render_fn=cm._renderer.actions_to_text) + assert result == [actions[2], actions[3]] + + def test_trim_actions_fallback_returns_pair_when_last_is_observation(self): + """Fallback when last action is observation and previous has tool_calls, return complete pair.""" + cm = make_cm() + actions = [ + ActionStep(step_number=0, model_output="A" * 400), + ActionStep(step_number=1, model_output="B" * 400, tool_calls=[{"name": "t1"}]), + ActionStep(step_number=2, model_output="C" * 400, observations="obs1"), + ] + result = trim_actions_to_budget(actions, task_text="", max_tokens=1, render_fn=cm._renderer.actions_to_text) + assert result == [actions[1], actions[2]] + + def test_trim_actions_fallback_returns_single_when_last_has_no_observation(self): + """Fallback when last action has no observations, return single last one.""" + cm = make_cm() + actions = [ + ActionStep(step_number=0, model_output="A" * 400), + ActionStep(step_number=1, model_output="B" * 400), + ActionStep(step_number=2, model_output="C" * 400), + ] + result = trim_actions_to_budget(actions, task_text="", max_tokens=1, render_fn=cm._renderer.actions_to_text) + assert result == [actions[-1]] diff --git a/test/sdk/core/agents/test_agent_context/unit/test_build_message.py b/test/sdk/core/agents/test_agent_context/unit/test_build_message.py index 50ceee1f0..de0c93fa6 100644 --- a/test/sdk/core/agents/test_agent_context/unit/test_build_message.py +++ b/test/sdk/core/agents/test_agent_context/unit/test_build_message.py @@ -1,44 +1,44 @@ -from factories import make_cm, make_pair -from loader import AgentMemory, SummaryTaskStep, SystemPromptStep - - -class TestBuildMessages: - - def test_build_messages_no_summary(self): - cm = make_cm() - t, a = make_pair("task", "action") - memory = AgentMemory(steps=[]) - msgs = cm._build_messages(memory, None, [], [t, a]) - all_text = " ".join( - b.get("text", "") - for m in msgs for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, dict) - ) - assert "task" in all_text - assert "action" in all_text - - def test_build_messages_with_prev_summary_comes_first(self): - cm = make_cm() - summary = SummaryTaskStep(task="history summary content") - t, a = make_pair("current task", "current result", 1) - memory = AgentMemory(steps=[]) - msgs = cm._build_messages(memory, summary, [], [t, a]) - all_texts = [ - b.get("text", "") - for m in msgs for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, dict) - ] - summary_idx = next(i for i, t in enumerate(all_texts) if "history summary content" in t) - curr_idx = next(i for i, t in enumerate(all_texts) if "current task" in t) - assert summary_idx < curr_idx - - def test_build_messages_with_system_prompt(self): - cm = make_cm() - memory = AgentMemory(steps=[], system_prompt=SystemPromptStep(system_prompt="system prompt")) - msgs = cm._build_messages(memory, None, [], []) - all_text = " ".join( - b.get("text", "") - for m in msgs for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, dict) - ) - assert "system prompt" in all_text \ No newline at end of file +from factories import make_cm, make_pair +from loader import AgentMemory, SummaryTaskStep, SystemPromptStep + + +class TestBuildMessages: + + def test_build_messages_no_summary(self): + cm = make_cm() + t, a = make_pair("task", "action") + memory = AgentMemory(steps=[]) + msgs = cm._renderer.build_messages(memory, None, [], [t, a]) + all_text = " ".join( + b.get("text", "") + for m in msgs for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, dict) + ) + assert "task" in all_text + assert "action" in all_text + + def test_build_messages_with_prev_summary_comes_first(self): + cm = make_cm() + summary = SummaryTaskStep(task="history summary content") + t, a = make_pair("current task", "current result", 1) + memory = AgentMemory(steps=[]) + msgs = cm._renderer.build_messages(memory, summary, [], [t, a]) + all_texts = [ + b.get("text", "") + for m in msgs for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, dict) + ] + summary_idx = next(i for i, t in enumerate(all_texts) if "history summary content" in t) + curr_idx = next(i for i, t in enumerate(all_texts) if "current task" in t) + assert summary_idx < curr_idx + + def test_build_messages_with_system_prompt(self): + cm = make_cm() + memory = AgentMemory(steps=[], system_prompt=SystemPromptStep(system_prompt="system prompt")) + msgs = cm._renderer.build_messages(memory, None, [], []) + all_text = " ".join( + b.get("text", "") + for m in msgs for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, dict) + ) + assert "system prompt" in all_text diff --git a/test/sdk/core/agents/test_agent_context/unit/test_cache_valid.py b/test/sdk/core/agents/test_agent_context/unit/test_cache_valid.py index 716f5808f..b4c53a010 100644 --- a/test/sdk/core/agents/test_agent_context/unit/test_cache_valid.py +++ b/test/sdk/core/agents/test_agent_context/unit/test_cache_valid.py @@ -1,85 +1,76 @@ -from factories import make_cm, make_pair -from loader import PreviousSummaryCache, CurrentSummaryCache, ActionStep, ContextManager - - -class TestCacheValidation: - - def test_prev_cache_none_returns_false(self): - cm = make_cm() - t, a = make_pair() - valid, idx = cm._is_prev_cache_valid([(t, a)]) - assert valid is False - assert idx == 0 - - def test_prev_cache_empty_pairs_returns_false(self): - cm = make_cm() - cm._previous_summary_cache = PreviousSummaryCache("summary", 1, "fp") - valid, idx = cm._is_prev_cache_valid([]) - assert valid is False - - def test_prev_cache_covered_exceeds_pairs_returns_false(self): - cm = make_cm() - t, a = make_pair("task", "action") - fp = cm._pair_fingerprint("task", "action") - cm._previous_summary_cache = PreviousSummaryCache("summary", 5, fp) - valid, _ = cm._is_prev_cache_valid([(t, a)]) - assert valid is False - - def test_prev_cache_fingerprint_mismatch_returns_false(self): - cm = make_cm() - t, a = make_pair("task A", "action A") - cm._previous_summary_cache = PreviousSummaryCache( - "summary", 1, "wrong_fingerprint_xyz" - ) - valid, _ = cm._is_prev_cache_valid([(t, a)]) - assert valid is False - - def test_prev_cache_valid_hit(self): - cm = make_cm() - t, a = make_pair("task", "action") - fp = cm._pair_fingerprint("task", "action") - cm._previous_summary_cache = PreviousSummaryCache("summary text", 1, fp) - valid, covered_idx = cm._is_prev_cache_valid([(t, a)]) - assert valid is True - assert covered_idx == 1 - - def test_prev_cache_valid_partial_coverage(self): - """Cache covers first 2 pairs, total 3 pairs -> valid, return covered=2.""" - cm = make_cm() - pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(3)] - t1, a1 = pairs[1] - fp = cm._pair_fingerprint(t1.task, a1.action_output) - cm._previous_summary_cache = PreviousSummaryCache("summary", 2, fp) - valid, covered_idx = cm._is_prev_cache_valid(pairs) - assert valid is True - assert covered_idx == 2 - - def test_curr_cache_none_returns_false(self): - cm = make_cm() - a = ActionStep(step_number=1, model_output="x", action_output="y") - valid, idx = cm._is_curr_cache_valid([a]) - assert valid is False - - def test_curr_cache_fingerprint_mismatch_returns_false(self): - cm = make_cm() - a = ActionStep(step_number=1, model_output="x", action_output="y") - cm._current_summary_cache = CurrentSummaryCache("summary", 1, "wrong_fp") - valid, _ = cm._is_curr_cache_valid([a]) - assert valid is False - - def test_curr_cache_end_steps_exceeds_list_returns_false(self): - cm = make_cm() - a = ActionStep(step_number=1, model_output="x", action_output="y") - fp = ContextManager._action_fingerprint(a) - cm._current_summary_cache = CurrentSummaryCache("summary", 5, fp) - valid, _ = cm._is_curr_cache_valid([a]) - assert valid is False - - def test_curr_cache_valid_hit(self): - cm = make_cm() - a = ActionStep(step_number=1, model_output="output", action_output="result") - fp = ContextManager._action_fingerprint(a) - cm._current_summary_cache = CurrentSummaryCache("summary text", 1, fp) - valid, end_steps = cm._is_curr_cache_valid([a]) - assert valid is True - assert end_steps == 1 \ No newline at end of file +from factories import make_cm, make_pair +from loader import ( + PreviousSummaryCache, CurrentSummaryCache, ActionStep, ContextManager, + is_prev_cache_valid, is_curr_cache_valid, pair_fingerprint, action_fingerprint, +) + + +class TestCacheValidation: + + def test_prev_cache_none_returns_false(self): + t, a = make_pair() + valid, idx = is_prev_cache_valid([(t, a)], None) + assert valid is False + assert idx == 0 + + def test_prev_cache_empty_pairs_returns_false(self): + cache = PreviousSummaryCache("summary", 1, "fp") + valid, idx = is_prev_cache_valid([], cache) + assert valid is False + + def test_prev_cache_covered_exceeds_pairs_returns_false(self): + t, a = make_pair("task", "action") + fp = pair_fingerprint("task", "action") + cache = PreviousSummaryCache("summary", 5, fp) + valid, _ = is_prev_cache_valid([(t, a)], cache) + assert valid is False + + def test_prev_cache_fingerprint_mismatch_returns_false(self): + t, a = make_pair("task A", "action A") + cache = PreviousSummaryCache("summary", 1, "wrong_fingerprint_xyz") + valid, _ = is_prev_cache_valid([(t, a)], cache) + assert valid is False + + def test_prev_cache_valid_hit(self): + t, a = make_pair("task", "action") + fp = pair_fingerprint("task", "action") + cache = PreviousSummaryCache("summary text", 1, fp) + valid, covered_idx = is_prev_cache_valid([(t, a)], cache) + assert valid is True + assert covered_idx == 1 + + def test_prev_cache_valid_partial_coverage(self): + """Cache covers first 2 pairs, total 3 pairs -> valid, return covered=2.""" + pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(3)] + t1, a1 = pairs[1] + fp = pair_fingerprint(t1.task, a1.action_output) + cache = PreviousSummaryCache("summary", 2, fp) + valid, covered_idx = is_prev_cache_valid(pairs, cache) + assert valid is True + assert covered_idx == 2 + + def test_curr_cache_none_returns_false(self): + a = ActionStep(step_number=1, model_output="x", action_output="y") + valid, idx = is_curr_cache_valid([a], None) + assert valid is False + + def test_curr_cache_fingerprint_mismatch_returns_false(self): + a = ActionStep(step_number=1, model_output="x", action_output="y") + cache = CurrentSummaryCache("summary", 1, "wrong_fp") + valid, _ = is_curr_cache_valid([a], cache) + assert valid is False + + def test_curr_cache_end_steps_exceeds_list_returns_false(self): + a = ActionStep(step_number=1, model_output="x", action_output="y") + fp = action_fingerprint(a) + cache = CurrentSummaryCache("summary", 5, fp) + valid, _ = is_curr_cache_valid([a], cache) + assert valid is False + + def test_curr_cache_valid_hit(self): + a = ActionStep(step_number=1, model_output="output", action_output="result") + fp = action_fingerprint(a) + cache = CurrentSummaryCache("summary text", 1, fp) + valid, end_steps = is_curr_cache_valid([a], cache) + assert valid is True + assert end_steps == 1 diff --git a/test/sdk/core/agents/test_agent_context/unit/test_component_management.py b/test/sdk/core/agents/test_agent_context/unit/test_component_management.py index 5f25e1119..f92ba6b4b 100644 --- a/test/sdk/core/agents/test_agent_context/unit/test_component_management.py +++ b/test/sdk/core/agents/test_agent_context/unit/test_component_management.py @@ -8,64 +8,54 @@ - build_system_prompt() - _get_strategy() - _calculate_component_budget() +- _message_already_present() """ -import sys -import os -from pathlib import Path - -TEST_ROOT = Path(__file__).resolve().parents[2] -PROJECT_ROOT = TEST_ROOT.parent - -for _path in (str(PROJECT_ROOT), str(TEST_ROOT)): - if _path not in sys.path: - sys.path.insert(0, _path) from loader import ContextManager, ContextManagerConfig -from stubs import _SystemPromptStep class MockComponent: """Mock context component for testing.""" - + def __init__(self, component_type="test", content="", priority=10, token_estimate=0): self.component_type = component_type self.priority = priority self.token_estimate = token_estimate self._content = content self.metadata = {} - + def to_messages(self): if self._content: return [{"role": "system", "content": self._content}] return [] - + def estimate_tokens(self, chars_per_token=1.5): return int(len(self._content) / chars_per_token) class TestRegisterComponent: """Tests for register_component() method.""" - + def test_register_single_component(self): cm = ContextManager() comp = MockComponent(component_type="test", content="test content") cm.register_component(comp) assert len(cm.get_registered_components()) == 1 - + def test_register_multiple_components(self): cm = ContextManager() cm.register_component(MockComponent(content="comp1")) cm.register_component(MockComponent(content="comp2")) cm.register_component(MockComponent(content="comp3")) assert len(cm.get_registered_components()) == 3 - + def test_register_sets_token_estimate(self): cm = ContextManager() comp = MockComponent(content="test content here", token_estimate=0) cm.register_component(comp) registered = cm.get_registered_components() assert registered[0].token_estimate > 0 - + def test_register_preserves_existing_token_estimate(self): cm = ContextManager() comp = MockComponent(content="test", token_estimate=100) @@ -76,19 +66,19 @@ def test_register_preserves_existing_token_estimate(self): class TestClearComponents: """Tests for clear_components() method.""" - + def test_clear_removes_all_components(self): cm = ContextManager() cm.register_component(MockComponent(content="comp1")) cm.register_component(MockComponent(content="comp2")) cm.clear_components() assert cm.get_registered_components() == [] - + def test_clear_on_empty_manager(self): cm = ContextManager() cm.clear_components() assert cm.get_registered_components() == [] - + def test_clear_allows_new_registration(self): cm = ContextManager() cm.register_component(MockComponent(content="old")) @@ -100,7 +90,7 @@ def test_clear_allows_new_registration(self): class TestGetRegisteredComponents: """Tests for get_registered_components() method.""" - + def test_returns_copy_not_reference(self): cm = ContextManager() cm.register_component(MockComponent(content="original")) @@ -108,12 +98,12 @@ def test_returns_copy_not_reference(self): copy2 = cm.get_registered_components() copy1.clear() assert len(copy2) == 1 - + def test_returns_empty_list_when_no_components(self): cm = ContextManager() result = cm.get_registered_components() assert result == [] - + def test_preserves_component_order(self): cm = ContextManager() cm.register_component(MockComponent(content="first", priority=10)) @@ -125,31 +115,31 @@ def test_preserves_component_order(self): class TestGetStrategy: """Tests for _get_strategy() method.""" - + def test_default_returns_token_budget_strategy(self): cm = ContextManager() strategy = cm._get_strategy() assert strategy.get_strategy_name() == "token_budget" - + def test_full_strategy(self): config = ContextManagerConfig(strategy="full") cm = ContextManager(config) strategy = cm._get_strategy() assert strategy.get_strategy_name() == "full" - + def test_buffered_strategy_with_custom_buffer_size(self): config = ContextManagerConfig(strategy="buffered", buffer_size_per_component=5) cm = ContextManager(config) strategy = cm._get_strategy() assert strategy.get_strategy_name() == "buffered" assert strategy.buffer_size == 5 - + def test_priority_strategy(self): config = ContextManagerConfig(strategy="priority") cm = ContextManager(config) strategy = cm._get_strategy() assert strategy.get_strategy_name() == "priority" - + def test_unknown_strategy_defaults_to_token_budget(self): config = ContextManagerConfig(strategy="unknown") cm = ContextManager(config) @@ -159,12 +149,12 @@ def test_unknown_strategy_defaults_to_token_budget(self): class TestBuildSystemPrompt: """Tests for build_system_prompt() method.""" - + def test_empty_components_returns_empty_messages(self): cm = ContextManager() messages = cm.build_system_prompt() assert messages == [] - + def test_single_component_returns_messages(self): cm = ContextManager() cm.register_component(MockComponent(content="test prompt")) @@ -172,14 +162,14 @@ def test_single_component_returns_messages(self): assert len(messages) == 1 assert messages[0]["role"] == "system" assert messages[0]["content"] == "test prompt" - + def test_multiple_components_combined(self): cm = ContextManager() cm.register_component(MockComponent(content="prompt1", priority=20)) cm.register_component(MockComponent(content="prompt2", priority=10)) messages = cm.build_system_prompt() assert len(messages) == 2 - + def test_custom_token_budget(self): cm = ContextManager() cm.register_component(MockComponent(content="short", token_estimate=50)) @@ -187,7 +177,7 @@ def test_custom_token_budget(self): messages = cm.build_system_prompt(token_budget=100) total_content = sum(len(m["content"]) for m in messages) assert total_content < 500 - + def test_deduplicates_identical_messages(self): cm = ContextManager() cm.register_component(MockComponent(content="same content")) @@ -198,14 +188,14 @@ def test_deduplicates_identical_messages(self): class TestCalculateComponentBudget: """Tests for _calculate_component_budget() method.""" - + def test_excludes_conversation_history(self): cm = ContextManager() budget = cm._calculate_component_budget() budgets = cm.config.component_budgets assert "conversation_history" in budgets assert budget == sum(v for k, v in budgets.items() if k != "conversation_history") - + def test_sum_of_non_excluded_budgets(self): cm = ContextManager() budget = cm._calculate_component_budget() @@ -223,25 +213,25 @@ def test_sum_of_non_excluded_budgets(self): class TestMessageAlreadyPresent: """Tests for _message_already_present() method.""" - + def test_identical_message_detected(self): cm = ContextManager() messages = [{"role": "system", "content": "test"}] new_msg = {"role": "system", "content": "test"} assert cm._message_already_present(messages, new_msg) is True - + def test_different_content_not_detected(self): cm = ContextManager() messages = [{"role": "system", "content": "test"}] new_msg = {"role": "system", "content": "different"} assert cm._message_already_present(messages, new_msg) is False - + def test_different_role_not_detected(self): cm = ContextManager() messages = [{"role": "system", "content": "test"}] new_msg = {"role": "user", "content": "test"} assert cm._message_already_present(messages, new_msg) is False - + def test_empty_messages_list(self): cm = ContextManager() new_msg = {"role": "system", "content": "test"} @@ -250,20 +240,20 @@ def test_empty_messages_list(self): class TestComponentManagementWithConfig: """Tests for component management with custom ContextManagerConfig.""" - + def test_strategy_selection_from_config(self): config = ContextManagerConfig(strategy="full") cm = ContextManager(config) strategy = cm._get_strategy() assert strategy.get_strategy_name() == "full" - + def test_component_budgets_from_config(self): custom_budgets = {"system_prompt": 2000, "tools": 1000, "conversation_history": 3000} config = ContextManagerConfig(component_budgets=custom_budgets) cm = ContextManager(config) budget = cm._calculate_component_budget() assert budget == 3000 - + def test_chars_per_token_used_in_estimation(self): config = ContextManagerConfig(chars_per_token=2.0) cm = ContextManager(config) diff --git a/test/sdk/core/agents/test_agent_context/unit/test_compress_if_needed.py b/test/sdk/core/agents/test_agent_context/unit/test_compress_if_needed.py index 79dfd5a03..0eec0367f 100644 --- a/test/sdk/core/agents/test_agent_context/unit/test_compress_if_needed.py +++ b/test/sdk/core/agents/test_agent_context/unit/test_compress_if_needed.py @@ -1,189 +1,192 @@ -from factories import make_cm, make_pair, make_model, make_memory_mixed, make_original_messages -from loader import AgentMemory, TaskStep, SystemPromptStep, CurrentSummaryCache, PreviousSummaryCache, ContextManager - - -def _all_texts(messages): - return [ - b.get("text", "") - for m in messages - for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, dict) - ] - - -def _joined(messages): - return " ".join(_all_texts(messages)) - - -class TestCompressIfNeeded: - - def test_disabled_returns_original_messages(self): - """config.enabled=False returns original_messages without any processing.""" - cm = make_cm(enabled=False, threshold=10) - n_prev_pairs = 1 - n_curr_actions = 1 - memory = make_memory_mixed(n_prev_pairs, n_curr_actions) - original = make_original_messages(memory) - current_run_start_idx = 2 * n_prev_pairs - result = cm.compress_if_needed(None, memory, original, current_run_start_idx=current_run_start_idx) - assert result is original - - def test_under_threshold_returns_original(self): - """raw tokens < threshold returns directly, no LLM call.""" - cm = make_cm(enabled=True, threshold=999999) - n_prev_pairs = 1 - n_curr_actions = 1 - memory = make_memory_mixed(n_prev_pairs, n_curr_actions) - original = make_original_messages(memory) - current_run_start_idx = 2 * n_prev_pairs - model = make_model() - result = cm.compress_if_needed(None, memory, original, current_run_start_idx=current_run_start_idx) - assert result is original - model.assert_not_called() - - def test_over_threshold_triggers_compression(self): - """raw tokens > threshold should call LLM (all previous-run scenario).""" - keep_recent_pairs = 1 - keep_recent_steps = 2 - cm = make_cm(enabled=True, threshold=10, keep_recent_steps=keep_recent_steps, keep_recent_pairs=keep_recent_pairs) - n_prev_pairs = 3 - n_curr_actions = 2 - memory = make_memory_mixed(n_prev_pairs=n_prev_pairs, n_curr_actions=n_curr_actions) - original = make_original_messages(memory) - assert len(original) == 1 + n_prev_pairs * 2 + 1 + n_curr_actions - current_run_start_idx = 2 * n_prev_pairs - model = make_model('{"task_overview": "summary"}') - result = cm.compress_if_needed(model, memory, original, current_run_start_idx) - assert result is not None - assert isinstance(result, list) - assert len(result) == 1 + 1 + 2 * keep_recent_pairs + 1 + keep_recent_steps - model.assert_called_once() - all_text = " ".join( - b.get("text", "") - for m in result for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, dict) - ) - assert "Summary of earlier steps" in all_text - - def test_run_boundary_clears_current_cache(self): - """Switching run (current_run_start_idx changes) and ensuring no current summary triggers, current cache should be cleared.""" - cm = make_cm(enabled=True, threshold=1) - cm._current_summary_cache = CurrentSummaryCache("old cache", 1, "fp") - cm._last_run_start_idx = 5 - memory = make_memory_mixed(1, 0) - original = make_original_messages(memory) - model = make_model('{"task_overview": "summary"}') - try: - cm.compress_if_needed(model, memory, original, current_run_start_idx=0) - except Exception: - pass - assert cm._current_summary_cache is None - - def test_effective_tokens_shortcut_applies_cache(self): - """effective tokens < threshold short-circuit, directly apply existing cache to build messages (all previous-run).""" - cm = make_cm(enabled=True, threshold=10, keep_recent_pairs=0) - pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(2)] - all_steps = [] - for t, a in pairs: - all_steps.extend([t, a]) - all_steps.append(TaskStep(task="New Task")) - memory = AgentMemory(steps=all_steps, system_prompt=SystemPromptStep(system_prompt="system prompt")) - last_t, last_a = pairs[1] - fp = cm._pair_fingerprint(last_t.task, last_a.action_output) - cm._previous_summary_cache = PreviousSummaryCache("short summary", 2, fp) - - model = make_model('{"task_overview": "summary"}') - original = make_original_messages(memory) - current_run_start_idx = 2 * len(pairs) - result = cm.compress_if_needed(model, memory, original, current_run_start_idx) - model.assert_not_called() - assert isinstance(result, list) - assert len(result) == 3 - all_text = " ".join( - b.get("text", "") - for m in result for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, dict) - ) - assert "short summary" in all_text - - def test_current_run_cache_full_hit_no_llm_call(self): - """current cache fully hit, current part should be replaced by summary and no LLM call.""" - cm = make_cm(enabled=True, threshold=7) - curr_t, curr_a = make_pair("curr_task", "curr_action", 0) - memory = AgentMemory(steps=[curr_t, curr_a], system_prompt=SystemPromptStep(system_prompt="system prompt")) - - fp = ContextManager._action_fingerprint(curr_a) - cm._current_summary_cache = CurrentSummaryCache("sum_cc", 1, fp) - - model = make_model() - original = make_original_messages(memory) - result = cm.compress_if_needed(model, memory, original, current_run_start_idx=0) - - model.assert_not_called() - assert isinstance(result, list) - assert len(result) == 3 - all_text = " ".join( - b.get("text", "") - for m in result for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, dict) - ) - assert "sum_cc" in all_text - - def test_both_caches_hit_result_structure(self): - """prev and current cache hit at the same time, result should include two summaries.""" - cm = make_cm(enabled=True, threshold=30) - - prev_t, prev_a = make_pair(f"prev_task:{'X'*50}", f"prev_action: {'Y'*50}", 0) - curr_t, curr_a = make_pair("curr_task", "curr_action", 1) - memory = AgentMemory( - steps=[prev_t, prev_a, curr_t, curr_a], - system_prompt=SystemPromptStep(system_prompt="system prompt"), - ) - - assert cm._estimate_tokens(memory) > cm.config.token_threshold - prev_fp = cm._pair_fingerprint(prev_t.task, prev_a.action_output) - cm._previous_summary_cache = PreviousSummaryCache("prev_sum", 1, prev_fp) - - curr_fp = ContextManager._action_fingerprint(curr_a) - cm._current_summary_cache = CurrentSummaryCache("curr_sum", 1, curr_fp) - - model = make_model() - original = make_original_messages(memory) - current_run_start_idx = 2 - - result = cm.compress_if_needed(model, memory, original, current_run_start_idx) - - model.assert_not_called() - assert isinstance(result, list) - assert len(result) == 4 - texts = [ - b.get("text", "") - for m in result for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, dict) - ] - assert any("prev_sum" in t for t in texts) - assert any("curr_sum" in t for t in texts) - assert cm._msg_token_count(result) < cm.config.token_threshold - - def test_mixed_prev_and_curr_over_threshold(self): - """previous + current both present and over threshold, should trigger compression separately.""" - cm = make_cm(enabled=True, threshold=1, keep_recent_pairs=1, keep_recent_steps=1) - memory = make_memory_mixed(n_prev_pairs=3, n_curr_actions=3) - original = make_original_messages(memory) - - current_run_start_idx = 6 - model = make_model('{"task_overview": "summary"}') - result = cm.compress_if_needed(model, memory, original, current_run_start_idx) - - assert result is not None - assert cm._previous_summary_cache is not None - assert cm._current_summary_cache is not None - assert isinstance(result, list) - assert len(result) < len(original) - assert model.call_count >= 2 - all_text = " ".join( - b.get("text", "") - for m in result for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, dict) - ) - assert "Summary of earlier steps" in all_text \ No newline at end of file +from factories import make_cm, make_pair, make_model, make_memory_mixed, make_original_messages +from loader import ( + AgentMemory, TaskStep, SystemPromptStep, CurrentSummaryCache, PreviousSummaryCache, + ContextManager, pair_fingerprint, action_fingerprint, estimate_tokens, msg_token_count, +) + + +def _all_texts(messages): + return [ + b.get("text", "") + for m in messages + for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, dict) + ] + + +def _joined(messages): + return " ".join(_all_texts(messages)) + + +class TestCompressIfNeeded: + + def test_disabled_returns_original_messages(self): + """config.enabled=False returns original_messages without any processing.""" + cm = make_cm(enabled=False, threshold=10) + n_prev_pairs = 1 + n_curr_actions = 1 + memory = make_memory_mixed(n_prev_pairs, n_curr_actions) + original = make_original_messages(memory) + current_run_start_idx = 2 * n_prev_pairs + result = cm.compress_if_needed(None, memory, original, current_run_start_idx=current_run_start_idx) + assert result is original + + def test_under_threshold_returns_original(self): + """raw tokens < threshold returns directly, no LLM call.""" + cm = make_cm(enabled=True, threshold=999999) + n_prev_pairs = 1 + n_curr_actions = 1 + memory = make_memory_mixed(n_prev_pairs, n_curr_actions) + original = make_original_messages(memory) + current_run_start_idx = 2 * n_prev_pairs + model = make_model() + result = cm.compress_if_needed(None, memory, original, current_run_start_idx=current_run_start_idx) + assert result is original + model.assert_not_called() + + def test_over_threshold_triggers_compression(self): + """raw tokens > threshold should call LLM (all previous-run scenario).""" + keep_recent_pairs = 1 + keep_recent_steps = 2 + cm = make_cm(enabled=True, threshold=10, keep_recent_steps=keep_recent_steps, keep_recent_pairs=keep_recent_pairs) + n_prev_pairs = 3 + n_curr_actions = 2 + memory = make_memory_mixed(n_prev_pairs=n_prev_pairs, n_curr_actions=n_curr_actions) + original = make_original_messages(memory) + assert len(original) == 1 + n_prev_pairs * 2 + 1 + n_curr_actions + current_run_start_idx = 2 * n_prev_pairs + model = make_model('{"task_overview": "summary"}') + result = cm.compress_if_needed(model, memory, original, current_run_start_idx) + assert result is not None + assert isinstance(result, list) + assert len(result) == 1 + 1 + 2 * keep_recent_pairs + 1 + keep_recent_steps + model.assert_called_once() + all_text = " ".join( + b.get("text", "") + for m in result for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, dict) + ) + assert "Summary of earlier steps" in all_text + + def test_run_boundary_clears_current_cache(self): + """Switching run (current_run_start_idx changes) and ensuring no current summary triggers, current cache should be cleared.""" + cm = make_cm(enabled=True, threshold=1) + cm._current_summary_cache = CurrentSummaryCache("old cache", 1, "fp") + cm._last_run_start_idx = 5 + memory = make_memory_mixed(1, 0) + original = make_original_messages(memory) + model = make_model('{"task_overview": "summary"}') + try: + cm.compress_if_needed(model, memory, original, current_run_start_idx=0) + except Exception: + pass + assert cm._current_summary_cache is None + + def test_effective_tokens_shortcut_applies_cache(self): + """effective tokens < threshold short-circuit, directly apply existing cache to build messages (all previous-run).""" + cm = make_cm(enabled=True, threshold=10, keep_recent_pairs=0) + pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(2)] + all_steps = [] + for t, a in pairs: + all_steps.extend([t, a]) + all_steps.append(TaskStep(task="New Task")) + memory = AgentMemory(steps=all_steps, system_prompt=SystemPromptStep(system_prompt="system prompt")) + last_t, last_a = pairs[1] + fp = pair_fingerprint(last_t.task, last_a.action_output) + cm._previous_summary_cache = PreviousSummaryCache("short summary", 2, fp) + + model = make_model('{"task_overview": "summary"}') + original = make_original_messages(memory) + current_run_start_idx = 2 * len(pairs) + result = cm.compress_if_needed(model, memory, original, current_run_start_idx) + model.assert_not_called() + assert isinstance(result, list) + assert len(result) == 3 + all_text = " ".join( + b.get("text", "") + for m in result for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, dict) + ) + assert "short summary" in all_text + + def test_current_run_cache_full_hit_no_llm_call(self): + """current cache fully hit, current part should be replaced by summary and no LLM call.""" + cm = make_cm(enabled=True, threshold=7) + curr_t, curr_a = make_pair("curr_task", "curr_action", 0) + memory = AgentMemory(steps=[curr_t, curr_a], system_prompt=SystemPromptStep(system_prompt="system prompt")) + + fp = action_fingerprint(curr_a) + cm._current_summary_cache = CurrentSummaryCache("sum_cc", 1, fp) + + model = make_model() + original = make_original_messages(memory) + result = cm.compress_if_needed(model, memory, original, current_run_start_idx=0) + + model.assert_not_called() + assert isinstance(result, list) + assert len(result) == 3 + all_text = " ".join( + b.get("text", "") + for m in result for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, dict) + ) + assert "sum_cc" in all_text + + def test_both_caches_hit_result_structure(self): + """prev and current cache hit at the same time, result should include two summaries.""" + cm = make_cm(enabled=True, threshold=30) + + prev_t, prev_a = make_pair(f"prev_task:{'X'*50}", f"prev_action: {'Y'*50}", 0) + curr_t, curr_a = make_pair("curr_task", "curr_action", 1) + memory = AgentMemory( + steps=[prev_t, prev_a, curr_t, curr_a], + system_prompt=SystemPromptStep(system_prompt="system prompt"), + ) + + assert estimate_tokens(memory, cm.config.chars_per_token) > cm.config.token_threshold + prev_fp = pair_fingerprint(prev_t.task, prev_a.action_output) + cm._previous_summary_cache = PreviousSummaryCache("prev_sum", 1, prev_fp) + + curr_fp = action_fingerprint(curr_a) + cm._current_summary_cache = CurrentSummaryCache("curr_sum", 1, curr_fp) + + model = make_model() + original = make_original_messages(memory) + current_run_start_idx = 2 + + result = cm.compress_if_needed(model, memory, original, current_run_start_idx) + + model.assert_not_called() + assert isinstance(result, list) + assert len(result) == 4 + texts = [ + b.get("text", "") + for m in result for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, dict) + ] + assert any("prev_sum" in t for t in texts) + assert any("curr_sum" in t for t in texts) + assert msg_token_count(result, cm.config.chars_per_token) < cm.config.token_threshold + + def test_mixed_prev_and_curr_over_threshold(self): + """previous + current both present and over threshold, should trigger compression separately.""" + cm = make_cm(enabled=True, threshold=1, keep_recent_pairs=1, keep_recent_steps=1) + memory = make_memory_mixed(n_prev_pairs=3, n_curr_actions=3) + original = make_original_messages(memory) + + current_run_start_idx = 6 + model = make_model('{"task_overview": "summary"}') + result = cm.compress_if_needed(model, memory, original, current_run_start_idx) + + assert result is not None + assert cm._previous_summary_cache is not None + assert cm._current_summary_cache is not None + assert isinstance(result, list) + assert len(result) < len(original) + assert model.call_count >= 2 + all_text = " ".join( + b.get("text", "") + for m in result for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, dict) + ) + assert "Summary of earlier steps" in all_text diff --git a/test/sdk/core/agents/test_agent_context/unit/test_compress_if_needed_extra.py b/test/sdk/core/agents/test_agent_context/unit/test_compress_if_needed_extra.py index e09f1090b..fc7f9bf90 100644 --- a/test/sdk/core/agents/test_agent_context/unit/test_compress_if_needed_extra.py +++ b/test/sdk/core/agents/test_agent_context/unit/test_compress_if_needed_extra.py @@ -1,367 +1,370 @@ -""" -unit/test_compress_if_needed_extra.py -Supplementary branch coverage for TestCompressIfNeeded. - -Existing tests cover: - G1 disabled / under-threshold / run-boundary / G2 both-cache / G2 prev-only / - G2 curr-only / main-path prev+curr both compress / main-path mixed - -This file adds (corresponding to branch diagram M1-M13): - M1 First call _last_run_start_idx=None -> no exception, no cache clear - M2 G2 shortcut no cache: return raw messages (no LLM call) - M3 compress_prev=True but pairs_to_compress empty (keep_n >= all pairs) - M4 compress_prev=True, LLM returns None -> raw prev displayed, no crash - M5 compress_prev=False with valid prev cache -> main path applies cache (not G2) - M6 compress_curr=True but actions_to_compress empty - M7 compress_curr=True, LLM returns None -> raw curr displayed, no crash - M8 compress_curr=False with valid curr cache -> main path applies cache (not G2) - M9 Only current-run (current_run_start_idx=0), no previous, over threshold, no cache - M10 keep_recent_pairs exceeds total pairs boundary handling - M11 prev+curr both LLM fail -> result still list, no crash - M12 No system_prompt -> no system message in result - M13 Each compress call clears _step_local_log -""" - -import sys -import os -sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) - -from unittest.mock import MagicMock, patch - -from factories import make_cm, make_pair, make_model, make_original_messages -from loader import ( - ActionStep, - AgentMemory, - ContextManager, - ContextManagerConfig, - CurrentSummaryCache, - PreviousSummaryCache, - SummaryTaskStep, - TaskStep, -) -from stubs import _SystemPromptStep as SystemPromptStep - - -def _all_texts(messages): - return [ - b.get("text", "") - for m in messages - for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, dict) - ] - - -def _joined(messages): - return " ".join(_all_texts(messages)) - - -class TestM1FirstCall: - - def test_first_call_no_exception_and_no_cache_clear(self): - """Initial state _last_run_start_idx=None, first call should not clear current cache.""" - cm = make_cm(enabled=True, threshold=999999) - cm._current_summary_cache = CurrentSummaryCache("existing summary", 1, "fp") - assert cm._last_run_start_idx is None - - t, a = make_pair("task", "action", 0) - memory = AgentMemory(steps=[t, a], system_prompt=None) - original = make_original_messages(memory) - - result = cm.compress_if_needed(None, memory, original, current_run_start_idx=2) - - assert result is original - assert cm._current_summary_cache is not None - - -class TestM2G2NoCacheRawReturn: - - def test_g2_shortcut_no_cache_returns_raw_messages(self): - """effective <= threshold but no cache, should use _build_messages to assemble raw steps.""" - cm = make_cm(enabled=True, threshold=10) - t, a = make_pair("x", "y", 0) - memory = AgentMemory(steps=[t, a], system_prompt=None) - original = make_original_messages(memory) - - with patch.object(cm, '_estimate_tokens', return_value=50): - with patch.object(cm, '_effective_tokens', return_value=5): - model = make_model() - result = cm.compress_if_needed(model, memory, original, current_run_start_idx=2) - - model.assert_not_called() - assert isinstance(result, list) - assert "Summary of earlier steps" not in _joined(result) - assert "x" in _joined(result) - - -class TestM3PairsToCompressEmpty: - - def test_compress_prev_true_but_all_pairs_kept_no_llm(self): - """keep_recent_pairs >= len(pairs), pairs_to_compress=[], should not call LLM. - All pairs retained in raw form. - """ - cm = make_cm(enabled=True, threshold=1, keep_recent_pairs=10) - t0, a0 = make_pair("task0 " + "X" * 50, "action0 " + "Y" * 50, 0) - t1, a1 = make_pair("task1 " + "X" * 50, "action1 " + "Y" * 50, 1) - memory = AgentMemory(steps=[t0, a0, t1, a1], system_prompt=None) - original = make_original_messages(memory) - - model = make_model('{"task_overview": "summary"}') - result = cm.compress_if_needed(model, memory, original, current_run_start_idx=4) - - model.assert_not_called() - assert isinstance(result, list) - assert "task0" in _joined(result) - assert "task1" in _joined(result) - - -class TestM4PrevLLMReturnsNone: - - def test_prev_llm_returns_none_raw_steps_shown(self): - """When _compress_previous_with_cache returns None, prev_summary_step=None, - raw prev steps appear in result, no crash. - """ - cm = make_cm(enabled=True, threshold=1, keep_recent_pairs=1) - t0, a0 = make_pair("task0 " + "X" * 50, "action0 " + "Y" * 50, 0) - t1, a1 = make_pair("task1 " + "X" * 50, "action1 " + "Y" * 50, 1) - memory = AgentMemory(steps=[t0, a0, t1, a1], system_prompt=None) - original = make_original_messages(memory) - - with patch.object(cm, '_compress_previous_with_cache', return_value=None): - model = make_model() - result = cm.compress_if_needed(model, memory, original, current_run_start_idx=4) - - assert isinstance(result, list) - assert "Summary of earlier steps" not in _joined(result) - assert "task1" in _joined(result) - - -class TestM5PrevCacheInMainPath: - - def test_compress_prev_false_with_valid_cache_applied_in_main_path(self): - """ - Scenario: effective_tokens > threshold (enter main path), - but prev_tokens <= threshold*0.6 (compress_prev=False), - and prev cache valid -> elif branch applies prev cache. - Different from G2 shortcut: G2 is effective <= threshold short-circuit. - """ - cm = make_cm(enabled=True, threshold=100, keep_recent_pairs=1) - - t, a = make_pair("prev_task" + "X" * 200, "prev_action" + "Y" * 200, 0) - curr_t, curr_a = make_pair("curr_task " + "X" * 200, "curr_action " + "Y" * 200, 1) - memory = AgentMemory( - steps=[t, a, curr_t, curr_a], - system_prompt=SystemPromptStep(system_prompt="sys"), - ) - - fp = cm._pair_fingerprint(t.task, a.action_output) - cm._previous_summary_cache = PreviousSummaryCache("prev_cached_summary", 1, fp) - - def mock_effective_prev(steps): - return 40 - - def mock_effective_curr(steps): - return 80 - - with patch.object(cm, '_effective_prev_tokens', side_effect=mock_effective_prev): - with patch.object(cm, '_effective_curr_tokens', side_effect=mock_effective_curr): - model = make_model('{"task_overview": "curr_summary"}') - original = make_original_messages(memory) - result = cm.compress_if_needed(model, memory, original, current_run_start_idx=2) - texts = _all_texts(result) - assert any("prev_cached_summary" in t for t in texts) - assert any("Summary of earlier steps" in t for t in texts) - - -class TestM6ActionsToCompressEmpty: - - def test_compress_curr_true_but_all_actions_kept_no_llm(self): - """keep_recent_steps >= len(action_steps), actions_to_compress=[], should not call LLM.""" - cm = make_cm(enabled=True, threshold=1, keep_recent_steps=10) - curr_t = TaskStep(task="current_task") - curr_a0 = ActionStep(step_number=0, model_output="output0 " + "Y" * 50, action_output="r0") - curr_a1 = ActionStep(step_number=1, model_output="output1 " + "Y" * 50, action_output="r1") - memory = AgentMemory(steps=[curr_t, curr_a0, curr_a1], system_prompt=None) - original = make_original_messages(memory) - - model = make_model('{"task_overview": "summary"}') - result = cm.compress_if_needed(model, memory, original, current_run_start_idx=0) - - model.assert_not_called() - assert isinstance(result, list) - assert "output0" in _joined(result) - assert "output1" in _joined(result) - - -class TestM7CurrLLMReturnsNone: - - def test_curr_llm_returns_none_raw_curr_shown(self): - """When _compress_current_with_cache returns None, curr_kept_steps=list(curr_steps), no crash.""" - cm = make_cm(enabled=True, threshold=1, keep_recent_steps=1) - curr_t = TaskStep(task="current_task") - curr_a0 = ActionStep(step_number=0, model_output="output0 " + "Y" * 50, action_output="r0") - curr_a1 = ActionStep(step_number=1, model_output="output1 " + "Y" * 50, action_output="r1") - memory = AgentMemory(steps=[curr_t, curr_a0, curr_a1], system_prompt=None) - original = make_original_messages(memory) - - with patch.object(cm, '_compress_current_with_cache', return_value=None): - model = make_model() - result = cm.compress_if_needed(model, memory, original, current_run_start_idx=0) - - assert isinstance(result, list) - assert "Summary of earlier steps" not in _joined(result) - assert "output0" in _joined(result) - assert "output1" in _joined(result) - - -class TestM8CurrCacheInMainPath: - - def test_compress_curr_false_with_valid_cache_applied_in_main_path(self): - """ - Scenario: effective_tokens > threshold, - prev_tokens > threshold*0.6 (compress_prev=True), - curr_tokens <= threshold*0.4 (compress_curr=False), - and curr cache valid -> elif branch applies curr cache. - """ - cm = make_cm(enabled=True, threshold=100, keep_recent_pairs=1) - - t0, a0 = make_pair("prev0 " + "X" * 100, "pa0 " + "Y" * 100, 0) - t1, a1 = make_pair("prev1 " + "X" * 100, "pa1 " + "Y" * 100, 1) - curr_t = TaskStep(task="curr_task") - curr_a = ActionStep(step_number=2, model_output="curr_out", action_output="curr_r") - memory = AgentMemory( - steps=[t0, a0, t1, a1, curr_t, curr_a], - system_prompt=SystemPromptStep(system_prompt="sys"), - ) - - fp = ContextManager._action_fingerprint(curr_a) - cm._current_summary_cache = CurrentSummaryCache("curr_cached_summary", 1, fp) - - def mock_effective_prev(steps): - return 80 - - def mock_effective_curr(steps): - return 30 - - with patch.object(cm, '_effective_prev_tokens', side_effect=mock_effective_prev): - with patch.object(cm, '_effective_curr_tokens', side_effect=mock_effective_curr): - model = make_model('{"task_overview": "prev_summary"}') - original = make_original_messages(memory) - result = cm.compress_if_needed(model, memory, original, current_run_start_idx=4) - - texts = _all_texts(result) - assert any("curr_cached_summary" in t for t in texts) - model.assert_called_once() - assert "prev_summary" in _joined(result) - - -class TestM9OnlyCurrentNoCache: - - def test_only_current_run_over_threshold_triggers_curr_compression(self): - """current_run_start_idx=0: all current-run, no prev, over threshold, no cache. - Should compress curr and call LLM once. - """ - cm = make_cm(enabled=True, threshold=1, keep_recent_steps=1) - curr_t = TaskStep(task="current_task " + "X" * 50) - actions = [ - ActionStep(step_number=i, model_output=f"output{i} " + "Y" * 50, action_output=f"r{i}") - for i in range(3) - ] - memory = AgentMemory(steps=[curr_t] + actions, system_prompt=None) - original = make_original_messages(memory) - - model = make_model('{"task_overview": "curr_summary"}') - result = cm.compress_if_needed(model, memory, original, current_run_start_idx=0) - - assert result is not None - assert isinstance(result, list) - assert len(result) < len(original) - model.assert_called_once() - assert "Summary of earlier steps" in _joined(result) - - -class TestM10KeepRecentPairsBoundary: - - def test_keep_recent_pairs_larger_than_total_pairs_keeps_all(self): - """keep_recent_pairs=999, pairs_to_compress=[], all pairs retained in raw form.""" - cm = make_cm(enabled=True, threshold=1, keep_recent_pairs=999) - pairs = [make_pair(f"task{i} " + "X" * 20, f"action{i} " + "Y" * 20, i) for i in range(3)] - steps = [s for t, a in pairs for s in (t, a)] - memory = AgentMemory(steps=steps, system_prompt=None) - original = make_original_messages(memory) - - model = make_model('{"task_overview": "summary"}') - result = cm.compress_if_needed(model, memory, original, current_run_start_idx=6) - - model.assert_not_called() - for i in range(3): - assert f"task{i}" in _joined(result) - - -class TestM11BothLLMFail: - - def test_both_llm_calls_return_none_still_returns_list(self): - """When both compression calls return None, result is still valid list, no exception.""" - cm = make_cm(enabled=True, threshold=1, keep_recent_pairs=1, keep_recent_steps=1) - - t0, a0 = make_pair("prev " + "X" * 50, "pa " + "Y" * 50, 0) - t1, a1 = make_pair("prev1 " + "X" * 50, "pa1 " + "Y" * 50, 1) - curr_t = TaskStep(task="curr_task " + "X" * 50) - curr_a0 = ActionStep(step_number=2, model_output="cout0 " + "Y" * 50, action_output="r0") - curr_a1 = ActionStep(step_number=3, model_output="cout1 " + "Y" * 50, action_output="r1") - memory = AgentMemory( - steps=[t0, a0, t1, a1, curr_t, curr_a0, curr_a1], - system_prompt=SystemPromptStep(system_prompt="sys"), - ) - original = make_original_messages(memory) - - with patch.object(cm, '_compress_previous_with_cache', return_value=None): - with patch.object(cm, '_compress_current_with_cache', return_value=None): - result = cm.compress_if_needed(None, memory, original, current_run_start_idx=4) - - assert isinstance(result, list) - assert len(result) > 0 - - -class TestM12NoSystemPrompt: - - def test_no_system_prompt_no_system_message_in_result(self): - """memory.system_prompt=None, _build_messages should not produce system role message.""" - from stubs import _MessageRole - cm = make_cm(enabled=True, threshold=1, keep_recent_pairs=1) - t, a = make_pair("task " + "X" * 50, "action " + "Y" * 50, 0) - t1, a1 = make_pair("task1 " + "X" * 50, "action1 " + "Y" * 50, 1) - memory = AgentMemory(steps=[t, a, t1, a1], system_prompt=None) - original = make_original_messages(memory) - - model = make_model('{"task_overview": "summary"}') - result = cm.compress_if_needed(model, memory, original, current_run_start_idx=4) - - roles = [m.role for m in result] - assert _MessageRole.SYSTEM not in roles - - -class TestM13StepLocalLogCleared: - - def test_step_local_log_cleared_at_start_of_each_compress_call(self): - """Two consecutive compression calls, the second _step_local_log should not contain records from the first.""" - cm = make_cm(enabled=True, threshold=1, keep_recent_pairs=1) - - def _make_mem(): - t0, a0 = make_pair("task0 " + "X" * 50, "action0 " + "Y" * 50, 0) - t1, a1 = make_pair("task1 " + "X" * 50, "action1 " + "Y" * 50, 1) - return AgentMemory(steps=[t0, a0, t1, a1], system_prompt=None) - - model = make_model('{"task_overview": "summary"}') - - mem1 = _make_mem() - cm.compress_if_needed(model, mem1, make_original_messages(mem1), current_run_start_idx=4) - count_after_first = len(cm._step_local_log) - assert count_after_first == 1 - assert cm._step_local_log[0].call_type == "previous_summary" - - mem2 = _make_mem() - cm.compress_if_needed(model, mem2, make_original_messages(mem2), current_run_start_idx=4) - count_after_second = len(cm._step_local_log) - # reuse Previous_summary_cache; cache hit is still recorded in _step_local_log - assert count_after_second == 1 - assert cm._step_local_log[0].call_type == "previous_cache_hit" \ No newline at end of file +""" +unit/test_compress_if_needed_extra.py +Supplementary branch coverage for TestCompressIfNeeded. + +Existing tests cover: + G1 disabled / under-threshold / run-boundary / G2 both-cache / G2 prev-only / + G2 curr-only / main-path prev+curr both compress / main-path mixed + +This file adds (corresponding to branch diagram M1-M13): + M1 First call _last_run_start_idx=None -> no exception, no cache clear + M2 G2 shortcut no cache: return raw messages (no LLM call) + M3 compress_prev=True but pairs_to_compress empty (keep_n >= all pairs) + M4 compress_prev=True, LLM returns None -> raw prev displayed, no crash + M5 compress_prev=False with valid prev cache -> main path applies cache (not G2) + M6 compress_curr=True but actions_to_compress empty + M7 compress_curr=True, LLM returns None -> raw curr displayed, no crash + M8 compress_curr=False with valid curr cache -> main path applies cache (not G2) + M9 Only current-run (current_run_start_idx=0), no previous, over threshold, no cache + M10 keep_recent_pairs exceeds total pairs boundary handling + M11 prev+curr both LLM fail -> result still list, no crash + M12 No system_prompt -> no system message in result + M13 Each compress call clears _step_local_log +""" + +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + +from unittest.mock import MagicMock, patch + +from factories import make_cm, make_pair, make_model, make_original_messages +from loader import ( + ActionStep, + AgentMemory, + ContextManager, + ContextManagerConfig, + CurrentSummaryCache, + PreviousSummaryCache, + SummaryTaskStep, + TaskStep, + PreviousCompressResult, + CurrentCompressResult, + pair_fingerprint, + action_fingerprint, +) +from stubs import _SystemPromptStep as SystemPromptStep + + +def _all_texts(messages): + return [ + b.get("text", "") + for m in messages + for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, dict) + ] + + +def _joined(messages): + return " ".join(_all_texts(messages)) + + +class TestM1FirstCall: + + def test_first_call_no_exception_and_no_cache_clear(self): + """Initial state _last_run_start_idx=None, first call should not clear current cache.""" + cm = make_cm(enabled=True, threshold=999999) + cm._current_summary_cache = CurrentSummaryCache("existing summary", 1, "fp") + assert cm._last_run_start_idx is None + + t, a = make_pair("task", "action", 0) + memory = AgentMemory(steps=[t, a], system_prompt=None) + original = make_original_messages(memory) + + result = cm.compress_if_needed(None, memory, original, current_run_start_idx=2) + + assert result is original + assert cm._current_summary_cache is not None + + +class TestM2G2NoCacheRawReturn: + + def test_g2_shortcut_no_cache_returns_raw_messages(self): + """effective <= threshold but no cache, should use build_messages to assemble raw steps.""" + cm = make_cm(enabled=True, threshold=10) + t, a = make_pair("x", "y", 0) + memory = AgentMemory(steps=[t, a], system_prompt=None) + original = make_original_messages(memory) + + with patch.object(cm, '_effective_tokens', return_value=5): + model = make_model() + result = cm.compress_if_needed(model, memory, original, current_run_start_idx=2) + + model.assert_not_called() + assert isinstance(result, list) + assert "Summary of earlier steps" not in _joined(result) + assert "x" in _joined(result) + + +class TestM3PairsToCompressEmpty: + + def test_compress_prev_true_but_all_pairs_kept_no_llm(self): + """keep_recent_pairs >= len(pairs), pairs_to_compress=[], should not call LLM. + All pairs retained in raw form. + """ + cm = make_cm(enabled=True, threshold=1, keep_recent_pairs=10) + t0, a0 = make_pair("task0 " + "X" * 50, "action0 " + "Y" * 50, 0) + t1, a1 = make_pair("task1 " + "X" * 50, "action1 " + "Y" * 50, 1) + memory = AgentMemory(steps=[t0, a0, t1, a1], system_prompt=None) + original = make_original_messages(memory) + + model = make_model('{"task_overview": "summary"}') + result = cm.compress_if_needed(model, memory, original, current_run_start_idx=4) + + model.assert_not_called() + assert isinstance(result, list) + assert "task0" in _joined(result) + assert "task1" in _joined(result) + + +class TestM4PrevLLMReturnsNone: + + def test_prev_llm_returns_none_raw_steps_shown(self): + """When _prev_compressor.compress returns summary_text=None, prev_summary_step=None, + raw prev steps appear in result, no crash. + """ + cm = make_cm(enabled=True, threshold=1, keep_recent_pairs=1) + t0, a0 = make_pair("task0 " + "X" * 50, "action0 " + "Y" * 50, 0) + t1, a1 = make_pair("task1 " + "X" * 50, "action1 " + "Y" * 50, 1) + memory = AgentMemory(steps=[t0, a0, t1, a1], system_prompt=None) + original = make_original_messages(memory) + + with patch.object(cm._prev_compressor, 'compress', return_value=PreviousCompressResult(summary_text=None)): + model = make_model() + result = cm.compress_if_needed(model, memory, original, current_run_start_idx=4) + + assert isinstance(result, list) + assert "Summary of earlier steps" not in _joined(result) + assert "task1" in _joined(result) + + +class TestM5PrevCacheInMainPath: + + def test_compress_prev_false_with_valid_cache_applied_in_main_path(self): + """ + Scenario: effective_tokens > threshold (enter main path), + but prev_tokens <= threshold*0.6 (compress_prev=False), + and prev cache valid -> elif branch applies prev cache. + Different from G2 shortcut: G2 is effective <= threshold short-circuit. + """ + cm = make_cm(enabled=True, threshold=100, keep_recent_pairs=1) + + t, a = make_pair("prev_task" + "X" * 200, "prev_action" + "Y" * 200, 0) + curr_t, curr_a = make_pair("curr_task " + "X" * 200, "curr_action " + "Y" * 200, 1) + memory = AgentMemory( + steps=[t, a, curr_t, curr_a], + system_prompt=SystemPromptStep(system_prompt="sys"), + ) + + fp = pair_fingerprint(t.task, a.action_output) + cm._previous_summary_cache = PreviousSummaryCache("prev_cached_summary", 1, fp) + + def mock_effective_prev(steps): + return 40 + + def mock_effective_curr(steps): + return 80 + + with patch.object(cm, '_effective_prev_tokens', side_effect=mock_effective_prev): + with patch.object(cm, '_effective_curr_tokens', side_effect=mock_effective_curr): + model = make_model('{"task_overview": "curr_summary"}') + original = make_original_messages(memory) + result = cm.compress_if_needed(model, memory, original, current_run_start_idx=2) + texts = _all_texts(result) + assert any("prev_cached_summary" in t for t in texts) + assert any("Summary of earlier steps" in t for t in texts) + + +class TestM6ActionsToCompressEmpty: + + def test_compress_curr_true_but_all_actions_kept_no_llm(self): + """keep_recent_steps >= len(action_steps), actions_to_compress=[], should not call LLM.""" + cm = make_cm(enabled=True, threshold=1, keep_recent_steps=10) + curr_t = TaskStep(task="current_task") + curr_a0 = ActionStep(step_number=0, model_output="output0 " + "Y" * 50, action_output="r0") + curr_a1 = ActionStep(step_number=1, model_output="output1 " + "Y" * 50, action_output="r1") + memory = AgentMemory(steps=[curr_t, curr_a0, curr_a1], system_prompt=None) + original = make_original_messages(memory) + + model = make_model('{"task_overview": "summary"}') + result = cm.compress_if_needed(model, memory, original, current_run_start_idx=0) + + model.assert_not_called() + assert isinstance(result, list) + assert "output0" in _joined(result) + assert "output1" in _joined(result) + + +class TestM7CurrLLMReturnsNone: + + def test_curr_llm_returns_none_raw_curr_shown(self): + """When _curr_compressor.compress returns summary_text=None, curr_kept_steps=list(curr_steps), no crash.""" + cm = make_cm(enabled=True, threshold=1, keep_recent_steps=1) + curr_t = TaskStep(task="current_task") + curr_a0 = ActionStep(step_number=0, model_output="output0 " + "Y" * 50, action_output="r0") + curr_a1 = ActionStep(step_number=1, model_output="output1 " + "Y" * 50, action_output="r1") + memory = AgentMemory(steps=[curr_t, curr_a0, curr_a1], system_prompt=None) + original = make_original_messages(memory) + + with patch.object(cm._curr_compressor, 'compress', return_value=CurrentCompressResult(summary_text=None)): + model = make_model() + result = cm.compress_if_needed(model, memory, original, current_run_start_idx=0) + + assert isinstance(result, list) + assert "Summary of earlier steps" not in _joined(result) + assert "output0" in _joined(result) + assert "output1" in _joined(result) + + +class TestM8CurrCacheInMainPath: + + def test_compress_curr_false_with_valid_cache_applied_in_main_path(self): + """ + Scenario: effective_tokens > threshold, + prev_tokens > threshold*0.6 (compress_prev=True), + curr_tokens <= threshold*0.4 (compress_curr=False), + and curr cache valid -> elif branch applies curr cache. + """ + cm = make_cm(enabled=True, threshold=100, keep_recent_pairs=1) + + t0, a0 = make_pair("prev0 " + "X" * 100, "pa0 " + "Y" * 100, 0) + t1, a1 = make_pair("prev1 " + "X" * 100, "pa1 " + "Y" * 100, 1) + curr_t = TaskStep(task="curr_task") + curr_a = ActionStep(step_number=2, model_output="curr_out", action_output="curr_r") + memory = AgentMemory( + steps=[t0, a0, t1, a1, curr_t, curr_a], + system_prompt=SystemPromptStep(system_prompt="sys"), + ) + + fp = action_fingerprint(curr_a) + cm._current_summary_cache = CurrentSummaryCache("curr_cached_summary", 1, fp) + + def mock_effective_prev(steps): + return 80 + + def mock_effective_curr(steps): + return 30 + + with patch.object(cm, '_effective_prev_tokens', side_effect=mock_effective_prev): + with patch.object(cm, '_effective_curr_tokens', side_effect=mock_effective_curr): + model = make_model('{"task_overview": "prev_summary"}') + original = make_original_messages(memory) + result = cm.compress_if_needed(model, memory, original, current_run_start_idx=4) + + texts = _all_texts(result) + assert any("curr_cached_summary" in t for t in texts) + model.assert_called_once() + assert "prev_summary" in _joined(result) + + +class TestM9OnlyCurrentNoCache: + + def test_only_current_run_over_threshold_triggers_curr_compression(self): + """current_run_start_idx=0: all current-run, no prev, over threshold, no cache. + Should compress curr and call LLM once. + """ + cm = make_cm(enabled=True, threshold=1, keep_recent_steps=1) + curr_t = TaskStep(task="current_task " + "X" * 50) + actions = [ + ActionStep(step_number=i, model_output=f"output{i} " + "Y" * 50, action_output=f"r{i}") + for i in range(3) + ] + memory = AgentMemory(steps=[curr_t] + actions, system_prompt=None) + original = make_original_messages(memory) + + model = make_model('{"task_overview": "curr_summary"}') + result = cm.compress_if_needed(model, memory, original, current_run_start_idx=0) + + assert result is not None + assert isinstance(result, list) + assert len(result) < len(original) + model.assert_called_once() + assert "Summary of earlier steps" in _joined(result) + + +class TestM10KeepRecentPairsBoundary: + + def test_keep_recent_pairs_larger_than_total_pairs_keeps_all(self): + """keep_recent_pairs=999, pairs_to_compress=[], all pairs retained in raw form.""" + cm = make_cm(enabled=True, threshold=1, keep_recent_pairs=999) + pairs = [make_pair(f"task{i} " + "X" * 20, f"action{i} " + "Y" * 20, i) for i in range(3)] + steps = [s for t, a in pairs for s in (t, a)] + memory = AgentMemory(steps=steps, system_prompt=None) + original = make_original_messages(memory) + + model = make_model('{"task_overview": "summary"}') + result = cm.compress_if_needed(model, memory, original, current_run_start_idx=6) + + model.assert_not_called() + for i in range(3): + assert f"task{i}" in _joined(result) + + +class TestM11BothLLMFail: + + def test_both_llm_calls_return_none_still_returns_list(self): + """When both compression calls return summary_text=None, result is still valid list, no exception.""" + cm = make_cm(enabled=True, threshold=1, keep_recent_pairs=1, keep_recent_steps=1) + + t0, a0 = make_pair("prev " + "X" * 50, "pa " + "Y" * 50, 0) + t1, a1 = make_pair("prev1 " + "X" * 50, "pa1 " + "Y" * 50, 1) + curr_t = TaskStep(task="curr_task " + "X" * 50) + curr_a0 = ActionStep(step_number=2, model_output="cout0 " + "Y" * 50, action_output="r0") + curr_a1 = ActionStep(step_number=3, model_output="cout1 " + "Y" * 50, action_output="r1") + memory = AgentMemory( + steps=[t0, a0, t1, a1, curr_t, curr_a0, curr_a1], + system_prompt=SystemPromptStep(system_prompt="sys"), + ) + original = make_original_messages(memory) + + with patch.object(cm._prev_compressor, 'compress', return_value=PreviousCompressResult(summary_text=None)): + with patch.object(cm._curr_compressor, 'compress', return_value=CurrentCompressResult(summary_text=None)): + result = cm.compress_if_needed(None, memory, original, current_run_start_idx=4) + + assert isinstance(result, list) + assert len(result) > 0 + + +class TestM12NoSystemPrompt: + + def test_no_system_prompt_no_system_message_in_result(self): + """memory.system_prompt=None, build_messages should not produce system role message.""" + from stubs import _MessageRole + cm = make_cm(enabled=True, threshold=1, keep_recent_pairs=1) + t, a = make_pair("task " + "X" * 50, "action " + "Y" * 50, 0) + t1, a1 = make_pair("task1 " + "X" * 50, "action1 " + "Y" * 50, 1) + memory = AgentMemory(steps=[t, a, t1, a1], system_prompt=None) + original = make_original_messages(memory) + + model = make_model('{"task_overview": "summary"}') + result = cm.compress_if_needed(model, memory, original, current_run_start_idx=4) + + roles = [m.role for m in result] + assert _MessageRole.SYSTEM not in roles + + +class TestM13StepLocalLogCleared: + + def test_step_local_log_cleared_at_start_of_each_compress_call(self): + """Two consecutive compression calls, the second _step_local_log should not contain records from the first.""" + cm = make_cm(enabled=True, threshold=1, keep_recent_pairs=1) + + def _make_mem(): + t0, a0 = make_pair("task0 " + "X" * 50, "action0 " + "Y" * 50, 0) + t1, a1 = make_pair("task1 " + "X" * 50, "action1 " + "Y" * 50, 1) + return AgentMemory(steps=[t0, a0, t1, a1], system_prompt=None) + + model = make_model('{"task_overview": "summary"}') + + mem1 = _make_mem() + cm.compress_if_needed(model, mem1, make_original_messages(mem1), current_run_start_idx=4) + count_after_first = len(cm._step_local_log) + assert count_after_first == 1 + assert cm._step_local_log[0].call_type == "previous_summary" + + mem2 = _make_mem() + cm.compress_if_needed(model, mem2, make_original_messages(mem2), current_run_start_idx=4) + count_after_second = len(cm._step_local_log) + # reuse Previous_summary_cache; cache hit is still recorded in _step_local_log + assert count_after_second == 1 + assert cm._step_local_log[0].call_type == "previous_cache_hit" diff --git a/test/sdk/core/agents/test_agent_context/unit/test_compress_with_cache.py b/test/sdk/core/agents/test_agent_context/unit/test_compress_with_cache.py index 01d05b348..9d0e05436 100644 --- a/test/sdk/core/agents/test_agent_context/unit/test_compress_with_cache.py +++ b/test/sdk/core/agents/test_agent_context/unit/test_compress_with_cache.py @@ -1,144 +1,171 @@ -from factories import make_cm, make_pair, make_model -from loader import ActionStep, PreviousSummaryCache, ContextManager, CurrentSummaryCache, TaskStep - - -class TestCompressPreviousWithCache: - - def _make_pairs_with_cache(self, n=2): - """Generate n pairs and preset full cache hit.""" - cm = make_cm() - pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(n)] - last_t, last_a = pairs[-1] - fp = cm._pair_fingerprint(last_t.task, last_a.action_output) - cm._previous_summary_cache = PreviousSummaryCache( - summary_text="existing summary", covered_pairs=n, anchor_fingerprint=fp - ) - return cm, pairs - - def test_previous_full_cache_hit_no_llm_call(self): - cm, pairs = self._make_pairs_with_cache(n=2) - model = make_model() - result = cm._compress_previous_with_cache(pairs, model) - assert result == "existing summary" - model.assert_not_called() - - def test_previous_incremental_calls_llm_with_old_summary(self): - cm = make_cm() - pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(3)] - anchor_t, anchor_a = pairs[1] - fp = cm._pair_fingerprint(anchor_t.task, anchor_a.action_output) - cm._previous_summary_cache = PreviousSummaryCache( - summary_text="old summary", covered_pairs=2, anchor_fingerprint=fp - ) - model = make_model('{"task_overview": "incremental summary"}') - result = cm._compress_previous_with_cache(pairs, model) - assert result is not None - model.assert_called_once() - call_args = model.call_args[0][0] - full_text = " ".join( - b.get("text", "") for m in call_args for b in (m.content if isinstance(m.content, list) else []) - ) - assert "old summary" in full_text - - def test_previous_fresh_compress_writes_cache(self): - cm = make_cm() - pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(2)] - model = make_model('{"task_overview": "full summary"}') - result = cm._compress_previous_with_cache(pairs, model) - assert result is not None - assert cm._previous_summary_cache is not None - assert cm._previous_summary_cache.covered_pairs == 2 - - def test_previous_incremental_updates_cache_to_full_coverage(self): - cm = make_cm() - pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(3)] - anchor_t, anchor_a = pairs[1] - fp = cm._pair_fingerprint(anchor_t.task, anchor_a.action_output) - cm._previous_summary_cache = PreviousSummaryCache("old summary", 2, fp) - model = make_model('{"task_overview": "new summary"}') - cm._compress_previous_with_cache(pairs, model) - assert cm._previous_summary_cache.covered_pairs == 3 - assert "new summary" in cm._previous_summary_cache.summary_text - - def test_previous_fingerprint_mismatch_falls_through_to_fresh(self): - cm = make_cm() - pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(3)] - cm._previous_summary_cache = PreviousSummaryCache("old summary", 2, "wrong_fp") - model = make_model('{"task_overview": "fresh summary"}') - result = cm._compress_previous_with_cache(pairs, model) - assert result is not None - call_args = model.call_args[0][0] - full_text = " ".join( - b.get("text", "") for m in call_args for b in (m.content if isinstance(m.content, list) else []) - ) - assert "old summary" not in full_text - assert cm._previous_summary_cache.covered_pairs == 3 - - def test_previous_empty_pairs_returns_none(self): - cm = make_cm() - model = make_model() - assert cm._compress_previous_with_cache([], model) is None - model.assert_not_called() - - -class TestCompressCurrentWithCache: - - def _make_actions_with_cache(self, n=2): - cm = make_cm() - actions = [ActionStep(step_number=i, model_output=f"output{i}", action_output=f"result{i}") for i in range(n)] - fp = ContextManager._action_fingerprint(actions[-1]) - cm._current_summary_cache = CurrentSummaryCache("existing step summary", n, fp) - return cm, actions - - def test_current_full_cache_hit_no_llm_call(self): - cm, actions = self._make_actions_with_cache(n=2) - model = make_model() - task = TaskStep(task="current task") - result = cm._compress_current_with_cache(task, actions, model) - assert result == "existing step summary" - model.assert_not_called() - - def test_current_incremental_calls_llm(self): - cm = make_cm() - actions = [ActionStep(step_number=i, model_output=f"output{i}", action_output=f"result{i}") for i in range(3)] - fp = ContextManager._action_fingerprint(actions[1]) - cm._current_summary_cache = CurrentSummaryCache("old step summary", 2, fp) - model = make_model('{"task_overview": "incremental step summary"}') - task = TaskStep(task="task") - result = cm._compress_current_with_cache(task, actions, model) - assert "incremental step" in result - assert "old step" not in result - assert cm._current_summary_cache.end_steps == 3 - model.assert_called_once() - - def test_current_fresh_writes_cache(self): - cm = make_cm() - actions = [ActionStep(step_number=i, model_output=f"output{i}", action_output=f"result{i}") for i in range(2)] - model = make_model('{"task_overview": "fresh step summary"}') - task = TaskStep(task="task") - cm._compress_current_with_cache(task, actions, model) - assert cm._current_summary_cache is not None - assert cm._current_summary_cache.end_steps == 2 - - def test_current_no_task_step(self): - cm = make_cm() - actions = [ActionStep(step_number=1, model_output="output", action_output="result")] - model = make_model('{"task_overview": "summary"}') - result = cm._compress_current_with_cache(None, actions, model) - assert result is not None - - def test_current_empty_actions_returns_none(self): - cm = make_cm() - model = make_model() - assert cm._compress_current_with_cache(TaskStep(task="t"), [], model) is None - model.assert_not_called() - - def test_current_incremental_updates_anchor_fingerprint(self): - cm = make_cm() - actions = [ActionStep(step_number=i, model_output=f"o{i}", action_output=f"r{i}") for i in range(3)] - fp_old = ContextManager._action_fingerprint(actions[1]) - cm._current_summary_cache = CurrentSummaryCache("old summary", 2, fp_old) - model = make_model('{"task_overview": "new summary"}') - cm._compress_current_with_cache(TaskStep(task="t"), actions, model) - fp_new = ContextManager._action_fingerprint(actions[2]) - assert cm._current_summary_cache.anchor_fingerprint == fp_new \ No newline at end of file +from factories import make_cm, make_pair, make_model +from loader import ( + ActionStep, PreviousSummaryCache, ContextManager, CurrentSummaryCache, TaskStep, + pair_fingerprint, action_fingerprint, +) + + +def _compress_previous_with_cache(cm, pairs, model): + """Helper that mimics old cm._compress_previous_with_cache(pairs, model) behavior. + + Calls cm._prev_compressor.compress(), applies cache update, and returns + summary_text (str or None) -- the same return type as the old method. + """ + result = cm._prev_compressor.compress(pairs, cm._previous_summary_cache, model) + if result.new_cache is not None: + cm._previous_summary_cache = result.new_cache + return result.summary_text + + +def _compress_current_with_cache(cm, task, actions, model): + """Helper that mimics old cm._compress_current_with_cache(task, actions, model) behavior. + + Calls cm._curr_compressor.compress(), applies cache update, and returns + summary_text (str or None) -- the same return type as the old method. + """ + result = cm._curr_compressor.compress(task, actions, cm._current_summary_cache, model) + if result.new_cache is not None: + cm._current_summary_cache = result.new_cache + return result.summary_text + + +class TestCompressPreviousWithCache: + + def _make_pairs_with_cache(self, n=2): + """Generate n pairs and preset full cache hit.""" + cm = make_cm() + pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(n)] + last_t, last_a = pairs[-1] + fp = pair_fingerprint(last_t.task, last_a.action_output) + cm._previous_summary_cache = PreviousSummaryCache( + summary_text="existing summary", covered_pairs=n, anchor_fingerprint=fp + ) + return cm, pairs + + def test_previous_full_cache_hit_no_llm_call(self): + cm, pairs = self._make_pairs_with_cache(n=2) + model = make_model() + result = _compress_previous_with_cache(cm, pairs, model) + assert result == "existing summary" + model.assert_not_called() + + def test_previous_incremental_calls_llm_with_old_summary(self): + cm = make_cm() + pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(3)] + anchor_t, anchor_a = pairs[1] + fp = pair_fingerprint(anchor_t.task, anchor_a.action_output) + cm._previous_summary_cache = PreviousSummaryCache( + summary_text="old summary", covered_pairs=2, anchor_fingerprint=fp + ) + model = make_model('{"task_overview": "incremental summary"}') + result = _compress_previous_with_cache(cm, pairs, model) + assert result is not None + model.assert_called_once() + call_args = model.call_args[0][0] + full_text = " ".join( + b.get("text", "") for m in call_args for b in (m.content if isinstance(m.content, list) else []) + ) + assert "old summary" in full_text + + def test_previous_fresh_compress_writes_cache(self): + cm = make_cm() + pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(2)] + model = make_model('{"task_overview": "full summary"}') + result = _compress_previous_with_cache(cm, pairs, model) + assert result is not None + assert cm._previous_summary_cache is not None + assert cm._previous_summary_cache.covered_pairs == 2 + + def test_previous_incremental_updates_cache_to_full_coverage(self): + cm = make_cm() + pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(3)] + anchor_t, anchor_a = pairs[1] + fp = pair_fingerprint(anchor_t.task, anchor_a.action_output) + cm._previous_summary_cache = PreviousSummaryCache("old summary", 2, fp) + model = make_model('{"task_overview": "new summary"}') + _compress_previous_with_cache(cm, pairs, model) + assert cm._previous_summary_cache.covered_pairs == 3 + assert "new summary" in cm._previous_summary_cache.summary_text + + def test_previous_fingerprint_mismatch_falls_through_to_fresh(self): + cm = make_cm() + pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(3)] + cm._previous_summary_cache = PreviousSummaryCache("old summary", 2, "wrong_fp") + model = make_model('{"task_overview": "fresh summary"}') + result = _compress_previous_with_cache(cm, pairs, model) + assert result is not None + call_args = model.call_args[0][0] + full_text = " ".join( + b.get("text", "") for m in call_args for b in (m.content if isinstance(m.content, list) else []) + ) + assert "old summary" not in full_text + assert cm._previous_summary_cache.covered_pairs == 3 + + def test_previous_empty_pairs_returns_none(self): + cm = make_cm() + model = make_model() + assert _compress_previous_with_cache(cm, [], model) is None + model.assert_not_called() + + +class TestCompressCurrentWithCache: + + def _make_actions_with_cache(self, n=2): + cm = make_cm() + actions = [ActionStep(step_number=i, model_output=f"output{i}", action_output=f"result{i}") for i in range(n)] + fp = action_fingerprint(actions[-1]) + cm._current_summary_cache = CurrentSummaryCache("existing step summary", n, fp) + return cm, actions + + def test_current_full_cache_hit_no_llm_call(self): + cm, actions = self._make_actions_with_cache(n=2) + model = make_model() + task = TaskStep(task="current task") + result = _compress_current_with_cache(cm, task, actions, model) + assert result == "existing step summary" + model.assert_not_called() + + def test_current_incremental_calls_llm(self): + cm = make_cm() + actions = [ActionStep(step_number=i, model_output=f"output{i}", action_output=f"result{i}") for i in range(3)] + fp = action_fingerprint(actions[1]) + cm._current_summary_cache = CurrentSummaryCache("old step summary", 2, fp) + model = make_model('{"task_overview": "incremental step summary"}') + task = TaskStep(task="task") + result = _compress_current_with_cache(cm, task, actions, model) + assert "incremental step" in result + assert "old step" not in result + assert cm._current_summary_cache.end_steps == 3 + model.assert_called_once() + + def test_current_fresh_writes_cache(self): + cm = make_cm() + actions = [ActionStep(step_number=i, model_output=f"output{i}", action_output=f"result{i}") for i in range(2)] + model = make_model('{"task_overview": "fresh step summary"}') + task = TaskStep(task="task") + _compress_current_with_cache(cm, task, actions, model) + assert cm._current_summary_cache is not None + assert cm._current_summary_cache.end_steps == 2 + + def test_current_no_task_step(self): + cm = make_cm() + actions = [ActionStep(step_number=1, model_output="output", action_output="result")] + model = make_model('{"task_overview": "summary"}') + result = _compress_current_with_cache(cm, None, actions, model) + assert result is not None + + def test_current_empty_actions_returns_none(self): + cm = make_cm() + model = make_model() + assert _compress_current_with_cache(cm, TaskStep(task="t"), [], model) is None + model.assert_not_called() + + def test_current_incremental_updates_anchor_fingerprint(self): + cm = make_cm() + actions = [ActionStep(step_number=i, model_output=f"o{i}", action_output=f"r{i}") for i in range(3)] + fp_old = action_fingerprint(actions[1]) + cm._current_summary_cache = CurrentSummaryCache("old summary", 2, fp_old) + model = make_model('{"task_overview": "new summary"}') + _compress_current_with_cache(cm, TaskStep(task="t"), actions, model) + fp_new = action_fingerprint(actions[2]) + assert cm._current_summary_cache.anchor_fingerprint == fp_new diff --git a/test/sdk/core/agents/test_agent_context/unit/test_compress_with_cache_extra.py b/test/sdk/core/agents/test_agent_context/unit/test_compress_with_cache_extra.py index a0fcf0ff0..0b0b46cd5 100644 --- a/test/sdk/core/agents/test_agent_context/unit/test_compress_with_cache_extra.py +++ b/test/sdk/core/agents/test_agent_context/unit/test_compress_with_cache_extra.py @@ -1,256 +1,277 @@ -import sys -import os -sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) - -from unittest.mock import MagicMock, patch - -from factories import make_cm, make_pair, make_model -from loader import ( - ActionStep, - ContextManager, - CurrentSummaryCache, - PreviousSummaryCache, - TaskStep, -) - - -def _llm_text(model) -> str: - """Extract concatenated user prompt text from mock model's last call.""" - call_args = model.call_args[0][0] - return " ".join( - b.get("text", "") - for m in call_args - for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, dict) - ) - - -def _all_texts(messages): - return [ - b.get("text", "") - for m in messages - for b in (m.content if isinstance(m.content, list) else []) - if isinstance(b, dict) - ] - - -def _joined(messages): - return " ".join(_all_texts(messages)) - - -class TestCompressPreviousExtra: - - def test_P1_full_hit_fp_mismatch_goes_to_fresh(self): - """covered_pairs == len(pairs) but fingerprint wrong. - Should not take incremental path (covered < len condition not met), - go directly to fresh full compression. - """ - cm = make_cm() - pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(2)] - cm._previous_summary_cache = PreviousSummaryCache( - summary_text="old summary", covered_pairs=2, anchor_fingerprint="WRONG" - ) - model = make_model('{"task_overview": "fresh summary"}') - result = cm._compress_previous_with_cache(pairs, model) - - assert result is not None - model.assert_called_once() - assert "old summary" not in _llm_text(model) - assert cm._previous_summary_cache.covered_pairs == 2 - - def test_P2_incremental_over_budget_falls_through_to_fresh(self): - """Incremental input token count exceeds max_summary_input_tokens, - should skip incremental and go to fresh, still call LLM once (fresh). - """ - cm = make_cm() - cm.config.max_summary_input_tokens = 0 - - pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(3)] - anchor_t, anchor_a = pairs[1] - fp = cm._pair_fingerprint(anchor_t.task, anchor_a.action_output) - cm._previous_summary_cache = PreviousSummaryCache("old summary", 2, fp) - - model = make_model('{"task_overview": "fresh summary"}') - - result = cm._compress_previous_with_cache(pairs, model) - assert result is not None - model.assert_called_once() - assert "old summary" not in _llm_text(model) - assert "task2" in _llm_text(model) - assert "fresh" in result - - def test_P3_incremental_llm_none_falls_through_to_fresh(self): - """When _generate_summary returns None in incremental path, - code fall-through to fresh, should call LLM again. - """ - cm = make_cm() - pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(3)] - anchor_t, anchor_a = pairs[1] - fp = cm._pair_fingerprint(anchor_t.task, anchor_a.action_output) - cm._previous_summary_cache = PreviousSummaryCache("old summary", 2, fp) - - call_count = [0] - def side_effect(text, model_, call_type="summary", prompt_type="initial"): - call_count[0] += 1 - if call_count[0] == 1: - return None - return '{"task_overview": "fresh summary"}' - - with patch.object(cm, '_generate_summary', side_effect=side_effect): - result = cm._compress_previous_with_cache(pairs, MagicMock()) - - assert call_count[0] == 2 - assert result is not None - - def test_P4_fresh_llm_none_returns_none_and_preserves_old_cache(self): - """When _summarize_pairs returns (None, False): - - function returns None - - existing _previous_summary_cache not modified - """ - cm = make_cm() - pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(2)] - cm._previous_summary_cache = PreviousSummaryCache("old summary", 99, "bad_fp") - - with patch.object(cm, '_summarize_pairs', return_value=(None, False)): - result = cm._compress_previous_with_cache(pairs, MagicMock()) - - assert result is None - assert cm._previous_summary_cache.summary_text == "old summary" - - def test_P4_fresh_llm_none_no_cache_remains_none(self): - """Initial no cache, fresh LLM returns None -> cache still None.""" - cm = make_cm() - pairs = [make_pair("task", "action", 0)] - assert cm._previous_summary_cache is None - - with patch.object(cm, '_summarize_pairs', return_value=(None, False)): - result = cm._compress_previous_with_cache(pairs, MagicMock()) - - assert result is None - assert cm._previous_summary_cache is None - - -class TestCompressCurrentExtra: - - def _make_actions(self, n): - return [ - ActionStep(step_number=i, model_output=f"output{i}", action_output=f"result{i}") - for i in range(n) - ] - - def test_C1_full_hit_fp_mismatch_goes_to_fresh(self): - """end_steps == len(actions) but anchor_fingerprint wrong. - Incremental condition 0 < end_steps < len not met, go directly to fresh. - """ - cm = make_cm() - actions = self._make_actions(2) - cm._current_summary_cache = CurrentSummaryCache( - summary_text="old summary", end_steps=2, anchor_fingerprint="WRONG" - ) - model = make_model('{"task_overview": "fresh summary"}') - result = cm._compress_current_with_cache(TaskStep(task="t"), actions, model) - - assert result is not None - assert "fresh summary" in result - assert "old summary" not in result - model.assert_called_once() - real_fp = ContextManager._action_fingerprint(actions[-1]) - assert cm._current_summary_cache.anchor_fingerprint == real_fp - - def test_C2_incremental_anchor_fp_mismatch_goes_to_fresh(self): - """cache.end_steps < len(actions) (incremental condition met), - but anchor action fingerprint mismatch with cache -> fall-through to fresh. - """ - cm = make_cm() - actions = self._make_actions(3) - cm._current_summary_cache = CurrentSummaryCache( - summary_text="old summary", end_steps=2, anchor_fingerprint="WRONG" - ) - model = make_model('{"task_overview": "fresh summary"}') - result = cm._compress_current_with_cache(TaskStep(task="t"), actions, model) - - assert result is not None - model.assert_called_once() - assert "old summary" not in _llm_text(model) - assert "fresh summary" in result - - def test_C4_incremental_llm_none_falls_through_to_fresh(self): - cm = make_cm() - actions = self._make_actions(3) - fp = ContextManager._action_fingerprint(actions[1]) - cm._current_summary_cache = CurrentSummaryCache("old summary", 2, fp) - - call_count = [0] - def side_effect(text, model_, call_type="summary", prompt_type="initial"): - call_count[0] += 1 - if call_count[0] == 1: - return None - return '{"task_overview": "fresh summary"}' - - with patch.object(cm, '_generate_summary', side_effect=side_effect): - result = cm._compress_current_with_cache(TaskStep(task="t"), actions, MagicMock()) - - assert call_count[0] == 2 - assert result is not None - assert cm._current_summary_cache.end_steps == len(actions) - - def test_C5_fresh_actions_trimmed_cache_uses_original_len(self): - """_trim_actions_to_budget trimmed some actions, - but end_steps should still record original len(actions_to_compress), - ensuring next call cache covers same range. - """ - cm = make_cm() - actions = self._make_actions(4) - - with patch.object(cm, '_trim_actions_to_budget', return_value=[actions[-1]]): - model = make_model('{"task_overview": "trimmed summary"}') - result = cm._compress_current_with_cache(TaskStep(task="t"), actions, model) - - assert result is not None - assert cm._current_summary_cache.end_steps == 4 - real_fp = ContextManager._action_fingerprint(actions[-1]) - assert cm._current_summary_cache.anchor_fingerprint == real_fp - - def test_C5_fresh_partial_trim_still_calls_llm_once(self): - """After trim still only call LLM once (no retry).""" - cm = make_cm() - actions = self._make_actions(3) - - with patch.object(cm, '_trim_actions_to_budget', return_value=[actions[-1]]): - model = make_model('{"task_overview": "summary"}') - cm._compress_current_with_cache(TaskStep(task="t"), actions, model) - - model.assert_called_once() - - def test_C6_fresh_llm_none_writes_none_to_cache(self): - """Current fresh path if LLM call fails, no cache. - Only truncation performed. - """ - cm = make_cm() - actions = self._make_actions(2) - - with patch.object(cm, '_generate_summary', return_value=None): - result = cm._compress_current_with_cache(TaskStep(task="t"), actions, MagicMock()) - - assert "[CONTEXT COMPACTION" in result - assert cm._current_summary_cache is None - - def test_C6_vs_previous_asymmetry(self): - """Regression test: clarify asymmetry between previous and current behavior when LLM=None. - previous fresh=None -> cache not written (preserve old value) - current fresh=None -> cache not written - """ - cm = make_cm() - pairs = [make_pair("task", "action", 0)] - actions = [ActionStep(step_number=0, model_output="out", action_output="r")] - - old_prev_cache = PreviousSummaryCache("old prev", 99, "bad") - cm._previous_summary_cache = old_prev_cache - - with patch.object(cm, '_summarize_pairs', return_value=(None, False)): - cm._compress_previous_with_cache(pairs, MagicMock()) - assert cm._previous_summary_cache is old_prev_cache - - with patch.object(cm, '_generate_summary', return_value=None): - cm._compress_current_with_cache(TaskStep(task="t"), actions, MagicMock()) - assert cm._current_summary_cache is None \ No newline at end of file +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + +from unittest.mock import MagicMock, patch + +from factories import make_cm, make_pair, make_model +from loader import ( + ActionStep, + ContextManager, + CurrentSummaryCache, + PreviousSummaryCache, + TaskStep, + pair_fingerprint, + action_fingerprint, + SummaryResult, + PreviousCompressResult, + CurrentCompressResult, +) + + +def _llm_text(model) -> str: + """Extract concatenated user prompt text from mock model's last call.""" + call_args = model.call_args[0][0] + return " ".join( + b.get("text", "") + for m in call_args + for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, dict) + ) + + +def _all_texts(messages): + return [ + b.get("text", "") + for m in messages + for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, dict) + ] + + +def _joined(messages): + return " ".join(_all_texts(messages)) + + +def _compress_previous_with_cache(cm, pairs, model): + """Helper: call prev compressor and apply cache, return summary_text.""" + result = cm._prev_compressor.compress(pairs, cm._previous_summary_cache, model) + if result.new_cache is not None: + cm._previous_summary_cache = result.new_cache + return result.summary_text + + +def _compress_current_with_cache(cm, task, actions, model): + """Helper: call curr compressor and apply cache, return summary_text.""" + result = cm._curr_compressor.compress(task, actions, cm._current_summary_cache, model) + if result.new_cache is not None: + cm._current_summary_cache = result.new_cache + return result.summary_text + + +class TestCompressPreviousExtra: + + def test_P1_full_hit_fp_mismatch_goes_to_fresh(self): + """covered_pairs == len(pairs) but fingerprint wrong. + Should not take incremental path (covered < len condition not met), + go directly to fresh full compression. + """ + cm = make_cm() + pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(2)] + cm._previous_summary_cache = PreviousSummaryCache( + summary_text="old summary", covered_pairs=2, anchor_fingerprint="WRONG" + ) + model = make_model('{"task_overview": "fresh summary"}') + result = _compress_previous_with_cache(cm, pairs, model) + + assert result is not None + model.assert_called_once() + assert "old summary" not in _llm_text(model) + assert cm._previous_summary_cache.covered_pairs == 2 + + def test_P2_incremental_over_budget_falls_through_to_fresh(self): + """Incremental input token count exceeds max_summary_input_tokens, + should skip incremental and go to fresh, still call LLM once (fresh). + """ + cm = make_cm() + cm.config.max_summary_input_tokens = 0 + + pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(3)] + anchor_t, anchor_a = pairs[1] + fp = pair_fingerprint(anchor_t.task, anchor_a.action_output) + cm._previous_summary_cache = PreviousSummaryCache("old summary", 2, fp) + + model = make_model('{"task_overview": "fresh summary"}') + + result = _compress_previous_with_cache(cm, pairs, model) + assert result is not None + model.assert_called_once() + assert "old summary" not in _llm_text(model) + assert "task2" in _llm_text(model) + assert "fresh" in result + + def test_P3_incremental_llm_none_falls_through_to_fresh(self): + """When generate_summary returns SummaryResult(summary_text=None) in incremental path, + code fall-through to fresh, should call LLM again. + """ + cm = make_cm() + pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(3)] + anchor_t, anchor_a = pairs[1] + fp = pair_fingerprint(anchor_t.task, anchor_a.action_output) + cm._previous_summary_cache = PreviousSummaryCache("old summary", 2, fp) + + call_count = [0] + def side_effect(text, model_, call_type="summary", prompt_type="initial"): + call_count[0] += 1 + if call_count[0] == 1: + return SummaryResult(summary_text=None, records=[]) + return SummaryResult(summary_text='{"task_overview": "fresh summary"}', records=[]) + + with patch.object(cm._llm, 'generate_summary', side_effect=side_effect): + result = _compress_previous_with_cache(cm, pairs, MagicMock()) + + assert call_count[0] == 2 + assert result is not None + + def test_P4_fresh_llm_none_returns_none_and_preserves_old_cache(self): + """When _summarize_pairs returns (None, False, []): + - function returns None + - existing _previous_summary_cache not modified + """ + cm = make_cm() + pairs = [make_pair(f"task{i}", f"action{i}", i) for i in range(2)] + cm._previous_summary_cache = PreviousSummaryCache("old summary", 99, "bad_fp") + + with patch.object(cm._prev_compressor, '_summarize_pairs', return_value=(None, False, [])): + result = _compress_previous_with_cache(cm, pairs, MagicMock()) + + assert result is None + assert cm._previous_summary_cache.summary_text == "old summary" + + def test_P4_fresh_llm_none_no_cache_remains_none(self): + """Initial no cache, fresh LLM returns None -> cache still None.""" + cm = make_cm() + pairs = [make_pair("task", "action", 0)] + assert cm._previous_summary_cache is None + + with patch.object(cm._prev_compressor, '_summarize_pairs', return_value=(None, False, [])): + result = _compress_previous_with_cache(cm, pairs, MagicMock()) + + assert result is None + assert cm._previous_summary_cache is None + + +class TestCompressCurrentExtra: + + def _make_actions(self, n): + return [ + ActionStep(step_number=i, model_output=f"output{i}", action_output=f"result{i}") + for i in range(n) + ] + + def test_C1_full_hit_fp_mismatch_goes_to_fresh(self): + """end_steps == len(actions) but anchor_fingerprint wrong. + Incremental condition 0 < end_steps < len not met, go directly to fresh. + """ + cm = make_cm() + actions = self._make_actions(2) + cm._current_summary_cache = CurrentSummaryCache( + summary_text="old summary", end_steps=2, anchor_fingerprint="WRONG" + ) + model = make_model('{"task_overview": "fresh summary"}') + result = _compress_current_with_cache(cm, TaskStep(task="t"), actions, model) + + assert result is not None + assert "fresh summary" in result + assert "old summary" not in result + model.assert_called_once() + real_fp = action_fingerprint(actions[-1]) + assert cm._current_summary_cache.anchor_fingerprint == real_fp + + def test_C2_incremental_anchor_fp_mismatch_goes_to_fresh(self): + """cache.end_steps < len(actions) (incremental condition met), + but anchor action fingerprint mismatch with cache -> fall-through to fresh. + """ + cm = make_cm() + actions = self._make_actions(3) + cm._current_summary_cache = CurrentSummaryCache( + summary_text="old summary", end_steps=2, anchor_fingerprint="WRONG" + ) + model = make_model('{"task_overview": "fresh summary"}') + result = _compress_current_with_cache(cm, TaskStep(task="t"), actions, model) + + assert result is not None + model.assert_called_once() + assert "old summary" not in _llm_text(model) + assert "fresh summary" in result + + def test_C4_incremental_llm_none_falls_through_to_fresh(self): + cm = make_cm() + actions = self._make_actions(3) + fp = action_fingerprint(actions[1]) + cm._current_summary_cache = CurrentSummaryCache("old summary", 2, fp) + + call_count = [0] + def side_effect(text, model_, call_type="summary", prompt_type="initial"): + call_count[0] += 1 + if call_count[0] == 1: + return SummaryResult(summary_text=None, records=[]) + return SummaryResult(summary_text='{"task_overview": "fresh summary"}', records=[]) + + with patch.object(cm._llm, 'generate_summary', side_effect=side_effect): + result = _compress_current_with_cache(cm, TaskStep(task="t"), actions, MagicMock()) + + assert call_count[0] == 2 + assert result is not None + assert cm._current_summary_cache.end_steps == len(actions) + + def test_C5_fresh_actions_trimmed_cache_uses_original_len(self): + """trim_actions_to_budget trimmed some actions, + but end_steps should still record original len(actions_to_compress), + ensuring next call cache covers same range. + """ + cm = make_cm() + actions = self._make_actions(4) + + with patch.object(cm._renderer, 'actions_to_text', return_value="short text"): + model = make_model('{"task_overview": "trimmed summary"}') + result = _compress_current_with_cache(cm, TaskStep(task="t"), actions, model) + + assert result is not None + assert cm._current_summary_cache.end_steps == 4 + real_fp = action_fingerprint(actions[-1]) + assert cm._current_summary_cache.anchor_fingerprint == real_fp + + def test_C5_fresh_partial_trim_still_calls_llm_once(self): + """After trim still only call LLM once (no retry).""" + cm = make_cm() + actions = self._make_actions(3) + + with patch.object(cm._renderer, 'actions_to_text', return_value="short text"): + model = make_model('{"task_overview": "summary"}') + _compress_current_with_cache(cm, TaskStep(task="t"), actions, model) + + model.assert_called_once() + + def test_C6_fresh_llm_none_writes_none_to_cache(self): + """Current fresh path if LLM call fails, no cache. + Only truncation performed. + """ + cm = make_cm() + actions = self._make_actions(2) + + with patch.object(cm._llm, 'generate_summary', return_value=SummaryResult(summary_text=None, records=[])): + result = _compress_current_with_cache(cm, TaskStep(task="t"), actions, MagicMock()) + + assert "[CONTEXT COMPACTION" in result + assert cm._current_summary_cache is None + + def test_C6_vs_previous_asymmetry(self): + """Regression test: clarify asymmetry between previous and current behavior when LLM=None. + previous fresh=None -> cache not written (preserve old value) + current fresh=None -> cache not written + """ + cm = make_cm() + pairs = [make_pair("task", "action", 0)] + actions = [ActionStep(step_number=0, model_output="out", action_output="r")] + + old_prev_cache = PreviousSummaryCache("old prev", 99, "bad") + cm._previous_summary_cache = old_prev_cache + + with patch.object(cm._prev_compressor, '_summarize_pairs', return_value=(None, False, [])): + _compress_previous_with_cache(cm, pairs, MagicMock()) + assert cm._previous_summary_cache is old_prev_cache + + with patch.object(cm._llm, 'generate_summary', return_value=SummaryResult(summary_text=None, records=[])): + _compress_current_with_cache(cm, TaskStep(task="t"), actions, MagicMock()) + assert cm._current_summary_cache is None diff --git a/test/sdk/core/agents/test_agent_context/unit/test_estimate_token.py b/test/sdk/core/agents/test_agent_context/unit/test_estimate_token.py index f767931fe..4d5f72fae 100644 --- a/test/sdk/core/agents/test_agent_context/unit/test_estimate_token.py +++ b/test/sdk/core/agents/test_agent_context/unit/test_estimate_token.py @@ -1,59 +1,60 @@ -""" -unit/test_estimate_token.py -Verify ContextManager._estimate_tokens(memory) and -ContextManager._msg_token_count(flat_messages) result consistency. -""" - -import pytest - -from factories import make_cm, make_memory_with_steps, make_original_messages, make_pair -from loader import AgentMemory, PreviousSummaryCache -from stubs import _SystemPromptStep - - -class TestEstimateTokenConsistency: - """_estimate_tokens and _msg_token_count(flat) must return the same value.""" - - def test_msg_token_count_equal_estimate_token_for_memory(self): - cm = make_cm(enabled=True, threshold=10) - memory = make_memory_with_steps(3) - original = make_original_messages(memory) - assert cm._estimate_tokens(memory) == cm._msg_token_count(original) - - -class TestEffectiveTokens: - - def test_effective_prev_tokens_no_cache(self): - """No cache should equal raw estimation.""" - cm = make_cm() - t, a = make_pair("task", "action") - steps = [t, a] - raw = cm._estimate_tokens_for_steps(steps) - effective = cm._effective_prev_tokens(steps) - assert effective == raw - - def test_effective_prev_tokens_with_valid_cache_less_than_raw(self): - """Valid cache exists, effective tokens should be <= raw (summary shorter than full text).""" - cm = make_cm() - t, a = make_pair("X" * 200, "Y" * 200, 1) - pairs = [(t, a)] - fp = cm._pair_fingerprint(t.task, a.model_output) - cm._previous_summary_cache = PreviousSummaryCache("short summary", 1, fp) - steps = [t, a] - raw = cm._estimate_tokens_for_steps(steps) - effective = cm._effective_prev_tokens(steps) - assert effective < raw - - def test_effective_curr_tokens_empty(self): - cm = make_cm() - assert cm._effective_curr_tokens([]) == 0 - - def test_effective_tokens_sums_prev_and_curr(self): - cm = make_cm() - t1, a1 = make_pair("prev task", "prev action", 1) - t2, a2 = make_pair("curr task", "curr action", 2) - memory = AgentMemory(steps=[t1, a1, t2, a2]) - total = cm._effective_tokens(memory, current_run_start_idx=2) - prev = cm._effective_prev_tokens([t1, a1]) - curr = cm._effective_curr_tokens([t2, a2]) - assert total == prev + curr \ No newline at end of file +""" +unit/test_estimate_token.py +Verify estimate_tokens(memory) and msg_token_count(flat_messages) result consistency. +""" + +import pytest + +from factories import make_cm, make_memory_with_steps, make_original_messages, make_pair +from loader import AgentMemory, PreviousSummaryCache +from loader import estimate_tokens, estimate_tokens_for_steps, msg_token_count +from stubs import _SystemPromptStep + + +class TestEstimateTokenConsistency: + """estimate_tokens and msg_token_count(flat) must return the same value.""" + + def test_msg_token_count_equal_estimate_token_for_memory(self): + cm = make_cm(enabled=True, threshold=10) + memory = make_memory_with_steps(3) + original = make_original_messages(memory) + assert estimate_tokens(memory, cm.config.chars_per_token) == msg_token_count(original, cm.config.chars_per_token) + + +class TestEffectiveTokens: + + def test_effective_prev_tokens_no_cache(self): + """No cache should equal raw estimation.""" + cm = make_cm() + t, a = make_pair("task", "action") + steps = [t, a] + raw = estimate_tokens_for_steps(steps, cm.config.chars_per_token) + effective = cm._effective_prev_tokens(steps) + assert effective == raw + + def test_effective_prev_tokens_with_valid_cache_less_than_raw(self): + """Valid cache exists, effective tokens should be <= raw (summary shorter than full text).""" + cm = make_cm() + t, a = make_pair("X" * 200, "Y" * 200, 1) + pairs = [(t, a)] + from loader import pair_fingerprint + fp = pair_fingerprint(t.task, a.model_output) + cm._previous_summary_cache = PreviousSummaryCache("short summary", 1, fp) + steps = [t, a] + raw = estimate_tokens_for_steps(steps, cm.config.chars_per_token) + effective = cm._effective_prev_tokens(steps) + assert effective < raw + + def test_effective_curr_tokens_empty(self): + cm = make_cm() + assert cm._effective_curr_tokens([]) == 0 + + def test_effective_tokens_sums_prev_and_curr(self): + cm = make_cm() + t1, a1 = make_pair("prev task", "prev action", 1) + t2, a2 = make_pair("curr task", "curr action", 2) + memory = AgentMemory(steps=[t1, a1, t2, a2]) + total = cm._effective_tokens(memory, current_run_start_idx=2) + prev = cm._effective_prev_tokens([t1, a1]) + curr = cm._effective_curr_tokens([t2, a2]) + assert total == prev + curr diff --git a/test/sdk/core/agents/test_agent_context/unit/test_offload_store.py b/test/sdk/core/agents/test_agent_context/unit/test_offload_store.py new file mode 100644 index 000000000..2bb0a4bf8 --- /dev/null +++ b/test/sdk/core/agents/test_agent_context/unit/test_offload_store.py @@ -0,0 +1,397 @@ +""" +unit/test_offload_store.py +Tests for OffloadStore: store, reload, list_active, build_reload_inventory, +tokenize, score, eviction, and diagnostics. +""" + +import uuid +import pytest + +from loader import OffloadStore + + +# ────────────────────────────────────────────────────────────── +# store() +# ────────────────────────────────────────────────────────────── + +class TestStore: + + def test_store_returns_valid_handle_and_retrievable_content(self): + store = OffloadStore() + handle = store.store("hello world", description="short desc") + assert isinstance(handle, str) and len(handle) == 32 + uuid.UUID(hex=handle) + assert store.reload(handle) == "hello world" + assert store.list_active()[0][1] == "short desc" + + def test_store_max_entry_chars_boundary(self): + store = OffloadStore(max_entry_chars=10) + assert store.store("x" * 11) is None + assert store.store("x" * 10) is not None + + def test_multiple_stores_increment_count(self): + store = OffloadStore() + for _ in range(3): + store.store("x") + assert len(store) == 3 + + def test_store_empty_inputs(self): + store = OffloadStore() + handle = store.store("") + assert handle is not None and store.reload(handle) == "" + assert store.list_active()[0][1] == "" + + +# ────────────────────────────────────────────────────────────── +# reload() +# ────────────────────────────────────────────────────────────── + +class TestReload: + + def test_reload_valid_and_invalid_handles(self): + store = OffloadStore() + h = store.store("exact content") + assert store.reload(h) == "exact content" + assert store.reload("nonexistent") is None + + def test_reload_after_eviction_returns_none(self): + store = OffloadStore(max_entries=1) + h1 = store.store("first") + store.store("second") # evicts first + assert store.reload(h1) is None + + def test_reload_hits_and_misses_diagnostics(self): + store = OffloadStore() + h = store.store("data") + store.reload(h) + store.reload(h) + store.reload("bad") + assert store.reload_hits == 2 + assert store.reload_misses == 1 + + +# ────────────────────────────────────────────────────────────── +# list_active() +# ────────────────────────────────────────────────────────────── + +class TestListActive: + + def test_empty_store_returns_empty_list(self): + store = OffloadStore() + assert store.list_active() == [] + + def test_returns_handle_description_pairs(self): + store = OffloadStore() + h1 = store.store("content1", "desc1") + h2 = store.store("content2", "desc2") + active = store.list_active() + assert len(active) == 2 + assert (h1, "desc1") in active + assert (h2, "desc2") in active + + def test_evicted_entries_not_listed(self): + store = OffloadStore(max_entries=2) + store.store("a", "first") + store.store("b", "second") + store.store("c", "third") # evicts first + active = store.list_active() + assert len(active) == 2 + descriptions = [d for _, d in active] + assert "first" not in descriptions + assert "second" in descriptions + assert "third" in descriptions + + +# ────────────────────────────────────────────────────────────── +# Eviction +# ────────────────────────────────────────────────────────────── + +class TestEviction: + + def test_count_based_eviction_fifo(self): + store = OffloadStore(max_entries=3) + handles = [store.store(f"item{i}") for i in range(5)] + assert len(store) == 3 + # First two should be evicted + assert store.reload(handles[0]) is None + assert store.reload(handles[1]) is None + assert store.reload(handles[2]) == "item2" + assert store.reload(handles[3]) == "item3" + assert store.reload(handles[4]) == "item4" + + def test_size_based_eviction(self): + store = OffloadStore(max_total_chars=20) + store.store("A" * 12) # 12 chars + h2 = store.store("B" * 12) # would be 24 total, evicts first + assert len(store) == 1 + assert store.reload(h2) == "B" * 12 + + def test_size_eviction_evicts_multiple_if_needed(self): + store = OffloadStore(max_total_chars=30) + store.store("A" * 10) # 10 + store.store("B" * 10) # 20 total + store.store("C" * 25) # 25 would make 45, evict both A and B + assert len(store) == 1 + + def test_store_rejects_oversized_content_only(self): + """Content over max_entry_chars is rejected; smaller ones still accepted.""" + store = OffloadStore(max_entry_chars=20) + h1 = store.store("ok size") + h2 = store.store("x" * 25) + assert h1 is not None + assert h2 is None + assert len(store) == 1 + + def test_clear_removes_all_entries(self): + store = OffloadStore() + store.store("a") + store.store("b") + store.clear() + assert len(store) == 0 + assert store.list_active() == [] + + def test_clear_resets_total_chars(self): + store = OffloadStore(max_total_chars=10) + store.store("1234567890") # 10 chars fills it up + store.clear() + # After clear, should be able to store another 10 chars + h = store.store("abcdefghij") + assert h is not None + + +# ────────────────────────────────────────────────────────────── +# items() and __len__ +# ────────────────────────────────────────────────────────────── + +class TestItemsAndLen: + + def test_len_reflects_entry_count(self): + store = OffloadStore() + assert len(store) == 0 + store.store("a") + assert len(store) == 1 + store.store("b") + assert len(store) == 2 + + def test_items_returns_handle_content_pairs(self): + store = OffloadStore() + h1 = store.store("content_a") + h2 = store.store("content_b") + items = store.items() + assert len(items) == 2 + contents = [c for _, c in items] + assert "content_a" in contents + assert "content_b" in contents + + def test_items_is_snapshot(self): + store = OffloadStore() + store.store("a") + items = store.items() + store.store("b") + assert len(items) == 1 # snapshot is frozen + + def test_items_empty_store(self): + store = OffloadStore() + assert store.items() == [] + + +# ────────────────────────────────────────────────────────────── +# build_reload_inventory() +# ────────────────────────────────────────────────────────────── + +class TestBuildReloadInventory: + + def test_disabled_returns_none(self): + store = OffloadStore() + store.store("content", "desc") + assert store.build_reload_inventory(enable_reload=False) is None + + def test_empty_store_returns_none(self): + store = OffloadStore() + assert store.build_reload_inventory(enable_reload=True) is None + + def test_no_query_returns_recent_entries(self): + store = OffloadStore() + store.store("old", "oldest") + store.store("mid", "middle") + store.store("new", "newest") + result = store.build_reload_inventory(enable_reload=True, max_items=2) + assert result is not None + assert "newest" in result + assert "middle" in result + assert "oldest" not in result + + def test_header_text_present(self): + store = OffloadStore() + store.store("data", "test_desc") + result = store.build_reload_inventory(enable_reload=True) + assert "[System Notice" in result + assert "handle=" in result + assert "test_desc" in result + + def test_max_items_caps_output(self): + store = OffloadStore() + for i in range(10): + store.store(f"content{i}", f"desc{i}") + result = store.build_reload_inventory(enable_reload=True, max_items=3) + lines = [l for l in result.split("\n") if l.startswith("- handle=")] + assert len(lines) == 3 + + def test_with_query_scores_and_ranks(self): + store = OffloadStore() + store.store("irrelevant", "nothing here") + store.store("target", "important database query result") + store.store("other", "some other text") + result = store.build_reload_inventory( + enable_reload=True, query="database result", max_items=2 + ) + assert result is not None + # "database" should match in the description + assert "important database query result" in result + + def test_with_query_no_matches_falls_back_to_recency(self): + store = OffloadStore() + store.store("a", "first entry") + store.store("b", "second entry") + store.store("c", "third entry") + result = store.build_reload_inventory( + enable_reload=True, query="zzz_nonexistent_xyz", max_items=2 + ) + assert result is not None + # Falls back to recency (tail) + assert "third entry" in result + assert "second entry" in result + + def test_query_with_unicode(self): + store = OffloadStore() + store.store("数据", "数据库查询结果") + store.store("other", "unrelated") + result = store.build_reload_inventory( + enable_reload=True, query="数据库", max_items=1 + ) + assert result is not None + assert "数据库查询结果" in result + + +# ────────────────────────────────────────────────────────────── +# _tokenize (static method) +# ────────────────────────────────────────────────────────────── + +class TestTokenize: + + def test_latin_tokenization(self): + tokens = OffloadStore._tokenize("Hello, WORLD! test.") + assert tokens >= {"hello", "world", "test"} + assert "," not in tokens and "!" not in tokens + + def test_filters_stop_words_short_tokens_and_digits(self): + tokens = OffloadStore._tokenize("the a is 42 score cd points 123") + assert tokens & {"the", "a", "is", "42", "123"} == set() + assert tokens >= {"score", "points"} + assert "cd" in tokens + + def test_cjk_tokenization(self): + # multi-char → bigrams + tokens = OffloadStore._tokenize("数据库查询") + assert tokens >= {"数据", "据库", "库查", "查询"} + # single CJK char → no bigrams produced + assert OffloadStore._tokenize("数") == set() + + def test_mixed_cjk_and_latin(self): + tokens = OffloadStore._tokenize("hello 数据库 world") + assert tokens >= {"hello", "world", "数据", "据库"} + + def test_tokenize_empty_string(self): + assert OffloadStore._tokenize("") == set() + + +# ────────────────────────────────────────────────────────────── +# _score_description (private, tested via public API + direct) +# ────────────────────────────────────────────────────────────── + +class TestScoreDescription: + + def test_exact_match_score(self): + store = OffloadStore() + desc_tokens = store._tokenize("database query result") + query_tokens = store._tokenize("database query") + score = store._score_description(desc_tokens, query_tokens) + assert score > 0 + # exact matches: database, query → overlap=2, 2^2/min(3,8)=4/3≈1.33 + assert score == pytest.approx(4.0 / 3.0, abs=0.01) + + def test_no_match_returns_zero(self): + store = OffloadStore() + desc_tokens = store._tokenize("hello world") + query_tokens = store._tokenize("xyzzy") + score = store._score_description(desc_tokens, query_tokens) + assert score == 0.0 + + def test_empty_desc_tokens_returns_zero(self): + store = OffloadStore() + score = store._score_description(set(), {"hello"}) + assert score == 0.0 + + def test_partial_substring_match(self): + store = OffloadStore() + desc_tokens = store._tokenize("download") + query_tokens = store._tokenize("down") + score = store._score_description(desc_tokens, query_tokens) + # "down" in "download" → 0.5, squared/min(1,8) = 0.25 + assert score == pytest.approx(0.25, abs=0.01) + + def test_multiple_matches_amplified(self): + store = OffloadStore() + desc_tokens = store._tokenize("database sql query result") + query_tokens = store._tokenize("database query sql") + score = store._score_description(desc_tokens, query_tokens) + # 3 exact matches → overlap=3, 3^2/min(4,8)=9/4=2.25 + assert score == pytest.approx(9.0 / 4.0, abs=0.01) + + def test_cjk_score(self): + store = OffloadStore() + desc_tokens = store._tokenize("数据库查询优化") + query_tokens = store._tokenize("数据库") + score = store._score_description(desc_tokens, query_tokens) + # desc tokens: 7 (6 CJK bigrams + full word "数据库查询优化") + # query tokens: 3 (2 CJK bigrams + full word "数据库") + # Exact: "数据" + "据库" → overlap=2.0 + # Partial: "数据库" in "数据库查询优化" → +0.5 + # score = 2.5^2 / min(7,8) = 6.25/7 ≈ 0.8929 + assert score == pytest.approx(6.25 / 7.0, abs=0.01) + + +# ────────────────────────────────────────────────────────────── +# Custom constructor parameters +# ────────────────────────────────────────────────────────────── + +class TestCustomConfig: + + @pytest.mark.parametrize("kwargs, expected", [ + ({}, (200, 2_000_000, 30000)), + ({"max_entries": 50, "max_total_chars": 10000, "max_entry_chars": 500}, + (50, 10000, 500)), + ]) + def test_config_defaults_and_custom(self, kwargs, expected): + store = OffloadStore(**kwargs) + assert store._max_entries == expected[0] + assert store._max_total_chars == expected[1] + assert store._max_entry_chars == expected[2] + + +# ────────────────────────────────────────────────────────────── +# Integration with ContextManager +# ────────────────────────────────────────────────────────────── + +class TestContextManagerOffloadStore: + + def test_cm_offload_store_integration(self): + from factories import make_cm + cm = make_cm() + assert isinstance(cm.offload_store, OffloadStore) + # Singleton property + assert cm.offload_store is cm.offload_store + # Functional end-to-end + handle = cm.offload_store.store("cm content", "cm desc") + assert handle and cm.offload_store.reload(handle) == "cm content" + assert cm.offload_store.reload_hits == 1 diff --git a/test/sdk/core/agents/test_agent_context/unit/test_pure_functions.py b/test/sdk/core/agents/test_agent_context/unit/test_pure_functions.py index eef2f8194..7f9ea8e2d 100644 --- a/test/sdk/core/agents/test_agent_context/unit/test_pure_functions.py +++ b/test/sdk/core/agents/test_agent_context/unit/test_pure_functions.py @@ -1,104 +1,98 @@ -import json -import pytest - -from factories import make_cm, make_memory_with_steps, make_original_messages, make_pair -from loader import ContextManager, SummaryTaskStep, TaskStep, ActionStep - - -class TestPureFunctions: - - def test_format_summary_valid_json(self): - cm = make_cm() - raw = '{"task_overview": "did something", "completed_work": "completed"}' - result = cm._format_summary(raw) - parsed = json.loads(result) - assert parsed["task_overview"] == "did something" - - def test_format_summary_strips_markdown_fence(self): - cm = make_cm() - raw = '```json\n{"task_overview": "x"}\n```' - result = cm._format_summary(raw) - assert result is not None - assert "```" not in result - - def test_format_summary_invalid_json_returns_plain_text(self): - cm = make_cm() - raw = "This is not JSON format text content" - result = cm._format_summary(raw) - assert result == raw - - def test_format_summary_empty_string_returns_none(self): - cm = make_cm() - assert cm._format_summary("") is None - assert cm._format_summary(" ") is None - - def test_extract_pairs_basic(self): - cm = make_cm() - t1, a1 = make_pair("task1", "result1", 1) - t2, a2 = make_pair("task2", "result2", 2) - steps = [t1, a1, t2, a2] - pairs = cm._extract_pairs(steps) - assert len(pairs) == 2 - assert pairs[0][0].task == "task1" - assert pairs[1][0].task == "task2" - - def test_extract_pairs_skips_summary_task_step(self): - cm = make_cm() - summary = SummaryTaskStep(task="existing summary") - t1, a1 = make_pair("task1", "result1", 1) - steps = [summary, t1, a1] - pairs = cm._extract_pairs(steps) - assert len(pairs) == 1 - assert pairs[0][0].task == "task1" - - def test_extract_pairs_ignores_orphan_task(self): - """A TaskStep without following ActionStep should not form a pair.""" - cm = make_cm() - t1, a1 = make_pair("task1", "result1", 1) - t_orphan = TaskStep(task="orphan task") - steps = [t1, a1, t_orphan] - pairs = cm._extract_pairs(steps) - assert len(pairs) == 1 - - def test_extract_pairs_empty_steps(self): - cm = make_cm() - assert cm._extract_pairs([]) == [] - - def test_pair_fingerprint_is_deterministic(self): - cm = make_cm() - fp1 = cm._pair_fingerprint("task content", "action content") - fp2 = cm._pair_fingerprint("task content", "action content") - assert fp1 == fp2 - - def test_pair_fingerprint_differs_on_content_change(self): - cm = make_cm() - fp1 = cm._pair_fingerprint("task A", "action A") - fp2 = cm._pair_fingerprint("task A", "action B") - assert fp1 != fp2 - - def test_action_fingerprint_is_deterministic(self): - a = ActionStep(step_number=3, model_output="output", action_output="result") - fp1 = ContextManager._action_fingerprint(a) - fp2 = ContextManager._action_fingerprint(a) - assert fp1 == fp2 - - def test_action_fingerprint_differs_on_output_change(self): - a1 = ActionStep(step_number=1, model_output="output A", action_output="result A") - a2 = ActionStep(step_number=1, model_output="output A", action_output="result B") - assert ContextManager._action_fingerprint(a1) != ContextManager._action_fingerprint(a2) - - def test_pairs_to_text_format(self): - cm = make_cm() - t, a = make_pair("user question", "model response", 1) - text = cm._pairs_to_text([(t, a)]) - assert "user question" in text - assert "model response" in text - assert "user:" in text - assert "assistant:" in text - - def test_pairs_to_text_multiple_pairs_joined_by_blank_line(self): - cm = make_cm() - pair1 = make_pair("question1", "answer1", 1) - pair2 = make_pair("question2", "answer2", 2) - text = cm._pairs_to_text([pair1, pair2]) - assert "\n\n" in text \ No newline at end of file +import json +import pytest + +from factories import make_cm, make_memory_with_steps, make_original_messages, make_pair +from loader import ( + ContextManager, SummaryTaskStep, TaskStep, ActionStep, + extract_pairs, pair_fingerprint, action_fingerprint, + format_summary_output, +) + + +class TestPureFunctions: + + def test_format_summary_valid_json(self): + raw = '{"task_overview": "did something", "completed_work": "completed"}' + result = format_summary_output(raw) + parsed = json.loads(result) + assert parsed["task_overview"] == "did something" + + def test_format_summary_strips_markdown_fence(self): + raw = '```json\n{"task_overview": "x"}\n```' + result = format_summary_output(raw) + assert result is not None + assert "```" not in result + + def test_format_summary_invalid_json_returns_plain_text(self): + raw = "This is not JSON format text content" + result = format_summary_output(raw) + assert result == raw + + def test_format_summary_empty_string_returns_none(self): + assert format_summary_output("") is None + assert format_summary_output(" ") is None + + def test_extract_pairs_basic(self): + t1, a1 = make_pair("task1", "result1", 1) + t2, a2 = make_pair("task2", "result2", 2) + steps = [t1, a1, t2, a2] + pairs = extract_pairs(steps) + assert len(pairs) == 2 + assert pairs[0][0].task == "task1" + assert pairs[1][0].task == "task2" + + def test_extract_pairs_skips_summary_task_step(self): + summary = SummaryTaskStep(task="existing summary") + t1, a1 = make_pair("task1", "result1", 1) + steps = [summary, t1, a1] + pairs = extract_pairs(steps) + assert len(pairs) == 1 + assert pairs[0][0].task == "task1" + + def test_extract_pairs_ignores_orphan_task(self): + """A TaskStep without following ActionStep should not form a pair.""" + t1, a1 = make_pair("task1", "result1", 1) + t_orphan = TaskStep(task="orphan task") + steps = [t1, a1, t_orphan] + pairs = extract_pairs(steps) + assert len(pairs) == 1 + + def test_extract_pairs_empty_steps(self): + assert extract_pairs([]) == [] + + def test_pair_fingerprint_is_deterministic(self): + fp1 = pair_fingerprint("task content", "action content") + fp2 = pair_fingerprint("task content", "action content") + assert fp1 == fp2 + + def test_pair_fingerprint_differs_on_content_change(self): + fp1 = pair_fingerprint("task A", "action A") + fp2 = pair_fingerprint("task A", "action B") + assert fp1 != fp2 + + def test_action_fingerprint_is_deterministic(self): + a = ActionStep(step_number=3, model_output="output", action_output="result") + fp1 = action_fingerprint(a) + fp2 = action_fingerprint(a) + assert fp1 == fp2 + + def test_action_fingerprint_differs_on_output_change(self): + a1 = ActionStep(step_number=1, model_output="output A", action_output="result A") + a2 = ActionStep(step_number=1, model_output="output A", action_output="result B") + assert action_fingerprint(a1) != action_fingerprint(a2) + + def test_pairs_to_text_format(self): + cm = make_cm() + t, a = make_pair("user question", "model response", 1) + text = cm._renderer.pairs_to_text([(t, a)]) + assert "user question" in text + assert "model response" in text + assert "user:" in text + assert "assistant:" in text + + def test_pairs_to_text_multiple_pairs_joined_by_blank_line(self): + cm = make_cm() + pair1 = make_pair("question1", "answer1", 1) + pair2 = make_pair("question2", "answer2", 2) + text = cm._renderer.pairs_to_text([pair1, pair2]) + assert "\n\n" in text diff --git a/test/sdk/core/agents/test_agent_context/unit/test_stats_export.py b/test/sdk/core/agents/test_agent_context/unit/test_stats_export.py new file mode 100644 index 000000000..98b897139 --- /dev/null +++ b/test/sdk/core/agents/test_agent_context/unit/test_stats_export.py @@ -0,0 +1,313 @@ +""" +unit/test_stats_export.py +Tests for stats_export.py pure functions and their ContextManager wrappers. +""" + +from factories import make_cm +from loader import ( + CompressionCallRecord, + ContextManagerConfig, + CurrentSummaryCache, + PreviousSummaryCache, + export_summary_fn, + get_all_compression_stats, + get_step_compression_stats, + get_token_counts, +) + + +# ────────────────────────────────────────────────────────────── +# Helpers +# ────────────────────────────────────────────────────────────── + +def _make_record(call_type="summary", input_tokens=100, output_tokens=50, + input_chars=400, output_chars=200, cache_hit=False): + return CompressionCallRecord( + call_type=call_type, + input_tokens=input_tokens, + output_tokens=output_tokens, + input_chars=input_chars, + output_chars=output_chars, + cache_hit=cache_hit, + ) + + +# ────────────────────────────────────────────────────────────── +# get_step_compression_stats +# ────────────────────────────────────────────────────────────── + +class TestGetStepCompressionStats: + + def test_empty_log_returns_defaults(self): + result = get_step_compression_stats([]) + assert result == { + "calls": 0, "input_tokens": 0, "output_tokens": 0, + "cache_hits": 0, "cache_types": [], + } + + def test_single_real_call(self): + log = [_make_record(call_type="previous_summary")] + result = get_step_compression_stats(log) + assert result["calls"] == 1 + assert result["input_tokens"] == 100 + assert result["output_tokens"] == 50 + assert result["input_chars"] == 400 + assert result["output_chars"] == 200 + assert result["cache_hits"] == 0 + assert result["cache_types"] == [] + + def test_single_cache_hit(self): + log = [_make_record(call_type="previous_cache_hit", cache_hit=True, + input_tokens=0, output_tokens=0)] + result = get_step_compression_stats(log) + assert result["calls"] == 0 + assert result["input_tokens"] == 0 + assert result["output_tokens"] == 0 + assert result["cache_hits"] == 1 + assert result["cache_types"] == ["previous_cache_hit"] + + def test_mixed_real_and_cache_hits(self): + log = [ + _make_record(call_type="previous_summary", input_tokens=200, output_tokens=80), + _make_record(call_type="previous_cache_hit", cache_hit=True, + input_tokens=0, output_tokens=0), + _make_record(call_type="current_summary", input_tokens=150, output_tokens=60), + _make_record(call_type="current_cache_hit", cache_hit=True, + input_tokens=0, output_tokens=0), + ] + result = get_step_compression_stats(log) + assert result["calls"] == 2 + assert result["input_tokens"] == 350 # 200 + 150 + assert result["output_tokens"] == 140 # 80 + 60 + assert result["cache_hits"] == 2 + assert result["cache_types"] == ["previous_cache_hit", "current_cache_hit"] + + def test_input_output_chars_summed(self): + log = [ + _make_record(input_chars=1000, output_chars=500), + _make_record(input_chars=2000, output_chars=800), + ] + result = get_step_compression_stats(log) + assert result["input_chars"] == 3000 + assert result["output_chars"] == 1300 + + def test_cache_types_only_includes_hits(self): + log = [ + _make_record(call_type="previous_summary", cache_hit=False), + _make_record(call_type="current_summary", cache_hit=False), + ] + result = get_step_compression_stats(log) + assert result["cache_types"] == [] + + +# ────────────────────────────────────────────────────────────── +# get_all_compression_stats +# ────────────────────────────────────────────────────────────── + +class TestGetAllCompressionStats: + + def test_empty_log_returns_zeros(self): + result = get_all_compression_stats([]) + assert result == { + "total_calls": 0, "total_attempts": 0, + "total_input_tokens": 0, "total_output_tokens": 0, + "total_cache_hits": 0, + } + + def test_only_cache_hits(self): + log = [ + _make_record(call_type="previous_cache_hit", cache_hit=True, + input_tokens=0, output_tokens=0), + _make_record(call_type="current_cache_hit", cache_hit=True, + input_tokens=0, output_tokens=0), + ] + result = get_all_compression_stats(log) + assert result["total_calls"] == 0 + assert result["total_attempts"] == 2 + assert result["total_input_tokens"] == 0 + assert result["total_output_tokens"] == 0 + assert result["total_cache_hits"] == 2 + + def test_only_real_calls(self): + log = [ + _make_record(input_tokens=300, output_tokens=100), + _make_record(input_tokens=200, output_tokens=80), + ] + result = get_all_compression_stats(log) + assert result["total_calls"] == 2 + assert result["total_attempts"] == 2 + assert result["total_input_tokens"] == 500 + assert result["total_output_tokens"] == 180 + assert result["total_cache_hits"] == 0 + + def test_mixed_real_and_cache(self): + log = [ + _make_record(call_type="previous_summary", input_tokens=300, output_tokens=100), + _make_record(call_type="previous_cache_hit", cache_hit=True, + input_tokens=0, output_tokens=0), + _make_record(call_type="current_summary", input_tokens=200, output_tokens=80), + _make_record(call_type="current_cache_hit", cache_hit=True, + input_tokens=0, output_tokens=0), + ] + result = get_all_compression_stats(log) + assert result["total_calls"] == 2 + assert result["total_attempts"] == 4 + assert result["total_input_tokens"] == 500 + assert result["total_output_tokens"] == 180 + assert result["total_cache_hits"] == 2 + + +# ────────────────────────────────────────────────────────────── +# export_summary +# ────────────────────────────────────────────────────────────── + +class TestExportSummary: + + def test_both_caches_none(self): + config = ContextManagerConfig() + result = export_summary_fn(None, None, config) + assert result["previous_summary"] is None + assert result["current_summary"] is None + assert result["previous_cache_info"] is None + assert result["current_cache_info"] is None + assert result["compression_boundary"]["previous_compressed_pairs"] == 0 + assert result["compression_boundary"]["current_compressed_steps"] == 0 + + def test_both_caches_present(self): + config = ContextManagerConfig(keep_recent_pairs=2, keep_recent_steps=3) + prev = PreviousSummaryCache("prev summary text", 5, "fp_prev") + curr = CurrentSummaryCache("curr summary text", 4, "fp_curr") + result = export_summary_fn(prev, curr, config) + + assert result["previous_summary"] == "prev summary text" + assert result["current_summary"] == "curr summary text" + + assert result["previous_cache_info"]["covered_pairs"] == 5 + assert result["previous_cache_info"]["is_fallback"] is False + + assert result["current_cache_info"]["end_steps"] == 4 + assert result["current_cache_info"]["is_fallback"] is False + + assert result["compression_boundary"]["config_keep_recent_pairs"] == 2 + assert result["compression_boundary"]["config_keep_recent_steps"] == 3 + assert result["compression_boundary"]["previous_compressed_pairs"] == 5 + assert result["compression_boundary"]["previous_retained_pairs"] == 2 + assert result["compression_boundary"]["current_compressed_steps"] == 4 + assert result["compression_boundary"]["current_retained_steps"] == 3 + + def test_previous_only_current_none(self): + config = ContextManagerConfig() + prev = PreviousSummaryCache("prev only", 3, "fp") + result = export_summary_fn(prev, None, config) + + assert result["previous_summary"] == "prev only" + assert result["current_summary"] is None + assert result["previous_cache_info"] is not None + assert result["current_cache_info"] is None + assert result["compression_boundary"]["previous_compressed_pairs"] == 3 + assert result["compression_boundary"]["current_compressed_steps"] == 0 + + def test_current_only_previous_none(self): + config = ContextManagerConfig() + curr = CurrentSummaryCache("curr only", 2, "fp") + result = export_summary_fn(None, curr, config) + + assert result["previous_summary"] is None + assert result["current_summary"] == "curr only" + assert result["previous_cache_info"] is None + assert result["current_cache_info"] is not None + assert result["compression_boundary"]["previous_compressed_pairs"] == 0 + assert result["compression_boundary"]["current_compressed_steps"] == 2 + + def test_fallback_detection_previous(self): + """summary_text containing '[CONTEXT COMPACTION' marks is_fallback=True.""" + config = ContextManagerConfig() + prev = PreviousSummaryCache("[CONTEXT COMPACTION... truncated]", 1, "fp") + result = export_summary_fn(prev, None, config) + assert result["previous_cache_info"]["is_fallback"] is True + + def test_fallback_detection_current(self): + config = ContextManagerConfig() + curr = CurrentSummaryCache("[CONTEXT COMPACTION fallback text]", 1, "fp") + result = export_summary_fn(None, curr, config) + assert result["current_cache_info"]["is_fallback"] is True + + def test_boundary_uses_zero_when_cache_none(self): + config = ContextManagerConfig(keep_recent_pairs=10, keep_recent_steps=20) + result = export_summary_fn(None, None, config) + assert result["compression_boundary"]["previous_compressed_pairs"] == 0 + assert result["compression_boundary"]["current_compressed_steps"] == 0 + assert result["compression_boundary"]["previous_retained_pairs"] == 10 + assert result["compression_boundary"]["current_retained_steps"] == 20 + + +# ────────────────────────────────────────────────────────────── +# get_token_counts +# ────────────────────────────────────────────────────────────── + +class TestGetTokenCounts: + + def test_both_values_present(self): + result = get_token_counts(5000, 2000) + assert result == {"last_uncompressed": 5000, "last_compressed": 2000} + + def test_both_none(self): + result = get_token_counts(None, None) + assert result == {"last_uncompressed": None, "last_compressed": None} + + def test_only_uncompressed(self): + result = get_token_counts(3000, None) + assert result == {"last_uncompressed": 3000, "last_compressed": None} + + def test_only_compressed(self): + result = get_token_counts(None, 500) + assert result == {"last_uncompressed": None, "last_compressed": 500} + + def test_zero_values(self): + result = get_token_counts(0, 0) + assert result == {"last_uncompressed": 0, "last_compressed": 0} + + +# ────────────────────────────────────────────────────────────── +# ContextManager wrapper methods +# ────────────────────────────────────────────────────────────── + +class TestContextManagerStatsMethods: + + def test_get_step_compression_stats_delegates(self): + """ContextManager.get_step_compression_stats() delegates to the pure function.""" + cm = make_cm() + # After a compress call, _step_local_log should have records + assert isinstance(cm.get_step_compression_stats(), dict) + assert cm.get_step_compression_stats()["calls"] == 0 + + def test_get_all_compression_stats_delegates(self): + cm = make_cm() + assert isinstance(cm.get_all_compression_stats(), dict) + assert cm.get_all_compression_stats()["total_calls"] == 0 + + def test_export_summary_delegates(self): + cm = make_cm() + result = cm.export_summary() + assert isinstance(result, dict) + assert "previous_summary" in result + assert "current_summary" in result + assert "previous_cache_info" in result + assert "compression_boundary" in result + + def test_get_token_counts_delegates(self): + cm = make_cm() + # Initial state: both are None + result = cm.get_token_counts() + assert result == {"last_uncompressed": None, "last_compressed": None} + + def test_export_summary_reflects_cache_state(self): + """After setting caches, export_summary should reflect them.""" + cm = make_cm() + cm._previous_summary_cache = PreviousSummaryCache("test prev", 3, "fp1") + cm._current_summary_cache = CurrentSummaryCache("test curr", 2, "fp2") + result = cm.export_summary() + assert result["previous_summary"] == "test prev" + assert result["current_summary"] == "test curr" + assert result["previous_cache_info"]["covered_pairs"] == 3 + assert result["current_cache_info"]["end_steps"] == 2 diff --git a/test/sdk/core/agents/test_agent_context/unit/test_step_renderer.py b/test/sdk/core/agents/test_agent_context/unit/test_step_renderer.py new file mode 100644 index 000000000..8880ab8d7 --- /dev/null +++ b/test/sdk/core/agents/test_agent_context/unit/test_step_renderer.py @@ -0,0 +1,369 @@ +""" +unit/test_step_renderer.py +Tests for StepRenderer methods not covered by other test files: + - render_action_step / _render_segment + - truncate_text_to_tokens + - pairs_to_steps + - render_steps_with_truncation + - compress_history_offline (standalone function) +""" + +import json + +import pytest +from unittest.mock import MagicMock, patch + +from factories import make_cm, make_pair, make_model +from loader import ( + ActionStep, + ContextManagerConfig, + OffloadStore, + StepRenderer, + compress_history_offline as cho, + estimate_tokens_text, +) +from stubs import _SystemPromptStep as SystemPromptStep + + +# ────────────────────────────────────────────────────────────── +# render_action_step +# ────────────────────────────────────────────────────────────── + +class TestRenderActionStep: + + def test_renders_model_output(self): + cm = make_cm() + action = ActionStep(step_number=1, model_output="hello world") + text = cm._renderer.render_action_step(action) + assert "hello world" in text + + def test_renders_observation(self): + """Stub ActionStep puts model_output in to_messages(); + 'Observation:' prefix triggers observation-specific rendering.""" + cm = make_cm() + action = ActionStep(step_number=1, + model_output="Observation:\nobserved result text here") + text = cm._renderer.render_action_step(action) + assert "observed result text here" in text + + def test_tool_call_kept_verbatim(self): + cm = make_cm() + action = ActionStep(step_number=1, model_output="Calling tools: tool1, tool2", + tool_calls=[{"name": "tool1"}]) + text = cm._renderer.render_action_step(action) + assert "Calling tools:" in text + + def test_no_offload_when_limit_zero(self): + """per_step_render_limit=0 disables offload — full text kept.""" + cm = make_cm() + cm._renderer._config.per_step_render_limit = 0 + action = ActionStep(step_number=1, model_output="x" * 5000) + text = cm._renderer.render_action_step(action) + assert "x" * 5000 in text + assert "OBS_OFFLOAD" not in text + assert "CONTENT_OFFLOAD" not in text + + def test_no_offload_when_offload_store_none(self): + cm = make_cm() + cm._renderer._config.per_step_render_limit = 100 + action = ActionStep(step_number=1, model_output="x" * 200) + text = cm._renderer.render_action_step(action, offload_store=None) + assert "x" * 200 in text + assert "CONTENT_OFFLOAD" not in text + + def test_offload_triggered_when_over_limit(self): + cm = make_cm() + cm._renderer._config.per_step_render_limit = 20 + store = OffloadStore() + action = ActionStep(step_number=1, model_output="abcdefghijklmnopqrstuvwxyz" * 3) + text = cm._renderer.render_action_step(action, offload_store=store) + assert "CONTENT_OFFLOAD" in text + assert "handle=" in text + assert len(store) >= 1 + + def test_offload_observation_uses_obs_offload_marker(self): + """Model output starting with 'Observation:' gets OBS_OFFLOAD marker.""" + cm = make_cm() + cm._renderer._config.per_step_render_limit = 10 + store = OffloadStore() + action = ActionStep(step_number=1, + model_output="Observation:\n" + "x" * 200) + text = cm._renderer.render_action_step(action, offload_store=store) + assert "OBS_OFFLOAD" in text + + def test_raw_observation_used_for_offload(self): + """When _raw_observation exists, offload archives the raw content.""" + cm = make_cm() + cm._renderer._config.per_step_render_limit = 10 + store = OffloadStore() + action = ActionStep(step_number=1, + model_output="Observation: short") + action._raw_observation = "Observation: " + "R" * 500 + text = cm._renderer.render_action_step(action, offload_store=store) + assert "OBS_OFFLOAD" in text + handles = [h for h, _ in store.list_active()] + assert len(handles) == 1 + reloaded = store.reload(handles[0]) + assert "R" * 500 in reloaded + + def test_reloaded_content_skips_re_offload(self): + """Content containing 'offload_handle' near the start is not re-offloaded.""" + cm = make_cm() + cm._renderer._config.per_step_render_limit = 10 + store = OffloadStore() + # "offload_handle" inside first 300 chars of model_output triggers skip + action = ActionStep(step_number=1, + model_output='{"offload_handle": "abc123", "data": "' + "y" * 500 + '"}') + text = cm._renderer.render_action_step(action, offload_store=store) + assert "OBS_OFFLOAD" not in text + assert len(store) == 0 + + def test_content_too_large_for_offload_store_falls_back_to_truncation(self): + cm = make_cm() + cm._renderer._config.per_step_render_limit = 10 + store = OffloadStore(max_entry_chars=20) + action = ActionStep(step_number=1, model_output="x" * 500) + text = cm._renderer.render_action_step(action, offload_store=store) + assert "CONTENT_TOO_LARGE_TO_OFFLOAD" in text + + +# ────────────────────────────────────────────────────────────── +# truncate_text_to_tokens +# ────────────────────────────────────────────────────────────── + +class TestTruncateTextToTokens: + + @pytest.mark.parametrize("max_tokens", [0, -1]) + def test_zero_or_negative_returns_empty(self, max_tokens): + cm = make_cm() + assert cm._renderer.truncate_text_to_tokens("hello world", max_tokens) == "" + + def test_within_budget_returns_unchanged(self): + cm = make_cm() + text = "short text" + assert cm._renderer.truncate_text_to_tokens(text, 99999) == text + + def test_over_budget_truncates_keeping_newest(self): + cm = make_cm() + paragraphs = [f"paragraph {i}: " + "X" * 100 for i in range(20)] + text = "\n\n".join(paragraphs) + result = cm._renderer.truncate_text_to_tokens(text, max_tokens=10) + assert len(result) < len(text) + assert "Earlier content truncated" in result + + def test_empty_string_returns_empty(self): + cm = make_cm() + assert cm._renderer.truncate_text_to_tokens("", 100) == "" + + def test_very_small_budget_uses_char_fallback(self): + cm = make_cm() + text = "A" * 5000 + result = cm._renderer.truncate_text_to_tokens(text, max_tokens=1) + assert len(result) < 5000 + assert "Earlier content truncated" in result + + +# ────────────────────────────────────────────────────────────── +# pairs_to_steps +# ────────────────────────────────────────────────────────────── + +class TestPairsToSteps: + + def test_converts_pairs_to_flat_list(self): + cm = make_cm() + t1, a1 = make_pair("task1", "action1", 1) + t2, a2 = make_pair("task2", "action2", 2) + assert cm._renderer.pairs_to_steps([(t1, a1), (t2, a2)]) == [t1, a1, t2, a2] + + def test_empty_and_single_pair_edge_cases(self): + cm = make_cm() + assert cm._renderer.pairs_to_steps([]) == [] + t, a = make_pair("only", "only", 1) + assert cm._renderer.pairs_to_steps([(t, a)]) == [t, a] + + +# ────────────────────────────────────────────────────────────── +# render_steps_with_truncation +# ────────────────────────────────────────────────────────────── + +class TestRenderStepsWithTruncation: + + def test_within_budget_returns_full_text(self): + cm = make_cm() + t, a = make_pair("short task", "short action", 1) + pairs = [(t, a)] + text = cm._renderer.render_steps_with_truncation( + pairs, fmt="pairs", max_tokens=99999 + ) + assert "short task" in text + assert "short action" in text + + def test_action_fmt_within_budget(self): + cm = make_cm() + actions = [ActionStep(step_number=1, model_output="hello")] + text = cm._renderer.render_steps_with_truncation( + actions, fmt="action", max_tokens=99999 + ) + assert "hello" in text + + def test_over_budget_truncates_with_fallback(self): + cm = make_cm() + actions = [ + ActionStep(step_number=i, model_output="X" * 500, action_output="Y" * 500) + for i in range(10) + ] + text = cm._renderer.render_steps_with_truncation( + actions, fmt="action", max_tokens=1, + min_budget_chars=20, task_budget_chars=30, action_budget_chars=40, + ) + assert len(text) < 500 * 10 + assert len(text) > 0 + + def test_default_max_tokens_from_config(self): + cm = make_cm() + cm._renderer._config.max_summary_input_tokens = 1 + actions = [ + ActionStep(step_number=i, model_output="X" * 200, action_output="Y" * 200) + for i in range(5) + ] + text = cm._renderer.render_steps_with_truncation(actions, fmt="action") + assert len(text) < 200 * 5 + + def test_empty_steps_returns_empty(self): + cm = make_cm() + text = cm._renderer.render_steps_with_truncation([], fmt="action") + assert text == "" + + def test_pairs_fmt_uses_user_assistant_prefix(self): + cm = make_cm() + t, a = make_pair("user question", "assistant answer", 1) + text = cm._renderer.render_steps_with_truncation( + [(t, a)], fmt="pairs", max_tokens=99999 + ) + assert "user:" in text + assert "assistant:" in text + + +# ────────────────────────────────────────────────────────────── +# _truncate_text and _reduce_budgets (private, critical logic) +# ────────────────────────────────────────────────────────────── + +class TestTruncateText: + + def test_within_or_at_limit_returns_unchanged(self): + cm = make_cm() + assert cm._renderer._truncate_text("hello", max_len=100) == "hello" + assert cm._renderer._truncate_text("abcde", max_len=5) == "abcde" + + def test_over_limit_truncates_with_mark(self): + cm = make_cm() + result = cm._renderer._truncate_text("a" * 50, max_len=20) + assert "...[Truncated]" in result and len(result) == 20 + + +class TestReduceBudgets: + + def test_reduces_action_budget_first(self): + cm = make_cm() + t, a = cm._renderer._reduce_budgets(800, 400, 80) + assert a == 320 # 400 * 0.8 + assert t == 800 + + def test_reduces_task_budget_when_action_at_min(self): + cm = make_cm() + t, a = cm._renderer._reduce_budgets(800, 80, 80) + assert t == 640 # 800 * 0.8 + assert a == 80 + + def test_both_at_min_no_reduction(self): + cm = make_cm() + t, a = cm._renderer._reduce_budgets(80, 80, 80) + assert t == 80 + assert a == 80 + + def test_action_above_min_clamped_to_min(self): + cm = make_cm() + t, a = cm._renderer._reduce_budgets(800, 100, 80) + assert a == 80 # max(min, 100*0.8) = max(80, 80) + assert t == 800 + + +# ────────────────────────────────────────────────────────────── +# compress_history_offline (standalone function) +# ────────────────────────────────────────────────────────────── + +class TestCompressHistoryOffline: + + def test_empty_pairs_and_no_prev_summary_returns_none(self): + result = cho([], MagicMock()) + assert result["summary"] is None + assert result["is_incremental"] is False + assert result["is_fallback"] is False + + def test_basic_compression_success(self): + model = make_model('{"task_overview": "test overview"}') + pairs = [("user question", "assistant answer")] + result = cho(pairs, model) + # format_summary_output reformats JSON — check content, not exact string + assert result["summary"] is not None + parsed = json.loads(result["summary"]) + assert parsed["task_overview"] == "test overview" + assert result["is_incremental"] is False + assert result["is_fallback"] is False + assert "user question" in result["input_text"] + assert result["input_chars"] > 0 + + def test_incremental_with_previous_summary(self): + model = make_model('{"task_overview": "updated summary"}') + pairs = [("new question", "new answer")] + result = cho(pairs, model, previous_summary="old summary text") + assert result["summary"] is not None + assert result["is_incremental"] is True + assert "old summary text" in result["input_text"] + assert "new question" in result["input_text"] + + @pytest.mark.parametrize("prev_summary, expected_phrase", [ + (None, "Create a structured checkpoint summary"), + ("existing summary", "Update the summary"), + ]) + def test_prompt_varies_by_mode(self, prev_summary, expected_phrase): + model = MagicMock() + response = MagicMock() + response.content = '{"task_overview": "x"}' + model.return_value = response + + cho([("q", "a")], model, previous_summary=prev_summary) + call_args = model.call_args[0][0] + roles = [m.role for m in call_args] + assert "system" in roles and "user" in roles + user_text = next(m for m in call_args if m.role == "user").content[0]["text"] + assert expected_phrase in user_text + + def test_llm_exception_falls_back_to_truncation(self): + """When LLM raises an unrecoverable exception, fallback summary is produced.""" + model = MagicMock() + model.side_effect = Exception("unrecoverable error") + + result = cho([("task", "answer")], model) + + assert result["summary"] is not None + assert result["is_fallback"] is True + assert "[CONTEXT COMPACTION" in result["summary"] + + def test_multiple_pairs_rendered(self): + model = make_model('{"task_overview": "summary"}') + pairs = [ + ("question 1", "answer 1"), + ("question 2", "answer 2"), + ("question 3", "answer 3"), + ] + result = cho(pairs, model) + assert "question 1" in result["input_text"] + assert "question 2" in result["input_text"] + assert "question 3" in result["input_text"] + + @pytest.mark.parametrize("config", [None, ContextManagerConfig(max_summary_input_tokens=100)]) + def test_config_default_and_custom(self, config): + model = make_model('{"task_overview": "x"}') + result = cho([("q", "a")], model, config=config) + assert result["summary"] is not None