diff --git a/.github/workflows/auto-unit-test.yml b/.github/workflows/auto-unit-test.yml index 1595fc769..dace8dab6 100644 --- a/.github/workflows/auto-unit-test.yml +++ b/.github/workflows/auto-unit-test.yml @@ -36,7 +36,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' - name: Install uv run: pip install --upgrade uv diff --git a/.github/workflows/sdk_publish.yml b/.github/workflows/sdk_publish.yml index 1e5759277..3cc413381 100644 --- a/.github/workflows/sdk_publish.yml +++ b/.github/workflows/sdk_publish.yml @@ -21,7 +21,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.11' - name: Install build dependencies run: | diff --git a/.gitignore b/.gitignore index ec5b3a3f9..e0bac2b47 100644 --- a/.gitignore +++ b/.gitignore @@ -61,4 +61,9 @@ data/ sdk/benchmark/.env /docker/.env.bak -.venv \ No newline at end of file +.venv + +.pytest-tmp +doc/mermaid + +.claude/skills/python-import-triage \ No newline at end of file diff --git a/backend/adapters/__init__.py b/backend/adapters/__init__.py new file mode 100644 index 000000000..ed46fc888 --- /dev/null +++ b/backend/adapters/__init__.py @@ -0,0 +1,13 @@ +from adapters.exception import JiuwenSDKError, JiuwenSDKUnavailableError, NexentCapabilityError + +try: + from adapters.jiuwen_sdk_adapter import JiuwenSDKAdapter +except ModuleNotFoundError: + JiuwenSDKAdapter = None # type: ignore[assignment, misc] + +__all__ = [ + "JiuwenSDKError", + "JiuwenSDKUnavailableError", + "NexentCapabilityError", + "JiuwenSDKAdapter", +] diff --git a/backend/adapters/exception.py b/backend/adapters/exception.py new file mode 100644 index 000000000..63812d3af --- /dev/null +++ b/backend/adapters/exception.py @@ -0,0 +1,13 @@ +class JiuwenSDKError(Exception): + """Jiuwen SDK 调用失败的通用异常""" + pass + + +class JiuwenSDKUnavailableError(JiuwenSDKError): + """Jiuwen SDK 不可用(依赖缺失或未启用)""" + pass + + +class NexentCapabilityError(Exception): + """nexent 原生模式不支持该能力""" + pass diff --git a/backend/adapters/jiuwen_sdk_adapter.py b/backend/adapters/jiuwen_sdk_adapter.py new file mode 100644 index 000000000..f62ce9d06 --- /dev/null +++ b/backend/adapters/jiuwen_sdk_adapter.py @@ -0,0 +1,514 @@ +""" +openjiuwen SDK adapter for Nexent. + +This module must be imported lazily (not at module load time) because +openjiuwen 0.1.13 has circular import bugs in its __init__.py files that +prevent the SDK from loading unless we bypass them. + +Import flow: + backend/adapters/__init__.py -> try/except -> JiuwenSDKAdapter = None + -> when needed: _install_jiuwen_bypasser() -> openjiuwen imports work +""" +import asyncio +import importlib.abc +import importlib.machinery +import json +import logging +import os +import sys +import types +from typing import Any, List, Literal, Optional + +logger = logging.getLogger("jiuwen_adapter") + +from adapters.exception import JiuwenSDKError + + +# ---------------------------------------------------------------------- +# Circular import bypasser for openjiuwen 0.1.13 +# +# openjiuwen has broken __init__.py files that create circular import chains: +# tune/__init__.py -> tune.optimizer -> core.operator -> agent_evolving -> ... +# This bypasser prevents those __init__.py files from executing while still +# allowing regular .py submodule files to load normally. +# ---------------------------------------------------------------------- +_CIRCULAR_CHAIN = { + "openjiuwen.agent_evolving", + "openjiuwen.agent_evolving.trainer", + "openjiuwen.agent_evolving.trainer.trainer", + "openjiuwen.agent_evolving.trainer.progress", + "openjiuwen.core", + "openjiuwen.dev_tools", + "openjiuwen.dev_tools.tune", + "openjiuwen.dev_tools.tune.optimizer", + "openjiuwen.dev_tools.tune.optimizer.instruction_optimizer", + "openjiuwen.dev_tools.prompt_builder", + "openjiuwen.dev_tools.prompt_builder.builder", +} + + +class _JiuwenInitBypasser(importlib.abc.MetaPathFinder, importlib.abc.Loader): + """ + Meta path finder that intercepts __init__.py loading within openjiuwen, + blocking only the packages in the circular import chain while letting + all other modules (including base.py files) load normally. + """ + + def find_spec(self, fullname: str, path: Any, target: Any = None) -> Any: + if not fullname.startswith("openjiuwen") or fullname == "openjiuwen": + return None + + try: + import openjiuwen as _oj + + pkg_root = _oj.__path__[0] + except ImportError: + return None + + parts = fullname.split(".")[1:] + file_path = pkg_root + for p in parts: + file_path = os.path.join(file_path, p) + + is_package = os.path.isdir(file_path) + if not is_package: + return None + + init_path = os.path.join(file_path, "__init__.py") + if not os.path.exists(init_path): + return None + + if fullname not in _CIRCULAR_CHAIN: + return None + + spec = importlib.machinery.ModuleSpec( + fullname, self, is_package=True, origin="" + ) + spec.submodule_search_locations = [file_path] + return spec + + def create_module(self, module: Any) -> None: + return None + + def exec_module(self, module: Any) -> None: + import openjiuwen as _oj + + pkg_root = _oj.__path__[0] + parts = module.__name__.split(".")[1:] + file_path = pkg_root + for p in parts: + file_path = os.path.join(file_path, p) + module.__path__ = [file_path] + module.__file__ = os.path.join(file_path, "__init__.py") + + def __getattr__(self, name: str) -> Any: + """Handle special attributes like find_distributions to prevent recursion.""" + import openjiuwen as _oj + import importlib + + # Prevent recursion when Python scans sys.meta_path for find_distributions etc. + if name in ( + "find_distributions", + "find_module", + "__path__", + "__name__", + "__file__", + "__loader__", + "__package__", + "__spec__", + ): + raise AttributeError(name) + + pkg_root = _oj.__path__[0] + parts = self.__name__.split(".")[1:] + [name] + file_path = pkg_root + for p in parts: + file_path = os.path.join(file_path, p) + + # If it's a package directory, import it as a submodule + if os.path.isdir(file_path) and os.path.exists(os.path.join(file_path, "__init__.py")): + return importlib.import_module(f"{self.__name__}.{name}") + # If it's a regular .py file + if os.path.exists(file_path + ".py"): + return importlib.import_module(f"{self.__name__}.{name}") + raise AttributeError(name) + + +_bypasser_installed = False + + +def _install_jiuwen_bypasser() -> bool: + """ + Install the circular import bypasser for openjiuwen. + Returns True if installed, False if already installed or openjiuwen not available. + """ + global _bypasser_installed + if _bypasser_installed: + return True + + # Stub missing optional dependencies before openjiuwen import chain reaches them + _stubbed = [ + ("pymilvus", {"is_successful": lambda *args, **kwargs: True}), + ("dashscope", {}), + ("pdfplumber", {}), + ] + for _name, _attrs in _stubbed: + if _name not in sys.modules: + _mod = types.ModuleType(_name) + for _k, _v in _attrs.items(): + setattr(_mod, _k, _v) + sys.modules[_name] = _mod + _mod.__path__ = [] + + # Pre-create nested stub modules for pymilvus.client.utils chain + if "pymilvus.client" not in sys.modules: + _client_mod = types.ModuleType("pymilvus.client") + _client_mod.__path__ = [] + sys.modules["pymilvus.client"] = _client_mod + if "pymilvus.client.utils" not in sys.modules: + _utils_mod = types.ModuleType("pymilvus.client.utils") + _utils_mod.is_successful = lambda *args, **kwargs: True + sys.modules["pymilvus.client.utils"] = _utils_mod + + # Stub dashscope sub-modules that may be imported lazily + _dashscope_subs = [ + ("dashscope.api_entities", {}), + ("dashscope.api_entities.data", {}), + ("dashscope.api_entities.dashscope_response", {"DashScopeAPIResponse": object}), + ("dashscope.common", {"REQUEST_TIMEOUT_KEYWORD": "timeout"}), + ("dashscope.common.constants", {"REQUEST_TIMEOUT_KEYWORD": "timeout"}), + ] + for _name, _attrs in _dashscope_subs: + if _name not in sys.modules: + _m = types.ModuleType(_name) + _m.__path__ = [] + for _k, _v in _attrs.items(): + setattr(_m, _k, _v) + sys.modules[_name] = _m + + try: + import openjiuwen # noqa: F401 + except ImportError: + return False + + for finder in sys.meta_path: + if isinstance(finder, _JiuwenInitBypasser): + _bypasser_installed = True + return True + + sys.meta_path.insert(0, _JiuwenInitBypasser()) + _bypasser_installed = True + return True + + +# ---------------------------------------------------------------------- +# Language helpers +# ---------------------------------------------------------------------- +LANGUAGE_MAP = {"zh": "zh-CN", "en": "en-US"} + + +def normalize_language(language: str) -> str: + return LANGUAGE_MAP.get(language, "zh-CN") + + +def run_async(coro): + """ + Safely run async coroutine from sync context (FastAPI or Celery). + Handles existing event loops properly. + """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + + if loop.is_running(): + try: + import nest_asyncio + nest_asyncio.apply() + return loop.run_until_complete(coro) + except ImportError: + import concurrent.futures + + def run_in_thread(): + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + return new_loop.run_until_complete(coro) + finally: + new_loop.close() + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_in_thread) + return future.result() + + return loop.run_until_complete(coro) + + +# ---------------------------------------------------------------------- +# Jiuwen SDK lazy import helpers +# ---------------------------------------------------------------------- +def _lazy_import_jiuwen_config(): + """Lazily import only lightweight Jiuwen config classes.""" + _install_jiuwen_bypasser() + + try: + import openjiuwen # noqa: F401 + except ImportError as e: + raise JiuwenSDKError(f"Jiuwen SDK 未安装: {e}") from e + + from openjiuwen.core.foundation.llm.schema.config import ( + ModelRequestConfig, + ModelClientConfig, + ProviderType, + ) + + return ModelRequestConfig, ModelClientConfig, ProviderType + + +def build_jiuwen_model_configs(model_id: int, tenant_id: str): + """将 nexent 模型配置转换为 Jiuwen 配置对象""" + from database.model_management_db import get_model_by_model_id + from utils.config_utils import get_model_name_from_config + + ModelRequestConfig, ModelClientConfig, ProviderType = _lazy_import_jiuwen_config() + + model_config = get_model_by_model_id(model_id, tenant_id) + if not model_config: + raise JiuwenSDKError(f"model_id={model_id} not found") + + api_base = (model_config.get("base_url", "") or "").strip() + if not api_base: + api_base = "https://api.openai.com/v1" + + # Jiuwen ModelClientConfig defaults to timeout=60.0, max_retries=3. + # For prompt optimization calls, 60s can be too small. Reuse Nexent model config timeout_seconds. + timeout_seconds = model_config.get("timeout_seconds") + if timeout_seconds is None: + timeout_seconds = 120 + + ssl_cert = model_config.get("ssl_cert") or None + ssl_verify = model_config.get("ssl_verify", True) + if ssl_verify and not ssl_cert: + ssl_verify = False + + client_config = ModelClientConfig( + client_provider=ProviderType.OpenAI, + api_key=model_config["api_key"], + api_base=api_base, + timeout=float(timeout_seconds), + verify_ssl=ssl_verify, + ssl_cert=ssl_cert, + ) + + request_config = ModelRequestConfig( + model_name=get_model_name_from_config(model_config), + temperature=0.3, + ) + return request_config, client_config + + +def _lazy_import_jiuwen_builders(): + """Lazily import prompt builders only when optimization paths need them.""" + _install_jiuwen_bypasser() + + try: + import openjiuwen # noqa: F401 + except ImportError as e: + raise JiuwenSDKError(f"Jiuwen SDK 未安装: {e}") from e + + from openjiuwen.dev_tools.prompt_builder.builder.feedback_prompt_builder import ( + FeedbackPromptBuilder, + ) + from openjiuwen.dev_tools.prompt_builder.builder.badcase_prompt_builder import ( + BadCasePromptBuilder, + ) + + return FeedbackPromptBuilder, BadCasePromptBuilder + + +def _unwrap_prompt_response(text: str) -> str: + """Strip JSON wrapper or markdown fence that Jiuwen LLM sometimes generates.""" + _logger = logging.getLogger("jiuwen_adapter") + _logger.debug(f"[unwrap] raw ({len(text)} chars): {text[:200]}") + + # Step 1: strip markdown code fences + text = text.strip() + if text.startswith("```"): + for lang in ("json", ""): + fence = f"```{lang}\n" + if text.startswith(fence): + text = text[len(fence):] + if text.endswith("\n```"): + text = text[:-4] + elif text.endswith("```"): + text = text[:-3] + break + text = text.strip() + _logger.debug(f"[unwrap] after fence strip ({len(text)} chars)") + + # Step 2: try standard JSON parse (handles format 1 and 2) + if text.startswith("{"): + try: + parsed = json.loads(text) + if isinstance(parsed, dict) and "prompt" in parsed: + result = parsed["prompt"].strip() + _logger.debug(f"[unwrap] extracted prompt ({len(result)} chars)") + return result + if isinstance(parsed, dict) and "result" in parsed: + result = parsed["result"].strip() + _logger.debug(f"[unwrap] extracted result ({len(result)} chars)") + return result + except Exception: + pass + + # Step 3: format 3 and 4 - raw text (possibly multi-line), return as-is + _logger.debug(f"[unwrap] no JSON wrapper, returning raw ({len(text)} chars)") + return text + + +def _lazy_import_jiuwen_tune_types(): + """Lazily import Jiuwen tune types only when badcase flow needs them.""" + _install_jiuwen_bypasser() + from openjiuwen.dev_tools.tune.base import Case, EvaluatedCase + return Case, EvaluatedCase + + +def to_jiuwen_evaluated_case(bad_case) -> Any: + """将 nexent BadCase 转换为 Jiuwen EvaluatedCase""" + Case, EvaluatedCase = _lazy_import_jiuwen_tune_types() + + case = Case( + inputs={"question": bad_case.question}, + label={"answer": bad_case.label or ""}, + ) + return EvaluatedCase( + case=case, + answer={"content": bad_case.answer}, + score=0.0, + reason=bad_case.reason or "", + ) + + +# ---------------------------------------------------------------------- +# Main adapter class +# ---------------------------------------------------------------------- +class JiuwenSDKAdapter: + """ + Jiuwen SDK 调用适配器 + + 封装 Jiuwen SDK 的所有调用,内部不处理降级, + 失败时抛出 JiuwenSDKError,由上层 PromptOptimizationService 决定是否降级 + """ + + def __init__(self, model_id: int, tenant_id: str): + self.model_id = model_id + self.tenant_id = tenant_id + self.logger = logging.getLogger("jiuwen_adapter") + + def _ensure_available(self): + """确保 Jiuwen SDK 可用""" + if not _bypasser_installed: + _install_jiuwen_bypasser() + + try: + import openjiuwen # noqa: F401 + except ImportError as e: + raise JiuwenSDKError(f"Jiuwen SDK 未安装: {e}") from e + + def optimize( + self, + prompt: str, + feedback: str, + mode: Literal["general", "insert", "select"] = "general", + start_pos: Optional[int] = None, + end_pos: Optional[int] = None, + language: str = "zh", + ) -> str: + """ + 调用 Jiuwen FeedbackPromptBuilder + + Raises: + JiuwenSDKError: SDK 调用失败 + """ + self._ensure_available() + + logger.info(f"[jiuwen-adapter] mode={mode}, start_pos={start_pos}, end_pos={end_pos}") + + request_config, client_config = build_jiuwen_model_configs( + self.model_id, self.tenant_id + ) + logger.info( + f"[jiuwen-adapter] model_id={self.model_id}, tenant_id={self.tenant_id}, " + f"api_base={client_config.api_base}, model={request_config.model_name}, " + f"timeout={getattr(client_config, 'timeout', None)}, max_retries={getattr(client_config, 'max_retries', None)}" + ) + FeedbackPromptBuilder, _ = _lazy_import_jiuwen_builders() + + builder = FeedbackPromptBuilder( + model_config=request_config, + model_client_config=client_config, + ) + + try: + result = run_async( + builder.build( + prompt=prompt, + feedback=feedback, + mode=mode, + start_pos=start_pos, + end_pos=end_pos, + language=normalize_language(language), + ) + ) + if result is None: + raise JiuwenSDKError("Jiuwen FeedbackPromptBuilder 返回为空") + return _unwrap_prompt_response(str(result)) + except Exception as e: + self.logger.error(f"Jiuwen FeedbackPromptBuilder 调用失败: {e}") + raise JiuwenSDKError(f"优化调用失败: {e}") from e + + def optimize_badcase( + self, + prompt: str, + bad_cases: List, + language: str = "zh", + ) -> str: + """ + 调用 Jiuwen BadCasePromptBuilder + + Raises: + JiuwenSDKError: SDK 调用失败 + """ + self._ensure_available() + + _, BadCasePromptBuilder = _lazy_import_jiuwen_builders() + + request_config, client_config = build_jiuwen_model_configs( + self.model_id, self.tenant_id + ) + builder = BadCasePromptBuilder( + model_config=request_config, + model_client_config=client_config, + ) + + jiuwen_cases = [to_jiuwen_evaluated_case(bc) for bc in bad_cases] + + try: + result = run_async( + builder.build( + prompt=prompt, + cases=jiuwen_cases, + language=normalize_language(language), + ) + ) + if result is None: + raise JiuwenSDKError("Jiuwen BadCasePromptBuilder 返回为空") + return _unwrap_prompt_response(str(result)) + except Exception as e: + self.logger.error(f"Jiuwen BadCasePromptBuilder 调用失败: {e}") + raise JiuwenSDKError(f"BadCasePromptBuilder 调用失败: {e}") from e + + def generate(self, **kwargs) -> dict: + """调用 Jiuwen 提示词生成能力""" + self._ensure_available() + raise JiuwenSDKError("Jiuwen 提示词生成能力尚未实现") diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index 50df7eb99..7e3b42e28 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -1,12 +1,12 @@ -import threading +import json +import threading import logging -from typing import List, Optional +from typing import Any, Dict, List, Optional from urllib.parse import urljoin -from datetime import datetime from jinja2 import Template, StrictUndefined from nexent.core.utils.observer import MessageObserver -from nexent.core.agents.agent_model import AgentRunInfo, ModelConfig, AgentConfig, ToolConfig, ExternalA2AAgentConfig, AgentHistory +from nexent.core.agents.agent_model import AgentRunInfo, ModelConfig, AgentConfig, ToolConfig, ExternalA2AAgentConfig, AgentHistory, AgentVerificationConfig from nexent.core.agents.agent_context import ContextManagerConfig from nexent.memory.memory_service import search_memory_in_levels @@ -22,7 +22,11 @@ from database.a2a_agent_db import PROTOCOL_JSONRPC from services.memory_config_service import build_memory_context from services.image_service import get_video_understanding_model, get_vlm_model -from database.agent_db import search_agent_info_by_agent_id, query_sub_agents_id_list +from database.agent_db import ( + search_agent_info_by_agent_id, + query_sub_agent_relations, + resolve_sub_agent_version_no, +) from database.agent_version_db import query_current_version_no from database.tool_db import search_tools_for_sub_agent from database.model_management_db import get_model_records, get_model_by_model_id @@ -33,12 +37,71 @@ from utils.config_utils import tenant_config_manager, get_model_name_from_config from utils.context_utils import build_context_components from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE, MINIO_DEFAULT_BUCKET +from consts.model import AgentToolParamsRequest, ToolParamsRequest from consts.exceptions import ValidationError logger = logging.getLogger("create_agent_info") logger.setLevel(logging.DEBUG) +def _normalize_tool_params_request(tool_params: Optional[ToolParamsRequest | Dict[str, Any]]) -> ToolParamsRequest: + """Normalize request-scoped tool parameter overrides into a ToolParamsRequest.""" + if tool_params is None: + return ToolParamsRequest() + if isinstance(tool_params, ToolParamsRequest): + return tool_params + if not isinstance(tool_params, dict): + raise ValidationError("tool_params must be an object.") + try: + return ToolParamsRequest.model_validate(tool_params) + except Exception as exc: + raise ValidationError(f"Invalid tool_params payload: {exc}") from exc + + +def _get_agent_tool_overrides( + tool_params: Optional[ToolParamsRequest], + agent_name: Optional[str], +) -> Dict[str, Dict[str, Any]]: + """Resolve tool overrides for a specific agent by its name.""" + if tool_params is None: + return {} + if not agent_name: + return {} + agent_override = tool_params.agents.get(agent_name) + if agent_override is None: + return {} + return dict(agent_override.tools) + + +def _merge_tool_params( + tool_record: Dict[str, Any], + override_params: Optional[Dict[str, Any]], + extra_params: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Merge request overrides on top of tool instance defaults from DB. + + Args: + tool_record: Tool configuration from database + override_params: Request-scoped overrides from tool_params + extra_params: Additional internal params not in DB schema (e.g., document_paths) + + Returns: + Merged params dict with DB defaults, overrides, and extra params + """ + merged_params: Dict[str, Any] = {} + for param in tool_record.get("params", []): + merged_params[param["name"]] = param.get("default") + + if override_params: + merged_params.update(override_params) + + # Extra params (e.g., internal access control params) always take precedence + if extra_params: + merged_params.update(extra_params) + + return merged_params + + def _build_internal_s3_url(file: dict) -> str: """Build a valid S3 URL for internal tools from uploaded file metadata.""" if not isinstance(file, dict): @@ -310,18 +373,23 @@ async def create_agent_config( allow_memory_search: bool = True, version_no: int = 0, override_model_id: int | None = None, + tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None, ): + normalized_tool_params = _normalize_tool_params_request(tool_params) agent_info = search_agent_info_by_agent_id( agent_id=agent_id, tenant_id=tenant_id, version_no=version_no) # create sub agent - sub_agent_id_list = query_sub_agents_id_list( + sub_agent_relations = query_sub_agent_relations( main_agent_id=agent_id, tenant_id=tenant_id, version_no=version_no) managed_agents = [] - for sub_agent_id in sub_agent_id_list: - # Get the current published version for this sub-agent (from draft version 0) - sub_agent_version_no = query_current_version_no( - agent_id=sub_agent_id, tenant_id=tenant_id) or 0 + for rel in sub_agent_relations: + sub_agent_id = rel['selected_agent_id'] + sub_agent_version_no = resolve_sub_agent_version_no( + selected_agent_id=sub_agent_id, + selected_agent_version_no=rel.get('selected_agent_version_no'), + tenant_id=tenant_id, + ) sub_agent_config = await create_agent_config( agent_id=sub_agent_id, tenant_id=tenant_id, @@ -331,13 +399,20 @@ async def create_agent_config( allow_memory_search=allow_memory_search, version_no=sub_agent_version_no, override_model_id=None, + tool_params=normalized_tool_params, ) managed_agents.append(sub_agent_config) # create external A2A agents (synchronous function, no await needed) external_a2a_agents = _get_external_a2a_agents(agent_id, tenant_id, version_no) - tool_list = await create_tool_config_list(agent_id, tenant_id, user_id, version_no=version_no) + tool_list = await create_tool_config_list( + agent_id, + tenant_id, + user_id, + version_no=version_no, + tool_params=normalized_tool_params, + ) # Build system prompt: prioritize segmented fields, fallback to original prompt field if not available duty_prompt = agent_info.get("duty_prompt", "") @@ -383,6 +458,77 @@ async def create_agent_config( # Bubble up to streaming layer so it can emit and fall back raise Exception(f"Failed to retrieve memory list: {e}") + # Append active memory tools if memory is enabled + if memory_context.user_config.memory_switch and memory_context.memory_config: + try: + memory_metadata = { + "memory_config": memory_context.memory_config, + "memory_user_config": memory_context.user_config, + "tenant_id": memory_context.tenant_id, + "user_id": memory_context.user_id, + "agent_id": memory_context.agent_id, + } + + store_tool_config = ToolConfig( + class_name="StoreMemoryTool", + name="store_memory", + description=( + "Save important information to long-term memory for future recall. " + "Use this when the user shares personal preferences, facts about themselves, " + "project context, or instructions that should persist across conversations. " + "Do NOT store transient information like temporary calculations, information " + "already in the knowledge base, or data the user explicitly says to forget." + ), + inputs=json.dumps({ + "content": { + "type": "string", + "description": "The information to remember", + "description_zh": "需要记住的信息" + } + }, ensure_ascii=False), + output_type="string", + params={}, + source="local", + usage=None, + metadata=memory_metadata, + ) + tool_list.append(store_tool_config) + + search_tool_config = ToolConfig( + class_name="SearchMemoryTool", + name="search_memory", + description=( + "Search long-term memory for relevant information from previous interactions. " + "Use this when you need context about the user's preferences, past decisions, " + "or previously discussed topics that aren't in the current conversation. " + "The system already provides some memory context automatically -- use this tool " + "when you need to search for specific information not already available." + ), + inputs=json.dumps({ + "query": { + "type": "string", + "description": "Natural language query describing what to search for", + "description_zh": "描述要搜索内容的自然语言查询" + }, + "top_k": { + "type": "integer", + "description": "Maximum number of results to return", + "description_zh": "返回结果的最大数量", + "default": 5, + "nullable": True + } + }, ensure_ascii=False), + output_type="string", + params={}, + source="local", + usage=None, + metadata=memory_metadata, + ) + tool_list.append(search_tool_config) + logger.debug("Active memory tools appended to agent tool list") + except Exception as e: + logger.warning(f"Failed to append active memory tools: {e}") + # Build knowledge base summary knowledge_base_summary = "" try: @@ -413,7 +559,6 @@ async def create_agent_config( # Get skills list for prompt template skills = _get_skills_for_template(agent_id, tenant_id, version_no) - time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") is_manager = len(managed_agents) > 0 or len(external_a2a_agents) > 0 render_kwargs = { @@ -428,7 +573,6 @@ async def create_agent_config( "APP_DESCRIPTION": app_description, "memory_list": memory_list, "knowledge_base_summary": knowledge_base_summary, - "time": time_str, "user_id": user_id, } system_prompt = Template(prompt_template["system_prompt"], undefined=StrictUndefined).render(render_kwargs) @@ -457,7 +601,6 @@ async def create_agent_config( few_shots=few_shots_prompt, app_name=app_name, app_description=app_description, - time_str=time_str, user_id=user_id, language=language, is_manager=is_manager, @@ -490,21 +633,48 @@ async def create_agent_config( external_a2a_agents=external_a2a_agents, context_manager_config=cm_config, context_components=context_components, + verification_config=AgentVerificationConfig.model_validate(agent_info.get("verification_config") or {}), ) return agent_config -async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int = 0): - # create tool +async def create_tool_config_list( + agent_id, + tenant_id, + user_id, + version_no: int = 0, + tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None, +): tool_config_list = [] langchain_tools = await discover_langchain_tools() + normalized_tool_params = _normalize_tool_params_request(tool_params) # now only admin can modify the agent, user_id is not used tools_list = search_tools_for_sub_agent(agent_id, tenant_id, version_no=version_no) + + # Look up agent name for use in error messages. + # Agent name is optional for tool_params matching (matching uses tool identifiers only), + # but we include it in error messages so callers can identify which agent/tool caused a failure. + agent_info = search_agent_info_by_agent_id(agent_id=agent_id, tenant_id=tenant_id, version_no=version_no) + agent_name = agent_info.get("name") if agent_info else None + agent_tool_overrides = _get_agent_tool_overrides(normalized_tool_params, agent_name) + + tool_keys_seen = set() for tool in tools_list: - param_dict = {} - for param in tool.get("params", []): - param_dict[param["name"]] = param.get("default") + tool_identifier = tool.get("name") or tool.get("class_name") + if tool_identifier in tool_keys_seen: + raise ValidationError( + f"Duplicate tool identifier '{tool_identifier}' found in agent '{agent_name or agent_id}'." + ) + tool_keys_seen.add(tool_identifier) + + override_params = None + if tool.get("name") in agent_tool_overrides: + override_params = agent_tool_overrides[tool.get("name")] + elif tool.get("class_name") in agent_tool_overrides: + override_params = agent_tool_overrides[tool.get("class_name")] + + param_dict = _merge_tool_params(tool, override_params) tool_config = ToolConfig( class_name=tool.get("class_name"), name=tool.get("name"), @@ -523,12 +693,21 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int tool_config.metadata = langchain_tool break + # Extract document_paths for KnowledgeBaseSearchTool (internal access control, not in DB schema) + document_paths = None + if override_params and "document_paths" in override_params: + document_paths = override_params.get("document_paths") + # Also check using the tool name as key + if not document_paths: + kb_overrides = agent_tool_overrides.get("knowledge_base_search") + if kb_overrides and "document_paths" in kb_overrides: + document_paths = kb_overrides.get("document_paths") + # special logic for search tools that may use reranking models if tool_config.class_name == "KnowledgeBaseSearchTool": - rerank = param_dict.get("rerank", False) - rerank_model_name = param_dict.get("rerank_model_name", "") + rerank = tool_config.params.get("rerank", False) + rerank_model_name = tool_config.params.get("rerank_model_name", "") rerank_model = None - is_multimodal = bool(tool_config.params.pop("multimodal", False)) if rerank and rerank_model_name: rerank_model = get_rerank_model( tenant_id=tenant_id, model_name=rerank_model_name @@ -536,7 +715,7 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int # Build display_name to index_name mapping for LLM parameter conversion # Also build reverse mapping (index_name -> display_name) for knowledge_base_summary - index_names = param_dict.get("index_names", []) + index_names = tool_config.params.get("index_names", []) display_name_to_index_map = {} index_name_to_display_map = {} if index_names: @@ -552,12 +731,14 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int "rerank_model": rerank_model, "display_name_to_index_map": display_name_to_index_map, "index_name_to_display_map": index_name_to_display_map, + # Internal access control: restrict results to specific document paths (path_or_urls) + "document_paths": document_paths, } - # Must have embedding model for knowledge base search if not index_names: raise ValidationError( - "Embedding model is required for knowledge_base_search but index_names is empty") + f"[{agent_name or agent_id}] knowledge_base_search tool requires index_names, " + f"but it is not configured in the agent and not provided via tool_params.") embedding_model, _, _ = get_embedding_model_by_index_name(tenant_id, index_names[0]) if not embedding_model: @@ -566,8 +747,8 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int f"Please configure an embedding model for this knowledge base.") tool_config.metadata["embedding_model"] = embedding_model elif tool_config.class_name in ["DifySearchTool", "DataMateSearchTool"]: - rerank = param_dict.get("rerank", False) - rerank_model_name = param_dict.get("rerank_model_name", "") + rerank = tool_config.params.get("rerank", False) + rerank_model_name = tool_config.params.get("rerank_model_name", "") rerank_model = None if rerank and rerank_model_name: rerank_model = get_rerank_model( @@ -861,6 +1042,7 @@ async def create_agent_run_info( is_debug: bool = False, override_version_no: int | None = None, override_model_id: int | None = None, + tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None, ): # Determine which version_no to use based on is_debug flag # If is_debug=false, use the current published version (current_version_no) @@ -893,7 +1075,7 @@ async def create_agent_run_info( if override_model_id is not None: create_config_kwargs["override_model_id"] = override_model_id - agent_config = await create_agent_config(**create_config_kwargs) + agent_config = await create_agent_config(**create_config_kwargs, tool_params=tool_params) remote_mcp_list = await get_remote_mcp_server_list(tenant_id=tenant_id, is_need_auth=True) default_mcp_url = urljoin(LOCAL_MCP_SERVER, "sse") diff --git a/backend/apps/agent_app.py b/backend/apps/agent_app.py index e280ff422..87abbf9e8 100644 --- a/backend/apps/agent_app.py +++ b/backend/apps/agent_app.py @@ -195,8 +195,6 @@ async def export_agent_api(request: AgentIDRequest, authorization: Optional[str] "Content-Disposition": f"attachment; filename=\"{result.get('filename', 'agent_export.zip')}\"" } ) - if isinstance(result, str): - result = json.loads(result) return ConversationResponse(code=0, message="success", data=result) except Exception as e: logger.error(f"Agent export error: {str(e)}") @@ -621,3 +619,5 @@ async def list_published_agents_api( raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Published agents list error." ) + + diff --git a/backend/apps/agent_repository_app.py b/backend/apps/agent_repository_app.py new file mode 100644 index 000000000..e9da2fde0 --- /dev/null +++ b/backend/apps/agent_repository_app.py @@ -0,0 +1,134 @@ +import logging +from http import HTTPStatus +from typing import Optional + +from fastapi import APIRouter, Body, Header, HTTPException, Query +from starlette.responses import JSONResponse + +from consts.exceptions import SkillDuplicateError, UnauthorizedError +from services.agent_repository_service import ( + create_agent_repository_listing_impl, + import_agent_from_repository_impl, + list_agent_repository_listings_impl, + update_agent_repository_status_impl, +) +from utils.auth_utils import get_current_user_id + +agent_repository_router = APIRouter(prefix="/repository/agent") +logger = logging.getLogger("agent_repository_app") + + +@agent_repository_router.get("") +async def list_agent_repository_listings_api( + status: Optional[str] = Query(None, description="Filter by listing status"), + authorization: str = Header(None), +): + """List all marketplace repository listings with optional status filter.""" + try: + get_current_user_id(authorization) + result = list_agent_repository_listings_impl(status=status) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except UnauthorizedError as e: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=str(e)) + except ValueError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) + except Exception as e: + logger.error(f"List agent repository listings error: {str(e)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="List agent repository listings error.", + ) + + +@agent_repository_router.patch("/{agent_repository_id}/status") +async def update_agent_repository_status_api( + agent_repository_id: int, + status: str = Body( + ..., + embed=True, + description=( + "New status: NOT_SHARED (未共享) / PENDING_REVIEW (待审核) / " + "REJECTED (审核驳回) / SHARED (已共享)" + ), + ), + authorization: str = Header(None), +): + """Update marketplace repository listing status (share, unshare, approve, reject).""" + try: + user_id, _ = get_current_user_id(authorization) + result = update_agent_repository_status_impl( + agent_repository_id=agent_repository_id, + status=status, + user_id=user_id, + ) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except UnauthorizedError as e: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=str(e)) + except ValueError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) + except Exception as e: + logger.error(f"Update agent repository status error: {str(e)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Update agent repository status error.", + ) + + +@agent_repository_router.post("/{agent_id}/versions/{version_no}") +async def create_agent_repository_listing_api( + agent_id: int, + version_no: int, + authorization: str = Header(None), +): + """Create or update a marketplace repository listing from an agent version snapshot.""" + try: + user_id, tenant_id = get_current_user_id(authorization) + result = await create_agent_repository_listing_impl( + agent_id=agent_id, + tenant_id=tenant_id, + user_id=user_id, + version_no=version_no, + ) + return JSONResponse(status_code=HTTPStatus.OK, content=result) + except UnauthorizedError as e: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=str(e)) + except ValueError as e: + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) + except Exception as e: + logger.error(f"Create agent repository listing error: {str(e)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Create agent repository listing error.", + ) + + +@agent_repository_router.post("/{agent_repository_id}/import") +async def import_agent_from_repository_api( + agent_repository_id: int, + authorization: Optional[str] = Header(None), +): + """Import an agent tree from a marketplace repository listing into the current tenant.""" + try: + await import_agent_from_repository_impl( + agent_repository_id=agent_repository_id, + authorization=authorization, + ) + return JSONResponse(status_code=HTTPStatus.OK, content={}) + except UnauthorizedError as e: + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=str(e)) + except SkillDuplicateError as exc: + raise HTTPException( + status_code=HTTPStatus.CONFLICT, + detail={ + "type": "skill_duplicate", + "duplicate_skills": exc.duplicate_names, + }, + ) + except ValueError as e: + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) + except Exception as e: + logger.error(f"Import agent from repository error: {str(e)}") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="Import agent from repository error.", + ) diff --git a/backend/apps/app_factory.py b/backend/apps/app_factory.py index 219da5b82..02816cec1 100644 --- a/backend/apps/app_factory.py +++ b/backend/apps/app_factory.py @@ -101,6 +101,16 @@ async def generic_exception_handler(request, exc): if isinstance(exc, AppException): return await app_exception_handler(request, exc) + # Handle NexentCapabilityError with a friendly message + from adapters.exception import NexentCapabilityError as _NCE + + if isinstance(exc, _NCE): + logger.warning(f"NexentCapabilityError: {exc}") + return JSONResponse( + status_code=400, + content={"message": str(exc)}, + ) + logger.error(f"Generic Exception: {exc}") return JSONResponse( status_code=500, diff --git a/backend/apps/cas_app.py b/backend/apps/cas_app.py new file mode 100644 index 000000000..dbf4815f8 --- /dev/null +++ b/backend/apps/cas_app.py @@ -0,0 +1,156 @@ +import html +import logging +from http import HTTPStatus +from typing import Optional +from urllib.parse import parse_qs, urlsplit + +from fastapi import APIRouter, HTTPException, Query, Request +from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse + +from services.cas_service import ( + CAS_SERVER_URL, + CasAuthenticationError, + build_login_url, + build_renew_url, + get_cas_config, + login_with_ticket, + renew_with_ticket, + revoke_from_logout_request, +) + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/user/cas", tags=["cas"]) + + +@router.get("/config") +async def config(): + return JSONResponse( + status_code=HTTPStatus.OK, + content={"message": "success", "data": get_cas_config()}, + ) + + +@router.get("/login") +async def login(redirect: str = Query("/", description="URL to return to after login")): + try: + login_url = _require_cas_server_redirect(build_login_url(redirect)) + return RedirectResponse(url=login_url, status_code=HTTPStatus.FOUND) + except CasAuthenticationError as exc: + logger.warning("CAS login rejected: %s", exc) + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail="CAS login is not available") + + +@router.get("/callback") +async def callback(ticket: str = "", redirect: str = "/"): + try: + result = await login_with_ticket(ticket, redirect) + return JSONResponse( + status_code=HTTPStatus.OK, + content={"message": "CAS login successful", "data": result}, + ) + except CasAuthenticationError as exc: + logger.warning("CAS callback rejected: %s", exc) + raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="CAS authentication failed") + except Exception as exc: + logger.error(f"CAS callback failed: {exc}") + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="CAS login failed") + + +@router.post("/callback") +async def callback_logout(request: Request, logout_request: Optional[str] = None): + return await _handle_logout_request(request, logout_request, endpoint="callback") + + +@router.get("/renew") +async def renew(): + try: + return RedirectResponse(url=build_renew_url(), status_code=HTTPStatus.FOUND) + except CasAuthenticationError as exc: + logger.warning("CAS renew rejected: %s", exc) + return _renew_html(False, "CAS renew failed") + + +@router.get("/renew_callback") +async def renew_callback(ticket: str = ""): + if not ticket: + return _renew_html(False, "CAS session is not active") + try: + result = await renew_with_ticket(ticket) + return JSONResponse( + status_code=HTTPStatus.OK, + content={"message": "CAS renew successful", "data": result}, + ) + except Exception as exc: + logger.warning(f"CAS renew failed: {exc}") + return _renew_html(False, "CAS renew failed") + + +@router.post("/logout_callback") +async def logout_callback( + request: Request, + logout_request: Optional[str] = None, +): + return await _handle_logout_request(request, logout_request, endpoint="logout_callback") + + +async def _handle_logout_request( + request: Request, + logout_request: Optional[str] = None, + endpoint: str = "unknown", +): + logout_request = await _extract_logout_request(request, logout_request) + logger.info( + "CAS SLO %s received logoutRequest: present=%s length=%s", + endpoint, + bool(logout_request), + len(logout_request or ""), + ) + result = revoke_from_logout_request(logout_request) + logger.info("CAS SLO %s revoke result: %s", endpoint, result) + return JSONResponse( + status_code=HTTPStatus.OK, + content={"message": "success", "data": result}, + ) + + +async def _extract_logout_request(request: Request, logout_request: Optional[str] = None) -> str: + if logout_request: + return logout_request + + query_logout_request = request.query_params.get("logoutRequest") or request.query_params.get("logout_request") + if query_logout_request: + return query_logout_request + + body = await request.body() + raw_body = body.decode("utf-8") if body else "" + if not raw_body: + return "" + + parsed = parse_qs(raw_body) + return (parsed.get("logoutRequest") or parsed.get("logout_request") or [raw_body])[0] + + +def _renew_html(success: bool, reason: str = "") -> HTMLResponse: + status = "success" if success else "failed" + safe_reason = html.escape(reason) + return HTMLResponse( + status_code=HTTPStatus.OK, + content=f""" +""", + ) + + +def _require_cas_server_redirect(url: str) -> str: + parsed_url = urlsplit(url) + parsed_cas = urlsplit(CAS_SERVER_URL) + if ( + parsed_url.scheme not in {"http", "https"} + or not parsed_url.netloc + or parsed_url.scheme != parsed_cas.scheme + or parsed_url.netloc != parsed_cas.netloc + ): + logger.warning("Blocked CAS redirect outside configured server: %s", url) + raise CasAuthenticationError("Invalid CAS redirect URL") + return url diff --git a/backend/apps/config_app.py b/backend/apps/config_app.py index 8cb383df7..a818ec7cb 100644 --- a/backend/apps/config_app.py +++ b/backend/apps/config_app.py @@ -2,6 +2,7 @@ from apps.app_factory import create_app from apps.agent_app import agent_config_router as agent_router +from apps.agent_repository_app import agent_repository_router from apps.config_sync_app import router as config_sync_router from apps.datamate_app import router as datamate_router from apps.vectordatabase_app import router as vectordatabase_router @@ -32,6 +33,7 @@ from apps.monitoring_app import router as monitoring_router from apps.a2a_server_app import router as a2a_server_router from apps.haotian_app import router as haotian_router +from apps.cas_app import router as cas_router from consts.const import IS_SPEED_MODE from services.prompt_template_service import sync_system_default_prompt_template @@ -54,6 +56,7 @@ async def sync_default_prompt_template_on_startup(): app.include_router(model_manager_router) app.include_router(config_sync_router) app.include_router(agent_router) +app.include_router(agent_repository_router) app.include_router(vectordatabase_router) app.include_router(datamate_router) app.include_router(voice_router) @@ -73,6 +76,7 @@ async def sync_default_prompt_template_on_startup(): app.include_router(user_management_router) app.include_router(oauth_router) +app.include_router(cas_router) app.include_router(summary_router) app.include_router(prompt_router) diff --git a/backend/apps/northbound_app.py b/backend/apps/northbound_app.py index e6aff8e06..9f3b7e323 100644 --- a/backend/apps/northbound_app.py +++ b/backend/apps/northbound_app.py @@ -1,14 +1,16 @@ import logging from http import HTTPStatus from typing import Optional, Dict, Any -from urllib.parse import urlparse +from urllib.parse import urlparse, unquote +import re import uuid import httpx -from fastapi import APIRouter, Body, Header, Request, HTTPException, Query +from fastapi import APIRouter, Body, File, Header, HTTPException, Query, Request, UploadFile from fastapi.responses import JSONResponse, StreamingResponse -from consts.exceptions import LimitExceededError, UnauthorizedError +from consts.exceptions import LimitExceededError, UnauthorizedError, ConversationNotFoundError +from consts.model import ToolParamsRequest from services.northbound_service import ( NorthboundContext, get_conversation_history, @@ -17,16 +19,35 @@ stop_chat, get_agent_info_list, update_conversation_title, + upload_files_for_northbound, ) from utils.auth_utils import validate_bearer_token, get_user_and_tenant_by_access_key +from .file_management_app import build_content_disposition_header + router = APIRouter(prefix="/nb/v1", tags=["northbound"]) __all__ = ["router", "_get_northbound_context"] +def _resolve_proxy_download_filename(presigned_url: str, content_disposition: str) -> str: + """Resolve a stable download filename for the northbound file proxy.""" + if content_disposition: + filename_star_match = re.search(r"filename\*=UTF-8''([^;]+)", content_disposition) + if filename_star_match: + return unquote(filename_star_match.group(1)) or "download" + + filename_match = re.search(r'filename="?([^";]+)"?', content_disposition) + if filename_match: + return filename_match.group(1) or "download" + + path = unquote(urlparse(presigned_url).path) + filename = path.split("/")[-1].strip() + return filename or "download" + + async def _get_northbound_context(request: Request) -> NorthboundContext: """ Build northbound context from request. @@ -109,13 +130,119 @@ async def health_check(): return {"status": "healthy", "service": "northbound-api"} -@router.post("/chat/run") +@router.post( + "/chat/attachments/upload", + summary="Upload chat attachments for northbound runs", + description=( + "Upload one or more files for later use in `/nb/v1/chat/run`. " + "Successful uploads return reusable `s3_url` references." + ), +) +async def upload_chat_attachments( + request: Request, + files: list[UploadFile] = File( + ..., + description="List of files to upload", + examples=["report.pdf", "diagram.png"], + ), +): + try: + ctx: NorthboundContext = await _get_northbound_context(request) + return JSONResponse( + status_code=HTTPStatus.OK, + content=await upload_files_for_northbound(ctx=ctx, files=files), + ) + except LimitExceededError as e: + logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) + raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, + detail="Too Many Requests: rate limit exceeded") + except ValueError as e: + logging.error(f"Invalid northbound upload request: {str(e)}", exc_info=e) + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) + except PermissionError as e: + logging.error(f"Permission denied while uploading northbound files: {str(e)}", exc_info=e) + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail=str(e)) + except HTTPException as e: + raise e + except Exception as e: + logging.error(f"Failed to upload northbound files: {str(e)}", exc_info=e) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Internal Server Error") + + +@router.post( + "/chat/run", + summary="Start a northbound chat run with optional attachments", + description=( + "Run a northbound chat request. Upload attachments first through " + "`/nb/v1/chat/attachments/upload`, then pass the returned `s3_url` values " + "through the `attachments` field." + ), +) async def run_chat( request: Request, - conversation_id: Optional[int] = Body(None, embed=True), - agent_name: str = Body(..., embed=True), - query: str = Body(..., embed=True), - meta_data: Optional[Dict[str, Any]] = Body(None, embed=True), + conversation_id: Optional[int] = Body( + None, + embed=True, + description="Existing conversation ID. Omit to create a new conversation.", + examples=[123], + ), + agent_name: str = Body( + ..., + embed=True, + description="Target agent name.", + examples=["general-assistant"], + ), + query: str = Body( + ..., + embed=True, + description="User input to send to the agent.", + examples=["Summarize the uploaded report and list the key risks."], + ), + attachments: Optional[list] = Body( + None, + embed=True, + description="Attachments for the chat. Can be either a list of S3 URL strings" + "or a list of attachment objects with full metadata.", + examples=[["s3://nexent/attachments/user123/20260609_report.pdf"]], + ), + meta_data: Optional[Dict[str, Any]] = Body( + None, + embed=True, + description="Optional metadata passed through for audit and usage logging.", + examples=[{"source": "crm", "ticket_id": "INC-1001"}], + ), + tool_params: Optional[ToolParamsRequest] = Body( + None, + embed=True, + description="Optional request-scoped overrides for tool initialization parameters. " + "Overrides DB-persisted params (ag_tool_instance_t.params) on a per-run basis. " + "Conflict resolution: request value wins over DB value. " + "Structure: agents -> {agent_name} -> tools -> {tool_name} -> {param_name: param_value}. " + "tool_name matching: first by tool.name, then by tool.class_name. " + "Unknown param names cause a ValidationError (400). " + "Metadata-derived fields (e.g., vdb_core, embedding_model) are recalculated " + "from merged params for tools like KnowledgeBaseSearchTool, DifySearchTool, DataMateSearchTool.", + examples=[{ + "agents": { + "common_sense_qa_assistant": { + "tools": { + "analyze_text_file": { + "chunk_size": 4000, + "summary_only": True, + "prompt": "Please provide a concise summary of this document focusing on key facts." + }, + "knowledge_base_search": { + "top_k": 10, + "rerank": True, + "rerank_model_name": "gte-rerank-v2", + "index_names": ["nexent-docs", "faq-index"] + } + } + } + } + }], + ), idempotency_key: Optional[str] = Header(None, alias="Idempotency-Key"), ): try: @@ -125,13 +252,21 @@ async def run_chat( conversation_id=conversation_id, agent_name=agent_name, query=query, + attachments=attachments, meta_data=meta_data, + tool_params=tool_params, idempotency_key=idempotency_key, ) except LimitExceededError as e: logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") + except ValueError as e: + logging.error(f"Invalid northbound chat request: {str(e)}", exc_info=e) + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) + except PermissionError as e: + logging.error(f"Permission denied while running northbound chat: {str(e)}", exc_info=e) + raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail=str(e)) except HTTPException as e: raise e except Exception as e: @@ -254,6 +389,9 @@ async def update_convs_title( logging.error(f"Too Many Requests: rate limit exceeded: {str(e)}", exc_info=e) raise HTTPException(status_code=HTTPStatus.TOO_MANY_REQUESTS, detail="Too Many Requests: rate limit exceeded") + except ConversationNotFoundError as e: + logging.error(f"Conversation not found while updating title: {str(e)}", exc_info=e) + raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) except HTTPException as e: raise e except Exception as e: @@ -312,12 +450,12 @@ async def fetch_file_from_presigned_url( content_type = response.headers.get("Content-Type", "application/octet-stream") content_disposition = response.headers.get("Content-Disposition", "") + download_filename = _resolve_proxy_download_filename(presigned_url, content_disposition) headers = { "Content-Type": content_type, + "Content-Disposition": build_content_disposition_header(download_filename), } - if content_disposition: - headers["Content-Disposition"] = content_disposition return StreamingResponse( content=response.aiter_bytes(), diff --git a/backend/apps/northbound_knowledge_app.py b/backend/apps/northbound_knowledge_app.py index 775d6c567..02739d138 100644 --- a/backend/apps/northbound_knowledge_app.py +++ b/backend/apps/northbound_knowledge_app.py @@ -51,7 +51,8 @@ async def _require_asset_owner_context(request: Request) -> NorthboundContext: @router.get("/indices") async def get_list_indices( request: Request, - pattern: Annotated[str, Query(description="Pattern to match index names")] = "*", + pattern: Annotated[str, Query( + description="Pattern to match index names")] = "*", ): """List knowledge bases visible to the asset-owner tenant. @@ -92,7 +93,7 @@ async def create_new_index( Optional[Dict[str, Any]], Body( description=( - "Request body with optional fields (ingroup_permission, group_ids, embedding_model_name)" + "Request body with optional fields (ingroup_permission, group_ids, embedding_model_name, preserve_source_file)" ), ), ] = None, @@ -110,10 +111,12 @@ async def create_new_index( ingroup_permission = None group_ids = None embedding_model_name = None + preserve_source_file = None if body: ingroup_permission = body.get("ingroup_permission") group_ids = body.get("group_ids") embedding_model_name = body.get("embedding_model_name") + preserve_source_file = body.get("preserve_source_file") return ElasticSearchService.create_knowledge_base( knowledge_name=index_name, @@ -124,6 +127,7 @@ async def create_new_index( ingroup_permission=ingroup_permission, group_ids=group_ids, embedding_model_name=embedding_model_name, + preserve_source_file=preserve_source_file, ) except LimitExceededError as e: logger.exception("Rate limit exceeded while creating index") @@ -222,52 +226,65 @@ async def delete_documents( request: Request, index_name: Annotated[str, Path(..., description="Name of the index")], path_or_url: Annotated[str, Query(..., description="Path or URL of documents to delete")], + scope: Annotated[ + str, + Query( + description=( + "source_only: delete MinIO source only; " + "full: delete ES, MinIO, and Redis records" + ), + ), + ] = "full", ): - """Delete documents by path or URL and clean up related Redis records. - - Restricted to asset administrators (same auth as get_list_indices). - """ + """Delete a document by scope. Restricted to asset administrators.""" try: - ctx = await _require_asset_owner_context(request) + await _require_asset_owner_context(request) vdb_core = get_vector_db_core(db_type=VectorDatabaseType.ELASTICSEARCH) - logger.debug("Deleting documents for index %s", index_name) - result = ElasticSearchService.delete_documents( - index_name, path_or_url, vdb_core) - - try: - redis_service = get_redis_service() - redis_cleanup_result = redis_service.delete_document_records( - index_name, path_or_url) - - result["redis_cleanup"] = redis_cleanup_result - - original_message = result.get( - "message", "Documents deleted successfully") - result["message"] = ( - f"{original_message}. " - f"Cleaned up {redis_cleanup_result['total_deleted']} Redis records " - f"({redis_cleanup_result['celery_tasks_deleted']} tasks, " - f"{redis_cleanup_result['cache_keys_deleted']} cache keys)." - ) - - if redis_cleanup_result.get("errors"): - result["redis_warnings"] = redis_cleanup_result["errors"] + logger.debug( + "Deleting documents for index %s scope=%s", index_name, scope + ) + result = await ElasticSearchService.delete_document_by_scope( + index_name, path_or_url, scope, vdb_core + ) - except Exception as redis_error: - logger.warning( - "Redis cleanup failed for index %s: %s", - index_name, - redis_error, - ) - result["redis_cleanup_error"] = str(redis_error) - original_message = result.get( - "message", "Documents deleted successfully") - result["message"] = ( - f"{original_message}, but Redis cleanup encountered an error: " - f"{str(redis_error)}" - ) + if scope == "full": + try: + redis_service = get_redis_service() + redis_cleanup_result = redis_service.delete_document_records( + index_name, path_or_url + ) + result["redis_cleanup"] = redis_cleanup_result + original_message = result.get( + "message", "Documents deleted successfully" + ) + result["message"] = ( + f"{original_message}. " + f"Cleaned up {redis_cleanup_result['total_deleted']} Redis records " + f"({redis_cleanup_result['celery_tasks_deleted']} tasks, " + f"{redis_cleanup_result['cache_keys_deleted']} cache keys)." + ) + if redis_cleanup_result.get("errors"): + result["redis_warnings"] = redis_cleanup_result["errors"] + except Exception as redis_error: + logger.warning( + "Redis cleanup failed for index %s: %s", + index_name, + redis_error, + ) + result["redis_cleanup_error"] = str(redis_error) + original_message = result.get( + "message", "Documents deleted successfully" + ) + result["message"] = ( + f"{original_message}, but Redis cleanup encountered an error: " + f"{str(redis_error)}" + ) return result + except ValueError as exc: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, detail=str(exc) + ) except LimitExceededError as e: logger.exception("Rate limit exceeded while deleting documents") raise HTTPException( diff --git a/backend/apps/prompt_app.py b/backend/apps/prompt_app.py index 987729e69..6b82a5c82 100644 --- a/backend/apps/prompt_app.py +++ b/backend/apps/prompt_app.py @@ -4,11 +4,19 @@ from fastapi import APIRouter, Header, Request from fastapi.responses import JSONResponse, StreamingResponse -from consts.model import GeneratePromptRequest, OptimizePromptSectionRequest +from consts.model import ( + GeneratePromptRequest, + OptimizePromptSectionRequest, + OptimizePromptBadCaseRequest, + OptimizePromptFromDebugRequest, +) from services.prompt_service import ( gen_system_prompt_streamable, - optimize_prompt_section_impl, + OptimizeRequest, + OptimizeResult, + PromptOptimizationService, ) +from adapters.exception import NexentCapabilityError from utils.auth_utils import get_current_user_info router = APIRouter(prefix="/prompt") @@ -48,30 +56,140 @@ async def optimize_prompt_section_api( http_request: Request, authorization: Optional[str] = Header(None) ): + _, tenant_id, language = get_current_user_info( + authorization, http_request) + + service = PromptOptimizationService( + model_id=optimize_request.model_id, + tenant_id=tenant_id, + language=language, + ) + try: - _, tenant_id, language = get_current_user_info( - authorization, http_request) - optimized_section = optimize_prompt_section_impl( + result = service.optimize( + OptimizeRequest( + agent_id=optimize_request.agent_id, + model_id=optimize_request.model_id, + task_description=optimize_request.task_description, + section_type=optimize_request.section_type, + section_title=optimize_request.section_title, + current_content=optimize_request.current_content, + feedback=optimize_request.feedback, + mode=optimize_request.mode, + start_pos=optimize_request.start_pos, + end_pos=optimize_request.end_pos, + tool_ids=optimize_request.tool_ids, + sub_agent_ids=optimize_request.sub_agent_ids, + knowledge_base_display_names=optimize_request.knowledge_base_display_names, + ) + ) + return JSONResponse( + status_code=HTTPStatus.OK, + content={ + "message": "Success", + "data": { + "optimized_content": result.optimized_content, + "section_type": result.section_type, + "section_title": result.section_title, + "original_content": result.original_content, + } + }, + headers={"X-Prompt-Source": result.source}, + ) + except NexentCapabilityError as e: + return JSONResponse( + status_code=HTTPStatus.BAD_REQUEST, + content={"message": str(e)}, + ) + except Exception as exc: + logger.exception(f"Error occurred while optimizing prompt section: {exc}") + raise + + +@router.post("/optimize/badcase") +async def optimize_prompt_badcase_api( + badcase_request: OptimizePromptBadCaseRequest, + http_request: Request, + authorization: Optional[str] = Header(None) +): + _, tenant_id, language = get_current_user_info( + authorization, http_request) + + service = PromptOptimizationService( + model_id=badcase_request.model_id, + tenant_id=tenant_id, + language=language, + ) + + try: + result = service.optimize_badcase( + current_content=badcase_request.current_content, + bad_cases=badcase_request.bad_cases, + agent_id=badcase_request.agent_id, + section_type=badcase_request.section_type, + section_title=badcase_request.section_title, + tool_ids=badcase_request.tool_ids, + sub_agent_ids=badcase_request.sub_agent_ids, + knowledge_base_display_names=badcase_request.knowledge_base_display_names, + ) + return JSONResponse( + status_code=HTTPStatus.OK, + content={ + "message": "Success", + "data": { + "optimized_content": result.optimized_content, + "section_type": result.section_type, + "section_title": result.section_title, + "original_content": result.original_content, + } + }, + headers={"X-Prompt-Source": result.source}, + ) + except NexentCapabilityError as e: + return JSONResponse( + status_code=HTTPStatus.BAD_REQUEST, + content={"message": str(e)}, + ) + + +@router.post("/optimize/from_debug") +async def optimize_prompt_from_debug_api( + optimize_request: OptimizePromptFromDebugRequest, + http_request: Request, + authorization: Optional[str] = Header(None) +): + _, tenant_id, language = get_current_user_info( + authorization, http_request) + + service = PromptOptimizationService( + model_id=optimize_request.model_id, + tenant_id=tenant_id, + language=language, + ) + + try: + result = service.optimize_from_debug( agent_id=optimize_request.agent_id, - model_id=optimize_request.model_id, - task_description=optimize_request.task_description, - tenant_id=tenant_id, - language=language, - section_type=optimize_request.section_type, - section_title=optimize_request.section_title, - current_content=optimize_request.current_content, feedback=optimize_request.feedback, - tool_ids=optimize_request.tool_ids, - sub_agent_ids=optimize_request.sub_agent_ids, - knowledge_base_display_names=optimize_request.knowledge_base_display_names, + selected=optimize_request.selected, + history=optimize_request.history, ) return JSONResponse( status_code=HTTPStatus.OK, content={ - "message": "Prompt section optimized successfully", - "data": optimized_section, - } + "message": "Success", + "data": { + "original_full_prompt": result.original_content, + "optimized_full_prompt": result.optimized_content, + } + }, + headers={"X-Prompt-Source": result.source}, + ) + except NexentCapabilityError as e: + return JSONResponse( + status_code=HTTPStatus.BAD_REQUEST, + content={"message": str(e)}, ) except Exception as exc: - logger.exception(f"Error occurred while optimizing prompt section: {exc}") + logger.exception(f"Error occurred while optimizing prompt from debug: {exc}") raise diff --git a/backend/apps/tool_config_app.py b/backend/apps/tool_config_app.py index f0b7f9304..bfc8d5ca0 100644 --- a/backend/apps/tool_config_app.py +++ b/backend/apps/tool_config_app.py @@ -160,12 +160,14 @@ async def import_openapi_service_api( server_url: Base URL of the REST API server openapi_json: Complete OpenAPI JSON specification service_description: Optional service description + headers_template: Optional default headers template force_update: If True, replace all existing tools for this service """ service_name = openapi_service_request.get("service_name") server_url = openapi_service_request.get("server_url") openapi_json = openapi_service_request.get("openapi_json") service_description = openapi_service_request.get("service_description") + headers_template = openapi_service_request.get("headers_template") force_update = openapi_service_request.get("force_update", False) if not service_name: @@ -192,6 +194,7 @@ async def import_openapi_service_api( tenant_id=tenant_id, user_id=user_id, service_description=service_description, + headers_template=headers_template, force_update=force_update ) diff --git a/backend/apps/user_management_app.py b/backend/apps/user_management_app.py index edbcdf27d..e79fde887 100644 --- a/backend/apps/user_management_app.py +++ b/backend/apps/user_management_app.py @@ -19,12 +19,13 @@ ValidationError, ) from consts.error_code import ErrorCode +from services.cas_service import build_logout_url, CasAuthenticationError from services.user_management_service import get_authorized_client, validate_token, \ check_auth_service_health, signup_user_with_invitation, signin_user, refresh_user_token, \ get_session_by_authorization, get_user_info, create_token, list_tokens_by_user, delete_token, \ update_password from services.user_service import delete_user_and_cleanup -from utils.auth_utils import get_current_user_id +from utils.auth_utils import get_current_user_id, extract_session_id_from_authorization load_dotenv() @@ -150,7 +151,18 @@ async def logout(request: Request): authorization = request.headers.get("Authorization") try: # Make logout idempotent: if no token or token expired, still return success + session_id = None + cas_logout_url = "" if authorization: + session_id = extract_session_id_from_authorization(authorization) + if session_id: + from database.cas_session_db import revoke_cas_session_by_session_id + + revoke_cas_session_by_session_id(session_id, actor="user") + try: + cas_logout_url = build_logout_url() + except CasAuthenticationError as cas_err: + logging.warning(f"CAS logout URL is unavailable: {str(cas_err)}") client = get_authorized_client(authorization) try: client.auth.sign_out() @@ -159,7 +171,12 @@ async def logout(request: Request): logging.warning( f"Sign out encountered an error but will be ignored: {str(signout_err)}") return JSONResponse(status_code=HTTPStatus.OK, - content={"message": "Logout successful"}) + content={ + "message": "Logout successful", + "data": { + "cas_logout_url": cas_logout_url + } + }) except Exception as e: logging.error(f"User logout failed: {str(e)}") @@ -214,6 +231,10 @@ async def get_user_information(request: Request): if not user_info: raise UnauthorizedError("User information not found") + user_info["user"]["auth_provider"] = ( + "cas" if extract_session_id_from_authorization(authorization) else "local" + ) + return JSONResponse(status_code=HTTPStatus.OK, content={"message": "Success", "data": user_info}) diff --git a/backend/apps/vectordatabase_app.py b/backend/apps/vectordatabase_app.py index 118537766..505c39559 100644 --- a/backend/apps/vectordatabase_app.py +++ b/backend/apps/vectordatabase_app.py @@ -76,7 +76,7 @@ def create_new_index( embedding_dim: Optional[int] = Query( None, description="Dimension of the embedding vectors"), request: Dict[str, Any] = Body( - None, description="Request body with optional fields (ingroup_permission, group_ids, embedding_model_name)"), + None, description="Request body with optional fields (ingroup_permission, group_ids, embedding_model_name, preserve_source_file)"), vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), authorization: Optional[str] = Header(None) ): @@ -89,11 +89,13 @@ def create_new_index( group_ids = None embedding_model_name: Optional[str] = None is_multimodal: Optional[bool] = None + preserve_source_file: Optional[bool] = None if request: ingroup_permission = request.get("ingroup_permission") group_ids = request.get("group_ids") embedding_model_name = request.get("embeddingModel") is_multimodal = request.get("is_multimodal") + preserve_source_file = request.get("preserve_source_file") # Treat path parameter as user-facing knowledge base name for new creations return ElasticSearchService.create_knowledge_base( @@ -106,6 +108,7 @@ def create_new_index( group_ids=group_ids, embedding_model_name=embedding_model_name, is_multimodal=is_multimodal, + preserve_source_file=preserve_source_file, ) except Exception as e: raise HTTPException( @@ -505,54 +508,70 @@ async def get_index_files( @router.delete("/{index_name}/documents") -def delete_documents( +async def delete_documents( index_name: str = Path(..., description="Name of the index"), path_or_url: str = Query(..., description="Path or URL of documents to delete"), + scope: str = Query( + "full", + description=( + "source_only: delete MinIO source only, keep ES chunks/vectors; " + "full: delete ES documents, MinIO source, and Redis task records" + ), + ), vdb_core: VectorDatabaseCore = Depends(get_vector_db_core) ): - """Delete documents by path or URL and clean up related Redis records""" + """Delete a document by scope: source file only or full removal from the index.""" try: - # First delete the documents using existing service - result = ElasticSearchService.delete_documents( - index_name, path_or_url, vdb_core) - - # Then clean up Redis records related to this specific document - try: - redis_service = get_redis_service() - redis_cleanup_result = redis_service.delete_document_records( - index_name, path_or_url) - - # Add Redis cleanup info to the result - result["redis_cleanup"] = redis_cleanup_result - - # Update the message to include Redis cleanup info - original_message = result.get( - "message", "Documents deleted successfully") - result["message"] = ( - f"{original_message}. " - f"Cleaned up {redis_cleanup_result['total_deleted']} Redis records " - f"({redis_cleanup_result['celery_tasks_deleted']} tasks, " - f"{redis_cleanup_result['cache_keys_deleted']} cache keys)." - ) - - if redis_cleanup_result.get("errors"): - result["redis_warnings"] = redis_cleanup_result["errors"] + result = await ElasticSearchService.delete_document_by_scope( + index_name, path_or_url, scope, vdb_core + ) - except Exception as redis_error: - logger.warning( - f"Redis cleanup failed for document {path_or_url} in index {index_name}: {str(redis_error)}") - result["redis_cleanup_error"] = str(redis_error) - original_message = result.get( - "message", "Documents deleted successfully") - result[ - "message"] = f"{original_message}, but Redis cleanup encountered an error: {str(redis_error)}" + if scope == "full": + try: + redis_service = get_redis_service() + redis_cleanup_result = redis_service.delete_document_records( + index_name, path_or_url + ) + result["redis_cleanup"] = redis_cleanup_result + original_message = result.get( + "message", "Documents deleted successfully" + ) + result["message"] = ( + f"{original_message}. " + f"Cleaned up {redis_cleanup_result['total_deleted']} Redis records " + f"({redis_cleanup_result['celery_tasks_deleted']} tasks, " + f"{redis_cleanup_result['cache_keys_deleted']} cache keys)." + ) + if redis_cleanup_result.get("errors"): + result["redis_warnings"] = redis_cleanup_result["errors"] + except Exception as redis_error: + logger.warning( + "Redis cleanup failed for document %s in index %s: %s", + path_or_url, + index_name, + redis_error, + ) + result["redis_cleanup_error"] = str(redis_error) + original_message = result.get( + "message", "Documents deleted successfully" + ) + result["message"] = ( + f"{original_message}, but Redis cleanup encountered an error: " + f"{str(redis_error)}" + ) return result + except ValueError as exc: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, detail=str(exc) + ) except Exception as e: raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"Error delete indexing documents: {e}") + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"Error delete indexing documents: {e}", + ) @router.get("/{index_name}/documents/{path_or_url:path}/error-info") diff --git a/backend/consts/const.py b/backend/consts/const.py index ac2196c2a..574d550c0 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -90,6 +90,31 @@ class VectorDatabaseType(str, Enum): OAUTH_CA_BUNDLE = os.getenv("OAUTH_CA_BUNDLE", "") +# CAS SSO Configuration +CAS_ENABLED = os.getenv("CAS_ENABLED", "false").lower() in ("true", "1", "yes", "on") +CAS_SERVER_URL = os.getenv("CAS_SERVER_URL", "").rstrip("/") +CAS_VALIDATE_PATH = os.getenv("CAS_VALIDATE_PATH", "/p3/serviceValidate") +CAS_CALLBACK_BASE_URL = os.getenv("CAS_CALLBACK_BASE_URL", OAUTH_CALLBACK_BASE_URL).rstrip("/") +# CAS login mode: +# - disabled: disable CAS login entry and automatic CAS redirects. +# - button: show CAS as an optional login entry. +# - force: automatically redirect unauthenticated users to CAS login. +CAS_LOGIN_MODE = os.getenv("CAS_LOGIN_MODE", "disabled").lower() +CAS_USER_ATTRIBUTE = os.getenv("CAS_USER_ATTRIBUTE", "") +CAS_EMAIL_ATTRIBUTE = os.getenv("CAS_EMAIL_ATTRIBUTE", "email") +CAS_ROLE_ATTRIBUTE = os.getenv("CAS_ROLE_ATTRIBUTE", "role") +CAS_TENANT_ATTRIBUTE = os.getenv("CAS_TENANT_ATTRIBUTE", "tenant_id") +CAS_ROLE_MAP_JSON = os.getenv("CAS_ROLE_MAP_JSON", "") +CAS_SESSION_MAX_AGE_SECONDS = int(os.getenv("CAS_SESSION_MAX_AGE_SECONDS", "3600") or 3600) +LOCAL_SESSION_MAX_AGE_SECONDS = int(os.getenv("LOCAL_SESSION_MAX_AGE_SECONDS", "3600") or 3600) +CAS_RENEW_BEFORE_SECONDS = int(os.getenv("CAS_RENEW_BEFORE_SECONDS", "300") or 300) +CAS_RENEW_TIMEOUT_SECONDS = int(os.getenv("CAS_RENEW_TIMEOUT_SECONDS", "10") or 10) +CAS_SYNTHETIC_EMAIL_DOMAIN = os.getenv("CAS_SYNTHETIC_EMAIL_DOMAIN", "cas.local") +CAS_LOGOUT_URL = os.getenv("CAS_LOGOUT_URL", "") +CAS_SSL_VERIFY = os.getenv("CAS_SSL_VERIFY", "true").lower() == "true" +CAS_CA_BUNDLE = os.getenv("CAS_CA_BUNDLE", "") + + # ===== To be migrated to frontend configuration ===== # Email Configuration IMAP_SERVER = os.getenv('IMAP_SERVER') @@ -208,6 +233,7 @@ class VectorDatabaseType(str, Enum): "NEXENT_MCP_DOCKER_IMAGE", "nexent/nexent-mcp:latest") ENABLE_UPLOAD_IMAGE = os.getenv( "ENABLE_UPLOAD_IMAGE", "false").lower() == "true" +ENABLE_JIUWEN_SDK = os.getenv("NEXENT_ENABLE_JIUWEN_SDK", "true").lower() == "true" # Celery Configuration @@ -375,36 +401,47 @@ class VectorDatabaseType(str, Enum): OTEL_SERVICE_NAME = OTEL_SERVICE_NAME_RAW or "nexent-backend" OTEL_EXPORTER_OTLP_ENDPOINT_RAW = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT") OTEL_EXPORTER_OTLP_ENDPOINT = OTEL_EXPORTER_OTLP_ENDPOINT_RAW or "http://localhost:4318" -OTEL_EXPORTER_OTLP_TRACES_ENDPOINT = os.getenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "") -OTEL_EXPORTER_OTLP_METRICS_ENDPOINT = os.getenv("OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", "") +OTEL_EXPORTER_OTLP_TRACES_ENDPOINT = os.getenv( + "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "") +OTEL_EXPORTER_OTLP_METRICS_ENDPOINT = os.getenv( + "OTEL_EXPORTER_OTLP_METRICS_ENDPOINT", "") OTEL_EXPORTER_OTLP_PROTOCOL_RAW = os.getenv("OTEL_EXPORTER_OTLP_PROTOCOL") OTEL_EXPORTER_OTLP_PROTOCOL = OTEL_EXPORTER_OTLP_PROTOCOL_RAW or "http" OTEL_EXPORTER_OTLP_HEADERS_RAW = os.getenv("OTEL_EXPORTER_OTLP_HEADERS") OTEL_EXPORTER_OTLP_HEADERS = OTEL_EXPORTER_OTLP_HEADERS_RAW or "" -OTEL_EXPORTER_OTLP_AUTHORIZATION = os.getenv("OTEL_EXPORTER_OTLP_AUTHORIZATION", "") +OTEL_EXPORTER_OTLP_AUTHORIZATION = os.getenv( + "OTEL_EXPORTER_OTLP_AUTHORIZATION", "") OTEL_EXPORTER_OTLP_X_API_KEY = os.getenv("OTEL_EXPORTER_OTLP_X_API_KEY", "") OTEL_EXPORTER_OTLP_LANGFUSE_INGESTION_VERSION = os.getenv( "OTEL_EXPORTER_OTLP_LANGFUSE_INGESTION_VERSION", "") LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY", "") LANGSMITH_PROJECT = os.getenv("LANGSMITH_PROJECT", "") -OTEL_EXPORTER_OTLP_METRICS_ENABLED_RAW = os.getenv("OTEL_EXPORTER_OTLP_METRICS_ENABLED") +OTEL_EXPORTER_OTLP_METRICS_ENABLED_RAW = os.getenv( + "OTEL_EXPORTER_OTLP_METRICS_ENABLED") OTEL_EXPORTER_OTLP_METRICS_ENABLED = ( OTEL_EXPORTER_OTLP_METRICS_ENABLED_RAW or "true").lower() == "true" -MONITORING_INSTRUMENT_REQUESTS_RAW = os.getenv("MONITORING_INSTRUMENT_REQUESTS") +MONITORING_INSTRUMENT_REQUESTS_RAW = os.getenv( + "MONITORING_INSTRUMENT_REQUESTS") MONITORING_INSTRUMENT_REQUESTS = ( MONITORING_INSTRUMENT_REQUESTS_RAW or "false").lower() == "true" -MONITORING_FASTAPI_INCLUDED_URLS = os.getenv("MONITORING_FASTAPI_INCLUDED_URLS", "") -MONITORING_FASTAPI_EXCLUDED_URLS = os.getenv("MONITORING_FASTAPI_EXCLUDED_URLS", "") -MONITORING_FASTAPI_EXCLUDE_SPANS = os.getenv("MONITORING_FASTAPI_EXCLUDE_SPANS", "receive,send") +MONITORING_FASTAPI_INCLUDED_URLS = os.getenv( + "MONITORING_FASTAPI_INCLUDED_URLS", "") +MONITORING_FASTAPI_EXCLUDED_URLS = os.getenv( + "MONITORING_FASTAPI_EXCLUDED_URLS", "") +MONITORING_FASTAPI_EXCLUDE_SPANS = os.getenv( + "MONITORING_FASTAPI_EXCLUDE_SPANS", "receive,send") MONITORING_PROJECT_NAME = os.getenv("MONITORING_PROJECT_NAME", "") MONITORING_DASHBOARD_URL = os.getenv("MONITORING_DASHBOARD_URL", "") -MONITORING_TRACE_CONTENT_MODE = os.getenv("MONITORING_TRACE_CONTENT_MODE", "summary") +MONITORING_TRACE_CONTENT_MODE = os.getenv( + "MONITORING_TRACE_CONTENT_MODE", "summary") MONITORING_TRACE_MAX_CHARS = os.getenv("MONITORING_TRACE_MAX_CHARS", "4000") MONITORING_TRACE_MAX_ITEMS = os.getenv("MONITORING_TRACE_MAX_ITEMS", "20") TELEMETRY_SAMPLE_RATE_RAW = os.getenv("TELEMETRY_SAMPLE_RATE") TELEMETRY_SAMPLE_RATE = float(TELEMETRY_SAMPLE_RATE_RAW or "1.0") # Parse OTLP headers into dict format + + def _parse_otlp_headers(headers_str: str) -> dict: """Parse OTLP headers string into dict. Format: 'key1=value1,key2=value2'""" if not headers_str: @@ -416,6 +453,7 @@ def _parse_otlp_headers(headers_str: str) -> dict: headers[key.strip()] = value.strip() return headers + OTLP_HEADERS = _parse_otlp_headers(OTEL_EXPORTER_OTLP_HEADERS) if OTEL_EXPORTER_OTLP_AUTHORIZATION: OTLP_HEADERS["Authorization"] = OTEL_EXPORTER_OTLP_AUTHORIZATION @@ -448,7 +486,7 @@ def _parse_otlp_headers(headers_str: str) -> dict: # APP Version -APP_VERSION = "v2.2.0" +APP_VERSION = "v2.2.1" # Skill Creation Streaming Configuration diff --git a/backend/consts/model.py b/backend/consts/model.py index 6969999fe..00e5b8a0a 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -1,8 +1,8 @@ from enum import Enum -from typing import Optional, Any, List, Dict +from typing import Optional, Any, List, Dict, Literal from pydantic import BaseModel, Field, EmailStr, ConfigDict, field_validator -from nexent.core.agents.agent_model import ToolConfig +from nexent.core.agents.agent_model import AgentVerificationConfig, ToolConfig from consts.prompt_template import PROMPT_GENERATE_TEMPLATE_FIELD_ALIAS_MAP @@ -230,6 +230,24 @@ class HistoryItem(BaseModel): minio_files: Optional[List[Dict[str, Any]]] = None +class AgentToolParamsRequest(BaseModel): + """Request-scoped tool parameter overrides for a single agent.""" + + tools: Dict[str, Dict[str, Any]] = Field( + default_factory=dict, + description="Mapping from tool identifier to request-scoped override params", + ) + + +class ToolParamsRequest(BaseModel): + """Request-scoped tool parameter overrides for main and managed agents.""" + + agents: Dict[str, AgentToolParamsRequest] = Field( + default_factory=dict, + description="Mapping from agent identifier to tool parameter overrides", + ) + + class AgentRequest(BaseModel): query: str conversation_id: Optional[int] = None @@ -240,6 +258,7 @@ class AgentRequest(BaseModel): model_id: Optional[int] = None version_no: Optional[int] = None is_debug: Optional[bool] = False + tool_params: Optional[ToolParamsRequest] = None class MessageUnit(BaseModel): @@ -414,6 +433,9 @@ class OptimizePromptSectionRequest(BaseModel): section_title: str current_content: str feedback: str + mode: Literal["general", "insert", "select"] = "general" + start_pos: Optional[int] = Field(None, description="Start position for insert/select mode") + end_pos: Optional[int] = Field(None, description="End position for insert/select mode") tool_ids: Optional[List[int]] = Field( None, description="Optional: tool IDs from frontend (takes precedence over database query)") sub_agent_ids: Optional[List[int]] = Field( @@ -422,6 +444,38 @@ class OptimizePromptSectionRequest(BaseModel): None, description="Optional: knowledge base display names from frontend (takes precedence over database query)") +class BadCaseItem(BaseModel): + question: str + answer: str + label: Optional[str] = None + reason: Optional[str] = None + + +class OptimizePromptBadCaseRequest(BaseModel): + agent_id: int + model_id: int + current_content: str + bad_cases: List[BadCaseItem] + section_type: str + section_title: str + tool_ids: Optional[List[int]] = Field(None) + sub_agent_ids: Optional[List[int]] = Field(None) + knowledge_base_display_names: Optional[List[str]] = Field(None) + + +class OptimizeFromDebugSelected(BaseModel): + user_question: str + assistant_answer: str + + +class OptimizePromptFromDebugRequest(BaseModel): + agent_id: int + model_id: int + feedback: str + selected: OptimizeFromDebugSelected + history: Optional[List[HistoryItem]] = None + + class GenerateTitleRequest(BaseModel): conversation_id: int question: str @@ -454,8 +508,18 @@ class AgentInfoRequest(BaseModel): group_ids: Optional[List[int]] = None ingroup_permission: Optional[str] = None enable_context_manager: Optional[bool] = None + verification_config: Optional[Dict[str, Any]] = None + greeting_message: Optional[str] = None + example_questions: Optional[List[str]] = None version_no: int = 0 + @field_validator("verification_config", mode="before") + @classmethod + def normalize_verification_config(cls, value): + if value is None: + return None + return AgentVerificationConfig.model_validate(value).model_dump() + class AgentIDRequest(BaseModel): agent_id: int @@ -520,6 +584,7 @@ class MessageIdRequest(BaseModel): class ExportAndImportAgentInfo(BaseModel): agent_id: int + tenant_id: Optional[str] = None name: str display_name: Optional[str] = None description: str @@ -527,6 +592,7 @@ class ExportAndImportAgentInfo(BaseModel): author: Optional[str] = None max_steps: int provide_run_summary: bool + verification_config: Optional[Dict[str, Any]] = None duty_prompt: Optional[str] = None constraint_prompt: Optional[str] = None few_shots_prompt: Optional[str] = None @@ -556,6 +622,11 @@ class ExportAndImportDataFormat(BaseModel): mcp_info: List[MCPInfo] +class AgentRepositorySnapshot(ExportAndImportDataFormat): + """Frozen marketplace snapshot: export format plus optional skill ZIP payloads.""" + skills: Optional[List["SkillZipEntry"]] = None + + class SkillZipEntry(BaseModel): """A skill bundled inside an agent export ZIP.""" skill_name: str diff --git a/backend/data_process/tasks.py b/backend/data_process/tasks.py index f2a30f9b7..4dd6edd69 100644 --- a/backend/data_process/tasks.py +++ b/backend/data_process/tasks.py @@ -8,9 +8,11 @@ import os import threading import time +from dataclasses import dataclass from typing import Any, Dict, Optional, List, Tuple import aiohttp +import requests import re import ray from celery import Task, chain, states, group, chord @@ -19,6 +21,7 @@ from utils.file_management_utils import get_file_size from database.attachment_db import get_file_stream +from database.knowledge_db import get_knowledge_record from services.redis_service import get_redis_service from .app import app from .ray_actors import DataProcessorRayActor @@ -43,10 +46,12 @@ logger = logging.getLogger("data_process.tasks") -ASYNC_SPLIT_RETRY_MAX = max(FORWARD_REDIS_RETRY_MAX * 5, FORWARD_REDIS_RETRY_MAX) +ASYNC_SPLIT_RETRY_MAX = max( + FORWARD_REDIS_RETRY_MAX * 5, FORWARD_REDIS_RETRY_MAX) FORWARD_ES_CHUNK_BATCH_SIZE = 64 IMAGE_METADATA_PROCESS_SOURCE = "UniversalImageExtractor" + def _wait_for_split_ready(redis_key: str, timeout_s: int, poll_interval_ms: int) -> int: """ Wait until async split aggregation is marked ready in Redis. @@ -91,7 +96,8 @@ def _estimate_parallel_parts() -> int: def _compute_split_wait_timeout(parts_count: int) -> int: base_timeout = DP_REDIS_CHUNKS_WAIT_TIMEOUT_S waves = math.ceil(max(1, parts_count) / _estimate_parallel_parts()) - dynamic_timeout = base_timeout + max(0, waves - 1) * max(1, PER_WAVE_TIMEOUT) + dynamic_timeout = base_timeout + \ + max(0, waves - 1) * max(1, PER_WAVE_TIMEOUT) return min(MAX_TIMEOUT, max(base_timeout, dynamic_timeout)) @@ -178,7 +184,6 @@ def _build_balanced_batches( return batches - # Thread lock for initializing Ray to prevent race conditions ray_init_lock = threading.Lock() @@ -327,6 +332,35 @@ def run_in_thread(): raise +def _delete_source_file_via_http_sync( + *, + base_url: str, + index_name: str, + path_or_url: str, + scope: str, + timeout_s: float = 30.0, +) -> Dict[str, Any]: + base = (base_url or "").rstrip("/") + if not base: + raise RuntimeError("ELASTICSEARCH_SERVICE is not configured") + url = f"{base}/indices/{index_name}/documents" + params = {"path_or_url": path_or_url, "scope": scope} + + resp = requests.delete(url, params=params, timeout=timeout_s) + body_text = getattr(resp, "text", "") + parsed = None + try: + parsed = resp.json() + except Exception: + parsed = _parse_json_or_none(body_text) if body_text else None + + return { + "http_status": getattr(resp, "status_code", None), + "response_json": parsed if isinstance(parsed, dict) else None, + "response_text": body_text if not isinstance(parsed, dict) else None, + } + + def _build_forward_error( message: str, index_name: str, @@ -350,6 +384,206 @@ def _parse_json_or_none(text: str) -> Optional[Dict[str, Any]]: return None +@dataclass(frozen=True) +class _ForwardContext: + task_id: str + request_id: str + start_time: float + source: str + index_name: str + source_type: str + original_filename: Optional[str] + + +def _init_forward_context( + *, + task_id: str, + request_id: str, + start_time: float, + source: str, + index_name: str, + source_type: str, + original_filename: Optional[str], +) -> _ForwardContext: + return _ForwardContext( + task_id=task_id, + request_id=request_id, + start_time=start_time, + source=source, + index_name=index_name, + source_type=source_type, + original_filename=original_filename, + ) + + +def _is_forward_task_cancelled(ctx: _ForwardContext) -> bool: + try: + redis_service = get_redis_service() + return bool(redis_service.is_task_cancelled(ctx.task_id)) + except Exception as exc: + logger.warning( + f"[{ctx.request_id}] FORWARD TASK: Failed to check cancellation flag for task {ctx.task_id}: " + f"{exc}" + ) + return False + + +def _build_forward_cancelled_result(ctx: _ForwardContext) -> Dict[str, Any]: + return { + 'task_id': ctx.task_id, + 'source': ctx.source, + 'index_name': ctx.index_name, + 'original_filename': ctx.original_filename, + 'chunks_stored': 0, + 'storage_time': 0, + 'es_result': { + "success": False, + "message": "Indexing cancelled because document was deleted.", + "total_indexed": 0, + "total_submitted": 0, + }, + } + + +def _load_forward_chunks( + self: Task, + *, + processed_data: Dict[str, Any], + original_source: str, + original_index_name: str, + filename: Optional[str], +) -> Tuple[Optional[List[Dict[str, Any]]], bool, str, str, Optional[str]]: + chunks = processed_data.get('chunks') + split_async = bool(processed_data.get('split_async')) + + # If chunks are not in payload, try loading from Redis via the redis_key + if (not chunks) and processed_data.get('redis_key'): + redis_key = processed_data.get('redis_key') + if not REDIS_BACKEND_URL: + raise Exception(json.dumps({ + "message": "REDIS_BACKEND_URL not configured to retrieve chunks", + "index_name": original_index_name, + "task_name": "forward", + "source": original_source, + "original_filename": filename + }, ensure_ascii=False)) + try: + import redis + client = redis.Redis.from_url( + REDIS_BACKEND_URL, decode_responses=True) + ready_key = f"{redis_key}:ready" + if split_async: + ready_flag = client.get(ready_key) + if not ready_flag: + retry_num = getattr(self.request, 'retries', 0) + logger.info( + f"[{self.request.id}] FORWARD TASK: Async split not ready for key {redis_key}. Retry {retry_num + 1}/{ASYNC_SPLIT_RETRY_MAX} in {FORWARD_REDIS_RETRY_DELAY_S}s") + raise self.retry( + countdown=FORWARD_REDIS_RETRY_DELAY_S, + max_retries=ASYNC_SPLIT_RETRY_MAX, + exc=Exception(json.dumps({ + "message": "Async split not ready; will retry", + "index_name": original_index_name, + "task_name": "forward", + "source": original_source, + "original_filename": filename + }, ensure_ascii=False)) + ) + cached = client.get(redis_key) + if cached: + try: + logger.debug( + f"[{self.request.id}] FORWARD TASK: Retrieved Redis key '{redis_key}', payload_length={len(cached)}") + chunks = json.loads(cached) + except json.JSONDecodeError as jde: + # Log raw prefix to help diagnose incorrect writes + raw_preview = cached[:120] if isinstance( + cached, str) else str(type(cached)) + logger.error( + f"[{self.request.id}] FORWARD TASK: JSON decode error for key '{redis_key}': {str(jde)}; raw_prefix={raw_preview!r}") + raise + else: + if split_async: + retry_num = getattr(self.request, 'retries', 0) + logger.info( + f"[{self.request.id}] FORWARD TASK: Async split ready but chunks missing for key {redis_key}. Retry {retry_num + 1}/{ASYNC_SPLIT_RETRY_MAX} in {FORWARD_REDIS_RETRY_DELAY_S}s") + raise self.retry( + countdown=FORWARD_REDIS_RETRY_DELAY_S, + max_retries=ASYNC_SPLIT_RETRY_MAX, + exc=Exception(json.dumps({ + "message": "Async split ready but chunks missing; will retry", + "index_name": original_index_name, + "task_name": "forward", + "source": original_source, + "original_filename": filename + }, ensure_ascii=False)) + ) + # No busy-wait: release the worker slot and retry later + retry_num = getattr(self.request, 'retries', 0) + logger.info( + f"[{self.request.id}] FORWARD TASK: Chunks not yet available for key {redis_key}. Retry {retry_num + 1}/{FORWARD_REDIS_RETRY_MAX} in {FORWARD_REDIS_RETRY_DELAY_S}s") + raise self.retry( + countdown=FORWARD_REDIS_RETRY_DELAY_S, + max_retries=FORWARD_REDIS_RETRY_MAX, + exc=Exception(json.dumps({ + "message": "Chunks not ready in Redis; will retry", + "index_name": original_index_name, + "task_name": "forward", + "source": original_source, + "original_filename": filename + }, ensure_ascii=False)) + ) + except Retry: + raise + except Exception as exc: + raise Exception(json.dumps({ + "message": f"Failed to retrieve chunks from Redis: {str(exc)}", + "index_name": original_index_name, + "task_name": "forward", + "source": original_source, + "original_filename": filename + }, ensure_ascii=False)) + + if processed_data.get('source'): + original_source = processed_data.get('source') + if processed_data.get('index_name'): + original_index_name = processed_data.get('index_name') + if processed_data.get('original_filename'): + filename = processed_data.get('original_filename') + + logger.info( + f"[{self.request.id}] FORWARD TASK: Received data for source '{original_source}' with {len(chunks) if chunks else 'None'} chunks") + + if chunks is None: + raise Exception(json.dumps({ + "message": "No chunks received for forwarding", + "index_name": original_index_name, + "task_name": "forward", + "source": original_source, + "original_filename": filename + }, ensure_ascii=False)) + if len(chunks) == 0: + if split_async and processed_data.get('redis_key'): + retry_num = getattr(self.request, 'retries', 0) + logger.info( + f"[{self.request.id}] FORWARD TASK: Empty chunks while waiting for async split. Retry {retry_num + 1}/{ASYNC_SPLIT_RETRY_MAX} in {FORWARD_REDIS_RETRY_DELAY_S}s") + raise self.retry( + countdown=FORWARD_REDIS_RETRY_DELAY_S, + max_retries=ASYNC_SPLIT_RETRY_MAX, + exc=Exception(json.dumps({ + "message": "Chunks not ready in Redis (empty); will retry", + "index_name": original_index_name, + "task_name": "forward", + "source": original_source, + "original_filename": filename + }, ensure_ascii=False)) + ) + logger.warning( + f"[{self.request.id}] FORWARD TASK: Empty chunks list received for source {original_source}") + + return chunks, split_async, original_source, original_index_name, filename + + def _extract_error_code_from_es_response( parsed_body: Optional[Dict[str, Any]], text: str, @@ -404,7 +638,7 @@ async def _post(): try: connector = aiohttp.TCPConnector(verify_ssl=False) timeout = aiohttp.ClientTimeout(total=600) - + request_params: Dict[str, str] = {} if large_mode: @@ -423,7 +657,8 @@ async def _post(): parsed_body = _parse_json_or_none(text) if status >= 400: - error_code = _extract_error_code_from_es_response(parsed_body, text) + error_code = _extract_error_code_from_es_response( + parsed_body, text) if error_code: raise Exception(json.dumps({ "error_code": error_code @@ -508,7 +743,8 @@ def get_actor(self) -> Any: if not self.actors: actor = self._create_and_warm_actor() if actor is None: - raise RuntimeError("Global actor pool is empty and actor warm-up failed") + raise RuntimeError( + "Global actor pool is empty and actor warm-up failed") self.actors.append(actor) idx = self.rr_index % len(self.actors) self.rr_index += 1 @@ -552,10 +788,12 @@ def prewarm_ray_actors(target_size: Optional[int] = None) -> int: """ Ensure a global shared pool of warm Ray actors exists for low-latency task execution. """ - desired = RAY_GLOBAL_ACTOR_POOL_SIZE if target_size is None else max(0, int(target_size)) + desired = RAY_GLOBAL_ACTOR_POOL_SIZE if target_size is None else max( + 0, int(target_size)) manager = _get_or_create_global_pool_manager() current_after = ray.get( - manager.ensure_pool.remote(desired=desired, max_allowed=_estimate_parallel_parts()) + manager.ensure_pool.remote( + desired=desired, max_allowed=_estimate_parallel_parts()) ) logger.info( f"Global Ray actor pool ready: current={current_after}, desired={desired}" @@ -578,6 +816,7 @@ def _get_split_actor() -> Any: """ return get_ray_actor() + class LoggingTask(Task): """Base task class with enhanced logging""" @@ -645,7 +884,8 @@ def process_part( "chunks_count": len(chunks), } except Exception as e: - logger.error(f"[process_part] Failed to process part for '{filename}': {str(e)}") + logger.error( + f"[process_part] Failed to process part for '{filename}': {str(e)}") return { "part_redis_key": part_redis_key, "chunks_count": 0, @@ -1159,7 +1399,8 @@ def process( fetch_start = time.perf_counter() file_stream = get_file_stream(source) if file_stream is None: - raise FileNotFoundError(f"Unable to fetch file from URL: {source}") + raise FileNotFoundError( + f"Unable to fetch file from URL: {source}") file_data = file_stream.read() fetch_elapsed = time.perf_counter() - fetch_start logger.info( @@ -1208,7 +1449,8 @@ def process( if cached: cached_chunks = json.loads(cached) if isinstance(cached_chunks, list): - image_metadata_chunk_count = _count_image_metadata_chunks(cached_chunks) + image_metadata_chunk_count = _count_image_metadata_chunks( + cached_chunks) except Exception as image_count_exc: logger.warning( f"[{self.request.id}] PROCESS TASK: Failed counting image metadata chunks for async split: {image_count_exc}") @@ -1232,17 +1474,17 @@ def process( self.update_state( state=states.SUCCESS, meta={ - 'chunks_count': chunk_count, - 'processing_time': elapsed_time, - 'source': source, - 'index_name': index_name, - 'original_filename': original_filename, - 'task_name': 'process', - 'stage': 'text_extracted', - 'file_size_mb': file_size_mb, - 'processing_speed_mb_s': file_size_mb / elapsed_time if file_size_mb > 0 and elapsed_time > 0 else 0 - } - ) + 'chunks_count': chunk_count, + 'processing_time': elapsed_time, + 'source': source, + 'index_name': index_name, + 'original_filename': original_filename, + 'task_name': 'process', + 'stage': 'text_extracted', + 'file_size_mb': file_size_mb, + 'processing_speed_mb_s': file_size_mb / elapsed_time if file_size_mb > 0 and elapsed_time > 0 else 0 + } + ) logger.info( f"[{self.request.id}] PROCESS TASK: Processing complete, waiting for forward task") @@ -1408,165 +1650,34 @@ def forward( filename = original_filename try: - # Before doing any heavy work, check whether this task has been - # explicitly cancelled (for example, because the user deleted the - # document from the knowledge base configuration page). - try: - redis_service = get_redis_service() - if redis_service.is_task_cancelled(task_id): - logger.info( - f"[{self.request.id}] FORWARD TASK: Detected cancellation flag for task {task_id}; " - f"skipping chunk forwarding for source '{source}' in index '{index_name}'." - ) - # Treat this as a graceful early exit. We still return a - # structured payload so callers can consider the task done. - return { - 'task_id': task_id, - 'source': source, - 'index_name': index_name, - 'original_filename': original_filename, - 'chunks_stored': 0, - 'storage_time': 0, - 'es_result': { - "success": False, - "message": "Indexing cancelled because document was deleted.", - "total_indexed": 0, - "total_submitted": 0, - }, - } - except Exception as cancel_check_exc: - logger.warning( - f"[{self.request.id}] FORWARD TASK: Failed to check cancellation flag for task {task_id}: " - f"{cancel_check_exc}" - ) + ctx = _init_forward_context( + task_id=task_id, + request_id=str(self.request.id), + start_time=start_time, + source=source, + index_name=index_name, + source_type=source_type, + original_filename=original_filename, + ) - chunks = processed_data.get('chunks') - split_async = bool(processed_data.get('split_async')) - # If chunks are not in payload, try loading from Redis via the redis_key - if (not chunks) and processed_data.get('redis_key'): - redis_key = processed_data.get('redis_key') - if not REDIS_BACKEND_URL: - raise Exception(json.dumps({ - "message": "REDIS_BACKEND_URL not configured to retrieve chunks", - "index_name": original_index_name, - "task_name": "forward", - "source": original_source, - "original_filename": filename - }, ensure_ascii=False)) - try: - import redis - client = redis.Redis.from_url( - REDIS_BACKEND_URL, decode_responses=True) - ready_key = f"{redis_key}:ready" - if split_async: - ready_flag = client.get(ready_key) - if not ready_flag: - retry_num = getattr(self.request, 'retries', 0) - logger.info( - f"[{self.request.id}] FORWARD TASK: Async split not ready for key {redis_key}. Retry {retry_num + 1}/{ASYNC_SPLIT_RETRY_MAX} in {FORWARD_REDIS_RETRY_DELAY_S}s") - raise self.retry( - countdown=FORWARD_REDIS_RETRY_DELAY_S, - max_retries=ASYNC_SPLIT_RETRY_MAX, - exc=Exception(json.dumps({ - "message": "Async split not ready; will retry", - "index_name": original_index_name, - "task_name": "forward", - "source": original_source, - "original_filename": filename - }, ensure_ascii=False)) - ) - cached = client.get(redis_key) - if cached: - try: - logger.debug( - f"[{self.request.id}] FORWARD TASK: Retrieved Redis key '{redis_key}', payload_length={len(cached)}") - chunks = json.loads(cached) - except json.JSONDecodeError as jde: - # Log raw prefix to help diagnose incorrect writes - raw_preview = cached[:120] if isinstance( - cached, str) else str(type(cached)) - logger.error( - f"[{self.request.id}] FORWARD TASK: JSON decode error for key '{redis_key}': {str(jde)}; raw_prefix={raw_preview!r}") - raise - else: - if split_async: - retry_num = getattr(self.request, 'retries', 0) - logger.info( - f"[{self.request.id}] FORWARD TASK: Async split ready but chunks missing for key {redis_key}. Retry {retry_num + 1}/{ASYNC_SPLIT_RETRY_MAX} in {FORWARD_REDIS_RETRY_DELAY_S}s") - raise self.retry( - countdown=FORWARD_REDIS_RETRY_DELAY_S, - max_retries=ASYNC_SPLIT_RETRY_MAX, - exc=Exception(json.dumps({ - "message": "Async split ready but chunks missing; will retry", - "index_name": original_index_name, - "task_name": "forward", - "source": original_source, - "original_filename": filename - }, ensure_ascii=False)) - ) - # No busy-wait: release the worker slot and retry later - retry_num = getattr(self.request, 'retries', 0) - logger.info( - f"[{self.request.id}] FORWARD TASK: Chunks not yet available for key {redis_key}. Retry {retry_num + 1}/{FORWARD_REDIS_RETRY_MAX} in {FORWARD_REDIS_RETRY_DELAY_S}s") - raise self.retry( - countdown=FORWARD_REDIS_RETRY_DELAY_S, - max_retries=FORWARD_REDIS_RETRY_MAX, - exc=Exception(json.dumps({ - "message": "Chunks not ready in Redis; will retry", - "index_name": original_index_name, - "task_name": "forward", - "source": original_source, - "original_filename": filename - }, ensure_ascii=False)) - ) - except Retry: - raise - except Exception as exc: - raise Exception(json.dumps({ - "message": f"Failed to retrieve chunks from Redis: {str(exc)}", - "index_name": original_index_name, - "task_name": "forward", - "source": original_source, - "original_filename": filename - }, ensure_ascii=False)) - if processed_data.get('source'): - original_source = processed_data.get('source') - if processed_data.get('index_name'): - original_index_name = processed_data.get('index_name') - if processed_data.get('original_filename'): - filename = processed_data.get('original_filename') - logger.info( - f"[{self.request.id}] FORWARD TASK: Received data for source '{original_source}' with {len(chunks) if chunks else 'None'} chunks") + # Before doing any heavy work, check whether this task has been explicitly cancelled. + if _is_forward_task_cancelled(ctx): + logger.info( + f"[{self.request.id}] FORWARD TASK: Detected cancellation flag for task {task_id}; " + f"skipping chunk forwarding for source '{source}' in index '{index_name}'." + ) + return _build_forward_cancelled_result(ctx) + + chunks, split_async, original_source, original_index_name, filename = _load_forward_chunks( + self, + processed_data=processed_data, + original_source=original_source, + original_index_name=original_index_name, + filename=filename, + ) # Calculate total chunks for progress tracking total_chunks = len(chunks) if chunks else 0 - - if chunks is None: - raise Exception(json.dumps({ - "message": "No chunks received for forwarding", - "index_name": original_index_name, - "task_name": "forward", - "source": original_source, - "original_filename": original_filename - }, ensure_ascii=False)) - if len(chunks) == 0: - if split_async and processed_data.get('redis_key'): - retry_num = getattr(self.request, 'retries', 0) - logger.info( - f"[{self.request.id}] FORWARD TASK: Empty chunks while waiting for async split. Retry {retry_num + 1}/{ASYNC_SPLIT_RETRY_MAX} in {FORWARD_REDIS_RETRY_DELAY_S}s") - raise self.retry( - countdown=FORWARD_REDIS_RETRY_DELAY_S, - max_retries=ASYNC_SPLIT_RETRY_MAX, - exc=Exception(json.dumps({ - "message": "Chunks not ready in Redis (empty); will retry", - "index_name": original_index_name, - "task_name": "forward", - "source": original_source, - "original_filename": filename - }, ensure_ascii=False)) - ) - logger.warning( - f"[{self.request.id}] FORWARD TASK: Empty chunks list received for source {original_source}") formatted_chunks = [] # Compute once per file to avoid repeated IO/MinIO calls inside loop file_size = get_file_size(source_type, original_source) if isinstance( @@ -1757,6 +1868,7 @@ def forward( logger.info( f"[{self.request.id}] FORWARD TASK: Successfully stored {len(chunks)} chunks to index {original_index_name} in {end_time - start_time:.2f}s") + return { 'task_id': task_id, 'source': original_source, @@ -1839,9 +1951,106 @@ def forward( raise -@app.task(bind=True, base=LoggingTask, name='data_process.tasks.process_and_forward') -def process_and_forward( - self, +@app.task( + bind=True, + base=LoggingTask, + name="data_process.tasks.cleanup_source", + queue="forward_q", +) +def cleanup_source(self, forward_result: Dict[str, Any]) -> Dict[str, Any]: + """ + Conditionally delete the MinIO source file after successful indexing. + + If the knowledge base is configured with preserve_source_file=false, call: + DELETE /indices/{index_name}/documents?path_or_url=...&scope=source_only + """ + index_name = (forward_result or {}).get("index_name") + source = (forward_result or {}).get("source") + + cleanup_info: Dict[str, Any] = { + "attempted": False, + "skipped_reason": None, + "success": None, + "http_status": None, + "response": None, + "error": None, + } + + if not index_name or not source: + cleanup_info["skipped_reason"] = "missing_index_name_or_source" + forward_result = dict(forward_result or {}) + forward_result["source_cleanup"] = cleanup_info + return forward_result + + try: + record = get_knowledge_record({"index_name": index_name}) or {} + preserve_source_file = record.get("preserve_source_file", True) + except Exception as exc: + logger.warning( + "[%s] CLEANUP TASK: Failed to load knowledge config for index '%s': %s", + getattr(self.request, "id", "unknown"), + index_name, + exc, + ) + cleanup_info["skipped_reason"] = "knowledge_record_lookup_failed" + forward_result = dict(forward_result or {}) + forward_result["source_cleanup"] = cleanup_info + return forward_result + + if preserve_source_file: + cleanup_info["skipped_reason"] = "preserve_source_file_true" + forward_result = dict(forward_result or {}) + forward_result["source_cleanup"] = cleanup_info + return forward_result + + cleanup_info["attempted"] = True + try: + resp = _delete_source_file_via_http_sync( + base_url=ELASTICSEARCH_SERVICE, + index_name=index_name, + path_or_url=source, + scope="source_only", + ) + cleanup_info["http_status"] = resp.get("http_status") + cleanup_info["response"] = ( + resp.get("response_json") + if resp.get("response_json") is not None + else resp.get("response_text") + ) + + ok = False + if isinstance(resp.get("response_json"), dict): + ok = bool(resp["response_json"].get("status") == "success") + elif resp.get("http_status") and 200 <= int(resp["http_status"]) < 300: + ok = True + + cleanup_info["success"] = ok + if not ok: + logger.warning( + "[%s] CLEANUP TASK: Source-only delete did not succeed. index='%s' source='%s' http_status=%s", + getattr(self.request, "id", "unknown"), + index_name, + source, + cleanup_info["http_status"], + ) + except Exception as exc: + cleanup_info["success"] = False + cleanup_info["error"] = str(exc) + logger.warning( + "[%s] CLEANUP TASK: Source-only delete failed. index='%s' source='%s' error=%s", + getattr(self.request, "id", "unknown"), + index_name, + source, + exc, + ) + + forward_result = dict(forward_result or {}) + forward_result["source_cleanup"] = cleanup_info + return forward_result + + +def submit_process_forward_chain( + *, source: str, source_type: str, chunking_strategy: str, @@ -1849,30 +2058,14 @@ def process_and_forward( original_filename: Optional[str] = None, authorization: Optional[str] = None, embedding_model_id: Optional[int] = None, - tenant_id: Optional[str] = None + tenant_id: Optional[str] = None, ) -> str: """ - Combined task that chains processing and forwarding - - This task delegates to a chain of process -> forward - - Args: - source: Source file path, URL, or text content - source_type: source of the file("local", "minio") - chunking_strategy: Strategy for chunking the document - index_name: Name of the index to store documents - original_filename: The original name of the file - authorization: Authorization header for API calls - embedding_model_id: Embedding model ID for chunk size configuration - tenant_id: Tenant ID for retrieving model configuration + Build and enqueue a Celery chain: process -> forward. Returns: - Task ID of the chain + Celery chain task ID, or empty string if enqueue failed. """ - logger.info( - f"Starting processing chain for {source}, original_filename={original_filename}, strategy={chunking_strategy}, index={index_name}, model_id={embedding_model_id}") - - # Create a task chain task_chain = chain( process.s( source=source, @@ -1889,20 +2082,66 @@ def process_and_forward( source_type=source_type, original_filename=original_filename, authorization=authorization - ).set(queue='forward_q') + ).set(queue='forward_q'), + cleanup_source.s().set(queue='forward_q'), ) - # Execute the chain result = task_chain.apply_async() if result is None or not hasattr(result, 'id') or result.id is None: logger.error( "Celery chain apply_async() did not return a valid result or result.id") return "" - logger.info(f"Created task chain ID: {result.id}") - return result.id +@app.task(bind=True, base=LoggingTask, name='data_process.tasks.process_and_forward') +def process_and_forward( + self, + source: str, + source_type: str, + chunking_strategy: str, + index_name: Optional[str] = None, + original_filename: Optional[str] = None, + authorization: Optional[str] = None, + embedding_model_id: Optional[int] = None, + tenant_id: Optional[str] = None +) -> str: + """ + Combined task that chains processing and forwarding + + This task delegates to a chain of process -> forward + + Args: + source: Source file path, URL, or text content + source_type: source of the file("local", "minio") + chunking_strategy: Strategy for chunking the document + index_name: Name of the index to store documents + original_filename: The original name of the file + authorization: Authorization header for API calls + embedding_model_id: Embedding model ID for chunk size configuration + tenant_id: Tenant ID for retrieving model configuration + + Returns: + Task ID of the chain + """ + logger.info( + f"Starting processing chain for {source}, original_filename={original_filename}, strategy={chunking_strategy}, index={index_name}, model_id={embedding_model_id}") + + chain_id = submit_process_forward_chain( + source=source, + source_type=source_type, + chunking_strategy=chunking_strategy, + index_name=index_name, + original_filename=original_filename, + authorization=authorization, + embedding_model_id=embedding_model_id, + tenant_id=tenant_id, + ) + if chain_id: + logger.info(f"Created task chain ID: {chain_id}") + return chain_id + + @app.task(bind=True, base=LoggingTask, name='data_process.tasks.process_sync') def process_sync( self, diff --git a/backend/database/agent_db.py b/backend/database/agent_db.py index 82696ffab..533659b0f 100644 --- a/backend/database/agent_db.py +++ b/backend/database/agent_db.py @@ -1,9 +1,10 @@ import logging -from typing import List +from typing import List, Optional from sqlalchemy import or_, update from database.client import get_db_session, as_dict, filter_property from database.db_models import AgentInfo, ToolInstance, AgentRelation +from database.agent_version_db import query_current_version_no from consts.const import ASSET_OWNER_TENANT_ID from utils.str_utils import convert_list_to_string @@ -102,6 +103,40 @@ def query_sub_agents_id_list(main_agent_id: int, tenant_id: str, version_no: int return [relation.selected_agent_id for relation in relations] +def query_sub_agent_relations(main_agent_id: int, tenant_id: str, version_no: int = 0) -> List[dict]: + """ + Query sub-agent relations by main agent id, including pinned version info. + Default version_no=0 queries the draft version. + + Args: + main_agent_id: Parent agent ID + tenant_id: Tenant ID + version_no: Version number to filter. Default 0 = draft/editing state + """ + with get_db_session() as session: + query = session.query(AgentRelation).filter( + AgentRelation.parent_agent_id == main_agent_id, + AgentRelation.tenant_id == tenant_id, + AgentRelation.version_no == version_no, + AgentRelation.delete_flag != 'Y') + relations = query.all() + return [as_dict(relation) for relation in relations] + + +def resolve_sub_agent_version_no( + selected_agent_id: int, + selected_agent_version_no: Optional[int], + tenant_id: str, +) -> int: + """ + Resolve the effective version number for a sub-agent relation. + Uses pinned version when set; otherwise falls back to child's current published version. + """ + if selected_agent_version_no is not None: + return selected_agent_version_no + return query_current_version_no(agent_id=selected_agent_id, tenant_id=tenant_id) or 0 + + def clear_agent_new_mark(agent_id: int, tenant_id: str, user_id: str, version_no: int = 0): """ Clear the NEW mark for an agent. @@ -163,6 +198,7 @@ def create_agent(agent_info, tenant_id: str, user_id: str): """ info_with_metadata = dict(agent_info) info_with_metadata.setdefault("max_steps", 15) + info_with_metadata.setdefault("verification_config", None) info_with_metadata.update({ "tenant_id": tenant_id, "version_no": 0, # Default to draft version @@ -201,6 +237,9 @@ def create_agent(agent_info, tenant_id: str, user_id: str): "group_ids": new_agent.group_ids, "is_new": new_agent.is_new, "enable_context_manager": new_agent.enable_context_manager, + "verification_config": new_agent.verification_config, + "greeting_message": new_agent.greeting_message, + "example_questions": new_agent.example_questions, "current_version_no": new_agent.current_version_no, "version_no": new_agent.version_no, "created_by": new_agent.created_by, diff --git a/backend/database/agent_repository_db.py b/backend/database/agent_repository_db.py new file mode 100644 index 000000000..a6bb4f48b --- /dev/null +++ b/backend/database/agent_repository_db.py @@ -0,0 +1,358 @@ +import logging +import math +from typing import Any, Dict, List, Optional + +from sqlalchemy import func, or_, update + +from database.client import as_dict, filter_property, get_db_session +from database.db_models import AgentRepository + +logger = logging.getLogger("agent_repository_db") + +# Listing status: NOT_SHARED (未共享), PENDING_REVIEW (待审核), +# REJECTED (审核驳回), SHARED (已共享) +STATUS_NOT_SHARED = "NOT_SHARED" +STATUS_PENDING_REVIEW = "PENDING_REVIEW" +STATUS_REJECTED = "REJECTED" +STATUS_SHARED = "SHARED" + +VALID_REPOSITORY_STATUSES = frozenset({ + STATUS_NOT_SHARED, + STATUS_PENDING_REVIEW, + STATUS_REJECTED, + STATUS_SHARED, +}) + +_UPSERT_IMMUTABLE_FIELDS = frozenset({ + "agent_id", + "agent_repository_id", + "publisher_tenant_id", +}) + +_UPSERT_SNAPSHOT_FIELDS = frozenset({ + "source_version_no", + "name", + "display_name", + "description", + "author", + "category_id", + "tags", + "tool_count", + "version_label", + "agent_info_json", +}) + + +def insert_agent_repository_record( + repository_data: Dict[str, Any], + publisher_tenant_id: str, + publisher_user_id: str, +) -> int: + """Insert a new agent repository listing record.""" + with get_db_session() as session: + payload = { + **repository_data, + "publisher_tenant_id": publisher_tenant_id, + "publisher_user_id": publisher_user_id, + "created_by": publisher_user_id, + "updated_by": publisher_user_id, + "delete_flag": "N", + } + if payload.get("status") is None: + payload["status"] = STATUS_NOT_SHARED + + new_record = AgentRepository( + **filter_property(payload, AgentRepository) + ) + session.add(new_record) + session.flush() + return int(new_record.agent_repository_id) + + +def get_agent_repository_by_id(repository_id: int) -> Optional[dict]: + """Fetch a repository listing by primary key.""" + with get_db_session() as session: + record = session.query(AgentRepository).filter( + AgentRepository.agent_repository_id == repository_id, + AgentRepository.delete_flag != "Y", + ).first() + return as_dict(record) if record else None + + +def get_agent_repository_by_id_and_publisher( + repository_id: int, + publisher_tenant_id: str, +) -> Optional[dict]: + """Fetch a repository listing scoped to the publisher tenant.""" + with get_db_session() as session: + record = session.query(AgentRepository).filter( + AgentRepository.agent_repository_id == repository_id, + AgentRepository.publisher_tenant_id == publisher_tenant_id, + AgentRepository.delete_flag != "Y", + ).first() + return as_dict(record) if record else None + + +def get_agent_repository_by_agent_id(agent_id: int) -> Optional[dict]: + """Fetch an active repository listing by root agent_id.""" + with get_db_session() as session: + record = session.query(AgentRepository).filter( + AgentRepository.agent_id == agent_id, + AgentRepository.delete_flag != "Y", + ).first() + return as_dict(record) if record else None + + +def upsert_agent_repository_record( + repository_data: Dict[str, Any], + publisher_tenant_id: str, + publisher_user_id: str, +) -> tuple[int, bool]: + """Insert or update a repository listing keyed by agent_id. + + When no record exists, inserts a new listing. When a record exists: + - Same source_version_no: updates status (and updated_by) only. + - Different source_version_no: updates all snapshot fields, preserving + agent_id, agent_repository_id, and publisher_tenant_id. + + Returns: + Tuple of (agent_repository_id, is_updated). is_updated is False on insert. + """ + agent_id = repository_data.get("agent_id") + if agent_id is None: + raise ValueError("agent_id is required for repository upsert") + + existing = get_agent_repository_by_agent_id(int(agent_id)) + if not existing: + repository_id = insert_agent_repository_record( + repository_data=repository_data, + publisher_tenant_id=publisher_tenant_id, + publisher_user_id=publisher_user_id, + ) + return repository_id, False + + existing_version = existing.get("source_version_no") + incoming_version = repository_data.get("source_version_no") + repository_id = int(existing["agent_repository_id"]) + + if existing_version == incoming_version: + update_fields: Dict[str, Any] = { + "status": repository_data.get("status", STATUS_NOT_SHARED), + "updated_by": publisher_user_id, + } + else: + update_fields = { + key: repository_data[key] + for key in _UPSERT_SNAPSHOT_FIELDS + if key in repository_data + } + update_fields["publisher_user_id"] = publisher_user_id + update_fields["updated_by"] = publisher_user_id + update_fields["status"] = repository_data.get("status", STATUS_NOT_SHARED) + + with get_db_session() as session: + session.execute( + update(AgentRepository) + .where( + AgentRepository.agent_repository_id == repository_id, + AgentRepository.publisher_tenant_id == publisher_tenant_id, + AgentRepository.delete_flag != "Y", + ) + .values(**update_fields) + ) + return repository_id, True + + +def list_agent_repository_summaries( + *, + status: Optional[str] = None, +) -> List[dict]: + """List all active repository summaries without heavy JSON blobs.""" + with get_db_session() as session: + query = session.query( + AgentRepository.agent_repository_id, + AgentRepository.author, + AgentRepository.name, + AgentRepository.display_name, + AgentRepository.description, + AgentRepository.status, + ).filter( + AgentRepository.delete_flag != "Y", + ) + if status: + query = query.filter(AgentRepository.status == status) + rows = query.order_by(AgentRepository.agent_repository_id.desc()).all() + return [ + { + "agent_repository_id": row.agent_repository_id, + "author": row.author, + "name": row.name, + "display_name": row.display_name, + "description": row.description, + "status": row.status, + } + for row in rows + ] + + +def query_agent_repository_list( + *, + page: int = 1, + page_size: int = 20, + search: Optional[str] = None, + tag: Optional[str] = None, + category_id: Optional[int] = None, + status: Optional[str] = STATUS_SHARED, + publisher_tenant_id: Optional[str] = None, +) -> Dict[str, Any]: + """Query repository listings with offset pagination.""" + page = max(page, 1) + page_size = max(min(page_size, 100), 1) + offset = (page - 1) * page_size + + with get_db_session() as session: + query = session.query(AgentRepository).filter( + AgentRepository.delete_flag != "Y", + ) + + if status: + query = query.filter(AgentRepository.status == status) + if publisher_tenant_id: + query = query.filter( + AgentRepository.publisher_tenant_id == publisher_tenant_id + ) + if category_id is not None: + query = query.filter(AgentRepository.category_id == category_id) + if tag: + query = query.filter(AgentRepository.tags.any(tag)) + if search: + keyword = f"%{search}%" + query = query.filter( + or_( + AgentRepository.name.ilike(keyword), + AgentRepository.display_name.ilike(keyword), + AgentRepository.description.ilike(keyword), + AgentRepository.author.ilike(keyword), + func.array_to_string(AgentRepository.tags, ",").ilike(keyword), + ) + ) + + total = query.count() + rows = ( + query.order_by(AgentRepository.agent_repository_id.desc()) + .offset(offset) + .limit(page_size) + .all() + ) + + total_pages = math.ceil(total / page_size) if total else 0 + return { + "items": [as_dict(row) for row in rows], + "pagination": { + "page": page, + "page_size": page_size, + "total": total, + "total_pages": total_pages, + }, + } + + +def update_agent_repository_by_id( + *, + repository_id: int, + publisher_tenant_id: str, + user_id: str, + updates: Dict[str, Any], +) -> int: + """Update a repository listing owned by the publisher tenant. Returns affected row count.""" + allowed_fields = { + "display_name", + "description", + "author", + "category_id", + "tags", + "tool_count", + "version_label", + "source_version_no", + "agent_info_json", + "status", + } + update_fields = { + key: value + for key, value in updates.items() + if key in allowed_fields + } + if not update_fields: + return 0 + + update_fields["updated_by"] = user_id + + with get_db_session() as session: + result = session.execute( + update(AgentRepository) + .where( + AgentRepository.agent_repository_id == repository_id, + AgentRepository.publisher_tenant_id == publisher_tenant_id, + AgentRepository.delete_flag != "Y", + ) + .values(**update_fields) + ) + return int(result.rowcount or 0) + + +def update_agent_repository_status_by_id( + *, + repository_id: int, + status: str, + user_id: str, +) -> int: + """Update repository listing status by primary key. Returns affected row count.""" + with get_db_session() as session: + result = session.execute( + update(AgentRepository) + .where( + AgentRepository.agent_repository_id == repository_id, + AgentRepository.delete_flag != "Y", + ) + .values(status=status, updated_by=user_id) + ) + return int(result.rowcount or 0) + + +def soft_delete_agent_repository_by_id( + *, + repository_id: int, + publisher_tenant_id: str, + user_id: str, +) -> int: + """Soft-delete a repository listing owned by the publisher tenant.""" + with get_db_session() as session: + result = session.execute( + update(AgentRepository) + .where( + AgentRepository.agent_repository_id == repository_id, + AgentRepository.publisher_tenant_id == publisher_tenant_id, + AgentRepository.delete_flag != "Y", + ) + .values(delete_flag="Y", updated_by=user_id) + ) + return int(result.rowcount or 0) + + +def list_agent_repository_by_publisher( + publisher_tenant_id: str, + *, + publisher_user_id: Optional[str] = None, +) -> List[dict]: + """List all repository listings published by a tenant.""" + with get_db_session() as session: + query = session.query(AgentRepository).filter( + AgentRepository.publisher_tenant_id == publisher_tenant_id, + AgentRepository.delete_flag != "Y", + ) + if publisher_user_id: + query = query.filter( + AgentRepository.publisher_user_id == publisher_user_id + ) + rows = query.order_by(AgentRepository.agent_repository_id.desc()).all() + return [as_dict(row) for row in rows] diff --git a/backend/database/cas_session_db.py b/backend/database/cas_session_db.py new file mode 100644 index 000000000..57d1aa8ea --- /dev/null +++ b/backend/database/cas_session_db.py @@ -0,0 +1,134 @@ +""" +Database operations for CAS-backed web sessions. +""" + +from datetime import datetime +from typing import Any, Dict, Optional + +from database.client import as_dict, get_db_session +from database.db_models import UserCasSession + +CAS_SESSION_ACTIVE = "active" +CAS_SESSION_REVOKED = "revoked" + + +def create_cas_session( + *, + session_id: str, + user_id: str, + cas_user_id: str, + expires_at: datetime, + cas_session_index: Optional[str] = None, +) -> Dict[str, Any]: + with get_db_session() as session: + record = UserCasSession( + session_id=session_id, + user_id=user_id, + cas_user_id=cas_user_id, + cas_session_index=cas_session_index, + status=CAS_SESSION_ACTIVE, + expires_at=expires_at, + created_by=user_id, + updated_by=user_id, + ) + session.add(record) + session.flush() + return as_dict(record) + + +def get_cas_session_by_session_id(session_id: str) -> Optional[Dict[str, Any]]: + if not session_id: + return None + with get_db_session() as session: + result = ( + session.query(UserCasSession) + .filter( + UserCasSession.session_id == session_id, + UserCasSession.delete_flag == "N", + ) + .first() + ) + return as_dict(result) if result else None + + +def is_cas_session_active(session_id: str) -> bool: + if not session_id: + return False + with get_db_session() as session: + result = ( + session.query(UserCasSession) + .filter( + UserCasSession.session_id == session_id, + UserCasSession.status == CAS_SESSION_ACTIVE, + UserCasSession.expires_at > datetime.now(), + UserCasSession.delete_flag == "N", + ) + .first() + ) + return result is not None + + +def revoke_cas_session_by_session_id(session_id: str, actor: str = "cas") -> int: + if not session_id: + return 0 + with get_db_session() as session: + result = ( + session.query(UserCasSession) + .filter( + UserCasSession.session_id == session_id, + UserCasSession.status == CAS_SESSION_ACTIVE, + UserCasSession.delete_flag == "N", + ) + .update( + { + "status": CAS_SESSION_REVOKED, + "revoked_at": datetime.now(), + "updated_by": actor, + } + ) + ) + return result + + +def revoke_cas_sessions_by_user_id(cas_user_id: str, actor: str = "cas") -> int: + if not cas_user_id: + return 0 + with get_db_session() as session: + result = ( + session.query(UserCasSession) + .filter( + UserCasSession.cas_user_id == cas_user_id, + UserCasSession.status == CAS_SESSION_ACTIVE, + UserCasSession.delete_flag == "N", + ) + .update( + { + "status": CAS_SESSION_REVOKED, + "revoked_at": datetime.now(), + "updated_by": actor, + } + ) + ) + return result + + +def revoke_cas_session_by_index(cas_session_index: str, actor: str = "cas") -> int: + if not cas_session_index: + return 0 + with get_db_session() as session: + result = ( + session.query(UserCasSession) + .filter( + UserCasSession.cas_session_index == cas_session_index, + UserCasSession.status == CAS_SESSION_ACTIVE, + UserCasSession.delete_flag == "N", + ) + .update( + { + "status": CAS_SESSION_REVOKED, + "revoked_at": datetime.now(), + "updated_by": actor, + } + ) + ) + return result diff --git a/backend/database/conversation_db.py b/backend/database/conversation_db.py index 18c0ee9fc..2d06bb9be 100644 --- a/backend/database/conversation_db.py +++ b/backend/database/conversation_db.py @@ -1016,3 +1016,71 @@ def get_message_id_by_index(conversation_id: int, message_index: int) -> Optiona result = session.execute(stmt).scalar() return result + + +def get_latest_assistant_message_id(conversation_id: int, user_id: Optional[str] = None) -> Optional[int]: + """ + Get the most recent assistant message ID for a conversation. + + Args: + conversation_id: Conversation ID (integer) + user_id: Optional user ID for ownership check + + Returns: + Optional[int]: The latest assistant message ID, or None if not found + """ + with get_db_session() as session: + conversation_id = int(conversation_id) + + stmt = select(ConversationMessage.message_id).where( + ConversationMessage.conversation_id == conversation_id, + ConversationMessage.delete_flag == 'N', + ConversationMessage.message_role == 'assistant' + ).order_by(desc(ConversationMessage.message_index)).limit(1) + + if user_id: + stmt = stmt.join( + ConversationRecord, + ConversationMessage.conversation_id == ConversationRecord.conversation_id + ).where(ConversationRecord.created_by == user_id) + + result = session.execute(stmt).scalar() + return result + + +def update_message_minio_files(message_id: int, skill_file_uploads: List[Dict[str, Any]]) -> bool: + """ + Merge skill file uploads into an existing message's minio_files field. + + Args: + message_id: Message ID to update + skill_file_uploads: List of skill file upload metadata dicts to append + + Returns: + bool: True if the message was updated, False if the message was not found + """ + with get_db_session() as session: + message_id = int(message_id) + + stmt = select(ConversationMessage).where( + ConversationMessage.message_id == message_id, + ConversationMessage.delete_flag == 'N' + ) + record = session.scalars(stmt).first() + if not record: + return False + + existing = record.minio_files + if existing: + try: + if isinstance(existing, str): + existing = json.loads(existing) + except (json.JSONDecodeError, TypeError): + existing = [] + else: + existing = [] + + existing.extend(skill_file_uploads) + record.minio_files = json.dumps(existing, ensure_ascii=False) + + return True diff --git a/backend/database/db_models.py b/backend/database/db_models.py index b779266c9..5450b5f74 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -15,6 +15,8 @@ _TENANT_ID_DOC = "Tenant ID for multi-tenancy isolation" # Base class for tables without audit fields + + class SimpleTableBase(DeclarativeBase): pass @@ -297,13 +299,16 @@ class AgentInfo(TableBase): agent_id = Column(Integer, Sequence( "ag_tenant_agent_t_agent_id_seq", schema=SCHEMA), nullable=False, primary_key=True, autoincrement=True, doc="ID") - version_no = Column(Integer, default=0, nullable=False, primary_key=True, doc="Version number. 0 = draft/editing state, >=1 = published snapshot") + version_no = Column(Integer, default=0, nullable=False, primary_key=True, + doc="Version number. 0 = draft/editing state, >=1 = published snapshot") name = Column(String(100), doc="Agent name") display_name = Column(String(100), doc="Agent display name") description = Column(Text, doc="Description") author = Column(String(100), doc="Agent author") - model_name = Column(String(100), doc="[DEPRECATED] Name of the model used, use model_id instead") - model_id = Column(Integer, doc="Model ID, foreign key reference to model_record_t.model_id") + model_name = Column( + String(100), doc="[DEPRECATED] Name of the model used, use model_id instead") + model_id = Column( + Integer, doc="Model ID, foreign key reference to model_record_t.model_id") max_steps = Column(Integer, doc="Maximum number of steps") duty_prompt = Column(Text, doc="Duty prompt content") constraint_prompt = Column(Text, doc="Constraint prompt content") @@ -315,15 +320,22 @@ class AgentInfo(TableBase): Boolean, doc="Whether to provide the running summary to the manager agent") business_description = Column( Text, doc="Manually entered by the user to describe the entire business process") - business_logic_model_name = Column(String(100), doc="Model name used for business logic prompt generation") - business_logic_model_id = Column(Integer, doc="Model ID used for business logic prompt generation, foreign key reference to model_record_t.model_id") - prompt_template_id = Column(Integer, doc="Prompt template ID used for business logic prompt generation") - prompt_template_name = Column(String(100), doc="Prompt template name used for business logic prompt generation") + business_logic_model_name = Column( + String(100), doc="Model name used for business logic prompt generation") + business_logic_model_id = Column( + Integer, doc="Model ID used for business logic prompt generation, foreign key reference to model_record_t.model_id") + prompt_template_id = Column( + Integer, doc="Prompt template ID used for business logic prompt generation") + prompt_template_name = Column(String( + 100), doc="Prompt template name used for business logic prompt generation") group_ids = Column(String, doc="Agent group IDs list") is_new = Column(Boolean, default=False, doc="Whether this agent is marked as new for the user") current_version_no = Column(Integer, nullable=True, doc="Current published version number. NULL means no version published yet") ingroup_permission = Column(String(30), doc="In-group permission: EDIT, READ_ONLY, PRIVATE") enable_context_manager = Column(Boolean, default=False, doc="Whether to enable context management (compression) for this agent") + verification_config = Column(JSONB, doc="Layered ReAct self-verification configuration") + greeting_message = Column(Text, doc="Agent greeting message displayed on chat initial screen") + example_questions = Column(JSONB, doc="List of example questions for starting a conversation with this agent") class PromptTemplate(TableBase): @@ -352,12 +364,15 @@ class PromptTemplate(TableBase): template_id = Column(Integer, Sequence( "ag_prompt_template_t_template_id_seq", schema=SCHEMA), primary_key=True, nullable=False, autoincrement=True, doc="Prompt template ID") - template_name = Column(String(100), nullable=False, doc="Prompt template name") + template_name = Column(String(100), nullable=False, + doc="Prompt template name") description = Column(String(500), doc="Prompt template description") - template_type = Column(String(50), nullable=False, default="agent_generate", doc="Prompt template type") + template_type = Column(String(50), nullable=False, + default="agent_generate", doc="Prompt template type") tenant_id = Column(String(100), nullable=False, doc="Tenant ID") user_id = Column(String(100), nullable=False, doc="User ID") - template_content_zh = Column(JSONB, nullable=False, doc="Chinese prompt template content") + template_content_zh = Column( + JSONB, nullable=False, doc="Chinese prompt template content") template_content_en = Column(JSONB, doc="English prompt template content") @@ -381,7 +396,8 @@ class ToolInstance(TableBase): user_id = Column(String(100), doc="User ID") tenant_id = Column(String(100), doc="Tenant ID") enabled = Column(Boolean, doc="Enabled") - version_no = Column(Integer, default=0, primary_key=True, nullable=False, doc="Version number. 0 = draft/editing state, >=1 = published snapshot") + version_no = Column(Integer, default=0, primary_key=True, nullable=False, + doc="Version number. 0 = draft/editing state, >=1 = published snapshot") class KnowledgeRecord(TableBase): @@ -397,18 +413,25 @@ class KnowledgeRecord(TableBase): knowledge_name = Column(String(100), doc="User-facing knowledge base name") knowledge_describe = Column(String(3000), doc="Knowledge base description") knowledge_sources = Column(String(300), doc="Knowledge base sources") - embedding_model_name = Column(String(200), doc="Embedding model name, used to record the embedding model used by the knowledge base") - embedding_model_id = Column(Integer, doc="Embedding model ID, foreign key reference to model_record_t.model_id") + embedding_model_name = Column(String( + 200), doc="Embedding model name, used to record the embedding model used by the knowledge base") + embedding_model_id = Column( + Integer, doc="Embedding model ID, foreign key reference to model_record_t.model_id") tenant_id = Column(String(100), doc="Tenant ID") group_ids = Column(String, doc="Knowledge base group IDs list") ingroup_permission = Column( String(30), doc="In-group permission: EDIT, READ_ONLY, PRIVATE") summary_frequency = Column(String(10), nullable=True, - doc="Auto-summary frequency: '3h', '5h', '1d', '1w', or NULL (disabled)") + doc="Auto-summary frequency: '3h', '5h', '1d', '1w', or NULL (disabled)") last_summary_time = Column(TIMESTAMP(timezone=False), nullable=True, - doc="Timestamp of last summary generation") + doc="Timestamp of last summary generation") last_doc_update_time = Column(TIMESTAMP(timezone=False), nullable=True, - doc="Timestamp of last document add/delete operation") + doc="Timestamp of last document add/delete operation") + preserve_source_file = Column( + Boolean, + default=True, + doc="Whether to preserve uploaded source documents after vectorization", + ) class TenantConfig(TableBase): @@ -481,7 +504,8 @@ class McpRecord(TableBase): doc="Custom HTTP headers as JSON object for MCP server requests", default=None, ) - source = Column(String(30), doc="Source type: local/mcp_registry/community") + source = Column( + String(30), doc="Source type: local/mcp_registry/community") registry_json = Column(JSONB, doc="Full MCP registry server.json snapshot") config_json = Column(JSON, doc="MCP config data") enabled = Column(Boolean, default=True, doc="Enabled") @@ -509,11 +533,13 @@ class McpCommunityRecord(TableBase): source = Column(String(30), doc="Source type, fixed to community") version = Column(String(50), doc="MCP version") registry_json = Column(JSONB, doc="Full MCP metadata JSON") - transport_type = Column(String(30), doc="Transport type: http/sse/container") + transport_type = Column( + String(30), doc="Transport type: http/sse/container") config_json = Column(JSON, doc="Public-shareable MCP configuration JSON") tags = Column(ARRAY(Text), doc="Tags") description = Column(Text, doc="Description") + class UserTenant(TableBase): """ User and tenant relationship table @@ -525,7 +551,8 @@ class UserTenant(TableBase): primary_key=True, nullable=False, doc="User tenant relationship ID, unique primary key") user_id = Column(String(100), nullable=False, doc="User ID") tenant_id = Column(String(100), nullable=False, doc="Tenant ID") - user_role = Column(String(30), doc="User role: SUPER_ADMIN, ADMIN, DEV, USER") + user_role = Column( + String(30), doc="User role: SUPER_ADMIN, ADMIN, DEV, USER") user_email = Column(String(255), doc="User email address") @@ -536,11 +563,18 @@ class AgentRelation(TableBase): __tablename__ = "ag_agent_relation_t" __table_args__ = {"schema": SCHEMA} - relation_id = Column(Integer, Sequence("ag_agent_relation_t_relation_id_seq", schema=SCHEMA), primary_key=True, nullable=False, doc="Relationship ID, primary key") - selected_agent_id = Column(Integer, primary_key=True, doc="Selected agent ID") + relation_id = Column(Integer, Sequence("ag_agent_relation_t_relation_id_seq", schema=SCHEMA), + primary_key=True, nullable=False, doc="Relationship ID, primary key") + selected_agent_id = Column( + Integer, primary_key=True, doc="Selected agent ID") parent_agent_id = Column(Integer, doc="Parent agent ID") tenant_id = Column(String(100), doc="Tenant ID") - version_no = Column(Integer, default=0, nullable=False, doc="Version number. 0 = draft/editing state, >=1 = published snapshot") + version_no = Column(Integer, default=0, nullable=False, + doc="Version number. 0 = draft/editing state, >=1 = published snapshot") + selected_agent_version_no = Column( + Integer, nullable=True, + doc="Pinned version of selected_agent_id. NULL = runtime fallback to child current_version_no", + ) class PartnerMappingId(TableBase): @@ -656,13 +690,51 @@ class AgentVersion(TableBase): primary_key=True, nullable=False, doc=_PRIMARY_KEY_DOC) tenant_id = Column(String(100), nullable=False, doc="Tenant ID") agent_id = Column(Integer, nullable=False, doc="Agent ID") - version_no = Column(Integer, nullable=False, doc="Version number, starts from 1. Does not include 0 (draft)") - version_name = Column(String(100), doc="User-defined version name for display") + version_no = Column(Integer, nullable=False, + doc="Version number, starts from 1. Does not include 0 (draft)") + version_name = Column( + String(100), doc="User-defined version name for display") release_note = Column(Text, doc="Release notes / publish remarks") - source_version_no = Column(Integer, doc="Source version number. If this version is a rollback, record the source version") - source_type = Column(String(30), doc="Source type: NORMAL (normal publish) / ROLLBACK (rollback and republish)") - status = Column(String(30), default="RELEASED", doc="Version status: RELEASED / DISABLED / ARCHIVED") - is_a2a = Column(Boolean, default=False, doc="Whether this version is published as an A2A Server agent") + source_version_no = Column( + Integer, doc="Source version number. If this version is a rollback, record the source version") + source_type = Column(String( + 30), doc="Source type: NORMAL (normal publish) / ROLLBACK (rollback and republish)") + status = Column(String(30), default="RELEASED", + doc="Version status: RELEASED / DISABLED / ARCHIVED") + is_a2a = Column(Boolean, default=False, + doc="Whether this version is published as an A2A Server agent") + + +class AgentRepository(TableBase): + """ + Agent repository (marketplace) table. Frozen snapshot of a published agent tree for sharing. + """ + __tablename__ = "ag_agent_repository_t" + __table_args__ = {"schema": SCHEMA} + + agent_repository_id = Column(BigInteger, Sequence("ag_agent_repository_t_agent_repository_id_seq", schema=SCHEMA), + primary_key=True, nullable=False, doc="Agent repository listing ID, unique primary key") + publisher_tenant_id = Column(String(100), nullable=False, doc="Publisher tenant ID") + publisher_user_id = Column(String(100), nullable=False, doc="Publisher user ID") + agent_id = Column(Integer, nullable=False, + doc="Root agent ID from ag_tenant_agent_t; upsert key") + source_version_no = Column(Integer, nullable=False, + doc="Published version number frozen at share time") + name = Column(String(100), nullable=False, + doc="Root agent programmatic name for display and search") + display_name = Column(String(100), doc="Root agent display name") + description = Column(Text, doc="Root agent description") + author = Column(String(100), doc="Agent author") + category_id = Column(Integer, doc="Optional marketplace category ID") + tags = Column(ARRAY(Text), doc="Marketplace tags") + tool_count = Column(Integer, + doc="Total tool count across all agents in the bundle (display only)") + version_label = Column(String(100), + doc="Repository entry version label for display (e.g. v1.0)") + agent_info_json = Column(JSONB, nullable=False, + doc="Frozen ExportAndImportDataFormat snapshot with optional skills") + status = Column(String(30), default="NOT_SHARED", + doc="Listing status: NOT_SHARED (未共享) / PENDING_REVIEW (待审核) / REJECTED (审核驳回) / SHARED (已共享)") class UserTokenInfo(TableBase): @@ -675,7 +747,8 @@ class UserTokenInfo(TableBase): token_id = Column(Integer, Sequence("user_token_info_t_token_id_seq", schema=SCHEMA), primary_key=True, nullable=False, doc="Token ID, unique primary key") access_key = Column(String(100), nullable=False, doc="Access Key (AK)") - user_id = Column(String(100), nullable=False, doc="User ID who owns this token") + user_id = Column(String(100), nullable=False, + doc="User ID who owns this token") class UserTokenUsageLog(TableBase): @@ -687,16 +760,21 @@ class UserTokenUsageLog(TableBase): token_usage_id = Column(Integer, Sequence("user_token_usage_log_t_token_usage_id_seq", schema=SCHEMA), primary_key=True, nullable=False, doc="Token usage log ID, unique primary key") - token_id = Column(Integer, nullable=False, doc="Foreign key to user_token_info_t.token_id") - call_function_name = Column(String(100), doc="API function name being called") - related_id = Column(Integer, doc="Related resource ID (e.g., conversation_id)") - meta_data = Column(JSONB, doc="Additional metadata for this usage log entry, stored as JSON") + token_id = Column(Integer, nullable=False, + doc="Foreign key to user_token_info_t.token_id") + call_function_name = Column( + String(100), doc="API function name being called") + related_id = Column( + Integer, doc="Related resource ID (e.g., conversation_id)") + meta_data = Column( + JSONB, doc="Additional metadata for this usage log entry, stored as JSON") class UserOAuthAccount(TableBase): __tablename__ = "user_oauth_account_t" __table_args__ = ( - UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"), + UniqueConstraint("provider", "provider_user_id", + name="uq_oauth_provider_user"), {"schema": SCHEMA}, ) @@ -714,11 +792,38 @@ class UserOAuthAccount(TableBase): provider_user_id = Column( String(200), nullable=False, doc="User ID from the OAuth provider" ) - provider_email = Column(String(255), doc="Email address from the OAuth provider") - provider_username = Column(String(200), doc="Display name from the OAuth provider") + provider_email = Column( + String(255), doc="Email address from the OAuth provider") + provider_username = Column( + String(200), doc="Display name from the OAuth provider") tenant_id = Column(String(100), doc="Tenant ID at time of linking") +class UserCasSession(TableBase): + __tablename__ = "user_cas_session_t" + __table_args__ = ( + Index("ix_user_cas_session_session_id", "session_id"), + Index("ix_user_cas_session_user_id", "user_id"), + Index("ix_user_cas_session_cas_user_id", "cas_user_id"), + {"schema": SCHEMA}, + ) + + cas_session_id = Column( + Integer, + Sequence("user_cas_session_t_cas_session_id_seq", schema=SCHEMA), + primary_key=True, + nullable=False, + doc="CAS session record ID", + ) + session_id = Column(String(100), nullable=False, unique=True, doc="JWT session ID") + user_id = Column(String(100), nullable=False, doc="Supabase user UUID") + cas_user_id = Column(String(200), nullable=False, doc="User ID from CAS") + cas_session_index = Column(String(500), doc="CAS SessionIndex or service ticket") + status = Column(String(30), nullable=False, default="active", doc="active/revoked") + expires_at = Column(TIMESTAMP(timezone=False), nullable=False, doc="Session expiration time") + revoked_at = Column(TIMESTAMP(timezone=False), doc="Revocation time") + + class SkillInfo(TableBase): """ Skill information table - stores skill metadata and content. @@ -728,13 +833,17 @@ class SkillInfo(TableBase): skill_id = Column(Integer, Sequence("ag_skill_info_t_skill_id_seq", schema=SCHEMA), primary_key=True, nullable=False, autoincrement=True, doc="Skill ID") - skill_name = Column(String(100), nullable=False, unique=True, doc="Unique skill name") - tenant_id = Column(String(100), nullable=True, doc="Tenant ID for multi-tenancy. NULL for pre-existing skills.") + skill_name = Column(String(100), nullable=False, + unique=True, doc="Unique skill name") + tenant_id = Column(String(100), nullable=True, + doc="Tenant ID for multi-tenancy. NULL for pre-existing skills.") skill_description = Column(String(1000), doc="Skill description") skill_tags = Column(JSON, doc="Skill tags as JSON array") skill_content = Column(Text, doc="Skill content in markdown format") - config_schemas = Column(JSON, doc="Parameter metadata from config/schema.yaml") - config_values = Column(JSON, doc="Runtime parameter values from config/config.yaml") + config_schemas = Column( + JSON, doc="Parameter metadata from config/schema.yaml") + config_values = Column( + JSON, doc="Runtime parameter values from config/config.yaml") source = Column(String(30), nullable=False, default="official", doc="Skill source: official, custom, etc.") @@ -748,8 +857,10 @@ class SkillToolRelation(TableBase): rel_id = Column(Integer, Sequence("ag_skill_tools_rel_t_rel_id_seq", schema=SCHEMA), primary_key=True, nullable=False, autoincrement=True, doc="Relation ID") - skill_id = Column(Integer, nullable=False, doc="Foreign key to ag_skill_info_t.skill_id") - tool_id = Column(Integer, nullable=False, doc="Foreign key to ag_tool_info_t.tool_id") + skill_id = Column(Integer, nullable=False, + doc="Foreign key to ag_skill_info_t.skill_id") + tool_id = Column(Integer, nullable=False, + doc="Foreign key to ag_tool_info_t.tool_id") class SkillInstance(TableBase): @@ -768,14 +879,19 @@ class SkillInstance(TableBase): nullable=False, doc="Skill instance ID" ) - skill_id = Column(Integer, nullable=False, doc="Foreign key to ag_skill_info_t.skill_id") + skill_id = Column(Integer, nullable=False, + doc="Foreign key to ag_skill_info_t.skill_id") agent_id = Column(Integer, nullable=False, doc="Agent ID") user_id = Column(String(100), doc="User ID") tenant_id = Column(String(100), doc="Tenant ID") - enabled = Column(Boolean, default=True, doc="Whether this skill is enabled for the agent") - version_no = Column(Integer, default=0, primary_key=True, nullable=False, doc="Version number. 0 = draft/editing state, >=1 = published snapshot") - config_values = Column(JSON, doc="Per-agent runtime parameter values (mirrors ag_tool_instance_t.params)") - config_schemas = Column(JSON, doc="Per-agent parameter schema overrides from config/schema.yaml") + enabled = Column(Boolean, default=True, + doc="Whether this skill is enabled for the agent") + version_no = Column(Integer, default=0, primary_key=True, nullable=False, + doc="Version number. 0 = draft/editing state, >=1 = published snapshot") + config_values = Column( + JSON, doc="Per-agent runtime parameter values (mirrors ag_tool_instance_t.params)") + config_schemas = Column( + JSON, doc="Per-agent parameter schema overrides from config/schema.yaml") class OuterApiService(TableBase): @@ -788,13 +904,16 @@ class OuterApiService(TableBase): id = Column(BigInteger, Sequence("ag_outer_api_services_id_seq", schema=SCHEMA), primary_key=True, nullable=False, doc="Service ID, unique primary key") - mcp_service_name = Column(String(100), nullable=False, doc="MCP service name (unique identifier per tenant)") + mcp_service_name = Column(String(100), nullable=False, + doc="MCP service name (unique identifier per tenant)") description = Column(Text, doc="Service description from OpenAPI info") openapi_json = Column(JSONB, doc="Complete OpenAPI JSON specification") server_url = Column(String(500), doc="Base URL of the REST API server") headers_template = Column(JSONB, doc="Default headers template as JSON") - tenant_id = Column(String(100), nullable=False, doc="Tenant ID for multi-tenancy") - is_available = Column(Boolean, default=True, doc="Whether the service is available") + tenant_id = Column(String(100), nullable=False, + doc="Tenant ID for multi-tenancy") + is_available = Column(Boolean, default=True, + doc="Whether the service is available") # Alias for backward compatibility @@ -809,27 +928,37 @@ class A2ANacosConfig(TableBase): __tablename__ = "ag_a2a_nacos_config_t" __table_args__ = {"schema": SCHEMA} - id = Column(BigInteger, primary_key=True, autoincrement=True, doc=_PRIMARY_KEY_DOC) - config_id = Column(String(64), unique=True, nullable=False, doc="Unique config identifier for API reference") + id = Column(BigInteger, primary_key=True, + autoincrement=True, doc=_PRIMARY_KEY_DOC) + config_id = Column(String(64), unique=True, nullable=False, + doc="Unique config identifier for API reference") # Nacos connection - nacos_addr = Column(String(512), nullable=False, doc="Nacos server address, e.g., http://nacos-server:8848") - nacos_username = Column(String(100), doc="Nacos username for authentication") - nacos_password = Column(String(256), doc="Nacos password, encrypted at rest") + nacos_addr = Column(String(512), nullable=False, + doc="Nacos server address, e.g., http://nacos-server:8848") + nacos_username = Column( + String(100), doc="Nacos username for authentication") + nacos_password = Column( + String(256), doc="Nacos password, encrypted at rest") # Discovery scope - namespace_id = Column(String(100), default="public", doc="Nacos namespace for service discovery") + namespace_id = Column(String(100), default="public", + doc="Nacos namespace for service discovery") # Metadata - name = Column(String(100), nullable=False, doc="Display name for this Nacos config") + name = Column(String(100), nullable=False, + doc="Display name for this Nacos config") description = Column(Text, doc="Description of this Nacos configuration") # Tenant isolation - tenant_id = Column(String(100), nullable=False, doc="Tenant ID for multi-tenancy") + tenant_id = Column(String(100), nullable=False, + doc="Tenant ID for multi-tenancy") # Status - is_active = Column(Boolean, default=True, doc="Whether this Nacos config is active") - last_scan_at = Column(TIMESTAMP(timezone=False), doc="Last time a scan was performed using this config") + is_active = Column(Boolean, default=True, + doc="Whether this Nacos config is active") + last_scan_at = Column(TIMESTAMP(timezone=False), + doc="Last time a scan was performed using this config") class A2AExternalAgent(TableBase): @@ -840,39 +969,49 @@ class A2AExternalAgent(TableBase): __tablename__ = "ag_a2a_external_agent_t" __table_args__ = {"schema": SCHEMA} - id = Column(BigInteger, primary_key=True, autoincrement=True, doc=_PRIMARY_KEY_DOC) + id = Column(BigInteger, primary_key=True, + autoincrement=True, doc=_PRIMARY_KEY_DOC) # Agent metadata (cached from Agent Card) - name = Column(String(255), nullable=False, doc="Agent name from Agent Card") + name = Column(String(255), nullable=False, + doc="Agent name from Agent Card") description = Column(Text, doc="Agent description from Agent Card") - version = Column(String(50), doc="Agent version from Agent Card, e.g., 1.2.0") + version = Column( + String(50), doc="Agent version from Agent Card, e.g., 1.2.0") # Primary interface (extracted from supportedInterfaces for quick access) # In A2A 1.0, this should store the http-json-rpc URL - agent_url = Column(String(512), nullable=False, doc="Primary A2A endpoint URL (http-json-rpc by default)") + agent_url = Column(String(512), nullable=False, + doc="Primary A2A endpoint URL (http-json-rpc by default)") # Protocol type for calling this agent: JSONRPC, HTTP+JSON, GRPC - protocol_type = Column(String(20), default=PROTOCOL_JSONRPC, doc="Protocol type for calling this agent") + protocol_type = Column(String(20), default=PROTOCOL_JSONRPC, + doc="Protocol type for calling this agent") # Capabilities - streaming = Column(Boolean, default=False, doc="Whether this agent supports SSE streaming") + streaming = Column(Boolean, default=False, + doc="Whether this agent supports SSE streaming") # All supported interfaces (full JSON array from Agent Card) # Format: [{protocolBinding, url, protocolVersion}, ...] supported_interfaces = Column(JSON, doc="All supported interfaces array") # Source information - source_type = Column(String(20), nullable=False, doc="Discovery source: url or nacos") + source_type = Column(String(20), nullable=False, + doc="Discovery source: url or nacos") # For URL mode source_url = Column(String(512), doc="Direct URL to agent card") # For Nacos mode - nacos_config_id = Column(String(64), doc="Reference to Nacos config used for discovery") - nacos_agent_name = Column(String(255), doc="Original name used for Nacos query") + nacos_config_id = Column( + String(64), doc="Reference to Nacos config used for discovery") + nacos_agent_name = Column( + String(255), doc="Original name used for Nacos query") # Base URL for infrastructure health checks - base_url = Column(String(512), doc="Base URL for health checks (service root address), e.g., http://agent:8080") + base_url = Column(String( + 512), doc="Base URL for health checks (service root address), e.g., http://agent:8080") # Tenant isolation tenant_id = Column(String(100), nullable=False, doc=_TENANT_ID_DOC) @@ -881,13 +1020,18 @@ class A2AExternalAgent(TableBase): raw_card = Column(JSON, doc="Full original Agent Card JSON from discovery") # Cache management - cached_at = Column(TIMESTAMP(timezone=False), doc="Timestamp when Agent Card was cached") - cache_expires_at = Column(TIMESTAMP(timezone=False), doc="Timestamp when cache expires") + cached_at = Column(TIMESTAMP(timezone=False), + doc="Timestamp when Agent Card was cached") + cache_expires_at = Column( + TIMESTAMP(timezone=False), doc="Timestamp when cache expires") # Health check status - is_available = Column(Boolean, default=True, doc="Whether this agent is currently reachable") - last_check_at = Column(TIMESTAMP(timezone=False), doc="Last health check timestamp") - last_check_result = Column(String(50), doc="Last health check result: OK, ERROR, TIMEOUT") + is_available = Column(Boolean, default=True, + doc="Whether this agent is currently reachable") + last_check_at = Column(TIMESTAMP(timezone=False), + doc="Last health check timestamp") + last_check_result = Column( + String(50), doc="Last health check result: OK, ERROR, TIMEOUT") class A2AExternalAgentRelation(TableBase): @@ -905,19 +1049,23 @@ class A2AExternalAgentRelation(TableBase): {"schema": SCHEMA}, ) - id = Column(BigInteger, primary_key=True, autoincrement=True, doc=_PRIMARY_KEY_DOC) + id = Column(BigInteger, primary_key=True, + autoincrement=True, doc=_PRIMARY_KEY_DOC) # Local agent (parent) - local_agent_id = Column(Integer, nullable=False, doc="Local parent agent ID") + local_agent_id = Column(Integer, nullable=False, + doc="Local parent agent ID") # External A2A agent (sub-agent) - FK to ag_a2a_external_agent_t.id - external_agent_id = Column(BigInteger, nullable=False, doc="External A2A agent ID (FK to ag_a2a_external_agent_t.id)") + external_agent_id = Column( + BigInteger, nullable=False, doc="External A2A agent ID (FK to ag_a2a_external_agent_t.id)") # Tenant isolation tenant_id = Column(String(100), nullable=False, doc=_TENANT_ID_DOC) # Status - is_enabled = Column(Boolean, default=True, doc="Whether this relation is active") + is_enabled = Column(Boolean, default=True, + doc="Whether this relation is active") class A2AServerAgent(TableBase): @@ -928,7 +1076,8 @@ class A2AServerAgent(TableBase): __tablename__ = "ag_a2a_server_agent_t" __table_args__ = {"schema": SCHEMA} - id = Column(BigInteger, primary_key=True, autoincrement=True, doc=_PRIMARY_KEY_DOC) + id = Column(BigInteger, primary_key=True, + autoincrement=True, doc=_PRIMARY_KEY_DOC) # Link to local agent agent_id = Column(Integer, nullable=False, doc="Local agent ID") @@ -938,35 +1087,44 @@ class A2AServerAgent(TableBase): tenant_id = Column(String(100), nullable=False, doc=_TENANT_ID_DOC) # Generated endpoint ID - endpoint_id = Column(String(64), unique=True, nullable=False, doc="Generated endpoint ID") + endpoint_id = Column(String(64), unique=True, + nullable=False, doc="Generated endpoint ID") # Basic info (extracted from local agent, can be overridden) - name = Column(String(255), nullable=False, doc="Agent name exposed in Agent Card") + name = Column(String(255), nullable=False, + doc="Agent name exposed in Agent Card") description = Column(Text, doc="Agent description exposed in Agent Card") version = Column(String(50), doc="Agent version exposed in Agent Card") # Primary endpoint URL (http-json-rpc by default) - agent_url = Column(String(512), doc="Primary A2A endpoint URL (http-json-rpc by default)") + agent_url = Column( + String(512), doc="Primary A2A endpoint URL (http-json-rpc by default)") # Capabilities - streaming = Column(Boolean, default=False, doc="Whether this agent supports SSE streaming") + streaming = Column(Boolean, default=False, + doc="Whether this agent supports SSE streaming") # All supported interfaces (A2A 1.0 compliant) # Format: [{protocolBinding, url, protocolVersion}, ...] - supported_interfaces = Column(JSON, doc="All supported interfaces: [{protocolBinding, url, protocolVersion}, ...]") + supported_interfaces = Column( + JSON, doc="All supported interfaces: [{protocolBinding, url, protocolVersion}, ...]") # Agent Card customization (partial overrides only) - card_overrides = Column(JSON, doc="User customizations for Agent Card (partial override)") + card_overrides = Column( + JSON, doc="User customizations for Agent Card (partial override)") # A2A Server status - is_enabled = Column(Boolean, default=False, doc="Whether A2A Server is enabled for this agent") + is_enabled = Column(Boolean, default=False, + doc="Whether A2A Server is enabled for this agent") # Raw Agent Card (generated from settings, for debugging) raw_card = Column(JSON, doc="Generated Agent Card JSON (for debugging)") # Publishing timestamps - published_at = Column(TIMESTAMP(timezone=False), doc="Timestamp when A2A Server was last enabled") - unpublished_at = Column(TIMESTAMP(timezone=False), doc="Timestamp when A2A Server was disabled") + published_at = Column(TIMESTAMP(timezone=False), + doc="Timestamp when A2A Server was last enabled") + unpublished_at = Column(TIMESTAMP(timezone=False), + doc="Timestamp when A2A Server was disabled") class A2ATask(SimpleTableBase): @@ -979,7 +1137,8 @@ class A2ATask(SimpleTableBase): # Core identifiers (following A2A spec) id = Column(String(64), primary_key=True, doc="Task ID (A2A spec: taskId)") - context_id = Column(String(64), doc="Context ID for grouping related tasks") + context_id = Column( + String(64), doc="Context ID for grouping related tasks") # Endpoint and caller info endpoint_id = Column(String(64), nullable=False, doc="Endpoint ID") @@ -990,16 +1149,21 @@ class A2ATask(SimpleTableBase): raw_request = Column(JSON, doc="Original A2A request payload") # Task state (following A2A TaskState enum) - task_state = Column(String(50), nullable=False, server_default="TASK_STATE_SUBMITTED", doc="Task state: TASK_STATE_SUBMITTED, TASK_STATE_WORKING, TASK_STATE_COMPLETED, TASK_STATE_FAILED, TASK_STATE_CANCELED, TASK_STATE_INPUT_REQUIRED, TASK_STATE_REJECTED, TASK_STATE_AUTH_REQUIRED") - state_timestamp = Column(TIMESTAMP(timezone=False), doc="Task state last update timestamp") + task_state = Column(String(50), nullable=False, server_default="TASK_STATE_SUBMITTED", + doc="Task state: TASK_STATE_SUBMITTED, TASK_STATE_WORKING, TASK_STATE_COMPLETED, TASK_STATE_FAILED, TASK_STATE_CANCELED, TASK_STATE_INPUT_REQUIRED, TASK_STATE_REJECTED, TASK_STATE_AUTH_REQUIRED") + state_timestamp = Column(TIMESTAMP(timezone=False), + doc="Task state last update timestamp") # Task result result_data = Column(JSON, doc="Task final result data") # Timestamps - create_time = Column(TIMESTAMP(timezone=False), server_default=func.now(), doc="Task creation timestamp") - update_time = Column(TIMESTAMP(timezone=False), server_default=func.now(), onupdate=func.now(), doc="Task last update timestamp") - completed_at = Column(TIMESTAMP(timezone=False), doc="Task completion timestamp") + create_time = Column(TIMESTAMP(timezone=False), + server_default=func.now(), doc="Task creation timestamp") + update_time = Column(TIMESTAMP(timezone=False), server_default=func.now( + ), onupdate=func.now(), doc="Task last update timestamp") + completed_at = Column(TIMESTAMP(timezone=False), + doc="Task completion timestamp") class A2AMessage(SimpleTableBase): @@ -1011,23 +1175,30 @@ class A2AMessage(SimpleTableBase): __table_args__ = {"schema": SCHEMA} # Core identifiers (following A2A spec) - message_id = Column(String(64), primary_key=True, doc="Message ID (A2A spec: messageId)") - task_id = Column(String(64), nullable=True, doc="Task ID this message belongs to (nullable for standalone/simple requests)") + message_id = Column(String(64), primary_key=True, + doc="Message ID (A2A spec: messageId)") + task_id = Column(String(64), nullable=True, + doc="Task ID this message belongs to (nullable for standalone/simple requests)") # Message attributes - message_index = Column(Integer, nullable=False, doc="Order of message in the conversation") - role = Column(String(20), nullable=False, doc="Message sender role: user or agent") + message_index = Column(Integer, nullable=False, + doc="Order of message in the conversation") + role = Column(String(20), nullable=False, + doc="Message sender role: user or agent") # Message content (following A2A Part structure) - parts = Column(JSON, nullable=False, doc="Message parts following A2A Part structure") + parts = Column(JSON, nullable=False, + doc="Message parts following A2A Part structure") meta_data = Column(JSON, doc="Optional metadata") extensions = Column(JSON, doc="Extension URI list") # References to other tasks (optional) - reference_task_ids = Column(JSON, doc="Referenced task IDs array for multi-turn scenarios") + reference_task_ids = Column( + JSON, doc="Referenced task IDs array for multi-turn scenarios") # Timestamp - create_time = Column(TIMESTAMP(timezone=False), server_default=func.now(), doc="Message creation timestamp") + create_time = Column(TIMESTAMP( + timezone=False), server_default=func.now(), doc="Message creation timestamp") class A2AArtifact(SimpleTableBase): @@ -1039,15 +1210,19 @@ class A2AArtifact(SimpleTableBase): # Core identifiers (following A2A spec) id = Column(String(64), primary_key=True, doc="Internal primary key") - artifact_id = Column(String(64), nullable=False, doc="Artifact ID (A2A spec: artifactId)") - task_id = Column(String(64), nullable=False, doc="Task ID this artifact belongs to") + artifact_id = Column(String(64), nullable=False, + doc="Artifact ID (A2A spec: artifactId)") + task_id = Column(String(64), nullable=False, + doc="Task ID this artifact belongs to") # Artifact attributes name = Column(String(255), doc="Human-readable artifact name") description = Column(Text, doc="Artifact description") - parts = Column(JSON, nullable=False, doc="Artifact parts following A2A Part structure") + parts = Column(JSON, nullable=False, + doc="Artifact parts following A2A Part structure") meta_data = Column(JSON, doc="Artifact metadata") extensions = Column(JSON, doc="Extension URI list") # Timestamp - create_time = Column(TIMESTAMP(timezone=False), server_default=func.now(), doc="Artifact creation timestamp") + create_time = Column(TIMESTAMP( + timezone=False), server_default=func.now(), doc="Artifact creation timestamp") diff --git a/backend/database/knowledge_db.py b/backend/database/knowledge_db.py index 9a8b1c8c1..8fc60d6bd 100644 --- a/backend/database/knowledge_db.py +++ b/backend/database/knowledge_db.py @@ -34,6 +34,7 @@ def create_knowledge_record(query: Dict[str, Any]) -> Dict[str, Any]: - user_id: Optional user ID for created_by and updated_by fields - tenant_id: Optional tenant ID for created_by and updated_by fields - embedding_model_name: embedding model name for the knowledge base + - preserve_source_file: whether to preserve uploaded source documents (optional) Returns: Dict[str, Any]: Dictionary with at least 'knowledge_id' and 'index_name' @@ -57,6 +58,7 @@ def create_knowledge_record(query: Dict[str, Any]) -> Dict[str, Any]: "knowledge_name": knowledge_name, "group_ids": convert_list_to_string(group_ids) if isinstance(group_ids, list) else group_ids, "ingroup_permission": query.get("ingroup_permission"), + "preserve_source_file": query.get("preserve_source_file", True), } # For backward compatibility: if caller explicitly provides index_name, @@ -117,11 +119,16 @@ def upsert_knowledge_record(query: Dict[str, Any]) -> Dict[str, Any]: if existing_record: # Update existing record - existing_record.knowledge_name = query.get('knowledge_name') or query.get('index_name') - existing_record.knowledge_describe = query.get('knowledge_describe', '') - existing_record.knowledge_sources = query.get('knowledge_sources', 'elasticsearch') - existing_record.embedding_model_name = query.get('embedding_model_name') - existing_record.embedding_model_id = query.get('embedding_model_id') + existing_record.knowledge_name = query.get( + 'knowledge_name') or query.get('index_name') + existing_record.knowledge_describe = query.get( + 'knowledge_describe', '') + existing_record.knowledge_sources = query.get( + 'knowledge_sources', 'elasticsearch') + existing_record.embedding_model_name = query.get( + 'embedding_model_name') + existing_record.embedding_model_id = query.get( + 'embedding_model_id') existing_record.updated_by = query.get('user_id') existing_record.update_time = func.current_timestamp() @@ -183,7 +190,7 @@ def update_knowledge_record(query: Dict[str, Any]) -> bool: # Update group IDs if query.get("group_ids") is not None: record.group_ids = query["group_ids"] - + # Update timestamp and user if query.get("user_id"): record.updated_by = query["user_id"] @@ -251,15 +258,17 @@ def get_knowledge_record(query: Optional[Dict[str, Any]] = None) -> Dict[str, An # Support both index_name and knowledge_name queries if 'index_name' in query: - db_query = db_query.filter(KnowledgeRecord.index_name == query['index_name']) + db_query = db_query.filter( + KnowledgeRecord.index_name == query['index_name']) elif 'knowledge_name' in query: - db_query = db_query.filter(KnowledgeRecord.knowledge_name == query['knowledge_name']) + db_query = db_query.filter( + KnowledgeRecord.knowledge_name == query['knowledge_name']) # Add tenant_id filter only if it is provided in the query if 'tenant_id' in query and query['tenant_id'] is not None: db_query = db_query.filter( KnowledgeRecord.tenant_id == query['tenant_id']) - + result = db_query.first() if result: diff --git a/backend/database/user_tenant_db.py b/backend/database/user_tenant_db.py index f1294f8a7..b147eac49 100644 --- a/backend/database/user_tenant_db.py +++ b/backend/database/user_tenant_db.py @@ -75,6 +75,37 @@ def insert_user_tenant(user_id: str, tenant_id: str, user_role: str = "USER", us session.add(user_tenant) +def upsert_user_tenant(user_id: str, tenant_id: str, user_role: str = "USER", user_email: str = None) -> Dict[str, Any]: + """ + Create or update the active user-tenant relationship for an external identity login. + """ + with get_db_session() as session: + result = session.query(UserTenant).filter( + UserTenant.user_id == user_id, + UserTenant.delete_flag == "N" + ).first() + + if result: + result.tenant_id = tenant_id + result.user_role = user_role + if user_email is not None: + result.user_email = user_email + result.updated_by = user_id + else: + result = UserTenant( + user_id=user_id, + tenant_id=tenant_id, + user_role=user_role, + user_email=user_email, + created_by=user_id, + updated_by=user_id + ) + session.add(result) + + session.flush() + return as_dict(result) + + def get_users_by_tenant_id(tenant_id: str, page: Optional[int] = 1, page_size: Optional[int] = 20, sort_by: str = "created_at", sort_order: str = "desc") -> Dict[str, Any]: """ diff --git a/backend/mcp_service.py b/backend/mcp_service.py index 0d8ab4c1b..4629d42ad 100644 --- a/backend/mcp_service.py +++ b/backend/mcp_service.py @@ -70,7 +70,7 @@ async def run(self, arguments: Dict[str, Any]) -> Any: nexent_mcp = FastMCP(name="nexent_mcp") -nexent_mcp.mount(local_mcp_service.name, local_mcp_service) +nexent_mcp.mount(local_mcp_service, local_mcp_service.name) _openapi_mcp_services: Dict[str, FastMCP] = {} @@ -188,7 +188,8 @@ def _sanitize_function_name(name: str) -> str: def register_openapi_service( service_name: str, openapi_json: Dict[str, Any], - server_url: str + server_url: str, + headers_template: Dict[str, str], ) -> bool: """ Register an OpenAPI service using FastMCP.from_openapi(). @@ -222,7 +223,7 @@ def register_openapi_service( openapi_spec["servers"] = [{"url": server_url}] # Create HTTP client for the underlying REST API - client = httpx.AsyncClient(base_url=server_url, timeout=30.0) + client = httpx.AsyncClient(base_url=server_url, timeout=120.0, headers=headers_template) # Create FastMCP instance from OpenAPI spec mcp_server = FastMCP.from_openapi( @@ -239,7 +240,7 @@ def register_openapi_service( _openapi_mcp_services[service_name] = mcp_server # Mount to the main MCP server - nexent_mcp.mount(service_name, mcp_server) + nexent_mcp.mount(mcp_server, service_name) logger.info(f"Registered OpenAPI service: {service_name}") return True @@ -320,13 +321,14 @@ def refresh_openapi_services_by_tenant(tenant_id: str) -> Dict[str, Any]: service_name = service.get("mcp_service_name") openapi_json = service.get("openapi_json") server_url = service.get("server_url") + headers_template = service.get("headers_template") if not openapi_json: logger.warning(f"Service '{service_name}' has no OpenAPI JSON, skipping") skipped_count += 1 continue - if register_openapi_service(service_name, openapi_json, server_url): + if register_openapi_service(service_name, openapi_json, server_url, headers_template): registered_count += 1 else: skipped_count += 1 @@ -394,6 +396,7 @@ def refresh_single_openapi_service(service_name: str, tenant_id: str) -> Dict[st # Re-register with fresh data openapi_json = service_data.get("openapi_json") server_url = service_data.get("server_url") + headers_template = service_data.get("headers_template") if not openapi_json: logger.warning(f"Service '{service_name}' has no OpenAPI JSON") @@ -403,7 +406,7 @@ def refresh_single_openapi_service(service_name: str, tenant_id: str) -> Dict[st "error": "No OpenAPI JSON found" } - success = register_openapi_service(service_name, openapi_json, server_url) + success = register_openapi_service(service_name, openapi_json, server_url, headers_template) return { "status": "refreshed" if success else "error", "service_name": service_name, diff --git a/backend/prompts/managed_system_prompt_template_en.yaml b/backend/prompts/managed_system_prompt_template_en.yaml index 5c2893c39..62e16e946 100644 --- a/backend/prompts/managed_system_prompt_template_en.yaml +++ b/backend/prompts/managed_system_prompt_template_en.yaml @@ -1,6 +1,6 @@ system_prompt: |- ### Basic Information - You are {{APP_NAME}}, {{APP_DESCRIPTION}}, it is {{time|default('current time')}} now + You are {{APP_NAME}}, {{APP_DESCRIPTION}} {%- if memory_list and memory_list|length > 0 %} ### Contextual Memory @@ -66,6 +66,11 @@ system_prompt: |- - Note that executed code is not visible to users. If users need to see the code, use 'code' for displaying code. - **IMPORTANT**: After code execution, the system will return content with "Observation:" marker (this is the real execution result). Please continue your next thinking based on these real results. **Do NOT fabricate observation results before code execution.** + 3. Self-verification: + - After critical events (tool calls, retrieval results, code execution, and final-answer preparation), the system may run explicit verification. + - If verification reports errors, insufficient evidence, incomplete parameters, or unreliable results, you must repair the issue, gather more evidence, call tools again, or clearly state what cannot be completed. + - The final answer is shown to the user only after verification passes. If the system returns Verification feedback, treat it as a real observation and continue revising. + After thinking, when you believe you can answer the user's question, you can generate a final answer directly to the user without generating code and stop the loop. When generating the final answer, you need to follow these specifications: @@ -178,3 +183,13 @@ final_answer: Original task: {{task}} Please provide a clear and concise summary of the work completed so far. + + +verification: + pre_messages: |- + You are a strict verifier for a ReAct agent. Judge reliability only from the task, candidate answer, tool outputs, and observations. Do not output hidden chain-of-thought. + You must output JSON only. + + post_messages: |- + Verify whether the candidate answer covers the user's intent, is grounded in observations, handles tool errors, uses trustworthy citations, and is formatted for users. + Output fields: passed, score, status, failed_criteria, checks, revision_instruction, user_visible_note. diff --git a/backend/prompts/managed_system_prompt_template_zh.yaml b/backend/prompts/managed_system_prompt_template_zh.yaml index 291e336fb..da3d53469 100644 --- a/backend/prompts/managed_system_prompt_template_zh.yaml +++ b/backend/prompts/managed_system_prompt_template_zh.yaml @@ -2,7 +2,7 @@ system_prompt: |- ### 基本信息 - 你是{{APP_NAME}},{{APP_DESCRIPTION}},现在是{{time|default('当前时间')}},用户ID为{{user_id}} + 你是{{APP_NAME}},{{APP_DESCRIPTION}},用户ID为{{user_id}} {%- if memory_list and memory_list|length > 0 %} ### 上下文记忆 @@ -130,6 +130,11 @@ system_prompt: |- - 注意运行的代码不会被用户看到,所以如果用户需要看到代码,你需要使用'代码'表达展示代码。 - **重要**:代码执行后,系统会返回 "Observation:" 标记的内容(这是真实的执行结果)。请基于这些真实结果继续下一步思考,**不要在代码执行前自行编造观察结果**。 + 3. 自验证: + - 关键事件(工具调用、检索结果、代码执行、准备最终回答)后,系统会进行显式自验证。 + - 如果自验证提示存在错误、证据不足、参数不完整或结果不可靠,必须优先修正、补充证据、重新调用工具,或清晰说明无法完成的部分。 + - 最终回答只有在自验证通过后才会展示给用户;如果系统返回 Verification feedback,请把它视为真实观察结果继续修正,不要忽略。 + 在思考结束后,当你认为可以回答用户问题,那么可以不生成代码,直接生成最终回答给到用户并停止循环。 生成最终回答时,你需要遵循以下规范: @@ -271,3 +276,13 @@ final_answer: 原始任务:{{task}} 请对迄今为止完成的工作进行清晰、简洁的总结。 + + +verification: + pre_messages: |- + 你是 ReAct 智能体的严格验证器。请仅根据任务、候选答案、工具输出和观察结果判断答案是否可靠,不要输出隐藏思维链。 + 你必须只输出 JSON。 + + post_messages: |- + 请验证候选答案是否覆盖用户意图、是否有观察结果支撑、是否处理了工具错误、引用是否可信、格式是否适合展示。 + 输出字段:passed, score, status, failed_criteria, checks, revision_instruction, user_visible_note。 diff --git a/backend/prompts/manager_system_prompt_template_en.yaml b/backend/prompts/manager_system_prompt_template_en.yaml index 8ce58db29..d44ed9a71 100644 --- a/backend/prompts/manager_system_prompt_template_en.yaml +++ b/backend/prompts/manager_system_prompt_template_en.yaml @@ -1,6 +1,6 @@ system_prompt: |- ### Basic Information - You are {{APP_NAME}}, {{APP_DESCRIPTION}}, it is {{time|default('current time')}} now + You are {{APP_NAME}}, {{APP_DESCRIPTION}} {%- if memory_list and memory_list|length > 0 %} ### Contextual Memory @@ -67,6 +67,11 @@ system_prompt: |- - Note that executed code is not visible to users. If users need to see the code, use 'code' for displaying code. - **IMPORTANT**: After code execution, the system will return content with "Observation:" marker (this is the real execution result). Please continue your next thinking based on these real results. **Do NOT fabricate observation results before code execution.** + 3. Self-verification: + - After critical events (tool calls, retrieval results, code execution, agent handoffs, and final-answer preparation), the system may run explicit verification. + - If verification reports errors, insufficient evidence, incomplete parameters, or unreliable results, you must repair the issue, gather more evidence, call tools again, or clearly state what cannot be completed. + - The final answer is shown to the user only after verification passes. If the system returns Verification feedback, treat it as a real observation and continue revising. + After thinking, when you believe you can answer the user's question, you can generate a final answer directly to the user without generating code and stop the loop. When generating the final answer, you need to follow these specifications: @@ -222,3 +227,13 @@ final_answer: Original task: {{task}} Please provide a clear and concise summary of the work completed so far. + + +verification: + pre_messages: |- + You are a strict verifier for a ReAct agent. Judge reliability only from the task, candidate answer, tool outputs, and observations. Do not output hidden chain-of-thought. + You must output JSON only. + + post_messages: |- + Verify whether the candidate answer covers the user's intent, is grounded in observations, handles tool errors, uses trustworthy citations, and is formatted for users. + Output fields: passed, score, status, failed_criteria, checks, revision_instruction, user_visible_note. diff --git a/backend/prompts/manager_system_prompt_template_zh.yaml b/backend/prompts/manager_system_prompt_template_zh.yaml index fc4eb7c0c..a49ced82d 100644 --- a/backend/prompts/manager_system_prompt_template_zh.yaml +++ b/backend/prompts/manager_system_prompt_template_zh.yaml @@ -1,6 +1,6 @@ system_prompt: |- ### 基本信息 - 你是{{APP_NAME}},{{APP_DESCRIPTION}},现在是{{time|default('当前时间')}},用户ID为{{user_id}} + 你是{{APP_NAME}},{{APP_DESCRIPTION}},用户ID为{{user_id}} {%- if memory_list and memory_list|length > 0 %} ### 上下文记忆 @@ -130,6 +130,11 @@ system_prompt: |- - 注意运行的代码不会被用户看到,所以如果用户需要看到代码,你需要使用'代码'表达展示代码。 - **重要**:代码执行后,系统会返回 "Observation:" 标记的内容(这是真实的执行结果)。请基于这些真实结果继续下一步思考,**不要在代码执行前自行编造观察结果**。 + 3. 自验证: + - 关键事件(工具调用、检索结果、代码执行、助手返回、准备最终回答)后,系统会进行显式自验证。 + - 如果自验证提示存在错误、证据不足、参数不完整或结果不可靠,必须优先修正、补充证据、重新调用工具,或清晰说明无法完成的部分。 + - 最终回答只有在自验证通过后才会展示给用户;如果系统返回 Verification feedback,请把它视为真实观察结果继续修正,不要忽略。 + 在思考结束后,当你认为可以回答用户问题,那么可以不生成代码,直接生成最终回答给到用户并停止循环。 生成最终回答时,你需要遵循以下规范: @@ -299,3 +304,13 @@ final_answer: 原始任务:{{task}} 请对迄今为止完成的工作进行清晰、简洁的总结。 + + +verification: + pre_messages: |- + 你是 ReAct 智能体的严格验证器。请仅根据任务、候选答案、工具输出和观察结果判断答案是否可靠,不要输出隐藏思维链。 + 你必须只输出 JSON。 + + post_messages: |- + 请验证候选答案是否覆盖用户意图、是否有观察结果支撑、是否处理了工具错误、引用是否可信、格式是否适合展示。 + 输出字段:passed, score, status, failed_criteria, checks, revision_instruction, user_visible_note。 diff --git a/backend/prompts/utils/greeting_generate_en.yaml b/backend/prompts/utils/greeting_generate_en.yaml new file mode 100644 index 000000000..31ea75632 --- /dev/null +++ b/backend/prompts/utils/greeting_generate_en.yaml @@ -0,0 +1,54 @@ +GREETING_SYSTEM_PROMPT: |- + ### You are an expert in generating agent greetings and example questions. You help users create engaging greetings and practical example questions for starting conversations with agents. + You are building an Agent application. The input includes: agent name, duty description, business description, and existing examples. + Generate a concise greeting and 3-5 example questions that help users quickly start a conversation with the agent. + The greeting should reflect the agent's positioning and capabilities. + + ### Requirements: + 1. The greeting should be concise and friendly, 1-2 sentences, introducing the agent's identity and core capabilities. Don't make it too long or too formal. + 2. Example questions should be specific and practical, representing questions users might actually ask, showcasing the agent's core features. + 3. If existing examples contain user query scenarios, prioritize extracting short user questions from them, keeping semantics consistent but simplified to natural conversational form. + 4. Provide 3-5 example questions, each with a clear use case. + 5. You MUST output strictly in JSON format, do not output any other content or formatting. + + ### Output format: + ```json + { + "greeting_message": "greeting content", + "example_questions": ["example question 1", "example question 2", "example question 3"] + } + ``` + + ### Examples: + Example 1 (Travel Planning Assistant, existing examples contain "Help me plan a trip from Shanghai to Beijing" etc.): + ```json + { + "greeting_message": "Hello! I'm your travel planning assistant, I can help you plan trips, recommend attractions, and arrange travel routes.", + "example_questions": ["Help me plan a 3-day trip from Shanghai to Beijing", "Recommend some family-friendly attractions", "What's fun to do in Hangzhou tomorrow?"] + } + ``` + + Example 2 (Data Analysis Assistant): + ```json + { + "greeting_message": "Hello! I'm a data analysis assistant, I can help you process and analyze data, provide visual reports and insights.", + "example_questions": ["Help me analyze trends in this sales data", "Generate a quarterly performance comparison report", "Which products have the highest profit margins?"] + } + ``` + +USER_PROMPT: |- + ### Agent Name: + {{display_name}} + + ### Agent Duty Description: + {{duty_description}} + + ### Business Description: + {{business_description}} + + {% if few_shots %} + ### Existing Examples (extract user query scenarios from these as example questions): + {{few_shots}} + {% endif %} + + Please generate the greeting and example questions based on the above information. Output strictly in JSON format. \ No newline at end of file diff --git a/backend/prompts/utils/greeting_generate_zh.yaml b/backend/prompts/utils/greeting_generate_zh.yaml new file mode 100644 index 000000000..34b8d85d3 --- /dev/null +++ b/backend/prompts/utils/greeting_generate_zh.yaml @@ -0,0 +1,53 @@ +GREETING_SYSTEM_PROMPT: |- + ### 你是【智能体开场白和示例问题生成专家】,用于帮助用户创建高效、吸引人的智能体开场白和示例问题。 + 现在正在构建一个Agent应用,用户的输入包含:智能体名称、职责描述、业务描述、已有示例。 + 请根据智能体的定位和职责,生成一个简短的开场白和3~5个示例问题,帮助用户快速开始与智能体的对话。 + + ### 要求: + 1.开场白要简洁友好,1-2句话即可,介绍智能体的身份和核心能力,不要过长或过于正式。 + 2.示例问题要具体、实用,是用户真实可能提出的问题,体现智能体的核心功能。 + 3.如果已有示例中包含用户的提问场景,请优先从中提炼简短的用户问题作为示例问题,保持语义一致但简化为自然对话形式。 + 4.示例问题数量为3~5个,每个问题要有明确的使用场景。 + 5.必须严格按照JSON格式输出,不要输出任何其他内容或格式。 + + ### 输出格式: + ```json + { + "greeting_message": "开场白内容", + "example_questions": ["示例问题1", "示例问题2", "示例问题3"] + } + ``` + + ### 参考示例: + 示例1(旅行规划助手,已有示例包含"帮我规划明天从上海出发去北京的行程"等场景): + ```json + { + "greeting_message": "你好!我是你的旅行规划助手,可以帮你规划行程、推荐景点和安排出行路线。", + "example_questions": ["帮我规划一个从上海到北京的三日旅行", "推荐一些适合家庭出游的景点", "明天去杭州有什么好玩的地方?"] + } + ``` + + 示例2(数据分析助手): + ```json + { + "greeting_message": "你好!我是数据分析助手,可以帮你处理和分析各种数据,提供可视化报告和洞察。", + "example_questions": ["帮我分析这组销售数据的趋势", "生成一份季度业绩对比报告", "哪些产品的利润率最高?"] + } + ``` + +USER_PROMPT: |- + ### 智能体名称: + {{display_name}} + + ### 智能体职责描述: + {{duty_description}} + + ### 业务描述: + {{business_description}} + + {% if few_shots %} + ### 已有示例(请从中提炼用户提问场景作为示例问题): + {{few_shots}} + {% endif %} + + 请根据以上信息生成开场白和示例问题。严格按JSON格式输出。 \ No newline at end of file diff --git a/backend/pyproject.toml b/backend/pyproject.toml index dff0e8693..b8f51dd4c 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "backend" version = "0.1.0" -requires-python = "==3.10.*" +requires-python = ">=3.11,<3.12" dependencies = [ "aiofiles>=0.8.0", "uvicorn>=0.34.0", @@ -11,7 +11,7 @@ dependencies = [ "aiohttp>=3.8.0", "authlib>=1.3.0", "cryptography>=42.0.0", - "psycopg2-binary==2.9.10", + "psycopg2-binary>=2.9.9", "PyJWT>=2.8.0", "sqlalchemy~=2.0.37", "greenlet<3.5.0", @@ -21,10 +21,14 @@ dependencies = [ "jsonref>=1.1.0", "ruamel-yaml==0.19.1", "redis>=5.0.0", - "fastmcp==2.12.0", + "fastmcp>=2.14.2,<3.0", "langchain>=0.3.26", "scikit-learn>=1.0.0", "numpy>=1.24.0", + "defusedxml>=0.7.1", + "openjiuwen>=0.1.0", + "pydantic-settings>=2.0.0", + "python-docx>=1.1.0", ] [project.optional-dependencies] @@ -34,7 +38,7 @@ data-process = [ "flower>=2.0.1", "nest_asyncio>=1.5.6", "unstructured[csv,docx,pdf,pptx,xlsx,md]==0.18.14", - "huggingface_hub>=0.19.0,<0.21.0" + "huggingface_hub>=0.30.0,<1.0" ] test = [ "pytest", diff --git a/backend/services/agent_repository_service.py b/backend/services/agent_repository_service.py new file mode 100644 index 000000000..87649bcd1 --- /dev/null +++ b/backend/services/agent_repository_service.py @@ -0,0 +1,306 @@ +import logging +from typing import Any, Dict, Optional + +from consts.const import ASSET_OWNER_TENANT_ID +from consts.model import AgentRepositorySnapshot +from database.agent_db import search_agent_info_by_agent_id +from database.agent_version_db import search_version_by_version_no +from database.agent_repository_db import ( + STATUS_PENDING_REVIEW, + VALID_REPOSITORY_STATUSES, + get_agent_repository_by_agent_id, + get_agent_repository_by_id, + insert_agent_repository_record, + list_agent_repository_summaries, + update_agent_repository_by_id, + update_agent_repository_status_by_id, +) +from services.agent_service import ( + collect_skill_zip_entries, + export_agent_dict_for_repository_impl, + import_agent_impl, + import_agent_with_skills_impl, +) + +logger = logging.getLogger("agent_repository_service") + +_UPDATE_SNAPSHOT_FIELDS = ( + "display_name", + "description", + "author", + "category_id", + "tags", + "tool_count", + "version_label", + "source_version_no", + "agent_info_json", + "status", +) + + +def _to_summary_item(record: Dict[str, Any]) -> Dict[str, Any]: + """Map a DB record to a lightweight marketplace summary item.""" + return { + "agent_repository_id": record.get("agent_repository_id"), + "author": record.get("author"), + "name": record.get("name"), + "display_name": record.get("display_name"), + "description": record.get("description"), + "status": record.get("status"), + } + + +def list_agent_repository_listings_impl( + *, + status: Optional[str] = None, +) -> Dict[str, Any]: + """List all repository listings with optional status filter.""" + if status is not None and status not in VALID_REPOSITORY_STATUSES: + raise ValueError( + f"Invalid status '{status}'; must be one of: " + f"{', '.join(sorted(VALID_REPOSITORY_STATUSES))}" + ) + records = list_agent_repository_summaries(status=status) + return {"items": [_to_summary_item(record) for record in records]} + + +def update_agent_repository_status_impl( + *, + agent_repository_id: int, + status: str, + user_id: str, +) -> Dict[str, Any]: + """Update a repository listing status by primary key.""" + if status not in VALID_REPOSITORY_STATUSES: + raise ValueError( + f"Invalid status '{status}'; must be one of: " + f"{', '.join(sorted(VALID_REPOSITORY_STATUSES))}" + ) + + record = get_agent_repository_by_id(agent_repository_id) + if not record: + raise ValueError("Repository listing not found") + + rows_affected = update_agent_repository_status_by_id( + repository_id=agent_repository_id, + status=status, + user_id=user_id, + ) + if rows_affected == 0: + raise ValueError("Repository listing not found") + + updated = get_agent_repository_by_id(agent_repository_id) + if not updated: + raise ValueError("Failed to load repository listing after update") + return _to_summary_item(updated) + + +def _to_list_item(record: Dict[str, Any]) -> Dict[str, Any]: + """Map a DB record to a marketplace list item (without heavy JSON blobs).""" + return { + "id": record.get("agent_repository_id"), + "agent_repository_id": record.get("agent_repository_id"), + "agent_id": record.get("agent_id"), + "name": record.get("name"), + "display_name": record.get("display_name"), + "description": record.get("description"), + "author": record.get("author"), + "category_id": record.get("category_id"), + "tags": record.get("tags") or [], + "tool_count": record.get("tool_count"), + "version_label": record.get("version_label"), + "status": record.get("status"), + "source_version_no": record.get("source_version_no"), + "publisher_tenant_id": record.get("publisher_tenant_id"), + "created_at": record.get("create_time"), + "updated_at": record.get("update_time"), + } + + +def _to_detail_item( + record: Dict[str, Any], + *, + include_bundles: bool = True, + is_updated: Optional[bool] = None, +) -> Dict[str, Any]: + """Map a DB record to a marketplace detail payload.""" + detail = _to_list_item(record) + if include_bundles: + detail["agent_info_json"] = record.get("agent_info_json") + if is_updated is not None: + detail["is_updated"] = is_updated + return detail + + +def _validate_create_payload(repository_data: Dict[str, Any]) -> None: + """Validate required fields before inserting a repository listing.""" + required_fields = ( + "agent_id", + "source_version_no", + "name", + "agent_info_json", + ) + missing = [ + field for field in required_fields + if field not in repository_data or repository_data[field] is None + ] + if missing: + raise ValueError(f"Missing required repository fields: {', '.join(missing)}") + if not repository_data.get("name"): + raise ValueError("name must be a non-empty string") + + agent_info_json = repository_data.get("agent_info_json") + if not isinstance(agent_info_json, dict): + raise ValueError("agent_info_json must be a JSON object") + for key in ("agent_id", "agent_info", "mcp_info"): + if key not in agent_info_json: + raise ValueError(f"agent_info_json must contain '{key}'") + + +def _validate_agent_info_json_shareable(agent_info_json: dict) -> None: + """Reject marketplace share when any agent in the tree belongs to ASSET_OWNER tenant.""" + agent_info_map = agent_info_json.get("agent_info") + if not isinstance(agent_info_map, dict): + return + for entry in agent_info_map.values(): + if not isinstance(entry, dict): + continue + if entry.get("tenant_id") == ASSET_OWNER_TENANT_ID: + raise ValueError("租户管理员智能体无法共享") + + +async def _build_agent_info_json( + agent_id: int, + tenant_id: str, + user_id: str, + version_no: int, +) -> dict: + """Build marketplace snapshot JSON via the agent export pipeline.""" + export_dict = await export_agent_dict_for_repository_impl( + agent_id=agent_id, + tenant_id=tenant_id, + user_id=user_id, + version_no=version_no, + ) + skills = collect_skill_zip_entries( + agent_id=agent_id, + tenant_id=tenant_id, + version_no=version_no, + ) + snapshot = AgentRepositorySnapshot( + **export_dict, + skills=skills or None, + ) + return snapshot.model_dump() + + +async def _build_repository_data_from_agent( + agent_id: int, + tenant_id: str, + user_id: str, + version_no: int, +) -> Dict[str, Any]: + """Build a repository upsert payload from a published agent version snapshot.""" + agent_info = search_agent_info_by_agent_id(agent_id, tenant_id, version_no) + agent_info_json = await _build_agent_info_json( + agent_id=agent_id, + tenant_id=tenant_id, + user_id=user_id, + version_no=version_no, + ) + _validate_agent_info_json_shareable(agent_info_json) + + version_meta = search_version_by_version_no(agent_id, tenant_id, version_no) + version_label = ( + version_meta.get("version_name") + if version_meta and version_meta.get("version_name") + else f"v{version_no}" + ) + + return { + "agent_id": agent_id, + "source_version_no": version_no, + "name": agent_info["name"], + "display_name": agent_info.get("display_name"), + "description": agent_info.get("description"), + "author": agent_info.get("author"), + "version_label": version_label, + "agent_info_json": agent_info_json, + "status": STATUS_PENDING_REVIEW, + } + + +async def create_agent_repository_listing_impl( + agent_id: int, + tenant_id: str, + user_id: str, + version_no: int, +) -> Dict[str, Any]: + """Create or update a repository listing from a published agent version. + + Loads agent metadata and builds agent_info_json via the export pipeline, + then inserts or updates the marketplace table. + + When a listing for the same agent_id already exists, snapshot fields are + updated via update_agent_repository_by_id. + """ + if version_no < 0: + raise ValueError("version_no must be >= 0") + + repository_data = await _build_repository_data_from_agent( + agent_id, tenant_id, user_id, version_no + ) + _validate_create_payload(repository_data) + + existing = get_agent_repository_by_agent_id(agent_id) + if not existing: + repository_id = insert_agent_repository_record( + repository_data=repository_data, + publisher_tenant_id=tenant_id, + publisher_user_id=user_id, + ) + is_updated = False + else: + repository_id = int(existing["agent_repository_id"]) + updates = { + key: repository_data[key] + for key in _UPDATE_SNAPSHOT_FIELDS + if key in repository_data + } + affected = update_agent_repository_by_id( + repository_id=repository_id, + publisher_tenant_id=tenant_id, + user_id=user_id, + updates=updates, + ) + if affected == 0: + raise ValueError("Failed to update repository listing") + is_updated = True + + record = get_agent_repository_by_id(repository_id) + if not record: + raise ValueError("Failed to load repository listing after write") + return _to_detail_item(record, is_updated=is_updated) + + +async def import_agent_from_repository_impl( + agent_repository_id: int, + authorization: str, +) -> Dict[int, int]: + """Import an agent tree from a marketplace repository listing into the current tenant.""" + record = get_agent_repository_by_id(agent_repository_id) + if not record: + raise ValueError("Repository listing not found") + + agent_info_json = record.get("agent_info_json") + if not isinstance(agent_info_json, dict): + raise ValueError("Repository listing has no agent snapshot") + + snapshot = AgentRepositorySnapshot.model_validate(agent_info_json) + if snapshot.skills: + return await import_agent_with_skills_impl( + snapshot, + snapshot.skills, + authorization, + ) + return await import_agent_impl(snapshot, authorization) diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py index 5a340b1d6..643d1995e 100644 --- a/backend/services/agent_service.py +++ b/backend/services/agent_service.py @@ -22,7 +22,8 @@ from utils.prompt_template_utils import normalize_prompt_generate_template_content from consts.const import MEMORY_SEARCH_START_MSG, MEMORY_SEARCH_DONE_MSG, MEMORY_SEARCH_FAIL_MSG, TOOL_TYPE_MAPPING, \ LANGUAGE, MESSAGE_ROLE, MODEL_CONFIG_MAPPING, CAN_EDIT_ALL_USER_ROLES, PERMISSION_EDIT, PERMISSION_READ, PERMISSION_PRIVATE -from consts.exceptions import MemoryPreparationException, SkillDuplicateError +from consts.exceptions import AppException, MemoryPreparationException, SkillDuplicateError +from consts.error_code import ErrorCode from consts.agent_unavailable_reasons import AgentUnavailableReason from consts.model import ( AgentInfoRequest, @@ -45,7 +46,9 @@ delete_related_agent, insert_related_agent, query_all_agent_info_by_tenant_id, + query_sub_agent_relations, query_sub_agents_id_list, + resolve_sub_agent_version_no, search_agent_id_by_agent_name, search_agent_info_by_agent_id, search_blank_sub_agent_by_main_agent_id, @@ -67,8 +70,10 @@ search_tools_for_sub_agent ) from database import skill_db +from database.attachment_db import upload_fileobj from services.skill_service import SkillService -from database.agent_version_db import query_version_list +from services.file_management_service import is_allowed_skill_upload_path +from database.agent_version_db import query_version_list, query_current_version_no from database.group_db import query_group_ids_by_user from database.user_tenant_db import get_user_tenant_by_user_id from database.a2a_agent_db import get_server_agent_ids, query_external_sub_agents @@ -78,7 +83,7 @@ get_prompt_template_summary, ) from utils.str_utils import convert_list_to_string, convert_string_to_list -from services.conversation_management_service import save_conversation_assistant, save_conversation_user +from services.conversation_management_service import save_conversation_assistant, save_conversation_user, save_skill_files_to_conversation from services.memory_config_service import build_memory_context from utils.auth_utils import get_current_user_info, get_user_language from utils.config_utils import tenant_config_manager @@ -97,9 +102,139 @@ SAFE_AGENT_STREAM_ERROR_MESSAGE = "Agent execution failed. Please try again later." -# ------------------------------------------------------------- -# Internal helper functions -# ------------------------------------------------------------- +def _extract_json_objects_from_text(text: str) -> list[dict]: + """Extract all JSON objects embedded in a text blob.""" + if not text: + return [] + + decoder = json.JSONDecoder() + results: list[dict] = [] + index = 0 + + while index < len(text): + start_index = text.find("{", index) + if start_index < 0: + break + + try: + payload, end_index = decoder.raw_decode(text, start_index) + except json.JSONDecodeError: + index = start_index + 1 + continue + + if isinstance(payload, dict): + results.append(payload) + index = max(end_index, start_index + 1) + + return results + + +def _extract_skill_file_upload_payloads(content: str) -> list[dict]: + """Extract JSON payloads containing absolute_path from streamed tool output.""" + payloads: list[dict] = [] + for payload in _extract_json_objects_from_text(content): + if payload.get("absolute_path"): + payloads.append(payload) + return payloads + + +def _transform_skill_files_to_standard_format(upload_results: list[dict]) -> list[dict]: + """ + Transform skill file upload results to match the frontend attachment format. + + Skill upload format: + {file_name, absolute_path, object_name, preview_url, url, presigned_url, mime_type, file_size, status} + Frontend format: + {object_name, name, type, size, url, presigned_url, description} + """ + frontend_files = [] + for result in upload_results: + frontend_files.append({ + "object_name": result.get("object_name", ""), + "name": result.get("file_name", result.get("name", "")), + "type": "file", + "size": result.get("file_size", result.get("size", 0)), + "url": result.get("url", ""), + "presigned_url": result.get("presigned_url", result.get("preview_url", "")), + "description": "", + }) + return frontend_files + + +async def _process_skill_file_uploads( + content: str, + user_id: str, + tenant_id: str, +) -> list[dict]: + """Upload generated skill files to storage and return upload metadata.""" + + upload_results: list[dict] = [] + for payload in _extract_skill_file_upload_payloads(content): + absolute_path = str(payload.get("absolute_path") or "").strip() + file_name = str( + payload.get("file_name") + or payload.get("file_path") + or os.path.basename(absolute_path) + ) + mime_type = str(payload.get("mime_type") or payload.get("content_type") or "application/octet-stream") + if not absolute_path: + continue + + if not is_allowed_skill_upload_path(absolute_path): + logger.warning( + "[skill-file] rejected unsafe path absolute_path=%s", + absolute_path, + ) + continue + + if not file_name: + file_name = os.path.basename(absolute_path) + + if not os.path.exists(absolute_path): + continue + + try: + file_size = os.path.getsize(absolute_path) + actual_prefix = f"skill-files/{user_id}" if user_id else "skill-files" + with open(absolute_path, "rb") as file_obj: + upload_result = upload_fileobj( + file_obj=file_obj, + file_name=file_name, + prefix=actual_prefix, + generate_presigned_url=True, + file_size=file_size, + ) + + if upload_result.get("success"): + upload_results.append( + { + "status": "success", + "file_name": file_name, + "absolute_path": absolute_path, + "object_name": upload_result.get("object_name"), + "preview_url": upload_result.get("presigned_url") or upload_result.get("url"), + "url": upload_result.get("url"), + "presigned_url": upload_result.get("presigned_url"), + "mime_type": mime_type, + "file_size": upload_result.get("file_size", file_size), + } + ) + else: + error_message = upload_result.get("error") or "Upload failed" + logger.warning( + "[skill-file] upload failed file_name=%s absolute_path=%s error=%s", + file_name, + absolute_path, + error_message, + ) + except Exception as exc: + logger.exception( + "[skill-file] failed to upload file file_name=%s absolute_path=%s", + file_name, + absolute_path, + ) + + return upload_results def _safe_agent_stream_error_chunk() -> str: @@ -647,23 +782,53 @@ async def _stream_agent_chunks( agent_run_info, memory_ctx, ): - """Yield SSE chunks from agent_run while persisting messages & cleanup. - - This utility centralizes the common streaming logic used by both - generate_stream_with_memory and generate_stream_no_memory so that the code - is easier to maintain and less error-prone. - """ + """Yield SSE chunks from agent_run while persisting messages and cleanup.""" local_messages = [] captured_final_answer = None + captured_skill_files: dict[str, dict] = {} + skill_file_uploads: list[dict] = [] try: async for chunk in agent_run(agent_run_info): local_messages.append(chunk) - # Try to capture the final answer as it streams by in order to start memory addition try: data = json.loads(chunk) - if data.get("type") == "final_answer": + chunk_type = data.get("type") + if chunk_type == "final_answer": captured_final_answer = data.get("content") + + should_parse_skill_file = chunk_type in {"execution_logs", "parse"} or data.get("role") == "tool-response" + if should_parse_skill_file: + extracted_payload_count = 0 + content_value = data.get("content") + if isinstance(content_value, list): + content_items = content_value + elif content_value: + content_items = [{"type": "text", "text": str(content_value)}] + else: + content_items = [] + + for item in content_items: + if isinstance(item, dict) and item.get("type") == "text": + text_value = item.get("text") + if text_value: + extracted_payloads = _extract_json_objects_from_text(text_value) + for payload in extracted_payloads: + absolute_path = str(payload.get("absolute_path") or "").strip() + if not absolute_path: + continue + if absolute_path in captured_skill_files: + continue + if not os.path.exists(absolute_path): + continue + captured_skill_files[absolute_path] = payload + extracted_payload_count += 1 + if extracted_payload_count: + logger.info( + "[skill-file] captured payloads count=%s current_total=%s", + extracted_payload_count, + len(captured_skill_files), + ) except Exception: pass yield f"data: {chunk}\n\n" @@ -671,7 +836,6 @@ async def _stream_agent_chunks( logger.error("Agent run error: %r", run_exc, exc_info=True) yield _safe_agent_stream_error_chunk() finally: - # Persist assistant messages for non-debug runs if not agent_request.is_debug: save_messages( agent_request, @@ -680,11 +844,54 @@ async def _stream_agent_chunks( tenant_id=tenant_id, user_id=user_id, ) - # Always unregister the run to release resources agent_run_manager.unregister_agent_run( agent_request.conversation_id, user_id) - # Schedule memory addition in background to avoid blocking SSE termination + try: + skill_file_content_local = "\n".join( + json.dumps(payload, ensure_ascii=False) + for payload in captured_skill_files.values() + ) + if skill_file_content_local: + skill_file_uploads = await _process_skill_file_uploads( + content=skill_file_content_local, + user_id=user_id, + tenant_id=tenant_id, + ) + logger.info( + "[skill-file] upload finished conversation=%s result_count=%s results=%s", + agent_request.conversation_id, + len(skill_file_uploads), skill_file_uploads + ) + if skill_file_uploads: + # Keep original format for real-time SSE display + skill_files_payload = json.dumps( + {"skill_file_uploads": skill_file_uploads}, + ensure_ascii=False, + ) + try: + yield f"data: {json.dumps({'type': 'skill_files', 'content': skill_files_payload}, ensure_ascii=False)}\n\n" + except RuntimeError: + # Stream is closing (e.g., client disconnect). Avoid raising during generator teardown. + pass + # Persist skill file uploads to the conversation history so they + # appear in subsequent GET /conversation/{id} calls. + # Transform to frontend attachment format (object_name, name, type, size, etc.) + try: + frontend_files = _transform_skill_files_to_standard_format(skill_file_uploads) + save_skill_files_to_conversation( + conversation_id=agent_request.conversation_id, + skill_file_uploads=frontend_files, + user_id=user_id, + ) + except Exception: + logger.exception( + "[skill-file] failed to persist skill file uploads to conversation=%s", + agent_request.conversation_id, + ) + except Exception: + logger.exception("Failed to process skill file uploads") + async def _add_memory_background(): try: # Skip if memory recording is disabled @@ -779,14 +986,13 @@ async def get_agent_info_impl(agent_id: int, tenant_id: str, version_no: int = 0 user_role = str(user_tenant_record.get("user_role") or "").upper() can_edit_all = user_role in CAN_EDIT_ALL_USER_ROLES - # Permission logic (same as agent list): - # - If creator or can_edit_all: PERMISSION_EDIT - # - Otherwise: use ingroup_permission, default to PERMISSION_READ if None - if can_edit_all or str(agent_info.get("created_by")) == str(user_id): - agent_info["permission"] = PERMISSION_EDIT - else: - ingroup_permission = agent_info.get("ingroup_permission") - agent_info["permission"] = ingroup_permission if ingroup_permission is not None else PERMISSION_READ + # Permission logic (same as agent list, including ASSET_OWNER read-only override) + agent_info["permission"] = resolve_agent_list_permission( + user_role=user_role, + agent=agent_info, + user_id=user_id, + can_edit_all=can_edit_all, + ) except Exception as e: logger.warning(f"Failed to calculate agent permission: {str(e)}") @@ -862,6 +1068,12 @@ async def get_agent_info_impl(agent_id: int, tenant_id: str, version_no: int = 0 agent_info["is_available"] = is_available agent_info["unavailable_reasons"] = unavailable_reasons + # Set current_version_no from draft record (version_no=0) + # This ensures the returned data always has the current published version info + if version_no > 0: + draft_version_no = query_current_version_no(agent_id, tenant_id) + agent_info["current_version_no"] = draft_version_no + return agent_info @@ -906,6 +1118,10 @@ async def get_creating_sub_agent_info_impl(authorization: str = Header(None)): async def update_agent_info_impl(request: AgentInfoRequest, authorization: str = Header(None)): user_id, tenant_id, _ = get_current_user_info(authorization) + + if request.example_questions is not None and len(request.example_questions) > 6: + raise AppException(ErrorCode.COMMON_PARAMETER_INVALID, "example_questions cannot exceed 6 items") + prompt_template_id, prompt_template_name = get_prompt_template_summary( template_id=request.prompt_template_id, tenant_id=tenant_id, @@ -932,9 +1148,12 @@ async def update_agent_info_impl(request: AgentInfoRequest, authorization: str = "prompt_template_name": prompt_template_name, "max_steps": request.max_steps, "provide_run_summary": request.provide_run_summary, + "verification_config": request.verification_config, "duty_prompt": request.duty_prompt, "constraint_prompt": request.constraint_prompt, "few_shots_prompt": request.few_shots_prompt, + "greeting_message": request.greeting_message, + "example_questions": request.example_questions, "enabled": request.enabled if request.enabled is not None else True, "group_ids": convert_list_to_string(request.group_ids) if request.group_ids else user_group_ids, "ingroup_permission": request.ingroup_permission @@ -1202,76 +1421,216 @@ async def clear_agent_memory(agent_id: int, tenant_id: str, user_id: str): # Silently fail to maintain agent deletion process -async def export_agent_impl(agent_id: int, authorization: str = Header(None)) -> str: - """ - Export the configuration information of the specified agent and all its sub-agents. - - Args: - agent_id (int): The ID of the agent to export. - authorization (str): User authentication information, obtained from the Header. - - Returns: - str: A formatted JSON string containing the configuration information of the agent and all its sub-agents. - - Data Structure Example: - model.py ExportAndImportDataFormat - - Note: - This function recursively finds all managed sub-agents and exports the detailed configuration of each agent (including tools, prompts, etc.) as a dictionary, and finally returns it as a formatted JSON string for frontend download and backup. - """ - - user_id, tenant_id, _ = get_current_user_info(authorization) - +async def _export_agent_dict_core( + root_agent_id: int, + tenant_id: str, + user_id: str, + version_no: int = 0, +) -> dict: + """Build ExportAndImportDataFormat dict for an agent tree at the given version.""" export_agent_dict = {} - search_list = deque([agent_id]) - agent_id_set = set() + search_list: deque = deque([(root_agent_id, version_no)]) + visited: set = set() mcp_info_set = set() - while len(search_list): - left_ele = search_list.popleft() - if left_ele in agent_id_set: + while search_list: + current_agent_id, current_version_no = search_list.popleft() + visit_key = (current_agent_id, current_version_no) + if visit_key in visited: continue + visited.add(visit_key) - agent_id_set.add(left_ele) - agent_info = await export_agent_by_agent_id(agent_id=left_ele, tenant_id=tenant_id, user_id=user_id) + agent_info = await export_agent_by_agent_id( + agent_id=current_agent_id, + tenant_id=tenant_id, + user_id=user_id, + version_no=current_version_no, + ) - # collect mcp name for tool in agent_info.tools: if tool.source == "mcp" and tool.usage: mcp_info_set.add(tool.usage) - search_list.extend(agent_info.managed_agents) + relations = query_sub_agent_relations( + main_agent_id=current_agent_id, + tenant_id=tenant_id, + version_no=current_version_no, + ) + for rel in relations: + child_id = rel["selected_agent_id"] + child_version = resolve_sub_agent_version_no( + child_id, + rel.get("selected_agent_version_no"), + tenant_id, + ) + search_list.append((child_id, child_version)) + export_agent_dict[str(agent_info.agent_id)] = agent_info - # convert mcp info to MCPInfo list mcp_info_list = [] for mcp_server_name in mcp_info_set: - # get mcp url by mcp_server_name and tenant_id mcp_url = get_mcp_server_by_name_and_tenant(mcp_server_name, tenant_id) mcp_info_list.append( MCPInfo(mcp_server_name=mcp_server_name, mcp_url=mcp_url)) export_data = ExportAndImportDataFormat( - agent_id=agent_id, agent_info=export_agent_dict, mcp_info=mcp_info_list) - return json.dumps(export_data.model_dump()) + agent_id=root_agent_id, + agent_info=export_agent_dict, + mcp_info=mcp_info_list, + ) + return export_data.model_dump() -async def export_agent_by_agent_id(agent_id: int, tenant_id: str, user_id: str) -> ExportAndImportAgentInfo: +async def export_agent_dict_impl( + agent_id: int, + authorization: str = Header(None), + version_no: int = 0, +) -> dict: """ - Export a single agent's information based on agent_id + Export the configuration information of the specified agent and all its sub-agents. + + Args: + agent_id (int): The ID of the agent to export. + authorization (str): User authentication information, obtained from the Header. + version_no (int): Version to export. Default 0 = draft. + + Returns: + dict: ExportAndImportDataFormat as a plain dict (via model_dump). """ + user_id, tenant_id, _ = get_current_user_info(authorization) + return await _export_agent_dict_core( + root_agent_id=agent_id, + tenant_id=tenant_id, + user_id=user_id, + version_no=version_no, + ) + + +async def export_agent_dict_for_repository_impl( + agent_id: int, + tenant_id: str, + user_id: str, + version_no: int, +) -> dict: + """Export agent tree for marketplace repository storage (no HTTP auth header).""" + return await _export_agent_dict_core( + root_agent_id=agent_id, + tenant_id=tenant_id, + user_id=user_id, + version_no=version_no, + ) + + +async def export_agent_impl( + agent_id: int, + authorization: str = Header(None), + version_no: int = 0, +) -> str: + """Serialize export_agent_dict_impl output to a JSON string for download or ZIP embedding.""" + agent_dict = await export_agent_dict_impl( + agent_id, authorization, version_no=version_no + ) + return json.dumps(agent_dict) + + +def _collect_skill_names_from_tree( + agent_id: int, + tenant_id: str, + version_no: int, + visited: Optional[set] = None, +) -> List[str]: + """Collect unique skill names from an agent tree at the given version.""" + if visited is None: + visited = set() + + skill_names: List[str] = [] + seen_names: set = set() + + def _walk(current_agent_id: int, current_version_no: int) -> None: + visit_key = (current_agent_id, current_version_no) + if visit_key in visited: + return + visited.add(visit_key) + + skill_instances = skill_db.query_skill_instances_by_agent_id( + agent_id=current_agent_id, + tenant_id=tenant_id, + version_no=current_version_no, + ) + for inst in skill_instances: + skill_id = inst.get("skill_id") + skill = skill_db.get_skill_by_id(skill_id, tenant_id) + if skill: + name = skill.get("name") + if name and name not in seen_names: + seen_names.add(name) + skill_names.append(name) + + relations = query_sub_agent_relations( + main_agent_id=current_agent_id, + tenant_id=tenant_id, + version_no=current_version_no, + ) + for rel in relations: + child_id = rel["selected_agent_id"] + child_version = resolve_sub_agent_version_no( + child_id, + rel.get("selected_agent_version_no"), + tenant_id, + ) + _walk(child_id, child_version) + + _walk(agent_id, version_no) + return skill_names + + +def collect_skill_zip_entries( + agent_id: int, + tenant_id: str, + version_no: int = 0, +) -> List[SkillZipEntry]: + """Export skill ZIP payloads for all skills in an agent tree.""" + skill_names = _collect_skill_names_from_tree(agent_id, tenant_id, version_no) + if not skill_names: + return [] + + skill_service = SkillService(tenant_id=tenant_id) + exported = skill_service.export_skills_by_names(skill_names, tenant_id) + return [ + SkillZipEntry( + skill_name=entry["skill_name"], + skill_zip_base64=entry["skill_zip_base64"], + ) + for entry in exported + ] + + +async def export_agent_by_agent_id( + agent_id: int, + tenant_id: str, + user_id: str, + version_no: int = 0, +) -> ExportAndImportAgentInfo: + """Export a single agent's information based on agent_id and version_no.""" agent_info = search_agent_info_by_agent_id( - agent_id=agent_id, tenant_id=tenant_id) + agent_id=agent_id, tenant_id=tenant_id, version_no=version_no + ) agent_relation_in_db = query_sub_agents_id_list( - main_agent_id=agent_id, tenant_id=tenant_id) - tool_list = await create_tool_config_list(agent_id=agent_id, tenant_id=tenant_id, user_id=user_id) + main_agent_id=agent_id, tenant_id=tenant_id, version_no=version_no + ) + tool_list = await create_tool_config_list( + agent_id=agent_id, + tenant_id=tenant_id, + user_id=user_id, + version_no=version_no, + ) # Collect skill names from skill instances skill_names: List[str] = [] try: skill_instances = skill_db.query_skill_instances_by_agent_id( - agent_id=agent_id, tenant_id=tenant_id, version_no=0 + agent_id=agent_id, tenant_id=tenant_id, version_no=version_no ) for inst in skill_instances: skill_id = inst.get("skill_id") @@ -1307,6 +1666,7 @@ async def export_agent_by_agent_id(agent_id: int, tenant_id: str, user_id: str) "display_name") if business_logic_model_info is not None else None agent_info = ExportAndImportAgentInfo(agent_id=agent_id, + tenant_id=agent_info["tenant_id"], name=agent_info["name"], display_name=agent_info["display_name"], description=agent_info["description"], @@ -1314,6 +1674,7 @@ async def export_agent_by_agent_id(agent_id: int, tenant_id: str, user_id: str) author=agent_info.get("author"), max_steps=agent_info["max_steps"], provide_run_summary=agent_info["provide_run_summary"], + verification_config=agent_info.get("verification_config"), duty_prompt=agent_info.get( "duty_prompt"), constraint_prompt=agent_info.get( @@ -1468,6 +1829,7 @@ async def import_agent_by_agent_id( "prompt_template_name": import_agent_info.prompt_template_name or SYSTEM_PROMPT_TEMPLATE_NAME, "max_steps": import_agent_info.max_steps, "provide_run_summary": import_agent_info.provide_run_summary, + "verification_config": getattr(import_agent_info, "verification_config", None), "duty_prompt": import_agent_info.duty_prompt, "constraint_prompt": import_agent_info.constraint_prompt, "few_shots_prompt": import_agent_info.few_shots_prompt, @@ -1835,6 +2197,7 @@ async def prepare_agent_run( is_debug=agent_request.is_debug, override_version_no=agent_request.version_no, override_model_id=agent_request.model_id, + tool_params=agent_request.tool_params, ) # Mount conversation-level reusable ContextManager if enabled @@ -2280,52 +2643,45 @@ def get_sub_agents_recursive(parent_agent_id: int, depth: int = 0, max_depth: in raise ValueError(f"Failed to get agent call relationship: {str(e)}") -async def export_agent_with_skills_impl(agent_id: int, authorization: str) -> dict: - """Export an agent, returning a ZIP if it has skill instances, otherwise plain JSON. +async def export_agent_with_skills_impl( + agent_id: int, + authorization: str, + version_no: int = 0, +) -> dict: + """Export an agent, returning a ZIP if it has skill instances, otherwise a plain dict. The response is either: - A dict with {"_zip": True, "data": bytes, "filename": str} when the agent has skills - - A plain dict (JSON string) when the agent has no skills + - ExportAndImportDataFormat as a plain dict when the agent has no skills """ - from services.skill_service import SkillService - user_id, tenant_id, _ = get_current_user_info(authorization) - skill_instances = skill_db.query_skill_instances_by_agent_id( - agent_id=agent_id, tenant_id=tenant_id, version_no=0 + skill_zip_entries = collect_skill_zip_entries( + agent_id=agent_id, tenant_id=tenant_id, version_no=version_no ) - if not skill_instances: - return await export_agent_impl(agent_id, authorization) - - skill_names = [] - for inst in skill_instances: - skill_id = inst.get("skill_id") - skill = skill_db.get_skill_by_id(skill_id, tenant_id) - if skill: - skill_names.append(skill.get("name")) - - if not skill_names: - return await export_agent_impl(agent_id, authorization) - - agent_json_str = await export_agent_impl(agent_id, authorization) + if not skill_zip_entries: + return await export_agent_dict_impl( + agent_id, authorization, version_no=version_no + ) - skill_service = SkillService(tenant_id=tenant_id) - skill_zip_entries = skill_service.export_skills_by_names( - skill_names, tenant_id) + agent_json_str = await export_agent_impl( + agent_id, authorization, version_no=version_no + ) zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf: zf.writestr("agent.json", agent_json_str) for entry in skill_zip_entries: - skill_zip_bytes = base64.b64decode(entry["skill_zip_base64"]) - zf.writestr(f"skills/{entry['skill_name']}.zip", skill_zip_bytes) + skill_zip_bytes = base64.b64decode(entry.skill_zip_base64) + zf.writestr(f"skills/{entry.skill_name}.zip", skill_zip_bytes) zip_buffer.seek(0) zip_data = zip_buffer.read() agent_info = search_agent_info_by_agent_id( - agent_id=agent_id, tenant_id=tenant_id) + agent_id=agent_id, tenant_id=tenant_id, version_no=version_no + ) agent_name = agent_info.get( "name", "anonymous") if agent_info else "anonymous" diff --git a/backend/services/agent_version_service.py b/backend/services/agent_version_service.py index d7096727b..8ed6e14d4 100644 --- a/backend/services/agent_version_service.py +++ b/backend/services/agent_version_service.py @@ -49,6 +49,17 @@ def _remove_audit_fields_for_insert(data: dict) -> None: data.pop('delete_flag', None) +def _build_sub_agent_relations(relations: List[dict]) -> List[dict]: + """Map relation snapshots to sub-agent relation payloads for API responses.""" + return [ + { + 'agent_id': r['selected_agent_id'], + 'version_no': r.get('selected_agent_version_no'), + } + for r in relations + ] + + def publish_version_impl( agent_id: int, tenant_id: str, @@ -92,11 +103,18 @@ def publish_version_impl( _remove_audit_fields_for_insert(tool_snapshot) insert_tool_snapshot(tool_snapshot) - # Insert relation snapshots + # Insert relation snapshots with pinned child agent versions for rel in relations_draft: + child_id = rel['selected_agent_id'] + child_version = query_current_version_no(child_id, tenant_id) + if child_version is None: + raise ValueError( + f"Sub-agent {child_id} has no published version; publish the sub-agent first." + ) rel_snapshot = rel.copy() rel_snapshot.pop('version_no', None) rel_snapshot['version_no'] = new_version_no + rel_snapshot['selected_agent_version_no'] = child_version _remove_audit_fields_for_insert(rel_snapshot) insert_relation_snapshot(rel_snapshot) @@ -271,6 +289,7 @@ def get_version_detail_impl( # Extract sub_agent_id_list from relations result['sub_agent_id_list'] = [r['selected_agent_id'] for r in relations_snapshot] + result['sub_agent_relations'] = _build_sub_agent_relations(relations_snapshot) # Get skill instances for this version (from ag_skill_instance_t with version_no) from database import skill_db as skill_db_module @@ -710,6 +729,7 @@ def _get_version_detail_or_draft( # Add tools (only enabled tools) result['tools'] = [t for t in tools_draft if t.get('enabled', True)] result['sub_agent_id_list'] = [r['selected_agent_id'] for r in relations_draft] + result['sub_agent_relations'] = _build_sub_agent_relations(relations_draft) # Get draft skill instances (version_no=0) skills_draft = skill_db_module.query_skill_instances_by_agent_id( @@ -783,12 +803,11 @@ async def list_published_agents_impl( CAN_EDIT_ALL_USER_ROLES, get_user_tenant_by_user_id, query_group_ids_by_user, - PERMISSION_EDIT, - PERMISSION_READ, get_model_by_model_id, check_agent_availability, _apply_duplicate_name_availability_rules, ) + from services.asset_owner_visibility import resolve_agent_list_permission from database.agent_version_db import query_agent_snapshot # Get user role for permission check @@ -858,9 +877,10 @@ async def list_published_agents_impl( # Extract sub_agent_id_list from relations agent_info['sub_agent_id_list'] = [r['selected_agent_id'] for r in relations_snapshot] + agent_info['sub_agent_relations'] = _build_sub_agent_relations(relations_snapshot) - # Add published version info - agent_info['published_version_no'] = current_version_no + # Add current version info + agent_info['current_version_no'] = current_version_no # Check agent availability using the shared function _, unavailable_reasons = check_agent_availability( @@ -893,7 +913,12 @@ async def list_published_agents_impl( model_cache[model_id] = get_model_by_model_id(model_id, tenant_id) model_info = model_cache.get(model_id) - permission = PERMISSION_EDIT if can_edit_all or str(agent.get("created_by")) == str(user_id) else PERMISSION_READ + permission = resolve_agent_list_permission( + user_role=user_role, + agent=agent, + user_id=user_id, + can_edit_all=can_edit_all, + ) simple_agent_list.append({ "agent_id": agent.get("agent_id"), @@ -909,7 +934,9 @@ async def list_published_agents_impl( "is_new": agent.get("is_new", False), "group_ids": agent.get("group_ids", []), "permission": permission, - "published_version_no": agent.get("published_version_no"), + "current_version_no": agent.get("current_version_no"), + "greeting_message": agent.get("greeting_message"), + "example_questions": agent.get("example_questions"), }) return simple_agent_list diff --git a/backend/services/cas_service.py b/backend/services/cas_service.py new file mode 100644 index 000000000..7db3fce1a --- /dev/null +++ b/backend/services/cas_service.py @@ -0,0 +1,424 @@ +import json +import logging +import os +import secrets +import ssl +import urllib.parse +import urllib.request +from xml.etree.ElementTree import Element +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any, Dict, Optional + +import defusedxml.ElementTree as ET +from defusedxml.common import DefusedXmlException + +from consts.const import ( + CAS_CA_BUNDLE, + CAS_CALLBACK_BASE_URL, + CAS_EMAIL_ATTRIBUTE, + CAS_ENABLED, + CAS_LOGIN_MODE, + CAS_LOGOUT_URL, + CAS_RENEW_BEFORE_SECONDS, + CAS_RENEW_TIMEOUT_SECONDS, + CAS_ROLE_ATTRIBUTE, + CAS_ROLE_MAP_JSON, + CAS_SERVER_URL, + CAS_SESSION_MAX_AGE_SECONDS, + CAS_SSL_VERIFY, + CAS_SYNTHETIC_EMAIL_DOMAIN, + CAS_TENANT_ATTRIBUTE, + CAS_USER_ATTRIBUTE, + CAS_VALIDATE_PATH, + DEFAULT_TENANT_ID, + LOCAL_SESSION_MAX_AGE_SECONDS, +) +from database.cas_session_db import ( + create_cas_session, + revoke_cas_session_by_index, + revoke_cas_sessions_by_user_id, +) +from database.oauth_account_db import get_oauth_account_by_provider +from database.user_tenant_db import get_user_tenant_by_user_id, upsert_user_tenant +from services.oauth_service import ( + create_or_update_oauth_account, + find_supabase_user_id_by_email, +) +from services.skill_service import init_skill_list_for_tenant +from services.tool_configuration_service import init_tool_list_for_tenant +from utils.auth_utils import calculate_expires_at, generate_session_jwt, get_supabase_admin_client + +logger = logging.getLogger(__name__) + +CAS_PROVIDER = "cas" +VALID_ROLES = {"SU", "ADMIN", "DEV", "USER"} + + +class CasAuthenticationError(Exception): + pass + + +@dataclass +class CasPrincipal: + cas_user_id: str + email: str + username: str + role: str + tenant_id: str + session_index: str + expires_at: datetime + + +def get_cas_config() -> Dict[str, Any]: + mode = CAS_LOGIN_MODE if CAS_LOGIN_MODE in {"button", "force", "disabled"} else "disabled" + enabled = CAS_ENABLED and bool(CAS_SERVER_URL) + if not enabled: + mode = "disabled" + return { + "enabled": enabled, + "login_mode": mode, + "renew_before_seconds": CAS_RENEW_BEFORE_SECONDS, + "renew_timeout_seconds": CAS_RENEW_TIMEOUT_SECONDS, + "display_name": "CAS", + } + + +def build_login_url(redirect: str = "/") -> str: + _ensure_enabled() + service_url = _build_callback_url("/api/user/cas/callback", {"redirect": _normalize_redirect(redirect)}) + return f"{CAS_SERVER_URL}/login?service={service_url}" + + +def build_renew_url() -> str: + _ensure_enabled() + service_url = _build_callback_url("/api/user/cas/renew_callback", {}) + return f"{CAS_SERVER_URL}/login?service={service_url}&gateway=true" + + +def build_logout_url() -> str: + _ensure_enabled() + configured_logout_url = CAS_LOGOUT_URL.strip() + if not configured_logout_url: + return "" + + parsed_config = urllib.parse.urlsplit(configured_logout_url) + if parsed_config.scheme and parsed_config.netloc: + logout_url = configured_logout_url + else: + logout_url = f"{CAS_SERVER_URL}/{configured_logout_url.lstrip('/')}" + + parsed = urllib.parse.urlsplit(logout_url) + if parsed.query: + return logout_url + + query = f"service={CAS_CALLBACK_BASE_URL}" + return urllib.parse.urlunsplit((parsed.scheme, parsed.netloc, parsed.path, query, parsed.fragment)) + + +async def login_with_ticket(ticket: str, redirect: str = "/") -> Dict[str, Any]: + redirect = _normalize_redirect(redirect) + service_url = _build_callback_url("/api/user/cas/callback", {"redirect": redirect}) + principal = validate_service_ticket(ticket, service_url) + return await _create_project_session(principal, redirect=redirect) + + +async def renew_with_ticket(ticket: str) -> Dict[str, Any]: + service_url = _build_callback_url("/api/user/cas/renew_callback", {}) + principal = validate_service_ticket(ticket, service_url) + return await _create_project_session(principal, redirect="/", renew=True) + + +def validate_service_ticket(ticket: str, service_url: str) -> CasPrincipal: + _ensure_enabled() + if not ticket: + raise CasAuthenticationError("CAS ticket is missing") + + validate_path = CAS_VALIDATE_PATH if CAS_VALIDATE_PATH.startswith("/") else f"/{CAS_VALIDATE_PATH}" + validate_url = f"{CAS_SERVER_URL}{validate_path}" + xml_text = _http_get_text(f"{validate_url}?service={service_url}&ticket={ticket}") + logger.info("CAS serviceValidate response: %s", xml_text) + return parse_service_validate_response(xml_text, fallback_session_index=ticket) + + +def parse_service_validate_response(xml_text: str, fallback_session_index: str = "") -> CasPrincipal: + try: + root = ET.fromstring(xml_text) + except (ET.ParseError, DefusedXmlException) as exc: + raise CasAuthenticationError("Invalid CAS validation response") from exc + + failure = _find_first(root, "authenticationFailure") + if failure is not None: + raise CasAuthenticationError((failure.text or "CAS authentication failed").strip()) + + success = _find_first(root, "authenticationSuccess") + if success is None: + raise CasAuthenticationError("CAS authentication failed") + + user = _get_child_text(success, "user") + attrs_node = _find_first(success, "attributes") + attrs = _extract_attributes(attrs_node) if attrs_node is not None else {} + + cas_user_id = _attribute_or_default(attrs, CAS_USER_ATTRIBUTE, user) or user + if not cas_user_id: + raise CasAuthenticationError("CAS user id is missing") + + email = _attribute_or_default(attrs, CAS_EMAIL_ATTRIBUTE, "") + username = attrs.get("displayName") or attrs.get("name") or cas_user_id + role = _map_role(_attribute_or_default(attrs, CAS_ROLE_ATTRIBUTE, "USER")) + tenant_id = _attribute_or_default(attrs, CAS_TENANT_ATTRIBUTE, DEFAULT_TENANT_ID) or DEFAULT_TENANT_ID + session_index = attrs.get("SessionIndex") or attrs.get("sessionIndex") or fallback_session_index + expires_at = _resolve_expires_at(attrs) + + if not email: + safe_user = "".join(c if c.isalnum() or c in ("-", "_", ".") else "_" for c in cas_user_id) + email = f"{safe_user}@{CAS_SYNTHETIC_EMAIL_DOMAIN}" + + return CasPrincipal( + cas_user_id=str(cas_user_id), + email=str(email).lower(), + username=str(username), + role=role, + tenant_id=str(tenant_id), + session_index=str(session_index or ""), + expires_at=expires_at, + ) + + +def parse_logout_request(logout_request: str) -> Dict[str, str]: + if not logout_request: + return {"cas_user_id": "", "session_index": ""} + try: + root = ET.fromstring(logout_request) + except (ET.ParseError, DefusedXmlException): + logger.warning("Invalid CAS logoutRequest XML") + return {"cas_user_id": "", "session_index": ""} + + session_index = _get_child_text(root, "SessionIndex") + cas_user_id = ( + _get_child_text(root, "NameID") + or _get_child_text(root, "nameID") + or _get_child_text(root, "user") + or _get_child_text(root, "casUserId") + ) + return {"cas_user_id": cas_user_id or "", "session_index": session_index or ""} + + +def revoke_from_logout_request(logout_request: str) -> Dict[str, Any]: + parsed = parse_logout_request(logout_request) + revoked = 0 + if parsed["cas_user_id"]: + revoked = revoke_cas_sessions_by_user_id(parsed["cas_user_id"]) + logger.info( + "CAS SLO revoke by cas_user_id: cas_user_id=%s revoked=%s", + parsed["cas_user_id"], + revoked, + ) + if revoked == 0 and parsed["session_index"]: + revoked = revoke_cas_session_by_index(parsed["session_index"]) + logger.info( + "CAS SLO revoke by session_index: session_index=%s revoked=%s", + parsed["session_index"], + revoked, + ) + if revoked == 0: + logger.warning("CAS SLO did not revoke any session: %s", parsed) + return {"revoked": revoked, **parsed} + + +async def _create_project_session(principal: CasPrincipal, redirect: str = "/", renew: bool = False) -> Dict[str, Any]: + user_id = _resolve_project_user(principal) + existing_tenant = get_user_tenant_by_user_id(user_id) + user_tenant = upsert_user_tenant( + user_id=user_id, + tenant_id=principal.tenant_id, + user_role=principal.role, + user_email=principal.email, + ) + if not existing_tenant: + await init_tool_list_for_tenant(principal.tenant_id, user_id) + await init_skill_list_for_tenant(principal.tenant_id, user_id) + + now = datetime.now() + max_local_expiry = now + timedelta(seconds=LOCAL_SESSION_MAX_AGE_SECONDS) + expires_at_dt = min(principal.expires_at, max_local_expiry) + expires_in_seconds = max(1, int((expires_at_dt - now).total_seconds())) + + session_id = secrets.token_urlsafe(32) + create_cas_session( + session_id=session_id, + user_id=user_id, + cas_user_id=principal.cas_user_id, + cas_session_index=principal.session_index, + expires_at=expires_at_dt, + ) + + jwt_token = generate_session_jwt(user_id, expires_in=expires_in_seconds, session_id=session_id) + + return { + "user": { + "id": str(user_id), + "email": principal.email, + "role": user_tenant.get("user_role", principal.role), + }, + "session": { + "access_token": jwt_token, + "refresh_token": "", + "expires_at": calculate_expires_at(jwt_token), + "expires_in_seconds": expires_in_seconds, + }, + "redirect_url": redirect, + "renew": renew, + } + + +def _resolve_project_user(principal: CasPrincipal) -> str: + existing = get_oauth_account_by_provider(CAS_PROVIDER, principal.cas_user_id) + if existing: + create_or_update_oauth_account( + user_id=existing["user_id"], + provider=CAS_PROVIDER, + provider_user_id=principal.cas_user_id, + email=principal.email, + username=principal.username, + tenant_id=principal.tenant_id, + ) + return existing["user_id"] + + admin_client = get_supabase_admin_client() + if not admin_client: + raise RuntimeError("Supabase admin client not available") + + user_id = find_supabase_user_id_by_email(admin_client, principal.email) + if not user_id: + create_resp = admin_client.auth.admin.create_user( + { + "email": principal.email, + "password": secrets.token_urlsafe(32), + "email_confirm": True, + "user_metadata": { + "full_name": principal.username, + "provider": CAS_PROVIDER, + "cas_user_id": principal.cas_user_id, + }, + } + ) + user_id = create_resp.user.id + + create_or_update_oauth_account( + user_id=user_id, + provider=CAS_PROVIDER, + provider_user_id=principal.cas_user_id, + email=principal.email, + username=principal.username, + tenant_id=principal.tenant_id, + ) + return user_id + + +def _ensure_enabled() -> None: + if not CAS_ENABLED or not CAS_SERVER_URL: + raise CasAuthenticationError("CAS is not configured") + + +def _build_callback_url(path: str, params: Dict[str, str]) -> str: + if not CAS_CALLBACK_BASE_URL: + raise CasAuthenticationError("CAS callback base URL is not configured") + query = _build_callback_query(params) + suffix = f"?{query}" if query else "" + return f"{CAS_CALLBACK_BASE_URL}{path}{suffix}" + + +def _build_callback_query(params: Dict[str, str]) -> str: + return "&".join(f"{key}={value}" for key, value in params.items()) + + +def _normalize_redirect(redirect: str) -> str: + if not redirect or not redirect.startswith("/") or redirect.startswith("//"): + return "/" + return redirect + + +def _build_ssl_context() -> ssl.SSLContext: + if CAS_CA_BUNDLE and os.path.isfile(CAS_CA_BUNDLE): + return ssl.create_default_context(cafile=CAS_CA_BUNDLE) + if not CAS_SSL_VERIFY: + ctx = ssl.create_default_context() + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + return ctx + return ssl.create_default_context() + + +def _http_get_text(url: str) -> str: + req = urllib.request.Request(url, headers={"Accept": "application/xml,text/xml,*/*"}) + with urllib.request.urlopen(req, timeout=15, context=_build_ssl_context()) as resp: + return resp.read().decode("utf-8") + + +def _local_name(tag: str) -> str: + return tag.rsplit("}", 1)[-1] + + +def _find_first(node: Element, name: str) -> Optional[Element]: + for child in node.iter(): + if _local_name(child.tag) == name: + return child + return None + + +def _get_child_text(node: Element, name: str) -> str: + found = _find_first(node, name) + return (found.text or "").strip() if found is not None else "" + + +def _extract_attributes(attrs_node: Element) -> Dict[str, str]: + attrs: Dict[str, str] = {} + for child in list(attrs_node): + value = (child.text or "").strip() + if value: + attrs[_local_name(child.tag)] = value + return attrs + + +def _attribute_or_default(attrs: Dict[str, str], key: str, default: str) -> str: + if key and key in attrs: + return attrs[key] + return default + + +def _map_role(raw_role: str) -> str: + role = (raw_role or "USER").upper() + try: + role_map = json.loads(CAS_ROLE_MAP_JSON) if CAS_ROLE_MAP_JSON else {} + role = str(role_map.get(raw_role, role_map.get(role, role))).upper() + except Exception: + logger.warning("Invalid CAS_ROLE_MAP_JSON; falling back to raw role") + return role if role in VALID_ROLES else "USER" + + +def _resolve_expires_at(attrs: Dict[str, str]) -> datetime: + for key in ("expiresAt", "expirationDate", "validUntil", "notOnOrAfter"): + value = attrs.get(key) + if not value: + continue + parsed = _parse_datetime(value) + if parsed: + return parsed + return datetime.now() + timedelta(seconds=CAS_SESSION_MAX_AGE_SECONDS) + + +def _parse_datetime(value: str) -> Optional[datetime]: + try: + if value.isdigit(): + timestamp = int(value) + if timestamp > 10_000_000_000: + timestamp = timestamp / 1000 + return datetime.fromtimestamp(timestamp) + normalized = value.replace("Z", "+00:00") + parsed = datetime.fromisoformat(normalized) + if parsed.tzinfo: + parsed = parsed.astimezone().replace(tzinfo=None) + return parsed + except Exception: + return None diff --git a/backend/services/conversation_management_service.py b/backend/services/conversation_management_service.py index 302ec63a8..0b7345461 100644 --- a/backend/services/conversation_management_service.py +++ b/backend/services/conversation_management_service.py @@ -8,6 +8,7 @@ from consts.const import LANGUAGE, MODEL_CONFIG_MAPPING, MESSAGE_ROLE, DEFAULT_EN_TITLE, DEFAULT_ZH_TITLE from consts.model import AgentRequest, ConversationResponse, MessageRequest, MessageUnit +from consts.exceptions import ConversationNotFoundError from database.conversation_db import ( create_conversation, create_conversation_message, @@ -18,12 +19,14 @@ get_conversation, get_conversation_history, get_conversation_list, + get_latest_assistant_message_id, get_message_id_by_index, get_source_images_by_conversation, get_source_images_by_message, get_source_searches_by_conversation, get_source_searches_by_message, rename_conversation, + update_message_minio_files, update_message_opinion ) from nexent.core.utils.observer import MessageObserver, ProcessType @@ -224,7 +227,7 @@ def save_conversation_assistant(request: AgentRequest, messages: List[str], user message_list.append(message) conversation_req = MessageRequest(conversation_id=request.conversation_id, message_idx=user_role_count * 2 + 1, - role=MESSAGE_ROLE["ASSISTANT"], message=message_list, minio_files=request.minio_files) + role=MESSAGE_ROLE["ASSISTANT"], message=message_list, minio_files=None) save_message(conversation_req, user_id=user_id, tenant_id=tenant_id) @@ -296,7 +299,9 @@ def update_conversation_title(conversation_id: int, title: str, user_id: str = N """ success = rename_conversation(conversation_id, title, user_id) if not success: - raise Exception(f"Conversation {conversation_id} does not exist or has been deleted") + raise ConversationNotFoundError( + f"Conversation {conversation_id} does not exist or has been deleted" + ) return success @@ -509,6 +514,10 @@ def get_conversation_history_service(conversation_id: int, user_id: str) -> List 'opinion_flag': msg['opinion_flag'] } + # Add minio_files field (if any, e.g., skill-generated attachments) + if 'minio_files' in msg and msg['minio_files']: + message_item['minio_files'] = msg['minio_files'] + # Add image content (if any) if message_id in image_by_message: message_item['picture'] = image_by_message[message_id] @@ -701,3 +710,52 @@ async def get_message_id_by_index_impl(conversation_id: int, message_index: int) if message_id is None: raise Exception("Message not found.") return message_id + + +def save_skill_files_to_conversation( + conversation_id: int, + skill_file_uploads: List[Dict[str, Any]], + user_id: str, +) -> bool: + """ + Append skill file upload records to the latest assistant message in a conversation. + + This persists generated documents (e.g., DOCX, XLSX created by skills) to the + conversation history so they appear in subsequent GET /conversation/{id} calls. + + Args: + conversation_id: Target conversation ID + skill_file_uploads: List of upload metadata dicts (e.g., from upload_fileobj) + user_id: User ID for ownership validation + + Returns: + bool: True if files were saved, False if no assistant message was found + """ + if not skill_file_uploads: + return False + + try: + message_id = get_latest_assistant_message_id(conversation_id, user_id) + if message_id is None: + logging.warning( + "[skill-file] no assistant message found for conversation=%s, " + "cannot persist skill file uploads", + conversation_id, + ) + return False + + success = update_message_minio_files(message_id, skill_file_uploads) + if success: + logging.info( + "[skill-file] persisted %d file(s) to message_id=%s conversation=%s", + len(skill_file_uploads), + message_id, + conversation_id, + ) + return success + except Exception as exc: + logging.exception( + "[skill-file] failed to persist skill file uploads for conversation=%s", + conversation_id, + ) + return False diff --git a/backend/services/data_process_service.py b/backend/services/data_process_service.py index ae3d35dcd..a7529127c 100644 --- a/backend/services/data_process_service.py +++ b/backend/services/data_process_service.py @@ -15,7 +15,7 @@ import redis import torch from PIL import Image -from celery import states, chain +from celery import states from transformers import CLIPProcessor, CLIPModel from nexent.data_process.core import DataProcessCore @@ -25,7 +25,7 @@ from database.attachment_db import delete_file, file_exists, get_file_size_from_minio, get_file_stream, upload_file from utils.file_management_utils import convert_office_to_pdf from data_process.app import app as celery_app -from data_process.tasks import process, forward +from data_process.tasks import submit_process_forward_chain from data_process.utils import get_task_info, get_all_task_ids_from_redis # Limit concurrent LibreOffice processes to avoid resource exhaustion @@ -54,7 +54,8 @@ def __init__(self): self._inspector = None self._inspector_last_time = 0 - self._inspector_ttl = 300 # 5 minutes - inspector is expensive to create (ping all workers) + # 5 minutes - inspector is expensive to create (ping all workers) + self._inspector_ttl = 300 self._inspector_lock = None self._inspector_lock = threading.Lock() @@ -152,7 +153,8 @@ async def get_all_tasks(self, filter: bool = True) -> List[Dict[str, Any]]: def _normalize_runtime_meta(task: Dict[str, Any]) -> Dict[str, Any]: task_name_full = task.get('name', '') or '' - task_name = task_name_full.split('.')[-1] if task_name_full else '' + task_name = task_name_full.split( + '.')[-1] if task_name_full else '' kwargs = task.get('kwargs') or {} if isinstance(kwargs, str): try: @@ -178,35 +180,43 @@ def _normalize_runtime_meta(task: Dict[str, Any]) -> Dict[str, Any]: def get_active(): t = time.time() # Create fresh inspector with short timeout for each call - short_inspector = celery_app.control.inspect(timeout=short_timeout) + short_inspector = celery_app.control.inspect( + timeout=short_timeout) result = short_inspector.active() elapsed = time.time() - t - logger.info(f"[get_all_tasks] inspector.active() took {elapsed:.3f}s") + logger.info( + f"[get_all_tasks] inspector.active() took {elapsed:.3f}s") return result if result else {} def get_reserved(): t = time.time() - short_inspector = celery_app.control.inspect(timeout=short_timeout) + short_inspector = celery_app.control.inspect( + timeout=short_timeout) result = short_inspector.reserved() elapsed = time.time() - t - logger.info(f"[get_all_tasks] inspector.reserved() took {elapsed:.3f}s") + logger.info( + f"[get_all_tasks] inspector.reserved() took {elapsed:.3f}s") return result if result else {} with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: future_active = executor.submit(get_active) future_reserved = executor.submit(get_reserved) - active_tasks_dict = future_active.result(timeout=short_timeout + 0.5) - reserved_tasks_dict = future_reserved.result(timeout=short_timeout + 0.5) + active_tasks_dict = future_active.result( + timeout=short_timeout + 0.5) + reserved_tasks_dict = future_reserved.result( + timeout=short_timeout + 0.5) celery_duration = time.time() - celery_start if celery_duration > 0.5: - logger.warning(f"[get_all_tasks] Inspector took {celery_duration:.3f}s (expected <0.5s)") + logger.warning( + f"[get_all_tasks] Inspector took {celery_duration:.3f}s (expected <0.5s)") if active_tasks_dict: for worker, tasks in active_tasks_dict.items(): for task in tasks: task_id = task.get('id') if task_id: task_ids.add(task_id) - runtime_task_meta[task_id] = _normalize_runtime_meta(task) + runtime_task_meta[task_id] = _normalize_runtime_meta( + task) if reserved_tasks_dict: for worker, tasks in reserved_tasks_dict.items(): for task in tasks: @@ -214,7 +224,8 @@ def get_reserved(): if task_id: task_ids.add(task_id) # Keep active metadata if already present - runtime_task_meta.setdefault(task_id, _normalize_runtime_meta(task)) + runtime_task_meta.setdefault( + task_id, _normalize_runtime_meta(task)) # Get task IDs from Redis backend (covers completed/failed tasks within expiry) try: @@ -241,11 +252,14 @@ def get_reserved(): if not task_info.get('task_name') and runtime_meta.get('task_name'): task_info['task_name'] = runtime_meta.get('task_name') if not task_info.get('index_name') and runtime_meta.get('index_name'): - task_info['index_name'] = runtime_meta.get('index_name') + task_info['index_name'] = runtime_meta.get( + 'index_name') if not task_info.get('path_or_url') and runtime_meta.get('path_or_url'): - task_info['path_or_url'] = runtime_meta.get('path_or_url') + task_info['path_or_url'] = runtime_meta.get( + 'path_or_url') if not task_info.get('original_filename') and runtime_meta.get('original_filename'): - task_info['original_filename'] = runtime_meta.get('original_filename') + task_info['original_filename'] = runtime_meta.get( + 'original_filename') if filter and not (task_info.get('index_name') and task_info.get('task_name')): # Keep user-visible queued tasks even before worker updates task meta. @@ -538,30 +552,23 @@ async def create_batch_tasks_impl(self, authorization: Optional[str], request: B f"Missing required field 'index_name' in source config: {source_config}") continue - # Create and submit a chain: process -> forward - task_chain = chain( - process.s( - source=source, - source_type=source_type, - chunking_strategy=chunking_strategy, - index_name=index_name, - original_filename=original_filename, - embedding_model_id=embedding_model_id, - tenant_id=tenant_id - ).set(queue='process_q'), - forward.s( - index_name=index_name, - source=source, - source_type=source_type, - original_filename=original_filename, - authorization=authorization - ).set(queue='forward_q') + chain_id = submit_process_forward_chain( + source=source, + source_type=source_type, + chunking_strategy=chunking_strategy, + index_name=index_name, + original_filename=original_filename, + authorization=authorization, + embedding_model_id=embedding_model_id, + tenant_id=tenant_id, ) + if not chain_id: + logger.error( + f"Failed to enqueue process-forward chain for source: {source}") + continue - task_result = task_chain.apply_async() - - task_ids.append(task_result.id) - logger.debug(f"Created task {task_result.id} for source: {source}") + task_ids.append(chain_id) + logger.debug(f"Created task {chain_id} for source: {source}") logger.info( f"Created {len(task_ids)} individual tasks for batch processing") return task_ids @@ -593,7 +600,7 @@ async def process_uploaded_text_file(self, file_content: bytes, filename: str, c f"Processing uploaded file: {filename} using SDK DataProcessCore") data_processor = DataProcessCore() - chunks = data_processor.file_process( + chunks, _ = data_processor.file_process( file_data=file_content, filename=filename, chunking_strategy=chunking_strategy @@ -642,7 +649,8 @@ async def convert_office_to_pdf_impl(self, object_name: str, pdf_object_name: st # Step 1: Download original Office file from MinIO original_stream = get_file_stream(object_name) if original_stream is None: - raise OfficeConversionException(f"Source file not found in storage: {object_name}") + raise OfficeConversionException( + f"Source file not found in storage: {object_name}") original_filename = os.path.basename(object_name) input_path = os.path.join(temp_dir, original_filename) @@ -654,10 +662,12 @@ async def convert_office_to_pdf_impl(self, object_name: str, pdf_object_name: st try: pdf_path = await convert_office_to_pdf(input_path, temp_dir, timeout=30) except Exception as exc: - raise OfficeConversionException(f"LibreOffice conversion failed: {exc}") from exc + raise OfficeConversionException( + f"LibreOffice conversion failed: {exc}") from exc # Step 3: Upload converted PDF to MinIO - result = upload_file(file_path=pdf_path, object_name=pdf_object_name) + result = upload_file(file_path=pdf_path, + object_name=pdf_object_name) if not result.get('success'): raise OfficeConversionException( f"Failed to upload PDF to MinIO: {result.get('error', 'Unknown error')}" @@ -666,14 +676,16 @@ async def convert_office_to_pdf_impl(self, object_name: str, pdf_object_name: st # Step 4: Validate the uploaded PDF (header check + minimum size) remote_size = get_file_size_from_minio(pdf_object_name) if remote_size <= 0: - raise OfficeConversionException("PDF validation failed: cannot read remote file size") + raise OfficeConversionException( + "PDF validation failed: cannot read remote file size") if remote_size < 100: raise OfficeConversionException( f"PDF validation failed: file too small ({remote_size} bytes)" ) remote_stream = get_file_stream(pdf_object_name) if remote_stream is None: - raise OfficeConversionException("PDF validation failed: cannot read uploaded file") + raise OfficeConversionException( + "PDF validation failed: cannot read uploaded file") try: header = remote_stream.read(5) finally: @@ -682,7 +694,8 @@ async def convert_office_to_pdf_impl(self, object_name: str, pdf_object_name: st except Exception: pass if not header.startswith(b'%PDF-'): - raise OfficeConversionException("PDF validation failed: invalid PDF header") + raise OfficeConversionException( + "PDF validation failed: invalid PDF header") except OfficeConversionException: # Clean up any partially-uploaded remote PDF so a future retry starts clean @@ -690,14 +703,16 @@ async def convert_office_to_pdf_impl(self, object_name: str, pdf_object_name: st delete_file(pdf_object_name) raise except Exception as exc: - raise OfficeConversionException(f"Unexpected error during conversion: {exc}") from exc + raise OfficeConversionException( + f"Unexpected error during conversion: {exc}") from exc finally: # Step 5: Clean up local temporary directory if temp_dir and os.path.exists(temp_dir): try: shutil.rmtree(temp_dir) except Exception as cleanup_err: - logger.warning(f"Failed to cleanup temp dir '{temp_dir}': {cleanup_err}") + logger.warning( + f"Failed to cleanup temp dir '{temp_dir}': {cleanup_err}") def convert_celery_states_to_custom(self, process_celery_state: Optional[str], forward_celery_state: Optional[str]) -> str: """Map Celery task states to a custom frontend state string. diff --git a/backend/services/file_management_service.py b/backend/services/file_management_service.py index b2850403d..585669c0c 100644 --- a/backend/services/file_management_service.py +++ b/backend/services/file_management_service.py @@ -52,6 +52,27 @@ logger = logging.getLogger("file_management_service") +ALLOWED_SKILL_UPLOAD_ROOT = Path("/mnt/nexent").resolve() + + +def is_allowed_skill_upload_path(file_path: str) -> bool: + """Return True when a local file path is under the allowed skill upload root.""" + if not file_path: + return False + + try: + candidate_path = Path(file_path).resolve() + except Exception: + return False + + try: + candidate_path.relative_to(ALLOWED_SKILL_UPLOAD_ROOT) + return True + except ValueError: + return False + + + def resolve_minio_upload_folder( folder: Optional[str], @@ -83,6 +104,11 @@ def resolve_minio_upload_folder( if folder == "knowledge_base": return "knowledge_base" + if folder == "skill-files": + if user_id: + return f"skill-files/{user_id}" + return "skill-files" + if user_id: return f"attachments/{user_id}" @@ -101,7 +127,6 @@ def check_file_access( - knowledge_base/*: All authenticated users can access - attachments/{user_id}/*: Only the owner (user_id) can access - images_in_attachments/*: All authenticated users can access - - preview/*: Accessible if the original file is accessible Args: object_name: File object name in storage @@ -125,6 +150,10 @@ def check_file_access( # Keep them readable for authenticated users to avoid broken image citations. return True + if object_name.startswith("skill-files/"): + # Generated documents are private to the uploader and must stay user-scoped. + return object_name.startswith(f"skill-files/{user_id}/") + # Check if file is in user's attachments folder # Pattern: attachments/{user_id}/* if object_name.startswith(f"attachments/{user_id}/"): @@ -357,14 +386,20 @@ async def upload_to_minio( # Convert file content to BytesIO object file_obj = BytesIO(file_content) + # Store original filename before upload + original_filename = f.filename or "" + # Upload file result = upload_fileobj( file_obj=file_obj, - file_name=f.filename or "", + file_name=original_filename, prefix=actual_folder, file_size=len(file_content) ) + # Preserve original filename in result (upload_fileobj uses it for object name generation) + result["original_file_name"] = original_filename + # Reset file pointer for potential re-reading await f.seek(0) results.append(result) @@ -376,6 +411,7 @@ async def upload_to_minio( results.append({ "success": False, "file_name": f.filename, + "original_file_name": f.filename, "error": "An error occurred while processing the file." }) return results diff --git a/backend/services/northbound_service.py b/backend/services/northbound_service.py index a6eaed77d..c5493a551 100644 --- a/backend/services/northbound_service.py +++ b/backend/services/northbound_service.py @@ -1,31 +1,40 @@ import asyncio import hashlib +import json import logging import time from dataclasses import dataclass -from typing import Any, Dict, Optional +from os.path import basename +from typing import Any, Dict, List, Optional +from fastapi import HTTPException, UploadFile from fastapi.responses import StreamingResponse + +from consts.const import ASSET_OWNER_TENANT_ID from consts.exceptions import ( LimitExceededError, UnauthorizedError, + ConversationNotFoundError, ) -from consts.model import AgentRequest -from database.conversation_db import get_conversation_messages +from consts.model import AgentRequest, ToolParamsRequest +from database.conversation_db import get_conversation_messages, get_source_searches_by_message from database.token_db import log_token_usage, get_latest_usage_metadata from services.agent_service import ( run_agent_stream, stop_agent_tasks, - list_all_agent_info_impl, get_agent_id_by_name ) +from services.agent_version_service import list_published_agents_impl from services.conversation_management_service import ( save_conversation_user, get_conversation_list_service, create_new_conversation, update_conversation_title as update_conversation_title_service, ) +from services.file_management_service import upload_to_minio, resolve_minio_upload_folder, validate_urls_access +from database.attachment_db import get_file_url, get_file_size_from_minio +from nexent.multi_modal.utils import parse_s3_url logger = logging.getLogger("northbound_service") @@ -39,6 +48,188 @@ class NorthboundContext: token_id: int = 0 +def _build_northbound_file_descriptor( + upload_result: Dict[str, Any], + original_file_name: str = "", + file_type: Optional[str] = None, + file_size: Optional[int] = None, +) -> Dict[str, Any]: + """Normalize upload metadata for northbound API consumers.""" + object_name = str(upload_result.get("object_name") or "").strip() + # Use original filename if provided, otherwise fall back to upload result or object name + if original_file_name: + file_name = original_file_name + else: + file_name = str(upload_result.get("file_name") or basename(object_name) or "") + # Frontend-compatible field order + descriptor = { + "object_name": object_name, + "name": file_name, + "type": file_type or "file", + # Use provided file_size, or from upload_result, or 0 as fallback + "size": file_size if file_size is not None else upload_result.get("file_size", 0), + # Use relative URL format matching frontend: /nexent/{object_name} + "url": f"/nexent/{object_name}", + "description": "", + } + presigned_url = upload_result.get("presigned_url") + if presigned_url: + descriptor["presigned_url"] = presigned_url + return descriptor + + +async def upload_files_for_northbound( + ctx: NorthboundContext, + files: List[UploadFile], + folder: str = "attachments", +) -> Dict[str, Any]: + """Upload files for northbound callers and return reusable storage references.""" + if not files: + raise ValueError("No files in the request") + + actual_folder = resolve_minio_upload_folder(folder, ctx.user_id, ctx.tenant_id) + results = await upload_to_minio(files=files, folder=actual_folder) + normalized_files = [] + for result, upload_file in zip(results, files): + if result.get("success") and result.get("object_name"): + content_type = result.get("content_type", "") + file_type = "image" if content_type.startswith("image/") else "file" + # Extract original filename - use upload result first, then fallback to UploadFile + # The upload result contains the original filename passed to upload_fileobj + original_file_name = result.get("original_file_name") or upload_file.filename or "" + file_size = result.get("file_size", 0) + # If file_size is 0 but we have the UploadFile, try to get size from headers + if file_size == 0 and hasattr(upload_file, 'size') and upload_file.size: + file_size = upload_file.size + descriptor = _build_northbound_file_descriptor( + result, + original_file_name=original_file_name, + file_type=file_type, + file_size=file_size, + ) + normalized_files.append(descriptor) + + if not normalized_files: + raise ValueError("No valid files uploaded") + + success_count = sum(1 for result in results if result.get("success", False)) + failed_count = sum(1 for result in results if not result.get("success", False)) + + return { + "message": f"Processed {len(results)} files", + "requestId": ctx.request_id, + "summary": { + "total": len(results), + "uploaded": success_count, + "failed": failed_count, + }, + "files": normalized_files, + } + + +def _normalize_northbound_attachments( + attachments: Optional[List[Any]], + user_id: str, + tenant_id: str, +) -> Optional[List[Dict[str, Any]]]: + """Convert northbound attachment references into internal minio_files objects. + + Supports two formats: + 1. List of S3 URL strings (backward compatible): ["s3://nexent/...", "/nexent/...", "attachments/..."] + 2. List of attachment objects (full metadata): [{"object_name": "...", "name": "...", ...}] + """ + from database.attachment_db import _build_mcp_presigned_url + + if attachments is None: + return None + if not isinstance(attachments, list): + raise ValueError("attachments must be an array") + + normalized_files: List[Dict[str, Any]] = [] + for attachment in attachments: + # Handle dict format (full attachment object) + if isinstance(attachment, dict): + # Use the attachment dict directly, just ensure required fields + normalized_file = { + "object_name": attachment.get("object_name", ""), + "name": attachment.get("name", basename(attachment.get("object_name", ""))), + "type": attachment.get("type", "file"), + "size": attachment.get("size", 0), + "url": attachment.get("url", ""), + "description": attachment.get("description", ""), + } + # Add presigned_url if available, or generate one if we have object_name + if "presigned_url" in attachment: + normalized_file["presigned_url"] = attachment["presigned_url"] + elif normalized_file.get("object_name"): + try: + presigned_result = get_file_url(object_name=normalized_file["object_name"], expires=86400) + if presigned_result.get("success") and presigned_result.get("url"): + normalized_file["presigned_url"] = _build_mcp_presigned_url(presigned_result["url"]) + except Exception: + pass + normalized_files.append(normalized_file) + continue + + # Handle string format (S3 URL) + if not isinstance(attachment, str) or not attachment.strip(): + raise ValueError("attachments must contain non-empty S3 URLs or object paths") + + attachment_url = attachment.strip() + + # Support multiple URL formats: + # 1. s3://nexent/attachments/xxx.md + # 2. /nexent/attachments/xxx.md + # 3. attachments/xxx.md (relative path) + if attachment_url.startswith("s3://"): + try: + _, object_name = parse_s3_url(attachment_url) + except ValueError as exc: + raise ValueError(f"Invalid S3 URL format: {attachment_url}") from exc + validate_url = attachment_url + elif attachment_url.startswith("/nexent/"): + object_name = attachment_url[len("/nexent/"):] + validate_url = f"s3://nexent/{object_name}" + elif attachment_url.startswith("attachments/") or attachment_url.startswith("nexent/"): + object_name = attachment_url if attachment_url.startswith("nexent/") else attachment_url + validate_url = f"s3://nexent/{object_name}" + else: + raise ValueError(f"Invalid attachment format: {attachment_url}. Expected s3:// URL, /nexent/ path, or attachments/ path") + + try: + validate_urls_access([validate_url], user_id, tenant_id) + presigned_result = get_file_url(object_name=object_name, expires=86400) + except PermissionError as exc: + detail = str(exc) + if "Invalid S3 URL format" in detail: + raise ValueError(detail) from exc + raise PermissionError(detail) from exc + + # Get file size from MinIO + try: + file_size = get_file_size_from_minio(object_name) + except Exception: + file_size = 0 + + # Build frontend-compatible minio_files format + file_name = basename(object_name.rstrip("/")) + normalized_file = { + "object_name": object_name, + "name": file_name, + "type": "file", + "size": file_size, + # Use relative URL format matching frontend: /nexent/{object_name} + "url": f"/nexent/{object_name}", + "description": "", + } + # Use MCP proxy URL for presigned_url (same as frontend format) + if presigned_result.get("success") and presigned_result.get("url"): + normalized_file["presigned_url"] = _build_mcp_presigned_url(presigned_result["url"]) + normalized_files.append(normalized_file) + + return normalized_files + + # ----------------------------- # In-memory idempotency and rate limit placeholders # ----------------------------- @@ -111,6 +302,12 @@ def _build_idempotency_key(*parts: Any) -> str: return ":".join(processed) +def _build_title_update_idempotency_key(tenant_id: str, conversation_id: int, title: str) -> str: + """Build an ASCII-safe idempotency key for title updates.""" + title_hash = hashlib.sha256(title.encode("utf-8")).hexdigest() + return _build_idempotency_key(tenant_id, str(conversation_id), title_hash) + + # ----------------------------- # Agent resolver # ----------------------------- @@ -126,7 +323,9 @@ async def start_streaming_chat( conversation_id: Optional[int], agent_name: str, query: str, + attachments: Optional[List[Any]] = None, meta_data: Optional[Dict[str, Any]] = None, + tool_params: Optional[ToolParamsRequest] = None, idempotency_key: Optional[str] = None ) -> StreamingResponse: try: @@ -145,6 +344,11 @@ async def start_streaming_chat( # Get history according to internal_conversation_id history_resp = await get_conversation_history_internal(ctx, internal_conversation_id) agent_id = await get_agent_id_by_name(agent_name=agent_name, tenant_id=ctx.tenant_id) + normalized_attachments = _normalize_northbound_attachments( + attachments=attachments, + user_id=ctx.user_id, + tenant_id=ctx.tenant_id, + ) # Idempotency: only prevent concurrent duplicate starts composed_key = idempotency_key or _build_idempotency_key(ctx.tenant_id, str(conversation_id), agent_id, query) await idempotency_start(composed_key) @@ -153,8 +357,9 @@ async def start_streaming_chat( agent_id=agent_id, query=query, history=(history_resp.get("data", {})).get("history", []), - minio_files=None, + minio_files=normalized_attachments, is_debug=False, + tool_params=tool_params, ) # Synchronously persist the user message before starting the stream to avoid race conditions @@ -257,15 +462,58 @@ async def list_conversations(ctx: NorthboundContext) -> Dict[str, Any]: return {"message": "success", "data": conversations, "requestId": ctx.request_id} +def _format_search_record(record: Dict[str, Any]) -> Dict[str, Any]: + """Format a search source record for API response.""" + search_item = { + "title": record.get("source_title", ""), + "text": record.get("source_content", ""), + "source_type": record.get("source_type", ""), + "url": record.get("source_location", ""), + "filename": record.get("source_title", "") if record.get("source_type") == "file" else None, + "published_date": None, + "score": float(record["score_overall"]) if record.get("score_overall") is not None else None, + "tool_sign": record.get("tool_sign", ""), + "cite_index": record.get("cite_index") + } + + if record.get("published_date"): + if hasattr(record["published_date"], "strftime"): + search_item["published_date"] = record["published_date"].strftime("%Y-%m-%d") + else: + search_item["published_date"] = str(record["published_date"])[:10] + + return search_item + + async def get_conversation_history_internal(ctx: NorthboundContext, conversation_id: int) -> Dict[str, Any]: """Internal helper to get conversation history without logging.""" history = get_conversation_messages(conversation_id) - # Remove unnecessary fields result = [] for message in history: + # Parse minio_files from database (stored as JSON string) + minio_files = [] + raw_minio_files = message.get("minio_files") + if raw_minio_files: + try: + minio_files = json.loads(raw_minio_files) if isinstance(raw_minio_files, str) else raw_minio_files + except (json.JSONDecodeError, TypeError): + logger.warning(f"Failed to parse minio_files for message {message.get('message_id')}") + + # Fetch search results for this message + message_id = message.get("message_id") + search_results = [] + if message_id: + try: + search_records = get_source_searches_by_message(message_id, user_id=ctx.user_id) + search_results = [_format_search_record(r) for r in search_records] + except Exception as e: + logger.warning(f"Failed to get search records for message {message_id}: {str(e)}") + result.append({ "role": message["message_role"], - "content": message["message_content"] + "content": message["message_content"], + "minio_files": minio_files, + "search": search_results }) response = { @@ -284,7 +532,18 @@ async def get_conversation_history(ctx: NorthboundContext, conversation_id: int) async def get_agent_info_list(ctx: NorthboundContext) -> Dict[str, Any]: try: - agent_info_list = await list_all_agent_info_impl(tenant_id=ctx.tenant_id, user_id=ctx.user_id) + agent_info_list = await list_published_agents_impl( + tenant_id=ctx.tenant_id, + user_id=ctx.user_id, + ) + # Match the same scope as /agent/published_list: non-asset-owner tenants + # also get the asset owner's published agents merged in. + if ctx.tenant_id != ASSET_OWNER_TENANT_ID: + asset_agent_list = await list_published_agents_impl( + tenant_id=ASSET_OWNER_TENANT_ID, + user_id=ctx.user_id, + ) + agent_info_list.extend(asset_agent_list) # Remove internal information that partner don't need for agent_info in agent_info_list: agent_info.pop("agent_id", None) @@ -298,7 +557,11 @@ async def update_conversation_title(ctx: NorthboundContext, conversation_id: int composed_key: Optional[str] = None try: # Idempotency: avoid concurrent duplicate title update for same conversation - composed_key = idempotency_key or _build_idempotency_key(ctx.tenant_id, str(conversation_id), title) + composed_key = idempotency_key or _build_title_update_idempotency_key( + ctx.tenant_id, + conversation_id, + title, + ) await idempotency_start(composed_key) update_conversation_title_service(conversation_id, title, ctx.user_id) @@ -324,6 +587,8 @@ async def update_conversation_title(ctx: NorthboundContext, conversation_id: int } except LimitExceededError as _: raise LimitExceededError("Duplicate request is still running, please wait.") + except ConversationNotFoundError: + raise except Exception as e: raise Exception(f"Failed to update conversation title for conversation_id {conversation_id}: {str(e)}") finally: diff --git a/backend/services/prompt_service.py b/backend/services/prompt_service.py index ee9704302..f1564cdbc 100644 --- a/backend/services/prompt_service.py +++ b/backend/services/prompt_service.py @@ -1,15 +1,17 @@ import json import logging import queue +import sys import threading from typing import Optional, List from jinja2 import StrictUndefined, Template -from consts.const import LANGUAGE +from consts.const import LANGUAGE, ENABLE_JIUWEN_SDK from consts.error_code import ErrorCode from consts.error_message import ErrorMessage from consts.exceptions import AppException +from consts.model import AgentInfoRequest from database.agent_db import search_agent_info_by_agent_id, query_all_agent_info_by_tenant_id, \ query_sub_agents_id_list from database.model_management_db import get_model_by_model_id @@ -22,15 +24,31 @@ _regenerate_agent_name_with_llm, _regenerate_agent_display_name_with_llm, _generate_unique_agent_name_with_suffix, - _generate_unique_display_name_with_suffix + _generate_unique_display_name_with_suffix, + update_agent, ) from services.prompt_template_service import resolve_prompt_generate_template from utils.llm_utils import call_llm_for_system_prompt from utils.prompt_template_utils import ( - get_prompt_generate_prompt_template, get_prompt_optimize_prompt_template, + get_prompt_template, ) +from dataclasses import dataclass, field +from typing import Optional as Opt + +from adapters.exception import JiuwenSDKError, NexentCapabilityError + + +def _get_jiuwen_adapter_class(): + """Import Jiuwen adapter only when optimization paths need it.""" + try: + from adapters import JiuwenSDKAdapter + except ModuleNotFoundError: + return None + return JiuwenSDKAdapter + + # Configure logging logger = logging.getLogger("prompt_service") @@ -105,14 +123,16 @@ def generate_and_save_system_prompt_impl(agent_id: int, # Get knowledge base display names for few-shot examples # Priority: frontend-provided > database query if knowledge_base_display_names: - logger.debug(f"Using frontend-provided knowledge base display names: {knowledge_base_display_names}") + logger.debug( + f"Using frontend-provided knowledge base display names: {knowledge_base_display_names}") else: knowledge_base_display_names = get_knowledge_base_display_names( tool_info_list=tool_info_list, agent_id=agent_id, tenant_id=tenant_id ) - logger.debug(f"Using database query for knowledge base display names: {knowledge_base_display_names}") + logger.debug( + f"Using database query for knowledge base display names: {knowledge_base_display_names}") # Handle sub-agent IDs if sub_agent_ids and len(sub_agent_ids) > 0: @@ -146,7 +166,7 @@ def generate_and_save_system_prompt_impl(agent_id: int, # 1. Real-time streaming push final_results = {"duty": "", "constraint": "", "few_shots": "", "agent_var_name": "", "agent_display_name": "", - "agent_description": ""} + "agent_description": "", "greeting_message": "", "example_questions": ""} # Get all existing agent names and display names for duplicate checking (only if not in create mode) all_agents = query_all_agent_info_by_tenant_id(tenant_id) @@ -192,7 +212,8 @@ def generate_and_save_system_prompt_impl(agent_id: int, exclude_agent_id=agent_id, agents_cache=all_agents ): - logger.info(f"Agent name '{agent_name}' already exists, regenerating with LLM") + logger.info( + f"Agent name '{agent_name}' already exists, regenerating with LLM") try: agent_name = _regenerate_agent_name_with_llm( original_name=agent_name, @@ -206,10 +227,12 @@ def generate_and_save_system_prompt_impl(agent_id: int, prompt_template_id=prompt_template_id, user_id=user_id, ) - logger.info(f"Regenerated agent name: '{agent_name}'") + logger.info( + f"Regenerated agent name: '{agent_name}'") final_results["agent_var_name"] = agent_name except Exception as e: - logger.error(f"Failed to regenerate agent name with LLM: {str(e)}, using fallback") + logger.error( + f"Failed to regenerate agent name with LLM: {str(e)}, using fallback") # Fallback: add suffix agent_name = _generate_unique_agent_name_with_suffix( agent_name, @@ -235,7 +258,8 @@ def generate_and_save_system_prompt_impl(agent_id: int, exclude_agent_id=agent_id, agents_cache=all_agents ): - logger.info(f"Agent display_name '{agent_display_name}' already exists, regenerating with LLM") + logger.info( + f"Agent display_name '{agent_display_name}' already exists, regenerating with LLM") try: agent_display_name = _regenerate_agent_display_name_with_llm( original_display_name=agent_display_name, @@ -249,10 +273,12 @@ def generate_and_save_system_prompt_impl(agent_id: int, prompt_template_id=prompt_template_id, user_id=user_id, ) - logger.info(f"Regenerated agent display_name: '{agent_display_name}'") + logger.info( + f"Regenerated agent display_name: '{agent_display_name}'") final_results["agent_display_name"] = agent_display_name except Exception as e: - logger.error(f"Failed to regenerate agent display_name with LLM: {str(e)}, using fallback") + logger.error( + f"Failed to regenerate agent display_name with LLM: {str(e)}, using fallback") # Fallback: add suffix agent_display_name = _generate_unique_display_name_with_suffix( agent_display_name, @@ -285,6 +311,68 @@ def generate_and_save_system_prompt_impl(agent_id: int, if not has_content: raise Exception("Failed to generate prompt content.") + # 3. Generate greeting message and example questions + try: + greeting_template = get_prompt_template('greeting_generate', language) + greeting_system_prompt = greeting_template.get("GREETING_SYSTEM_PROMPT", "") + greeting_user_prompt_template = greeting_template.get("USER_PROMPT", "") + + greeting_user_prompt = Template(greeting_user_prompt_template, undefined=StrictUndefined).render({ + "display_name": final_results.get("agent_display_name", ""), + "duty_description": final_results.get("duty", ""), + "business_description": task_description, + "few_shots": final_results.get("few_shots", ""), + }) + + greeting_result = call_llm_for_system_prompt( + model_id=model_id, + user_prompt=greeting_user_prompt, + system_prompt=greeting_system_prompt, + tenant_id=tenant_id, + ) + + parsed = None + try: + json_start = greeting_result.find("{") + json_end = greeting_result.rfind("}") + 1 + if json_start >= 0 and json_end > json_start: + parsed = json.loads(greeting_result[json_start:json_end]) + except json.JSONDecodeError: + logger.warning(f"Failed to parse greeting JSON from LLM output: {greeting_result}") + + if parsed and "greeting_message" in parsed and "example_questions" in parsed: + greeting_message = parsed["greeting_message"] + example_questions = parsed["example_questions"] + if isinstance(example_questions, list) and len(example_questions) > 6: + example_questions = example_questions[:6] + else: + greeting_message = greeting_result.strip() if greeting_result else "" + example_questions = [] + + yield { + "type": "greeting_message", + "content": greeting_message, + "is_complete": True + } + yield { + "type": "example_questions", + "content": json.dumps(example_questions, ensure_ascii=False), + "is_complete": True + } + + final_results["greeting_message"] = greeting_message + final_results["example_questions"] = json.dumps(example_questions, ensure_ascii=False) + + # Update agent with greeting (skip in create mode) + if agent_id != 0: + update_agent(agent_id, AgentInfoRequest( + agent_id=agent_id, + greeting_message=greeting_message, + example_questions=example_questions, + ), user_id) + except Exception as e: + logger.warning(f"Greeting generation failed: {str(e)}, skipping greeting") + def optimize_prompt_section_impl( agent_id: int, model_id: int, @@ -339,7 +427,8 @@ def optimize_prompt_section_impl( prompt_context = join_info_for_optimize_prompt_section( prompt_for_optimize=prompt_template, section_type=normalized_section_type, - section_title=section_title or _default_prompt_section_title(normalized_section_type, language), + section_title=section_title or _default_prompt_section_title( + normalized_section_type, language), task_description=task_description, current_content=current_content, feedback=feedback, @@ -398,7 +487,8 @@ def generate_system_prompt(sub_agent_info_list, task_description, tool_info_list # If None or >= 6, no limit (all 6 calls run concurrently) # If < 6, use semaphore to limit concurrent calls model_config = get_model_by_model_id(model_id, tenant_id) - concurrency_limit = model_config.get("concurrency_limit") if model_config else None + concurrency_limit = model_config.get( + "concurrency_limit") if model_config else None # Start all generation threads with concurrency control threads, error_holder = _start_generation_threads( @@ -443,7 +533,8 @@ def _resolve_knowledge_base_display_names( agent_id=agent_id, tenant_id=tenant_id ) - logger.debug(f"Using database query for knowledge base display names: {resolved_names}") + logger.debug( + f"Using database query for knowledge base display names: {resolved_names}") return resolved_names @@ -471,8 +562,9 @@ def _resolve_prompt_generation_sub_agents( tenant_id=tenant_id, agent_id=agent_id ) + def _start_generation_threads(content, prompt_for_generate, produce_queue, latest, stop_flags, tenant_id, model_id, - has_selected_resources = True, concurrency_limit: Optional[int] = None): + has_selected_resources=True, concurrency_limit: Optional[int] = None): """Start all prompt generation threads with optional concurrency control.""" # Shared error tracking across threads error_holder = {"error": None} @@ -488,9 +580,11 @@ def _start_generation_threads(content, prompt_for_generate, produce_queue, lates effective_limit = concurrency_limit # Use semaphore if concurrency is limited - semaphore = threading.Semaphore(effective_limit) if effective_limit else None + semaphore = threading.Semaphore( + effective_limit) if effective_limit else None if semaphore: - logger.info(f"Using concurrency limit of {effective_limit} for prompt generation (total tasks: {total_tasks})") + logger.info( + f"Using concurrency limit of {effective_limit} for prompt generation (total tasks: {total_tasks})") else: logger.info("Using unlimited concurrency for prompt generation") @@ -539,7 +633,8 @@ def run_and_flag(tag, sys_prompt): ("few_shots", prompt_for_generate["few_shots_system_prompt"]), ]) else: - logger.info("Skipping constraint and few_shots generation: no tools or sub-agents selected") + logger.info( + "Skipping constraint and few_shots generation: no tools or sub-agents selected") # Mark these sections as already complete with empty content stop_flags["constraint"] = True stop_flags["few_shots"] = True @@ -638,13 +733,15 @@ def join_info_for_generate_system_prompt(prompt_for_generate, sub_agent_info_lis # This is necessary because Jinja2 StrictUndefined raises an error for any # undefined variable, even inside an {% if %} block. if knowledge_base_display_names: - kb_names_str = ", ".join(f'"{name}"' for name in knowledge_base_display_names) + kb_names_str = ", ".join( + f'"{name}"' for name in knowledge_base_display_names) else: kb_names_str = "" template_context["knowledge_base_names"] = kb_names_str # Generate content using template - content = Template(prompt_for_generate["user_prompt"], undefined=StrictUndefined).render(template_context) + content = Template( + prompt_for_generate["user_prompt"], undefined=StrictUndefined).render(template_context) return content @@ -672,7 +769,8 @@ def join_info_for_optimize_prompt_section( ) if knowledge_base_display_names: - kb_names_str = ", ".join(f'"{name}"' for name in knowledge_base_display_names) + kb_names_str = ", ".join( + f'"{name}"' for name in knowledge_base_display_names) else: kb_names_str = "" @@ -724,7 +822,8 @@ def get_knowledge_base_display_names(tool_info_list: List[dict], agent_id: int, List of knowledge base display names if knowledge_base_search tool is configured, None otherwise """ # Check if knowledge_base_search tool is in the list - kb_tool_ids = [tool['tool_id'] for tool in tool_info_list if tool.get('name') == 'knowledge_base_search'] + kb_tool_ids = [tool['tool_id'] for tool in tool_info_list if tool.get( + 'name') == 'knowledge_base_search'] if not kb_tool_ids: logger.debug("No knowledge_base_search tool found in tool list") return None @@ -747,19 +846,23 @@ def get_knowledge_base_display_names(tool_info_list: List[dict], agent_id: int, try: all_index_names.extend(json.loads(index_names)) except json.JSONDecodeError: - logger.warning(f"Failed to parse index_names JSON: {index_names}") + logger.warning( + f"Failed to parse index_names JSON: {index_names}") except Exception as e: - logger.warning(f"Failed to get tool instance for tool_id {kb_tool_id}: {e}") + logger.warning( + f"Failed to get tool instance for tool_id {kb_tool_id}: {e}") if not all_index_names: - logger.debug("No index_names configured for knowledge_base_search tool") + logger.debug( + "No index_names configured for knowledge_base_search tool") return None # Remove duplicates while preserving order unique_index_names = list(dict.fromkeys(all_index_names)) # Convert to display names - knowledge_name_map = get_knowledge_name_map_by_index_names(unique_index_names) + knowledge_name_map = get_knowledge_name_map_by_index_names( + unique_index_names) # Return list of display names (knowledge_name) for each configured index_name display_names = [] @@ -768,7 +871,8 @@ def get_knowledge_base_display_names(tool_info_list: List[dict], agent_id: int, if display_name and display_name not in display_names: display_names.append(display_name) - logger.debug(f"Converted index_names {unique_index_names} to display_names: {display_names}") + logger.debug( + f"Converted index_names {unique_index_names} to display_names: {display_names}") return display_names if display_names else None @@ -785,3 +889,299 @@ def get_enabled_sub_agent_description_for_generate_prompt(agent_id: int, tenant_ sub_agent_info_list.append(sub_agent_info) return sub_agent_info_list + + +# ── Jiuwen SDK 集成 ─────────────────────────────────────────────────────────── + + +@dataclass +class OptimizeRequest: + """优化请求的统一数据结构""" + agent_id: int + model_id: int + task_description: str + section_type: str + section_title: str + current_content: str + feedback: str + mode: str = "general" + start_pos: Opt[int] = None + end_pos: Opt[int] = None + tool_ids: Opt[list[int]] = None + sub_agent_ids: Opt[list[int]] = None + knowledge_base_display_names: Opt[list[str]] = None + + +@dataclass +class OptimizeResult: + """优化结果的统一数据结构""" + optimized_content: str + source: str + section_type: str = "" + section_title: str = "" + original_content: str = "" + + +class PromptOptimizationService: + """提示词优化服务 — 统一入口,模式二选一""" + + def optimize_from_debug(self, agent_id: int, feedback: str, selected, history=None) -> OptimizeResult: + """基于调试对话自动优化整个 system prompt(完整模板)。 + + Args: + selected: OptimizeFromDebugSelected (pydantic model) or any object with user_question/assistant_answer. + history: Optional[List[HistoryItem]] + """ + if not (feedback or "").strip(): + raise AppException( + ErrorCode.COMMON_MISSING_REQUIRED_FIELD, + "Optimization feedback is required.", + ) + + if not self.is_jiuwen_mode_available(): + raise NexentCapabilityError( + "Auto optimize from debug requires Jiuwen SDK to be enabled." + ) + + agent_info = search_agent_info_by_agent_id( + agent_id=agent_id, tenant_id=self.tenant_id, version_no=0) + + duty = (agent_info.get("duty_prompt") or "").strip() + constraint = (agent_info.get("constraint_prompt") or "").strip() + few_shots = (agent_info.get("few_shots_prompt") or "").strip() + + original_full_prompt = "\n\n".join( + [ + "# Duty\n" + duty, + "# Constraint\n" + constraint, + "# FewShots\n" + few_shots, + ] + ).strip() + + if not original_full_prompt: + raise AppException( + ErrorCode.COMMON_MISSING_REQUIRED_FIELD, + "Agent system prompt is empty.", + ) + + user_question = getattr(selected, "user_question", None) or ( + selected.get("user_question") if isinstance(selected, dict) else "") + assistant_answer = getattr(selected, "assistant_answer", None) or ( + selected.get("assistant_answer") if isinstance(selected, dict) else "") + + bad_case_obj = type("_BadCase", (), {}) + bc = bad_case_obj() + bc.question = user_question or "" + bc.answer = assistant_answer or "" + bc.label = "" + bc.reason = feedback + + adapter_cls = _get_jiuwen_adapter_class() + if adapter_cls is None: + raise JiuwenSDKError("Jiuwen SDK adapter is unavailable") + + adapter = adapter_cls( + model_id=self.model_id, tenant_id=self.tenant_id) + + optimized_full_prompt = adapter.optimize_badcase( + prompt=original_full_prompt, + bad_cases=[bc], + language=self.language, + ) + + return OptimizeResult( + optimized_content=optimized_full_prompt, + source="jiuwen", + section_type="full_prompt", + section_title="system_prompt", + original_content=original_full_prompt, + ) + + def __init__(self, model_id: int, tenant_id: str, language: str): + self.model_id = model_id + self.tenant_id = tenant_id + self.language = language + + def is_jiuwen_mode_available(self) -> bool: + """判断 Jiuwen SDK 模式是否可用""" + if not ENABLE_JIUWEN_SDK: + return False + + return _get_jiuwen_adapter_class() is not None + + def optimize(self, request: OptimizeRequest) -> OptimizeResult: + """统一优化入口 — 优先 Jiuwen SDK,失败则降级 nexent 原生""" + if self.is_jiuwen_mode_available(): + logger.info( + f"[prompt-optimize] mode={request.mode}, using Jiuwen SDK") + try: + return self._optimize_with_jiuwen(request) + except JiuwenSDKError as e: + logger.warning(f"Jiuwen SDK 模式失败,降级到 nexent 原生: {e}") + return self._optimize_with_nexent(request) + else: + return self._optimize_with_nexent(request) + + def _optimize_with_jiuwen(self, request: OptimizeRequest) -> OptimizeResult: + """Jiuwen SDK 模式""" + logger.info( + f"[jiuwen-optimize] mode={request.mode}, start_pos={request.start_pos}, " + f"end_pos={request.end_pos}, prompt_len={len(request.current_content)}, " + f"feedback_len={len(request.feedback)}" + ) + adapter_cls = _get_jiuwen_adapter_class() + if adapter_cls is None: + raise JiuwenSDKError("Jiuwen SDK adapter is unavailable") + + adapter = adapter_cls( + model_id=self.model_id, + tenant_id=self.tenant_id, + ) + result = adapter.optimize( + prompt=request.current_content, + feedback=request.feedback, + mode=request.mode, + start_pos=request.start_pos, + end_pos=request.end_pos, + language=self.language, + ) + + # Jiuwen insert/select mode returns a fragment by design. + # We reassemble the full prompt here so frontend always receives full optimized content. + if request.mode == "insert": + if request.start_pos is None or not isinstance(request.start_pos, int): + raise JiuwenSDKError("insert mode requires start_pos") + if request.start_pos < 0 or request.start_pos > len(request.current_content): + raise JiuwenSDKError("insert mode start_pos out of bounds") + optimized_full = ( + request.current_content[: request.start_pos] + + result + + request.current_content[request.start_pos:] + ) + elif request.mode == "select": + if request.start_pos is None or request.end_pos is None: + raise JiuwenSDKError( + "select mode requires start_pos and end_pos") + if not isinstance(request.start_pos, int) or not isinstance(request.end_pos, int): + raise JiuwenSDKError( + "select mode start_pos/end_pos must be int") + if request.start_pos < 0 or request.end_pos < 0 or request.start_pos >= request.end_pos: + raise JiuwenSDKError("select mode start_pos/end_pos invalid") + if request.end_pos > len(request.current_content): + raise JiuwenSDKError("select mode end_pos out of bounds") + optimized_full = ( + request.current_content[: request.start_pos] + + result + + request.current_content[request.end_pos:] + ) + else: + optimized_full = result + + return OptimizeResult( + optimized_content=optimized_full, + source="jiuwen", + section_type=request.section_type, + section_title=request.section_title, + original_content=request.current_content, + ) + + def _optimize_with_nexent(self, request: OptimizeRequest) -> OptimizeResult: + """nexent 原生模式 — 只支持 general 模式""" + if request.mode != "general": + raise NexentCapabilityError( + f"nexent 原生模式只支持 general 模式," + f"当前请求 mode={request.mode} 不支持,请启用 Jiuwen SDK" + ) + + result = optimize_prompt_section_impl( + agent_id=request.agent_id, + model_id=self.model_id, + task_description=request.task_description, + tenant_id=self.tenant_id, + language=self.language, + section_type=request.section_type, + section_title=request.section_title, + current_content=request.current_content, + feedback=request.feedback, + tool_ids=request.tool_ids, + sub_agent_ids=request.sub_agent_ids, + knowledge_base_display_names=request.knowledge_base_display_names, + ) + return OptimizeResult( + optimized_content=result["optimized_content"], + source="nexent", + section_type=result["section_type"], + section_title=result["section_title"], + original_content=result["original_content"], + ) + + def optimize_badcase( + self, + current_content: str, + bad_cases: list, + agent_id: int, + section_type: str, + section_title: str, + tool_ids: Opt[list[int]] = None, + sub_agent_ids: Opt[list[int]] = None, + knowledge_base_display_names: Opt[list[str]] = None, + ) -> OptimizeResult: + """坏案例优化入口 — 优先 Jiuwen SDK,失败则降级""" + if self.is_jiuwen_mode_available(): + logger.info("[prompt-badcase] using Jiuwen SDK") + try: + return self._optimize_badcase_with_jiuwen( + current_content, bad_cases, section_type, section_title + ) + except JiuwenSDKError as e: + logger.warning(f"Jiuwen SDK badcase 模式失败,降级到 nexent 原生: {e}") + return self._optimize_badcase_with_nexent( + current_content, bad_cases, agent_id, section_type, section_title, + tool_ids, sub_agent_ids, knowledge_base_display_names, + ) + else: + return self._optimize_badcase_with_nexent( + current_content, bad_cases, agent_id, section_type, section_title, + tool_ids, sub_agent_ids, knowledge_base_display_names, + ) + + def _optimize_badcase_with_jiuwen( + self, current_content: str, bad_cases: list, section_type: str, section_title: str + ) -> OptimizeResult: + """Jiuwen SDK 坏案例优化""" + adapter_cls = _get_jiuwen_adapter_class() + if adapter_cls is None: + raise JiuwenSDKError("Jiuwen SDK adapter is unavailable") + + adapter = adapter_cls( + model_id=self.model_id, + tenant_id=self.tenant_id, + ) + result = adapter.optimize_badcase( + prompt=current_content, + bad_cases=bad_cases, + language=self.language, + ) + return OptimizeResult( + optimized_content=result, + source="jiuwen", + section_type=section_type, + section_title=section_title, + original_content=current_content, + ) + + def _optimize_badcase_with_nexent( + self, + current_content: str, + bad_cases: list, + agent_id: int, + section_type: str, + section_title: str, + tool_ids: Opt[list[int]] = None, + sub_agent_ids: Opt[list[int]] = None, + knowledge_base_display_names: Opt[list[str]] = None, + ) -> OptimizeResult: + """nexent 原生模式不支持坏案例优化""" + raise NexentCapabilityError( + "nexent 原生模式不支持 badcase 优化,请启用 Jiuwen SDK" + ) diff --git a/backend/services/remote_mcp_service.py b/backend/services/remote_mcp_service.py index 56a73fb4b..7e77a9c43 100644 --- a/backend/services/remote_mcp_service.py +++ b/backend/services/remote_mcp_service.py @@ -230,7 +230,7 @@ async def add_mcp_service( server_url: str, tags: list | None, authorization_token: str | None, - custom_headers: dict | None, + custom_headers: dict | None = None, container_config: dict | None, registry_json: dict | None, enabled: bool = False, diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index ba51567dc..3cbf5edc5 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -782,6 +782,8 @@ def _validate_local_tool( 'embedding_model': embedding_model, 'rerank_model': rerank_model, 'display_name_to_index_map': display_name_to_index_map, + # Internal access control: restrict results to specific document paths (path_or_urls) + 'document_paths': instantiation_params.get('document_paths'), } tool_instance = tool_class(**params) elif tool_name in ["dify_search", "datamate_search"]: @@ -982,6 +984,7 @@ def import_openapi_service( tenant_id: str, user_id: str, service_description: str = None, + headers_template: Dict[str, Any] = None, force_update: bool = False ) -> Dict[str, Any]: """ @@ -995,6 +998,7 @@ def import_openapi_service( tenant_id: Tenant ID for multi-tenancy user_id: User ID for audit service_description: Optional service description (if not provided, reads from openapi_json.info.description) + headers_template: Optional default headers template force_update: If True, replace all existing tools for this service Returns: @@ -1015,7 +1019,8 @@ def import_openapi_service( server_url=server_url, tenant_id=tenant_id, user_id=user_id, - description=service_description + description=service_description, + headers_template=headers_template, ) logger.info(f"Imported service '{service_name}' for tenant {tenant_id}") diff --git a/backend/services/user_management_service.py b/backend/services/user_management_service.py index a983b25d3..0b38a76bc 100644 --- a/backend/services/user_management_service.py +++ b/backend/services/user_management_service.py @@ -18,6 +18,7 @@ get_supabase_admin_client, calculate_expires_at, get_jwt_expiry_seconds, + ensure_cas_session_active_from_authorization, resolve_tenant_id_from_user_tenant_record, ) from consts.const import ( @@ -107,6 +108,7 @@ def validate_token(token: str) -> Tuple[bool, Optional[Any]]: try: user = get_current_user_from_client(client, token) if user: + ensure_cas_session_active_from_authorization(token) return True, user return False, None except Exception as e: diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py index 11c5fd9bf..dd2f6e51a 100644 --- a/backend/services/vectordatabase_service.py +++ b/backend/services/vectordatabase_service.py @@ -10,6 +10,7 @@ 4. Health check interface """ import asyncio +import hashlib import json import logging import os @@ -28,7 +29,7 @@ from consts.const import DATAMATE_URL, ES_API_KEY, ES_HOST, LANGUAGE, VectorDatabaseType, IS_SPEED_MODE, PERMISSION_EDIT, PERMISSION_READ, ASSET_OWNER_TENANT_ID from consts.model import ChunkCreateRequest, ChunkUpdateRequest -from database.attachment_db import delete_file, get_file_stream +from database.attachment_db import delete_file, file_exists, get_file_stream from database.knowledge_db import ( create_knowledge_record, delete_knowledge_record, @@ -353,15 +354,18 @@ def get_embedding_model( tenant_id: Tenant ID model_name: Optional display name of the embedding model to use. If provided, will find the model by display_name in the tenant's model list. + model_type: Optional model type filter. When model_name is omitted, queries tenant + model records by this type; when model_type is also omitted, prefers + embedding models, then multi_embedding models. Returns: Tuple of (embedding model instance or None, model_id or None) """ if model_name: try: - normalized_model_type = _normalize_model_type(model_type) - if normalized_model_type: - model = get_model_by_display_name(model_name, tenant_id, normalized_model_type) + model_type = _normalize_model_type(model_type) + if model_type: + model = get_model_by_display_name(model_name, tenant_id, model_type) else: model = get_model_by_display_name(model_name, tenant_id) @@ -372,8 +376,25 @@ def get_embedding_model( return _create_embedding_model(model), model.get("model_id") except Exception as e: logger.warning(f"Failed to get embedding model by name {model_name}: {e}") + else: + try: + if model_type: + records = get_model_records({"model_type": model_type}, tenant_id) + else: + records = get_model_records({"model_type": "embedding"}, tenant_id) + if not records: + records = get_model_records({"model_type": "multi_embedding"}, tenant_id) + + if records: + model = records[0] + if model.get("model_type") in ["embedding", "multi_embedding"]: + return _create_embedding_model(model), model.get("model_id") + logger.warning( + f"Resolved model is not an embedding model: {model.get('model_type')}" + ) + except Exception as e: + logger.warning(f"Failed to get default embedding model for tenant {tenant_id}: {e}") - # No default fallback - return None, None when no model is specified or found return None, None @@ -636,6 +657,7 @@ def create_knowledge_base( group_ids: Optional[List[int]] = None, embedding_model_name: Optional[str] = None, is_multimodal: Optional[bool] = None, + preserve_source_file: Optional[bool] = None, ): """ Create a new knowledge base with a user-facing name and an internal Elasticsearch index name. @@ -655,6 +677,8 @@ def create_knowledge_base( group_ids: List of group IDs (optional) embedding_model_name: Specific embedding model name to use (optional). If provided, will use this model instead of tenant default. + preserve_source_file: Whether to preserve uploaded source documents after + vectorization (optional; defaults to True when omitted). For backward compatibility, legacy callers can still use create_index() directly with an explicit index_name. @@ -694,6 +718,8 @@ def create_knowledge_base( knowledge_data["ingroup_permission"] = ingroup_permission if group_ids is not None: knowledge_data["group_ids"] = group_ids + if preserve_source_file is not None: + knowledge_data["preserve_source_file"] = preserve_source_file record_info = create_knowledge_record(knowledge_data) index_name = record_info["index_name"] @@ -1091,6 +1117,7 @@ def list_indices( # Auto-summary settings "summary_frequency": record.get("summary_frequency"), "last_summary_time": record.get("last_summary_time"), + "preserve_source_file": record.get("preserve_source_file", True), "stats": index_stats, }) @@ -1488,6 +1515,11 @@ async def list_files( # chunk_count is already set from ES aggregation (doc_count) file_data['chunk_count'] = file_data.get('chunk_count', 0) + for file_data in files: + file_data["source_available"] = ( + ElasticSearchService._compute_source_available(file_data) + ) + total_duration = time.time() - total_start_time logger.info(f"[list_files:complete] index={index_name}, total_files={len(files)}, " f"total_duration={total_duration:.3f}s") @@ -1498,6 +1530,100 @@ async def list_files( raise Exception( f"Error getting file list for index {index_name}: {str(e)}") + DOCUMENT_DELETE_SCOPES = ("source_only", "full") + + @staticmethod + def _preview_pdf_cache_object_name(object_name: str) -> str: + """Object key for Office-to-PDF preview cache (matches file_management_service).""" + name_without_ext = ( + object_name.rsplit(".", 1)[0] if "." in object_name else object_name + ) + hash_suffix = hashlib.md5(object_name.encode()).hexdigest()[:8] + return f"preview/converted/{name_without_ext}_{hash_suffix}.pdf" + + @staticmethod + def _compute_source_available(file_data: Dict[str, Any]) -> bool: + path_or_url = file_data.get("path_or_url") or "" + status = file_data.get("status", "") + if status != "COMPLETED": + return True + if path_or_url.startswith("knowledge_base/"): + return file_exists(path_or_url) + return True + + @staticmethod + def delete_source_file(path_or_url: str) -> Dict[str, Any]: + """Remove MinIO source (and preview cache); does not touch Elasticsearch.""" + minio_result = delete_file(path_or_url) + deleted_minio = bool(minio_result.get("success")) + + if path_or_url.startswith("knowledge_base/"): + preview_key = ElasticSearchService._preview_pdf_cache_object_name( + path_or_url + ) + try: + if file_exists(preview_key): + delete_file(preview_key) + except Exception as exc: + logger.warning( + "Failed to delete preview cache for '%s': %s", + path_or_url, + exc, + ) + + return {"deleted_minio": deleted_minio} + + @staticmethod + async def _assert_source_only_deletable( + index_name: str, path_or_url: str + ) -> None: + celery_task_files = await get_all_files_status(index_name) + status_info = celery_task_files.get(path_or_url) + if not status_info or not isinstance(status_info, dict): + return + state = status_info.get("state") or "" + if state and state != "COMPLETED": + raise ValueError( + f"Cannot delete source file while document is in state '{state}'. " + "Wait until processing completes or use scope=full to remove the document." + ) + + @staticmethod + async def delete_document_by_scope( + index_name: str, + path_or_url: str, + scope: str, + vdb_core: VectorDatabaseCore, + ) -> Dict[str, Any]: + if scope not in ElasticSearchService.DOCUMENT_DELETE_SCOPES: + raise ValueError( + f"Invalid scope '{scope}'. " + f"Must be one of: {ElasticSearchService.DOCUMENT_DELETE_SCOPES}" + ) + + if scope == "source_only": + await ElasticSearchService._assert_source_only_deletable( + index_name, path_or_url + ) + minio_part = ElasticSearchService.delete_source_file(path_or_url) + return { + "status": "success", + "scope": scope, + "deleted_es_count": 0, + "deleted_minio": minio_part.get("deleted_minio", False), + "source_available": False, + "message": ( + "Source file deleted; index chunks and vectors preserved." + ), + } + + result = ElasticSearchService.delete_documents( + index_name, path_or_url, vdb_core + ) + result["scope"] = scope + result["source_available"] = False + return result + @staticmethod def delete_documents( index_name: str = Path(..., description="Name of the index"), diff --git a/backend/utils/auth_utils.py b/backend/utils/auth_utils.py index 04e81e6e3..a7194f050 100644 --- a/backend/utils/auth_utils.py +++ b/backend/utils/auth_utils.py @@ -326,16 +326,13 @@ def calculate_expires_at(token: Optional[str] = None) -> int: return int((datetime.now() + timedelta(seconds=expiry_seconds)).timestamp()) -def _extract_user_id_from_jwt_token(authorization: str) -> Optional[str]: +def _decode_jwt_token(authorization: str) -> dict: """ Extract user ID from JWT token after verifying signature and expiration. Args: authorization: Authorization header value - Returns: - Optional[str]: User ID, return None if parsing fails - Raises: UnauthorizedError: If token is invalid, expired, or signature verification fails """ @@ -355,17 +352,12 @@ def _extract_user_id_from_jwt_token(authorization: str) -> Optional[str]: # Decode and verify JWT (signature + expiration) # verify_aud=False: allow tokens with aud claim (e.g. test JWT, Supabase) without strict audience check - decoded = jwt.decode( + return jwt.decode( token, SUPABASE_JWT_SECRET, algorithms=["HS256"], options={"verify_exp": True, "verify_aud": False}, ) - - # Extract user ID from JWT claims - user_id = decoded.get("sub") - - return user_id except jwt.ExpiredSignatureError: logging.warning("Token expired") raise UnauthorizedError("Token has expired") @@ -378,10 +370,47 @@ def _extract_user_id_from_jwt_token(authorization: str) -> Optional[str]: except UnauthorizedError: raise except Exception as e: - logging.error(f"Failed to extract user ID from token: {str(e)}") + logging.error(f"Failed to decode token: {str(e)}") raise UnauthorizedError("Invalid or expired authentication token") +def _extract_user_id_from_jwt_token(authorization: str) -> Optional[str]: + """ + Extract user ID from JWT token after verifying signature and expiration. + """ + decoded = _decode_jwt_token(authorization) + return decoded.get("sub") + + +def extract_session_id_from_authorization(authorization: Optional[str]) -> Optional[str]: + """Extract the sid claim without enforcing token validity, for idempotent logout.""" + if not authorization: + return None + try: + token = ( + authorization.replace("Bearer ", "") + if authorization.startswith("Bearer ") + else authorization + ) + decoded = jwt.decode(token, options={"verify_signature": False}) + sid = decoded.get("sid") + return str(sid) if sid else None + except Exception: + return None + + +def ensure_cas_session_active_from_authorization(authorization: Optional[str]) -> None: + """Reject CAS-issued JWTs whose server-side session is expired or revoked.""" + session_id = extract_session_id_from_authorization(authorization) + if not session_id: + return + + from database.cas_session_db import is_cas_session_active + + if not is_cas_session_active(str(session_id)): + raise UnauthorizedError("CAS session has expired or been revoked") + + def get_current_user_id(authorization: Optional[str] = None) -> tuple[str, str]: """ Get current user ID and tenant ID from authorization token @@ -405,10 +434,13 @@ def get_current_user_id(authorization: Optional[str] = None) -> tuple[str, str]: raise UnauthorizedError("No authorization header provided") try: - user_id = _extract_user_id_from_jwt_token(authorization) + decoded = _decode_jwt_token(authorization) + user_id = decoded.get("sub") if not user_id: raise UnauthorizedError("Invalid or expired authentication token") + ensure_cas_session_active_from_authorization(authorization) + user_tenant_record = get_user_tenant_by_user_id(user_id) if user_tenant_record and user_tenant_record.get("tenant_id"): tenant_id = user_tenant_record["tenant_id"] @@ -421,6 +453,8 @@ def get_current_user_id(authorization: Optional[str] = None) -> tuple[str, str]: return user_id, tenant_id + except UnauthorizedError: + raise except Exception as e: logging.error(f"Failed to get user ID and tenant ID: {str(e)}") raise UnauthorizedError("Invalid or expired authentication token") @@ -472,7 +506,7 @@ def generate_test_jwt(user_id: str, expires_in: int = 3600) -> str: return jwt.encode(payload, MOCK_JWT_SECRET_KEY, algorithm="HS256") -def generate_session_jwt(user_id: str, expires_in: int = 3600) -> str: +def generate_session_jwt(user_id: str, expires_in: int = 3600, session_id: str = None) -> str: """Generate a signed JWT compatible with the existing auth verification flow.""" now = int(time.time()) payload = { @@ -483,6 +517,8 @@ def generate_session_jwt(user_id: str, expires_in: int = 3600) -> str: "exp": now + expires_in, "iss": SUPABASE_URL, } + if session_id: + payload["sid"] = session_id return jwt.encode(payload, SUPABASE_JWT_SECRET, algorithm="HS256") diff --git a/backend/utils/context_utils.py b/backend/utils/context_utils.py index 740bf66df..0c3af8915 100644 --- a/backend/utils/context_utils.py +++ b/backend/utils/context_utils.py @@ -8,7 +8,6 @@ allowing ContextManager to assemble them in the correct order. """ -from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional if TYPE_CHECKING: @@ -508,13 +507,12 @@ def _format_agent_fallback( return "- 当前没有可用的助手" if language == "zh" else "- No agents are currently available" -def _format_app_context(app_name: str, app_description: str, user_id: str, time_str: str) -> str: +def _format_app_context(app_name: str, app_description: str, user_id: str) -> str: """Format application context for system prompt injection.""" lines = [ f"Application: {app_name}", f"Description: {app_description}", f"Current user: {user_id}", - f"Current time: {time_str}", ] return "\n".join(lines) @@ -528,7 +526,6 @@ def _format_app_context(app_name: str, app_description: str, user_id: str, time_ def build_skeleton_header_component( app_name: str, app_description: str, - time_str: str, user_id: str, language: str = "zh", priority: int = 100, @@ -536,14 +533,17 @@ def build_skeleton_header_component( """Build SystemPromptComponent for the header section. Section: "### 基本信息" / "### Basic Information" - Content: Agent identity, app name/description, time, user_id + Content: Agent identity, app name/description, user_id. + Note: Current time is intentionally excluded from the system prompt so the + static system prefix can hit the LLM KV/prompt cache across requests. The + current time is injected on the user-message side instead (see CoreAgent.run). """ from nexent.core.agents.agent_model import SystemPromptComponent if language == "zh": - content = f"### 基本信息\n你是{app_name},{app_description},现在是{time_str},用户ID为{user_id}" + content = f"### 基本信息\n你是{app_name},{app_description},用户ID为{user_id}" else: - content = f"### Basic Information\nYou are {app_name}, {app_description}, it is {time_str} now" + content = f"### Basic Information\nYou are {app_name}, {app_description}" return SystemPromptComponent( content=content, @@ -611,6 +611,11 @@ def build_skeleton_execution_flow_component( lines.append(" - 注意运行的代码不会被用户看到,所以如果用户需要看到代码,你需要使用'代码'表达展示代码。") lines.append(" - **重要**:代码执行后,系统会返回 \"Observation:\" 标记的内容(这是真实的执行结果)。请基于这些真实结果继续下一步思考,**不要在代码执行前自行编造观察结果**。") lines.append("") + lines.append("3. 自验证:") + lines.append(" - 关键事件(工具调用、检索结果、代码执行、助手返回、准备最终回答)后,系统会进行显式自验证。") + lines.append(" - 如果自验证提示存在错误、证据不足、参数不完整或结果不可靠,必须优先修正、补充证据、重新调用工具,或清晰说明无法完成的部分。") + lines.append(" - 最终回答只有在自验证通过后才会展示给用户;如果系统返回 Verification feedback,请把它视为真实观察结果继续修正,不要忽略。") + lines.append("") lines.append("在思考结束后,当你认为可以回答用户问题,那么可以不生成代码,直接生成最终回答给到用户并停止循环。") lines.append("") lines.append("生成最终回答时,你需要遵循以下规范:") @@ -652,6 +657,11 @@ def build_skeleton_execution_flow_component( lines.append(" - Note that executed code is not visible to users. If users need to see the code, use 'code' for displaying code.") lines.append(" - **IMPORTANT**: After code execution, the system will return content with \"Observation:\" marker (this is the real execution result). Please continue your next thinking based on these real results. **Do NOT fabricate observation results before code execution.**") lines.append("") + lines.append("3. Self-verification:") + lines.append(" - After critical events (tool calls, retrieval results, code execution, agent handoffs, and final-answer preparation), the system may run explicit verification.") + lines.append(" - If verification reports errors, insufficient evidence, incomplete parameters, or unreliable results, you must repair the issue, gather more evidence, call tools again, or clearly state what cannot be completed.") + lines.append(" - The final answer is shown to the user only after verification passes. If the system returns Verification feedback, treat it as a real observation and continue revising.") + lines.append("") lines.append("After thinking, when you believe you can answer the user's question, you can generate a final answer directly to the user without generating code and stop the loop.") lines.append("") lines.append("When generating the final answer, you need to follow these specifications:") @@ -1112,7 +1122,6 @@ def build_context_components( few_shots: Optional[str] = None, app_name: Optional[str] = None, app_description: Optional[str] = None, - time_str: Optional[str] = None, user_id: Optional[str] = None, language: str = "zh", is_manager: bool = True, @@ -1167,7 +1176,6 @@ def build_context_components( few_shots: Example templates text app_name: Application name app_description: Application description - time_str: Current time string user_id: Current user ID language: Language code ('zh' or 'en') is_manager: Whether this is a manager agent @@ -1188,12 +1196,11 @@ def build_context_components( components: List = [] # 1. Header - if app_name and app_description and time_str and user_id: + if app_name and app_description and user_id: components.append( build_skeleton_header_component( app_name=app_name, app_description=app_description, - time_str=time_str, user_id=user_id, language=language, ) @@ -1328,5 +1335,4 @@ def build_app_context_string( Returns: Formatted app context string """ - time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - return _format_app_context(app_name, app_description, user_id, time_str) \ No newline at end of file + return _format_app_context(app_name, app_description, user_id) diff --git a/backend/utils/http_client_utils.py b/backend/utils/http_client_utils.py index 1c1d14af6..262c0a593 100644 --- a/backend/utils/http_client_utils.py +++ b/backend/utils/http_client_utils.py @@ -8,6 +8,7 @@ def create_httpx_client( headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, auth: httpx.Auth | None = None, + **kwargs, ) -> AsyncClient: return AsyncClient( headers=headers, @@ -15,4 +16,5 @@ def create_httpx_client( auth=auth, trust_env=False, verify=False, + **kwargs, ) diff --git a/backend/utils/memory_utils.py b/backend/utils/memory_utils.py index ada7019a1..e3ba01d6d 100644 --- a/backend/utils/memory_utils.py +++ b/backend/utils/memory_utils.py @@ -1,4 +1,5 @@ import logging +import re from typing import Dict, Any from urllib.parse import urlparse @@ -9,6 +10,11 @@ logger = logging.getLogger("memory_utils") +def _sanitize_index_component(value: str) -> str: + """Convert arbitrary text into an Elasticsearch-safe index component.""" + return re.sub(r"[^a-z0-9_.-]", "_", value.lower()) + + def build_memory_config(tenant_id: str) -> Dict[str, Any]: """Return a fully-validated configuration dictionary for *mem0* ``Memory``. """ @@ -30,9 +36,8 @@ def build_memory_config(tenant_id: str) -> Dict[str, Any]: es_host = f"{parsed.scheme}://{parsed.hostname}" es_port = parsed.port # Normalize repo/name to avoid problematic characters in index names - safe_repo = embed_raw["model_repo"].lower().replace( - "/", "_") if embed_raw["model_repo"] else "" - safe_name = embed_raw["model_name"].lower().replace("/", "_") + safe_repo = _sanitize_index_component(embed_raw["model_repo"]) if embed_raw["model_repo"] else "" + safe_name = _sanitize_index_component(embed_raw["model_name"]) index_name = ( f"mem0_{safe_repo}_{safe_name}_{embed_raw['max_tokens']}" if embed_raw["model_repo"] @@ -73,4 +78,4 @@ def build_memory_config(tenant_id: str) -> Dict[str, Any]: }, "telemetry": {"enabled": False}, } - return memory_config \ No newline at end of file + return memory_config diff --git a/backend/utils/prompt_template_utils.py b/backend/utils/prompt_template_utils.py index 8822e5fd4..299d3bf94 100644 --- a/backend/utils/prompt_template_utils.py +++ b/backend/utils/prompt_template_utils.py @@ -99,6 +99,10 @@ def get_prompt_template(template_type: str, language: str = LANGUAGE["ZH"], **kw LANGUAGE["ZH"]: 'backend/prompts/utils/generate_title_zh.yaml', LANGUAGE["EN"]: 'backend/prompts/utils/generate_title_en.yaml' }, + 'greeting_generate': { + LANGUAGE["ZH"]: 'backend/prompts/utils/greeting_generate_zh.yaml', + LANGUAGE["EN"]: 'backend/prompts/utils/greeting_generate_en.yaml' + }, 'document_summary': { LANGUAGE["ZH"]: 'backend/prompts/document_summary_agent_zh.yaml', LANGUAGE["EN"]: 'backend/prompts/document_summary_agent_en.yaml' diff --git a/doc/docs/en/quick-start/installation.md b/doc/docs/en/quick-start/installation.md index 0b1544819..7b6a9cb76 100644 --- a/doc/docs/en/quick-start/installation.md +++ b/doc/docs/en/quick-start/installation.md @@ -273,6 +273,114 @@ Provider enablement rules: For local Docker, a GitHub callback example is `http://localhost:3000/api/user/oauth/callback?provider=github`. In production, use a public HTTPS domain such as `https://nexent.example.com/api/user/oauth/callback?provider=github` and register the exact same URL in the OAuth provider console. +### CAS Login Configuration + +CAS SSO does not require the `supabase` component. Set `CAS_CALLBACK_BASE_URL` to the browser-accessible Nexent Web URL without a trailing `/`. `CAS_SERVER_URL` is the CAS Server root URL and should also not include a trailing `/`. + +For Docker, configure CAS in `docker/.env`: + +```bash +CAS_ENABLED=true +CAS_SERVER_URL=http://localhost:8080/cas +CAS_VALIDATE_PATH=/p3/serviceValidate +CAS_CALLBACK_BASE_URL=http://localhost:3000 + +# disabled: disable the CAS login entry and automatic redirects +# button: show CAS as an optional login button +# force: redirect unauthenticated Nexent users to CAS automatically +CAS_LOGIN_MODE=force + +# Empty means use ; set userName to read +CAS_USER_ATTRIBUTE= +CAS_EMAIL_ATTRIBUTE=email +CAS_ROLE_ATTRIBUTE=role +CAS_TENANT_ATTRIBUTE=tenant_id +CAS_ROLE_MAP_JSON={"cas-admin":"ADMIN","cas-user":"USER"} +CAS_SESSION_MAX_AGE_SECONDS=3600 +LOCAL_SESSION_MAX_AGE_SECONDS=3600 +CAS_RENEW_BEFORE_SECONDS=300 +CAS_RENEW_TIMEOUT_SECONDS=10 +CAS_SYNTHETIC_EMAIL_DOMAIN=cas.local + +# Empty means Nexent logout will not call the CAS Server logout endpoint. +# /logout is resolved against CAS_SERVER_URL. +CAS_LOGOUT_URL=/logout +CAS_SSL_VERIFY=true +CAS_CA_BUNDLE= +``` + +Common CAS URLs: + +| Purpose | URL | +|---------|-----| +| Nexent login entry | `{CAS_CALLBACK_BASE_URL}/api/user/cas/login?redirect=/` | +| CAS service callback | `{CAS_CALLBACK_BASE_URL}/api/user/cas/callback` | +| CAS silent renewal callback | `{CAS_CALLBACK_BASE_URL}/api/user/cas/renew_callback` | +| CAS single logout callback | `POST {CAS_CALLBACK_BASE_URL}/api/user/cas/logout_callback` | + +For Apereo CAS JSON Service Registry, create a service registration file such as `Nexent-10001.json` in the service registry directory configured by your CAS deployment. The `id` must be globally unique. This is a local Docker example: + +```json +{ + "@class": "org.apereo.cas.services.RegexRegisteredService", + "serviceId": "http://localhost:3000.*", + "name": "Nexent CAS Client", + "id": 10001, + "description": "Nexent CAS SSO client", + "evaluationOrder": 1, + "logoutType": "BACK_CHANNEL", + "logoutUrl": "http://localhost:3000/api/user/cas/logout_callback" +} +``` + +In production, keep `CAS_SSL_VERIFY=true`; for self-signed certificates, prefer `CAS_CA_BUNDLE` and only use `CAS_SSL_VERIFY=false` for local testing. + +#### CAS Integration with ModelEngine + +When integrating with ModelEngine through the CAS protocol, deploy Nexent with the following configuration: + +```bash +CAS_ENABLED=true +CAS_SERVER_URL=https://:5443/SSOSvr +CAS_VALIDATE_PATH=/p3/serviceValidate +CAS_CALLBACK_BASE_URL=http://:3000 +CAS_LOGIN_MODE=force +CAS_USER_ATTRIBUTE=userName +CAS_EMAIL_ATTRIBUTE=email +CAS_ROLE_ATTRIBUTE=userType +CAS_TENANT_ATTRIBUTE=tenant_id +CAS_ROLE_MAP_JSON={"1":"ADMIN","3":"DEV"} +CAS_SESSION_MAX_AGE_SECONDS=3600 +LOCAL_SESSION_MAX_AGE_SECONDS=3600 +CAS_RENEW_BEFORE_SECONDS=300 +CAS_RENEW_TIMEOUT_SECONDS=10 +CAS_SYNTHETIC_EMAIL_DOMAIN=cas.local +CAS_LOGOUT_URL=/logout?service=http://:3000 +CAS_SSL_VERIFY=false +CAS_CA_BUNDLE= +``` + +You also need to add a CAS client service registration file in the OMS container. Use the following steps as a reference: + +```bash +# Create the registration file, paste the JSON content into it, and save it. +vim Nexent-10000001.json +{ + "@class": "org.apereo.cas.services.CasRegisteredService", + "serviceId": "http://:3000.*", + "name": "Nexent CAS Client", + "id": 1000001, + "description": "Nexent CAS SSO client", + "evaluationOrder": 1, + "logoutType": "BACK_CHANNEL", + "logoutUrl": "http://:3000/api/user/cas/logout_callback" +} + +# Run the following command to copy the registration file into the container. +kubectl cp Nexent-10000001.json model-engine/$(kubectl get pods -n model-engine -l app=oms --no-headers | awk '{print $1}'):/opt/huawei/fce/apps/platform/webapps/SSOSvr/WEB-INF/classes/services/Nexent-10000001.json +kubectl exec -i -n model-engine $(kubectl get pods -n model-engine -l app=oms --no-headers | awk '{print $1}') -- chown tomcat:fusioncube /opt/huawei/fce/apps/platform/webapps/SSOSvr/WEB-INF/classes/services/Nexent-10000001.json +``` + ### Northbound Interface Configuration (NORTHBOUND_EXTERNAL_URL) If you need to use any of the following features, configure the `NORTHBOUND_EXTERNAL_URL` environment variable: diff --git a/doc/docs/en/quick-start/kubernetes-installation.md b/doc/docs/en/quick-start/kubernetes-installation.md index 8253c411f..a10873c7c 100644 --- a/doc/docs/en/quick-start/kubernetes-installation.md +++ b/doc/docs/en/quick-start/kubernetes-installation.md @@ -291,6 +291,122 @@ Provider callback URLs: For local NodePort, a GitHub callback example is `http://localhost:30000/api/user/oauth/callback?provider=github`. In production, use a public HTTPS domain and register the exact same URL in the OAuth provider console. +### CAS Login Configuration + +CAS SSO does not require the `supabase` component. Set `nexent-common.config.cas.callbackBaseUrl` to the browser-accessible Nexent Web URL without a trailing `/`. `nexent-common.config.cas.serverUrl` is the CAS Server root URL and should also not include a trailing `/`. + +Kubernetes writes CAS settings into backend environment variables through `nexent-common` `config.cas.*` values: + +```bash +helm upgrade --install nexent nexent \ + --namespace nexent --create-namespace \ + --set nexent-common.config.cas.enabled=true \ + --set nexent-common.config.cas.serverUrl=https://cas.example.com/cas \ + --set nexent-common.config.cas.callbackBaseUrl=https://nexent.example.com \ + --set nexent-common.config.cas.loginMode=force \ + --set nexent-common.config.cas.logoutUrl=/logout +``` + +Configurable CAS values: + +| Value | Environment variable | Description | +|-------|----------------------|-------------| +| `nexent-common.config.cas.enabled` | `CAS_ENABLED` | Enables CAS | +| `nexent-common.config.cas.serverUrl` | `CAS_SERVER_URL` | CAS Server root URL | +| `nexent-common.config.cas.validatePath` | `CAS_VALIDATE_PATH` | serviceValidate path, default `/p3/serviceValidate` | +| `nexent-common.config.cas.callbackBaseUrl` | `CAS_CALLBACK_BASE_URL` | Web entry URL; CAS callback paths are appended automatically | +| `nexent-common.config.cas.loginMode` | `CAS_LOGIN_MODE` | `disabled`, `button`, or `force` | +| `nexent-common.config.cas.userAttribute` | `CAS_USER_ATTRIBUTE` | User identifier attribute. Empty means use `` | +| `nexent-common.config.cas.emailAttribute` | `CAS_EMAIL_ATTRIBUTE` | Email attribute | +| `nexent-common.config.cas.roleAttribute` | `CAS_ROLE_ATTRIBUTE` | Role attribute | +| `nexent-common.config.cas.tenantAttribute` | `CAS_TENANT_ATTRIBUTE` | Tenant attribute | +| `nexent-common.config.cas.roleMapJson` | `CAS_ROLE_MAP_JSON` | JSON mapping from CAS roles to Nexent roles | +| `nexent-common.config.cas.sessionMaxAgeSeconds` | `CAS_SESSION_MAX_AGE_SECONDS` | Maximum local CAS session lifetime | +| `nexent-common.config.cas.localSessionMaxAgeSeconds` | `LOCAL_SESSION_MAX_AGE_SECONDS` | Nexent local session lifetime | +| `nexent-common.config.cas.renewBeforeSeconds` | `CAS_RENEW_BEFORE_SECONDS` | Trigger silent renewal within this many seconds before expiry | +| `nexent-common.config.cas.renewTimeoutSeconds` | `CAS_RENEW_TIMEOUT_SECONDS` | Silent renewal timeout | +| `nexent-common.config.cas.syntheticEmailDomain` | `CAS_SYNTHETIC_EMAIL_DOMAIN` | Domain used when CAS does not return an email | +| `nexent-common.config.cas.logoutUrl` | `CAS_LOGOUT_URL` | CAS logout URL. Empty means Nexent logout will not call the CAS Server logout endpoint | +| `nexent-common.config.cas.sslVerify` | `CAS_SSL_VERIFY` | Whether to verify CAS Server TLS certificates | +| `nexent-common.config.cas.caBundle` | `CAS_CA_BUNDLE` | Custom CA bundle path | + +Common CAS URLs: + +| Purpose | URL | +|---------|-----| +| Nexent login entry | `{CAS_CALLBACK_BASE_URL}/api/user/cas/login?redirect=/` | +| CAS service callback | `{CAS_CALLBACK_BASE_URL}/api/user/cas/callback` | +| CAS silent renewal callback | `{CAS_CALLBACK_BASE_URL}/api/user/cas/renew_callback` | +| CAS single logout callback | `POST {CAS_CALLBACK_BASE_URL}/api/user/cas/logout_callback` | + +For Apereo CAS JSON Service Registry, create a service registration file such as `Nexent-10001.json` in the service registry directory configured by your CAS deployment. The `id` must be globally unique. This is a local NodePort example: + +```json +{ + "@class": "org.apereo.cas.services.RegexRegisteredService", + "serviceId": "http://localhost:30000.*", + "name": "Nexent CAS Client", + "id": 10001, + "description": "Nexent CAS SSO client", + "evaluationOrder": 1, + "logoutType": "BACK_CHANNEL", + "logoutUrl": "http://localhost:30000/api/user/cas/logout_callback" +} +``` + +In production, keep `CAS_SSL_VERIFY=true`; for self-signed certificates, prefer `CAS_CA_BUNDLE` and only use `CAS_SSL_VERIFY=false` for local testing. + +#### CAS Integration with ModelEngine + +When integrating with ModelEngine through the CAS protocol, use a values file to configure Nexent. This avoids complex command-line escaping for `CAS_ROLE_MAP_JSON`. + +Create `cas-modelengine-values.yaml`: + +```yaml +nexent-common: + config: + cas: + enabled: true + serverUrl: "https://:5443/SSOSvr" + validatePath: "/p3/serviceValidate" + callbackBaseUrl: "http://:30000" + loginMode: "force" + userAttribute: "userName" + emailAttribute: "email" + roleAttribute: "userType" + tenantAttribute: "tenant_id" + roleMapJson: '{"1":"ADMIN","3":"DEV"}' + sessionMaxAgeSeconds: 3600 + localSessionMaxAgeSeconds: 3600 + renewBeforeSeconds: 300 + renewTimeoutSeconds: 10 + syntheticEmailDomain: "cas.local" + logoutUrl: "/logout?service=http://:30000" + sslVerify: false + caBundle: "" +``` + +You also need to add a CAS client service registration file in the OMS container. Use the following steps as a reference: + +```bash +# Create the registration file, paste the JSON content into it, and save it. +vim Nexent-10000001.json +{ + "@class": "org.apereo.cas.services.CasRegisteredService", + "serviceId": "http://:30000.*", + "name": "Nexent CAS Client", + "id": 1000001, + "description": "Nexent CAS SSO client", + "evaluationOrder": 1, + "logoutType": "BACK_CHANNEL", + "logoutUrl": "http://:30000/api/user/cas/logout_callback" +} + +# Run the following command to copy the registration file into the container. +kubectl cp Nexent-10000001.json model-engine/$(kubectl get pods -n model-engine -l app=oms --no-headers | awk '{print $1}'):/opt/huawei/fce/apps/platform/webapps/SSOSvr/WEB-INF/classes/services/Nexent-10000001.json +kubectl exec -i -n model-engine $(kubectl get pods -n model-engine -l app=oms --no-headers | awk '{print $1}') -- chown tomcat:fusioncube /opt/huawei/fce/apps/platform/webapps/SSOSvr/WEB-INF/classes/services/Nexent-10000001.json +``` + ## 🔍 Troubleshooting ### Check Pod Status diff --git a/doc/docs/en/user-guide/agent-development.md b/doc/docs/en/user-guide/agent-development.md index 7637cd620..8e6b47d4f 100644 --- a/doc/docs/en/user-guide/agent-development.md +++ b/doc/docs/en/user-guide/agent-development.md @@ -111,6 +111,18 @@ In the External A2A Agent list, you can view and manage all discovered external > - Batch integrate all agents from the same service registry through Nacos discovery > - Configure protocols to meet the requirements of different agent service providers +###### Integrate [DataAgent](https://gitcode.com/datagallery/dataagent) A2A Agent via URL + +1. Refer to the [DataAgent documentation](https://gitcode.com/datagallery/dataagent#%F0%9F%8C%90-a2a-10-%E6%9C%8D%E5%8A%A1%E6%A8%A1%E5%BC%8F) and start DataAgent in A2A service mode. + > Nexent does not currently support agents that require authentication. Do not set `auth-token` when starting DataAgent. + +
+ +
+ +2. Refer to [Discover Agent via URL](#discover-agent-via-url) to integrate the agent. The URL is `http://:9999/.well-known/agent-card.json`. +3. Refer to [Manage Discovered External Agents](#manage-discovered-external-agents) to configure the invocation protocol, and select HTTP + JSON for integration. + ### 🛠️ Select Agent Tools Agents can use various tools to complete tasks, such as knowledge base search, file parsing, image parsing, email sending/receiving, file management, and other local tools. They can also integrate third-party MCP tools or custom tools. diff --git a/doc/docs/en/user-guide/assets/agent-development/dataagent_deploy.png b/doc/docs/en/user-guide/assets/agent-development/dataagent_deploy.png new file mode 100644 index 000000000..46fa9fde3 Binary files /dev/null and b/doc/docs/en/user-guide/assets/agent-development/dataagent_deploy.png differ diff --git a/doc/docs/zh/quick-start/installation.md b/doc/docs/zh/quick-start/installation.md index 871cae0cc..6d3538b90 100644 --- a/doc/docs/zh/quick-start/installation.md +++ b/doc/docs/zh/quick-start/installation.md @@ -269,6 +269,111 @@ Provider 启用规则: 本地默认回调示例为 `http://localhost:3000/api/user/oauth/callback?provider=github`。生产环境应改为公网 HTTPS 域名,例如 `https://nexent.example.com/api/user/oauth/callback?provider=github`,并在 OAuth provider 控制台中登记相同地址。 +### CAS 登录配置 + +CAS SSO 不依赖 `supabase`。启用 CAS 时,请将 `CAS_CALLBACK_BASE_URL` 设置为浏览器可访问的 Nexent Web 地址,且不要带结尾 `/`。`CAS_SERVER_URL` 是 CAS Server 根地址,也不要带结尾 `/`。 + +Docker 部署在 `docker/.env` 中配置 CAS: + +```bash +CAS_ENABLED=true +CAS_SERVER_URL=http://localhost:8080/cas +CAS_VALIDATE_PATH=/p3/serviceValidate +CAS_CALLBACK_BASE_URL=http://localhost:3000 + +# disabled: 禁用 CAS 登录入口和自动跳转 +# button: 在登录页显示 CAS 登录按钮 +# force: 未登录访问 Nexent 时自动跳转到 CAS +CAS_LOGIN_MODE=force + +# 为空时使用 ;填写 userName 时从 取用户标识 +CAS_USER_ATTRIBUTE= +CAS_EMAIL_ATTRIBUTE=email +CAS_ROLE_ATTRIBUTE=role +CAS_TENANT_ATTRIBUTE=tenant_id +CAS_ROLE_MAP_JSON={"cas-admin":"ADMIN","cas-user":"USER"} +CAS_SESSION_MAX_AGE_SECONDS=3600 +LOCAL_SESSION_MAX_AGE_SECONDS=3600 +CAS_RENEW_BEFORE_SECONDS=300 +CAS_RENEW_TIMEOUT_SECONDS=10 +CAS_SYNTHETIC_EMAIL_DOMAIN=cas.local + +# 为空时 Nexent 主动退出不会调用 CAS Server 登出接口。 +# 可配置为 /logout,系统会基于 CAS_SERVER_URL 拼接。 +CAS_LOGOUT_URL=/logout +CAS_SSL_VERIFY=true +CAS_CA_BUNDLE= +``` + +常用 CAS 地址: + +| 用途 | 地址 | +|------|------| +| Nexent 登录入口 | `{CAS_CALLBACK_BASE_URL}/api/user/cas/login?redirect=/` | +| CAS service 回调 | `{CAS_CALLBACK_BASE_URL}/api/user/cas/callback` | +| CAS 无感续期回调 | `{CAS_CALLBACK_BASE_URL}/api/user/cas/renew_callback` | +| CAS 单点登出回调 | `POST {CAS_CALLBACK_BASE_URL}/api/user/cas/logout_callback` | + +Apereo CAS 使用 JSON Service Registry 时,可以新增一个服务注册文件,例如 `Nexent-10001.json`。文件需要放到 CAS 部署配置的 service registry 目录中,`id` 必须全局唯一。下面是本地 Docker 示例: + +```json +{ + "@class": "org.apereo.cas.services.RegexRegisteredService", + "serviceId": "http://localhost:3000.*", + "name": "Nexent CAS Client", + "id": 10001, + "description": "Nexent CAS SSO client", + "evaluationOrder": 1, + "logoutType": "BACK_CHANNEL", + "logoutUrl": "http://localhost:3000/api/user/cas/logout_callback" +} +``` + +生产环境建议保持 `CAS_SSL_VERIFY=true`;自签名证书优先配置 `CAS_CA_BUNDLE`,仅本地验证时再临时设置 `CAS_SSL_VERIFY=false`。 + +#### CAS对接ModelEngine +当使用CAS协议对接ModelEngine时,可以使用如下配置部署Nexent: +```bash +CAS_ENABLED=true +CAS_SERVER_URL=https://:5443/SSOSvr +CAS_VALIDATE_PATH=/p3/serviceValidate +CAS_CALLBACK_BASE_URL=http://:3000 +CAS_LOGIN_MODE=force +CAS_USER_ATTRIBUTE=userName +CAS_EMAIL_ATTRIBUTE=email +CAS_ROLE_ATTRIBUTE=userType +CAS_TENANT_ATTRIBUTE=tenant_id +CAS_ROLE_MAP_JSON={"1":"ADMIN","3":"DEV"} +CAS_SESSION_MAX_AGE_SECONDS=3600 +LOCAL_SESSION_MAX_AGE_SECONDS=3600 +CAS_RENEW_BEFORE_SECONDS=300 +CAS_RENEW_TIMEOUT_SECONDS=10 +CAS_SYNTHETIC_EMAIL_DOMAIN=cas.local +CAS_LOGOUT_URL=/logout?service=http://:3000 +CAS_SSL_VERIFY=false +CAS_CA_BUNDLE= +``` + +同时,需要进入oms容器添加cas client的注册配置文件,参考如下步骤: +```bash +# 创建注册配置文件,将json部分输入文件并保存 +vim Nexent-10000001.json +{ + "@class": "org.apereo.cas.services.CasRegisteredService", + "serviceId": "http://:3000.*", + "name": "Nexent CAS Client", + "id": 1000001, + "description": "Nexent CAS SSO client", + "evaluationOrder": 1, + "logoutType": "BACK_CHANNEL", + "logoutUrl": "http://:3000/api/user/cas/logout_callback" +} + +# 执行如下命令,将配置文件拷贝到容器中 +kubectl cp Nexent-10000001.json model-engine/$(kubectl get pods -n model-engine -l app=oms --no-headers | awk '{print $1}'):/opt/huawei/fce/apps/platform/webapps/SSOSvr/WEB-INF/classes/services/Nexent-10000001.json +kubectl exec -i -n model-engine $(kubectl get pods -n model-engine -l app=oms --no-headers | awk '{print $1}') -- chown tomcat:fusioncube /opt/huawei/fce/apps/platform/webapps/SSOSvr/WEB-INF/classes/services/Nexent-10000001.json +``` + ### 北向接口配置 (NORTHBOUND_EXTERNAL_URL) 如果您需要使用以下功能,需要配置 `NORTHBOUND_EXTERNAL_URL` 环境变量: diff --git a/doc/docs/zh/quick-start/kubernetes-installation.md b/doc/docs/zh/quick-start/kubernetes-installation.md index 47d2af816..7229f1ea8 100644 --- a/doc/docs/zh/quick-start/kubernetes-installation.md +++ b/doc/docs/zh/quick-start/kubernetes-installation.md @@ -291,6 +291,122 @@ Provider 回调地址: 本地 NodePort 默认回调示例为 `http://localhost:30000/api/user/oauth/callback?provider=github`。生产环境应改为公网 HTTPS 域名,并在 OAuth provider 控制台中登记相同地址。 +### CAS 登录配置 + +CAS SSO 不依赖 `supabase`。启用 CAS 时,请将 `nexent-common.config.cas.callbackBaseUrl` 设置为浏览器可访问的 Nexent Web 地址,且不要带结尾 `/`。`nexent-common.config.cas.serverUrl` 是 CAS Server 根地址,也不要带结尾 `/`。 + +Kubernetes 部署通过 `nexent-common` 的 `config.cas.*` values 写入后端环境变量: + +```bash +helm upgrade --install nexent nexent \ + --namespace nexent --create-namespace \ + --set nexent-common.config.cas.enabled=true \ + --set nexent-common.config.cas.serverUrl=https://cas.example.com/cas \ + --set nexent-common.config.cas.callbackBaseUrl=https://nexent.example.com \ + --set nexent-common.config.cas.loginMode=force \ + --set nexent-common.config.cas.logoutUrl=/logout +``` + +可配置的 CAS values: + +| Values | 对应环境变量 | 说明 | +|--------|--------------|------| +| `nexent-common.config.cas.enabled` | `CAS_ENABLED` | 是否启用 CAS | +| `nexent-common.config.cas.serverUrl` | `CAS_SERVER_URL` | CAS Server 根地址 | +| `nexent-common.config.cas.validatePath` | `CAS_VALIDATE_PATH` | serviceValidate 路径,默认 `/p3/serviceValidate` | +| `nexent-common.config.cas.callbackBaseUrl` | `CAS_CALLBACK_BASE_URL` | Web 入口地址,CAS 回调路径会自动拼接 | +| `nexent-common.config.cas.loginMode` | `CAS_LOGIN_MODE` | `disabled`、`button` 或 `force` | +| `nexent-common.config.cas.userAttribute` | `CAS_USER_ATTRIBUTE` | 用户标识属性。为空时使用 `` | +| `nexent-common.config.cas.emailAttribute` | `CAS_EMAIL_ATTRIBUTE` | 邮箱属性 | +| `nexent-common.config.cas.roleAttribute` | `CAS_ROLE_ATTRIBUTE` | 角色属性 | +| `nexent-common.config.cas.tenantAttribute` | `CAS_TENANT_ATTRIBUTE` | 租户属性 | +| `nexent-common.config.cas.roleMapJson` | `CAS_ROLE_MAP_JSON` | CAS 角色到 Nexent 角色的 JSON 映射 | +| `nexent-common.config.cas.sessionMaxAgeSeconds` | `CAS_SESSION_MAX_AGE_SECONDS` | CAS 本地会话最长有效期 | +| `nexent-common.config.cas.localSessionMaxAgeSeconds` | `LOCAL_SESSION_MAX_AGE_SECONDS` | Nexent 本地会话有效期 | +| `nexent-common.config.cas.renewBeforeSeconds` | `CAS_RENEW_BEFORE_SECONDS` | 距离过期多少秒内触发无感续期 | +| `nexent-common.config.cas.renewTimeoutSeconds` | `CAS_RENEW_TIMEOUT_SECONDS` | 无感续期等待超时时间 | +| `nexent-common.config.cas.syntheticEmailDomain` | `CAS_SYNTHETIC_EMAIL_DOMAIN` | CAS 未返回邮箱时生成邮箱使用的域名 | +| `nexent-common.config.cas.logoutUrl` | `CAS_LOGOUT_URL` | CAS 登出地址。为空时 Nexent 主动退出不调用 CAS Server 登出接口 | +| `nexent-common.config.cas.sslVerify` | `CAS_SSL_VERIFY` | 访问 CAS Server 时是否校验证书 | +| `nexent-common.config.cas.caBundle` | `CAS_CA_BUNDLE` | 自定义 CA bundle 路径 | + +常用 CAS 地址: + +| 用途 | 地址 | +|------|------| +| Nexent 登录入口 | `{CAS_CALLBACK_BASE_URL}/api/user/cas/login?redirect=/` | +| CAS service 回调 | `{CAS_CALLBACK_BASE_URL}/api/user/cas/callback` | +| CAS 无感续期回调 | `{CAS_CALLBACK_BASE_URL}/api/user/cas/renew_callback` | +| CAS 单点登出回调 | `POST {CAS_CALLBACK_BASE_URL}/api/user/cas/logout_callback` | + +Apereo CAS 使用 JSON Service Registry 时,可以新增一个服务注册文件,例如 `Nexent-10001.json`。文件需要放到 CAS 部署配置的 service registry 目录中,`id` 必须全局唯一。本地 NodePort 示例: + +```json +{ + "@class": "org.apereo.cas.services.RegexRegisteredService", + "serviceId": "http://localhost:30000.*", + "name": "Nexent CAS Client", + "id": 10001, + "description": "Nexent CAS SSO client", + "evaluationOrder": 1, + "logoutType": "BACK_CHANNEL", + "logoutUrl": "http://localhost:30000/api/user/cas/logout_callback" +} +``` + +生产环境建议保持 `CAS_SSL_VERIFY=true`;自签名证书优先配置 `CAS_CA_BUNDLE`,仅本地验证时再临时设置 `CAS_SSL_VERIFY=false`。 + +#### CAS 对接 ModelEngine + +当使用 CAS 协议对接 ModelEngine 时,建议通过 values 文件配置 Nexent,避免 `CAS_ROLE_MAP_JSON` 在命令行中转义复杂。 + +创建 `cas-modelengine-values.yaml`: + +```yaml +nexent-common: + config: + cas: + enabled: true + serverUrl: "https://:5443/SSOSvr" + validatePath: "/p3/serviceValidate" + callbackBaseUrl: "http://:30000" + loginMode: "force" + userAttribute: "userName" + emailAttribute: "email" + roleAttribute: "userType" + tenantAttribute: "tenant_id" + roleMapJson: '{"1":"ADMIN","3":"DEV"}' + sessionMaxAgeSeconds: 3600 + localSessionMaxAgeSeconds: 3600 + renewBeforeSeconds: 300 + renewTimeoutSeconds: 10 + syntheticEmailDomain: "cas.local" + logoutUrl: "/logout?service=http://:30000" + sslVerify: false + caBundle: "" +``` + +同时,需要进入 OMS 容器添加 CAS client 的注册配置文件,参考如下步骤: + +```bash +# 创建注册配置文件,将 JSON 部分输入文件并保存 +vim Nexent-10000001.json +{ + "@class": "org.apereo.cas.services.CasRegisteredService", + "serviceId": "http://:30000.*", + "name": "Nexent CAS Client", + "id": 1000001, + "description": "Nexent CAS SSO client", + "evaluationOrder": 1, + "logoutType": "BACK_CHANNEL", + "logoutUrl": "http://:30000/api/user/cas/logout_callback" +} + +# 执行如下命令,将配置文件拷贝到容器中 +kubectl cp Nexent-10000001.json model-engine/$(kubectl get pods -n model-engine -l app=oms --no-headers | awk '{print $1}'):/opt/huawei/fce/apps/platform/webapps/SSOSvr/WEB-INF/classes/services/Nexent-10000001.json +kubectl exec -i -n model-engine $(kubectl get pods -n model-engine -l app=oms --no-headers | awk '{print $1}') -- chown tomcat:fusioncube /opt/huawei/fce/apps/platform/webapps/SSOSvr/WEB-INF/classes/services/Nexent-10000001.json +``` + ## 🔍 故障排查 ### 查看 Pod 状态 diff --git a/doc/docs/zh/sdk/vector-database.md b/doc/docs/zh/sdk/vector-database.md index 940af9c33..b940400fd 100644 --- a/doc/docs/zh/sdk/vector-database.md +++ b/doc/docs/zh/sdk/vector-database.md @@ -579,7 +579,11 @@ python -m nexent.service.vectordatabase_service - 参数: - `index_name`: 索引名称 (路径参数) - `path_or_url`: 文档路径或URL (查询参数) - - 返回示例: `{"status": "success", "deleted_count": 1}` + - `scope`: 删除范围 (查询参数,默认 `full`) + - `source_only`: 仅删除 MinIO 源文件,保留 ES 中的切片与向量(检索仍可用,预览不可用) + - `full`: 删除 ES 文档、MinIO 源文件,并清理相关 Redis 任务记录 + - 返回示例 (`source_only`): `{"status": "success", "scope": "source_only", "deleted_es_count": 0, "deleted_minio": true, "source_available": false}` + - 返回示例 (`full`): `{"status": "success", "scope": "full", "deleted_es_count": 5, "deleted_minio": true}` #### 搜索操作 @@ -728,8 +732,11 @@ curl -X POST "http://localhost:8000/indices/search/hybrid" \ "weight_accurate": 0.3 }' -# 删除文档 -curl -X DELETE "http://localhost:8000/indices/my_documents/documents?path_or_url=https://example.com/doc1" +# 删除源文件(保留索引) +curl -X DELETE "http://localhost:8000/indices/my_documents/documents?path_or_url=knowledge_base/doc1.pdf&scope=source_only" + +# 从知识库彻底移除文档 +curl -X DELETE "http://localhost:8000/indices/my_documents/documents?path_or_url=knowledge_base/doc1.pdf&scope=full" # 创建索引 curl -X POST "http://localhost:8000/indices/my_documents" diff --git a/doc/docs/zh/user-guide/agent-development.md b/doc/docs/zh/user-guide/agent-development.md index 3edf31de7..40805aeea 100644 --- a/doc/docs/zh/user-guide/agent-development.md +++ b/doc/docs/zh/user-guide/agent-development.md @@ -113,6 +113,17 @@ Nexent 支持通过 A2A 协议与第三方 Agent 进行通信。您可以通过 > - 通过 Nacos 发现批量接入同一服务注册中心的所有 Agent > - 配置协议以兼容不同 Agent 服务提供商的要求 + +###### 通过URL对接[DataAgent](https://gitcode.com/datagallery/dataagent) A2A Agent +1. 参考[DataAgent文档](https://gitcode.com/datagallery/dataagent#%F0%9F%8C%90-a2a-10-%E6%9C%8D%E5%8A%A1%E6%A8%A1%E5%BC%8F)以A2A服务模式启动DataAgent + >当前Nexent不支持带认证的agent,启动DataAgent时请勿设置auth-token +
+ +
+ +2. 参考[通过 URL 发现 Agent](#通过-url-发现-agent)接入agent,url为http://\:9999/.well-known/agent-card.json +3. 参考[管理已发现的外部 Agent](#管理已发现的外部-agent)配置调用协议,选择HTTP+JSON方式接入 + ### 🛠️ 选择智能体的工具 智能体可以使用各种工具来完成任务,如知识库检索、文件解析、图片解析、收发邮件、文件管理等本地工具,也可接入第三方 MCP 工具,或自定义工具。 diff --git a/doc/docs/zh/user-guide/assets/agent-development/dataagent_deploy.png b/doc/docs/zh/user-guide/assets/agent-development/dataagent_deploy.png new file mode 100644 index 000000000..46fa9fde3 Binary files /dev/null and b/doc/docs/zh/user-guide/assets/agent-development/dataagent_deploy.png differ diff --git a/doc/procedural-memory-verification.md b/doc/procedural-memory-verification.md new file mode 100644 index 000000000..ea9f53290 --- /dev/null +++ b/doc/procedural-memory-verification.md @@ -0,0 +1,315 @@ +# Procedural Memory Verification Report + +## Summary +**Status: ⚠️ FULLY SUPPORTED but REQUIRES OPTIONAL DEPENDENCY** + +Procedural memory is a fully implemented feature in mem0ai version 0.1.117, **BUT it requires `langchain-core` to be installed separately**. Without this dependency, the feature will fail at runtime. + +--- + +## ⚠️ CRITICAL FINDING: Optional Dependency Required + +**Your colleague is partially correct.** The procedural memory code is NOT empty (it's 50 lines of real implementation), but it has a critical dependency issue: + +### The Problem + +The `_create_procedural_memory()` method contains: + +```python +try: + from langchain_core.messages.utils import convert_to_messages +except Exception: + logger.error( + "Import error while loading langchain-core. " + "Please install 'langchain-core' to use procedural memory." + ) + raise # ← Fails here if langchain-core not installed +``` + +### Reality Check + +| Aspect | Status | +|--------|--------| +| Code exists? | ✅ Yes, 50 lines of real implementation | +| Code is empty/stub? | ❌ No, it's fully implemented | +| Works out of the box? | ❌ **NO** - requires `langchain-core` package | +| Documented requirement? | ⚠️ Only in error message, not in main docs | + +### Why Your Colleague Thought It Was Empty + +1. They called `memory.add(..., memory_type="procedural_memory")` +2. Got `ImportError: No module named 'langchain_core'` +3. Saw the error and concluded "it doesn't work" or "it's empty" +4. This is understandable - the feature exists but is **disabled by default** + +--- + +## Verification Results + +### 1. API Support ✅ +The `memory_type` parameter is available in both `AsyncMemory.add()` and `Memory.add()`: + +```python +async def add( + self, + messages, + *, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + run_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + infer: bool = True, + memory_type: Optional[str] = None, # ✅ SUPPORTED + prompt: Optional[str] = None, + llm=None +) +``` + +### 2. MemoryType Enum ✅ +Located in `mem0.configs.enums.MemoryType`: + +```python +class MemoryType(Enum): + SEMANTIC = "semantic_memory" + EPISODIC = "episodic_memory" + PROCEDURAL = "procedural_memory" # ✅ AVAILABLE +``` + +### 3. Implementation ✅ +The `_create_procedural_memory()` method exists in both `AsyncMemory` and `Memory` classes: + +**AsyncMemory signature:** +```python +async def _create_procedural_memory( + self, + messages, + metadata=None, + llm=None, + prompt=None +) +``` + +**Memory (sync) signature:** +```python +def _create_procedural_memory( + self, + messages, + metadata=None, + prompt=None +) +``` + +### 4. Validation Logic ✅ +The `add()` method validates `memory_type` and enforces constraints: + +```python +# Only "procedural_memory" is accepted +if memory_type is not None and memory_type != MemoryType.PROCEDURAL.value: + raise ValueError( + f"Invalid 'memory_type'. Please pass {MemoryType.PROCEDURAL.value} " + "to create procedural memories." + ) + +# agent_id is REQUIRED for procedural memory +if agent_id is not None and memory_type == MemoryType.PROCEDURAL.value: + results = await self._create_procedural_memory( + messages, metadata=processed_metadata, prompt=prompt, llm=llm + ) + return results +``` + +### 5. System Prompt ✅ +A comprehensive 5,100-character system prompt exists in `mem0.configs.prompts.PROCEDURAL_MEMORY_SYSTEM_PROMPT`: + +**Purpose:** Records and preserves complete interaction history between human and AI agent + +**Structure:** +- Overview (Global Metadata) + - Task Objective + - Progress Status +- Sequential Agent Actions (Numbered Steps) + - Agent Action + - Action Result (Mandatory, Unmodified) + - Embedded Metadata (Key Findings, Navigation History, Errors, Current Context) + +**Key Guidelines:** +1. Preserve every output verbatim +2. Maintain chronological order +3. Include exact data (URLs, element indexes, error messages, JSON responses) +4. Output only the structured summary + +--- + +## Usage Example + +```python +from mem0 import AsyncMemory + +# Initialize memory +memory = await AsyncMemory.from_config(config) + +# Create procedural memory +messages = [ + {"role": "user", "content": "Search for AI news"}, + {"role": "assistant", "content": "I'll search for recent AI news..."}, + # ... more conversation history +] + +result = await memory.add( + messages=messages, + user_id="user_123", + agent_id="research_agent", # ⚠️ REQUIRED for procedural memory + memory_type="procedural_memory", + metadata={ + "task": "AI news research", + "session_id": "session_456" + } +) + +# Result format: +# { +# "results": [ +# { +# "id": "memory_id_here", +# "memory": "## Summary of the agent's execution history...", +# "event": "ADD" +# } +# ] +# } +``` + +--- + +## Requirements & Constraints + +### Required Parameters +- ✅ `agent_id`: **MUST** be provided when using `memory_type="procedural_memory"` +- ✅ `metadata`: **MUST** be provided (cannot be None) +- ✅ `messages`: List of conversation messages to summarize + +### Optional Parameters +- `prompt`: Custom prompt to override default `PROCEDURAL_MEMORY_SYSTEM_PROMPT` +- `llm`: Custom LangChain ChatModel (async version only) + +### Validation Rules +1. `memory_type` must be exactly `"procedural_memory"` (or None) +2. If `memory_type="procedural_memory"` is set, `agent_id` must be provided +3. `metadata` cannot be None for procedural memories + +--- + +## Implementation Details + +### How It Works +1. **Validation**: Checks `memory_type` and required parameters +2. **Prompt Construction**: Uses default or custom system prompt +3. **LLM Summarization**: Calls LLM to generate comprehensive execution summary +4. **Embedding**: Generates embedding for the summary +5. **Storage**: Stores in vector database with `metadata["memory_type"] = "procedural_memory"` +6. **Return**: Returns memory ID and summary text + +### Async vs Sync +- **AsyncMemory**: Supports custom LangChain `llm` parameter +- **Memory**: Uses internal LLM from config only + +--- + +## Integration with Nexent + +### Current Status +The Nexent codebase does **NOT** currently use procedural memory. The `memory_type` parameter is not passed in any `add_memory()` calls. + +### Recommended Integration Points + +1. **Agent Service** (`backend/services/agent_service.py`): + - Detect when agent completes a multi-step task + - Call `add_memory_in_levels()` with `memory_type="procedural_memory"` + - Pass the full conversation history as messages + +2. **Memory Service** (`sdk/nexent/memory/memory_service.py`): + - Add `memory_type` parameter to `add_memory()` and `add_memory_in_levels()` + - Pass through to mem0's `add()` method + +3. **Agent Run Info** (`sdk/nexent/core/agents/agent_model.py`): + - Add `memory_type` field to track if current run should create procedural memory + +### Example Integration + +```python +# In agent_service.py, after agent completes a complex task +if task_complexity >= threshold: # Your logic here + await add_memory_in_levels( + messages=conversation_history, + memory_config=memory_ctx.memory_config, + tenant_id=memory_ctx.tenant_id, + user_id=memory_ctx.user_id, + agent_id=memory_ctx.agent_id, + memory_levels=["agent", "user_agent"], + memory_type="procedural_memory", # ✅ NEW PARAMETER + metadata={ + "task_type": "complex_research", + "duration_seconds": duration, + "steps_completed": step_count + } + ) +``` + +--- + +## Conclusion + +Procedural memory is a **fully functional feature** in mem0ai==0.1.117, **BUT it requires an optional dependency**. It provides: + +- ✅ Complete API support +- ✅ Comprehensive system prompt (5,100 characters) +- ✅ Proper validation and error handling +- ✅ Both sync and async implementations +- ✅ Integration with existing memory infrastructure +- ⚠️ **REQUIRES `langchain-core` package to be installed** + +### The Truth About "Empty Function" Claims + +**The code is NOT empty.** It's a 50-line implementation that: +1. Calls LLM to generate execution summary +2. Creates embeddings +3. Stores in vector database +4. Returns proper results + +**However, it fails at runtime** if `langchain-core` is not installed, which is why your colleague might have thought it was a no-op. + +### How to Enable + +**Option 1: Install the dependency** +```bash +pip install langchain-core +``` + +**Option 2: Add to Nexent's dependencies** +```toml +# In sdk/pyproject.toml +dependencies = [ + # ... existing deps ... + "langchain-core>=0.1.0", # Required for procedural memory +] +``` + +**Option 3: Make it optional with fallback** +```python +try: + result = await memory.add(..., memory_type="procedural_memory") +except ImportError as e: + if "langchain-core" in str(e): + logger.warning("Procedural memory requires langchain-core. Using regular memory.") + result = await memory.add(...) # Fallback + else: + raise +``` + +### Final Recommendation + +This feature **can be integrated into Nexent**, but you must: +1. Add `langchain-core` to dependencies, OR +2. Implement graceful fallback when dependency is missing, OR +3. Document it as an optional feature requiring extra installation + +Without addressing the dependency issue, procedural memory will fail at runtime despite having complete implementation code. diff --git a/docker/.env.example b/docker/.env.example index c34300523..3970efb95 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -226,3 +226,27 @@ OAUTH_CALLBACK_BASE_URL=http://localhost:3000 # Asset owner role (opt-in; default false). Set true to enable ASSET_OWNER. ENABLE_ASSET_OWNER_ROLE=false + +# ===== CAS SSO Configuration ===== +CAS_ENABLED=false +CAS_SERVER_URL= +CAS_VALIDATE_PATH=/p3/serviceValidate +CAS_CALLBACK_BASE_URL=http://localhost:3000 +# Supported values: +# - disabled: disable CAS login entry and automatic CAS redirects. +# - button: show CAS as an optional login entry. +# - force: automatically redirect unauthenticated users to CAS login. +CAS_LOGIN_MODE=disabled +CAS_USER_ATTRIBUTE= +CAS_EMAIL_ATTRIBUTE=email +CAS_ROLE_ATTRIBUTE=role +CAS_TENANT_ATTRIBUTE=tenant_id +CAS_ROLE_MAP_JSON= +CAS_SESSION_MAX_AGE_SECONDS=3600 +LOCAL_SESSION_MAX_AGE_SECONDS=3600 +CAS_RENEW_BEFORE_SECONDS=300 +CAS_RENEW_TIMEOUT_SECONDS=10 +CAS_SYNTHETIC_EMAIL_DOMAIN=cas.local +CAS_LOGOUT_URL=/logout +CAS_SSL_VERIFY=true +CAS_CA_BUNDLE= diff --git a/docker/deploy.sh b/docker/deploy.sh index 2069330d1..fbf3664b5 100755 --- a/docker/deploy.sh +++ b/docker/deploy.sh @@ -1367,7 +1367,7 @@ main_deploy() { echo "--------------------------------" echo "" - APP_VERSION="latest" + APP_VERSION="$(get_app_version)" if [ -z "$APP_VERSION" ]; then echo "❌ Failed to get app version, please check the backend/consts/const.py file" exit 1 diff --git a/docker/init.sql b/docker/init.sql index 0668def01..046bdecf1 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -230,6 +230,7 @@ CREATE TABLE IF NOT EXISTS "knowledge_record_t" ( "summary_frequency" varchar(10) COLLATE "pg_catalog"."default", "last_summary_time" timestamp(0), "last_doc_update_time" timestamp(0), + "preserve_source_file" boolean NOT NULL DEFAULT true, CONSTRAINT "knowledge_record_t_pk" PRIMARY KEY ("knowledge_id") ); ALTER TABLE "knowledge_record_t" OWNER TO "root"; @@ -251,6 +252,7 @@ COMMENT ON COLUMN "knowledge_record_t"."created_by" IS 'User who created the rec COMMENT ON COLUMN "knowledge_record_t"."summary_frequency" IS 'Auto-summary frequency: 1h, 3h, 6h, 1d, 1w, or NULL (disabled)'; COMMENT ON COLUMN "knowledge_record_t"."last_summary_time" IS 'Timestamp of last summary generation'; COMMENT ON COLUMN "knowledge_record_t"."last_doc_update_time" IS 'Timestamp of last document add/delete operation, used for auto-summary optimization to skip unnecessary summary regeneration'; +COMMENT ON COLUMN "knowledge_record_t"."preserve_source_file" IS 'Whether to preserve uploaded source documents after vectorization'; COMMENT ON COLUMN "knowledge_record_t"."updated_by" IS 'Last updater ID, audit field'; COMMENT ON COLUMN "knowledge_record_t"."created_by" IS 'Creator ID, audit field'; COMMENT ON TABLE "knowledge_record_t" IS 'Records knowledge base description and status information'; @@ -337,9 +339,12 @@ CREATE TABLE IF NOT EXISTS nexent.ag_tenant_agent_t ( is_new BOOLEAN DEFAULT FALSE, provide_run_summary BOOLEAN DEFAULT FALSE, enable_context_manager BOOLEAN DEFAULT FALSE, + verification_config JSONB, version_no INTEGER DEFAULT 0 NOT NULL, current_version_no INTEGER NULL, ingroup_permission VARCHAR(30), + greeting_message TEXT, + example_questions JSONB, create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, created_by VARCHAR(100), @@ -397,6 +402,9 @@ COMMENT ON COLUMN nexent.ag_tenant_agent_t.version_no IS 'Version number. 0 = dr COMMENT ON COLUMN nexent.ag_tenant_agent_t.current_version_no IS 'Current published version number. NULL means no version published yet'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.ingroup_permission IS 'In-group permission: EDIT, READ_ONLY, PRIVATE'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.enable_context_manager IS 'Whether to enable context management (compression) for this agent'; +COMMENT ON COLUMN nexent.ag_tenant_agent_t.verification_config IS 'Layered ReAct self-verification configuration'; +COMMENT ON COLUMN nexent.ag_tenant_agent_t.greeting_message IS 'Agent greeting message displayed on chat initial screen'; +COMMENT ON COLUMN nexent.ag_tenant_agent_t.example_questions IS 'List of example questions for starting a conversation with this agent'; -- Create index for is_new queries CREATE INDEX IF NOT EXISTS idx_ag_tenant_agent_t_is_new @@ -715,6 +723,7 @@ CREATE TABLE IF NOT EXISTS nexent.ag_agent_relation_t ( parent_agent_id INTEGER, tenant_id VARCHAR(100), version_no INTEGER DEFAULT 0 NOT NULL, + selected_agent_version_no INTEGER, create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, created_by VARCHAR(100), @@ -747,6 +756,7 @@ COMMENT ON COLUMN nexent.ag_agent_relation_t.selected_agent_id IS 'Selected agen COMMENT ON COLUMN nexent.ag_agent_relation_t.parent_agent_id IS 'Parent agent ID'; COMMENT ON COLUMN nexent.ag_agent_relation_t.tenant_id IS 'Tenant ID'; COMMENT ON COLUMN nexent.ag_agent_relation_t.version_no IS 'Version number. 0 = draft/editing state, >=1 = published snapshot'; +COMMENT ON COLUMN nexent.ag_agent_relation_t.selected_agent_version_no IS 'Pinned version of selected_agent_id. NULL = use child current published version at runtime (legacy/draft).'; COMMENT ON COLUMN nexent.ag_agent_relation_t.create_time IS 'Creation time, audit field'; COMMENT ON COLUMN nexent.ag_agent_relation_t.update_time IS 'Update time, audit field'; COMMENT ON COLUMN nexent.ag_agent_relation_t.created_by IS 'Creator ID, audit field'; @@ -1260,7 +1270,6 @@ CREATE TABLE IF NOT EXISTS nexent.ag_skill_info_t ( config_schemas JSON, config_values JSON, source VARCHAR(30) DEFAULT 'official', - tenant_id VARCHAR(100), created_by VARCHAR(100), create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, updated_by VARCHAR(100), @@ -1900,3 +1909,31 @@ FOR EACH ROW EXECUTE FUNCTION update_mcp_community_record_update_time(); COMMENT ON TRIGGER update_mcp_community_record_update_time_trigger ON nexent.mcp_community_record_t IS 'Trigger to maintain update_time'; + +CREATE TABLE IF NOT EXISTS nexent.user_cas_session_t ( + cas_session_id SERIAL PRIMARY KEY, + session_id VARCHAR(100) NOT NULL UNIQUE, + user_id VARCHAR(100) NOT NULL, + cas_user_id VARCHAR(200) NOT NULL, + cas_session_index VARCHAR(500), + status VARCHAR(30) NOT NULL DEFAULT 'active', + expires_at TIMESTAMP NOT NULL, + revoked_at TIMESTAMP, + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag VARCHAR(1) DEFAULT 'N' +); + +CREATE INDEX IF NOT EXISTS ix_user_cas_session_session_id + ON nexent.user_cas_session_t (session_id); +CREATE INDEX IF NOT EXISTS ix_user_cas_session_user_id + ON nexent.user_cas_session_t (user_id); +CREATE INDEX IF NOT EXISTS ix_user_cas_session_cas_user_id + ON nexent.user_cas_session_t (cas_user_id); + +COMMENT ON TABLE nexent.user_cas_session_t IS 'Server-side session records for CAS SSO login and logout synchronization'; +COMMENT ON COLUMN nexent.user_cas_session_t.session_id IS 'JWT sid claim for revocation checks'; +COMMENT ON COLUMN nexent.user_cas_session_t.cas_user_id IS 'User identifier returned by CAS'; +COMMENT ON COLUMN nexent.user_cas_session_t.cas_session_index IS 'CAS SessionIndex or service ticket'; diff --git a/docker/official-skills-zip/create-docx.zip b/docker/official-skills-zip/create-docx.zip new file mode 100644 index 000000000..aa53e82b0 Binary files /dev/null and b/docker/official-skills-zip/create-docx.zip differ diff --git a/docker/sql/v2.2.0_0526_add_cas_session_t.sql b/docker/sql/v2.2.0_0526_add_cas_session_t.sql new file mode 100644 index 000000000..3f1aab4fa --- /dev/null +++ b/docker/sql/v2.2.0_0526_add_cas_session_t.sql @@ -0,0 +1,27 @@ +CREATE TABLE IF NOT EXISTS nexent.user_cas_session_t ( + cas_session_id SERIAL PRIMARY KEY, + session_id VARCHAR(100) NOT NULL UNIQUE, + user_id VARCHAR(100) NOT NULL, + cas_user_id VARCHAR(200) NOT NULL, + cas_session_index VARCHAR(500), + status VARCHAR(30) NOT NULL DEFAULT 'active', + expires_at TIMESTAMP NOT NULL, + revoked_at TIMESTAMP, + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag VARCHAR(1) DEFAULT 'N' +); + +CREATE INDEX IF NOT EXISTS ix_user_cas_session_session_id + ON nexent.user_cas_session_t (session_id); +CREATE INDEX IF NOT EXISTS ix_user_cas_session_user_id + ON nexent.user_cas_session_t (user_id); +CREATE INDEX IF NOT EXISTS ix_user_cas_session_cas_user_id + ON nexent.user_cas_session_t (cas_user_id); + +COMMENT ON TABLE nexent.user_cas_session_t IS 'Server-side session records for CAS SSO login and logout synchronization'; +COMMENT ON COLUMN nexent.user_cas_session_t.session_id IS 'JWT sid claim for revocation checks'; +COMMENT ON COLUMN nexent.user_cas_session_t.cas_user_id IS 'User identifier returned by CAS'; +COMMENT ON COLUMN nexent.user_cas_session_t.cas_session_index IS 'CAS SessionIndex or service ticket'; diff --git a/docker/sql/v2.2.1_0601_add_agent_verification_config.sql b/docker/sql/v2.2.1_0601_add_agent_verification_config.sql new file mode 100644 index 000000000..d3882e1e2 --- /dev/null +++ b/docker/sql/v2.2.1_0601_add_agent_verification_config.sql @@ -0,0 +1,7 @@ +-- Migration: Add layered ReAct self-verification config to agents +-- Description: Stores per-agent verification controls for step-level and final-answer validation. + +ALTER TABLE nexent.ag_tenant_agent_t +ADD COLUMN IF NOT EXISTS verification_config JSONB; + +COMMENT ON COLUMN nexent.ag_tenant_agent_t.verification_config IS 'Layered ReAct self-verification configuration'; diff --git a/docker/sql/v2.2.1_0601_add_preserve_source_file_to_knowledge_record_t.sql b/docker/sql/v2.2.1_0601_add_preserve_source_file_to_knowledge_record_t.sql new file mode 100644 index 000000000..30b588a51 --- /dev/null +++ b/docker/sql/v2.2.1_0601_add_preserve_source_file_to_knowledge_record_t.sql @@ -0,0 +1,8 @@ +-- Migration: Add preserve_source_file to knowledge_record_t table +-- Date: 2026-06-01 +-- Description: Whether to preserve uploaded source documents after vectorization (default: true) + +ALTER TABLE nexent.knowledge_record_t +ADD COLUMN IF NOT EXISTS preserve_source_file BOOLEAN NOT NULL DEFAULT true; + +COMMENT ON COLUMN nexent.knowledge_record_t.preserve_source_file IS 'Whether to preserve uploaded source documents after vectorization'; diff --git a/docker/sql/v2.2.1_0603_add_greeting_fields_to_ag_tenant_agent_t.sql b/docker/sql/v2.2.1_0603_add_greeting_fields_to_ag_tenant_agent_t.sql new file mode 100644 index 000000000..7786bb902 --- /dev/null +++ b/docker/sql/v2.2.1_0603_add_greeting_fields_to_ag_tenant_agent_t.sql @@ -0,0 +1,15 @@ +-- Migration: Add greeting_message and example_questions columns to ag_tenant_agent_t table +-- Date: 2026-06-03 +-- Description: Add greeting message and example questions fields for agent chat initial screen + +-- Add greeting_message column to ag_tenant_agent_t table +ALTER TABLE nexent.ag_tenant_agent_t +ADD COLUMN IF NOT EXISTS greeting_message TEXT; + +-- Add example_questions column to ag_tenant_agent_t table +ALTER TABLE nexent.ag_tenant_agent_t +ADD COLUMN IF NOT EXISTS example_questions JSONB; + +-- Add comments to the columns +COMMENT ON COLUMN nexent.ag_tenant_agent_t.greeting_message IS 'Agent greeting message displayed on chat initial screen'; +COMMENT ON COLUMN nexent.ag_tenant_agent_t.example_questions IS 'List of example questions for starting a conversation with this agent'; \ No newline at end of file diff --git a/docker/sql/v2.2.1_0605_add_ag_agent_repository_t.sql b/docker/sql/v2.2.1_0605_add_ag_agent_repository_t.sql new file mode 100644 index 000000000..d719fc5aa --- /dev/null +++ b/docker/sql/v2.2.1_0605_add_ag_agent_repository_t.sql @@ -0,0 +1,96 @@ +-- Migration: Add ag_agent_repository_t table +-- Date: 2026-06-05 +-- Description: Agent marketplace repository for frozen shareable agent snapshots. + +SET search_path TO nexent; + +BEGIN; + +CREATE SEQUENCE IF NOT EXISTS nexent.ag_agent_repository_t_agent_repository_id_seq; + +CREATE TABLE IF NOT EXISTS nexent.ag_agent_repository_t ( + agent_repository_id BIGINT NOT NULL DEFAULT nextval('nexent.ag_agent_repository_t_agent_repository_id_seq'), + publisher_tenant_id VARCHAR(100) NOT NULL, + publisher_user_id VARCHAR(100) NOT NULL, + agent_id INTEGER NOT NULL, + source_version_no INTEGER NOT NULL, + name VARCHAR(100) NOT NULL, + display_name VARCHAR(100), + description TEXT, + author VARCHAR(100), + category_id INTEGER, + tags TEXT[], + tool_count INTEGER, + version_label VARCHAR(100), + agent_info_json JSONB NOT NULL, + status VARCHAR(30) DEFAULT 'NOT_SHARED', + create_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP, + created_by VARCHAR(100), + updated_by VARCHAR(100), + delete_flag VARCHAR(1) DEFAULT 'N', + CONSTRAINT ag_agent_repository_t_pkey PRIMARY KEY (agent_repository_id) +); + +ALTER SEQUENCE nexent.ag_agent_repository_t_agent_repository_id_seq + OWNED BY nexent.ag_agent_repository_t.agent_repository_id; + +ALTER TABLE nexent.ag_agent_repository_t OWNER TO root; + +COMMENT ON TABLE nexent.ag_agent_repository_t IS 'Agent marketplace repository for frozen shareable agent snapshots'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.agent_repository_id IS 'Agent repository listing ID, unique primary key'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.publisher_tenant_id IS 'Publisher tenant ID'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.publisher_user_id IS 'Publisher user ID'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.agent_id IS 'Root agent ID from ag_tenant_agent_t; upsert key with publisher_tenant_id'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.source_version_no IS 'Published version number frozen at share time'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.name IS 'Root agent programmatic name for display and search'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.display_name IS 'Root agent display name'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.description IS 'Root agent description'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.author IS 'Agent author'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.category_id IS 'Optional marketplace category ID'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.tags IS 'Marketplace tags'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.tool_count IS 'Total tool count across all agents in the bundle (display only)'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.version_label IS 'Repository entry version label for display (e.g. v1.0)'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.agent_info_json IS 'Frozen ExportAndImportDataFormat snapshot with optional skills'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.status IS 'Listing status: NOT_SHARED (未共享) / PENDING_REVIEW (待审核) / REJECTED (审核驳回) / SHARED (已共享)'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.create_time IS 'Creation time'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.update_time IS 'Update time'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.created_by IS 'Creator ID'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.updated_by IS 'Updater ID'; +COMMENT ON COLUMN nexent.ag_agent_repository_t.delete_flag IS 'Soft delete flag: Y/N'; + +CREATE UNIQUE INDEX IF NOT EXISTS uq_agent_repository_tenant_agent_active + ON nexent.ag_agent_repository_t (publisher_tenant_id, agent_id) + WHERE delete_flag = 'N'; + +CREATE INDEX IF NOT EXISTS idx_agent_repository_publisher_delete + ON nexent.ag_agent_repository_t (publisher_tenant_id, delete_flag); + +CREATE INDEX IF NOT EXISTS idx_agent_repository_status_delete + ON nexent.ag_agent_repository_t (status, delete_flag); + +CREATE INDEX IF NOT EXISTS idx_agent_repository_name_delete + ON nexent.ag_agent_repository_t (name, delete_flag); + +CREATE INDEX IF NOT EXISTS idx_agent_repository_tags_gin + ON nexent.ag_agent_repository_t USING GIN (tags); + +CREATE OR REPLACE FUNCTION update_ag_agent_repository_update_time() +RETURNS TRIGGER AS $$ +BEGIN + NEW.update_time = CURRENT_TIMESTAMP; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +COMMENT ON FUNCTION update_ag_agent_repository_update_time() IS 'Auto-update update_time for ag_agent_repository_t'; + +DROP TRIGGER IF EXISTS update_ag_agent_repository_update_time_trigger ON nexent.ag_agent_repository_t; +CREATE TRIGGER update_ag_agent_repository_update_time_trigger +BEFORE UPDATE ON nexent.ag_agent_repository_t +FOR EACH ROW +EXECUTE FUNCTION update_ag_agent_repository_update_time(); + +COMMENT ON TRIGGER update_ag_agent_repository_update_time_trigger ON nexent.ag_agent_repository_t IS 'Trigger to maintain update_time'; + +COMMIT; diff --git a/docker/sql/v2.2.1_0609_add_selected_agent_version_no_to_agent_relation_t.sql b/docker/sql/v2.2.1_0609_add_selected_agent_version_no_to_agent_relation_t.sql new file mode 100644 index 000000000..9a67c1ab2 --- /dev/null +++ b/docker/sql/v2.2.1_0609_add_selected_agent_version_no_to_agent_relation_t.sql @@ -0,0 +1,15 @@ +-- Migration: Add selected_agent_version_no to ag_agent_relation_t +-- Date: 2026-06-09 +-- Description: Pin child agent version on parent-child relations at publish time. + +SET search_path TO nexent; + +BEGIN; + +ALTER TABLE nexent.ag_agent_relation_t + ADD COLUMN IF NOT EXISTS selected_agent_version_no INTEGER; + +COMMENT ON COLUMN nexent.ag_agent_relation_t.selected_agent_version_no IS + 'Pinned version of selected_agent_id. NULL = use child current published version at runtime (legacy/draft).'; + +COMMIT; diff --git a/frontend/app/[locale]/agents/components/AgentConfigComp.tsx b/frontend/app/[locale]/agents/components/AgentConfigComp.tsx index 13484595f..1e750d5eb 100644 --- a/frontend/app/[locale]/agents/components/AgentConfigComp.tsx +++ b/frontend/app/[locale]/agents/components/AgentConfigComp.tsx @@ -29,6 +29,8 @@ export default function AgentConfigComp({}: AgentConfigCompProps) { const currentAgentId = useAgentConfigStore((state) => state.currentAgentId); const isCreatingMode = useAgentConfigStore((state) => state.isCreatingMode); const isReadOnly = useAgentConfigStore((state) => state.isReadOnly()); + const selectedTools = useAgentConfigStore((state) => state.editedAgent.tools); + const selectedSkills = useAgentConfigStore((state) => state.editedAgent.skills); const [isMcpModalOpen, setIsMcpModalOpen] = useState(false); const [isSkillModalOpen, setIsSkillModalOpen] = useState(false); @@ -125,7 +127,12 @@ export default function AgentConfigComp({}: AgentConfigCompProps) { - {t("toolPool.title")} + + {t("toolPool.title")} + {selectedTools.length > 0 && ( + + )} + {t("toolPool.tooltip.functionGuide")}} color="#ffffff" @@ -144,7 +151,14 @@ export default function AgentConfigComp({}: AgentConfigCompProps) { - {t("skillPool.title")} + + + {t("skillPool.title")} + {selectedSkills && selectedSkills.length > 0 && ( + + )} + + diff --git a/frontend/app/[locale]/agents/components/agentConfig/McpConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/McpConfigModal.tsx index 277e85d3d..41c8baa45 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/McpConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/McpConfigModal.tsx @@ -80,6 +80,7 @@ export default function McpConfigModal({ const [openApiJson, setOpenApiJson] = useState(""); const [openApiServiceName, setOpenApiServiceName] = useState(""); const [openApiServerUrl, setOpenApiServerUrl] = useState(""); + const [openApiHeadersTemplate, setOpenApiHeadersTemplate] = useState(""); const [importingOpenApi, setImportingOpenApi] = useState(false); const [openapiServices, setOpenapiServices] = useState([]); const [loadingOpenapiServices, setLoadingOpenapiServices] = useState(false); @@ -506,6 +507,7 @@ export default function McpConfigModal({ service_name: openApiServiceName.trim(), server_url: openApiServerUrl.trim(), openapi_json: parsedJson, + headers_template: openApiHeadersTemplate.trim() ? JSON.parse(openApiHeadersTemplate.trim()) : null, }), }); @@ -514,6 +516,7 @@ export default function McpConfigModal({ setOpenApiJson(""); setOpenApiServiceName(""); setOpenApiServerUrl(""); + setOpenApiHeadersTemplate(""); await loadOpenapiServices(); await refreshToolsAndAgents(); } else { @@ -1220,15 +1223,20 @@ export default function McpConfigModal({ style={{ flex: 3 }} /> -
- setOpenApiJson(e.target.value)} - rows={6} - disabled={actionsLocked || importingOpenApi} - /> -
+ setOpenApiHeadersTemplate(e.target.value)} + rows={2} + disabled={actionsLocked || importingOpenApi} + /> + setOpenApiJson(e.target.value)} + rows={6} + disabled={actionsLocked || importingOpenApi} + />
{ + const selectedCount = group.skills.filter(s => originalSelectedSkillIdsSet.has(s.skill_id)).length; + return { key: group.key, label: ( - - {group.label} + + + {group.label} + + {selectedCount > 0 && ( + + )} ), diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index 0cb73de62..62edc3ac8 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -4,7 +4,7 @@ import { useState, useEffect, useCallback } from "react"; import { useTranslation } from "react-i18next"; import ToolConfigModal from "./tool/ToolConfigModal"; import { ToolGroup, Tool, ToolParam } from "@/types/agentConfig"; -import { Tabs, Collapse, message, Tooltip } from "antd"; +import { Tabs, Collapse, message, Tooltip, Badge } from "antd"; import { useAgentConfigStore } from "@/stores/agentConfigStore"; import { useToolList } from "@/hooks/agent/useToolList"; import { usePrefetchKnowledgeBases } from "@/hooks/useKnowledgeBaseSelector"; @@ -307,21 +307,29 @@ export default function ToolManagement({ // Generate Tabs configuration const tabItems = toolGroups.map((group) => { const label = t(group.label); + const selectedCount = group.subGroups + ? group.subGroups.reduce( + (sum, sg) => sum + sg.tools.filter(t => originalSelectedToolIdsSet.has(t.id)).length, 0) + : group.tools.filter(t => originalSelectedToolIdsSet.has(t.id)).length; return { key: group.key, label: ( - - {label} + + + {label} + + {selectedCount > 0 && ( + + )} ), @@ -351,17 +359,25 @@ export default function ToolManagement({ items={group.subGroups.map((subGroup, index) => ({ key: subGroup.key, label: ( - - {subGroup.label} + + + {subGroup.label} + + {subGroup.tools.filter(t => originalSelectedToolIdsSet.has(t.id)).length > 0 && ( + originalSelectedToolIdsSet.has(t.id)).length} + size="small" + color="blue" + /> + )} ), className: `tool-category-panel ${ diff --git a/frontend/app/[locale]/agents/components/agentConfig/skill/SkillConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/skill/SkillConfigModal.tsx index 6f372e2b4..9729007e2 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/skill/SkillConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/skill/SkillConfigModal.tsx @@ -12,13 +12,13 @@ import { message, Tag, Skeleton, + Tooltip } from "antd"; import { Settings } from "lucide-react"; import { CloseOutlined } from "@ant-design/icons"; import { Skill, SkillParam } from "@/types/agentConfig"; import { KnowledgeBase } from "@/types/knowledgeBase"; -import { Tooltip } from "@/components/ui/tooltip"; import { saveSkillInstance } from "@/services/agentConfigService"; import KnowledgeBaseSelectorModal from "@/components/tool-config/KnowledgeBaseSelectorModal"; import { diff --git a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx index 8b6cd82d7..cd46d2aa3 100644 --- a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx +++ b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useEffect, useMemo, useRef, useCallback } from "react"; +import { useState, useEffect, useMemo, useRef } from "react"; import { useTranslation } from "react-i18next"; import { Button, @@ -17,9 +17,11 @@ import { } from "antd"; import { Tabs, TabsList, TabsTrigger, TabsContent } from "@/components/ui/tabs"; import { Zap, Maximize2, Settings2, Sparkles } from "lucide-react"; +import { Textarea } from "@/components/ui/textarea"; import { AgentConfigUpdate, + DEFAULT_AGENT_VERIFICATION_CONFIG, PromptTemplate, } from "@/types/agentConfig"; import { @@ -169,6 +171,7 @@ export default function AgentGenerateDetail({}) { constraintPrompt: editedAgent.constraint_prompt || "", fewShotsPrompt: editedAgent.few_shots_prompt || "", provideRunSummary: editedAgent.provide_run_summary || false, + verificationEnabled: editedAgent.verification_config?.enabled ?? false, businessDescription: editedAgent.business_description || "", businessLogicModelName:editedAgent.business_logic_model_name, businessLogicModelId: editedAgent.business_logic_model_id, @@ -233,6 +236,7 @@ export default function AgentGenerateDetail({}) { setOptimizeModalOpen(true); }; + const renderExpandButton = (type: "duty" | "constraint" | "few-shots") => { return (