Skip to content

Commit f8a1daf

Browse files
authored
Add ConversationStore (#96)
1 parent 806da0b commit f8a1daf

9 files changed

Lines changed: 515 additions & 97 deletions

File tree

splunklib/ai/README.md

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,83 @@ async with Agent(
264264
) as agent: ...
265265
```
266266

267+
## Conversation stores
268+
269+
By default, each call to `agent.invoke` is stateless - the agent has no memory of previous interactions,
270+
unless you provide the previouis message history explicitly. A conversation store enables the agent to persist
271+
and recall message history across invocations.
272+
273+
### `InMemoryStore`
274+
275+
The built-in `InMemoryStore` keeps conversation history in process memory.
276+
277+
```py
278+
from splunklib.ai import Agent, OpenAIModel
279+
from splunklib.ai.conversation_store import InMemoryStore
280+
from splunklib.ai.messages import HumanMessage
281+
from splunklib.client import connect
282+
283+
model = OpenAIModel(...)
284+
service = connect(...)
285+
286+
async with Agent(
287+
model=model,
288+
service=service,
289+
system_prompt="",
290+
conversation_store=InMemoryStore(),
291+
) as agent:
292+
await agent.invoke([HumanMessage(content="Hi, my name is Chris.")])
293+
result = await agent.invoke([HumanMessage(content="What is my name?")])
294+
print(result.final_message.content) # Chris
295+
```
296+
297+
### Multiple conversation threads
298+
299+
Each conversation is isolated by a `thread_id`. You can pass a `thread_id` per invocation to maintain
300+
separate histories for different users or sessions within the same agent instance.
301+
302+
```py
303+
async with Agent(
304+
model=model,
305+
service=service,
306+
system_prompt="",
307+
conversation_store=InMemoryStore(),
308+
) as agent:
309+
await agent.invoke(
310+
[HumanMessage(content="Hi, my name is Alice.")],
311+
thread_id="user-alice",
312+
)
313+
await agent.invoke(
314+
[HumanMessage(content="Hi, my name is Bob.")],
315+
thread_id="user-bob",
316+
)
317+
318+
result = await agent.invoke(
319+
[HumanMessage(content="What is my name?")],
320+
thread_id="user-alice",
321+
)
322+
print(result.final_message.content) # Alice - Bob's thread is unaffected
323+
```
324+
325+
A custom `thread_id` can also be set on the agent constructor. When `invoke` is called without an explicit
326+
`thread_id`, the `thread_id` from the constructor is used. If no `thread_id` is provided in the constructor, one
327+
is generated implicitly.
328+
329+
```py
330+
async with Agent(
331+
model=model,
332+
service=service,
333+
system_prompt="",
334+
conversation_store=InMemoryStore(),
335+
thread_id="session-42",
336+
) as agent:
337+
await agent.invoke([HumanMessage(content="Hi, my name is Chris.")])
338+
339+
# No thread_id supplied — falls back to "session-42"
340+
result = await agent.invoke([HumanMessage(content="What is my name?")])
341+
print(result.final_message.content) # Chris
342+
```
343+
267344
## Subagents
268345

269346
The `Agent` constructor can accept subagents as input parameters.

splunklib/ai/agent.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
1818
from logging import Logger
1919
from typing import Self, final, override
20+
from uuid import uuid4
2021

2122
from pydantic import BaseModel
2223

2324
from splunklib.ai.base_agent import BaseAgent
25+
from splunklib.ai.conversation_store import ConversationStore
2426
from splunklib.ai.core.backend import AgentImpl
2527
from splunklib.ai.core.backend_registry import get_backend
2628
from splunklib.ai.messages import AgentResponse, BaseMessage, OutputT
@@ -107,6 +109,29 @@ class Agent(BaseAgent[OutputT]):
107109
logger:
108110
Optional logger instance used for tracing and debugging the agent's execution.
109111
Additionally logs from the local tools are forwarded to this logger.
112+
113+
conversation_store:
114+
Optional `ConversationStore` instance used to persist conversation history
115+
across multiple `invoke` calls. When provided, the agent automatically loads
116+
prior messages for the active thread before each invocation and saves the
117+
full updated history afterwards.
118+
119+
Use the built-in `InMemoryStore` for in-process persistence, or implement
120+
`ConversationStore` to back history with an external store.
121+
122+
Without a store, each `invoke` call is stateless and the agent has no memory
123+
of previous turns.
124+
125+
thread_id:
126+
Identifies the conversation thread used when reading from and writing to the
127+
`conversation_store`. Each unique `thread_id` maintains a separate history,
128+
so different users or sessions can share one store without interference.
129+
130+
If omitted, a random ID is generated automatically. The `thread_id` can
131+
also be overridden per-call by passing it directly to `invoke`.
132+
133+
Never invoke an Agent using the same thread_id more than once concurrently
134+
while using the same conversation_store.
110135
"""
111136

112137
_impl: AgentImpl[OutputT] | None
@@ -129,17 +154,22 @@ def __init__(
129154
name: str = "", # Only used by Subagents
130155
description: str = "", # Only used by Subagents
131156
logger: Logger | None = None,
157+
conversation_store: ConversationStore | None = None,
158+
thread_id: str | None = None,
132159
) -> None:
133160
super().__init__(
134161
model=model,
135162
system_prompt=system_prompt,
136163
name=name,
137164
description=description,
165+
tools=None,
138166
agents=agents,
139167
input_schema=input_schema,
140168
output_schema=output_schema,
141169
middleware=middleware,
142170
logger=logger,
171+
conversation_store=conversation_store,
172+
thread_id=thread_id if thread_id is not None else str(uuid4()),
143173
)
144174

145175
self._use_mcp_tools = use_mcp_tools
@@ -242,12 +272,19 @@ async def __aexit__(
242272
self._agent_context_manager = None
243273
return result
244274

275+
# TODO: for now we have a thread_id as an optional param, should
276+
# we wrap it in a dataclass? Might help with future-proofing the API??
245277
@override
246-
async def invoke(self, messages: list[BaseMessage]) -> AgentResponse[OutputT]:
278+
async def invoke(
279+
self, messages: list[BaseMessage], thread_id: str | None = None
280+
) -> AgentResponse[OutputT]:
247281
if not self._impl:
248282
raise AssertionError("Agent must be used inside 'async with'")
249283

250-
return await self._impl.invoke(messages)
284+
if thread_id is None:
285+
thread_id = self._thread_id
286+
287+
return await self._impl.invoke(messages, thread_id)
251288

252289

253290
def _local_tools_path() -> tuple[str | None, str]:

splunklib/ai/base_agent.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from pydantic import BaseModel
2222

23+
from splunklib.ai.conversation_store import ConversationStore
2324
from splunklib.ai.messages import AgentResponse, BaseMessage, OutputT
2425
from splunklib.ai.middleware import AgentMiddleware
2526
from splunklib.ai.model import PredefinedModel
@@ -38,19 +39,23 @@ class BaseAgent(Generic[OutputT], ABC):
3839
_middleware: Sequence[AgentMiddleware] | None = None
3940
_trace_id: str
4041
_logger: logging.Logger
42+
_conversation_store: ConversationStore | None = None
43+
_thread_id: str
4144

4245
def __init__(
4346
self,
4447
system_prompt: str,
4548
model: PredefinedModel,
46-
description: str = "",
47-
name: str = "",
48-
tools: Sequence[Tool] | None = None,
49-
agents: Sequence["BaseAgent[BaseModel | None]"] | None = None,
50-
input_schema: type[BaseModel] | None = None,
51-
output_schema: type[OutputT] | None = None,
52-
middleware: Sequence[AgentMiddleware] | None = None,
53-
logger: logging.Logger | None = None,
49+
description: str,
50+
name: str,
51+
tools: Sequence[Tool] | None,
52+
agents: Sequence["BaseAgent[BaseModel | None]"] | None,
53+
input_schema: type[BaseModel] | None,
54+
output_schema: type[OutputT] | None,
55+
middleware: Sequence[AgentMiddleware] | None,
56+
logger: logging.Logger | None,
57+
conversation_store: ConversationStore | None,
58+
thread_id: str,
5459
) -> None:
5560
self._system_prompt = system_prompt
5661
self._model = model
@@ -62,6 +67,8 @@ def __init__(
6267
self._output_schema = output_schema
6368
self._middleware = tuple(middleware) if middleware else ()
6469
self._trace_id = secrets.token_hex(16) # 32 Hex characters
70+
self._conversation_store = conversation_store
71+
self._thread_id = thread_id
6572

6673
if logger is None:
6774
# Create a no-op logger to skip checking for its existence.
@@ -70,7 +77,9 @@ def __init__(
7077
self._logger = logger
7178

7279
@abstractmethod
73-
async def invoke(self, messages: list[BaseMessage]) -> AgentResponse[OutputT]: ...
80+
async def invoke(
81+
self, messages: list[BaseMessage], thread_id: str | None = None
82+
) -> AgentResponse[OutputT]: ...
7483

7584
@property
7685
def logger(self) -> logging.Logger:
@@ -115,3 +124,11 @@ def middleware(self) -> Sequence[AgentMiddleware] | None:
115124
@property
116125
def trace_id(self) -> str:
117126
return self._trace_id
127+
128+
@property
129+
def conversation_store(self) -> ConversationStore | None:
130+
return self._conversation_store
131+
132+
@property
133+
def default_thread_id(self) -> str:
134+
return self._thread_id

splunklib/ai/conversation_store.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright © 2011-2026 Splunk, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"): you may
4+
# not use this file except in compliance with the License. You may obtain
5+
# a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
# License for the specific language governing permissions and limitations
13+
# under the License.
14+
15+
from collections.abc import Sequence
16+
from typing import Protocol, override
17+
18+
from splunklib.ai.messages import BaseMessage
19+
20+
21+
class ConversationStore(Protocol):
22+
async def get_messages(self, thread_id: str) -> Sequence[BaseMessage]: ...
23+
24+
async def store_messages(
25+
self, thread_id: str, messages: list[BaseMessage]
26+
) -> None: ...
27+
28+
29+
class InMemoryStore(ConversationStore):
30+
_threads: dict[str, Sequence[BaseMessage]]
31+
32+
def __init__(self) -> None:
33+
self._threads = {}
34+
35+
@override
36+
async def get_messages(self, thread_id: str) -> Sequence[BaseMessage]:
37+
return self._threads.get(thread_id, [])
38+
39+
@override
40+
async def store_messages(self, thread_id: str, messages: list[BaseMessage]) -> None:
41+
self._threads[thread_id] = messages.copy()

splunklib/ai/core/backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ class InvalidMessageTypeError(Exception):
2929
class AgentImpl(Protocol[OutputT]):
3030
"""Backend-specific agent implementation used by the public `Agent` wrapper."""
3131

32-
async def invoke(self, messages: list[BaseMessage]) -> AgentResponse[OutputT]: ...
32+
async def invoke(
33+
self, messages: list[BaseMessage], thread_id: str
34+
) -> AgentResponse[OutputT]: ...
3335

3436

3537
class Backend(Protocol):

0 commit comments

Comments
 (0)