88 Any ,
99 Callable ,
1010 ClassVar ,
11- Coroutine ,
11+ Protocol ,
12+ TypeAlias ,
1213 Union ,
14+ cast ,
1315 get_type_hints ,
1416 get_origin ,
1517 get_args ,
2022
2123from agentexec import activity
2224from agentexec .core import queue
23- from agentexec .core .task import Task
25+ from agentexec .core .task import Task , TaskResult
2426
2527if TYPE_CHECKING :
2628 from agentexec .worker import WorkerPool
2729
28- StepResult = Union [BaseModel , tuple [BaseModel , ...]]
29- StepHandler = Callable [..., Union [StepResult , Coroutine [Any , Any , StepResult ]]]
30+ StepResult : TypeAlias = Union [TaskResult , tuple [TaskResult , ...]]
31+
32+
33+ class _SyncStepHandler (Protocol ):
34+ """Protocol for pipeline step handler methods.
35+
36+ Step handlers are methods on a pipeline class that receive
37+ one or more BaseModel context arguments from the previous step.
38+ """
39+
40+ __name__ : str
41+
42+ def __call__ (
43+ self ,
44+ instance : _PipelineBase ,
45+ ** kwargs : BaseModel ,
46+ ) -> StepResult : ...
47+
48+
49+ class _AsyncStepHandler (Protocol ):
50+ """Protocol for async pipeline step handler methods.
51+
52+ Step handlers are methods on a pipeline class that receive
53+ one or more BaseModel context arguments from the previous step.
54+ """
55+
56+ __name__ : str
57+
58+ async def __call__ (
59+ self ,
60+ instance : _PipelineBase ,
61+ ** kwargs : BaseModel ,
62+ ) -> StepResult : ...
63+
64+
65+ StepHandler : TypeAlias = _SyncStepHandler | _AsyncStepHandler
3066
3167
3268def _format_pipeline_name (cls : type ) -> str :
@@ -38,7 +74,7 @@ class _PipelineBaseMeta(type):
3874 """Metaclass that registers pipeline subclasses with their bound pipeline."""
3975
4076 @classmethod
41- def bind_pipeline (mcs , pipeline : Pipeline ) -> type :
77+ def bind_pipeline (mcs , pipeline : Pipeline ) -> type [ _PipelineBase ] :
4278 """Create a new PipelineBase class bound to the given pipeline.
4379
4480 Args:
@@ -66,8 +102,8 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
66102 if cls .__name__ == "PipelineBase" :
67103 return
68104
69- cls ._pipeline .bind_user_pipeline (cls )
70- cls ._pipeline .register_task ()
105+ cls ._pipeline ._bind_user_pipeline (cls )
106+ cls ._pipeline ._register_task ()
71107
72108
73109class StepDefinition :
@@ -76,17 +112,17 @@ class StepDefinition:
76112 name : str
77113 order : Any
78114 handler : StepHandler
79- return_type : type | None
80- param_types : dict [str , type | None ]
115+ return_type : type [ BaseModel ] | None
116+ param_types : dict [str , type [ BaseModel ] | None ]
81117 _description : str | None = None
82118
83119 def __init__ (
84120 self ,
85121 name : str ,
86122 order : Any ,
87123 handler : StepHandler ,
88- return_type : type | None ,
89- param_types : dict [str , type | None ],
124+ return_type : type [ BaseModel ] | None ,
125+ param_types : dict [str , type [ BaseModel ] | None ],
90126 description : str | None = None ,
91127 ) -> None :
92128 self .name = name
@@ -103,6 +139,15 @@ def description(self) -> str:
103139 return self ._description
104140 return self .handler .__name__
105141
142+ async def __call__ (self , instance : _PipelineBase , ** kwargs : BaseModel ) -> StepResult :
143+ """Invoke the step handler."""
144+ if asyncio .iscoroutinefunction (self .handler ):
145+ handler = cast (_AsyncStepHandler , self .handler )
146+ return await handler (instance , ** kwargs )
147+ else :
148+ handler = cast (_SyncStepHandler , self .handler )
149+ return handler (instance , ** kwargs )
150+
106151
107152class Pipeline :
108153 """Orchestrates multi-step task workflows.
@@ -171,15 +216,15 @@ def Base(self) -> type[_PipelineBase]:
171216 """
172217 return _PipelineBaseMeta .bind_pipeline (self )
173218
174- def bind_user_pipeline (self , cls : type [_PipelineBase ]) -> None :
219+ def _bind_user_pipeline (self , cls : type [_PipelineBase ]) -> None :
175220 """Manually set the pipeline implementation class.
176221
177222 Args:
178223 cls: Pipeline implementation class
179224 """
180225 self ._user_pipeline_class = cls
181226
182- def register_task (self ) -> None :
227+ def _register_task (self ) -> None :
183228 """
184229 Register Task handler with the pipeline's worker pool
185230 """
@@ -190,7 +235,7 @@ def register_task(self) -> None:
190235
191236 self ._pool ._add_task (
192237 name = self .name ,
193- func = self ._run_task , # type: ignore[arg-type]
238+ func = self ._run_task ,
194239 context_type = self ._input_type ,
195240 result_type = self ._output_type ,
196241 )
@@ -306,9 +351,10 @@ async def run(self, context: BaseModel) -> StepResult:
306351
307352 async def _run_task (
308353 self ,
354+ * ,
309355 agent_id : UUID ,
310356 context : BaseModel ,
311- ) -> StepResult :
357+ ) -> TaskResult :
312358 """Run the pipeline as a task handler.
313359
314360 Args:
@@ -330,6 +376,7 @@ async def _run_task(
330376 )
331377 _context = await self ._run_step (step , _context )
332378
379+ assert isinstance (_context , TaskResult ), "Final step must return BaseModel"
333380 return _context
334381
335382 async def _run_step (
@@ -355,9 +402,7 @@ async def _run_step(
355402 else :
356403 kwargs = {}
357404
358- if asyncio .iscoroutinefunction (step .handler ):
359- return await step .handler (instance , ** kwargs )
360- return step .handler (instance , ** kwargs )
405+ return await step (instance , ** kwargs )
361406
362407 def _validate_type_flow (self ) -> None :
363408 """Verify that step return types match next step's parameters.
0 commit comments