Skip to content

Commit 8ead2ef

Browse files
committed
Add new tasks as delegates of the last agent in the project.
1 parent 5f7fca9 commit 8ead2ef

2 files changed

Lines changed: 60 additions & 15 deletions

File tree

  • agentstack

agentstack/frameworks/openai_swarm.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1-
from typing import TYPE_CHECKING, Optional, Any, Callable
1+
from typing import Optional, Any, Callable
22
from pathlib import Path
33
import ast
44
from agentstack import conf, log
55
from agentstack.exceptions import ValidationError
6+
from agentstack.generation import InsertionPoint
67
from agentstack._tools import ToolConfig
78
from agentstack.tasks import TaskConfig
89
from agentstack.agents import AgentConfig
910
from agentstack.generation import asttools
1011
from agentstack import graph
1112

12-
if TYPE_CHECKING:
13-
from agentstack.generation import InsertionPoint
1413

1514
NAME: str = "OpenAI Swarm"
1615
ENTRYPOINT: Path = Path('src/stack.py')
@@ -62,19 +61,54 @@ def add_task_method(self, task: TaskConfig):
6261
pos, _ = self.get_node_range(main_method)
6362

6463
code = f""" @agentstack.task
65-
def {task.name}(self):
64+
def {task.name}(self, messages: list[str] = []) -> Agent:
6665
task_config = agentstack.get_task('{task.name}')
66+
agent = getattr(self, task_config.agent)
6767
messages = [
68+
*messages,
6869
task_config.prompt,
6970
]
70-
agent = getattr(self, task_config.agent)
7171
return agent(messages)"""
7272

7373
if not self.source[:pos].endswith('\n'):
7474
code = '\n\n' + code
7575
if not self.source[pos:].startswith('\n'):
7676
code += '\n\n'
7777
self.edit_node_range(pos, pos, code)
78+
79+
# add a new task to the last agent in the stack
80+
existing_agent_methods = self.get_agent_methods()
81+
if not len(existing_agent_methods):
82+
return # no agents to update
83+
84+
# add a call to `self._handoff(task_name)` to the front of the update_method's
85+
# `function` argument which is a list of functions
86+
update_method = existing_agent_methods[-1]
87+
try:
88+
agent_instance = asttools.find_method_calls(update_method, 'Agent')[0]
89+
except IndexError:
90+
raise ValidationError(f"Agent method `{update_method.name}` does not instantiate `Agent` in {ENTRYPOINT}")
91+
92+
existing_agent_tools = asttools.find_kwarg_in_method_call(agent_instance, 'functions')
93+
if not existing_agent_tools:
94+
raise ValidationError(
95+
f"`@agent` method `{update_method.name}` does not have a keyword argument `functions` in {ENTRYPOINT}"
96+
)
97+
98+
assert isinstance(existing_agent_tools.value, ast.List)
99+
existing_elts = existing_agent_tools.value.elts
100+
existing_elts.insert(0, ast.Call(
101+
func=ast.Attribute(
102+
value=ast.Name(id='self', ctx=ast.Load()),
103+
attr='_handoff',
104+
ctx=ast.Load(),
105+
),
106+
args=[ast.Constant(value=task.name)],
107+
keywords=[],
108+
))
109+
new_node = ast.List(elts=existing_elts, ctx=ast.Load())
110+
start, end = self.get_node_range(existing_agent_tools.value)
111+
self.edit_node_range(start, end, new_node)
78112

79113
def get_agent_methods(self) -> list[ast.FunctionDef]:
80114
"""An `agent` method is a method decorated with `@agent`."""
@@ -92,12 +126,16 @@ def add_agent_method(self, agent: AgentConfig) -> None:
92126
pos, _ = self.get_node_range(main_method)
93127

94128
code = f""" @agentstack.agent
95-
def {agent.name}(self, messages: list[str] = []):
129+
def {agent.name}(self, messages: list[str] = []) -> Agent:
96130
agent_config = agentstack.get_agent('{agent.name}')
131+
messages = [
132+
agent_config.prompt,
133+
*messages,
134+
]
97135
return Agent(
98136
name=agent_config.name,
99137
model=agent_config.model,
100-
instructions='\\n'.join([agent_config.prompt, *messages, ]),
138+
instructions='\\n'.join(messages),
101139
functions=[],
102140
)"""
103141

@@ -259,7 +297,7 @@ def get_agent_tool_names(agent_name: str) -> list[Any]:
259297
return entrypoint.get_agent_tool_names(agent_name)
260298

261299

262-
def add_agent(agent: AgentConfig, position: Optional['InsertionPoint'] = None) -> None:
300+
def add_agent(agent: AgentConfig, position: Optional[InsertionPoint] = None) -> None:
263301
"""
264302
Add an agent method to the entrypoint.
265303
"""

agentstack/templates/openai_swarm/{{cookiecutter.project_metadata.project_slug}}/src/stack.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
class {{ cookiecutter.project_metadata.project_name|replace('-', '')|replace('_', '')|capitalize }}Stack:
66

7-
def _handoff(self, agent_name: str):
8-
"""Return a tool for handing off to another agent."""
9-
agent = getattr(self, agent_name)
10-
def func(context_variables: list[str]):
11-
return agent(context_variables)
7+
def _handoff(self, task_name: str):
8+
"""Return a task formatted as a tool for handing off to another agent."""
9+
task = getattr(self, task_name)
10+
def func(messages: list[str] = []) -> Agent:
11+
return task(messages=messages)
12+
func.__name__ = task_name
1213
return func
1314

14-
def _get_first_task(self):
15+
def _get_first_task(self) -> Agent:
1516
"""Get the first task."""
1617
task_name = agentstack.get_all_task_names()[0]
1718
return getattr(self, task_name)()
@@ -26,5 +27,11 @@ def run(self, inputs: list[str]):
2627
)
2728

2829
for message in response.messages:
29-
agentstack.log.info(message['content'])
30+
if message.get('tool_calls'):
31+
for tool_call in message['tool_calls']:
32+
agentstack.log.notify(f"Calling tool `{tool_call['function']['name']}`")
33+
agentstack.log.debug(tool_call['function']['arguments'])
34+
elif message.get('role') != 'tool':
35+
agentstack.log.notify(f"{message['role']}:")
36+
agentstack.log.info(message['content'])
3037

0 commit comments

Comments
 (0)