Skip to content

Commit f429083

Browse files
committed
Type checking with ty, better protocol defs.
1 parent 87aaca6 commit f429083

11 files changed

Lines changed: 227 additions & 161 deletions

File tree

pyproject.toml

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,19 @@ dev-dependencies = [
4747
"pytest-asyncio>=0.23.0",
4848
"pytest-cov>=4.1.0",
4949
"ruff>=0.5.0",
50-
"mypy>=1.10.0",
50+
"ty>=0.0.1a7",
5151
"fakeredis>=2.32.1",
5252
]
5353

5454
[tool.ruff]
5555
line-length = 100
5656
target-version = "py311"
5757

58-
[tool.mypy]
59-
python_version = "3.11"
60-
warn_return_any = true
61-
warn_unused_configs = true
62-
disallow_untyped_defs = true
63-
disallow_incomplete_defs = true
58+
[tool.ty.environment]
59+
python-version = "3.11"
60+
61+
[tool.ty.src]
62+
include = ["src/agentexec"]
6463

6564
[tool.pytest.ini_options]
6665
asyncio_mode = "auto"

src/agentexec/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async def search(agent_id: UUID, context: Input) -> Output:
5959

6060
# OpenAI runner is only available if agents package is installed
6161
try:
62-
from agentexec.runners import OpenAIRunner
62+
from agentexec.runners import OpenAIRunner # type: ignore[possibly-missing-import]
6363

6464
__all__.append("OpenAIRunner")
6565
except ImportError:

src/agentexec/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ class Config(BaseSettings):
7474
)
7575

7676
state_backend: str = Field(
77-
default="redis",
78-
description="State backend to use (redis, memory, etc.)",
77+
default="agentexec.state.redis_backend",
78+
description="State backend to use (fully-qualified module path)",
7979
validation_alias="AGENTEXEC_STATE_BACKEND",
8080
)
8181

src/agentexec/core/task.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from __future__ import annotations
22

3-
import asyncio
43
import inspect
5-
from typing import Any, Protocol, TypedDict, get_type_hints
4+
from typing import Any, Protocol, TypeAlias, TypedDict, cast, get_type_hints
65
from uuid import UUID
76

87
from pydantic import BaseModel, ConfigDict, PrivateAttr, field_serializer
@@ -21,21 +20,38 @@ class TaskHandlerKwargs(TypedDict):
2120
context: BaseModel
2221

2322

24-
TaskResult = BaseModel # Union[BaseModel, tuple[BaseModel, ...]]
23+
TaskResult: TypeAlias = BaseModel
2524

2625

27-
class TaskHandler(Protocol):
26+
class _SyncTaskHandler(Protocol):
2827
"""Protocol for task handler functions."""
2928

3029
__name__: str
3130

3231
def __call__(
3332
self,
33+
*,
34+
agent_id: UUID,
35+
context: BaseModel,
36+
) -> TaskResult: ...
37+
38+
39+
class _AsyncTaskHandler(Protocol):
40+
"""Protocol for async task handler functions."""
41+
42+
__name__: str
43+
44+
async def __call__(
45+
self,
46+
*,
3447
agent_id: UUID,
3548
context: BaseModel,
3649
) -> TaskResult: ...
3750

3851

52+
TaskHandler: TypeAlias = _SyncTaskHandler | _AsyncTaskHandler
53+
54+
3955
class TaskDefinition:
4056
"""Definition of a task type (created at registration time).
4157
@@ -83,6 +99,15 @@ def __init__(
8399
self.context_type = context_type or self._infer_context_type(handler)
84100
self.result_type = result_type or self._infer_result_type(handler)
85101

102+
async def __call__(self, agent_id: UUID, context: BaseModel) -> TaskResult:
103+
"""Delegate calls to the handler function."""
104+
if inspect.iscoroutinefunction(self.handler):
105+
handler = cast(_AsyncTaskHandler, self.handler)
106+
return await handler(agent_id=agent_id, context=context)
107+
else:
108+
handler = cast(_SyncTaskHandler, self.handler)
109+
return handler(agent_id=agent_id, context=context)
110+
86111
def _infer_context_type(self, handler: TaskHandler) -> type[BaseModel]:
87112
"""Infer context class from handler's type annotations.
88113
@@ -244,24 +269,19 @@ async def execute(self) -> TaskResult | None:
244269
)
245270

246271
try:
247-
result: TaskResult
248-
if asyncio.iscoroutinefunction(self._definition.handler):
249-
result = await self._definition.handler(
250-
agent_id=self.agent_id,
251-
context=self.context,
252-
)
253-
else:
254-
result = self._definition.handler(
255-
agent_id=self.agent_id,
256-
context=self.context,
257-
)
258-
259-
await state.aset_result(
260-
self.agent_id,
261-
result,
262-
ttl_seconds=CONF.result_ttl,
272+
result = await self._definition(
273+
agent_id=self.agent_id,
274+
context=self.context,
263275
)
264276

277+
# TODO ensure we are properly supporting None return values
278+
if isinstance(result, BaseModel):
279+
await state.aset_result(
280+
self.agent_id,
281+
result,
282+
ttl_seconds=CONF.result_ttl,
283+
)
284+
265285
activity.update(
266286
agent_id=self.agent_id,
267287
message=CONF.activity_message_complete,

src/agentexec/pipeline.py

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
Any,
99
Callable,
1010
ClassVar,
11-
Coroutine,
11+
Protocol,
12+
TypeAlias,
1213
Union,
14+
cast,
1315
get_type_hints,
1416
get_origin,
1517
get_args,
@@ -20,13 +22,47 @@
2022

2123
from agentexec import activity
2224
from agentexec.core import queue
23-
from agentexec.core.task import Task
25+
from agentexec.core.task import Task, TaskResult
2426

2527
if TYPE_CHECKING:
2628
from agentexec.worker import WorkerPool
2729

28-
StepResult = Union[BaseModel, tuple[BaseModel, ...]]
29-
StepHandler = Callable[..., Union[StepResult, Coroutine[Any, Any, StepResult]]]
30+
StepResult: TypeAlias = Union[TaskResult, tuple[TaskResult, ...]]
31+
32+
33+
class _SyncStepHandler(Protocol):
34+
"""Protocol for pipeline step handler methods.
35+
36+
Step handlers are methods on a pipeline class that receive
37+
one or more BaseModel context arguments from the previous step.
38+
"""
39+
40+
__name__: str
41+
42+
def __call__(
43+
self,
44+
instance: _PipelineBase,
45+
**kwargs: BaseModel,
46+
) -> StepResult: ...
47+
48+
49+
class _AsyncStepHandler(Protocol):
50+
"""Protocol for async pipeline step handler methods.
51+
52+
Step handlers are methods on a pipeline class that receive
53+
one or more BaseModel context arguments from the previous step.
54+
"""
55+
56+
__name__: str
57+
58+
async def __call__(
59+
self,
60+
instance: _PipelineBase,
61+
**kwargs: BaseModel,
62+
) -> StepResult: ...
63+
64+
65+
StepHandler: TypeAlias = _SyncStepHandler | _AsyncStepHandler
3066

3167

3268
def _format_pipeline_name(cls: type) -> str:
@@ -38,7 +74,7 @@ class _PipelineBaseMeta(type):
3874
"""Metaclass that registers pipeline subclasses with their bound pipeline."""
3975

4076
@classmethod
41-
def bind_pipeline(mcs, pipeline: Pipeline) -> type:
77+
def bind_pipeline(mcs, pipeline: Pipeline) -> type[_PipelineBase]:
4278
"""Create a new PipelineBase class bound to the given pipeline.
4379
4480
Args:
@@ -66,8 +102,8 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
66102
if cls.__name__ == "PipelineBase":
67103
return
68104

69-
cls._pipeline.bind_user_pipeline(cls)
70-
cls._pipeline.register_task()
105+
cls._pipeline._bind_user_pipeline(cls)
106+
cls._pipeline._register_task()
71107

72108

73109
class StepDefinition:
@@ -76,17 +112,17 @@ class StepDefinition:
76112
name: str
77113
order: Any
78114
handler: StepHandler
79-
return_type: type | None
80-
param_types: dict[str, type | None]
115+
return_type: type[BaseModel] | None
116+
param_types: dict[str, type[BaseModel] | None]
81117
_description: str | None = None
82118

83119
def __init__(
84120
self,
85121
name: str,
86122
order: Any,
87123
handler: StepHandler,
88-
return_type: type | None,
89-
param_types: dict[str, type | None],
124+
return_type: type[BaseModel] | None,
125+
param_types: dict[str, type[BaseModel] | None],
90126
description: str | None = None,
91127
) -> None:
92128
self.name = name
@@ -103,6 +139,15 @@ def description(self) -> str:
103139
return self._description
104140
return self.handler.__name__
105141

142+
async def __call__(self, instance: _PipelineBase, **kwargs: BaseModel) -> StepResult:
143+
"""Invoke the step handler."""
144+
if asyncio.iscoroutinefunction(self.handler):
145+
handler = cast(_AsyncStepHandler, self.handler)
146+
return await handler(instance, **kwargs)
147+
else:
148+
handler = cast(_SyncStepHandler, self.handler)
149+
return handler(instance, **kwargs)
150+
106151

107152
class Pipeline:
108153
"""Orchestrates multi-step task workflows.
@@ -171,15 +216,15 @@ def Base(self) -> type[_PipelineBase]:
171216
"""
172217
return _PipelineBaseMeta.bind_pipeline(self)
173218

174-
def bind_user_pipeline(self, cls: type[_PipelineBase]) -> None:
219+
def _bind_user_pipeline(self, cls: type[_PipelineBase]) -> None:
175220
"""Manually set the pipeline implementation class.
176221
177222
Args:
178223
cls: Pipeline implementation class
179224
"""
180225
self._user_pipeline_class = cls
181226

182-
def register_task(self) -> None:
227+
def _register_task(self) -> None:
183228
"""
184229
Register Task handler with the pipeline's worker pool
185230
"""
@@ -190,7 +235,7 @@ def register_task(self) -> None:
190235

191236
self._pool._add_task(
192237
name=self.name,
193-
func=self._run_task, # type: ignore[arg-type]
238+
func=self._run_task,
194239
context_type=self._input_type,
195240
result_type=self._output_type,
196241
)
@@ -306,9 +351,10 @@ async def run(self, context: BaseModel) -> StepResult:
306351

307352
async def _run_task(
308353
self,
354+
*,
309355
agent_id: UUID,
310356
context: BaseModel,
311-
) -> StepResult:
357+
) -> TaskResult:
312358
"""Run the pipeline as a task handler.
313359
314360
Args:
@@ -330,6 +376,7 @@ async def _run_task(
330376
)
331377
_context = await self._run_step(step, _context)
332378

379+
assert isinstance(_context, TaskResult), "Final step must return BaseModel"
333380
return _context
334381

335382
async def _run_step(
@@ -355,9 +402,7 @@ async def _run_step(
355402
else:
356403
kwargs = {}
357404

358-
if asyncio.iscoroutinefunction(step.handler):
359-
return await step.handler(instance, **kwargs)
360-
return step.handler(instance, **kwargs)
405+
return await step(instance, **kwargs)
361406

362407
def _validate_type_flow(self) -> None:
363408
"""Verify that step return types match next step's parameters.

src/agentexec/runners/openai.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from agents import Agent, MaxTurnsExceeded, Runner, function_tool
66
from agents.items import TResponseInputItem
77
from agents.result import RunResult, RunResultStreaming
8+
from openai.types.responses.easy_input_message_param import EasyInputMessageParam
89

910
from agentexec.runners.base import BaseAgentRunner, _RunnerTools
1011

@@ -29,7 +30,7 @@ def _extract_input(e: MaxTurnsExceeded) -> list[TResponseInputItem]:
2930
final_input: list[TResponseInputItem] = (
3031
list(e.run_data.input)
3132
if isinstance(e.run_data.input, list)
32-
else [{"role": "user", "content": e.run_data.input}]
33+
else [EasyInputMessageParam(role="user", content=e.run_data.input)]
3334
)
3435

3536
# Add all the conversation items that were generated
@@ -107,6 +108,12 @@ def __init__(
107108
# Override with OpenAI-specific tools
108109
self.tools = _OpenAIRunnerTools(self.agent_id)
109110

111+
def _wrap_up_prompt(self) -> EasyInputMessageParam:
112+
return EasyInputMessageParam(
113+
role="system",
114+
content=self.prompts.wrap_up,
115+
)
116+
110117
async def run(
111118
self,
112119
agent: Agent[Any],
@@ -142,12 +149,7 @@ async def run(
142149

143150
logger.info("Max turns exceeded, attempting recovery")
144151
final_input = _extract_input(e)
145-
final_input.append(
146-
{
147-
"role": "user",
148-
"content": self.prompts.wrap_up,
149-
}
150-
)
152+
final_input.append(self._wrap_up_prompt())
151153
result = await Runner.run(
152154
agent,
153155
final_input,
@@ -212,12 +214,7 @@ async def run_streamed(
212214

213215
logger.info("Max turns exceeded, attempting recovery")
214216
final_input = _extract_input(e)
215-
final_input.append(
216-
{
217-
"role": "user",
218-
"content": self.prompts.wrap_up,
219-
}
220-
)
217+
final_input.append(self._wrap_up_prompt())
221218
result = Runner.run_streamed(
222219
agent,
223220
final_input,

0 commit comments

Comments
 (0)