@@ -146,6 +146,7 @@ def __init__(self, agent: BaseAgent[OutputT]) -> None:
146146 tools = _prepare_langchain_tools (agent .tools )
147147
148148 system_prompt = agent .system_prompt
149+ structured_subagents : list [str ] = []
149150 if agent .agents :
150151 seen_names : set [str ] = set ()
151152 for subagent in agent .agents :
@@ -161,6 +162,9 @@ def __init__(self, agent: BaseAgent[OutputT]) -> None:
161162 seen_names .add (subagent .name )
162163 tools .append (tool )
163164
165+ if subagent .input_schema is not None :
166+ structured_subagents .append (subagent .name )
167+
164168 system_prompt = AGENT_AS_TOOLS_PROMPT + "\n " + system_prompt
165169
166170 before_user_middlewares , after_user_middlewares = _debugging_middleware (
@@ -211,7 +215,87 @@ async def awrap_tool_call(
211215
212216 return resp
213217
218+ class _SubagentArgumentPacker (LC_AgentMiddleware ):
219+ # For non-structured subagents, the SubagentCall.args field is an `str | dict[str, Any]`,
220+ # to differentiate that we wrap the resulting args in an SubagentLCArgs.
221+ #
222+ # This middleware performs the corresponding pack/unpack at the two
223+ # points in the LangChain call graph where raw args are needed/retreived.
224+ #
225+ # TODO: once we move middlewares into one LC middleware, we should move
226+ # that piece of logic there (DVPL-12959).
227+ @override
228+ async def awrap_model_call (
229+ self ,
230+ request : LC_ModelRequest ,
231+ handler : Callable [[LC_ModelRequest ], Awaitable [LC_ModelCallResult ]],
232+ ) -> LC_ModelCallResult :
233+ # Unpack existing messages.
234+ messages : list [LC_AnyMessage ] = []
235+ for msg in request .messages :
236+ if isinstance (msg , LC_AIMessage ):
237+ new_calls : list [LC_ToolCall ] = []
238+ for call in msg .tool_calls :
239+ new_calls .append (self .unpack_tool_call (call ))
240+ msg = msg .model_copy (update = {"tool_calls" : new_calls })
241+ messages .append (msg )
242+
243+ response = await handler (request .override (messages = messages ))
244+
245+ ai_message = response
246+ if isinstance (ai_message , LC_ExtendedModelResponse ):
247+ ai_message = ai_message .model_response
248+ if isinstance (ai_message , LC_ModelResponse ):
249+ ai_message = next (
250+ (m for m in ai_message .result if isinstance (m , LC_AIMessage )),
251+ None ,
252+ )
253+ assert ai_message , "AIMessage not found found in response"
254+
255+ # Pack new message.
256+ for call in ai_message .tool_calls :
257+ if call ["name" ].startswith (AGENT_PREFIX ):
258+ if (
259+ _denormalize_agent_name (call ["name" ])
260+ in structured_subagents
261+ ):
262+ args = SubagentLCArgs (call ["args" ])
263+ else :
264+ content : str = call ["args" ].get ("content" , "" )
265+ args = SubagentLCArgs (content )
266+ call ["args" ] = asdict (args )
267+
268+ return response
269+
270+ # Unpack args, just before tool call.
271+ @override
272+ async def awrap_tool_call (
273+ self ,
274+ request : LC_ToolCallRequest ,
275+ handler : Callable [
276+ [LC_ToolCallRequest ], Awaitable [LC_ToolMessage | LC_Command [None ]]
277+ ],
278+ ) -> LC_ToolMessage | LC_Command [None ]:
279+ return await handler (
280+ request .override (
281+ tool_call = self .unpack_tool_call (request .tool_call ),
282+ )
283+ )
284+
285+ def unpack_tool_call (self , call : LC_ToolCall ) -> LC_ToolCall :
286+ if call ["name" ].startswith (AGENT_PREFIX ):
287+ unpacked_args = SubagentLCArgs (** call ["args" ]).args
288+ if isinstance (unpacked_args , str ):
289+ unpacked_args = {"content" : unpacked_args }
290+ return LC_ToolCall (
291+ id = call ["id" ],
292+ name = call ["name" ],
293+ args = unpacked_args ,
294+ )
295+ return call
296+
214297 lc_middleware .append (_ToolFailureArtifact ())
298+ lc_middleware .append (_SubagentArgumentPacker ())
215299
216300 self ._agent = create_agent (
217301 model = model_impl ,
@@ -933,12 +1017,17 @@ async def _run(
9331017 )
9341018
9351019
1020+ @dataclass (frozen = True )
1021+ class SubagentLCArgs :
1022+ args : str | dict [str , Any ]
1023+
1024+
9361025def _map_tool_call_from_langchain (tool_call : LC_ToolCall ) -> ToolCall | SubagentCall :
9371026 name = tool_call ["name" ]
9381027 if name .startswith (AGENT_PREFIX ):
9391028 return SubagentCall (
9401029 name = _denormalize_agent_name (name ),
941- args = tool_call ["args" ],
1030+ args = SubagentLCArgs ( ** tool_call ["args" ]). args ,
9421031 id = tool_call ["id" ],
9431032 )
9441033
@@ -957,10 +1046,12 @@ def _map_tool_call_to_langchain(call: ToolCall | SubagentCall) -> LC_ToolCall:
9571046 match call :
9581047 case SubagentCall ():
9591048 name = _normalize_agent_name (call .name )
1049+ args = asdict (SubagentLCArgs (call .args ))
9601050 case ToolCall ():
9611051 name = _normalize_tool_name (call .name , call .type )
1052+ args = call .args
9621053
963- return LC_ToolCall (id = call .id , name = name , args = call . args )
1054+ return LC_ToolCall (id = call .id , name = name , args = args )
9641055
9651056
9661057def _map_message_from_langchain (message : LC_BaseMessage ) -> BaseMessage :
0 commit comments