From f3631740b81734d51bb0b4ac239954be712afd58 Mon Sep 17 00:00:00 2001 From: sewon Date: Fri, 5 Jun 2026 20:49:53 +0900 Subject: [PATCH 1/3] =?UTF-8?q?fix:=20OpenAI=20gpt-4o-mini,=20Neon=20DSN?= =?UTF-8?q?=20=ED=8C=8C=EC=8B=B1=20=EB=B2=84=EA=B7=B8=20=EC=88=98=EC=A0=95?= =?UTF-8?q?,=20=EC=9E=90=EB=8F=99=20=EC=9E=AC=EC=8B=9C=EC=9E=91=20?= =?UTF-8?q?=EC=8A=A4=ED=81=AC=EB=A6=BD=ED=8A=B8=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 9 +++++ dev.sh | 39 ++++++++++++++++++++++ src/lang2sql/adapters/db/dsn_builder.py | 9 +++-- src/lang2sql/adapters/llm/openai_.py | 15 ++++++--- src/lang2sql/frontends/discord/bot.py | 21 +++++++++--- src/lang2sql/frontends/discord/commands.py | 24 ++++++++++++- 6 files changed, 106 insertions(+), 11 deletions(-) create mode 100755 dev.sh diff --git a/.env.example b/.env.example index 3525ef0..fb5fbf0 100644 --- a/.env.example +++ b/.env.example @@ -9,6 +9,15 @@ DISCORD_BOT_TOKEN= # for a smoke run, not for real answers). OPENAI_API_KEY= +# ── Local LLM (vLLM / Ollama) ──────────────────────────────────────────── +# When LANG2SQL_LLM_BASE_URL is set it takes priority over OPENAI_API_KEY. +# Point to the base URL of any OpenAI-compatible server. +# vLLM: http://localhost:8000 +# Ollama: http://localhost:11434 +LANG2SQL_LLM_BASE_URL= +# Model name as the server expects it (e.g. Qwen/Qwen3-14B-AWQ for vLLM). +LANG2SQL_LLM_MODEL= + # Fernet key used to encrypt stored secrets (DSNs / API keys) at rest. Optional: # if unset, a key is auto-generated and persisted in the SQLite kv table. Set it # in production so secrets decrypt across restarts and machines. Generate one: diff --git a/dev.sh b/dev.sh new file mode 100755 index 0000000..fc2513e --- /dev/null +++ b/dev.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Auto-reload dev runner for lang2sql-bot. +# Watches src/ for .py changes and restarts the bot automatically. + +set -a +source "$(dirname "$0")/.env" +set +a + +REF=$(mktemp) + +restart_bot() { + if [ -n "$BOT_PID" ] && kill -0 "$BOT_PID" 2>/dev/null; then + echo "[watch] stopping PID $BOT_PID..." + kill "$BOT_PID" + wait "$BOT_PID" 2>/dev/null + fi + echo "[watch] starting bot..." + .venv/bin/lang2sql-bot & + BOT_PID=$! + touch "$REF" + echo "[watch] PID $BOT_PID" +} + +trap 'kill $BOT_PID 2>/dev/null; rm -f $REF; exit' INT TERM + +restart_bot + +while true; do + sleep 2 + if find src/ -name "*.py" -newer "$REF" | grep -q .; then + CHANGED=$(find src/ -name "*.py" -newer "$REF" | head -3 | tr '\n' ' ') + echo "[watch] changed: $CHANGED" + restart_bot + elif ! kill -0 "$BOT_PID" 2>/dev/null; then + echo "[watch] bot crashed, restarting in 2s..." + sleep 2 + restart_bot + fi +done diff --git a/src/lang2sql/adapters/db/dsn_builder.py b/src/lang2sql/adapters/db/dsn_builder.py index 7de2fdd..188ea4f 100644 --- a/src/lang2sql/adapters/db/dsn_builder.py +++ b/src/lang2sql/adapters/db/dsn_builder.py @@ -10,7 +10,7 @@ from __future__ import annotations from dataclasses import dataclass -from urllib.parse import quote_plus +from urllib.parse import quote_plus, urlsplit @dataclass @@ -37,8 +37,13 @@ def _quote(s: str) -> str: def build_postgresql(*, host: str, port: str, database: str, user: str, password: str) -> ConnectionSpec: + # User may paste a full URL (e.g. "host/db?sslmode=require") into the host field. + # Extract just the hostname to avoid corrupting the assembled DSN. + parsed = urlsplit("//" + host) + clean_host = parsed.hostname or host p = int(port) if port else 5432 - dsn = f"postgresql+psycopg://{_quote(user)}:{_quote(password)}@{host}:{p}/{database}" + suffix = "?sslmode=require" if clean_host.endswith(".neon.tech") else "" + dsn = f"postgresql+psycopg://{_quote(user)}:{_quote(password)}@{clean_host}:{p}/{database}{suffix}" return ConnectionSpec(dsn=dsn, extras={}) diff --git a/src/lang2sql/adapters/llm/openai_.py b/src/lang2sql/adapters/llm/openai_.py index 74b5b02..17f2ceb 100644 --- a/src/lang2sql/adapters/llm/openai_.py +++ b/src/lang2sql/adapters/llm/openai_.py @@ -11,8 +11,10 @@ from __future__ import annotations +import asyncio import json import os +import re import urllib.error import urllib.request from typing import Any, Sequence @@ -27,7 +29,7 @@ class OpenAILLM: def __init__( self, - model: str = "gpt-4.1-mini", + model: str = "gpt-4o-mini", api_key: str | None = None, *, base_url: str = _DEFAULT_URL, @@ -35,7 +37,8 @@ def __init__( ) -> None: self.model = model # Resolve lazily-ish: read env now, but tolerate absence until complete(). - self._api_key = api_key if api_key is not None else os.environ.get("OPENAI_API_KEY") + raw_key = api_key if api_key is not None else os.environ.get("OPENAI_API_KEY") + self._api_key = raw_key.strip() if raw_key else raw_key self._base_url = base_url self._timeout = timeout @@ -54,7 +57,7 @@ async def complete( if tools: payload["tools"] = [_encode_tool(t) for t in tools] - raw = self._post(payload) + raw = await asyncio.to_thread(self._post, payload) return _decode_completion(raw) def _post(self, payload: dict[str, Any]) -> dict[str, Any]: @@ -83,6 +86,10 @@ def _post(self, payload: dict[str, Any]) -> dict[str, Any]: raise RuntimeError(f"OpenAI returned non-JSON response: {text[:200]!r}") from exc +def _strip_thinking(text: str) -> str: + return re.sub(r".*?", "", text, flags=re.DOTALL).strip() + + def _encode_message(m: Message) -> dict[str, Any]: """Core :class:`Message` → an OpenAI chat message dict.""" out: dict[str, Any] = {"role": m.role.value} @@ -141,7 +148,7 @@ def _decode_completion(raw: dict[str, Any]) -> Completion: ) return Completion( - content=msg.get("content") or "", + content=_strip_thinking(msg.get("content") or ""), tool_calls=tool_calls, finish_reason=choice.get("finish_reason"), ) diff --git a/src/lang2sql/frontends/discord/bot.py b/src/lang2sql/frontends/discord/bot.py index aed6f78..949f9d9 100644 --- a/src/lang2sql/frontends/discord/bot.py +++ b/src/lang2sql/frontends/discord/bot.py @@ -148,27 +148,40 @@ async def _run(self, interaction: discord.Interaction, coro) -> None: await interaction.response.defer(thinking=True) message = await coro content, file = _to_sendable(message) - await interaction.followup.send(content=content or "(empty)", file=file) + kwargs: dict = {"content": content or "(empty)"} + if file is not None: + kwargs["file"] = file + await interaction.followup.send(**kwargs) async def on_message(self, message: discord.Message) -> None: """Treat an @mention (or a reply inside a thread) as a free-form query.""" if message.author == self.user: return + print(f"[DEBUG] message from {message.author}: content={message.content!r} mentions={[u.id for u in message.mentions]}") mentioned = self.user is not None and self.user.mentioned_in(message) in_thread = isinstance(message.channel, discord.Thread) + print(f"[DEBUG] mentioned={mentioned} in_thread={in_thread} self.user={self.user}") if not mentioned and not in_thread: return text = message.content if self.user is not None: text = text.replace(self.user.mention, "").strip() + print(f"[DEBUG] text after strip: {text!r}") if not text: return identity = to_identity(_message_context(message)) - out = await self._handlers.query(identity, text) - content, file = _to_sendable(out) - await message.channel.send(content=content or "(empty)", file=file) + try: + out = await self._handlers.query(identity, text) + content, file = _to_sendable(out) + if content and len(content) > 1900: + content = content[:1900] + "\n…(truncated)" + await message.channel.send(content=content or "(empty)", file=file) + except Exception as exc: + import traceback + traceback.print_exc() + await message.channel.send(content=f"❌ Error: {type(exc).__name__}: {exc}") def run() -> None: diff --git a/src/lang2sql/frontends/discord/commands.py b/src/lang2sql/frontends/discord/commands.py index 63e2ee8..89378f3 100644 --- a/src/lang2sql/frontends/discord/commands.py +++ b/src/lang2sql/frontends/discord/commands.py @@ -22,6 +22,7 @@ from ...adapters.db.dsn_builder import assemble from ...core.identity import Identity from ...core.ports.frontend import OutboundMessage +from ...core.types import Role from ...harness.loop import agent_loop from ...tenancy.concierge import ContextConcierge from .render import render_answer @@ -43,7 +44,28 @@ async def query(self, identity: Identity, text: str) -> OutboundMessage: ctx = await self._concierge.build_context(identity, user_text=text) answer = await agent_loop(ctx, text) await self._concierge.store.save(identity.session_key(), ctx.session) - return render_answer(answer) + + history = ctx.session.history() + + sql_queries = [ + tc.arguments["sql"] + for msg in history + if msg.role == Role.ASSISTANT and msg.tool_calls + for tc in msg.tool_calls + if tc.name == "run_sql" and "sql" in tc.arguments + ] + sql_results = [ + msg.content + for msg in history + if msg.role == Role.TOOL and msg.name == "run_sql" and msg.content + ] + + suffix = "" + if sql_queries: + suffix += "\n\n**SQL:**\n```sql\n" + "\n\n".join(sql_queries) + "\n```" + if sql_results: + suffix += "\n\n**결과:**\n```\n" + "\n\n".join(sql_results) + "\n```" + return render_answer(answer + suffix) async def define_metric( self, From 233e78d10cd3f126d4e64efb711dc438fbc853a8 Mon Sep 17 00:00:00 2001 From: sewon Date: Sat, 6 Jun 2026 09:31:16 +0900 Subject: [PATCH 2/3] =?UTF-8?q?feat:=20DB=20=EA=B0=95=EA=B1=B4=EC=84=B1=20?= =?UTF-8?q?=EA=B3=A0=EB=8F=84=ED=99=94=20=E2=80=94=20enrich=20=EA=B8=B0?= =?UTF-8?q?=EB=8A=A5=20=EA=B5=AC=ED=98=84=20+=20RowLimitLayer=20+=20?= =?UTF-8?q?=EC=84=B8=EC=85=98=20=EC=95=95=EC=B6=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - EnrichSchemaTool: 컬럼 샘플 데이터 기반 LLM 자동 메타데이터 보강 (/enrich 커맨드) - DISTINCT 값 샘플링 후 단일 LLM 호출로 컬럼 설명 + FK 관계 추론 - KV 영속화(enriched_desc:table:col, schema_relationships) - 시스템 프롬프트에 컬럼 설명 + Table relationships 섹션 주입 - RowLimitLayer: 최상위 SELECT에 LIMIT 자동 추가 (서브쿼리 LIMIT 무시) - Safety pipeline 기본 순서: WhitelistLayer → RowLimitLayer → TimeoutLayer - Session.compress(): 매 턴 저장 전 tool call/result 메시지 제거로 세션 오염 방지 - 세션 SQL 추출 범위를 현재 턴(pre_loop_len)으로 제한 - OpenAI ASSISTANT null content 400 오류 수정 - NeonDB psycopg3 드라이버 자동 변환 + sslmode 자동 추가 - Discord tree.sync() rate limit 방지 (LANG2SQL_SYNC_COMMANDS=true 시만 동기화) Co-Authored-By: Claude Sonnet 4.6 --- .env.example | 4 + src/lang2sql/adapters/db/factory.py | 7 +- .../adapters/db/sqlalchemy_explorer.py | 11 +- src/lang2sql/adapters/llm/openai_.py | 8 +- src/lang2sql/adapters/storage/sqlite_store.py | 9 + src/lang2sql/frontends/discord/bot.py | 17 +- src/lang2sql/frontends/discord/commands.py | 18 +- src/lang2sql/harness/context.py | 5 + src/lang2sql/harness/session.py | 14 ++ src/lang2sql/harness/system_prompt.py | 41 ++++- src/lang2sql/safety/__init__.py | 4 +- src/lang2sql/safety/layers/__init__.py | 3 +- src/lang2sql/safety/layers/row_limit.py | 41 +++++ src/lang2sql/safety/pipeline.py | 6 +- src/lang2sql/tenancy/concierge.py | 12 +- src/lang2sql/tools/__init__.py | 6 +- src/lang2sql/tools/enrich_schema.py | 165 ++++++++++++++++++ src/lang2sql/tools/explore_schema.py | 20 +++ tests/test_integration.py | 4 +- tests/test_safety.py | 57 +++++- 20 files changed, 425 insertions(+), 27 deletions(-) create mode 100644 src/lang2sql/safety/layers/row_limit.py create mode 100644 src/lang2sql/tools/enrich_schema.py diff --git a/.env.example b/.env.example index fb5fbf0..8090667 100644 --- a/.env.example +++ b/.env.example @@ -4,6 +4,10 @@ # Required to run `lang2sql-bot`; the bot exits with a clear error if it's unset. DISCORD_BOT_TOKEN= +# Set to true only when you add/remove slash commands to re-sync with Discord. +# Leave unset (or false) during normal development to avoid rate limits. +LANG2SQL_SYNC_COMMANDS= + # OpenAI API key. Optional: when set, the agent uses gpt-4.1-mini. When unset, # it falls back to the offline FakeLLM (deterministic canned tool cycles — fine # for a smoke run, not for real answers). diff --git a/src/lang2sql/adapters/db/factory.py b/src/lang2sql/adapters/db/factory.py index c714f39..b2180a0 100644 --- a/src/lang2sql/adapters/db/factory.py +++ b/src/lang2sql/adapters/db/factory.py @@ -55,6 +55,10 @@ def build_explorer( token=extras.get("d1_token"), ) + # Normalize bare postgresql:// → postgresql+psycopg:// (psycopg3 is installed). + if scheme == "postgresql": + connection = "postgresql+psycopg" + connection[len("postgresql"):] + # Anything else is assumed to be a SQLAlchemy URL (driver loaded lazily). return SqlAlchemyExplorer(connection, schema=schema) @@ -67,7 +71,8 @@ def explorer_from_env() -> ExplorerPort | None: """ url = os.environ.get("LANG2SQL_DB_URL") if url: - return build_explorer(url, schema=os.environ.get("LANG2SQL_DB_SCHEMA")) + schema = os.environ.get("LANG2SQL_DB_SCHEMA") or None + return build_explorer(url, schema=schema) account = os.environ.get("CLOUDFLARE_D1_ACCOUNT_ID") database = os.environ.get("CLOUDFLARE_D1_DATABASE_ID") diff --git a/src/lang2sql/adapters/db/sqlalchemy_explorer.py b/src/lang2sql/adapters/db/sqlalchemy_explorer.py index fbacaad..c7129d8 100644 --- a/src/lang2sql/adapters/db/sqlalchemy_explorer.py +++ b/src/lang2sql/adapters/db/sqlalchemy_explorer.py @@ -55,10 +55,15 @@ async def execute(self, sql: str, limit: int = 1000) -> list[dict]: def _list_tables_sync(self) -> list[Table]: from sqlalchemy import inspect - insp = inspect(self._get_engine()) - schema = self._schema or insp.default_schema_name + engine = self._get_engine() + engine.dispose() # flush stale pool connections so schema changes are visible + insp = inspect(engine) + default = insp.default_schema_name + effective = self._schema or default + # Omit schema when it's the connection default so SQL stays unqualified. + display_schema = "" if (not self._schema or self._schema == default) else effective return [ - Table(name=t, schema=schema or "") + Table(name=t, schema=display_schema) for t in insp.get_table_names(schema=self._schema) ] diff --git a/src/lang2sql/adapters/llm/openai_.py b/src/lang2sql/adapters/llm/openai_.py index 17f2ceb..df3a112 100644 --- a/src/lang2sql/adapters/llm/openai_.py +++ b/src/lang2sql/adapters/llm/openai_.py @@ -93,8 +93,12 @@ def _strip_thinking(text: str) -> str: def _encode_message(m: Message) -> dict[str, Any]: """Core :class:`Message` → an OpenAI chat message dict.""" out: dict[str, Any] = {"role": m.role.value} - # OpenAI wants content present (may be null when only tool_calls are set). - out["content"] = m.content or None + # OpenAI allows null content only when tool_calls are present. + # For plain assistant messages (after session compress), force empty string. + if m.role == Role.ASSISTANT and not m.tool_calls: + out["content"] = m.content or "" + else: + out["content"] = m.content or None if m.role == Role.TOOL: out["tool_call_id"] = m.tool_call_id if m.name: diff --git a/src/lang2sql/adapters/storage/sqlite_store.py b/src/lang2sql/adapters/storage/sqlite_store.py index 99ce42d..ef0c674 100644 --- a/src/lang2sql/adapters/storage/sqlite_store.py +++ b/src/lang2sql/adapters/storage/sqlite_store.py @@ -130,6 +130,15 @@ def kv_delete(self, scope: str, key: str) -> None: ) self._conn.commit() + def kv_delete_prefix(self, scope: str, prefix: str) -> int: + """Delete all keys under scope that start with prefix. Returns count deleted.""" + cur = self._conn.execute( + "DELETE FROM kv WHERE scope = ? AND key LIKE ?", + (scope, prefix + "%"), + ) + self._conn.commit() + return cur.rowcount + # -- Session (de)serialization ------------------------------------------ diff --git a/src/lang2sql/frontends/discord/bot.py b/src/lang2sql/frontends/discord/bot.py index 949f9d9..1ebfb17 100644 --- a/src/lang2sql/frontends/discord/bot.py +++ b/src/lang2sql/frontends/discord/bot.py @@ -102,8 +102,11 @@ def __init__(self, handlers: CommandHandlers) -> None: self._register_commands() async def setup_hook(self) -> None: - # Sync slash commands with Discord on startup. - await self.tree.sync() + # Sync only when LANG2SQL_SYNC_COMMANDS=true (e.g. after adding/removing commands). + # Skipping sync on every restart avoids Discord rate limits during dev. + if os.environ.get("LANG2SQL_SYNC_COMMANDS", "").lower() == "true": + await self.tree.sync() + print("[bot] slash commands synced") def _register_commands(self) -> None: tree = self.tree @@ -135,6 +138,13 @@ async def define_metric( async def remember(interaction: discord.Interaction, text: str) -> None: await self._run(interaction, handlers.remember(to_identity(_interaction_context(interaction)), text)) + @tree.command(name="enrich", description="LLM으로 DB 컬럼 메타데이터 자동 보강 (clear=True로 초기화)") + async def enrich(interaction: discord.Interaction, table: str = "", clear: bool = False) -> None: + await self._run( + interaction, + handlers.enrich(to_identity(_interaction_context(interaction)), table=table, clear=clear), + ) + @tree.command(name="semantic_show", description="Show definitions in effect here") async def semantic_show(interaction: discord.Interaction) -> None: await self._run(interaction, handlers.semantic_show(to_identity(_interaction_context(interaction)))) @@ -195,6 +205,7 @@ def run() -> None: raise RuntimeError( f"{TOKEN_ENV} is not set; export your Discord bot token to run the bot." ) - handlers = CommandHandlers(ContextConcierge()) + data_path = os.environ.get("LANG2SQL_DATA_PATH", "lang2sql_data.db") + handlers = CommandHandlers(ContextConcierge(path=data_path)) client = Lang2SQLBot(handlers) client.run(token) diff --git a/src/lang2sql/frontends/discord/commands.py b/src/lang2sql/frontends/discord/commands.py index 89378f3..a213a5b 100644 --- a/src/lang2sql/frontends/discord/commands.py +++ b/src/lang2sql/frontends/discord/commands.py @@ -42,24 +42,28 @@ async def query(self, identity: Identity, text: str) -> OutboundMessage: thread/DM continues the conversation (tiebreaker #4). """ ctx = await self._concierge.build_context(identity, user_text=text) + pre_loop_len = len(ctx.session.history()) answer = await agent_loop(ctx, text) - await self._concierge.store.save(identity.session_key(), ctx.session) history = ctx.session.history() + current_turn = history[pre_loop_len:] sql_queries = [ tc.arguments["sql"] - for msg in history + for msg in current_turn if msg.role == Role.ASSISTANT and msg.tool_calls for tc in msg.tool_calls if tc.name == "run_sql" and "sql" in tc.arguments ] sql_results = [ msg.content - for msg in history + for msg in current_turn if msg.role == Role.TOOL and msg.name == "run_sql" and msg.content ] + ctx.session.compress() + await self._concierge.store.save(identity.session_key(), ctx.session) + suffix = "" if sql_queries: suffix += "\n\n**SQL:**\n```sql\n" + "\n\n".join(sql_queries) + "\n```" @@ -171,6 +175,14 @@ async def register_db_for_guild( ) ) + async def enrich(self, identity: Identity, table: str = "", clear: bool = False) -> OutboundMessage: + """Run EnrichSchema tool: sample DB columns and LLM-infer descriptions.""" + ctx = await self._concierge.build_context(identity) + result = await ctx.tools.dispatch( + "enrich_schema", {"table": table, "clear": clear}, ctx, "cmd:enrich" + ) + return OutboundMessage(text=result.content) + async def connect(self, identity: Identity, dsn: str) -> OutboundMessage: """V1 stub: stash a DB DSN keyed by guild/DM in the concierge kv store. diff --git a/src/lang2sql/harness/context.py b/src/lang2sql/harness/context.py index 1d3431b..5266521 100644 --- a/src/lang2sql/harness/context.py +++ b/src/lang2sql/harness/context.py @@ -9,8 +9,12 @@ from __future__ import annotations from dataclasses import dataclass +from typing import TYPE_CHECKING from ..core.identity import Identity + +if TYPE_CHECKING: + from ..adapters.storage.sqlite_store import SqliteStore from ..core.ports.audit import AuditPort from ..core.ports.explorer import ExplorerPort from ..core.ports.llm import LLMPort @@ -30,4 +34,5 @@ class HarnessContext: safety: SafetyPipelinePort | None = None audit: AuditPort | None = None scope_resolver: ScopeResolverPort | None = None + store: SqliteStore | None = None max_turns: int = 8 diff --git a/src/lang2sql/harness/session.py b/src/lang2sql/harness/session.py index 6eda61f..b3f31d6 100644 --- a/src/lang2sql/harness/session.py +++ b/src/lang2sql/harness/session.py @@ -26,3 +26,17 @@ def history(self) -> list[Message]: def reset(self) -> None: self.transcript.clear() + + def compress(self) -> None: + """Remove tool call/result messages to prevent context pollution across turns.""" + from ..core.types import Role + cleaned: list[Message] = [] + for msg in self.transcript: + if msg.role == Role.TOOL: + continue + if msg.role == Role.ASSISTANT and msg.tool_calls: + if msg.content: # skip if no text content — empty assistant messages confuse OpenAI + cleaned.append(Message(role=Role.ASSISTANT, content=msg.content)) + else: + cleaned.append(msg) + self.transcript = cleaned diff --git a/src/lang2sql/harness/system_prompt.py b/src/lang2sql/harness/system_prompt.py index 10d2837..3899a6f 100644 --- a/src/lang2sql/harness/system_prompt.py +++ b/src/lang2sql/harness/system_prompt.py @@ -8,6 +8,8 @@ from __future__ import annotations +import json + from .context import HarnessContext _BASE = """\ @@ -18,7 +20,7 @@ - When you need data, call the run_sql tool with a single SELECT/WITH query. - Discover schema with explore_schema before guessing table or column names. - Prefer definitions from the semantic layer below over your own assumptions. -- Answer concisely; show the SQL you ran. +- Answer concisely. Show only the final successful SQL you ran, not intermediate attempts. """ @@ -34,7 +36,40 @@ async def build_system_prompt(ctx: HarnessContext) -> str: if ctx.explorer is not None: tables = await ctx.explorer.list_tables() if tables: - names = ", ".join(t.qualified for t in tables) - parts.append("## Known tables\n" + names) + scope = (ctx.identity.guild_id or f"dm:{ctx.identity.user_id}") if ctx.store else None + has_enrichment = bool( + scope and ctx.store and + ctx.store.kv_get(scope, "schema_relationships") + ) + + if has_enrichment and scope and ctx.store: + schema_lines: list[str] = [] + for tbl in tables: + try: + described = await ctx.explorer.describe_table(tbl.name) + except Exception: + schema_lines.append(f"- {tbl.qualified}") + continue + col_lines = [] + for col in described.columns: + desc = col.description or ctx.store.kv_get(scope, f"enriched_desc:{tbl.name}:{col.name}") or "" + col_lines.append(f" - {col.name}{': ' + desc if desc else ''}") + schema_lines.append(f"- {tbl.qualified}\n" + "\n".join(col_lines)) + parts.append("## Known tables (with column descriptions)\n" + "\n".join(schema_lines)) + else: + names = ", ".join(t.qualified for t in tables) + parts.append("## Known tables\n" + names) + + if ctx.store is not None: + scope = ctx.identity.guild_id or f"dm:{ctx.identity.user_id}" + raw = ctx.store.kv_get(scope, "schema_relationships") + if raw: + try: + rels = json.loads(raw) + if rels: + rel_text = "\n".join(f"- {r}" for r in rels) + parts.append("## Table relationships (use these for JOINs)\n" + rel_text) + except (ValueError, TypeError): + pass return "\n\n".join(parts) diff --git a/src/lang2sql/safety/__init__.py b/src/lang2sql/safety/__init__.py index f0cf9f5..d8cc9f4 100644 --- a/src/lang2sql/safety/__init__.py +++ b/src/lang2sql/safety/__init__.py @@ -2,7 +2,7 @@ from __future__ import annotations -from .layers import TimeoutLayer, WhitelistLayer +from .layers import RowLimitLayer, TimeoutLayer, WhitelistLayer from .pipeline import SafetyPipeline -__all__ = ["SafetyPipeline", "WhitelistLayer", "TimeoutLayer"] +__all__ = ["SafetyPipeline", "WhitelistLayer", "RowLimitLayer", "TimeoutLayer"] diff --git a/src/lang2sql/safety/layers/__init__.py b/src/lang2sql/safety/layers/__init__.py index b311b42..98b9c6a 100644 --- a/src/lang2sql/safety/layers/__init__.py +++ b/src/lang2sql/safety/layers/__init__.py @@ -2,7 +2,8 @@ from __future__ import annotations +from .row_limit import RowLimitLayer from .timeout import TimeoutLayer from .whitelist import WhitelistLayer -__all__ = ["WhitelistLayer", "TimeoutLayer"] +__all__ = ["WhitelistLayer", "RowLimitLayer", "TimeoutLayer"] diff --git a/src/lang2sql/safety/layers/row_limit.py b/src/lang2sql/safety/layers/row_limit.py new file mode 100644 index 0000000..919299e --- /dev/null +++ b/src/lang2sql/safety/layers/row_limit.py @@ -0,0 +1,41 @@ +"""RowLimitLayer — automatically appends LIMIT to queries that lack one. + +Returns PASS (not REWRITE) so the rewritten SQL flows into the next layer +(TimeoutLayer) and ctx.timeout_seconds still gets configured. +""" + +from __future__ import annotations + +import re + +from ...core.ports.safety import SafetyContext, SafetyDecision, Verdict + +_LIMIT_RE = re.compile(r"\bLIMIT\s+\d+", re.IGNORECASE) + + +def _has_top_level_limit(sql: str) -> bool: + """True if a LIMIT clause exists at parenthesis depth 0.""" + depth = 0 + for part in re.split(r"(\(|\))", sql): + if part == "(": + depth += 1 + elif part == ")": + depth -= 1 + elif depth == 0 and _LIMIT_RE.search(part): + return True + return False + + +class RowLimitLayer: + """Appends LIMIT when absent; leaves explicit LIMITs untouched.""" + + @property + def name(self) -> str: + return "row_limit" + + def check(self, sql: str, ctx: SafetyContext) -> SafetyDecision: + if _has_top_level_limit(sql): + return SafetyDecision(verdict=Verdict.PASS, sql=sql, layer=self.name) + limit = ctx.row_limit if ctx.row_limit and ctx.row_limit > 0 else 1000 + rewritten = sql.rstrip().rstrip(";") + f"\nLIMIT {limit}" + return SafetyDecision(verdict=Verdict.PASS, sql=rewritten, layer=self.name) diff --git a/src/lang2sql/safety/pipeline.py b/src/lang2sql/safety/pipeline.py index 968d498..96909a4 100644 --- a/src/lang2sql/safety/pipeline.py +++ b/src/lang2sql/safety/pipeline.py @@ -18,12 +18,12 @@ SafetyLayerPort, Verdict, ) -from .layers import TimeoutLayer, WhitelistLayer +from .layers import RowLimitLayer, TimeoutLayer, WhitelistLayer def _default_layers() -> list[SafetyLayerPort]: - # Whitelist first (cheap, fail-closed reject), then Timeout (exec config). - return [WhitelistLayer(), TimeoutLayer()] + # Whitelist (reject) → RowLimit (rewrite) → Timeout (exec config). + return [WhitelistLayer(), RowLimitLayer(), TimeoutLayer()] class SafetyPipeline: diff --git a/src/lang2sql/tenancy/concierge.py b/src/lang2sql/tenancy/concierge.py index 2d2e318..526eadf 100644 --- a/src/lang2sql/tenancy/concierge.py +++ b/src/lang2sql/tenancy/concierge.py @@ -155,12 +155,22 @@ async def build_context( safety=self._safety, audit=self._audit, scope_resolver=self._scope_resolver, + store=self._store, max_turns=self._max_turns, ) def _default_llm() -> LLMPort: - """OpenAI when a key is present, otherwise the offline FakeLLM.""" + """Local vLLM/Ollama when LANG2SQL_LLM_BASE_URL is set, OpenAI when keyed, else FakeLLM.""" + base_url = os.environ.get("LANG2SQL_LLM_BASE_URL") + if base_url: + model = os.environ.get("LANG2SQL_LLM_MODEL", "default") + # Local servers (vLLM, Ollama) speak OpenAI-compatible API; dummy key satisfies the header. + api_key = os.environ.get("OPENAI_API_KEY") or "local" + url = base_url.rstrip("/") + if not url.endswith("/chat/completions"): + url = url + "/v1/chat/completions" + return OpenAILLM(model=model, api_key=api_key, base_url=url) if os.environ.get("OPENAI_API_KEY"): return OpenAILLM() return FakeLLM() diff --git a/src/lang2sql/tools/__init__.py b/src/lang2sql/tools/__init__.py index cc5b398..cbd2e8f 100644 --- a/src/lang2sql/tools/__init__.py +++ b/src/lang2sql/tools/__init__.py @@ -14,6 +14,7 @@ from ..memory.service import MemoryService from .ask_user import AskUser from .define_metric import DefineMetric +from .enrich_schema import EnrichSchema from .explore_schema import ExploreSchema from .ingest_doc import IngestDoc from .remember import Remember @@ -21,7 +22,7 @@ __all__ = [ "build_default_tools", - "RunSQL", "ExploreSchema", "DefineMetric", "Remember", "AskUser", "IngestDoc", + "RunSQL", "ExploreSchema", "EnrichSchema", "DefineMetric", "Remember", "AskUser", "IngestDoc", ] @@ -32,10 +33,11 @@ def build_default_tools( source: SourcePort, extractor: DocExtractorPort, ) -> list[ToolPort]: - """The six V1 tools (v4.1 §4.1).""" + """The V1 tools.""" return [ RunSQL(), ExploreSchema(), + EnrichSchema(), DefineMetric(), AskUser(), Remember(memory), diff --git a/src/lang2sql/tools/enrich_schema.py b/src/lang2sql/tools/enrich_schema.py new file mode 100644 index 0000000..7898b20 --- /dev/null +++ b/src/lang2sql/tools/enrich_schema.py @@ -0,0 +1,165 @@ +"""enrich_schema — LLM-powered column metadata enrichment. + +Samples DISTINCT values from each column, sends the full schema + samples to +the LLM in a single call, and stores the inferred descriptions in the KV store. +Subsequent explore_schema calls read from this cache (highest priority). + +KV key pattern: enriched_desc:{table}:{column} +""" + +from __future__ import annotations + +import json +import re +from typing import TYPE_CHECKING, Any + +from ..core.types import Message, Role, ToolResult, ToolSpec + +if TYPE_CHECKING: + from ..harness.context import HarnessContext + +_SAMPLE_LIMIT = 10 +_KV_PREFIX = "enriched_desc" +_KV_RELATIONSHIPS = "schema_relationships" + + +def _kv_key(table: str, column: str) -> str: + return f"{_KV_PREFIX}:{table}:{column}" + + +def _build_prompt(schema_block: str) -> str: + return ( + "다음은 DB 테이블들의 스키마와 실제 샘플 데이터야.\n" + "각 컬럼의 의미와 테이블 간 JOIN 관계를 추론해줘.\n\n" + f"{schema_block}\n\n" + "아래 JSON 형식으로만 응답해:\n" + "{\n" + ' "columns": {"테이블명.컬럼명": "컬럼 설명 (추론 불확실하면 빈 문자열)"},\n' + ' "relationships": ["tableA.col = tableB.col", ...]\n' + "}\n\n" + "설명 작성 규칙:\n" + "- 코드값 컬럼(짧은 문자열 샘플): 샘플에서 추론한 각 값의 의미 명시\n" + "- 계산에 쓰이는 컬럼 쌍: 실제 계산 공식을 설명에 포함\n" + "- relationships: 샘플값이 겹치거나 의미상 같은 컬럼 쌍을 'A.x = B.y' 형식으로 나열\n" + " (FK 선언이 없어도 값이 같으면 포함)" + ) + + +def _extract_result(text: str) -> tuple[dict[str, str], list[str]]: + """Extract columns dict and relationships list from LLM response.""" + m = re.search(r"\{.*\}", text, re.DOTALL) + if not m: + return {}, [] + try: + data = json.loads(m.group(0)) + columns = data.get("columns", {}) if isinstance(data, dict) else {} + relationships = data.get("relationships", []) if isinstance(data, dict) else [] + return columns, relationships + except (ValueError, TypeError): + return {}, [] + + +class EnrichSchema: + @property + def spec(self) -> ToolSpec: + return ToolSpec( + name="enrich_schema", + description=( + "DB 컬럼 메타데이터를 실제 샘플 데이터 기반으로 LLM이 자동 보강한다. " + "테이블 간 FK 관계도 추론한다. /enrich 명령으로 호출." + ), + parameters={ + "type": "object", + "properties": { + "table": { + "type": "string", + "description": "보강할 테이블명 (생략 시 전체 테이블)", + }, + "clear": { + "type": "boolean", + "description": "true이면 보강 캐시를 초기화", + }, + }, + }, + ) + + async def run(self, args: dict[str, Any], ctx: "HarnessContext") -> ToolResult: + if ctx.explorer is None: + return ToolResult(call_id="", content="DB가 연결되지 않았습니다 (/connect 먼저).", is_error=True) + if ctx.store is None: + return ToolResult(call_id="", content="KV store를 사용할 수 없습니다.", is_error=True) + + scope = ctx.identity.guild_id or f"dm:{ctx.identity.user_id}" + + if args.get("clear"): + count = ctx.store.kv_delete_prefix(scope, _KV_PREFIX + ":") + ctx.store.kv_delete(scope, _KV_RELATIONSHIPS) + return ToolResult(call_id="", content=f"🗑️ 보강 캐시 초기화 완료 ({count}개 삭제)") + + target = (args.get("table") or "").strip() + all_tables = await ctx.explorer.list_tables() + if target: + tables = [t for t in all_tables if t.name == target or t.qualified == target] + if not tables: + return ToolResult(call_id="", content=f"테이블 '{target}'을 찾을 수 없습니다.", is_error=True) + else: + tables = all_tables + + # Build schema block with sample values for each column. + schema_lines: list[str] = [] + for tbl in tables: + described = await ctx.explorer.describe_table(tbl.name) + schema_lines.append(f"테이블: {tbl.name}") + for col in described.columns: + try: + sample_sql = ( + f"SELECT DISTINCT {col.name} FROM {tbl.qualified} " + f"WHERE {col.name} IS NOT NULL LIMIT {_SAMPLE_LIMIT}" + ) + rows = await ctx.explorer.execute(sample_sql, _SAMPLE_LIMIT) + samples = [str(r.get(col.name, r.get(list(r.keys())[0], ""))) for r in rows] + except Exception: + samples = [] + sample_str = f" 샘플: {samples}" if samples else "" + schema_lines.append(f"- {col.name} ({col.type}){sample_str}") + schema_lines.append("") + + schema_block = "\n".join(schema_lines) + prompt = _build_prompt(schema_block) + + # Single LLM call for all tables at once. + completion = await ctx.llm.complete( + [Message(role=Role.USER, content=prompt)] + ) + columns, relationships = _extract_result(completion.content) + + if not columns and not relationships: + return ToolResult( + call_id="", + content="LLM이 JSON을 반환하지 않았습니다. 다시 시도해주세요.", + is_error=True, + ) + + saved: list[str] = [] + for key, desc in columns.items(): + if not desc: + continue + parts = key.split(".", 1) + if len(parts) != 2: + continue + tbl_name, col_name = parts + ctx.store.kv_set(scope, _kv_key(tbl_name, col_name), desc) + saved.append(f"- {key}: {desc}") + + rel_lines: list[str] = [] + if relationships: + ctx.store.kv_set(scope, _KV_RELATIONSHIPS, json.dumps(relationships, ensure_ascii=False)) + rel_lines = [f"- {r}" for r in relationships] + + result_parts = [] + if saved: + result_parts.append("✅ 컬럼 메타데이터 보강 완료:\n" + "\n".join(saved)) + if rel_lines: + result_parts.append("🔗 테이블 관계 추론:\n" + "\n".join(rel_lines)) + + return ToolResult(call_id="", content="\n\n".join(result_parts) or "보강된 내용이 없습니다.") diff --git a/src/lang2sql/tools/explore_schema.py b/src/lang2sql/tools/explore_schema.py index f5e08bb..327f3c3 100644 --- a/src/lang2sql/tools/explore_schema.py +++ b/src/lang2sql/tools/explore_schema.py @@ -6,13 +6,32 @@ from __future__ import annotations +from dataclasses import replace from typing import TYPE_CHECKING, Any +from ..core.ports.explorer import Column, Table from ..core.types import ToolResult, ToolSpec if TYPE_CHECKING: from ..harness.context import HarnessContext +_KV_PREFIX = "enriched_desc" + + +def _apply_enrichment_cache(table: Table, ctx: "HarnessContext") -> Table: + """Overlay KV-cached descriptions onto columns that lack one.""" + if ctx.store is None: + return table + scope = ctx.identity.guild_id or f"dm:{ctx.identity.user_id}" + enriched_cols: list[Column] = [] + for col in table.columns: + if col.description: + enriched_cols.append(col) + continue + cached = ctx.store.kv_get(scope, f"{_KV_PREFIX}:{table.name}:{col.name}") + enriched_cols.append(replace(col, description=cached or "")) + return replace(table, columns=enriched_cols) + class ExploreSchema: @property @@ -39,6 +58,7 @@ async def run(self, args: dict[str, Any], ctx: "HarnessContext") -> ToolResult: return ToolResult(call_id="", content="Tables:\n" + names) t = await ctx.explorer.describe_table(table) + t = _apply_enrichment_cache(t, ctx) cols = "\n".join( f"- {c.name}: {c.type}{'' if c.nullable else ' NOT NULL'}" f"{(' — ' + c.description) if c.description else ''}" diff --git a/tests/test_integration.py b/tests/test_integration.py index 62d2016..413f2ff 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -23,10 +23,10 @@ def _ctx(): return ident, asyncio.run(concierge.build_context(ident)) -def test_six_v1_tools_registered(): +def test_v1_tools_registered(): _, ctx = _ctx() names = {s.name for s in ctx.tools.specs()} - assert names == {"run_sql", "explore_schema", "define_metric", "ask_user", "remember", "ingest_doc"} + assert names == {"run_sql", "explore_schema", "enrich_schema", "define_metric", "ask_user", "remember", "ingest_doc"} def test_run_sql_passes_gate_and_returns_rows(): diff --git a/tests/test_safety.py b/tests/test_safety.py index 669f9db..4532339 100644 --- a/tests/test_safety.py +++ b/tests/test_safety.py @@ -122,4 +122,59 @@ def test_default_timeout_set_on_pass(): def test_pipeline_exposes_default_layers(): pipeline = SafetyPipeline() names = [layer.name for layer in pipeline.layers] - assert names == ["whitelist", "timeout"] + assert names == ["whitelist", "row_limit", "timeout"] + + +# --- RowLimitLayer tests ------------------------------------------------------ + + +def test_row_limit_added_when_absent(): + ctx = SafetyContext(row_limit=500) + decision = SafetyPipeline().evaluate("SELECT * FROM users", ctx) + assert decision.verdict is Verdict.PASS + assert "LIMIT 500" in decision.sql.upper() + + +def test_row_limit_not_added_when_present(): + sql = "SELECT * FROM users LIMIT 10" + ctx = SafetyContext(row_limit=500) + decision = SafetyPipeline().evaluate(sql, ctx) + assert decision.verdict is Verdict.PASS + assert decision.sql.upper().count("LIMIT") == 1 + + +def test_row_limit_ignores_subquery_limit(): + sql = "SELECT * FROM (SELECT id FROM t LIMIT 5) sub" + ctx = SafetyContext(row_limit=100) + decision = SafetyPipeline().evaluate(sql, ctx) + assert decision.verdict is Verdict.PASS + assert "LIMIT 100" in decision.sql.upper() + + +def test_row_limit_ignores_cte_limit(): + sql = "WITH a AS (SELECT * FROM t LIMIT 10) SELECT * FROM a" + ctx = SafetyContext(row_limit=200) + decision = SafetyPipeline().evaluate(sql, ctx) + assert decision.verdict is Verdict.PASS + assert "LIMIT 200" in decision.sql.upper() + + +def test_row_limit_and_timeout_both_applied(): + ctx = SafetyContext(row_limit=42, timeout_seconds=0) + decision = SafetyPipeline().evaluate("SELECT * FROM t", ctx) + assert decision.verdict is Verdict.PASS + assert "LIMIT 42" in decision.sql.upper() + assert ctx.timeout_seconds == 30 + + +def test_row_limit_default_1000_when_unset(): + ctx = SafetyContext() + decision = SafetyPipeline().evaluate("SELECT * FROM t", ctx) + assert "LIMIT 1000" in decision.sql.upper() + + +def test_case_08_huge_table_gets_limit(): + ctx = SafetyContext(row_limit=1000) + decision = SafetyPipeline().evaluate("SELECT * FROM huge_table", ctx) + assert decision.verdict is Verdict.PASS + assert "LIMIT 1000" in decision.sql.upper() From db97bd4ee0734eb4bba8227c026ba9ba343cde31 Mon Sep 17 00:00:00 2001 From: sewon Date: Sat, 6 Jun 2026 21:59:34 +0900 Subject: [PATCH 3/3] fix: remove debug prints, filter successful SQL only in Discord output PR #234 review feedback: replace stdout debug prints with logging, and pair run_sql tool_call_id with results to show only successful queries. Co-Authored-By: Claude Sonnet 4.6 --- src/lang2sql/frontends/discord/bot.py | 8 ++++---- src/lang2sql/frontends/discord/commands.py | 21 +++++++++++++-------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/lang2sql/frontends/discord/bot.py b/src/lang2sql/frontends/discord/bot.py index 1ebfb17..edace0f 100644 --- a/src/lang2sql/frontends/discord/bot.py +++ b/src/lang2sql/frontends/discord/bot.py @@ -15,6 +15,7 @@ from __future__ import annotations import io +import logging import os import discord @@ -25,6 +26,8 @@ from .commands import CommandHandlers from .session_router import InteractionContext, to_identity +logger = logging.getLogger(__name__) + TOKEN_ENV = "DISCORD_BOT_TOKEN" @@ -106,7 +109,7 @@ async def setup_hook(self) -> None: # Skipping sync on every restart avoids Discord rate limits during dev. if os.environ.get("LANG2SQL_SYNC_COMMANDS", "").lower() == "true": await self.tree.sync() - print("[bot] slash commands synced") + logger.info("slash commands synced") def _register_commands(self) -> None: tree = self.tree @@ -167,17 +170,14 @@ async def on_message(self, message: discord.Message) -> None: """Treat an @mention (or a reply inside a thread) as a free-form query.""" if message.author == self.user: return - print(f"[DEBUG] message from {message.author}: content={message.content!r} mentions={[u.id for u in message.mentions]}") mentioned = self.user is not None and self.user.mentioned_in(message) in_thread = isinstance(message.channel, discord.Thread) - print(f"[DEBUG] mentioned={mentioned} in_thread={in_thread} self.user={self.user}") if not mentioned and not in_thread: return text = message.content if self.user is not None: text = text.replace(self.user.mention, "").strip() - print(f"[DEBUG] text after strip: {text!r}") if not text: return diff --git a/src/lang2sql/frontends/discord/commands.py b/src/lang2sql/frontends/discord/commands.py index a213a5b..3c1edb4 100644 --- a/src/lang2sql/frontends/discord/commands.py +++ b/src/lang2sql/frontends/discord/commands.py @@ -48,18 +48,23 @@ async def query(self, identity: Identity, text: str) -> OutboundMessage: history = ctx.session.history() current_turn = history[pre_loop_len:] - sql_queries = [ - tc.arguments["sql"] + call_id_to_sql: dict[str, str] = { + tc.id: tc.arguments["sql"] for msg in current_turn if msg.role == Role.ASSISTANT and msg.tool_calls for tc in msg.tool_calls if tc.name == "run_sql" and "sql" in tc.arguments - ] - sql_results = [ - msg.content - for msg in current_turn - if msg.role == Role.TOOL and msg.name == "run_sql" and msg.content - ] + } + + sql_queries: list[str] = [] + sql_results: list[str] = [] + for msg in current_turn: + if msg.role != Role.TOOL or msg.name != "run_sql" or not msg.content: + continue + sql = call_id_to_sql.get(msg.tool_call_id or "") + if sql and ("row(s):" in msg.content or "(0 rows)" in msg.content): + sql_queries.append(sql) + sql_results.append(msg.content) ctx.session.compress() await self._concierge.store.save(identity.session_key(), ctx.session)