Skip to content

Commit 1843a05

Browse files
committed
fixed summarization issues
1 parent 53ea418 commit 1843a05

1 file changed

Lines changed: 7 additions & 4 deletions

File tree

willa/chatbot/graph_manager.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class WillaChatbotState(TypedDict):
1919
messages: Annotated[list[AnyMessage], add_messages]
2020
filtered_messages: NotRequired[list[AnyMessage]]
2121
summarized_messages: NotRequired[list[AnyMessage]]
22+
messages_for_generation: NotRequired[list[AnyMessage]]
2223
search_query: NotRequired[str]
2324
tind_metadata: NotRequired[str]
2425
documents: NotRequired[list[Any]]
@@ -41,7 +42,6 @@ def _create_workflow(self) -> CompiledStateGraph:
4142
summarization_node = SummarizationNode(
4243
max_tokens=int(CONFIG['SUMMARIZATION_MAX_TOKENS']),
4344
model=self._model,
44-
token_counter=self._model.get_num_tokens_from_messages,
4545
input_messages_key="filtered_messages",
4646
output_messages_key="summarized_messages"
4747
)
@@ -70,7 +70,10 @@ def _filter_messages(self, state: WillaChatbotState) -> dict[str, list[AnyMessag
7070
"""Filter out TIND messages from the conversation history."""
7171
messages = state["messages"]
7272

73-
filtered = [msg for msg in messages if 'tind' not in msg.response_metadata]
73+
filtered = [
74+
msg for msg in messages
75+
if 'tind' not in msg.response_metadata and msg.type != "system"
76+
]
7477
return {"filtered_messages": filtered}
7578

7679
def _prepare_search_query(self, state: WillaChatbotState) -> dict[str, str]:
@@ -132,14 +135,14 @@ def _prepare_for_generation(self, state: WillaChatbotState) -> dict[str, list[An
132135
else:
133136
all_messages = summarized_conversation + [system_messages]
134137

135-
return {"messages": all_messages}
138+
return {"messages_for_generation": all_messages}
136139

137140
def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
138141
"""Generate response using the model."""
139142
tind_metadata = state.get("tind_metadata", "")
140143
model = self._model
141144
documents = state.get("documents", [])
142-
messages = state["messages"]
145+
messages = state["messages_for_generation"]
143146

144147
if not model:
145148
return {"messages": [AIMessage(content="Model not available.")]}

0 commit comments

Comments
 (0)