diff --git a/src/core/client.py b/src/core/client.py index 085a0b8..9c2e57c 100644 --- a/src/core/client.py +++ b/src/core/client.py @@ -1,12 +1,12 @@ import asyncio import json -import os from dataclasses import asdict -from typing import Dict, List, Optional, Type +from typing import Dict, List, Type import websockets from src.core.dataclasses.config import ClientConfig +from src.core.user_config import get_or_create_user_id from .client_senders import ClientSender, get_senders @@ -17,23 +17,13 @@ class Client: def __init__( self, config: ClientConfig, - user_id_file: str = os.path.expanduser("~/.huri_user_id"), + user_id_file: str | None = None, senders_dict: Dict[str, Type[ClientSender]] = get_senders(), ): self.config = config self.user_id_file = user_id_file self.senders_dict = senders_dict - def _load_user_id(self) -> Optional[str]: - if os.path.exists(self.user_id_file): - with open(self.user_id_file) as f: - return f.read().strip() - return None - - def _save_user_id(self, _user_id: str): - with open(self.user_id_file, "w") as f: - f.write(_user_id) - async def _receive_loop(self, ws: websockets.ClientConnection): try: while True: @@ -48,7 +38,7 @@ async def run(self): async with websockets.connect(self.config.huri_url) as ws: print("Connected to server") - self.config.user_id = self._load_user_id() + self.config.user_id = get_or_create_user_id(self.user_id_file) senders: List[ClientSender] = [ self.senders_dict[config.name](ws=ws, **config.args) @@ -60,7 +50,6 @@ async def run(self): init_msg = json.loads(await ws.recv()) if init_msg.get("type") == "session_init": user_id = init_msg["user_id"] - self._save_user_id(user_id) print(f"Session started with _user_id: {user_id}") receive_task = asyncio.create_task(self._receive_loop(ws)) diff --git a/src/core/user_config.py b/src/core/user_config.py new file mode 100644 index 0000000..19a74ef --- /dev/null +++ b/src/core/user_config.py @@ -0,0 +1,60 @@ +import os +import platform +import uuid +from pathlib import Path + + +def get_config_dir() -> Path: + """Cross-platform config directory.""" + system = platform.system() + + if system == "Windows": + # TODO: To be tested -> also consider language-specific if needed + base = os.environ.get("APPDATA", os.path.expanduser("~/AppData/Roaming")) + elif system == "Darwin": + # TODO: To be tested -> also consider language-specific if needed + base = os.path.expanduser("~/Library/Application Support") + else: + base = os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) + + config_dir = Path(base) / "huri" + config_dir.mkdir(parents=True, exist_ok=True) + return config_dir + + +def load_user_id(path: str | None = None) -> str | None: + """Load existing _user_id, or return None if new user.""" + id_file: Path + + if path is None: + id_file = get_config_dir() / "_user_id" + else: + id_file = Path(path) + if id_file.exists(): + uid = id_file.read_text().strip() + if uid: + return uid + return None + + +def save_user_id(_user_id: str, path: str | None = None): + id_file: Path + + if path is None: + id_file = get_config_dir() / "_user_id" + else: + id_file = Path(path) + + id_file.write_text(_user_id) + if platform.system() != "Windows": + id_file.chmod(0o600) + + +def get_or_create_user_id(path: str | None = None) -> str: + """Load existing or generate new _user_id.""" + uid = load_user_id(path) + if uid: + return uid + uid = str(uuid.uuid4()) + save_user_id(uid, path) + return uid diff --git a/src/modules/rag/__init__.py b/src/modules/rag/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py index f4e4dae..a50f747 100644 --- a/src/modules/rag/ingestion.py +++ b/src/modules/rag/ingestion.py @@ -1,5 +1,4 @@ import argparse -import os import re import sys import uuid @@ -17,10 +16,10 @@ PointStruct, VectorParams, ) -from semantic_chunker import SemanticChunker from sentence_transformers import SentenceTransformer -USER_ID_FILE = os.path.expanduser("~/.huri_user_id") +from src.core.user_config import get_or_create_user_id +from src.modules.rag.semantic_chunker import SemanticChunker def _split_sentences(text: str) -> list[str]: @@ -89,21 +88,6 @@ def extract_text_from_pdf(pdf_path: str) -> str: sys.exit(1) -def get_user_id(provided_id: str | None = None) -> str: - if provided_id: - return provided_id - if os.path.exists(USER_ID_FILE): - with open(USER_ID_FILE) as f: - uid = f.read().strip() - if uid: - return uid - new_id = str(uuid.uuid4()) - with open(USER_ID_FILE, "w") as f: - f.write(new_id) - print(f"Generated new user_id: {new_id}") - return new_id - - def ensure_collection(client: QdrantClient, collection: str, vector_size: int): collections = [c.name for c in client.get_collections().collections] if collection not in collections: @@ -145,7 +129,6 @@ def ingest_chunks( ) if points: - # Upsert in batches of 100 batch_size = 100 for i in range(0, len(points), batch_size): batch = points[i : i + batch_size] @@ -403,7 +386,7 @@ def main(): args = parser.parse_args() - _user_id = get_user_id(args._user_id) + _user_id = get_or_create_user_id() print(f"User: {_user_id}") client = QdrantClient(url=args.qdrant_url) @@ -440,7 +423,7 @@ def main(): # Ingest a text file python ingestion.py text notes.txt story.md - # Specify a user ID (otherwise reads from ~/.huri_user_id) + # Specify a user ID (otherwise it will be auto-generated and saved) python ingestion.py --user-id "abc-123" pdf report.pdf # Use a different collection diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index 6b9744d..a96a756 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -21,6 +21,7 @@ class RAGQuery: _user_id: str question: str preferences: dict = field(default_factory=dict) + history: list[dict] | None = None # preferences can include: language, tone, # response_format, max_length, system_prompt, extra_instructions, etc. @@ -110,8 +111,9 @@ def _search( ] qdrant_filter = Filter(must=conditions) + doc_results = [] try: - results = qdrant.query_points( + doc_results = qdrant.query_points( collection_name=collection, query=query_vector, query_filter=qdrant_filter, @@ -119,14 +121,15 @@ def _search( score_threshold=self.score_threshold, ).points except Exception: - results = [] + pass + return [ { "text": point.payload.get("text", ""), "score": point.score, "metadata": {k: v for k, v in point.payload.items() if k != "text"}, } - for point in results + for point in doc_results ] def _build_prompt( @@ -134,12 +137,19 @@ def _build_prompt( question: str, chunks: list[dict], preferences: dict, + history=None, ) -> tuple[str, str]: - parts = [ - "You are a robot speaking to a user. Answer based on the provided context.", - "If the context is insufficient, say so clearly.", - ] + parts = [] + + if history: + lines = [f"{m['role']}: {m['content']}" for m in history] + parts.append("[Recent conversation]\n" + "\n".join(lines)) + + parts.append( + "You are a robot speaking to a user. Answer based on the provided context." + + " If the context is insufficient, say so clearly.", + ) if preferences.get("language"): parts.append(f"Always respond in {preferences['language']}.") if preferences.get("tone"): @@ -153,11 +163,7 @@ def _build_prompt( system_prompt = " ".join(parts) if not chunks: - user_prompt = ( - "No relevant context was found.\n\n" - f"Question: {question}\n\n" - "Answer based on general knowledge." - ) + user_prompt = f"Question: {question}\n\n" else: context_parts = [] for i, chunk in enumerate(chunks, 1): @@ -259,7 +265,7 @@ async def process(self, query: RAGQuery) -> RAGResult: print(f" - score: {c['score']:.2f} | {c['text'][:100]}...") system_prompt, user_prompt = self._build_prompt( - query.question, chunks, query.preferences + query.question, chunks, query.preferences, query.history ) print(f"[RAG] System prompt: {system_prompt[:200]}...") answer = await self._llm_generate(system_prompt, user_prompt, query.preferences) @@ -288,6 +294,7 @@ def __init__( response_format="paragraph", max_length=1024, extra_instructions="", + max_history=10, **kwargs, ): super().__init__(_handle=_handle, _user_id=_user_id, **kwargs) @@ -299,6 +306,8 @@ def __init__( "max_length": max_length, "extra_instructions": extra_instructions, } + self.history: list[dict] = [] + self.max_history = max_history async def process(self, data: Sentence) -> Optional[RAGResult]: """ @@ -311,9 +320,26 @@ async def process(self, data: Sentence) -> Optional[RAGResult]: _user_id=self._user_id if self._user_id else "anonymous", question=question_text, preferences=self.preferences, + history=( + self.history + if len(self.history) <= self.max_history + else self.history[-self.max_history :] + ), + ) + + if self._handle is None: + print("[RAG] No handle available, returning None") + return None + + result: RAGResult | None = None + if self._handle is not None: + result = await self._handle.process.remote(query) + + self.history.append({"role": "user", "content": question_text}) + self.history.append( + {"role": "assistant", "content": result.answer if result else None} ) - result: RAGResult = await self._handle.process.remote(query) return result def update_preferences(self, new_preferences: dict):