diff --git a/python/antigravity/harness_server.py b/python/antigravity/harness_server.py index c11165d..6a8aaa9 100644 --- a/python/antigravity/harness_server.py +++ b/python/antigravity/harness_server.py @@ -97,6 +97,33 @@ def hydrate_ax_history_to_steps(historical_messages) -> list[Step]: class AntigravityHarnessServiceServicer(ax_pb2_grpc.HarnessServiceServicer): """Implements the ax.HarnessService protocol over gRPC.""" + def __init__(self): + # TODO: Implement an eviction/idle-timeout policy to prevent unbounded memory growth in production. + self._agents = {} + self._lock = asyncio.Lock() + + async def _get_or_create_agent(self, conversation_id: str) -> Agent: + async with self._lock: + if conversation_id not in self._agents: + global loaded_config + if not loaded_config: + raise ValueError("Agent config is not loaded on the server") + print(f"[gRPC] Creating new Agent instance for conv_id={conversation_id}") + agent = Agent(loaded_config) + await agent.__aenter__() + self._agents[conversation_id] = agent + return self._agents[conversation_id] + + async def cleanup(self): + print("[gRPC] Cleaning up agent instances...") + async with self._lock: + for conv_id, agent in self._agents.items(): + try: + await agent.__aexit__(None, None, None) + except Exception as e: + print(f"Error closing agent for conv_id={conv_id}: {e}") + self._agents.clear() + async def Connect(self, request_iterator, context): # Each HarnessRequest{start} drives one stateless turn; the stream stays # open across turns until the client half-closes. @@ -129,69 +156,62 @@ async def _run_turn(self, request): return latest_query_text = latest_message.content.text.text - # 2. Initialize the Antigravity Agent session - global loaded_config - if not loaded_config: - yield ax_pb2.HarnessResponse( - conversation_id=request.conversation_id, - end=ax_pb2.HarnessEnd(state=ax_pb2.STATE_FAILED, error_message="Agent config is not loaded on the server") - ) - return - + # 2. Initialize or get the Antigravity Agent session try: - async with Agent(loaded_config) as agent: - conversation = agent.conversation - - # Hydrate history - print(f"[gRPC] Hydrating {len(historical_messages)} historical messages...") - history_steps = hydrate_ax_history_to_steps(historical_messages) - conversation._steps.extend(history_steps) - - # Run the turn with streaming - print(f"[gRPC] Running chat query: {latest_query_text}") - response = await conversation.chat(latest_query_text) - - async for chunk in response.chunks: - if isinstance(chunk, Text): - msg = ax_pb2.Message( - role="assistant", - content=content_pb2.Content(text=content_pb2.TextContent(text=chunk.text)) - ) - yield ax_pb2.HarnessResponse( - conversation_id=request.conversation_id, - outputs=ax_pb2.HarnessOutputs(messages=[msg]) - ) - elif isinstance(chunk, Thought): - summary = [ - content_pb2.ThoughtSummaryContent(text=content_pb2.TextContent(text=chunk.text)) - ] - msg = ax_pb2.Message( - role="model", - content=content_pb2.Content(thought=content_pb2.ThoughtContent(summary=summary)) - ) - yield ax_pb2.HarnessResponse( - conversation_id=request.conversation_id, - outputs=ax_pb2.HarnessOutputs(messages=[msg]) - ) - elif isinstance(chunk, ToolCall): - struct_args = Struct() - struct_args.update(chunk.args) - - func_call = content_pb2.FunctionCallContent( - name=str(chunk.name), - arguments=struct_args - ) - msg = ax_pb2.Message( - role="model", - content=content_pb2.Content(tool_call=content_pb2.ToolCallContent( - id=chunk.id or "", - function_call=func_call - )) - ) - yield ax_pb2.HarnessResponse( - conversation_id=request.conversation_id, - outputs=ax_pb2.HarnessOutputs(messages=[msg]) - ) + agent = await self._get_or_create_agent(request.conversation_id) + conversation = agent.conversation + + # Hydrate history (clear first to prevent duplication) + print(f"[gRPC] Hydrating {len(historical_messages)} historical messages...") + history_steps = hydrate_ax_history_to_steps(historical_messages) + conversation._steps.clear() + conversation._steps.extend(history_steps) + + # Run the turn with streaming + print(f"[gRPC] Running chat query: {latest_query_text}") + response = await conversation.chat(latest_query_text) + + async for chunk in response.chunks: + if isinstance(chunk, Text): + msg = ax_pb2.Message( + role="assistant", + content=content_pb2.Content(text=content_pb2.TextContent(text=chunk.text)) + ) + yield ax_pb2.HarnessResponse( + conversation_id=request.conversation_id, + outputs=ax_pb2.HarnessOutputs(messages=[msg]) + ) + elif isinstance(chunk, Thought): + summary = [ + content_pb2.ThoughtSummaryContent(text=content_pb2.TextContent(text=chunk.text)) + ] + msg = ax_pb2.Message( + role="model", + content=content_pb2.Content(thought=content_pb2.ThoughtContent(summary=summary)) + ) + yield ax_pb2.HarnessResponse( + conversation_id=request.conversation_id, + outputs=ax_pb2.HarnessOutputs(messages=[msg]) + ) + elif isinstance(chunk, ToolCall): + struct_args = Struct() + struct_args.update(chunk.args) + + func_call = content_pb2.FunctionCallContent( + name=str(chunk.name), + arguments=struct_args + ) + msg = ax_pb2.Message( + role="model", + content=content_pb2.Content(tool_call=content_pb2.ToolCallContent( + id=chunk.id or "", + function_call=func_call + )) + ) + yield ax_pb2.HarnessResponse( + conversation_id=request.conversation_id, + outputs=ax_pb2.HarnessOutputs(messages=[msg]) + ) # Yield completion end frame yield ax_pb2.HarnessResponse( @@ -210,7 +230,8 @@ async def _run_turn(self, request): async def serve(host: str, port: int): server = grpc.aio.server() - ax_pb2_grpc.add_HarnessServiceServicer_to_server(AntigravityHarnessServiceServicer(), server) + servicer = AntigravityHarnessServiceServicer() + ax_pb2_grpc.add_HarnessServiceServicer_to_server(servicer, server) # Serve the standard gRPC health protocol. health_servicer = health.aio.HealthServicer() @@ -221,7 +242,10 @@ async def serve(host: str, port: int): server.add_insecure_port(listen_addr) print(f"Starting gRPC harness server on {listen_addr}...") await server.start() - await server.wait_for_termination() + try: + await server.wait_for_termination() + finally: + await servicer.cleanup() def resolve_localhost(): """Ensure `localhost` resolves to 127.0.0.1. diff --git a/python/antigravity/harness_server_test.py b/python/antigravity/harness_server_test.py index ae14f6c..2c866d7 100644 --- a/python/antigravity/harness_server_test.py +++ b/python/antigravity/harness_server_test.py @@ -95,6 +95,94 @@ async def request_iter(): asyncio.run(_run()) +def test_grpc_connect_agent_reused(mock_config, monkeypatch): + async def _run(): + server = grpc.aio.server() + servicer = AntigravityHarnessServiceServicer() + ax_pb2_grpc.add_HarnessServiceServicer_to_server(servicer, server) + port = server.add_insecure_port("localhost:0") + await server.start() + + addr = f"localhost:{port}" + async with grpc.aio.insecure_channel(addr) as channel: + stub = ax_pb2_grpc.HarnessServiceStub(channel) + + class MockConversation: + def __init__(self): + self._steps = [] + async def chat(self, text): + class MockResponse: + def __init__(self): + self.chunks = self._chunk_generator() + async def _chunk_generator(self): + from google.antigravity.types import Text + yield Text(text="Response", step_index=0) + return MockResponse() + + agent_instances = [] + class MockAgent: + def __init__(self, config): + self.conversation = MockConversation() + self.closed = False + agent_instances.append(self) + async def __aenter__(self): + return self + async def __aexit__(self, exc_type, exc, tb): + self.closed = True + + monkeypatch.setattr("python.antigravity.harness_server.Agent", MockAgent) + + # Fire first turn for conv-1 + req1 = ax_pb2.HarnessRequest( + conversation_id="conv-1", + harness_id="antigravity", + start=ax_pb2.HarnessStart( + messages=[ax_pb2.Message(role="user", content=content_pb2.Content(text=content_pb2.TextContent(text="Hi")))] + ) + ) + async def req_iter1(): + yield req1 + async for _ in stub.Connect(req_iter1()): + pass + + # Fire second turn for same conv-1 + req2 = ax_pb2.HarnessRequest( + conversation_id="conv-1", + harness_id="antigravity", + start=ax_pb2.HarnessStart( + messages=[ax_pb2.Message(role="user", content=content_pb2.Content(text=content_pb2.TextContent(text="Hi again")))] + ) + ) + async def req_iter2(): + yield req2 + async for _ in stub.Connect(req_iter2()): + pass + + # Fire third turn for a different conv-2 + req3 = ax_pb2.HarnessRequest( + conversation_id="conv-2", + harness_id="antigravity", + start=ax_pb2.HarnessStart( + messages=[ax_pb2.Message(role="user", content=content_pb2.Content(text=content_pb2.TextContent(text="New conv")))] + ) + ) + async def req_iter3(): + yield req3 + async for _ in stub.Connect(req_iter3()): + pass + + # Verify only 2 agents were instantiated (reused the first one) + assert len(agent_instances) == 2 + + # Verify cleanup closes all agents + await servicer.cleanup() + assert all(a.closed for a in agent_instances) + + await server.stop(0) + + asyncio.run(_run()) + + def test_health_check(): async def _run(): from grpc_health.v1 import health, health_pb2, health_pb2_grpc