1- from typing import TYPE_CHECKING , Optional , Any , Callable
1+ from typing import Optional , Any , Callable
22from pathlib import Path
33import ast
44from agentstack import conf , log
55from agentstack .exceptions import ValidationError
6+ from agentstack .generation import InsertionPoint
67from agentstack ._tools import ToolConfig
78from agentstack .tasks import TaskConfig
89from agentstack .agents import AgentConfig
910from agentstack .generation import asttools
1011from agentstack import graph
1112
12- if TYPE_CHECKING :
13- from agentstack .generation import InsertionPoint
1413
1514NAME : str = "OpenAI Swarm"
1615ENTRYPOINT : Path = Path ('src/stack.py' )
@@ -62,19 +61,54 @@ def add_task_method(self, task: TaskConfig):
6261 pos , _ = self .get_node_range (main_method )
6362
6463 code = f""" @agentstack.task
65- def { task .name } (self) :
64+ def { task .name } (self, messages: list[str] = []) -> Agent :
6665 task_config = agentstack.get_task('{ task .name } ')
66+ agent = getattr(self, task_config.agent)
6767 messages = [
68+ *messages,
6869 task_config.prompt,
6970 ]
70- agent = getattr(self, task_config.agent)
7171 return agent(messages)"""
7272
7373 if not self .source [:pos ].endswith ('\n ' ):
7474 code = '\n \n ' + code
7575 if not self .source [pos :].startswith ('\n ' ):
7676 code += '\n \n '
7777 self .edit_node_range (pos , pos , code )
78+
79+ # add a new task to the last agent in the stack
80+ existing_agent_methods = self .get_agent_methods ()
81+ if not len (existing_agent_methods ):
82+ return # no agents to update
83+
84+ # add a call to `self._handoff(task_name)` to the front of the update_method's
85+ # `function` argument which is a list of functions
86+ update_method = existing_agent_methods [- 1 ]
87+ try :
88+ agent_instance = asttools .find_method_calls (update_method , 'Agent' )[0 ]
89+ except IndexError :
90+ raise ValidationError (f"Agent method `{ update_method .name } ` does not instantiate `Agent` in { ENTRYPOINT } " )
91+
92+ existing_agent_tools = asttools .find_kwarg_in_method_call (agent_instance , 'functions' )
93+ if not existing_agent_tools :
94+ raise ValidationError (
95+ f"`@agent` method `{ update_method .name } ` does not have a keyword argument `functions` in { ENTRYPOINT } "
96+ )
97+
98+ assert isinstance (existing_agent_tools .value , ast .List )
99+ existing_elts = existing_agent_tools .value .elts
100+ existing_elts .insert (0 , ast .Call (
101+ func = ast .Attribute (
102+ value = ast .Name (id = 'self' , ctx = ast .Load ()),
103+ attr = '_handoff' ,
104+ ctx = ast .Load (),
105+ ),
106+ args = [ast .Constant (value = task .name )],
107+ keywords = [],
108+ ))
109+ new_node = ast .List (elts = existing_elts , ctx = ast .Load ())
110+ start , end = self .get_node_range (existing_agent_tools .value )
111+ self .edit_node_range (start , end , new_node )
78112
79113 def get_agent_methods (self ) -> list [ast .FunctionDef ]:
80114 """An `agent` method is a method decorated with `@agent`."""
@@ -92,12 +126,16 @@ def add_agent_method(self, agent: AgentConfig) -> None:
92126 pos , _ = self .get_node_range (main_method )
93127
94128 code = f""" @agentstack.agent
95- def { agent .name } (self, messages: list[str] = []):
129+ def { agent .name } (self, messages: list[str] = []) -> Agent :
96130 agent_config = agentstack.get_agent('{ agent .name } ')
131+ messages = [
132+ agent_config.prompt,
133+ *messages,
134+ ]
97135 return Agent(
98136 name=agent_config.name,
99137 model=agent_config.model,
100- instructions='\\ n'.join([agent_config.prompt, * messages, ] ),
138+ instructions='\\ n'.join(messages),
101139 functions=[],
102140 )"""
103141
@@ -259,7 +297,7 @@ def get_agent_tool_names(agent_name: str) -> list[Any]:
259297 return entrypoint .get_agent_tool_names (agent_name )
260298
261299
262- def add_agent (agent : AgentConfig , position : Optional [' InsertionPoint' ] = None ) -> None :
300+ def add_agent (agent : AgentConfig , position : Optional [InsertionPoint ] = None ) -> None :
263301 """
264302 Add an agent method to the entrypoint.
265303 """
0 commit comments