Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 88 additions & 64 deletions python/antigravity/harness_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,33 @@ def hydrate_ax_history_to_steps(historical_messages) -> list[Step]:
class AntigravityHarnessServiceServicer(ax_pb2_grpc.HarnessServiceServicer):
"""Implements the ax.HarnessService protocol over gRPC."""

def __init__(self):
# TODO: Implement an eviction/idle-timeout policy to prevent unbounded memory growth in production.
self._agents = {}
self._lock = asyncio.Lock()

async def _get_or_create_agent(self, conversation_id: str) -> Agent:
async with self._lock:
if conversation_id not in self._agents:
global loaded_config
if not loaded_config:
raise ValueError("Agent config is not loaded on the server")
print(f"[gRPC] Creating new Agent instance for conv_id={conversation_id}")
agent = Agent(loaded_config)
await agent.__aenter__()
self._agents[conversation_id] = agent
return self._agents[conversation_id]

async def cleanup(self):
print("[gRPC] Cleaning up agent instances...")
async with self._lock:
for conv_id, agent in self._agents.items():
try:
await agent.__aexit__(None, None, None)
except Exception as e:
print(f"Error closing agent for conv_id={conv_id}: {e}")
self._agents.clear()

async def Connect(self, request_iterator, context):
# Each HarnessRequest{start} drives one stateless turn; the stream stays
# open across turns until the client half-closes.
Expand Down Expand Up @@ -129,69 +156,62 @@ async def _run_turn(self, request):
return
latest_query_text = latest_message.content.text.text

# 2. Initialize the Antigravity Agent session
global loaded_config
if not loaded_config:
yield ax_pb2.HarnessResponse(
conversation_id=request.conversation_id,
end=ax_pb2.HarnessEnd(state=ax_pb2.STATE_FAILED, error_message="Agent config is not loaded on the server")
)
return

# 2. Initialize or get the Antigravity Agent session
try:
async with Agent(loaded_config) as agent:
conversation = agent.conversation

# Hydrate history
print(f"[gRPC] Hydrating {len(historical_messages)} historical messages...")
history_steps = hydrate_ax_history_to_steps(historical_messages)
conversation._steps.extend(history_steps)

# Run the turn with streaming
print(f"[gRPC] Running chat query: {latest_query_text}")
response = await conversation.chat(latest_query_text)

async for chunk in response.chunks:
if isinstance(chunk, Text):
msg = ax_pb2.Message(
role="assistant",
content=content_pb2.Content(text=content_pb2.TextContent(text=chunk.text))
)
yield ax_pb2.HarnessResponse(
conversation_id=request.conversation_id,
outputs=ax_pb2.HarnessOutputs(messages=[msg])
)
elif isinstance(chunk, Thought):
summary = [
content_pb2.ThoughtSummaryContent(text=content_pb2.TextContent(text=chunk.text))
]
msg = ax_pb2.Message(
role="model",
content=content_pb2.Content(thought=content_pb2.ThoughtContent(summary=summary))
)
yield ax_pb2.HarnessResponse(
conversation_id=request.conversation_id,
outputs=ax_pb2.HarnessOutputs(messages=[msg])
)
elif isinstance(chunk, ToolCall):
struct_args = Struct()
struct_args.update(chunk.args)

func_call = content_pb2.FunctionCallContent(
name=str(chunk.name),
arguments=struct_args
)
msg = ax_pb2.Message(
role="model",
content=content_pb2.Content(tool_call=content_pb2.ToolCallContent(
id=chunk.id or "",
function_call=func_call
))
)
yield ax_pb2.HarnessResponse(
conversation_id=request.conversation_id,
outputs=ax_pb2.HarnessOutputs(messages=[msg])
)
agent = await self._get_or_create_agent(request.conversation_id)
conversation = agent.conversation

# Hydrate history (clear first to prevent duplication)
print(f"[gRPC] Hydrating {len(historical_messages)} historical messages...")
history_steps = hydrate_ax_history_to_steps(historical_messages)
conversation._steps.clear()
conversation._steps.extend(history_steps)

# Run the turn with streaming
print(f"[gRPC] Running chat query: {latest_query_text}")
response = await conversation.chat(latest_query_text)

async for chunk in response.chunks:
if isinstance(chunk, Text):
msg = ax_pb2.Message(
role="assistant",
content=content_pb2.Content(text=content_pb2.TextContent(text=chunk.text))
)
yield ax_pb2.HarnessResponse(
conversation_id=request.conversation_id,
outputs=ax_pb2.HarnessOutputs(messages=[msg])
)
elif isinstance(chunk, Thought):
summary = [
content_pb2.ThoughtSummaryContent(text=content_pb2.TextContent(text=chunk.text))
]
msg = ax_pb2.Message(
role="model",
content=content_pb2.Content(thought=content_pb2.ThoughtContent(summary=summary))
)
yield ax_pb2.HarnessResponse(
conversation_id=request.conversation_id,
outputs=ax_pb2.HarnessOutputs(messages=[msg])
)
elif isinstance(chunk, ToolCall):
struct_args = Struct()
struct_args.update(chunk.args)

func_call = content_pb2.FunctionCallContent(
name=str(chunk.name),
arguments=struct_args
)
msg = ax_pb2.Message(
role="model",
content=content_pb2.Content(tool_call=content_pb2.ToolCallContent(
id=chunk.id or "",
function_call=func_call
))
)
yield ax_pb2.HarnessResponse(
conversation_id=request.conversation_id,
outputs=ax_pb2.HarnessOutputs(messages=[msg])
)

# Yield completion end frame
yield ax_pb2.HarnessResponse(
Expand All @@ -210,7 +230,8 @@ async def _run_turn(self, request):

async def serve(host: str, port: int):
server = grpc.aio.server()
ax_pb2_grpc.add_HarnessServiceServicer_to_server(AntigravityHarnessServiceServicer(), server)
servicer = AntigravityHarnessServiceServicer()
ax_pb2_grpc.add_HarnessServiceServicer_to_server(servicer, server)

# Serve the standard gRPC health protocol.
health_servicer = health.aio.HealthServicer()
Expand All @@ -221,7 +242,10 @@ async def serve(host: str, port: int):
server.add_insecure_port(listen_addr)
print(f"Starting gRPC harness server on {listen_addr}...")
await server.start()
await server.wait_for_termination()
try:
await server.wait_for_termination()
finally:
await servicer.cleanup()

def resolve_localhost():
"""Ensure `localhost` resolves to 127.0.0.1.
Expand Down
88 changes: 88 additions & 0 deletions python/antigravity/harness_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,94 @@ async def request_iter():
asyncio.run(_run())


def test_grpc_connect_agent_reused(mock_config, monkeypatch):
async def _run():
server = grpc.aio.server()
servicer = AntigravityHarnessServiceServicer()
ax_pb2_grpc.add_HarnessServiceServicer_to_server(servicer, server)
port = server.add_insecure_port("localhost:0")
await server.start()

addr = f"localhost:{port}"
async with grpc.aio.insecure_channel(addr) as channel:
stub = ax_pb2_grpc.HarnessServiceStub(channel)

class MockConversation:
def __init__(self):
self._steps = []
async def chat(self, text):
class MockResponse:
def __init__(self):
self.chunks = self._chunk_generator()
async def _chunk_generator(self):
from google.antigravity.types import Text
yield Text(text="Response", step_index=0)
return MockResponse()

agent_instances = []
class MockAgent:
def __init__(self, config):
self.conversation = MockConversation()
self.closed = False
agent_instances.append(self)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
self.closed = True

monkeypatch.setattr("python.antigravity.harness_server.Agent", MockAgent)

# Fire first turn for conv-1
req1 = ax_pb2.HarnessRequest(
conversation_id="conv-1",
harness_id="antigravity",
start=ax_pb2.HarnessStart(
messages=[ax_pb2.Message(role="user", content=content_pb2.Content(text=content_pb2.TextContent(text="Hi")))]
)
)
async def req_iter1():
yield req1
async for _ in stub.Connect(req_iter1()):
pass

# Fire second turn for same conv-1
req2 = ax_pb2.HarnessRequest(
conversation_id="conv-1",
harness_id="antigravity",
start=ax_pb2.HarnessStart(
messages=[ax_pb2.Message(role="user", content=content_pb2.Content(text=content_pb2.TextContent(text="Hi again")))]
)
)
async def req_iter2():
yield req2
async for _ in stub.Connect(req_iter2()):
pass

# Fire third turn for a different conv-2
req3 = ax_pb2.HarnessRequest(
conversation_id="conv-2",
harness_id="antigravity",
start=ax_pb2.HarnessStart(
messages=[ax_pb2.Message(role="user", content=content_pb2.Content(text=content_pb2.TextContent(text="New conv")))]
)
)
async def req_iter3():
yield req3
async for _ in stub.Connect(req_iter3()):
pass

# Verify only 2 agents were instantiated (reused the first one)
assert len(agent_instances) == 2

# Verify cleanup closes all agents
await servicer.cleanup()
assert all(a.closed for a in agent_instances)

await server.stop(0)

asyncio.run(_run())


def test_health_check():
async def _run():
from grpc_health.v1 import health, health_pb2, health_pb2_grpc
Expand Down
Loading