Skip to content

Commit a023a4d

Browse files
committed
Use internal serialization format instead of pickle
1 parent 80b2338 commit a023a4d

10 files changed

Lines changed: 357 additions & 101 deletions

File tree

src/agentexec/core/queue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ async def research(agent_id: UUID, context: ResearchContext):
7070

7171

7272
async def dequeue(
73+
*,
7374
queue_name: str | None = None,
7475
timeout: int = 1,
7576
) -> dict[str, Any] | None:

src/agentexec/core/results.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,32 @@
1-
"""Task result storage and retrieval."""
2-
31
from __future__ import annotations
42

53
import asyncio
64
import time
7-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING
6+
7+
from pydantic import BaseModel
88

99
from agentexec import state
1010

1111
if TYPE_CHECKING:
1212
from agentexec.core.task import Task
1313

1414

15-
async def get_result(task: Task, timeout: float = 300) -> Any:
15+
DEFAULT_TIMEOUT: int = 300 # TODO improve this polling approach
16+
17+
18+
async def get_result(task: Task, timeout: int = DEFAULT_TIMEOUT) -> BaseModel:
1619
"""Poll for a task result.
1720
1821
Waits for a task to complete and returns its result.
22+
Uses automatic type reconstruction from serialized class information.
1923
2024
Args:
2125
task: The Task instance to wait for
2226
timeout: Maximum seconds to wait for result
2327
2428
Returns:
25-
The task's return value
29+
Deserialized result as BaseModel instance
2630
2731
Raises:
2832
TimeoutError: If result not available within timeout
@@ -38,22 +42,23 @@ async def get_result(task: Task, timeout: float = 300) -> Any:
3842
raise TimeoutError(f"Result for {task.agent_id} not available within {timeout}s")
3943

4044

41-
async def gather(*tasks: Task) -> tuple[Any, ...]:
45+
async def gather(*tasks: Task, timeout: int = DEFAULT_TIMEOUT) -> tuple[BaseModel, ...]:
4246
"""Wait for multiple tasks and return their results.
4347
4448
Similar to asyncio.gather, but for background tasks.
4549
4650
Args:
4751
*tasks: Task instances to wait for
52+
timeout: Maximum seconds to wait for each result
4853
4954
Returns:
50-
Tuple of results in the same order as input tasks
55+
Tuple of deserialized results as BaseModel instances
5156
5257
Example:
5358
brand = await ax.enqueue("brand_research", ctx)
5459
market = await ax.enqueue("market_research", ctx)
5560
5661
brand_result, market_result = await ax.gather(brand, market)
5762
"""
58-
results = await asyncio.gather(*[get_result(task) for task in tasks])
63+
results = await asyncio.gather(*[get_result(task, timeout) for task in tasks])
5964
return tuple(results)

src/agentexec/core/task.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import asyncio
44
import inspect
5-
from typing import Any, Callable, Protocol, TypedDict, Unpack, get_type_hints
5+
from typing import Any, Coroutine, Protocol, TypedDict, Union, Unpack, get_type_hints
66
from uuid import UUID
77

88
from pydantic import BaseModel, ConfigDict, PrivateAttr, field_serializer
@@ -22,13 +22,18 @@ class TaskHandlerKwargs(TypedDict):
2222

2323

2424
class TaskHandler(Protocol):
25-
"""Protocol for task handler functions.
25+
"""Protocol for task handler functions (sync or async).
2626
27-
Handlers accept **kwargs matching HandlerKwargs structure.
28-
Return value is ignored. Can be sync or async.
27+
Handlers accept **kwargs matching TaskHandlerKwargs structure.
28+
Must return a Pydantic BaseModel (or Coroutine that resolves to BaseModel).
2929
"""
3030

31-
def __call__(self, **kwargs: Unpack[TaskHandlerKwargs]) -> None: ...
31+
__name__: str # All functions have __name__ attribute
32+
33+
def __call__(
34+
self,
35+
**kwargs: Unpack[TaskHandlerKwargs],
36+
) -> Union[BaseModel, Coroutine[Any, Any, BaseModel]]: ...
3237

3338

3439
class TaskDefinition:
@@ -51,10 +56,12 @@ async def research(agent_id: UUID, context: ResearchContext):
5156
"""
5257

5358
name: str
54-
handler: Callable[..., Any]
59+
handler: TaskHandler
5560
context_class: type[BaseModel]
61+
# TODO we handle this with serialize/deserialize when writing the result so this can probably go away
62+
result_class: type[BaseModel]
5663

57-
def __init__(self, name: str, handler: Callable[..., Any]):
64+
def __init__(self, name: str, handler: TaskHandler):
5865
"""Initialize task definition.
5966
6067
Args:
@@ -63,12 +70,14 @@ def __init__(self, name: str, handler: Callable[..., Any]):
6370
6471
Raises:
6572
TypeError: If handler doesn't have a typed 'context' parameter with BaseModel subclass
73+
TypeError: If handler doesn't have a return type annotation with BaseModel subclass
6674
"""
6775
self.name = name
6876
self.handler = handler
6977
self.context_class = self._infer_context_class(handler)
78+
self.result_class = self._infer_result_class(handler)
7079

71-
def _infer_context_class(self, handler: Callable[..., Any]) -> type[BaseModel]:
80+
def _infer_context_class(self, handler: TaskHandler) -> type[BaseModel]:
7281
"""Infer context class from handler's type annotations.
7382
7483
Looks for a 'context' parameter with a Pydantic BaseModel type hint.
@@ -98,6 +107,36 @@ def _infer_context_class(self, handler: Callable[..., Any]) -> type[BaseModel]:
98107

99108
return context_type
100109

110+
def _infer_result_class(self, handler: TaskHandler) -> type[BaseModel]:
111+
"""Infer result class from handler's return type annotation.
112+
113+
Looks for a return annotation with a Pydantic BaseModel type hint.
114+
115+
Args:
116+
handler: The task handler function
117+
118+
Returns:
119+
Result class (BaseModel subclass)
120+
121+
Raises:
122+
TypeError: If return annotation is missing or not a BaseModel subclass
123+
"""
124+
hints = get_type_hints(handler)
125+
if "return" not in hints:
126+
raise TypeError(
127+
f"Task handler '{handler.__name__}' must have a return type "
128+
f"annotation with a BaseModel subclass"
129+
)
130+
131+
return_type = hints["return"]
132+
if not (inspect.isclass(return_type) and issubclass(return_type, BaseModel)):
133+
raise TypeError(
134+
f"Task handler '{handler.__name__}' return type must be a "
135+
f"BaseModel subclass, got {return_type}"
136+
)
137+
138+
return return_type
139+
101140

102141
class Task(BaseModel):
103142
"""Represents a background task instance.
@@ -187,17 +226,16 @@ def create(cls, task_name: str, context: BaseModel) -> Task:
187226
agent_id=agent_id,
188227
)
189228

190-
async def execute(self) -> Any:
229+
async def execute(self) -> BaseModel | None:
191230
"""Execute the task using its bound definition's handler.
192231
193232
Manages task lifecycle: marks started, runs handler, marks completed/errored.
194233
195234
Returns:
196-
Handler return value
235+
Handler return value, or None if handler raised an exception
197236
198237
Raises:
199238
RuntimeError: If task has not been bound to a definition
200-
Exception: Re-raises any exception from the handler after marking errored
201239
"""
202240
if self._definition is None:
203241
raise RuntimeError("Task must be bound to a definition before execution")
@@ -214,12 +252,12 @@ async def execute(self) -> Any:
214252
"context": self.context,
215253
}
216254

255+
result: BaseModel
217256
if asyncio.iscoroutinefunction(self._definition.handler):
218257
result = await self._definition.handler(**kwargs)
219258
else:
220-
result = self._definition.handler(**kwargs)
259+
result = self._definition.handler(**kwargs) # type: ignore[assignment]
221260

222-
# Store result for pipeline coordination
223261
await state.aset_result(
224262
self.agent_id,
225263
result,
@@ -239,3 +277,4 @@ async def execute(self) -> Any:
239277
message=CONF.activity_message_error.format(error=e),
240278
status=activity.Status.ERROR,
241279
)
280+
return None

src/agentexec/state/__init__.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# cspell:ignore acheck
22

3-
from typing import AsyncGenerator, Coroutine
3+
from typing import AsyncGenerator, Coroutine, cast
44
from uuid import UUID
55

6+
from pydantic import BaseModel
7+
68
from agentexec.config import CONF
9+
from agentexec.state.backend import StateBackend
710

811
KEY_RESULT = (CONF.key_prefix, "result")
912
KEY_EVENT = (CONF.key_prefix, "event")
@@ -12,10 +15,11 @@
1215

1316
match CONF.state_backend:
1417
case "redis":
15-
from agentexec.state import redis_backend as backend
18+
from agentexec.state import redis_backend as _backend
1619
case _:
1720
raise RuntimeError(f"Unsupported state backend: {CONF.state_backend}.")
1821

22+
backend: StateBackend = cast(StateBackend, _backend)
1923

2024
__all__ = [
2125
"backend",
@@ -34,42 +38,50 @@
3438
]
3539

3640

37-
def get_result(agent_id: UUID | str) -> object | None:
41+
def get_result(agent_id: UUID | str) -> BaseModel | None:
3842
"""Get result for an agent (sync).
3943
44+
Returns deserialized BaseModel instance with automatic type reconstruction.
45+
4046
Args:
4147
agent_id: Unique agent identifier (UUID or string)
4248
4349
Returns:
44-
Deserialized result object or None if not found
50+
Deserialized BaseModel or None if not found
4551
"""
4652
data = backend.get(backend.format_key(*KEY_RESULT, str(agent_id)))
4753
return backend.deserialize(data) if data else None
4854

4955

50-
def aget_result(agent_id: UUID | str) -> Coroutine[None, None, object | None]:
56+
def aget_result(agent_id: UUID | str) -> Coroutine[None, None, BaseModel | None]:
5157
"""Get result for an agent (async).
5258
59+
Returns deserialized BaseModel instance with automatic type reconstruction.
60+
5361
Args:
5462
agent_id: Unique agent identifier (UUID or string)
5563
5664
Returns:
57-
Coroutine that resolves to deserialized result object or None if not found
65+
Coroutine that resolves to deserialized BaseModel or None if not found
5866
"""
5967

60-
async def _get() -> object | None:
68+
async def _get() -> BaseModel | None:
6169
data = await backend.aget(backend.format_key(*KEY_RESULT, str(agent_id)))
6270
return backend.deserialize(data) if data else None
6371

6472
return _get()
6573

6674

67-
def set_result(agent_id: UUID | str, data: object, ttl_seconds: int | None = None) -> bool:
75+
def set_result(
76+
agent_id: UUID | str,
77+
data: BaseModel,
78+
ttl_seconds: int | None = None,
79+
) -> bool:
6880
"""Set result for an agent (sync).
6981
7082
Args:
7183
agent_id: Unique agent identifier (UUID or string)
72-
data: Result data to store
84+
data: Result data (must be Pydantic BaseModel)
7385
ttl_seconds: Optional time-to-live in seconds
7486
7587
Returns:
@@ -84,14 +96,14 @@ def set_result(agent_id: UUID | str, data: object, ttl_seconds: int | None = Non
8496

8597
def aset_result(
8698
agent_id: UUID | str,
87-
data: object,
99+
data: BaseModel,
88100
ttl_seconds: int | None = None,
89101
) -> Coroutine[None, None, bool]:
90102
"""Set result for an agent (async).
91103
92104
Args:
93105
agent_id: Unique agent identifier (UUID or string)
94-
data: Result data to store
106+
data: Result data (must be Pydantic BaseModel)
95107
ttl_seconds: Optional time-to-live in seconds
96108
97109
Returns:

src/agentexec/state/backend.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import AsyncGenerator, Coroutine, Optional, Protocol
44

5+
from pydantic import BaseModel
6+
57

68
class StateBackend(Protocol):
79
"""Protocol defining the state backend interface.
@@ -147,7 +149,7 @@ def publish(self, channel: str, message: str) -> None:
147149
"""
148150
...
149151

150-
async def subscribe(self, channel: str) -> AsyncGenerator[str, None]:
152+
def subscribe(self, channel: str) -> AsyncGenerator[str, None]:
151153
"""Subscribe to a channel and yield messages.
152154
153155
Args:
@@ -171,35 +173,38 @@ def format_key(self, *args: str) -> str:
171173
...
172174

173175
# Serialization
174-
def serialize(self, obj: object) -> bytes:
175-
"""Serialize a Python object to bytes.
176+
def serialize(self, obj: BaseModel) -> bytes:
177+
"""Serialize a Pydantic BaseModel to bytes.
176178
177-
The serialization method is backend-specific. For example, Redis might
178-
use pickle, while other backends might use JSON or msgpack.
179+
Stores the fully qualified class name alongside the data to enable
180+
automatic type reconstruction during deserialization.
179181
180182
Args:
181-
obj: Python object to serialize
183+
obj: Pydantic BaseModel instance to serialize
182184
183185
Returns:
184186
Serialized bytes
185187
186188
Raises:
187-
Exception: If serialization fails
189+
TypeError: If obj is not a BaseModel instance
188190
"""
189191
...
190192

191-
def deserialize(self, data: bytes) -> object:
192-
"""Deserialize bytes back to a Python object.
193+
def deserialize(self, data: bytes) -> BaseModel:
194+
"""Deserialize bytes back to a Pydantic BaseModel instance.
193195
194-
Must be compatible with the serialize() method for this backend.
196+
Uses the stored class information to dynamically import and reconstruct
197+
the original type.
195198
196199
Args:
197200
data: Serialized bytes
198201
199202
Returns:
200-
Deserialized Python object
203+
Deserialized BaseModel instance
201204
202205
Raises:
203-
Exception: If deserialization fails
206+
ImportError: If the class module cannot be imported
207+
AttributeError: If the class does not exist in the module
208+
ValueError: If the data is invalid
204209
"""
205210
...

0 commit comments

Comments
 (0)