Skip to content

Commit 8181104

Browse files
authored
Reflect unstructured inputs in SubagentCall.args (#110)
1 parent 7e8d1c3 commit 8181104

4 files changed

Lines changed: 130 additions & 11 deletions

File tree

splunklib/ai/engines/langchain.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def __init__(self, agent: BaseAgent[OutputT]) -> None:
146146
tools = _prepare_langchain_tools(agent.tools)
147147

148148
system_prompt = agent.system_prompt
149+
structured_subagents: list[str] = []
149150
if agent.agents:
150151
seen_names: set[str] = set()
151152
for subagent in agent.agents:
@@ -161,6 +162,9 @@ def __init__(self, agent: BaseAgent[OutputT]) -> None:
161162
seen_names.add(subagent.name)
162163
tools.append(tool)
163164

165+
if subagent.input_schema is not None:
166+
structured_subagents.append(subagent.name)
167+
164168
system_prompt = AGENT_AS_TOOLS_PROMPT + "\n" + system_prompt
165169

166170
before_user_middlewares, after_user_middlewares = _debugging_middleware(
@@ -211,7 +215,87 @@ async def awrap_tool_call(
211215

212216
return resp
213217

218+
class _SubagentArgumentPacker(LC_AgentMiddleware):
219+
# For non-structured subagents, the SubagentCall.args field is an `str | dict[str, Any]`,
220+
# to differentiate that we wrap the resulting args in an SubagentLCArgs.
221+
#
222+
# This middleware performs the corresponding pack/unpack at the two
223+
# points in the LangChain call graph where raw args are needed/retreived.
224+
#
225+
# TODO: once we move middlewares into one LC middleware, we should move
226+
# that piece of logic there (DVPL-12959).
227+
@override
228+
async def awrap_model_call(
229+
self,
230+
request: LC_ModelRequest,
231+
handler: Callable[[LC_ModelRequest], Awaitable[LC_ModelCallResult]],
232+
) -> LC_ModelCallResult:
233+
# Unpack existing messages.
234+
messages: list[LC_AnyMessage] = []
235+
for msg in request.messages:
236+
if isinstance(msg, LC_AIMessage):
237+
new_calls: list[LC_ToolCall] = []
238+
for call in msg.tool_calls:
239+
new_calls.append(self.unpack_tool_call(call))
240+
msg = msg.model_copy(update={"tool_calls": new_calls})
241+
messages.append(msg)
242+
243+
response = await handler(request.override(messages=messages))
244+
245+
ai_message = response
246+
if isinstance(ai_message, LC_ExtendedModelResponse):
247+
ai_message = ai_message.model_response
248+
if isinstance(ai_message, LC_ModelResponse):
249+
ai_message = next(
250+
(m for m in ai_message.result if isinstance(m, LC_AIMessage)),
251+
None,
252+
)
253+
assert ai_message, "AIMessage not found found in response"
254+
255+
# Pack new message.
256+
for call in ai_message.tool_calls:
257+
if call["name"].startswith(AGENT_PREFIX):
258+
if (
259+
_denormalize_agent_name(call["name"])
260+
in structured_subagents
261+
):
262+
args = SubagentLCArgs(call["args"])
263+
else:
264+
content: str = call["args"].get("content", "")
265+
args = SubagentLCArgs(content)
266+
call["args"] = asdict(args)
267+
268+
return response
269+
270+
# Unpack args, just before tool call.
271+
@override
272+
async def awrap_tool_call(
273+
self,
274+
request: LC_ToolCallRequest,
275+
handler: Callable[
276+
[LC_ToolCallRequest], Awaitable[LC_ToolMessage | LC_Command[None]]
277+
],
278+
) -> LC_ToolMessage | LC_Command[None]:
279+
return await handler(
280+
request.override(
281+
tool_call=self.unpack_tool_call(request.tool_call),
282+
)
283+
)
284+
285+
def unpack_tool_call(self, call: LC_ToolCall) -> LC_ToolCall:
286+
if call["name"].startswith(AGENT_PREFIX):
287+
unpacked_args = SubagentLCArgs(**call["args"]).args
288+
if isinstance(unpacked_args, str):
289+
unpacked_args = {"content": unpacked_args}
290+
return LC_ToolCall(
291+
id=call["id"],
292+
name=call["name"],
293+
args=unpacked_args,
294+
)
295+
return call
296+
214297
lc_middleware.append(_ToolFailureArtifact())
298+
lc_middleware.append(_SubagentArgumentPacker())
215299

216300
self._agent = create_agent(
217301
model=model_impl,
@@ -933,12 +1017,17 @@ async def _run(
9331017
)
9341018

9351019

1020+
@dataclass(frozen=True)
1021+
class SubagentLCArgs:
1022+
args: str | dict[str, Any]
1023+
1024+
9361025
def _map_tool_call_from_langchain(tool_call: LC_ToolCall) -> ToolCall | SubagentCall:
9371026
name = tool_call["name"]
9381027
if name.startswith(AGENT_PREFIX):
9391028
return SubagentCall(
9401029
name=_denormalize_agent_name(name),
941-
args=tool_call["args"],
1030+
args=SubagentLCArgs(**tool_call["args"]).args,
9421031
id=tool_call["id"],
9431032
)
9441033

@@ -957,10 +1046,12 @@ def _map_tool_call_to_langchain(call: ToolCall | SubagentCall) -> LC_ToolCall:
9571046
match call:
9581047
case SubagentCall():
9591048
name = _normalize_agent_name(call.name)
1049+
args = asdict(SubagentLCArgs(call.args))
9601050
case ToolCall():
9611051
name = _normalize_tool_name(call.name, call.type)
1052+
args = call.args
9621053

963-
return LC_ToolCall(id=call.id, name=name, args=call.args)
1054+
return LC_ToolCall(id=call.id, name=name, args=args)
9641055

9651056

9661057
def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:

splunklib/ai/messages.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ class ToolCall:
3232
@dataclass(frozen=True)
3333
class SubagentCall:
3434
name: str
35-
# TODO: should be a str | dict[str, Any] for subagents without structured inputs
36-
args: dict[str, Any]
35+
args: str | dict[str, Any]
3736
id: str | None # TODO: can be None?
3837

3938

tests/integration/ai/test_agent.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pydantic import BaseModel, Field
1717

1818
from splunklib.ai import Agent
19-
from splunklib.ai.messages import HumanMessage, SubagentMessage
19+
from splunklib.ai.messages import AIMessage, HumanMessage, SubagentCall, SubagentMessage
2020
from tests.ai_testlib import AITestCase
2121

2222
OPENAI_BASE_URL = "http://localhost:11434/v1"
@@ -175,7 +175,17 @@ class NicknameGeneratorInput(BaseModel):
175175
]
176176
)
177177

178-
response = result.final_message.content
178+
first_ai_message = next(
179+
m for m in result.messages if isinstance(m, AIMessage)
180+
)
181+
assert first_ai_message
182+
assert len(first_ai_message.calls) == 1
183+
assert isinstance(first_ai_message.calls[0], SubagentCall)
184+
args = first_ai_message.calls[0].args
185+
assert isinstance(args, dict)
186+
187+
# asserts that can create NicknameGeneratorInput from args
188+
NicknameGeneratorInput(**args)
179189

180190
subagent_message = next(
181191
filter(lambda m: m.role == "subagent", result.messages), None
@@ -184,6 +194,8 @@ class NicknameGeneratorInput(BaseModel):
184194
"Invalid subagent message"
185195
)
186196
assert subagent_message, "No subagent message found in response"
197+
198+
response = result.final_message.content
187199
assert "Chris-zilla" in response, "Agent did generate valid nickname"
188200

189201
@pytest.mark.asyncio
@@ -217,6 +229,15 @@ async def test_subagent_without_input_schema(self):
217229
]
218230
)
219231

232+
first_ai_message = next(
233+
m for m in result.messages if isinstance(m, AIMessage)
234+
)
235+
assert first_ai_message
236+
assert len(first_ai_message.calls) == 1
237+
assert isinstance(first_ai_message.calls[0], SubagentCall)
238+
assert isinstance(first_ai_message.calls[0].args, str)
239+
assert first_ai_message.calls[0].args.lower() == "chris"
240+
220241
response = result.final_message.content
221242
assert "Chris-zilla" in response, "Agent did generate valid nickname"
222243

tests/unit/ai/engine/test_langchain_backend.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_map_message_from_langchain_ai_with_tool_calls(self) -> None:
5858

5959
def test_map_message_from_langchain_ai_with_agent_call(self) -> None:
6060
tool_call = LC_ToolCall(
61-
name=f"{lc.AGENT_PREFIX}assistant", args={"q": "test"}, id="tc-2"
61+
name=f"{lc.AGENT_PREFIX}assistant", args={"args": {"q": "test"}}, id="tc-2"
6262
)
6363
message = LC_AIMessage(content="done", tool_calls=[tool_call])
6464
mapped = lc._map_message_from_langchain(message)
@@ -75,7 +75,7 @@ def test_map_message_from_langchain_ai_with_agent_call(self) -> None:
7575
def test_map_message_from_langchain_ai_with_mixed_calls(self) -> None:
7676
tool_call = LC_ToolCall(name="lookup", args={"q": "test"}, id="tc-1")
7777
agent_call = LC_ToolCall(
78-
name=f"{lc.AGENT_PREFIX}assistant", args={"q": "test"}, id="tc-2"
78+
name=f"{lc.AGENT_PREFIX}assistant", args={"args": {"q": "test"}}, id="tc-2"
7979
)
8080
message = LC_AIMessage(content="done", tool_calls=[tool_call, agent_call])
8181

@@ -155,14 +155,22 @@ def test_map_message_to_langchain_ai(self) -> None:
155155
def test_map_message_to_langchain_ai_with_agent_call(self) -> None:
156156
message = AIMessage(
157157
content="hi",
158-
calls=[SubagentCall(name="assistant", args={"q": "test"}, id="tc-2")],
158+
calls=[
159+
SubagentCall(
160+
name="assistant",
161+
args={"q": "test"},
162+
id="tc-2",
163+
)
164+
],
159165
)
160166
mapped = lc._map_message_to_langchain(message)
161167

162168
assert isinstance(mapped, LC_AIMessage)
163169
assert mapped.tool_calls == [
164170
LC_ToolCall(
165-
name=f"{lc.AGENT_PREFIX}assistant", args={"q": "test"}, id="tc-2"
171+
name=f"{lc.AGENT_PREFIX}assistant",
172+
args={"args": {"q": "test"}},
173+
id="tc-2",
166174
)
167175
]
168176

@@ -268,7 +276,7 @@ def test_map_message_to_langchain_agent_call_with_agent_prefix_raises(
268276
# Fine, but in practice a unnecessary prefix.
269277
assert isinstance(message, LC_AIMessage)
270278
assert message.tool_calls == [
271-
LC_ToolCall(name="__agent-__agent-bad-agent", args={}, id="tc-1")
279+
LC_ToolCall(name="__agent-__agent-bad-agent", args={"args": {}}, id="tc-1")
272280
]
273281

274282
def test_map_message_to_langchain_system(self) -> None:

0 commit comments

Comments
 (0)