Skip to content

Commit 7e8d1c3

Browse files
authored
Move LC agent construction logic to the constructor (#108)
1 parent f8a1daf commit 7e8d1c3

1 file changed

Lines changed: 62 additions & 83 deletions

File tree

splunklib/ai/engines/langchain.py

Lines changed: 62 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
from langgraph.types import Command as LC_Command
5151

5252
from splunklib.ai.base_agent import BaseAgent
53-
from splunklib.ai.conversation_store import ConversationStore
5453
from splunklib.ai.core.backend import (
5554
AgentImpl,
5655
Backend,
@@ -125,27 +124,58 @@
125124
ANTHROPIC_CHAT_MODEL_TYPE = "anthropic-chat"
126125

127126

127+
@final
128+
class LangChainBackend(Backend):
129+
@override
130+
async def create_agent(
131+
self,
132+
agent: BaseAgent[OutputT],
133+
) -> AgentImpl[OutputT]:
134+
return LangChainAgentImpl(agent)
135+
136+
128137
@dataclass
129138
class LangChainAgentImpl(AgentImpl[OutputT]):
130139
_agent: CompiledStateGraph[Any]
131-
_output_schema: type[OutputT] | None
132-
_middleware: Sequence[AgentMiddleware]
133-
_conversation_store: ConversationStore | None = None
140+
_sdk_agent: BaseAgent[OutputT]
134141

135-
def __init__(
136-
self,
137-
system_prompt: str,
138-
model: BaseChatModel,
139-
tools: list[BaseTool],
140-
output_schema: type[OutputT] | None,
141-
lc_middleware: list[LC_AgentMiddleware],
142-
middleware: Sequence[AgentMiddleware] | None = None,
143-
conversation_store: ConversationStore | None = None,
144-
) -> None:
142+
def __init__(self, agent: BaseAgent[OutputT]) -> None:
145143
super().__init__()
146-
self._output_schema = output_schema
147-
self._middleware = middleware or []
148-
self._conversation_store = conversation_store
144+
self._sdk_agent = agent
145+
146+
tools = _prepare_langchain_tools(agent.tools)
147+
148+
system_prompt = agent.system_prompt
149+
if agent.agents:
150+
seen_names: set[str] = set()
151+
for subagent in agent.agents:
152+
# Call _agent_as_tool first, so that the empty name exception is
153+
# checked and raised first, before the duplicated name exception.
154+
tool = _agent_as_tool(subagent)
155+
156+
if subagent.name in seen_names:
157+
raise AssertionError(
158+
f"Subagents share the same name: {subagent.name}"
159+
)
160+
161+
seen_names.add(subagent.name)
162+
tools.append(tool)
163+
164+
system_prompt = AGENT_AS_TOOLS_PROMPT + "\n" + system_prompt
165+
166+
before_user_middlewares, after_user_middlewares = _debugging_middleware(
167+
agent.logger
168+
)
169+
170+
middleware = before_user_middlewares
171+
middleware.extend(agent.middleware or [])
172+
middleware.extend(after_user_middlewares)
173+
174+
model_impl = _create_langchain_model(agent.model)
175+
176+
lc_middleware: list[LC_AgentMiddleware] = [
177+
_Middleware(m, model_impl, agent.logger) for m in (middleware or [])
178+
]
149179

150180
# This middleware is executed just after the tool execution and populates
151181
# the artifact field for failed tool calls, since in such cases we can't
@@ -184,10 +214,10 @@ async def awrap_tool_call(
184214
lc_middleware.append(_ToolFailureArtifact())
185215

186216
self._agent = create_agent(
187-
model=model,
217+
model=model_impl,
188218
tools=tools,
189219
system_prompt=system_prompt,
190-
response_format=output_schema,
220+
response_format=agent.output_schema,
191221
middleware=lc_middleware,
192222
)
193223

@@ -211,7 +241,7 @@ def _with_agent_middleware(
211241
# so the first middleware in the list becomes the outermost one.
212242

213243
invoke = agent_invoke
214-
for middleware in reversed(self._middleware):
244+
for middleware in reversed(self._sdk_agent.middleware or []):
215245

216246
def make_next(
217247
m: AgentMiddleware, h: AgentMiddlewareHandler
@@ -237,8 +267,8 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
237267
langchain_msgs = []
238268

239269
# Prepend messages from conversation store.
240-
if self._conversation_store:
241-
msgs = await self._conversation_store.get_messages(thread_id)
270+
if self._sdk_agent.conversation_store:
271+
msgs = await self._sdk_agent.conversation_store.get_messages(thread_id)
242272
langchain_msgs.extend([_map_message_to_langchain(m) for m in msgs])
243273

244274
langchain_msgs.extend([_map_message_to_langchain(m) for m in req.messages])
@@ -258,11 +288,11 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
258288
# if an LLM made any mistakes or not is _always_ up to the developer.
259289

260290
assert (
261-
self._output_schema is None
262-
or type(result["structured_response"]) is self._output_schema
291+
self._sdk_agent.output_schema is None
292+
or type(result["structured_response"]) is self._sdk_agent.output_schema
263293
)
264294

265-
if self._output_schema:
295+
if self._sdk_agent.output_schema:
266296
return AgentResponse(
267297
structured_output=result["structured_response"],
268298
messages=sdk_msgs,
@@ -287,19 +317,19 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
287317
if len(result.messages[-1].calls) != 0:
288318
raise AssertionError("AgentMiddleware included tool calls in AIMessage")
289319

290-
if self._output_schema:
320+
if self._sdk_agent.output_schema:
291321
if result.structured_output is None:
292322
raise AssertionError("Agent middleware discarded a structured output")
293323

294-
if type(result.structured_output) is not self._output_schema:
324+
if type(result.structured_output) is not self._sdk_agent.output_schema:
295325
raise AssertionError(
296-
f"Agent middleware returned an invalid structured_output type: {type(result.structured_output)}, want: {self._output_schema}"
326+
f"Agent middleware returned an invalid structured_output type: {type(result.structured_output)}, want: {self._sdk_agent.output_schema}"
297327
)
298328

299329
# Store the resulting messages in the conversation store, after all
300330
# agent middlewares have been executed.
301-
if self._conversation_store:
302-
await self._conversation_store.store_messages(
331+
if self._sdk_agent.conversation_store:
332+
await self._sdk_agent.conversation_store.store_messages(
303333
thread_id, result.messages
304334
)
305335

@@ -315,8 +345,8 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
315345

316346
# Store the resulting messages in the conversation store, after all
317347
# agent middlewares have been executed.
318-
if self._conversation_store:
319-
await self._conversation_store.store_messages(
348+
if self._sdk_agent.conversation_store:
349+
await self._sdk_agent.conversation_store.store_messages(
320350
thread_id, result.messages
321351
)
322352

@@ -337,57 +367,6 @@ def _prepare_langchain_tools(agent_tools: Sequence[Tool]) -> list[BaseTool]:
337367
return tools
338368

339369

340-
@final
341-
class LangChainBackend(Backend):
342-
@override
343-
async def create_agent(
344-
self,
345-
agent: BaseAgent[OutputT],
346-
) -> AgentImpl[OutputT]:
347-
tools = _prepare_langchain_tools(agent.tools)
348-
349-
system_prompt = agent.system_prompt
350-
if agent.agents:
351-
seen_names: set[str] = set()
352-
for subagent in agent.agents:
353-
# Call _agent_as_tool first, so that the empty name exception is
354-
# checked and raised first, before the duplicated name exception.
355-
tool = _agent_as_tool(subagent)
356-
357-
if subagent.name in seen_names:
358-
raise AssertionError(
359-
f"Subagents share the same name: {subagent.name}"
360-
)
361-
362-
seen_names.add(subagent.name)
363-
tools.append(tool)
364-
365-
system_prompt = AGENT_AS_TOOLS_PROMPT + "\n" + system_prompt
366-
367-
before_user_middlewares, after_user_middlewares = _debugging_middleware(
368-
agent.logger
369-
)
370-
371-
middleware = before_user_middlewares
372-
middleware.extend(agent.middleware or [])
373-
middleware.extend(after_user_middlewares)
374-
375-
model_impl = _create_langchain_model(agent.model)
376-
lc_middleware: list[LC_AgentMiddleware] = [
377-
_Middleware(m, model_impl, agent.logger) for m in middleware or []
378-
]
379-
380-
return LangChainAgentImpl(
381-
system_prompt=system_prompt,
382-
model=model_impl,
383-
tools=tools,
384-
output_schema=agent.output_schema,
385-
lc_middleware=lc_middleware,
386-
middleware=agent.middleware,
387-
conversation_store=agent.conversation_store,
388-
)
389-
390-
391370
class _Middleware(LC_AgentMiddleware):
392371
_middleware: AgentMiddleware
393372
_model: BaseChatModel

0 commit comments

Comments
 (0)