22
33import asyncio
44import inspect
5- from typing import Any , Callable , Protocol , TypedDict , Unpack , get_type_hints
5+ from typing import Any , Coroutine , Protocol , TypedDict , Union , Unpack , get_type_hints
66from uuid import UUID
77
88from pydantic import BaseModel , ConfigDict , PrivateAttr , field_serializer
@@ -22,13 +22,18 @@ class TaskHandlerKwargs(TypedDict):
2222
2323
2424class TaskHandler (Protocol ):
25- """Protocol for task handler functions.
25+ """Protocol for task handler functions (sync or async) .
2626
27- Handlers accept **kwargs matching HandlerKwargs structure.
28- Return value is ignored. Can be sync or async .
27+ Handlers accept **kwargs matching TaskHandlerKwargs structure.
28+ Must return a Pydantic BaseModel (or Coroutine that resolves to BaseModel) .
2929 """
3030
31- def __call__ (self , ** kwargs : Unpack [TaskHandlerKwargs ]) -> None : ...
31+ __name__ : str # All functions have __name__ attribute
32+
33+ def __call__ (
34+ self ,
35+ ** kwargs : Unpack [TaskHandlerKwargs ],
36+ ) -> Union [BaseModel , Coroutine [Any , Any , BaseModel ]]: ...
3237
3338
3439class TaskDefinition :
@@ -51,10 +56,12 @@ async def research(agent_id: UUID, context: ResearchContext):
5156 """
5257
5358 name : str
54- handler : Callable [..., Any ]
59+ handler : TaskHandler
5560 context_class : type [BaseModel ]
61+ # TODO we handle this with serialize/deserialize when writing the result so this can probably go away
62+ result_class : type [BaseModel ]
5663
57- def __init__ (self , name : str , handler : Callable [..., Any ] ):
64+ def __init__ (self , name : str , handler : TaskHandler ):
5865 """Initialize task definition.
5966
6067 Args:
@@ -63,12 +70,14 @@ def __init__(self, name: str, handler: Callable[..., Any]):
6370
6471 Raises:
6572 TypeError: If handler doesn't have a typed 'context' parameter with BaseModel subclass
73+ TypeError: If handler doesn't have a return type annotation with BaseModel subclass
6674 """
6775 self .name = name
6876 self .handler = handler
6977 self .context_class = self ._infer_context_class (handler )
78+ self .result_class = self ._infer_result_class (handler )
7079
71- def _infer_context_class (self , handler : Callable [..., Any ] ) -> type [BaseModel ]:
80+ def _infer_context_class (self , handler : TaskHandler ) -> type [BaseModel ]:
7281 """Infer context class from handler's type annotations.
7382
7483 Looks for a 'context' parameter with a Pydantic BaseModel type hint.
@@ -98,6 +107,36 @@ def _infer_context_class(self, handler: Callable[..., Any]) -> type[BaseModel]:
98107
99108 return context_type
100109
110+ def _infer_result_class (self , handler : TaskHandler ) -> type [BaseModel ]:
111+ """Infer result class from handler's return type annotation.
112+
113+ Looks for a return annotation with a Pydantic BaseModel type hint.
114+
115+ Args:
116+ handler: The task handler function
117+
118+ Returns:
119+ Result class (BaseModel subclass)
120+
121+ Raises:
122+ TypeError: If return annotation is missing or not a BaseModel subclass
123+ """
124+ hints = get_type_hints (handler )
125+ if "return" not in hints :
126+ raise TypeError (
127+ f"Task handler '{ handler .__name__ } ' must have a return type "
128+ f"annotation with a BaseModel subclass"
129+ )
130+
131+ return_type = hints ["return" ]
132+ if not (inspect .isclass (return_type ) and issubclass (return_type , BaseModel )):
133+ raise TypeError (
134+ f"Task handler '{ handler .__name__ } ' return type must be a "
135+ f"BaseModel subclass, got { return_type } "
136+ )
137+
138+ return return_type
139+
101140
102141class Task (BaseModel ):
103142 """Represents a background task instance.
@@ -187,17 +226,16 @@ def create(cls, task_name: str, context: BaseModel) -> Task:
187226 agent_id = agent_id ,
188227 )
189228
190- async def execute (self ) -> Any :
229+ async def execute (self ) -> BaseModel | None :
191230 """Execute the task using its bound definition's handler.
192231
193232 Manages task lifecycle: marks started, runs handler, marks completed/errored.
194233
195234 Returns:
196- Handler return value
235+ Handler return value, or None if handler raised an exception
197236
198237 Raises:
199238 RuntimeError: If task has not been bound to a definition
200- Exception: Re-raises any exception from the handler after marking errored
201239 """
202240 if self ._definition is None :
203241 raise RuntimeError ("Task must be bound to a definition before execution" )
@@ -214,12 +252,12 @@ async def execute(self) -> Any:
214252 "context" : self .context ,
215253 }
216254
255+ result : BaseModel
217256 if asyncio .iscoroutinefunction (self ._definition .handler ):
218257 result = await self ._definition .handler (** kwargs )
219258 else :
220- result = self ._definition .handler (** kwargs )
259+ result = self ._definition .handler (** kwargs ) # type: ignore[assignment]
221260
222- # Store result for pipeline coordination
223261 await state .aset_result (
224262 self .agent_id ,
225263 result ,
@@ -239,3 +277,4 @@ async def execute(self) -> Any:
239277 message = CONF .activity_message_error .format (error = e ),
240278 status = activity .Status .ERROR ,
241279 )
280+ return None
0 commit comments