diff --git a/edgeai/ondevice-eval-agent/.dockerignore b/edgeai/ondevice-eval-agent/.dockerignore new file mode 100644 index 00000000..d7222024 --- /dev/null +++ b/edgeai/ondevice-eval-agent/.dockerignore @@ -0,0 +1,11 @@ +**/__pycache__ +**/*.pyc +**/*.pyo +.git +.gitignore +.pytest_cache +.venv +venv +tests/ +frontend/node_modules +frontend/dist diff --git a/edgeai/ondevice-eval-agent/.gitignore b/edgeai/ondevice-eval-agent/.gitignore new file mode 100644 index 00000000..fbc1b0c2 --- /dev/null +++ b/edgeai/ondevice-eval-agent/.gitignore @@ -0,0 +1,19 @@ +# Python +__pycache__/ +*.py[cod] +*.egg-info/ +.pytest_cache/ +.venv/ +venv/ + +# Node / frontend +frontend/node_modules/ +frontend/dist/ +frontend/tsconfig.tsbuildinfo + +# Environment +.env +.env.local + +# OS +.DS_Store diff --git a/edgeai/ondevice-eval-agent/Dockerfile b/edgeai/ondevice-eval-agent/Dockerfile new file mode 100644 index 00000000..37e2b09d --- /dev/null +++ b/edgeai/ondevice-eval-agent/Dockerfile @@ -0,0 +1,62 @@ +# Build context: ondevice-eval-agent/ +# +# Multi-stage build: +# stage 1 (node) — compile the React SPA (frontend/ → dist/) +# stage 2 (python) — install deps, copy backend, drop SPA dist into webapp/spa/ +# +# The Python runtime serves both the API and the built SPA on :8080, +# so the whole app is a single image and a single port. + +# ---------- Stage 1: build the React SPA ---------- +# Pinned to BUILDPLATFORM so multi-arch builds compile the SPA natively +# (output is static JS/HTML/CSS — arch-neutral) instead of via qemu, which +# can crash esbuild. +FROM --platform=$BUILDPLATFORM node:20-alpine AS spa-builder + +# Pin pnpm to a version that still supports Node 20. Without this, corepack +# auto-fetches the newest pnpm (11+) which requires Node 22's built-in +# `node:sqlite` and crashes on cold install. +RUN corepack enable && corepack prepare pnpm@9.15.0 --activate + +WORKDIR /spa + +COPY frontend/package.json ./ +RUN pnpm install --no-frozen-lockfile + +COPY frontend/ ./ +# Same-origin API — the Flask server serves the SPA and the /agent/* routes. +ENV VITE_API_BASE="" +RUN pnpm build + + +# ---------- Stage 2: Python runtime ---------- +FROM python:3.11-slim + +WORKDIR /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + libgl1 \ + libglib2.0-0 \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY client/ client/ +COPY webapp/ webapp/ + +# Drop the built SPA where core.py (SPA_DIST) expects it. +COPY --from=spa-builder /spa/dist/ webapp/spa/ + +ENV MAX_STARTUP_WAIT=300 \ + HEALTH_CHECK_INTERVAL=10 \ + SPA_DIST=/app/webapp/spa \ + PYTHONUNBUFFERED=1 + +EXPOSE 8080 + +HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \ + CMD curl -fsS http://localhost:8080/agent/status || exit 1 + +CMD ["python", "webapp/app.py"] diff --git a/edgeai/ondevice-eval-agent/README.md b/edgeai/ondevice-eval-agent/README.md new file mode 100644 index 00000000..397fda76 --- /dev/null +++ b/edgeai/ondevice-eval-agent/README.md @@ -0,0 +1,296 @@ +# ZEDEDA On-Device AI Agent - Client Container + +Flask-based web application for ML model inference with an AI-powered assistant for model exploration and integration guidance. + +## Features + +- Web interface for image upload and classification +- Multi-model support with dynamic discovery +- Real-time processing logs +- API endpoints for programmatic access +- Customizable preprocessing +- **AI Agent for model exploration and integration guidance** + +## AI Agent (Agentic Demo POC) + +The business logic includes an intelligent AI assistant that helps developers understand and integrate with deployed ML models. + +### Agent Capabilities + +| Capability | Description | +|------------|-------------| +| **Model Discovery** | Identifies available models on Triton/OpenVINO inference servers | +| **Input Requirements** | Explains image formats, preprocessing, and camera feed recommendations | +| **Output Interpretation** | Describes model outputs (bounding boxes, labels, masks) and post-processing | +| **Integration Guidance** | Provides code examples for JavaScript, Python, React, and cURL | + +### Example Questions + +- "What model is currently running on the server?" +- "How should I structure the frontend/client logic to call this model?" +- "What images or camera feed characteristics will this model respond to reliably?" +- "How do I interpret the bounding box outputs from this detection model?" +- "Show me how to preprocess images for this model" + +### Agent API Endpoints + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/agent/chat` | POST | Send a message to the AI agent | +| `/agent/status` | GET | Check if agent is enabled | + +#### Chat Request Example + +```bash +curl -X POST http://localhost:8080/agent/chat \ + -H "Content-Type: application/json" \ + -d '{"message": "What model is running?", "session_id": "my-session"}' +``` + +#### Response Format + +```json +{ + "success": true, + "response": "Agent's response text...", + "session_id": "my-session", + "enabled": true, + "tool_calls": [...], + "tokens": {"input": 150, "output": 200} +} +``` + +### Agent Tools + +The agent has access to these tools for real-time model exploration: + +| Tool | Purpose | +|------|---------| +| `list_available_models` | Discover all models on the inference server | +| `get_model_metadata` | Get detailed model specifications | +| `get_model_input_requirements` | Get preprocessing and input format guidance | +| `get_model_output_interpretation` | Understand model outputs and post-processing | +| `analyze_model_type` | Infer model type from tensor shapes | +| `get_server_status` | Check inference server health | +| `get_api_examples` | Get cURL commands for API testing | +| `get_frontend_integration_guide` | Get full integration code examples | + +### Enabling the Agent + +The agent supports multiple LLM backends. Set one of the following: + +#### Option 1: Anthropic Claude (Recommended) + +Best for reliable tool calling and high-quality responses. + +```bash +export ANTHROPIC_API_KEY=sk-ant-your-key-here +export ANTHROPIC_MODEL=claude-sonnet-4-20250514 # optional +``` + +#### Option 2: OpenAI + +Use GPT-4o or other OpenAI models. + +```bash +export OPENAI_API_KEY=sk-your-key-here +export OPENAI_MODEL=gpt-4o # optional, defaults to gpt-4o +``` + +#### Option 3: Google Gemini + +Use Gemini 1.5 Pro or other Google models. + +```bash +export GOOGLE_API_KEY=your-key-here +export GOOGLE_MODEL=gemini-1.5-pro # optional +``` + +#### Option 4: Local LLM Server (OpenAI-Compatible) + +Use Ollama, LM Studio, vLLM, or any OpenAI-compatible API. + +```bash +export LLM_SERVER_URL=http://your-llm-server:1234 +export LLM_MODEL_NAME=your-model-name # optional +export LLM_API_KEY=your-api-key # optional +``` + +**Server-specific examples:** +```bash +# Ollama +export LLM_SERVER_URL=http://localhost:11434 +export LLM_MODEL_NAME=llama3.1 + +# LM Studio +export LLM_SERVER_URL=http://localhost:1234 + +# vLLM +export LLM_SERVER_URL=http://localhost:8000 +``` + +> **Priority:** If multiple backends are configured, they are used in this order: +> Anthropic → OpenAI → Google → Local LLM Server + +## Configuration + +Environment variables: +- `MODEL_SERVER_URL`: URL of the inference server (Triton or OpenVINO) +- `ANTHROPIC_API_KEY`: Anthropic API key (for Claude backend) +- `ANTHROPIC_MODEL`: Claude model to use (default: `claude-sonnet-4-20250514`) +- `OPENAI_API_KEY`: OpenAI API key (for GPT backend) +- `OPENAI_MODEL`: OpenAI model to use (default: `gpt-4o`) +- `GOOGLE_API_KEY`: Google API key (for Gemini backend) +- `GOOGLE_MODEL`: Google model to use (default: `gemini-1.5-pro`) +- `LLM_SERVER_URL`: URL of OpenAI-compatible LLM server +- `LLM_MODEL_NAME`: Model name for OpenAI-compatible server (default: `local-model`) +- `LLM_API_KEY`: API key for OpenAI-compatible server (default: `not-needed`) +- `APP_TITLE`: Application title +- `APP_DESCRIPTION`: Application description +- `LOGO_URL`: URL for logo image +- `PRIMARY_COLOR`: Primary theme color (CSS) + +## API Endpoints + +- `GET /` - Web interface +- `GET /health` - Health check +- `GET /models` - List available models +- `POST /predict` - Run inference +- `GET /models//metadata` - Get model metadata +- `POST /agent/chat` - AI agent chat +- `GET /agent/status` - Agent status + +## Customization + +### Class Names + +Edit `class_names.json` with your model's class labels: + +```json +[ + "cat", + "dog", + "bird", + ... +] +``` + +### Preprocessing + +Modify `client.py` to adjust preprocessing for your model: + +```python +def preprocess_image(self, image_path, ...): + # Customize resize, normalization, etc. +``` + +## Local Development + +```bash +python -m venv venv +source venv/bin/activate +pip install -r requirements.txt + +# Option 1: Use Anthropic Claude +export ANTHROPIC_API_KEY=sk-ant-your-key-here + +# Option 2: Use local LLM server +export LLM_SERVER_URL=http://localhost:11434 # e.g., Ollama +export LLM_MODEL_NAME=llama3.1 + +python webapp/app.py +``` + +## Architecture + +``` +business-logic/ +├── client.py # Model server client (Triton/OpenVINO) +├── requirements.txt # Python dependencies +├── Dockerfile # Container build +└── webapp/ + ├── app.py # Flask application + ├── static/ # CSS, JS assets + ├── templates/ # HTML templates + ├── agent/ # Agent package (backward compatibility) + │ ├── tools.py # Re-exports from mcp package + │ └── prompts.py # LLM chat processing + ├── router/ # LLM Router package + │ ├── config.py # Provider configuration + │ ├── llm_router.py # Multi-provider routing + │ └── adapters/ # Provider-specific adapters + ├── inference/ # Inference client wrapper + └── mcp/ # MCP (Model Context Protocol) tools + ├── base.py # Base utilities (ToolResult, ok, error_response) + ├── session.py # Session storage management + ├── registry.py # Tool registration and execution + └── tools/ # Individual tool modules + ├── list_models.py + ├── model_metadata.py + ├── model_inputs.py + ├── model_outputs.py + ├── model_type.py + ├── server_status.py + ├── api_examples.py + ├── integration_guide.py + └── recommendations.py +``` + +## MCP Package + +The MCP (Model Context Protocol) package provides a modular tool framework for AI agent interactions with ML inference servers. Each tool is in its own file for easy maintenance and extension. + +### Usage + +```python +# Direct imports from mcp package +from mcp import execute_tool, TOOL_SCHEMAS, TOOL_FUNCTIONS +from mcp.tools import list_available_models, get_model_metadata + +# Or use backward-compatible imports +from agent.tools import TOOL_SCHEMAS, execute_tool +``` + +### Adding New Tools + +To add a new tool, create a file in `webapp/mcp/tools/`: + +```python +# mcp/tools/my_new_tool.py +from ..base import ok, error_response, get_client +from ..registry import register_tool + +def my_new_tool(param: str) -> Dict[str, Any]: + """Tool implementation.""" + try: + client = get_client() + # Your tool logic here + return ok(result="success", data={...}) + except Exception as e: + return error_response(e, operation="my_new_tool") + +# Auto-register the tool +register_tool( + name="my_new_tool", + func=my_new_tool, + description="Description for AI agent to understand when to use this tool", + input_schema={ + "type": "object", + "properties": { + "param": { + "type": "string", + "description": "Parameter description" + } + }, + "required": ["param"] + } +) +``` + +Then add the import to `mcp/tools/__init__.py`: + +```python +from .my_new_tool import my_new_tool +``` + +The tool will be automatically available to the AI agent. diff --git a/edgeai/ondevice-eval-agent/client/__init__.py b/edgeai/ondevice-eval-agent/client/__init__.py new file mode 100644 index 00000000..3a2aeb0e --- /dev/null +++ b/edgeai/ondevice-eval-agent/client/__init__.py @@ -0,0 +1,165 @@ +""" +Model Server Client Package. + +A flexible client for NVIDIA Triton Inference Server and OpenVINO Model Server. +Automatically detects model input/output specifications for any image model. +Communicates via the KServe v2 gRPC protocol for low-latency binary tensor +transfer, with an optional HTTP session for Prometheus metrics. + +Thread Safety: + This client uses threading.Lock for mutable caches. Multiple threads can + safely share a single client instance. + +Quick Start: + >>> from client import ModelServerClient + >>> + >>> client = ModelServerClient(grpc_url="localhost:8001") + >>> models = client.get_available_models() + >>> result = client.infer_image(image_bytes, models[0]) + +Context Manager: + >>> with ModelServerClient() as client: + ... result = client.infer_image("image.jpg", "resnet50") + +Modules: + client: Main ModelServerClient facade class + config: Constants, configuration dataclasses, API paths + exceptions: Exception hierarchy for error handling + preprocessing: Image preprocessing and normalization + metadata: Model metadata retrieval and caching (gRPC) + discovery: Server type detection and health checking (gRPC) + inference: Inference request handling and response processing (gRPC) + grpc_client: gRPC client factory and response conversion utilities + http_session: HTTP session creation (metrics endpoint only) +""" + +from .client import ModelServerClient +from .config import ( + APIPath, + COMMON_CHANNEL_COUNTS, + DEFAULT_DATA_FORMAT, + DEFAULT_GRPC_PORT, + DEFAULT_GRPC_PORT_OPENVINO, + DEFAULT_GRPC_PORT_TRITON, + DEFAULT_IMAGENET_MEAN, + DEFAULT_IMAGENET_STD, + DEFAULT_INFERENCE_TIMEOUT_SECONDS, + DEFAULT_INPUT_SPEC, + DEFAULT_METRICS_PATH, + DEFAULT_METRICS_PORT, + DEFAULT_OUTPUT_SPEC, + DEFAULT_TARGET_SIZE, + DEFAULT_TIMEOUT_SECONDS, + InputSpec, + MAX_RETRIES, + OutputSpec, + PIXEL_VALUE_MAX, + PreprocessingConfig, + RETRY_BACKOFF_FACTOR, + SERVER_TYPE_OPENVINO, + SERVER_TYPE_TRITON, + SERVER_TYPE_UNKNOWN, + ServerType, +) +from .discovery import HealthStatus, ModelState, ServerDiscovery, ServerInfo +from .exceptions import ( + ConfigurationError, + ImagePreprocessingError, + InferenceError, + ModelMetadataError, + ModelNotReadyError, + ModelServerError, + ServerConnectionError, +) +from .grpc_client import ( + create_grpc_client, + grpc_url_from_http, + parse_prometheus_metrics, + get_triton_latency_metrics, +) +from .http_session import SessionManager, create_session +from .inference import ClassificationResult, InferenceRequest, InferenceResult, InferenceRunner +from .metadata import ModelMetadataManager, TensorSpec +from .preprocessing import ImagePreprocessor, PreprocessingParams +from .llm_client import ( + LLMInferenceClient, + LLMModelInfo, + LLMPerformanceMetrics, + LLMServerMetrics, + LLMServerType, + get_llm_client, +) + +__all__ = [ + # Main client + "ModelServerClient", + # Server types + "ServerType", + "SERVER_TYPE_TRITON", + "SERVER_TYPE_OPENVINO", + "SERVER_TYPE_UNKNOWN", + # Specifications + "InputSpec", + "OutputSpec", + "PreprocessingConfig", + "DEFAULT_INPUT_SPEC", + "DEFAULT_OUTPUT_SPEC", + # Constants + "DEFAULT_IMAGENET_MEAN", + "DEFAULT_IMAGENET_STD", + "DEFAULT_TARGET_SIZE", + "DEFAULT_DATA_FORMAT", + "DEFAULT_TIMEOUT_SECONDS", + "DEFAULT_INFERENCE_TIMEOUT_SECONDS", + "MAX_RETRIES", + "RETRY_BACKOFF_FACTOR", + "PIXEL_VALUE_MAX", + "COMMON_CHANNEL_COUNTS", + "APIPath", + # gRPC + "DEFAULT_GRPC_PORT", + "DEFAULT_GRPC_PORT_TRITON", + "DEFAULT_GRPC_PORT_OPENVINO", + "DEFAULT_METRICS_PORT", + "DEFAULT_METRICS_PATH", + "create_grpc_client", + "grpc_url_from_http", + "parse_prometheus_metrics", + "get_triton_latency_metrics", + # Exceptions + "ModelServerError", + "InferenceError", + "ModelNotReadyError", + "ServerConnectionError", + "ImagePreprocessingError", + "ModelMetadataError", + "ConfigurationError", + # Discovery + "ServerDiscovery", + "ServerInfo", + "HealthStatus", + "ModelState", + # Metadata + "ModelMetadataManager", + "TensorSpec", + # Preprocessing + "ImagePreprocessor", + "PreprocessingParams", + # Inference + "InferenceRunner", + "InferenceRequest", + "InferenceResult", + "ClassificationResult", + # HTTP Session + "create_session", + "SessionManager", + # LLM Client + "LLMInferenceClient", + "LLMModelInfo", + "LLMPerformanceMetrics", + "LLMServerMetrics", + "LLMServerType", + "get_llm_client", +] + +__version__ = "3.0.0" diff --git a/edgeai/ondevice-eval-agent/client/client.py b/edgeai/ondevice-eval-agent/client/client.py new file mode 100644 index 00000000..946345e2 --- /dev/null +++ b/edgeai/ondevice-eval-agent/client/client.py @@ -0,0 +1,785 @@ +""" +Model Server Client - Main Facade. + +This module provides the main ModelServerClient class that combines all +components into a cohesive, easy-to-use interface for inference operations. + +The client communicates with NVIDIA Triton Inference Server and OpenVINO +Model Server via the KServe v2 gRPC protocol for low-latency binary +tensor transfer. An optional HTTP session is maintained for fetching +Prometheus metrics from the Triton metrics endpoint. +""" + +from __future__ import annotations + +import json +import logging +import os +from pathlib import Path +from typing import Any, BinaryIO, Dict, Final, List, Literal, Optional, Tuple, Union + +import numpy as np +import requests +from numpy.typing import NDArray + +from .config import ( + DEFAULT_GRPC_PORT, + DEFAULT_INFERENCE_TIMEOUT_SECONDS, + DEFAULT_METRICS_PATH, + DEFAULT_METRICS_PORT, + DEFAULT_TIMEOUT_SECONDS, + MAX_RETRIES, + PreprocessingConfig, + ServerType, +) +from .discovery import ServerDiscovery, HealthStatus +from .exceptions import ( + ImagePreprocessingError, + InferenceError, +) +from .grpc_client import ( + create_grpc_client, + grpc_url_from_http, + parse_prometheus_metrics, + get_triton_latency_metrics, + repository_index_to_list, + _TRITON_TO_NUMPY, + InferenceServerException, +) +import tritonclient.grpc as grpcclient +from .http_session import create_session +from .inference import InferenceRunner +from .metadata import ModelMetadataManager +from .preprocessing import ImagePreprocessor + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Constants +# ============================================================================= + +# Default server URLs +_DEFAULT_SERVER_URL: Final[str] = "http://localhost:8000" +_DEFAULT_GRPC_URL: Final[str] = f"localhost:{DEFAULT_GRPC_PORT}" + +# Environment variable names +_ENV_MODEL_SERVER_URL: Final[str] = "MODEL_SERVER_URL" +_ENV_GRPC_URL: Final[str] = "MODEL_SERVER_GRPC_URL" +_ENV_METRICS_URL: Final[str] = "MODEL_SERVER_METRICS_URL" +_ENV_INFERENCE_BACKEND: Final[str] = "INFERENCE_BACKEND" +_ENV_KNOWN_MODELS: Final[str] = "KNOWN_MODELS" +_ENV_MODEL_NAME: Final[str] = "MODEL_NAME" + +# Class names file +_CLASS_NAMES_FILENAME: Final[str] = "class_names.json" + + +# ============================================================================= +# Model Server Client +# ============================================================================= + +class ModelServerClient: + """ + Client for communicating with NVIDIA Triton or OpenVINO Model Server. + + Uses gRPC (KServe v2 protocol) for all inference, metadata, and + health operations. An HTTP session is kept solely for fetching + Prometheus metrics from the Triton metrics endpoint (port 8002). + + Features: + - gRPC binary tensor transfer (no JSON serialization overhead) + - Automatic server type detection (Triton vs OpenVINO) + - Auto-detection of model input/output specifications + - Image preprocessing with configurable normalization + - Thread-safe caching of metadata + - Prometheus metrics integration for accurate server-side latency + - Context manager support for resource cleanup + + Thread Safety: + All mutable caches are protected by locks. Multiple threads can + safely share a single client instance. + + Example: + >>> client = ModelServerClient(grpc_url="localhost:8001") + >>> models = client.get_available_models() + >>> result = client.infer_image(image_bytes, models[0]) + + >>> with ModelServerClient() as client: + ... result = client.infer_image("image.jpg", "resnet50") + """ + + __slots__ = ( + "server_url", + "grpc_url", + "metrics_url", + "timeout", + "inference_timeout", + "inference_backend", + "_known_models", + "_grpc_client", + "_http_session", + "_preprocessor", + "_metadata_manager", + "_discovery", + "_inference_runner", + ) + + def __init__( + self, + server_url: Optional[str] = None, + *, + grpc_url: Optional[str] = None, + metrics_url: Optional[str] = None, + timeout: int = DEFAULT_TIMEOUT_SECONDS, + inference_timeout: int = DEFAULT_INFERENCE_TIMEOUT_SECONDS, + max_retries: int = MAX_RETRIES, + test_connectivity: bool = True, + ) -> None: + """ + Initialize the client. + + Args: + server_url: HTTP base URL (used to derive gRPC/metrics URLs + when not given explicitly). Falls back to + ``MODEL_SERVER_URL`` env var or ``http://localhost:8000``. + grpc_url: ``host:port`` for gRPC. Falls back to + ``MODEL_SERVER_GRPC_URL`` env var or derived from + *server_url* (same host, port 8001). + metrics_url: Full URL for Triton metrics endpoint. Falls + back to ``MODEL_SERVER_METRICS_URL`` env var or + derived from *server_url* (same host, port 8002). + timeout: Default timeout for API requests in seconds. + inference_timeout: Timeout for inference requests. + max_retries: Maximum retry attempts for HTTP requests. + test_connectivity: Whether to test server connectivity on init. + """ + # Resolve URLs + self.server_url = self._resolve_server_url(server_url) + self.grpc_url = self._resolve_grpc_url(grpc_url, self.server_url) + self.metrics_url = self._resolve_metrics_url(metrics_url, self.server_url) + self.timeout = timeout + self.inference_timeout = inference_timeout + + # Load configuration from environment + self.inference_backend = os.environ.get(_ENV_INFERENCE_BACKEND, "").lower() + self._known_models = self._parse_known_models() + + # Create gRPC client (primary communication channel) + self._grpc_client = create_grpc_client(self.grpc_url) + + # Create HTTP session (only for metrics endpoint) + self._http_session = create_session(max_retries) + + # Initialize components with gRPC client + self._preprocessor = ImagePreprocessor() + self._metadata_manager = ModelMetadataManager( + self._grpc_client, timeout + ) + self._discovery = ServerDiscovery( + self._grpc_client, timeout, self.inference_backend + ) + self._inference_runner = InferenceRunner( + self._grpc_client, inference_timeout + ) + + # Load class names if available + self._load_class_names() + + # Log initialization + logger.info(f"Model server client initialized (gRPC: {self.grpc_url})") + if self.inference_backend: + logger.info(f"Inference backend preference: {self.inference_backend}") + + # Test connectivity if requested + if test_connectivity: + self._discovery.test_connectivity() + + # ========================================================================= + # Context Manager Protocol + # ========================================================================= + + def __enter__(self) -> "ModelServerClient": + """Context manager entry.""" + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Context manager exit - close resources.""" + self.close() + + def close(self) -> None: + """Close the HTTP session and release resources.""" + if hasattr(self, "_http_session") and self._http_session: + self._http_session.close() + logger.debug("HTTP session closed") + if hasattr(self, "_grpc_client") and self._grpc_client: + try: + self._grpc_client.close() + except Exception: + pass + logger.debug("gRPC client closed") + + # ========================================================================= + # Cache Management + # ========================================================================= + + def clear_cache(self) -> None: + """Clear all cached metadata. Thread-safe.""" + self._metadata_manager.clear_cache() + self._discovery.clear_cache() + logger.info("All caches cleared") + + # ========================================================================= + # Configuration Properties + # ========================================================================= + + @property + def preprocessing_config(self) -> Dict[str, Any]: + """Get preprocessing config as dict (backward compatibility).""" + return self._preprocessor.config.to_dict() + + @preprocessing_config.setter + def preprocessing_config(self, value: Dict[str, Any]) -> None: + """Set preprocessing config from dict (backward compatibility).""" + self._preprocessor.config = PreprocessingConfig.from_dict(value) + + def set_preprocessing_config(self, config: Dict[str, Any]) -> None: + """Update preprocessing configuration.""" + self._preprocessor.update_config(config) + + @property + def class_names(self) -> Optional[List[str]]: + """Get class names for labeling predictions.""" + return self._inference_runner.class_names + + @class_names.setter + def class_names(self, value: Optional[List[str]]) -> None: + """Set class names for labeling predictions.""" + self._inference_runner.class_names = value + + # ========================================================================= + # Server Discovery + # ========================================================================= + + def detect_server_type(self) -> str: + """Detect the type of inference server (Triton or OpenVINO).""" + return self._discovery.detect_server_type() + + def get_server_info(self) -> Optional[Dict[str, Any]]: + """Get server information. Thread-safe.""" + return self._discovery.get_server_info() + + def check_server_health(self) -> Tuple[bool, str]: + """Check if the inference server is healthy and ready.""" + health = self._discovery.check_server_health() + return (health.is_healthy, health.message) + + def get_server_device_info(self) -> Literal["CPU", "GPU"]: + """Detect compute device (CPU/GPU) from the inference server.""" + return self._discovery.get_server_device_info() + + def check_model_ready(self, model_name: str) -> bool: + """Check if a specific model is ready for inference.""" + return self._discovery.check_model_ready(model_name) + + def get_available_models(self) -> List[str]: + """Get list of available models from the inference server.""" + return self._discovery.get_available_models(self._known_models) + + # ========================================================================= + # Model Repository Management + # ========================================================================= + + def get_repository_index(self) -> List[Dict[str, Any]]: + """ + Get the full model repository index (all states). + + Unlike ``get_available_models()`` which only returns READY models, + this returns every entry including UNAVAILABLE or LOADING models + with their ``state`` and ``reason`` fields. + + Returns: + List of dicts with keys: ``name``, ``version``, ``state``, ``reason``. + + Raises: + InferenceServerException: If the repository index is not supported + (e.g. OpenVINO Model Server without repository index). + """ + index = self._grpc_client.get_model_repository_index() + return repository_index_to_list(index) + + def load_model( + self, + model_name: str, + config: Optional[str] = None, + files: Optional[Dict[str, bytes]] = None, + ) -> None: + """ + Load or reload a model on the inference server. + + Requires Triton to be started with ``--model-control-mode=explicit`` + or ``--model-control-mode=poll``. + + Args: + model_name: Name of the model to load. + config: Optional JSON string of a model config override. + When provided, this config is used instead of config.pbtxt + on disk. + files: Optional dict mapping file paths to bytes content. + Requires *config* to also be provided. + + Raises: + InferenceServerException: If loading fails or model control + mode does not allow it. + """ + try: + self._grpc_client.load_model( + model_name, config=config, files=files, + ) + # Clear stale metadata for the loaded model + self._metadata_manager.clear_cache() + logger.info(f"Model '{model_name}' load request sent") + except InferenceServerException as e: + err_msg = str(e).lower() + if "model control" in err_msg or "not allowed" in err_msg: + raise InferenceServerException( + f"Cannot load model: Triton model control mode does not " + f"allow API-driven load. Start Triton with " + f"--model-control-mode=explicit or poll. " + f"Original error: {e}" + ) from e + raise + + def unload_model(self, model_name: str) -> None: + """ + Unload a model from the inference server. + + Args: + model_name: Name of the model to unload. + + Raises: + InferenceServerException: If unloading fails. + """ + self._grpc_client.unload_model(model_name) + self._metadata_manager.clear_cache() + logger.info(f"Model '{model_name}' unload request sent") + + def send_raw_inference( + self, + model_name: str, + inputs: List[Tuple[str, "NDArray", str]], + ) -> Dict[str, Any]: + """ + Send a raw multi-input inference request via gRPC. + + Unlike ``send_inference_request()`` which is single-input and + image-oriented, this accepts arbitrary inputs for probing + unknown models. + + Args: + inputs: List of ``(name, numpy_array, triton_dtype_string)`` + tuples. Example: ``[("input", data, "FP32")]``. + + Returns: + Dict with ``model_name`` and ``outputs`` list, each output + having ``name``, ``shape``, ``datatype``, and ``data`` keys. + + Raises: + InferenceServerException: On gRPC errors. + """ + grpc_inputs: List[grpcclient.InferInput] = [] + for name, data, dtype in inputs: + inp = grpcclient.InferInput(name, list(data.shape), dtype) + # Map Triton dtype to numpy dtype for correct casting + np_dtype = _TRITON_TO_NUMPY.get(dtype, np.dtype("float32")) + inp.set_data_from_numpy(data.astype(np_dtype)) + grpc_inputs.append(inp) + + result = self._grpc_client.infer( + model_name=model_name, + inputs=grpc_inputs, + client_timeout=self.inference_timeout, + ) + + # Convert result to dict by enumerating the response outputs + outputs: List[Dict[str, Any]] = [] + response = result.get_response() + if hasattr(response, "outputs"): + for idx, out_meta in enumerate(response.outputs): + out_name = out_meta.name if hasattr(out_meta, "name") else f"output_{idx}" + out_data = result.as_numpy(out_name) + outputs.append({ + "name": out_name, + "shape": list(out_data.shape), + "datatype": out_meta.datatype if hasattr(out_meta, "datatype") else "FP32", + "data": out_data, + }) + + return {"model_name": model_name, "outputs": outputs} + + # ========================================================================= + # Model Metadata + # ========================================================================= + + def get_model_metadata( + self, + model_name: str, + use_cache: bool = True, + ) -> Optional[Dict[str, Any]]: + """Get detailed model metadata from inference server.""" + return self._metadata_manager.get_metadata(model_name, use_cache) + + def get_model_config( + self, + model_name: str, + use_cache: bool = True, + ) -> Optional[Dict[str, Any]]: + """Get model configuration (config.pbtxt equivalent) from the server.""" + return self._metadata_manager.get_model_config(model_name, use_cache) + + def get_model_input_spec(self, model_name: str) -> Dict[str, Any]: + """Auto-detect model input specifications from server metadata.""" + return self._metadata_manager.get_input_spec(model_name) + + def get_model_output_spec(self, model_name: str) -> Dict[str, Any]: + """Auto-detect model output specifications.""" + return self._metadata_manager.get_output_spec(model_name) + + def get_all_output_specs(self, model_name: str) -> List[Dict[str, Any]]: + """Get specifications for ALL model outputs.""" + return self._metadata_manager.get_all_output_specs(model_name) + + def get_model_input_shape(self, model_name: str) -> Tuple[int, int]: + """Get the input shape (height, width) for a specific model.""" + return self._metadata_manager.get_input_shape(model_name) + + # ========================================================================= + # Image Preprocessing + # ========================================================================= + + def preprocess_image_bytes( + self, + image_bytes: Union[bytes, BinaryIO], + model_name: Optional[str] = None, + target_size: Optional[Tuple[int, int]] = None, + ) -> Optional[NDArray[np.floating[Any]]]: + """Preprocess image from bytes for model inference.""" + try: + input_spec = self.get_model_input_spec(model_name) if model_name else None + return self._preprocessor.preprocess_bytes(image_bytes, input_spec, target_size) + except ImagePreprocessingError as e: + logger.error(str(e)) + return None + + def preprocess_image( + self, + image_path: str, + model_name: Optional[str] = None, + target_size: Optional[Tuple[int, int]] = None, + ) -> Optional[NDArray[np.floating[Any]]]: + """Preprocess image from file path for model inference.""" + try: + input_spec = self.get_model_input_spec(model_name) if model_name else None + return self._preprocessor.preprocess_file(image_path, input_spec, target_size) + except ImagePreprocessingError as e: + logger.error(str(e)) + return None + + # ========================================================================= + # Inference + # ========================================================================= + + def send_inference_request( + self, + image_array: NDArray[np.floating[Any]], + model_name: str, + measure_latency: bool = False, + ) -> Optional[Dict[str, Any]]: + """Send inference request to inference server via gRPC.""" + try: + input_spec = self.get_model_input_spec(model_name) + server_type = self.detect_server_type() + return self._inference_runner.send_inference_request( + image_array, model_name, input_spec, server_type, measure_latency + ) + except InferenceError as e: + logger.error(str(e)) + return None + + def process_prediction( + self, + response: Optional[Dict[str, Any]], + model_name: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: + """Process the prediction response from inference server.""" + if response is None: + return None + try: + return self._inference_runner.process_prediction(response, model_name) + except InferenceError as e: + logger.error(str(e)) + return None + + def infer_image( + self, + image_data: Union[bytes, BinaryIO, str], + model_name: str, + *, + measure_latency: bool = False, + process_result: bool = True, + ) -> Optional[Dict[str, Any]]: + """ + High-level convenience method: preprocess image and run inference. + + This is the recommended API for most use cases. + """ + # Step 1: Preprocess image + image_array = self._preprocess_image_data(image_data, model_name) + if image_array is None: + return None + + # Step 2: Run inference + response = self.send_inference_request( + image_array, model_name, measure_latency=measure_latency + ) + if response is None: + return None + + # Step 3: Process results + if process_result: + result = self.process_prediction(response, model_name) + if result and measure_latency and "latency" in response: + result["latency"] = response["latency"] + return result + + return response + + # ========================================================================= + # Metrics (Prometheus / HTTP) + # ========================================================================= + + def get_metrics_raw(self) -> Optional[str]: + """ + Fetch raw Prometheus metrics text from the Triton metrics endpoint. + + Returns: + Raw metrics text, or None if unavailable. + """ + try: + response = self._http_session.get( + self.metrics_url, timeout=self.timeout + ) + if response.status_code == 200: + return response.text + except requests.RequestException as e: + logger.debug(f"Metrics endpoint unavailable: {e}") + return None + + def get_model_metrics(self, model_name: str) -> Optional[Dict[str, float]]: + """ + Fetch Triton server-side latency metrics for a specific model. + + Returns a dict with keys like ``queue_ms``, ``compute_infer_ms``, + ``compute_input_ms``, ``compute_output_ms``, ``request_duration_ms``, + and ``request_count``. All durations are in milliseconds. + + Returns: + Metrics dict, or None if the endpoint is unavailable. + """ + raw = self.get_metrics_raw() + if raw is None: + return None + parsed = parse_prometheus_metrics(raw, model_name=model_name) + return get_triton_latency_metrics(parsed) + + # ========================================================================= + # API Information + # ========================================================================= + + def get_api_endpoints_info(self, model_name: str) -> Dict[str, Any]: + """Get API endpoint information for developers.""" + input_spec = self.get_model_input_spec(model_name) + output_spec = self.get_model_output_spec(model_name) + server_type = self.detect_server_type() + + endpoints: Dict[str, Any] = { + "server_type": server_type, + "protocol": "gRPC (KServe v2)", + "grpc_url": self.grpc_url, + "metrics_url": self.metrics_url, + "detected_input_spec": input_spec, + "detected_output_spec": output_spec, + } + + if server_type == ServerType.TRITON.value: + endpoints.update(self._build_triton_endpoints(model_name)) + else: + endpoints.update(self._build_openvino_endpoints(model_name)) + + return endpoints + + def get_full_model_info(self, model_name: str) -> Dict[str, Any]: + """Get comprehensive model information.""" + return { + "model_name": model_name, + "server_type": self.detect_server_type(), + "server_info": self.get_server_info(), + "ready": self.check_model_ready(model_name), + "input_spec": self.get_model_input_spec(model_name), + "output_spec": self.get_model_output_spec(model_name), + "metadata": self.get_model_metadata(model_name), + } + + # ========================================================================= + # Private - Initialization Helpers + # ========================================================================= + + @staticmethod + def _resolve_server_url(server_url: Optional[str]) -> str: + """Resolve HTTP server URL from parameter or environment.""" + url = server_url or os.environ.get(_ENV_MODEL_SERVER_URL, _DEFAULT_SERVER_URL) + return url.rstrip("/") + + @staticmethod + def _resolve_grpc_url(grpc_url: Optional[str], server_url: str) -> str: + """Resolve gRPC URL from parameter, env var, or derived from HTTP URL.""" + if grpc_url: + # Strip scheme if present + if "://" in grpc_url: + from urllib.parse import urlparse + parsed = urlparse(grpc_url) + return f"{parsed.hostname or 'localhost'}:{parsed.port or DEFAULT_GRPC_PORT}" + return grpc_url + + env_grpc = os.environ.get(_ENV_GRPC_URL, "") + if env_grpc: + return env_grpc + + # Derive from HTTP server_url: same host, gRPC port + return grpc_url_from_http(server_url, DEFAULT_GRPC_PORT) + + @staticmethod + def _resolve_metrics_url(metrics_url: Optional[str], server_url: str) -> str: + """Resolve Triton metrics URL.""" + if metrics_url: + return metrics_url + + env_metrics = os.environ.get(_ENV_METRICS_URL, "") + if env_metrics: + return env_metrics + + # Derive from HTTP server_url: same host, metrics port + from urllib.parse import urlparse + parsed = urlparse(server_url) + host = parsed.hostname or "localhost" + return f"http://{host}:{DEFAULT_METRICS_PORT}{DEFAULT_METRICS_PATH}" + + @staticmethod + def _parse_known_models() -> List[str]: + """Parse known model names from environment variables.""" + models: List[str] = [] + + models_str = os.environ.get(_ENV_KNOWN_MODELS, "") + if models_str: + for model in models_str.split(","): + model = model.strip() + if model and model not in models: + models.append(model) + + model_name = os.environ.get(_ENV_MODEL_NAME, "").strip() + if model_name and model_name not in models: + models.append(model_name) + + if models: + logger.info(f"Known models from environment: {models}") + + return models + + def _load_class_names(self) -> None: + """Load class names from class_names.json if available.""" + try: + class_names_path = Path(__file__).parent.parent / _CLASS_NAMES_FILENAME + if class_names_path.exists(): + with open(class_names_path, encoding="utf-8") as f: + class_names = json.load(f) + self._inference_runner.class_names = class_names + logger.info(f"Loaded {len(class_names)} class names from file") + except (OSError, json.JSONDecodeError) as e: + logger.debug(f"Could not load class names: {e}") + + def _preprocess_image_data( + self, + image_data: Union[bytes, BinaryIO, str], + model_name: str, + ) -> Optional[NDArray[np.floating[Any]]]: + """Preprocess image data from any supported format.""" + if isinstance(image_data, str): + return self.preprocess_image(image_data, model_name) + return self.preprocess_image_bytes(image_data, model_name) + + # ========================================================================= + # Private - Endpoint Documentation + # ========================================================================= + + def _build_triton_endpoints(self, model_name: str) -> Dict[str, Any]: + """Build Triton-specific endpoint documentation.""" + return { + "grpc_inference": { + "endpoint": f"{self.grpc_url}", + "protocol": "gRPC", + "description": "Send inference via gRPC (binary tensor transfer)", + "python_example": ( + f"import tritonclient.grpc as grpcclient\n" + f"client = grpcclient.InferenceServerClient(url='{self.grpc_url}')\n" + f"inputs = [grpcclient.InferInput('input', shape, 'FP32')]\n" + f"inputs[0].set_data_from_numpy(np_array)\n" + f"result = client.infer('{model_name}', inputs)" + ), + }, + "metrics": { + "endpoint": self.metrics_url, + "method": "GET", + "description": "Prometheus metrics (latency, throughput, etc.)", + "curl_command": f"curl {self.metrics_url}", + }, + "rest_inference": { + "endpoint": f"{self.server_url}/v2/models/{model_name}/infer", + "method": "POST", + "description": "REST inference (fallback, higher latency than gRPC)", + "curl_command": ( + f'curl -X POST {self.server_url}/v2/models/{model_name}/infer ' + f'-H "Content-Type: application/json" -d \'{{"inputs": [...]}}\'' + ), + }, + } + + def _build_openvino_endpoints(self, model_name: str) -> Dict[str, Any]: + """Build OpenVINO-specific endpoint documentation.""" + return { + "grpc_inference": { + "endpoint": f"{self.grpc_url}", + "protocol": "gRPC", + "description": "Send inference via gRPC (KServe v2 protocol)", + "python_example": ( + f"import tritonclient.grpc as grpcclient\n" + f"client = grpcclient.InferenceServerClient(url='{self.grpc_url}')\n" + f"inputs = [grpcclient.InferInput('input', shape, 'FP32')]\n" + f"inputs[0].set_data_from_numpy(np_array)\n" + f"result = client.infer('{model_name}', inputs)" + ), + }, + "rest_inference": { + "endpoint": f"{self.server_url}/v2/models/{model_name}/infer", + "method": "POST", + "description": "REST inference (KServe v2, higher latency than gRPC)", + "curl_command": ( + f'curl -X POST {self.server_url}/v2/models/{model_name}/infer ' + f'-H "Content-Type: application/json" -d \'{{"inputs": [...]}}\'' + ), + }, + } + + +__all__ = [ + "ModelServerClient", +] diff --git a/edgeai/ondevice-eval-agent/client/config.py b/edgeai/ondevice-eval-agent/client/config.py new file mode 100644 index 00000000..d424f614 --- /dev/null +++ b/edgeai/ondevice-eval-agent/client/config.py @@ -0,0 +1,301 @@ +""" +Constants and configuration for Model Server Client. + +This module centralizes all constants, default values, and configuration +dataclasses used across the client modules. + +Organization: + - Server Types: Enum and constants for server identification + - Image Preprocessing: Default values for image normalization + - Network Configuration: Timeouts and retry settings + - API Paths: URL templates for KServe v2 and TensorFlow Serving APIs + - Specifications: Dataclasses for input/output tensor metadata +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, Final, List, Optional, Tuple + + +# ============================================================================= +# Server Types +# ============================================================================= + +class ServerType(str, Enum): + """ + Inference server types supported by the client. + + The client automatically detects the server type, but users can + also explicitly specify their preference via INFERENCE_BACKEND. + """ + TRITON = "triton" + OPENVINO = "openvino" + UNKNOWN = "unknown" + + +# Legacy constants for backward compatibility with existing code +SERVER_TYPE_TRITON: Final[str] = ServerType.TRITON.value +SERVER_TYPE_OPENVINO: Final[str] = ServerType.OPENVINO.value +SERVER_TYPE_UNKNOWN: Final[str] = ServerType.UNKNOWN.value + + +# ============================================================================= +# Image Preprocessing Defaults +# ============================================================================= + +# ImageNet normalization constants (standard for pretrained vision models) +DEFAULT_IMAGENET_MEAN: Final[tuple[float, float, float]] = (0.485, 0.456, 0.406) +DEFAULT_IMAGENET_STD: Final[tuple[float, float, float]] = (0.229, 0.224, 0.225) + +# Default image dimensions (standard ImageNet input size) +DEFAULT_TARGET_SIZE: Final[tuple[int, int]] = (224, 224) # (height, width) + +# Data format (batch, channels, height, width) +DEFAULT_DATA_FORMAT: Final[str] = "NCHW" + +# Maximum pixel value for normalization (8-bit images) +PIXEL_VALUE_MAX: Final[float] = 255.0 + +# Common channel configurations for format detection +# Used to distinguish NCHW from NHWC based on dimension values +COMMON_CHANNEL_COUNTS: Final[frozenset[int]] = frozenset({1, 3, 4}) + + +# ============================================================================= +# Network Configuration +# ============================================================================= + +# HTTP request timeouts (in seconds) +DEFAULT_TIMEOUT_SECONDS: Final[int] = 30 +DEFAULT_INFERENCE_TIMEOUT_SECONDS: Final[int] = 60 + +# Retry configuration +MAX_RETRIES: Final[int] = 3 +RETRY_BACKOFF_FACTOR: Final[float] = 0.5 + +# Default gRPC ports for inference servers +DEFAULT_GRPC_PORT_TRITON: Final[int] = 8001 +DEFAULT_GRPC_PORT_OPENVINO: Final[int] = 9000 +DEFAULT_GRPC_PORT: Final[int] = 8001 # Default assumes Triton + +# Triton metrics endpoint (Prometheus format, HTTP only) +DEFAULT_METRICS_PORT: Final[int] = 8002 +DEFAULT_METRICS_PATH: Final[str] = "/metrics" + + +# ============================================================================= +# API Path Templates +# ============================================================================= + +class APIPath: + """ + API endpoint path templates for inference servers. + + Supports both KServe v2 API (Triton and OpenVINO) and + TensorFlow Serving v1 API (OpenVINO fallback). + + Usage: + >>> url = f"{base_url}{APIPath.V2_MODEL.format(model_name='resnet50')}" + """ + + # ------------------------------------------------------------------------- + # KServe v2 API paths (both Triton and OpenVINO) + # ------------------------------------------------------------------------- + + # Server endpoints + V2_ROOT: Final[str] = "/v2" + V2_HEALTH_READY: Final[str] = "/v2/health/ready" + V2_HEALTH_LIVE: Final[str] = "/v2/health/live" + + # Model endpoints (requires model_name parameter) + V2_MODEL: Final[str] = "/v2/models/{model_name}" + V2_MODEL_READY: Final[str] = "/v2/models/{model_name}/ready" + V2_MODEL_INFER: Final[str] = "/v2/models/{model_name}/infer" + V2_MODEL_CONFIG: Final[str] = "/v2/models/{model_name}/config" + + # Repository management (Triton-specific) + V2_REPO_INDEX: Final[str] = "/v2/repository/index" + + # ------------------------------------------------------------------------- + # OpenVINO v1 API paths (TensorFlow Serving format) + # ------------------------------------------------------------------------- + + V1_CONFIG: Final[str] = "/v1/config" + V1_MODEL: Final[str] = "/v1/models/{model_name}" + V1_MODEL_PREDICT: Final[str] = "/v1/models/{model_name}:predict" + + +# ============================================================================= +# Specification Dataclasses +# ============================================================================= + +@dataclass(frozen=True) +class InputSpec: + """ + Model input tensor specification. + + Describes the expected input format for a model, including shape, + data type, and layout format (NCHW vs NHWC). + + Attributes: + name: Input tensor name (e.g., 'images', 'input_0'). + shape: Full tensor shape including batch dimension. + datatype: Data type string ('FP32', 'FP16', 'INT8', etc.). + format: Layout format ('NCHW' or 'NHWC'). + channels: Number of color channels (typically 3 for RGB). + height: Input image height in pixels. + width: Input image width in pixels. + """ + name: str = "images" + shape: tuple[int, ...] = (-1, 3, 640, 640) + datatype: str = "FP32" + format: str = "NCHW" + channels: int = 3 + height: int = 640 + width: int = 640 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for backward compatibility and JSON serialization.""" + return { + "name": self.name, + "shape": list(self.shape), + "datatype": self.datatype, + "format": self.format, + "channels": self.channels, + "height": self.height, + "width": self.width, + } + + +@dataclass(frozen=True) +class OutputSpec: + """ + Model output tensor specification. + + Describes the output format of a model, including shape + and number of classes for classification models. + + Attributes: + name: Output tensor name (e.g., 'output0', 'predictions'). + shape: Full tensor shape including batch dimension. + datatype: Data type string ('FP32', 'FP16', etc.). + num_classes: Number of classes for classification (None for non-classification). + """ + name: str = "output0" + shape: tuple[int, ...] = (-1, 84, 8400) + datatype: str = "FP32" + num_classes: Optional[int] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for backward compatibility and JSON serialization.""" + return { + "name": self.name, + "shape": list(self.shape), + "datatype": self.datatype, + "num_classes": self.num_classes, + } + + +@dataclass +class PreprocessingConfig: + """ + Image preprocessing configuration. + + Controls how images are prepared for model inference, including + resizing, normalization, and format conversion. + + Attributes: + target_size: Target (height, width) for resizing. + normalize: Whether to apply ImageNet normalization. + mean: Per-channel mean values for normalization. + std: Per-channel standard deviation values for normalization. + format: Output format ('NCHW' or 'NHWC'). + + Example: + >>> config = PreprocessingConfig(target_size=(224, 224), normalize=True) + >>> preprocessor = ImagePreprocessor(config) + """ + target_size: tuple[int, int] = DEFAULT_TARGET_SIZE + normalize: bool = True + mean: List[float] = field(default_factory=lambda: list(DEFAULT_IMAGENET_MEAN)) + std: List[float] = field(default_factory=lambda: list(DEFAULT_IMAGENET_STD)) + format: str = DEFAULT_DATA_FORMAT + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for backward compatibility.""" + return { + "target_size": self.target_size, + "normalize": self.normalize, + "mean": self.mean, + "std": self.std, + "format": self.format, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PreprocessingConfig": + """ + Create PreprocessingConfig from dictionary. + + Args: + data: Configuration dictionary with optional keys. + + Returns: + New PreprocessingConfig instance. + """ + return cls( + target_size=tuple(data.get("target_size", DEFAULT_TARGET_SIZE)), + normalize=data.get("normalize", True), + mean=data.get("mean", list(DEFAULT_IMAGENET_MEAN)), + std=data.get("std", list(DEFAULT_IMAGENET_STD)), + format=data.get("format", DEFAULT_DATA_FORMAT), + ) + + +# ============================================================================= +# Default Specifications +# ============================================================================= + +# Default specifications as dicts for backward compatibility with existing code +DEFAULT_INPUT_SPEC: Final[Dict[str, Any]] = InputSpec().to_dict() +DEFAULT_OUTPUT_SPEC: Final[Dict[str, Any]] = OutputSpec().to_dict() + + +# ============================================================================= +# Module Exports +# ============================================================================= + +__all__ = [ + # Server types + "ServerType", + "SERVER_TYPE_TRITON", + "SERVER_TYPE_OPENVINO", + "SERVER_TYPE_UNKNOWN", + # Image preprocessing + "DEFAULT_IMAGENET_MEAN", + "DEFAULT_IMAGENET_STD", + "DEFAULT_TARGET_SIZE", + "DEFAULT_DATA_FORMAT", + "PIXEL_VALUE_MAX", + "COMMON_CHANNEL_COUNTS", + # Network + "DEFAULT_TIMEOUT_SECONDS", + "DEFAULT_INFERENCE_TIMEOUT_SECONDS", + "MAX_RETRIES", + "RETRY_BACKOFF_FACTOR", + # gRPC + "DEFAULT_GRPC_PORT_TRITON", + "DEFAULT_GRPC_PORT_OPENVINO", + "DEFAULT_GRPC_PORT", + "DEFAULT_METRICS_PORT", + "DEFAULT_METRICS_PATH", + # API paths + "APIPath", + # Specifications + "InputSpec", + "OutputSpec", + "PreprocessingConfig", + "DEFAULT_INPUT_SPEC", + "DEFAULT_OUTPUT_SPEC", +] diff --git a/edgeai/ondevice-eval-agent/client/discovery.py b/edgeai/ondevice-eval-agent/client/discovery.py new file mode 100644 index 00000000..7beb90a7 --- /dev/null +++ b/edgeai/ondevice-eval-agent/client/discovery.py @@ -0,0 +1,428 @@ +""" +Server discovery and health checking via gRPC. + +This module handles inference server detection, health checking, +and model discovery operations for both Triton and OpenVINO servers +using the KServe v2 gRPC protocol. +""" + +from __future__ import annotations + +import logging +import threading +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, Final, List, Literal, Optional + +import tritonclient.grpc as grpcclient +from tritonclient.utils import InferenceServerException + +from .config import ( + DEFAULT_TIMEOUT_SECONDS, + SERVER_TYPE_OPENVINO, + SERVER_TYPE_TRITON, + SERVER_TYPE_UNKNOWN, +) +from .grpc_client import ( + server_metadata_to_dict, + repository_index_to_list, +) + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Constants +# ============================================================================= + +class ModelState(str, Enum): + """Model readiness states from inference servers.""" + READY = "READY" + AVAILABLE = "AVAILABLE" + LOADING = "LOADING" + UNLOADING = "UNLOADING" + + +# Server name patterns for auto-detection +_OPENVINO_PATTERNS: Final[frozenset[str]] = frozenset({"openvino"}) +_TRITON_PATTERNS: Final[frozenset[str]] = frozenset({"triton"}) + +# GPU indicators in server extensions +_GPU_INDICATORS: Final[frozenset[str]] = frozenset({"cuda", "gpu", "tensorrt"}) + + +# ============================================================================= +# Data Classes +# ============================================================================= + +@dataclass(frozen=True) +class ServerInfo: + """Immutable server information container.""" + name: str + version: str + extensions: tuple[str, ...] + raw_data: Dict[str, Any] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ServerInfo": + """Create ServerInfo from a dict (e.g. converted gRPC metadata).""" + return cls( + name=data.get("name", "Unknown"), + version=data.get("version", "Unknown"), + extensions=tuple(data.get("extensions", [])), + raw_data=data, + ) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return self.raw_data + + +@dataclass(frozen=True) +class HealthStatus: + """Server health check result.""" + is_healthy: bool + message: str + + def __iter__(self): + """Allow unpacking as tuple for backward compatibility.""" + return iter((self.is_healthy, self.message)) + + +# ============================================================================= +# Server Discovery +# ============================================================================= + +class ServerDiscovery: + """ + Handles inference server discovery and health checking via gRPC. + + Supports both NVIDIA Triton Inference Server and OpenVINO Model Server. + + Thread Safety: + All mutable state is protected by locks. Safe for concurrent access + from multiple threads. + """ + + __slots__ = ( + "_grpc_client", + "_timeout", + "_inference_backend", + "_lock", + "_server_type", + "_server_info", + ) + + def __init__( + self, + grpc_client: grpcclient.InferenceServerClient, + timeout: int = DEFAULT_TIMEOUT_SECONDS, + inference_backend: str = "", + ) -> None: + """ + Initialize server discovery. + + Args: + grpc_client: gRPC inference-server client instance. + timeout: Request timeout in seconds. + inference_backend: Preferred backend ('triton', 'openvino', or '' for auto). + """ + self._grpc_client = grpc_client + self._timeout = timeout + self._inference_backend = inference_backend.lower().strip() + + # Thread-safe state + self._lock = threading.Lock() + self._server_type: Optional[str] = None + self._server_info: Optional[ServerInfo] = None + + # ========================================================================= + # Public API - Cache Management + # ========================================================================= + + def clear_cache(self) -> None: + """Clear cached server information. Thread-safe.""" + with self._lock: + self._server_type = None + self._server_info = None + logger.info("Server discovery cache cleared") + + # ========================================================================= + # Public API - Connectivity & Health + # ========================================================================= + + def test_connectivity(self) -> bool: + """ + Test basic connectivity to the model server via gRPC. + + Returns: + True if server is reachable and live. + """ + try: + if self._grpc_client.is_server_live(): + metadata = self._grpc_client.get_server_metadata() + info = server_metadata_to_dict(metadata) + logger.info( + f"Connected to {info.get('name', 'Unknown')} " + f"v{info.get('version', 'Unknown')} (gRPC)" + ) + return True + except InferenceServerException as e: + logger.warning(f"gRPC connectivity test failed: {e}") + except Exception as e: + logger.warning(f"Could not connect to model server via gRPC: {e}") + return False + + def check_server_health(self) -> HealthStatus: + """ + Check if the inference server is healthy and ready. + + Returns: + HealthStatus with is_healthy flag and message. + """ + try: + if self._grpc_client.is_server_ready(): + return HealthStatus(True, "Server is ready") + return HealthStatus(False, "Server not ready") + except InferenceServerException as e: + return HealthStatus(False, f"Health check failed: {e}") + except Exception as e: + return HealthStatus(False, f"Health check failed: {e}") + + # ========================================================================= + # Public API - Server Type Detection + # ========================================================================= + + def detect_server_type(self) -> str: + """ + Detect the type of inference server (Triton or OpenVINO). + + Detection strategy: + 1. Return cached result if available. + 2. Use INFERENCE_BACKEND preference if explicitly set. + 3. Auto-detect from server metadata via gRPC. + 4. Probe Triton-specific repository index as fallback. + + Returns: + Server type: 'triton', 'openvino', or 'unknown'. + """ + with self._lock: + if self._server_type is not None: + return self._server_type + + if self._inference_backend in ("triton", "openvino"): + self._server_type = self._inference_backend + logger.info(f"Using server type from preference: {self._server_type}") + return self._server_type + + detected = self._auto_detect_server_type() + + with self._lock: + self._server_type = detected + + return detected + + def get_server_info(self) -> Optional[Dict[str, Any]]: + """Get server information via gRPC.""" + with self._lock: + if self._server_info is not None: + return self._server_info.to_dict() + + try: + metadata = self._grpc_client.get_server_metadata() + info_dict = server_metadata_to_dict(metadata) + info = ServerInfo.from_dict(info_dict) + with self._lock: + self._server_info = info + return info.to_dict() + except InferenceServerException as e: + logger.error(f"Failed to get server info via gRPC: {e}") + except Exception as e: + logger.error(f"Failed to get server info: {e}") + return None + + def get_server_device_info(self) -> Literal["CPU", "GPU"]: + """ + Detect compute device (CPU/GPU) from the inference server. + + Returns: + 'GPU' if CUDA/TensorRT detected, otherwise 'CPU'. + """ + try: + server_type = self.detect_server_type() + + if server_type == SERVER_TYPE_TRITON: + metadata = self._grpc_client.get_server_metadata() + extensions = list(metadata.extensions) + extensions_str = " ".join(ext.lower() for ext in extensions) + + if any(indicator in extensions_str for indicator in _GPU_INDICATORS): + logger.debug("Triton server using GPU (detected from extensions)") + return "GPU" + + logger.debug(f"{server_type} server using CPU") + return "CPU" + + except Exception as e: + logger.debug(f"Error detecting server device: {e}") + return "CPU" + + # ========================================================================= + # Public API - Model Discovery + # ========================================================================= + + def check_model_ready(self, model_name: str) -> bool: + """ + Check if a specific model is ready for inference. + + Args: + model_name: Name of the model to check. + + Returns: + True if model is ready, False otherwise. + """ + try: + ready = self._grpc_client.is_model_ready(model_name) + if ready: + logger.debug(f"Model {model_name} is ready (gRPC)") + return True + except InferenceServerException: + pass + except Exception: + pass + + logger.debug(f"Model {model_name} not ready") + return False + + def get_available_models( + self, + known_models: Optional[List[str]] = None, + ) -> List[str]: + """ + Get list of available models from the inference server. + + Discovery strategy: + 1. Try gRPC repository index (Triton & compatible OVMS). + 2. Fall back to checking known models individually. + + Args: + known_models: Optional list of model names to check as fallback. + + Returns: + List of model names that are ready for inference. + """ + models = self._discover_via_repository_index() + if models: + return models + + if known_models: + return self._discover_via_known_models(known_models) + + return [] + + # ========================================================================= + # Private - Server Type Detection + # ========================================================================= + + def _auto_detect_server_type(self) -> str: + """Auto-detect server type from gRPC server metadata.""" + try: + metadata = self._grpc_client.get_server_metadata() + info_dict = server_metadata_to_dict(metadata) + info = ServerInfo.from_dict(info_dict) + + with self._lock: + self._server_info = info + + server_name_lower = info.name.lower() + + if any(pattern in server_name_lower for pattern in _OPENVINO_PATTERNS): + logger.info( + f"Detected OpenVINO Model Server: " + f"{info.name} v{info.version}" + ) + return SERVER_TYPE_OPENVINO + + if any(pattern in server_name_lower for pattern in _TRITON_PATTERNS): + logger.info( + f"Detected Triton Inference Server: " + f"{info.name} v{info.version}" + ) + return SERVER_TYPE_TRITON + + # Probe Triton-specific endpoint + return self._detect_by_repository_index() + + except InferenceServerException as e: + logger.warning(f"Failed to detect server type via gRPC: {e}") + return SERVER_TYPE_UNKNOWN + except Exception as e: + logger.warning(f"Failed to detect server type: {e}") + return SERVER_TYPE_UNKNOWN + + def _detect_by_repository_index(self) -> str: + """Detect server type by probing Triton-specific repository index.""" + try: + self._grpc_client.get_model_repository_index() + logger.info("Detected Triton via repository index (gRPC)") + return SERVER_TYPE_TRITON + except InferenceServerException: + pass + except Exception: + pass + + logger.info("Assuming OpenVINO (no repository index via gRPC)") + return SERVER_TYPE_OPENVINO + + # ========================================================================= + # Private - Model Discovery + # ========================================================================= + + def _discover_via_repository_index(self) -> List[str]: + """Discover models via gRPC repository index.""" + try: + index = self._grpc_client.get_model_repository_index() + index_list = repository_index_to_list(index) + + models: List[str] = [] + for entry in index_list: + name = entry.get("name") + if not name: + continue + state = entry.get("state", "").upper() + if state == "" or state == "READY": + models.append(name) + logger.debug(f"Found model: {name} (state: {state or 'not specified'})") + + if models: + logger.info(f"Discovered {len(models)} models via gRPC repository index: {models}") + return models + + except InferenceServerException as e: + logger.debug(f"Repository index not available via gRPC: {e}") + except Exception as e: + logger.warning(f"Repository index failed: {e}") + return [] + + def _discover_via_known_models(self, known_models: List[str]) -> List[str]: + """Check known models and return those that are ready.""" + logger.info("Trying known models discovery") + available: List[str] = [] + + for model_name in known_models: + if self.check_model_ready(model_name): + available.append(model_name) + logger.info(f"Found ready model (known): {model_name}") + + if available: + logger.info(f"Discovered {len(available)} models via known models") + + return available + + +__all__ = [ + "ServerDiscovery", + "ServerInfo", + "HealthStatus", + "ModelState", +] diff --git a/edgeai/ondevice-eval-agent/client/exceptions.py b/edgeai/ondevice-eval-agent/client/exceptions.py new file mode 100644 index 00000000..60ec594d --- /dev/null +++ b/edgeai/ondevice-eval-agent/client/exceptions.py @@ -0,0 +1,178 @@ +""" +Custom exceptions for Model Server Client. + +This module provides a hierarchy of exceptions for consistent error handling +across the client codebase. +""" + +from __future__ import annotations + +from typing import Any, Optional + + +class ModelServerError(Exception): + """ + Base exception for all model server errors. + + Attributes: + message: Human-readable error description + details: Optional dict with additional error context + status_code: Optional HTTP status code if applicable + """ + + def __init__( + self, + message: str, + details: Optional[dict[str, Any]] = None, + status_code: Optional[int] = None, + ) -> None: + super().__init__(message) + self.message = message + self.details = details or {} + self.status_code = status_code + + def to_dict(self) -> dict[str, Any]: + """Convert exception to a dictionary for JSON responses.""" + result: dict[str, Any] = { + "error": self.__class__.__name__, + "message": self.message, + } + if self.details: + result["details"] = self.details + if self.status_code: + result["status_code"] = self.status_code + return result + + +class InferenceError(ModelServerError): + """ + Raised when inference fails. + + This can be due to model execution errors, invalid input, or server issues. + """ + + def __init__( + self, + message: str, + model_name: Optional[str] = None, + details: Optional[dict[str, Any]] = None, + status_code: Optional[int] = None, + ) -> None: + super().__init__(message, details, status_code) + self.model_name = model_name + if model_name: + self.details["model_name"] = model_name + + +class ModelNotReadyError(ModelServerError): + """ + Raised when a model is not ready for inference. + + This typically means the model is loading, unloaded, or in an error state. + """ + + def __init__( + self, + model_name: str, + message: Optional[str] = None, + details: Optional[dict[str, Any]] = None, + ) -> None: + msg = message or f"Model '{model_name}' is not ready for inference" + super().__init__(msg, details, status_code=503) + self.model_name = model_name + self.details["model_name"] = model_name + + +class ServerConnectionError(ModelServerError): + """ + Raised when connection to the inference server fails. + + This covers network errors, timeouts, and server unreachable conditions. + """ + + def __init__( + self, + server_url: str, + message: Optional[str] = None, + cause: Optional[Exception] = None, + details: Optional[dict[str, Any]] = None, + ) -> None: + msg = message or f"Failed to connect to server at '{server_url}'" + super().__init__(msg, details, status_code=503) + self.server_url = server_url + self.cause = cause + self.details["server_url"] = server_url + if cause: + self.details["cause"] = str(cause) + + +class ImagePreprocessingError(ModelServerError): + """ + Raised when image preprocessing fails. + + This covers format errors, invalid images, and preprocessing failures. + """ + + def __init__( + self, + message: str, + image_source: Optional[str] = None, + cause: Optional[Exception] = None, + details: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__(message, details, status_code=400) + self.image_source = image_source + self.cause = cause + if image_source: + self.details["image_source"] = image_source + if cause: + self.details["cause"] = str(cause) + + +class ModelMetadataError(ModelServerError): + """ + Raised when model metadata retrieval fails. + + This can be due to invalid model names or server configuration issues. + """ + + def __init__( + self, + model_name: str, + message: Optional[str] = None, + details: Optional[dict[str, Any]] = None, + ) -> None: + msg = message or f"Failed to retrieve metadata for model '{model_name}'" + super().__init__(msg, details, status_code=404) + self.model_name = model_name + self.details["model_name"] = model_name + + +class ConfigurationError(ModelServerError): + """ + Raised when there is a configuration error. + + This covers invalid settings, missing required configuration, etc. + """ + + def __init__( + self, + message: str, + config_key: Optional[str] = None, + details: Optional[dict[str, Any]] = None, + ) -> None: + super().__init__(message, details, status_code=400) + self.config_key = config_key + if config_key: + self.details["config_key"] = config_key + + +__all__ = [ + "ModelServerError", + "InferenceError", + "ModelNotReadyError", + "ServerConnectionError", + "ImagePreprocessingError", + "ModelMetadataError", + "ConfigurationError", +] diff --git a/edgeai/ondevice-eval-agent/client/grpc_client.py b/edgeai/ondevice-eval-agent/client/grpc_client.py new file mode 100644 index 00000000..866c6ae2 --- /dev/null +++ b/edgeai/ondevice-eval-agent/client/grpc_client.py @@ -0,0 +1,373 @@ +""" +gRPC client wrapper for inference servers. + +This module provides a thin wrapper around tritonclient.grpc for +communicating with Triton and OpenVINO Model Server via the KServe v2 +gRPC protocol. Both servers implement the same gRPC interface, so a +single client works for either backend. + +Key benefits over HTTP: + - Binary tensor transfer (no JSON serialization of large arrays) + - Persistent HTTP/2 connections with lower per-request overhead + - Native streaming support for future use +""" + +from __future__ import annotations + +import logging +import re +from typing import Any, Dict, Final, List, Optional, Tuple +from urllib.parse import urlparse + +import numpy as np +import tritonclient.grpc as grpcclient +from tritonclient.utils import InferenceServerException + +from .config import DEFAULT_GRPC_PORT, DEFAULT_TIMEOUT_SECONDS + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Constants +# ============================================================================= + +# Triton datatype string -> numpy dtype mapping +_TRITON_TO_NUMPY: Final[Dict[str, np.dtype]] = { + "BOOL": np.dtype("bool"), + "UINT8": np.dtype("uint8"), + "UINT16": np.dtype("uint16"), + "UINT32": np.dtype("uint32"), + "UINT64": np.dtype("uint64"), + "INT8": np.dtype("int8"), + "INT16": np.dtype("int16"), + "INT32": np.dtype("int32"), + "INT64": np.dtype("int64"), + "FP16": np.dtype("float16"), + "FP32": np.dtype("float32"), + "FP64": np.dtype("float64"), + "BYTES": np.dtype("object"), +} + +# Numpy dtype -> Triton datatype string mapping +_NUMPY_TO_TRITON: Final[Dict[np.dtype, str]] = { + v: k for k, v in _TRITON_TO_NUMPY.items() +} + +# Triton metadata dtype (e.g. "FP32") -> config.pbtxt dtype (e.g. "TYPE_FP32") +_TRITON_DTYPE_TO_CONFIG: Final[Dict[str, str]] = { + "BOOL": "TYPE_BOOL", + "UINT8": "TYPE_UINT8", + "UINT16": "TYPE_UINT16", + "UINT32": "TYPE_UINT32", + "UINT64": "TYPE_UINT64", + "INT8": "TYPE_INT8", + "INT16": "TYPE_INT16", + "INT32": "TYPE_INT32", + "INT64": "TYPE_INT64", + "FP16": "TYPE_FP16", + "FP32": "TYPE_FP32", + "FP64": "TYPE_FP64", + "BYTES": "TYPE_STRING", + "BF16": "TYPE_BF16", +} + +# Reverse: config.pbtxt dtype -> Triton metadata dtype +_CONFIG_TO_TRITON_DTYPE: Final[Dict[str, str]] = { + v: k for k, v in _TRITON_DTYPE_TO_CONFIG.items() +} + + +# ============================================================================= +# Factory +# ============================================================================= + +def create_grpc_client( + url: str = f"localhost:{DEFAULT_GRPC_PORT}", + *, + verbose: bool = False, +) -> grpcclient.InferenceServerClient: + """ + Create a gRPC inference-server client. + + Args: + url: ``host:port`` of the gRPC endpoint (no scheme prefix). + Defaults to ``localhost:8001``. + verbose: Enable verbose logging in the underlying Triton client. + + Returns: + A ready-to-use ``tritonclient.grpc.InferenceServerClient``. + """ + # Strip scheme if the caller accidentally included one + url = _strip_scheme(url) + logger.info(f"Creating gRPC client for {url}") + return grpcclient.InferenceServerClient(url=url, verbose=verbose) + + +# ============================================================================= +# URL helpers +# ============================================================================= + +def _strip_scheme(url: str) -> str: + """Remove ``http://`` or ``grpc://`` prefix, returning ``host:port``.""" + if "://" in url: + parsed = urlparse(url) + host = parsed.hostname or "localhost" + port = parsed.port or DEFAULT_GRPC_PORT + return f"{host}:{port}" + return url + + +def grpc_url_from_http(http_url: str, grpc_port: int = DEFAULT_GRPC_PORT) -> str: + """ + Derive a gRPC ``host:port`` from an HTTP base URL. + + Example: + >>> grpc_url_from_http("http://192.168.1.10:8000") + '192.168.1.10:8001' + """ + parsed = urlparse(http_url) + host = parsed.hostname or "localhost" + return f"{host}:{grpc_port}" + + +# ============================================================================= +# Response conversion helpers +# ============================================================================= + +def server_metadata_to_dict(metadata: Any) -> Dict[str, Any]: + """ + Convert a gRPC ``ServerMetadataResponse`` to a plain dict matching + the KServe v2 JSON schema used by the rest of the codebase. + """ + return { + "name": metadata.name, + "version": metadata.version, + "extensions": list(metadata.extensions), + } + + +def model_metadata_to_dict(metadata: Any) -> Dict[str, Any]: + """ + Convert a gRPC ``ModelMetadataResponse`` to a dict matching the + KServe v2 REST ``/v2/models/{name}`` JSON response. + """ + inputs: List[Dict[str, Any]] = [] + for inp in metadata.inputs: + inputs.append({ + "name": inp.name, + "datatype": inp.datatype, + "shape": list(inp.shape), + }) + + outputs: List[Dict[str, Any]] = [] + for out in metadata.outputs: + outputs.append({ + "name": out.name, + "datatype": out.datatype, + "shape": list(out.shape), + }) + + return { + "name": metadata.name, + "versions": list(metadata.versions), + "platform": metadata.platform, + "inputs": inputs, + "outputs": outputs, + } + + +def model_config_to_dict(config: Any) -> Dict[str, Any]: + """ + Convert a gRPC ``ModelConfigResponse`` to a plain dict. + + The config protobuf is complex; we serialise the most commonly + inspected fields and fall back to ``str()`` for anything exotic. + """ + try: + from google.protobuf.json_format import MessageToDict + return MessageToDict(config, preserving_proto_field_name=True) + except Exception: + # Fallback: manually extract the top-level fields + result: Dict[str, Any] = {"name": getattr(config, "name", "")} + if hasattr(config, "platform"): + result["platform"] = config.platform + if hasattr(config, "backend"): + result["backend"] = config.backend + if hasattr(config, "max_batch_size"): + result["max_batch_size"] = config.max_batch_size + return result + + +def repository_index_to_list(index: Any) -> List[Dict[str, Any]]: + """ + Convert a gRPC repository-index response to the list-of-dicts + format returned by the REST ``POST /v2/repository/index`` endpoint. + """ + models: List[Dict[str, Any]] = [] + for entry in index: + models.append({ + "name": entry.name, + "version": getattr(entry, "version", ""), + "state": getattr(entry, "state", ""), + "reason": getattr(entry, "reason", ""), + }) + return models + + +def infer_result_to_dict( + result: grpcclient.InferResult, + model_name: str, +) -> Dict[str, Any]: + """ + Convert a gRPC ``InferResult`` into the dict format matching the + KServe v2 REST inference response used by the rest of the codebase. + + This allows downstream code (prediction processing, etc.) to remain + unchanged. + """ + output = result.get_output(0) + outputs: List[Dict[str, Any]] = [] + + # Iterate through all outputs + idx = 0 + while True: + try: + out_meta = result.get_output(idx) + except IndexError: + break + if out_meta is None: + break + + out_name = out_meta.name if hasattr(out_meta, "name") else f"output_{idx}" + out_data = result.as_numpy(out_name) + outputs.append({ + "name": out_name, + "shape": list(out_data.shape), + "datatype": out_meta.datatype if hasattr(out_meta, "datatype") else "FP32", + "data": out_data.flatten().tolist(), + }) + idx += 1 + + return { + "model_name": model_name, + "outputs": outputs, + } + + +# ============================================================================= +# Metrics parsing (Prometheus text format) +# ============================================================================= + +# Regex for Prometheus metric lines: metric_name{labels} value +_METRIC_LINE_RE = re.compile( + r'^(?P[a-zA-Z_:][a-zA-Z0-9_:]*)' + r'(?:\{(?P[^}]*)\})?\s+' + r'(?P[0-9eE.+\-]+)$' +) + + +def parse_prometheus_metrics( + text: str, + model_name: Optional[str] = None, +) -> Dict[str, Dict[str, float]]: + """ + Parse Prometheus text-format metrics into a nested dict. + + Args: + text: Raw Prometheus metrics text (from ``/metrics``). + model_name: If given, only return metrics for this model. + + Returns: + ``{metric_name: {label_key: value, ...}, ...}`` + When *model_name* is specified the outer dict is filtered to + metrics whose ``model`` label matches. + """ + metrics: Dict[str, Dict[str, float]] = {} + + for line in text.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + + m = _METRIC_LINE_RE.match(line) + if not m: + continue + + name = m.group("name") + labels_str = m.group("labels") or "" + try: + value = float(m.group("value")) + except ValueError: + continue + + # Parse labels + labels: Dict[str, str] = {} + if labels_str: + for pair in labels_str.split(","): + k, _, v = pair.partition("=") + labels[k.strip()] = v.strip().strip('"') + + # Filter by model if requested + if model_name and labels.get("model") != model_name: + continue + + # Store with version suffix for uniqueness + version = labels.get("version", "") + key = f"{name}" if not version else f"{name}:v{version}" + metrics[key] = {"value": value, **labels} + + return metrics + + +def get_triton_latency_metrics( + metrics: Dict[str, Dict[str, float]], +) -> Dict[str, float]: + """ + Extract Triton-specific latency counters (in microseconds) from + parsed Prometheus metrics and convert to milliseconds. + + Returns a dict with keys like ``queue_ms``, ``compute_infer_ms``, etc. + Missing metrics are omitted rather than defaulted. + """ + mapping = { + "nv_inference_request_duration_us": "request_duration_ms", + "nv_inference_queue_duration_us": "queue_ms", + "nv_inference_compute_input_duration_us": "compute_input_ms", + "nv_inference_compute_infer_duration_us": "compute_infer_ms", + "nv_inference_compute_output_duration_us": "compute_output_ms", + } + + result: Dict[str, float] = {} + for prom_name, friendly_name in mapping.items(): + # Try without version suffix first, then with :v1 + for key in (prom_name, f"{prom_name}:v1"): + if key in metrics: + result[friendly_name] = metrics[key]["value"] / 1000.0 + break + + # Also grab request count for computing per-request averages + for key in ("nv_inference_request_success", "nv_inference_request_success:v1"): + if key in metrics: + result["request_count"] = metrics[key]["value"] + break + + return result + + +__all__ = [ + "create_grpc_client", + "grpc_url_from_http", + "server_metadata_to_dict", + "model_metadata_to_dict", + "model_config_to_dict", + "repository_index_to_list", + "infer_result_to_dict", + "parse_prometheus_metrics", + "get_triton_latency_metrics", + "InferenceServerException", + "_TRITON_TO_NUMPY", + "_NUMPY_TO_TRITON", + "_TRITON_DTYPE_TO_CONFIG", + "_CONFIG_TO_TRITON_DTYPE", +] diff --git a/edgeai/ondevice-eval-agent/client/http_session.py b/edgeai/ondevice-eval-agent/client/http_session.py new file mode 100644 index 00000000..e837fb20 --- /dev/null +++ b/edgeai/ondevice-eval-agent/client/http_session.py @@ -0,0 +1,105 @@ +""" +HTTP session management for Model Server Client. + +This module handles HTTP session creation with retry logic and connection pooling. +""" + +from __future__ import annotations + +import logging +from typing import Optional + +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +from .config import MAX_RETRIES, RETRY_BACKOFF_FACTOR + +logger = logging.getLogger(__name__) + + +def create_session( + max_retries: int = MAX_RETRIES, + backoff_factor: float = RETRY_BACKOFF_FACTOR, +) -> requests.Session: + """ + Create a requests session with retry logic and connection pooling. + + Args: + max_retries: Maximum number of retry attempts for failed requests + backoff_factor: Exponential backoff factor between retries + + Returns: + Configured requests.Session instance + """ + session = requests.Session() + + retry_strategy = Retry( + total=max_retries, + backoff_factor=backoff_factor, + status_forcelist=[429, 500, 502, 503, 504], + allowed_methods=['GET', 'HEAD', 'OPTIONS'], + ) + + adapter = HTTPAdapter(max_retries=retry_strategy) + session.mount('http://', adapter) + session.mount('https://', adapter) + + return session + + +class SessionManager: + """ + Manages HTTP sessions with context manager support. + + Example: + with SessionManager() as session: + response = session.get("http://example.com") + """ + + def __init__( + self, + max_retries: int = MAX_RETRIES, + backoff_factor: float = RETRY_BACKOFF_FACTOR, + ) -> None: + """ + Initialize the session manager. + + Args: + max_retries: Maximum retry attempts + backoff_factor: Exponential backoff factor + """ + self._session: Optional[requests.Session] = None + self._max_retries = max_retries + self._backoff_factor = backoff_factor + + @property + def session(self) -> requests.Session: + """Get or create the HTTP session.""" + if self._session is None: + self._session = create_session( + self._max_retries, + self._backoff_factor, + ) + return self._session + + def close(self) -> None: + """Close the HTTP session and release resources.""" + if self._session is not None: + self._session.close() + self._session = None + logger.debug("HTTP session closed") + + def __enter__(self) -> requests.Session: + """Context manager entry - returns session.""" + return self.session + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Context manager exit - closes session.""" + self.close() + + +__all__ = [ + "create_session", + "SessionManager", +] diff --git a/edgeai/ondevice-eval-agent/client/inference.py b/edgeai/ondevice-eval-agent/client/inference.py new file mode 100644 index 00000000..1076f02f --- /dev/null +++ b/edgeai/ondevice-eval-agent/client/inference.py @@ -0,0 +1,440 @@ +""" +Inference operations for Model Server Client via gRPC. + +This module handles sending inference requests and processing responses +from both Triton and OpenVINO inference servers using the KServe v2 +gRPC protocol. Tensor data is transferred in binary form, avoiding +the JSON serialization overhead of the REST API. +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass, field +from typing import Any, Dict, Final, List, Optional + +import numpy as np +import tritonclient.grpc as grpcclient +from numpy.typing import NDArray +from tritonclient.utils import InferenceServerException + +from .config import DEFAULT_INFERENCE_TIMEOUT_SECONDS +from .exceptions import InferenceError + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Constants +# ============================================================================= + +_DEFAULT_INPUT_NAME: Final[str] = "input" +_DEFAULT_DATATYPE: Final[str] = "FP32" +_DEFAULT_TOP_K: Final[int] = 5 + + +# ============================================================================= +# Data Classes +# ============================================================================= + +@dataclass +class InferenceRequest: + """ + Structured inference request for KServe v2 gRPC API. + + Encapsulates all data needed for an inference request. + """ + model_name: str + input_name: str + input_shape: List[int] + input_data: NDArray[np.floating[Any]] + datatype: str = _DEFAULT_DATATYPE + + def to_grpc_inputs(self) -> List[grpcclient.InferInput]: + """Build gRPC InferInput objects from this request.""" + infer_input = grpcclient.InferInput( + self.input_name, + self.input_shape, + self.datatype, + ) + infer_input.set_data_from_numpy(self.input_data.astype(np.float32)) + return [infer_input] + + +@dataclass +class InferenceResult: + """Structured inference result.""" + model_name: str + outputs: List[Dict[str, Any]] + latency: Optional[float] = None + raw_response: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + result = { + "model_name": self.model_name, + "outputs": self.outputs, + } + if self.latency is not None: + result["latency"] = self.latency + return result + + +@dataclass +class ClassificationResult: + """Classification prediction result.""" + model_name: str + timestamp: str + num_classes: int + output_name: str + output_shape: List[int] + predictions: List[Dict[str, Any]] + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "timestamp": self.timestamp, + "model_name": self.model_name, + "num_classes": self.num_classes, + "output_name": self.output_name, + "output_shape": self.output_shape, + "top_predictions": self.predictions, + } + + +# ============================================================================= +# Inference Runner +# ============================================================================= + +class InferenceRunner: + """ + Handles inference request execution and response processing via gRPC. + + Uses tritonclient.grpc to send numpy arrays directly over gRPC, + eliminating JSON serialization overhead for tensor data. + """ + + __slots__ = ("_grpc_client", "_timeout", "_class_names") + + def __init__( + self, + grpc_client: grpcclient.InferenceServerClient, + timeout: int = DEFAULT_INFERENCE_TIMEOUT_SECONDS, + ) -> None: + """ + Initialize the inference runner. + + Args: + grpc_client: gRPC inference-server client instance. + timeout: Inference request timeout in seconds. + """ + self._grpc_client = grpc_client + self._timeout = timeout + self._class_names: Optional[List[str]] = None + + # ========================================================================= + # Properties + # ========================================================================= + + @property + def class_names(self) -> Optional[List[str]]: + """Get class names for labeling predictions.""" + return self._class_names + + @class_names.setter + def class_names(self, value: Optional[List[str]]) -> None: + """Set class names for labeling predictions.""" + self._class_names = value + + # ========================================================================= + # Public API - Inference + # ========================================================================= + + def send_inference_request( + self, + image_array: NDArray[np.floating[Any]], + model_name: str, + input_spec: Dict[str, Any], + server_type: str, + measure_latency: bool = False, + ) -> Dict[str, Any]: + """ + Send inference request to inference server via gRPC. + + Args: + image_array: Preprocessed image array with batch dimension. + model_name: Name of the model. + input_spec: Model input specification. + server_type: Server type ('triton', 'openvino', 'unknown'). + measure_latency: Whether to include request latency in result. + + Returns: + Raw inference response dict. + + Raises: + InferenceError: If inference fails. + """ + request = InferenceRequest( + model_name=model_name, + input_name=input_spec.get("name", _DEFAULT_INPUT_NAME), + input_shape=list(image_array.shape), + input_data=image_array, + datatype=input_spec.get("datatype", _DEFAULT_DATATYPE), + ) + + result = self._send_grpc_inference(request, measure_latency) + if result is not None: + return result + + raise InferenceError( + f"gRPC inference failed for model {model_name}", + model_name=model_name, + ) + + def process_prediction( + self, + response: Dict[str, Any], + model_name: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Process the prediction response from inference server. + + For classification models, applies softmax and returns top-k predictions. + For non-classification outputs, returns raw output info. + + Raises: + InferenceError: If response format is invalid. + """ + if not response or "outputs" not in response: + raise InferenceError( + f"Invalid response format for model {model_name}", + model_name=model_name, + ) + + try: + output_data = self._extract_output_data(response) + scores = self._reshape_scores(output_data) + + if self._is_classification_output(scores): + return self._process_classification( + scores, + output_data["name"], + output_data["shape"], + model_name, + ) + + return self._create_raw_output_result( + scores, output_data["name"], output_data["shape"], model_name + ) + + except (KeyError, IndexError, ValueError, TypeError) as e: + raise InferenceError( + f"Error processing prediction: {e}", + model_name=model_name, + details={"cause": str(e)}, + ) from e + + # ========================================================================= + # Private - gRPC Inference + # ========================================================================= + + def _send_grpc_inference( + self, + request: InferenceRequest, + measure_latency: bool, + ) -> Optional[Dict[str, Any]]: + """Send inference using gRPC with binary tensor transfer.""" + try: + grpc_inputs = request.to_grpc_inputs() + + # Request all outputs from the model + # (passing None for outputs requests all available outputs) + start_time = time.perf_counter() + grpc_result = self._grpc_client.infer( + model_name=request.model_name, + inputs=grpc_inputs, + client_timeout=self._timeout, + ) + latency = time.perf_counter() - start_time + + # Convert gRPC result to the dict format expected downstream + result = self._grpc_result_to_dict(grpc_result, request.model_name) + + if measure_latency: + result["latency"] = latency + + logger.debug( + f"gRPC inference successful for {request.model_name} " + f"({latency*1000:.1f}ms)" + ) + return result + + except InferenceServerException as e: + logger.warning(f"gRPC inference failed for {request.model_name}: {e}") + return None + except Exception as e: + logger.warning(f"gRPC inference error for {request.model_name}: {e}") + return None + + def _grpc_result_to_dict( + self, + grpc_result: grpcclient.InferResult, + model_name: str, + ) -> Dict[str, Any]: + """ + Convert a gRPC InferResult into the dict format matching the + KServe v2 REST inference response used by downstream code. + """ + outputs: List[Dict[str, Any]] = [] + + # Get the result's response object to enumerate output names + response = grpc_result.get_response() + if hasattr(response, "outputs"): + for out_meta in response.outputs: + out_name = out_meta.name + out_data = grpc_result.as_numpy(out_name) + outputs.append({ + "name": out_name, + "shape": list(out_data.shape), + "datatype": out_meta.datatype, + "data": out_data.flatten().tolist(), + }) + else: + # Fallback: try output_0 + try: + out_data = grpc_result.as_numpy("output_0") + outputs.append({ + "name": "output_0", + "shape": list(out_data.shape), + "datatype": "FP32", + "data": out_data.flatten().tolist(), + }) + except Exception: + pass + + return { + "model_name": model_name, + "outputs": outputs, + } + + # ========================================================================= + # Private - Response Processing + # ========================================================================= + + def _extract_output_data(self, response: Dict[str, Any]) -> Dict[str, Any]: + """Extract first output data from response.""" + outputs = response["outputs"] + + if not isinstance(outputs, list) or len(outputs) == 0: + raise ValueError(f"Unexpected outputs format: {type(outputs)}") + + output = outputs[0] + return { + "name": output.get("name", "output"), + "shape": output.get("shape", []), + "data": output.get("data", []), + } + + def _reshape_scores(self, output_data: Dict[str, Any]) -> NDArray: + """Reshape prediction scores based on output shape.""" + scores = np.array(output_data["data"]) + shape = output_data["shape"] + + if shape: + scores = scores.reshape(shape) + + if len(scores.shape) == 2 and scores.shape[0] == 1: + scores = scores[0] + + return scores + + @staticmethod + def _is_classification_output(scores: NDArray) -> bool: + """Check if output looks like classification (1D array with multiple values).""" + return len(scores.shape) == 1 and len(scores) > 1 + + def _create_raw_output_result( + self, + scores: NDArray, + output_name: str, + output_shape: List[int], + model_name: Optional[str], + ) -> Dict[str, Any]: + """Create result dict for non-classification outputs.""" + return { + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "model_name": model_name, + "output_name": output_name, + "output_shape": output_shape, + "raw_output": scores.tolist() if hasattr(scores, "tolist") else scores, + "top_predictions": [], + } + + # ========================================================================= + # Private - Classification Processing + # ========================================================================= + + def _process_classification( + self, + scores: NDArray, + output_name: str, + output_shape: List[int], + model_name: Optional[str], + ) -> Dict[str, Any]: + """Process classification model output.""" + probabilities = self._softmax(scores) + + num_classes = len(probabilities) + top_k = min(_DEFAULT_TOP_K, num_classes) + top_indices = np.argsort(probabilities)[-top_k:][::-1] + top_probs = probabilities[top_indices] + + predictions = [ + self._create_prediction_entry(i, int(idx), float(prob)) + for i, (idx, prob) in enumerate(zip(top_indices, top_probs)) + ] + + return ClassificationResult( + model_name=model_name or "unknown", + timestamp=time.strftime("%Y-%m-%d %H:%M:%S"), + num_classes=num_classes, + output_name=output_name, + output_shape=output_shape, + predictions=predictions, + ).to_dict() + + @staticmethod + def _softmax(scores: NDArray) -> NDArray: + """Apply softmax normalization with numerical stability.""" + exp_scores = np.exp(scores - np.max(scores)) + return exp_scores / np.sum(exp_scores) + + def _create_prediction_entry( + self, + rank: int, + class_id: int, + probability: float, + ) -> Dict[str, Any]: + """Create a single prediction entry with optional class name.""" + class_name = ( + self._class_names[class_id] + if self._class_names and 0 <= class_id < len(self._class_names) + else f"Class_{class_id}" + ) + + return { + "rank": rank + 1, + "class_id": class_id, + "confidence": probability, + "probability": probability, + "class_name": class_name, + } + + +__all__ = [ + "InferenceRunner", + "InferenceRequest", + "InferenceResult", + "ClassificationResult", +] diff --git a/edgeai/ondevice-eval-agent/client/llm_client.py b/edgeai/ondevice-eval-agent/client/llm_client.py new file mode 100644 index 00000000..6f845466 --- /dev/null +++ b/edgeai/ondevice-eval-agent/client/llm_client.py @@ -0,0 +1,491 @@ +""" +LLM Inference Client. + +Lightweight client for interfacing with LLM serving backends (vLLM, llama.cpp) +over the OpenAI-compatible API. Handles service discovery via environment +variables, health checking, model listing, inference, and performance metrics. + +Both vLLM and llama.cpp expose OpenAI-compatible endpoints: + - GET /v1/models + - POST /v1/chat/completions + - POST /v1/completions + - GET /metrics (Prometheus, vLLM only) + +Service discovery mirrors the Triton pattern. URLs are resolved in order: + OPENAI_API_BASE_URLS -> base URL injected by the on-device Helm chart + (plural; may be comma-separated; may carry a + trailing ``/v1`` path — stripped automatically) + LLM_SERVER_URL -> legacy single-URL fallback + default -> http://localhost:8000 + + LLM_SERVER_TYPE -> "vllm" or "llamacpp" (affects metrics parsing) +""" + +from __future__ import annotations + +import logging +import os +import time +from dataclasses import dataclass, field +from enum import Enum +from functools import lru_cache +from typing import Any, Dict, Final, List, Optional + +import requests +from openai import OpenAI + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Constants +# ============================================================================= + +_ENV_LLM_SERVER_URL: Final[str] = "LLM_SERVER_URL" +_ENV_OPENAI_API_BASE_URLS: Final[str] = "OPENAI_API_BASE_URLS" +_ENV_LLM_SERVER_TYPE: Final[str] = "LLM_SERVER_TYPE" + +_DEFAULT_LLM_SERVER_URL: Final[str] = "http://localhost:8000" +_DEFAULT_TIMEOUT: Final[int] = 120 + + +def _resolve_llm_base_url(explicit: Optional[str]) -> str: + """ + Resolve the LLM server base URL from (in order): explicit arg → + ``OPENAI_API_BASE_URLS`` → ``LLM_SERVER_URL`` → localhost default. + + ``OPENAI_API_BASE_URLS`` is injected by the on-device Helm chart in the + OpenWebUI convention: it may be a single URL or a comma-separated list, + and it commonly carries a trailing ``/v1`` path. The first entry is used + and any trailing ``/v1`` is stripped so callers can unconditionally + append OpenAI-style paths (e.g. ``/v1/models``, ``/v1/chat/completions``) + without doubling the prefix. + """ + raw = ( + explicit + or os.environ.get(_ENV_OPENAI_API_BASE_URLS) + or os.environ.get(_ENV_LLM_SERVER_URL) + or _DEFAULT_LLM_SERVER_URL + ) + first = raw.split(",")[0].strip().rstrip("/") + if first.endswith("/v1"): + first = first[:-3] + return first + + +class LLMServerType(str, Enum): + """Supported LLM serving backends.""" + VLLM = "vllm" + LLAMACPP = "llamacpp" + UNKNOWN = "unknown" + + +# ============================================================================= +# Data Classes +# ============================================================================= + +@dataclass(frozen=True) +class LLMModelInfo: + """Information about a served LLM model.""" + id: str + created: Optional[int] = None + owned_by: Optional[str] = None + raw: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class LLMPerformanceMetrics: + """Performance metrics for an LLM inference request.""" + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + time_to_first_token_ms: Optional[float] = None + total_time_ms: float = 0.0 + tokens_per_second: float = 0.0 + + +@dataclass +class LLMServerMetrics: + """Server-level metrics scraped from the Prometheus endpoint.""" + raw: Dict[str, float] = field(default_factory=dict) + avg_generation_throughput_tps: Optional[float] = None + avg_prompt_throughput_tps: Optional[float] = None + running_requests: Optional[int] = None + waiting_requests: Optional[int] = None + gpu_cache_usage_pct: Optional[float] = None + + +# ============================================================================= +# LLM Client +# ============================================================================= + +class LLMInferenceClient: + """ + Client for LLM serving backends (vLLM, llama.cpp) over the + OpenAI-compatible REST API. + + Thread Safety: + The OpenAI SDK client is thread-safe. This class is safe for + concurrent use from multiple threads. + """ + + __slots__ = ("_base_url", "_server_type", "_openai", "_timeout") + + def __init__( + self, + base_url: Optional[str] = None, + server_type: Optional[str] = None, + timeout: int = _DEFAULT_TIMEOUT, + ) -> None: + self._base_url = _resolve_llm_base_url(base_url) + + raw_type = ( + server_type + or os.environ.get(_ENV_LLM_SERVER_TYPE, "") + ).lower().strip() + try: + self._server_type = LLMServerType(raw_type) + except ValueError: + self._server_type = LLMServerType.UNKNOWN + + self._timeout = timeout + + self._openai = OpenAI( + base_url=f"{self._base_url}/v1", + api_key="not-needed", + timeout=float(timeout), + ) + + logger.info( + "LLMInferenceClient initialised: base_url=%s, server_type=%s", + self._base_url, + self._server_type.value, + ) + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def base_url(self) -> str: + return self._base_url + + @property + def server_type(self) -> LLMServerType: + return self._server_type + + # ------------------------------------------------------------------ + # Health + # ------------------------------------------------------------------ + + def is_healthy(self) -> bool: + """Return True if the LLM server is reachable.""" + try: + resp = requests.get( + f"{self._base_url}/v1/models", timeout=10 + ) + return resp.status_code == 200 + except Exception as exc: + logger.debug("LLM health check failed: %s", exc) + return False + + # ------------------------------------------------------------------ + # Model Listing + # ------------------------------------------------------------------ + + def list_models(self) -> List[LLMModelInfo]: + """List models served by the LLM backend.""" + try: + response = self._openai.models.list() + models: List[LLMModelInfo] = [] + for m in response.data: + models.append(LLMModelInfo( + id=m.id, + created=getattr(m, "created", None), + owned_by=getattr(m, "owned_by", None), + raw=m.model_dump() if hasattr(m, "model_dump") else {}, + )) + return models + except Exception as exc: + logger.error("Failed to list LLM models: %s", exc) + raise + + # ------------------------------------------------------------------ + # Inference + # ------------------------------------------------------------------ + + def chat_completion( + self, + model: str, + messages: List[Dict[str, str]], + max_tokens: int = 512, + temperature: float = 0.7, + stream: bool = False, + ) -> Dict[str, Any]: + """ + Send a chat completion request and return the result with timing. + + Returns a dict with keys: response, usage, performance. + """ + t_start = time.perf_counter() + + completion = self._openai.chat.completions.create( + model=model, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + stream=False, + ) + + total_time = (time.perf_counter() - t_start) * 1000.0 # ms + + usage = completion.usage + prompt_tokens = usage.prompt_tokens if usage else 0 + completion_tokens = usage.completion_tokens if usage else 0 + total_tokens = usage.total_tokens if usage else 0 + + tokens_per_sec = ( + (completion_tokens / (total_time / 1000.0)) + if total_time > 0 and completion_tokens > 0 + else 0.0 + ) + + response_text = "" + if completion.choices: + response_text = completion.choices[0].message.content or "" + + return { + "response": response_text, + "model": completion.model, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + }, + "performance": { + "total_time_ms": round(total_time, 3), + "tokens_per_second": round(tokens_per_sec, 2), + }, + "finish_reason": ( + completion.choices[0].finish_reason + if completion.choices + else None + ), + } + + def chat_completion_streaming( + self, + model: str, + messages: List[Dict[str, str]], + max_tokens: int = 512, + temperature: float = 0.7, + ) -> Dict[str, Any]: + """ + Send a streaming chat completion to measure time-to-first-token. + + Returns the same dict shape as ``chat_completion()`` with an + additional ``performance.time_to_first_token_ms`` field. Token + usage is best-effort — vLLM returns it via ``stream_options`` + while llama.cpp may not. + """ + t_start = time.perf_counter() + t_first_token: Optional[float] = None + response_parts: List[str] = [] + prompt_tokens = 0 + completion_tokens = 0 + finish_reason: Optional[str] = None + model_id: Optional[str] = None + + # Try with stream_options first (vLLM ≥0.4 supports this). + # Fall back gracefully if the backend rejects the extra kwarg. + stream_kwargs: Dict[str, Any] = dict( + model=model, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + stream=True, + ) + try: + stream = self._openai.chat.completions.create( + **stream_kwargs, + stream_options={"include_usage": True}, + ) + except Exception: + # Backend doesn't support stream_options — retry without. + stream = self._openai.chat.completions.create(**stream_kwargs) + + for chunk in stream: + # Record TTFT on the first chunk that carries content. + if chunk.choices: + delta_content = chunk.choices[0].delta.content + if t_first_token is None and delta_content: + t_first_token = time.perf_counter() + if delta_content: + response_parts.append(delta_content) + if chunk.choices[0].finish_reason: + finish_reason = chunk.choices[0].finish_reason + if chunk.model: + model_id = chunk.model + # vLLM sends usage in the final chunk when stream_options is set. + if hasattr(chunk, "usage") and chunk.usage: + prompt_tokens = chunk.usage.prompt_tokens or 0 + completion_tokens = chunk.usage.completion_tokens or 0 + + total_time = (time.perf_counter() - t_start) * 1000.0 + ttft_ms = ( + (t_first_token - t_start) * 1000.0 + if t_first_token is not None + else None + ) + total_tokens = prompt_tokens + completion_tokens + tokens_per_sec = ( + (completion_tokens / (total_time / 1000.0)) + if total_time > 0 and completion_tokens > 0 + else 0.0 + ) + + return { + "response": "".join(response_parts), + "model": model_id or model, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + }, + "performance": { + "total_time_ms": round(total_time, 3), + "tokens_per_second": round(tokens_per_sec, 2), + "time_to_first_token_ms": ( + round(ttft_ms, 3) if ttft_ms is not None else None + ), + }, + "finish_reason": finish_reason, + } + + def text_completion( + self, + model: str, + prompt: str, + max_tokens: int = 512, + temperature: float = 0.7, + ) -> Dict[str, Any]: + """ + Send a text completion request (non-chat) and return the result + with timing. + """ + t_start = time.perf_counter() + + completion = self._openai.completions.create( + model=model, + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + ) + + total_time = (time.perf_counter() - t_start) * 1000.0 + + usage = completion.usage + prompt_tokens = usage.prompt_tokens if usage else 0 + completion_tokens = usage.completion_tokens if usage else 0 + total_tokens = usage.total_tokens if usage else 0 + + tokens_per_sec = ( + (completion_tokens / (total_time / 1000.0)) + if total_time > 0 and completion_tokens > 0 + else 0.0 + ) + + response_text = "" + if completion.choices: + response_text = completion.choices[0].text or "" + + return { + "response": response_text, + "model": completion.model, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + }, + "performance": { + "total_time_ms": round(total_time, 3), + "tokens_per_second": round(tokens_per_sec, 2), + }, + "finish_reason": ( + completion.choices[0].finish_reason + if completion.choices + else None + ), + } + + # ------------------------------------------------------------------ + # Server-Level Metrics (Prometheus) + # ------------------------------------------------------------------ + + def get_server_metrics(self) -> Optional[LLMServerMetrics]: + """ + Scrape Prometheus metrics from the LLM server. + + vLLM exposes metrics at GET /metrics. llama.cpp does not have a + standard metrics endpoint, so this returns None for llama.cpp. + """ + try: + resp = requests.get( + f"{self._base_url}/metrics", timeout=10 + ) + if resp.status_code != 200: + logger.debug("Metrics endpoint returned %d", resp.status_code) + return None + + return self._parse_prometheus_metrics(resp.text) + except Exception as exc: + logger.debug("Failed to fetch LLM server metrics: %s", exc) + return None + + @staticmethod + def _parse_prometheus_metrics(text: str) -> LLMServerMetrics: + """Parse Prometheus text format into LLMServerMetrics.""" + raw: Dict[str, float] = {} + + for line in text.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + parts = line.split() + if len(parts) >= 2: + try: + raw[parts[0]] = float(parts[1]) + except ValueError: + continue + + # vLLM-specific gauge names + metrics = LLMServerMetrics(raw=raw) + metrics.avg_generation_throughput_tps = raw.get( + "vllm:avg_generation_throughput_toks_per_s" + ) + metrics.avg_prompt_throughput_tps = raw.get( + "vllm:avg_prompt_throughput_toks_per_s" + ) + metrics.running_requests = ( + int(raw["vllm:num_requests_running"]) + if "vllm:num_requests_running" in raw + else None + ) + metrics.waiting_requests = ( + int(raw["vllm:num_requests_waiting"]) + if "vllm:num_requests_waiting" in raw + else None + ) + metrics.gpu_cache_usage_pct = raw.get("vllm:gpu_cache_usage_perc") + + return metrics + + +# ============================================================================= +# Singleton accessor (mirrors get_client() in mcp/base.py) +# ============================================================================= + +@lru_cache(maxsize=1) +def get_llm_client() -> LLMInferenceClient: + """Get or create the shared LLMInferenceClient singleton.""" + return LLMInferenceClient() diff --git a/edgeai/ondevice-eval-agent/client/metadata.py b/edgeai/ondevice-eval-agent/client/metadata.py new file mode 100644 index 00000000..7b3f34ec --- /dev/null +++ b/edgeai/ondevice-eval-agent/client/metadata.py @@ -0,0 +1,369 @@ +""" +Model metadata retrieval and management via gRPC. + +This module handles model metadata operations including input/output +specification detection and thread-safe caching, using the KServe v2 +gRPC protocol. +""" + +from __future__ import annotations + +import logging +import threading +from dataclasses import dataclass, field +from typing import Any, Dict, Final, List, Optional + +import tritonclient.grpc as grpcclient +from tritonclient.utils import InferenceServerException + +from .config import ( + COMMON_CHANNEL_COUNTS, + DEFAULT_INPUT_SPEC, + DEFAULT_OUTPUT_SPEC, + DEFAULT_TARGET_SIZE, + DEFAULT_TIMEOUT_SECONDS, +) +from .grpc_client import model_metadata_to_dict, model_config_to_dict + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Constants +# ============================================================================= + +# Minimum shape length for valid tensor specifications +_MIN_SHAPE_LENGTH: Final[int] = 4 + +# Index positions for NCHW format +_NCHW_CHANNEL_IDX: Final[int] = 1 +_NCHW_HEIGHT_IDX: Final[int] = 2 +_NCHW_WIDTH_IDX: Final[int] = 3 + + +# ============================================================================= +# Data Classes +# ============================================================================= + +@dataclass +class TensorSpec: + """ + Tensor specification for model inputs/outputs. + + Provides structured access to tensor metadata from KServe v2 API. + """ + name: str + shape: List[int] + datatype: str + + # Derived properties (computed from shape) + format: str = "NCHW" + channels: int = 3 + height: int = 224 + width: int = 224 + num_classes: Optional[int] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for backward compatibility.""" + return { + "name": self.name, + "shape": self.shape, + "datatype": self.datatype, + "format": self.format, + "channels": self.channels, + "height": self.height, + "width": self.width, + "num_classes": self.num_classes, + } + + @classmethod + def from_input_info( + cls, + input_info: Dict[str, Any], + default_size: int = DEFAULT_TARGET_SIZE[0], + ) -> "TensorSpec": + """ + Create TensorSpec from KServe v2 input metadata. + + Automatically detects NCHW vs NHWC format based on shape. + """ + name = input_info.get("name", "input") + shape = input_info.get("shape", [-1, 3, *DEFAULT_TARGET_SIZE]) + datatype = input_info.get("datatype", "FP32") + + format_str, channels, height, width = _parse_input_shape(shape, default_size) + + logger.debug(f"Detected input spec: {format_str} {height}x{width}x{channels}") + + return cls( + name=name, + shape=shape, + datatype=datatype, + format=format_str, + channels=channels, + height=height, + width=width, + ) + + @classmethod + def from_output_info(cls, output_info: Dict[str, Any]) -> "TensorSpec": + """Create TensorSpec from KServe v2 output metadata.""" + name = output_info.get("name", "output") + shape = output_info.get("shape", [-1, 1000]) + datatype = output_info.get("datatype", "FP32") + + num_classes = shape[-1] if len(shape) >= 2 and shape[-1] > 0 else None + + logger.debug(f"Detected output spec: {name}, shape={shape}, classes={num_classes}") + + return cls( + name=name, + shape=shape, + datatype=datatype, + num_classes=num_classes, + ) + + +# ============================================================================= +# Shape Parsing Utilities +# ============================================================================= + +def _parse_input_shape( + shape: List[int], + default_size: int, +) -> tuple[str, int, int, int]: + """ + Parse input shape to extract format and dimensions. + + Handles both NCHW and NHWC formats by detecting channel position. + """ + if len(shape) < _MIN_SHAPE_LENGTH: + return "NCHW", 3, default_size, default_size + + if shape[_NCHW_CHANNEL_IDX] in COMMON_CHANNEL_COUNTS: + return ( + "NCHW", + _resolve_dim(shape[1], 3), + _resolve_dim(shape[2], default_size), + _resolve_dim(shape[3], default_size), + ) + + if shape[-1] in COMMON_CHANNEL_COUNTS: + return ( + "NHWC", + _resolve_dim(shape[-1], 3), + _resolve_dim(shape[1], default_size), + _resolve_dim(shape[2], default_size), + ) + + return ( + "NCHW", + _resolve_dim(shape[1], 3), + _resolve_dim(shape[2], default_size), + _resolve_dim(shape[3], default_size), + ) + + +def _resolve_dim(value: int, default: int) -> int: + """Resolve dimension value, using default for dynamic (-1) dimensions.""" + return value if value > 0 else default + + +# ============================================================================= +# Model Metadata Manager +# ============================================================================= + +class ModelMetadataManager: + """ + Manages model metadata retrieval and caching via gRPC. + + Provides thread-safe access to model metadata from inference servers + with automatic caching to reduce redundant gRPC calls. + + Thread Safety: + All cache operations are protected by a lock. + """ + + __slots__ = ("_grpc_client", "_timeout", "_cache_lock", "_metadata_cache", "_config_cache") + + def __init__( + self, + grpc_client: grpcclient.InferenceServerClient, + timeout: int = DEFAULT_TIMEOUT_SECONDS, + ) -> None: + """ + Initialize the metadata manager. + + Args: + grpc_client: gRPC inference-server client instance. + timeout: Request timeout in seconds. + """ + self._grpc_client = grpc_client + self._timeout = timeout + + # Thread-safe cache + self._cache_lock = threading.Lock() + self._metadata_cache: Dict[str, Dict[str, Any]] = {} + self._config_cache: Dict[str, Dict[str, Any]] = {} + + # ========================================================================= + # Public API - Cache Management + # ========================================================================= + + def clear_cache(self) -> None: + """Clear all cached metadata. Thread-safe.""" + with self._cache_lock: + self._metadata_cache.clear() + self._config_cache.clear() + logger.info("Model metadata cache cleared") + + # ========================================================================= + # Public API - Metadata Retrieval + # ========================================================================= + + def get_metadata( + self, + model_name: str, + use_cache: bool = True, + ) -> Optional[Dict[str, Any]]: + """ + Get detailed model metadata from inference server via gRPC. + + Args: + model_name: Name of the model. + use_cache: Whether to use cached metadata. + + Returns: + Model metadata with input/output specifications, or None on error. + """ + if use_cache: + with self._cache_lock: + if model_name in self._metadata_cache: + return self._metadata_cache[model_name] + + try: + logger.debug(f"Getting model metadata via gRPC for: {model_name}") + grpc_metadata = self._grpc_client.get_model_metadata(model_name) + metadata = model_metadata_to_dict(grpc_metadata) + + with self._cache_lock: + self._metadata_cache[model_name] = metadata + + logger.info(f"Model metadata retrieved and cached for {model_name} (gRPC)") + return metadata + + except InferenceServerException as e: + logger.error(f"gRPC error getting metadata for {model_name}: {e}") + return None + except Exception as e: + logger.error(f"Error getting model metadata for {model_name}: {e}") + return None + + def get_model_config(self, model_name: str, use_cache: bool = True) -> Optional[Dict[str, Any]]: + """Fetch model configuration from the server via gRPC.""" + if use_cache: + with self._cache_lock: + if model_name in self._config_cache: + return self._config_cache[model_name] + + try: + logger.debug(f"Getting model config via gRPC for: {model_name}") + grpc_config = self._grpc_client.get_model_config(model_name) + config = model_config_to_dict(grpc_config) + + with self._cache_lock: + self._config_cache[model_name] = config + + logger.info(f"Model config retrieved and cached for {model_name} (gRPC)") + return config + + except InferenceServerException as e: + logger.error(f"gRPC error getting model config for {model_name}: {e}") + return None + except Exception as e: + logger.error(f"Error getting model config for {model_name}: {e}") + return None + + # ========================================================================= + # Public API - Input/Output Specifications + # ========================================================================= + + def get_input_spec(self, model_name: str) -> Dict[str, Any]: + """Auto-detect model input specifications from server metadata.""" + try: + metadata = self.get_metadata(model_name) + + if not metadata: + logger.warning(f"No metadata for {model_name}, using defaults") + return self._get_default_input_spec() + + inputs = metadata.get("inputs", []) + if inputs: + return TensorSpec.from_input_info(inputs[0]).to_dict() + + return self._get_default_input_spec() + + except (KeyError, IndexError, TypeError) as e: + logger.error(f"Error getting input spec for {model_name}: {e}") + return self._get_default_input_spec() + + def get_output_spec(self, model_name: str) -> Dict[str, Any]: + """Auto-detect model output specifications from server metadata.""" + try: + metadata = self.get_metadata(model_name) + + if not metadata: + return self._get_default_output_spec() + + outputs = metadata.get("outputs", []) + if outputs: + return TensorSpec.from_output_info(outputs[0]).to_dict() + + return self._get_default_output_spec() + + except (KeyError, IndexError, TypeError) as e: + logger.error(f"Error getting output spec for {model_name}: {e}") + return self._get_default_output_spec() + + def get_all_output_specs(self, model_name: str) -> List[Dict[str, Any]]: + """Get specifications for ALL model outputs (for multi-output models).""" + try: + metadata = self.get_metadata(model_name) + + if not metadata: + return [self._get_default_output_spec()] + + outputs = metadata.get("outputs", []) + if not outputs: + return [self._get_default_output_spec()] + + return [TensorSpec.from_output_info(output).to_dict() for output in outputs] + + except (KeyError, TypeError) as e: + logger.error(f"Error getting all output specs for {model_name}: {e}") + return [self._get_default_output_spec()] + + def get_input_shape(self, model_name: str) -> tuple[int, int]: + """Get the input shape (height, width) for a specific model.""" + input_spec = self.get_input_spec(model_name) + return (input_spec["height"], input_spec["width"]) + + # ========================================================================= + # Private - Defaults + # ========================================================================= + + @staticmethod + def _get_default_input_spec() -> Dict[str, Any]: + """Return default input specification.""" + return DEFAULT_INPUT_SPEC.copy() + + @staticmethod + def _get_default_output_spec() -> Dict[str, Any]: + """Return default output specification.""" + return DEFAULT_OUTPUT_SPEC.copy() + + +__all__ = [ + "ModelMetadataManager", + "TensorSpec", +] diff --git a/edgeai/ondevice-eval-agent/client/preprocessing.py b/edgeai/ondevice-eval-agent/client/preprocessing.py new file mode 100644 index 00000000..1d5f5682 --- /dev/null +++ b/edgeai/ondevice-eval-agent/client/preprocessing.py @@ -0,0 +1,405 @@ +""" +Image preprocessing module for Model Server Client. + +This module handles all image loading and preprocessing operations +required before inference. Supports various input formats and +outputs properly formatted numpy arrays for inference servers. +""" + +from __future__ import annotations + +import io +import logging +from dataclasses import dataclass +from typing import Any, BinaryIO, Optional, Union + +import numpy as np +from numpy.typing import NDArray +from PIL import Image + +from .config import ( + DEFAULT_DATA_FORMAT, + DEFAULT_TARGET_SIZE, + PIXEL_VALUE_MAX, + PreprocessingConfig, +) +from .exceptions import ImagePreprocessingError + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Type Aliases +# ============================================================================= + +# Image array after preprocessing: float32 with shape [N, C, H, W] or [N, H, W, C] +ImageArray = NDArray[np.floating[Any]] + +# Supported image input types +ImageInput = Union[bytes, BinaryIO, io.BytesIO, str, Image.Image] + + +# ============================================================================= +# Data Classes +# ============================================================================= + +@dataclass(frozen=True) +class PreprocessingParams: + """Immutable preprocessing parameters.""" + width: int + height: int + data_format: str + + @classmethod + def from_input_spec( + cls, + input_spec: Optional[dict[str, Any]] = None, + target_size: Optional[tuple[int, int]] = None, + default_size: tuple[int, int] = DEFAULT_TARGET_SIZE, + default_format: str = DEFAULT_DATA_FORMAT, + ) -> "PreprocessingParams": + """ + Create parameters from input spec and optional overrides. + + Args: + input_spec: Model input specification dict. + target_size: Optional (height, width) override. + default_size: Default (height, width) if not specified. + default_format: Default data format if not specified. + + Returns: + PreprocessingParams instance. + """ + if input_spec: + height = input_spec.get("height", default_size[0]) + width = input_spec.get("width", default_size[1]) + data_format = input_spec.get("format", default_format) + else: + height, width = default_size + data_format = default_format + + # Apply override if provided + if target_size is not None: + height, width = target_size + + return cls(width=width, height=height, data_format=data_format) + + +# ============================================================================= +# Image Preprocessor +# ============================================================================= + +class ImagePreprocessor: + """ + Handles image preprocessing for model inference. + + Supports various input formats (bytes, file paths, file objects, PIL Images) + and outputs properly formatted numpy arrays for inference servers. + + Features: + - Automatic format detection and conversion + - Configurable normalization (ImageNet defaults) + - NCHW/NHWC format conversion + - High-quality LANCZOS resampling + + Example: + >>> preprocessor = ImagePreprocessor() + >>> image_array = preprocessor.preprocess(image_bytes, input_spec) + >>> # image_array.shape: (1, 3, 224, 224) for NCHW format + """ + + __slots__ = ("_config",) + + def __init__(self, config: Optional[PreprocessingConfig] = None) -> None: + """ + Initialize the preprocessor. + + Args: + config: Preprocessing configuration. Uses defaults if not provided. + """ + self._config = config or PreprocessingConfig() + + # ========================================================================= + # Properties + # ========================================================================= + + @property + def config(self) -> PreprocessingConfig: + """Get current preprocessing configuration.""" + return self._config + + @config.setter + def config(self, value: PreprocessingConfig) -> None: + """Set preprocessing configuration.""" + self._config = value + + # ========================================================================= + # Public API + # ========================================================================= + + def update_config(self, updates: dict[str, Any]) -> None: + """ + Update preprocessing configuration with new values. + + Args: + updates: Dictionary of config values to update. + """ + current = self._config.to_dict() + current.update(updates) + self._config = PreprocessingConfig.from_dict(current) + logger.info(f"Updated preprocessing config: {self._config.to_dict()}") + + def get_preprocessing_params( + self, + input_spec: Optional[dict[str, Any]] = None, + target_size: Optional[tuple[int, int]] = None, + ) -> tuple[int, int, str]: + """ + Get preprocessing parameters (width, height, format). + + Args: + input_spec: Model input specification, or None for defaults. + target_size: Override (height, width), or None. + + Returns: + Tuple of (width, height, data_format). + """ + params = PreprocessingParams.from_input_spec( + input_spec=input_spec, + target_size=target_size, + default_size=self._config.target_size, + default_format=self._config.format, + ) + logger.debug(f"Preprocessing params: {params.data_format} {params.height}x{params.width}") + return params.width, params.height, params.data_format + + def preprocess_bytes( + self, + image_bytes: Union[bytes, BinaryIO, io.BytesIO], + input_spec: Optional[dict[str, Any]] = None, + target_size: Optional[tuple[int, int]] = None, + ) -> ImageArray: + """ + Preprocess image from bytes for model inference. + + Args: + image_bytes: bytes, BytesIO, or file-like object containing image data. + input_spec: Optional model input spec for auto-detecting dimensions. + target_size: Optional (height, width) tuple to override auto-detection. + + Returns: + Numpy array ready for inference [1, C, H, W] or [1, H, W, C]. + + Raises: + ImagePreprocessingError: If preprocessing fails. + """ + try: + params = PreprocessingParams.from_input_spec( + input_spec, target_size, self._config.target_size, self._config.format + ) + image = self._load_image_from_bytes(image_bytes) + return self._preprocess_pil_image(image, params) + + except (OSError, ValueError) as e: + raise ImagePreprocessingError( + f"Failed to preprocess image from bytes: {e}", + cause=e, + ) from e + + def preprocess_file( + self, + image_path: str, + input_spec: Optional[dict[str, Any]] = None, + target_size: Optional[tuple[int, int]] = None, + ) -> ImageArray: + """ + Preprocess image from file path for model inference. + + Args: + image_path: Path to image file. + input_spec: Optional model input spec for auto-detecting dimensions. + target_size: Optional (height, width) tuple to override auto-detection. + + Returns: + Numpy array ready for inference [1, C, H, W] or [1, H, W, C]. + + Raises: + ImagePreprocessingError: If preprocessing fails. + """ + try: + params = PreprocessingParams.from_input_spec( + input_spec, target_size, self._config.target_size, self._config.format + ) + image = Image.open(image_path) + return self._preprocess_pil_image(image, params) + + except (OSError, ValueError) as e: + raise ImagePreprocessingError( + f"Failed to preprocess image from file: {e}", + image_source=image_path, + cause=e, + ) from e + + def preprocess( + self, + image_data: ImageInput, + input_spec: Optional[dict[str, Any]] = None, + target_size: Optional[tuple[int, int]] = None, + ) -> ImageArray: + """ + Preprocess image from any supported format. + + This is the recommended unified API for preprocessing. Automatically + detects the input type and delegates to the appropriate handler. + + Args: + image_data: Image bytes, file path, file object, or PIL Image. + input_spec: Optional model input spec for auto-detecting dimensions. + target_size: Optional (height, width) tuple to override. + + Returns: + Numpy array ready for inference [1, C, H, W] or [1, H, W, C]. + + Raises: + ImagePreprocessingError: If preprocessing fails. + """ + try: + params = PreprocessingParams.from_input_spec( + input_spec, target_size, self._config.target_size, self._config.format + ) + image = self._load_image(image_data) + return self._preprocess_pil_image(image, params) + + except (OSError, ValueError) as e: + source = image_data if isinstance(image_data, str) else str(type(image_data)) + raise ImagePreprocessingError( + f"Failed to preprocess image: {e}", + image_source=source, + cause=e, + ) from e + + # ========================================================================= + # Private - Image Loading + # ========================================================================= + + def _load_image(self, image_data: ImageInput) -> Image.Image: + """ + Load image from any supported format. + + Args: + image_data: Image in any supported format. + + Returns: + PIL Image object. + + Raises: + ValueError: If image format is not supported. + """ + if isinstance(image_data, str): + return Image.open(image_data) + + if isinstance(image_data, Image.Image): + return image_data + + if isinstance(image_data, (bytes, io.BytesIO)): + return self._load_image_from_bytes(image_data) + + if hasattr(image_data, "read"): + return self._load_image_from_bytes(image_data) + + raise ValueError(f"Unsupported image_data type: {type(image_data)}") + + def _load_image_from_bytes( + self, + image_bytes: Union[bytes, BinaryIO, io.BytesIO], + ) -> Image.Image: + """ + Load PIL Image from bytes or file-like object. + + Args: + image_bytes: Image data as bytes or file-like object. + + Returns: + PIL Image object. + """ + if isinstance(image_bytes, bytes): + return Image.open(io.BytesIO(image_bytes)) + + if isinstance(image_bytes, io.BytesIO): + return Image.open(image_bytes) + + # File-like object with read() method + content = image_bytes.read() + return Image.open(io.BytesIO(content)) + + # ========================================================================= + # Private - Core Preprocessing + # ========================================================================= + + def _preprocess_pil_image( + self, + image: Image.Image, + params: PreprocessingParams, + ) -> ImageArray: + """ + Core preprocessing logic for PIL images. + + Processing steps: + 1. Convert to RGB (handles grayscale, RGBA, etc.) + 2. Resize to target dimensions using LANCZOS + 3. Convert to float32 and normalize to [0, 1] + 4. Apply ImageNet normalization if configured + 5. Transpose to NCHW format if required + 6. Add batch dimension + + Args: + image: PIL Image to preprocess. + params: Preprocessing parameters. + + Returns: + Preprocessed numpy array with shape [1, C, H, W] or [1, H, W, C]. + """ + # Step 1: Convert to RGB + image = image.convert("RGB") + + # Step 2: Resize with high-quality resampling + image = image.resize((params.width, params.height), Image.Resampling.LANCZOS) + + # Step 3: Convert to numpy and normalize to [0, 1] + image_array = np.array(image, dtype=np.float32) / PIXEL_VALUE_MAX + + # Step 4: Apply ImageNet normalization if configured + if self._config.normalize: + image_array = self._apply_normalization(image_array) + + # Step 5: Convert to NCHW if required (default is HWC from PIL) + if params.data_format == "NCHW": + image_array = np.transpose(image_array, (2, 0, 1)) # HWC -> CHW + + # Step 6: Add batch dimension + image_array = np.expand_dims(image_array, axis=0) + + logger.debug(f"Preprocessed image shape: {image_array.shape}") + return image_array + + def _apply_normalization(self, image_array: NDArray[np.float32]) -> NDArray[np.float32]: + """ + Apply ImageNet normalization: (x - mean) / std. + + Args: + image_array: Image array in HWC format, values in [0, 1]. + + Returns: + Normalized image array. + """ + mean = np.array(self._config.mean, dtype=np.float32) + std = np.array(self._config.std, dtype=np.float32) + return (image_array - mean) / std + + +__all__ = [ + "ImagePreprocessor", + "PreprocessingParams", + "ImageArray", + "ImageInput", +] diff --git a/edgeai/ondevice-eval-agent/frontend/.dockerignore b/edgeai/ondevice-eval-agent/frontend/.dockerignore new file mode 100644 index 00000000..01b8e105 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/.dockerignore @@ -0,0 +1,5 @@ +node_modules +dist +.git +.DS_Store +*.log diff --git a/edgeai/ondevice-eval-agent/frontend/.gitignore b/edgeai/ondevice-eval-agent/frontend/.gitignore new file mode 100644 index 00000000..70e69f71 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/.gitignore @@ -0,0 +1,6 @@ +node_modules +dist +.DS_Store +*.log +.env +.env.local diff --git a/edgeai/ondevice-eval-agent/frontend/Dockerfile.dev b/edgeai/ondevice-eval-agent/frontend/Dockerfile.dev new file mode 100644 index 00000000..89c04763 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/Dockerfile.dev @@ -0,0 +1,19 @@ +# Dev-only image: Vite dev server with HMR. +# For prod you'd do a `pnpm build` + nginx stage instead. +FROM node:20-alpine + +RUN corepack enable + +WORKDIR /app + +# Copy manifest only so `docker build` caches deps across source changes. +COPY package.json ./ +RUN pnpm install --no-frozen-lockfile + +# Source is bind-mounted at runtime for HMR; this COPY is a fallback +# for `docker run` without a volume. +COPY . . + +EXPOSE 5173 + +CMD ["pnpm", "dev"] diff --git a/edgeai/ondevice-eval-agent/frontend/README.md b/edgeai/ondevice-eval-agent/frontend/README.md new file mode 100644 index 00000000..f1364904 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/README.md @@ -0,0 +1,76 @@ +# ondevice-eval-agent — frontend + +React + TypeScript + Vite + Tailwind. Replaces the Jinja + vanilla JS UI in +`webapp/templates/` and `webapp/static/js/` with a proper SPA that consumes +the Flask backend's existing SSE stream at `POST /agent/chat/stream`. + +Design tokens (colors, typography, radii, shadows) are ported verbatim from +`webapp/static/css/variables.css` into `src/index.css` and `tailwind.config.js`, +so the UI shares the ZEDEDA EPI theme with the legacy app. + +## Dev + +```bash +# in ondevice-eval-agent/ +python webapp/app.py # Flask on :8080 + +# in ondevice-eval-agent/frontend/ +pnpm install # or npm install +pnpm dev # Vite on :5173, proxies /agent /llm /core /eval /static → :8080 +``` + +Open http://localhost:5173. + +## Build + +```bash +pnpm build # → dist/ +``` + +Serve `dist/` from any static host, or have Flask serve it. Set +`VITE_API_BASE` at build time if the API is on a different origin. + +## SSE event contract + +Mirrors `webapp/routes/agent.py::_generate_sse_events`: + +| event | payload | +|---------------|------------------------------------------------------------| +| `start` | `{ session_id, warnings? }` | +| `warning` | `{ has_warnings, ... }` | +| (default) | `{ token: string }` — streaming token chunk | +| `tool_start` | `{ name, id }` | +| `tool_end` | `{ name, result }` | +| `done` | `{ response, tool_calls, finish_reason, meta, success }` | +| `complete` | same shape as `done`, used when streaming unavailable | +| `error` | `{ error, limit_exceeded?, enabled? }` | + +Parsed in `src/lib/sse.ts`; reduced into `ChatMessage[]` in +`src/hooks/useStreamingChat.ts`. + +## Layout + +``` +src/ + App.tsx — screen: Header + ChatThread + Composer + index.css — EPI tokens + prose + hljs + lib/ + api.ts — fetch wrappers + sse.ts — fetch-based SSE parser + types.ts — ChatMessage, ToolCall, AgentStatus + hooks/ + useStreamingChat.ts — send/stop/reset + reducer for SSE events + components/ + layout/Header.tsx + ui/{Avatar,AutoResizeTextarea,ThemeToggle}.tsx + chat/ + ChatThread.tsx — message list + auto-scroll + Composer.tsx — input + send/stop + WelcomeScreen.tsx — empty state + suggestion pills + UserMessage.tsx + AssistantMessage.tsx — combines tool cards + markdown + cursor + InlineToolCard.tsx — per-tool color, expandable args/result + MarkdownRenderer.tsx — react-markdown + GFM + highlight + CodeBlock.tsx — code header + copy + TypingIndicator.tsx +``` diff --git a/edgeai/ondevice-eval-agent/frontend/index.html b/edgeai/ondevice-eval-agent/frontend/index.html new file mode 100644 index 00000000..54688223 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/index.html @@ -0,0 +1,13 @@ + + + + + + + ZEDEDA Edge AI — Eval Agent + + +
+ + + diff --git a/edgeai/ondevice-eval-agent/frontend/package.json b/edgeai/ondevice-eval-agent/frontend/package.json new file mode 100644 index 00000000..8f94357f --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/package.json @@ -0,0 +1,34 @@ +{ + "name": "ondevice-eval-agent-frontend", + "private": true, + "version": "0.1.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "tsc -b && vite build", + "preview": "vite preview", + "lint": "tsc -b --noEmit" + }, + "dependencies": { + "@fontsource/fira-code": "^5.2.5", + "@fontsource/inter": "^5.2.5", + "clsx": "^2.1.1", + "lucide-react": "^0.468.0", + "react": "^18.3.1", + "react-dom": "^18.3.1", + "react-markdown": "^9.0.3", + "rehype-highlight": "^7.0.2", + "remark-gfm": "^4.0.1" + }, + "devDependencies": { + "@types/react": "^18.3.18", + "@types/react-dom": "^18.3.5", + "@vitejs/plugin-react": "^4.3.4", + "autoprefixer": "^10.4.20", + "highlight.js": "^11.11.1", + "postcss": "^8.5.1", + "tailwindcss": "^3.4.17", + "typescript": "^5.7.2", + "vite": "^6.0.7" + } +} diff --git a/edgeai/ondevice-eval-agent/frontend/postcss.config.js b/edgeai/ondevice-eval-agent/frontend/postcss.config.js new file mode 100644 index 00000000..2aa7205d --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/postcss.config.js @@ -0,0 +1,6 @@ +export default { + plugins: { + tailwindcss: {}, + autoprefixer: {}, + }, +}; diff --git a/edgeai/ondevice-eval-agent/frontend/src/App.tsx b/edgeai/ondevice-eval-agent/frontend/src/App.tsx new file mode 100644 index 00000000..53858ab2 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/App.tsx @@ -0,0 +1,121 @@ +import { useEffect, useRef, useState } from 'react'; +import { Header } from './components/layout/Header'; +import { Sidebar } from './components/layout/Sidebar'; +import { ChatThread } from './components/chat/ChatThread'; +import { Composer } from './components/chat/Composer'; +import { SessionWarningBanner } from './components/chat/SessionWarningBanner'; +import { SettingsModal } from './components/settings/SettingsModal'; +import { ToastProvider, useToast } from './components/ui/Toast'; +import { ErrorBoundary } from './components/ErrorBoundary'; +import { useStreamingChat } from './hooks/useStreamingChat'; +import { useThreads } from './hooks/useThreads'; +import { useAgentStatus } from './hooks/useAgentStatus'; + +const SIDEBAR_KEY = 'ondevice-eval.sidebarCollapsed'; + +export default function App() { + return ( + + + + + + ); +} + +function Shell() { + const toast = useToast(); + const { active, activeId, ensureActive } = useThreads(); + const [sidebarCollapsed, setSidebarCollapsed] = useState( + () => localStorage.getItem(SIDEBAR_KEY) === 'true', + ); + const [settingsOpen, setSettingsOpen] = useState(false); + const { status, refresh } = useAgentStatus(); + + // Ensure an active thread on first load. + useEffect(() => { + if (!activeId) ensureActive(); + }, [activeId, ensureActive]); + + // Auto-collapse sidebar on narrow viewports so chat stays usable on + // small laptops / phones. User's explicit toggle still wins: we only + // force-collapse when the viewport *becomes* narrow, not every render. + useEffect(() => { + const mq = window.matchMedia('(max-width: 767px)'); + const apply = (narrow: boolean) => { + if (narrow) setSidebarCollapsed(true); + }; + apply(mq.matches); + const handler = (e: MediaQueryListEvent) => apply(e.matches); + mq.addEventListener('change', handler); + return () => mq.removeEventListener('change', handler); + }, []); + + // If the agent reports itself not-configured, nudge the user into Settings + // on the first page load (not when they explicitly dismissed). + const nudgedRef = useRef(false); + useEffect(() => { + if (!status || nudgedRef.current) return; + nudgedRef.current = true; + if (!status.enabled) { + toast.info('No LLM configured — add one in Settings.'); + } + }, [status, toast]); + + const { messages, isStreaming, warning, suggestions, send, stop, clearWarning } = + useStreamingChat(active?.id ?? null); + + const toggleSidebar = () => { + setSidebarCollapsed((v) => { + const n = !v; + localStorage.setItem(SIDEBAR_KEY, String(n)); + return n; + }); + }; + + return ( +
+ {/* Sidebar spans full viewport height — header only covers the + main column on the right. */} + setSettingsOpen(true)} + /> +
+
setSettingsOpen(true)} /> +
+ {warning && ( + + )} + send(t, [])} + /> + send(text, drafts)} + onStop={stop} + isStreaming={isStreaming} + disabled={!active} + /> +
+
+ {/* Isolated boundary so a settings render error doesn't black out the chat. */} + + setSettingsOpen(false)} + onChange={refresh} + /> + +
+ ); +} + diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/ErrorBoundary.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/ErrorBoundary.tsx new file mode 100644 index 00000000..19dc6c6d --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/ErrorBoundary.tsx @@ -0,0 +1,74 @@ +import { Component, type ErrorInfo, type ReactNode } from 'react'; +import { AlertCircle, RotateCw } from 'lucide-react'; + +interface Props { + children: ReactNode; + /** If provided, used as the label in the reset button. */ + resetLabel?: string; +} +interface State { + err: Error | null; +} + +/** + * App-level error boundary. Catches render errors (e.g. a response with an + * unexpected shape being rendered as a React child) and shows a visible + * recovery UI instead of blanking the screen. + */ +export class ErrorBoundary extends Component { + state: State = { err: null }; + + static getDerivedStateFromError(err: Error): State { + return { err }; + } + + componentDidCatch(err: Error, info: ErrorInfo): void { + // Surface to the console for easier debugging without swallowing. + console.error('ErrorBoundary caught:', err, info); + } + + reset = () => this.setState({ err: null }); + + render() { + if (!this.state.err) return this.props.children; + + return ( +
+
+ + Something went wrong. +
+
+          {String(this.state.err?.message ?? this.state.err)}
+        
+ +
+ ); + } +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/chat/AssistantMessage.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/chat/AssistantMessage.tsx new file mode 100644 index 00000000..f2575723 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/chat/AssistantMessage.tsx @@ -0,0 +1,169 @@ +import { AlertCircle } from 'lucide-react'; +import type { ChatMessage, MessageBlock, ToolCall } from '../../lib/types'; +import { MarkdownRenderer } from './MarkdownRenderer'; +import { ToolStepsBlock, ToolStep } from './ToolStepsBlock'; +import { ZThrobber } from './ZThrobber'; +import { MessageCopyButton } from './MessageCopyButton'; +import { useThrottledMarkdown } from '../../hooks/useThrottledMarkdown'; + +/** + * Assistant response — no bubble, no avatar, flows like prose. + * + * Two rendering modes: + * + * 1. Block-aware (new). Messages streamed with useStreamingChat write a + * `blocks` array preserving the order text / tools actually arrived. + * We render each block in place, so a turn that goes + * "Let me check → run_inference → Here's what I found → view_image → ..." + * shows exactly like that instead of bunching every tool up top. + * + * 2. Legacy (fallback). Old persisted messages and image-upload turns + * don't have blocks; for those we render `ToolStepsBlock` above the + * text as before. Nothing in localStorage needs migration. + * + * Streaming note: raw content is throttled (~12 Hz) via useThrottledMarkdown + * to keep react-markdown + rehype-highlight parse cost sane during long + * token bursts. Final value flushes instantly when streaming flips off. + */ +export function AssistantMessage({ message }: { message: ChatMessage }) { + const streaming = message.streaming ?? false; + const displayed = useThrottledMarkdown(message.content, streaming); + const hasContent = displayed.trim().length > 0; + const hasTools = message.toolCalls.length > 0; + const runningTool = message.toolCalls.find((t) => t.status === 'running'); + const hasBlocks = Array.isArray(message.blocks) && message.blocks.length > 0; + + // Throbber only while there is nothing to show yet in THIS turn — once + // any block has rendered (text or a tool step) the activity is + // visible via the tool marker or the text itself. + const showThrobber = streaming && !hasContent && !hasTools; + const throbberLabel = runningTool + ? `Running ${runningTool.name}` + : 'Thinking'; + + return ( +
+ {hasBlocks ? ( + + ) : ( + <> + {hasTools && ( + + )} + {hasContent && ( +
+ +
+ )} + + )} + + {showThrobber && } + + {message.error && ( +
+ + {message.error} +
+ )} + + {!streaming && (hasContent || hasBlocks) && ( +
+ +
+ )} +
+ ); +} + +/** + * Renders the ordered block list. Text blocks go through the throttled + * markdown pipeline; tool blocks render as inline ToolStep rows. + * + * The LAST text block during streaming renders the `throttledContent` + * from useThrottledMarkdown so updates feel smooth; earlier text blocks + * are already sealed so they render their stored text as-is. + */ +function BlockList({ + blocks, + toolCalls, + streaming, + throttledContent, +}: { + blocks: MessageBlock[]; + toolCalls: ToolCall[]; + streaming: boolean; + throttledContent: string; +}) { + // Resolve tool refs by id once. O(N) lookup below is fine (N is tiny). + const toolById = new Map(toolCalls.map((t) => [t.id, t])); + + // Find the index of the last text block so we can swap in the throttled + // value for the one that's currently growing. + const lastTextIdx = (() => { + for (let i = blocks.length - 1; i >= 0; i--) { + if (blocks[i].type === 'text') return i; + } + return -1; + })(); + + // Reconstruct the full content (before the last text block) so we can + // subtract it from throttledContent and show just this block's share. + // This matters when a turn has multiple text blocks separated by tool + // calls: earlier blocks are sealed, only the last one streams. + let consumedChars = 0; + for (let i = 0; i < lastTextIdx; i++) { + const b = blocks[i]; + if (b.type === 'text') consumedChars += b.text.length; + } + const streamingTail = streaming ? throttledContent.slice(consumedChars) : ''; + + return ( +
+ {blocks.map((block, idx) => { + if (block.type === 'text') { + const isLastText = idx === lastTextIdx; + const text = + isLastText && streaming ? streamingTail : block.text; + if (!text.trim()) return null; + return ( +
+ +
+ ); + } + + const tc = toolById.get(block.toolCallId); + if (!tc) return null; + return ( + + ); + })} +
+ ); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/chat/AttachmentPreview.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/chat/AttachmentPreview.tsx new file mode 100644 index 00000000..bc8a73cc --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/chat/AttachmentPreview.tsx @@ -0,0 +1,66 @@ +import { File as FileIcon, X } from 'lucide-react'; +import type { Attachment } from '../../lib/types'; + +/** Inline thumbnail for a draft or sent attachment. */ +export function AttachmentChip({ + attachment, + onRemove, + onOpen, +}: { + attachment: Attachment; + onRemove?: () => void; + onOpen?: () => void; +}) { + const isImage = attachment.kind === 'image'; + + return ( +
+ {isImage && attachment.previewUrl ? ( + + ) : ( +
+ +
+ )} + {attachment.name} + {onRemove && ( + + )} +
+ ); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/chat/ChatThread.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/chat/ChatThread.tsx new file mode 100644 index 00000000..18120bb1 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/chat/ChatThread.tsx @@ -0,0 +1,103 @@ +import { useEffect, useRef, useState } from 'react'; +import type { ChatMessage } from '../../lib/types'; +import { UserMessage } from './UserMessage'; +import { AssistantMessage } from './AssistantMessage'; +import { WelcomeScreen } from './WelcomeScreen'; +import { ImageModal } from './ImageModal'; +import { isAutoWelcome } from '../../lib/welcomeMessage'; + +interface Props { + messages: ChatMessage[]; + /** Context-aware follow-up prompts shown below the auto-welcome message. */ + suggestions?: string[]; + onPickSuggestion: (text: string) => void; +} + +export function ChatThread({ messages, suggestions, onPickSuggestion }: Props) { + const endRef = useRef(null); + const [preview, setPreview] = useState<{ src: string; alt?: string } | null>( + null, + ); + + const last = messages[messages.length - 1]; + const lastAssistantLen = last?.role === 'assistant' ? last.content.length : 0; + useEffect(() => { + endRef.current?.scrollIntoView({ behavior: 'smooth', block: 'end' }); + }, [messages.length, lastAssistantLen]); + + if (messages.length === 0) { + // The welcome message is fetched and injected by useStreamingChat, so + // this branch only renders very briefly on a brand-new thread while + // the /server-info + /models + /llm/status fetches are in flight. + return ( +
+ +
+ ); + } + + // Show suggestion chips only when the thread is still just the + // auto-welcome (no user messages yet). Once the user sends anything, + // useStreamingChat clears the suggestions list. + const showSuggestions = + (suggestions?.length ?? 0) > 0 && + messages.length === 1 && + isAutoWelcome(messages[0]); + + return ( +
+
+ {messages.map((m) => + m.role === 'user' ? ( + setPreview({ src, alt })} + /> + ) : ( + + ), + )} + + {showSuggestions && ( +
+ {suggestions!.map((s) => ( + + ))} +
+ )} + +
+
+ {preview && ( + setPreview(null)} + /> + )} +
+ ); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/chat/CodeBlock.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/chat/CodeBlock.tsx new file mode 100644 index 00000000..16194f0b --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/chat/CodeBlock.tsx @@ -0,0 +1,40 @@ +import clsx from 'clsx'; +import { FloatingCopyButton } from '../ui/FloatingCopyButton'; + +/** + * Minimal code block: + * - No full-width header bar + * - Floating copy button in the top-right + * - Explicit dark palette (not theme-flipping CSS vars) so contrast + * stays high regardless of the user's light/dark preference + */ +export function CodeBlock({ + language, + children, + className, +}: { + language?: string; + children: string; + className?: string; +}) { + return ( +
+ {language && ( + + {language} + + )} + +
+        
+          {children}
+        
+      
+
+ ); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/chat/Composer.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/chat/Composer.tsx new file mode 100644 index 00000000..feddf2a3 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/chat/Composer.tsx @@ -0,0 +1,210 @@ +import { useRef, useState } from 'react'; +import { ArrowUp, ImageIcon, Square, X } from 'lucide-react'; +import { AutoResizeTextarea } from '../ui/AutoResizeTextarea'; +import { AttachmentChip } from './AttachmentPreview'; +import type { Attachment } from '../../lib/types'; +import { shortId } from '../../lib/ids'; + +export interface DraftAttachment extends Attachment { + file: File; +} + +interface Props { + onSubmit: (text: string, attachments: DraftAttachment[]) => void; + onStop: () => void; + isStreaming: boolean; + disabled?: boolean; +} + +const MAX_MB = 10; + +export function Composer({ onSubmit, onStop, isStreaming, disabled }: Props) { + const [text, setText] = useState(''); + const [drafts, setDrafts] = useState([]); + const textareaRef = useRef(null); + const fileRef = useRef(null); + + const canSend = (text.trim() || drafts.length > 0) && !isStreaming && !disabled; + + const submit = () => { + if (!canSend) return; + onSubmit(text, drafts); + setText(''); + setDrafts([]); + textareaRef.current?.focus(); + }; + + const handleFiles = (files: FileList | null) => { + if (!files) return; + const next: DraftAttachment[] = []; + for (const file of Array.from(files)) { + if (file.size > MAX_MB * 1024 * 1024) { + // Silently drop oversize; toast would require prop drilling. + console.warn(`Dropped ${file.name}: exceeds ${MAX_MB}MB`); + continue; + } + const kind: Attachment['kind'] = file.type.startsWith('image/') + ? 'image' + : 'file'; + const previewUrl = + kind === 'image' ? URL.createObjectURL(file) : undefined; + next.push({ + id: shortId('att'), + kind, + name: file.name, + mimeType: file.type, + previewUrl, + file, + }); + } + setDrafts((prev) => [...prev, ...next]); + }; + + const removeDraft = (id: string) => { + setDrafts((prev) => { + const gone = prev.find((d) => d.id === id); + if (gone?.previewUrl) URL.revokeObjectURL(gone.previewUrl); + return prev.filter((d) => d.id !== id); + }); + }; + + return ( +
{ + e.preventDefault(); + }} + onDrop={(e) => { + e.preventDefault(); + handleFiles(e.dataTransfer.files); + }} + > +
+ {drafts.length > 0 && ( +
+ {drafts.map((d) => ( + removeDraft(d.id)} + onOpen={() => { + if (d.previewUrl) window.open(d.previewUrl, '_blank'); + }} + /> + ))} + {drafts.length > 0 && ( + + )} +
+ )} + +
+ + { + handleFiles(e.target.files); + e.target.value = ''; + }} + /> + + setText(e.target.value)} + onKeyDown={(e) => { + if (e.key === 'Enter' && !e.shiftKey) { + e.preventDefault(); + submit(); + } + }} + className="min-h-[24px] flex-1 resize-none border-0 bg-transparent px-2 py-2 text-[15px] outline-none" + style={{ color: 'var(--gray-900)' }} + /> + + {isStreaming ? ( + + ) : ( + + )} +
+

+ Enter to send · Shift+Enter for newline · drag-drop images +

+
+
+ ); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/chat/ImageModal.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/chat/ImageModal.tsx new file mode 100644 index 00000000..8962b01d --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/chat/ImageModal.tsx @@ -0,0 +1,43 @@ +import { useEffect } from 'react'; +import { X } from 'lucide-react'; + +export function ImageModal({ + src, + alt, + onClose, +}: { + src: string; + alt?: string; + onClose: () => void; +}) { + useEffect(() => { + const esc = (e: KeyboardEvent) => { + if (e.key === 'Escape') onClose(); + }; + window.addEventListener('keydown', esc); + return () => window.removeEventListener('keydown', esc); + }, [onClose]); + + return ( +
+ {alt} e.stopPropagation()} + /> + +
+ ); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/chat/MarkdownRenderer.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/chat/MarkdownRenderer.tsx new file mode 100644 index 00000000..dbcbbcf1 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/chat/MarkdownRenderer.tsx @@ -0,0 +1,78 @@ +import { useRef, type ReactNode } from 'react'; +import ReactMarkdown from 'react-markdown'; +import remarkGfm from 'remark-gfm'; +import rehypeHighlight from 'rehype-highlight'; +import { CodeBlock } from './CodeBlock'; +import { FloatingCopyButton } from '../ui/FloatingCopyButton'; +import 'highlight.js/styles/atom-one-dark.css'; + +export function MarkdownRenderer({ content }: { content: string }) { + return ( +
+ typeof c === 'string').join('') + : ''; + return ( + {text.replace(/\n$/, '')} + ); + } + return
{children}
; + }, + + // Wrap tables so they get a floating copy button that exports + // the table as tab-separated values (paste into Sheets/Excel). + table({ children }) { + return {children}; + }, + }} + > + {content} +
+
+ ); +} + +function TableWithCopy({ children }: { children: ReactNode }) { + const ref = useRef(null); + + const getTsv = () => { + const table = ref.current?.querySelector('table'); + if (!table) return ''; + return Array.from(table.rows) + .map((r) => + Array.from(r.cells) + .map((c) => c.innerText.replace(/\t/g, ' ').trim()) + .join('\t'), + ) + .join('\n'); + }; + + // Scroll container holds the table at its natural width so wide tables + // get a horizontal scrollbar instead of stretching the layout; the + // wrapper itself fills 100% of the message column. + return ( +
+ + {children} +
+ ); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/chat/MessageCopyButton.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/chat/MessageCopyButton.tsx new file mode 100644 index 00000000..2ee399a6 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/chat/MessageCopyButton.tsx @@ -0,0 +1,29 @@ +import { useState } from 'react'; +import { Check, Copy } from 'lucide-react'; + +export function MessageCopyButton({ text }: { text: string }) { + const [copied, setCopied] = useState(false); + const copy = async () => { + try { + await navigator.clipboard.writeText(text); + setCopied(true); + setTimeout(() => setCopied(false), 1500); + } catch { + /* noop */ + } + }; + return ( + + ); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/chat/SessionWarningBanner.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/chat/SessionWarningBanner.tsx new file mode 100644 index 00000000..6ec38bd8 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/chat/SessionWarningBanner.tsx @@ -0,0 +1,41 @@ +import { AlertTriangle, X } from 'lucide-react'; +import type { SessionWarning } from '../../hooks/useStreamingChat'; + +export function SessionWarningBanner({ + warning, + onDismiss, +}: { + warning: SessionWarning; + onDismiss: () => void; +}) { + const hard = warning.hard_limit_exceeded; + const near = warning.near_limit_dimensions ?? []; + const text = hard + ? `Session limit reached${warning.exceeded_dimension ? ` (${warning.exceeded_dimension})` : ''}. Start a new chat to continue.` + : near.length > 0 + ? `Session nearing its limit on: ${near.join(', ')}.` + : 'Session warning.'; + + return ( +
+ + {text} + +
+ ); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/chat/ToolStepsBlock.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/chat/ToolStepsBlock.tsx new file mode 100644 index 00000000..aad47460 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/chat/ToolStepsBlock.tsx @@ -0,0 +1,384 @@ +import { useState } from 'react'; +import { + Activity, + AlertCircle, + Boxes, + ChevronDown, + Compass, + FileText, + Hammer, + Loader2, + ScanSearch, + Sparkles, + Wrench, + Workflow, +} from 'lucide-react'; +import clsx from 'clsx'; +import type { ToolCall } from '../../lib/types'; + +/** + * Grouped timeline of tool calls — renders like Claude's "Completed N + * steps" UI: a single collapsible header with a vertical line running + * through per-step markers. Replaces the flat list of per-tool cards. + */ +export function ToolStepsBlock({ + toolCalls, + isStreaming, +}: { + toolCalls: ToolCall[]; + isStreaming: boolean; +}) { + const [open, setOpen] = useState(true); + if (toolCalls.length === 0) return null; + + const running = toolCalls.some((t) => t.status === 'running'); + const label = running + ? `Working on ${toolCalls.length} step${toolCalls.length === 1 ? '' : 's'}` + : `Completed ${toolCalls.length} step${toolCalls.length === 1 ? '' : 's'}`; + + return ( +
+ + + {open && ( +
+ {/* Vertical timeline line — runs between first and last marker. */} + + {toolCalls.map((tc, i) => ( + + ))} +
+ )} +
+ ); +} + +/** + * Single tool step. Exported so AssistantMessage can render it inline + * between text blocks — with `inline=true` the absolute marker is + * replaced by an inline icon so it doesn't depend on an ancestor + * providing the timeline vertical line. + */ +export function ToolStep({ + tool, + isLast: _isLast, + streaming: _streaming, + inline = false, +}: { + tool: ToolCall; + isLast: boolean; + streaming: boolean; + inline?: boolean; +}) { + const [open, setOpen] = useState(false); + const visual = toolVisual(tool.name); + const Icon = visual.icon; + + const running = tool.status === 'running'; + const errored = tool.status === 'error'; + + const markerStyle = { + background: 'var(--island-bg)', + border: `1.5px solid ${running ? 'var(--zededa-cyan-border)' : errored ? 'rgba(239,68,68,0.4)' : 'var(--gray-200)'}`, + color: running + ? 'var(--zededa-cyan)' + : errored + ? 'var(--color-error)' + : visual.color, + }; + + const markerInner = running ? ( + + ) : errored ? ( + + ) : ( + + ); + + return ( +
+ {inline ? ( + // Inline form: icon sits next to the label, no timeline line. + // Used by AssistantMessage when rendering tool blocks interleaved + // with text so the tool doesn't need an ancestor container + // providing the vertical line. + + {markerInner} + + ) : ( + // Absolute-positioned marker, vertically centered on its row. + // Parent container has pl-6 (24px) and the timeline line sits + // at left:12px, so a 20px marker at left:2px has its centre at + // 12px — exactly on the line. + + {markerInner} + + )} + + + + {/* If this tool returned an image, always render it inline — hiding + it behind the "expand JSON" toggle makes the most useful part of + the output invisible by default (e.g. inference result overlays, + view_image output, DETR visualizations). */} + {(() => { + const img = extractImageFromToolResult(tool.result); + if (!img) return null; + return ( +
+ {img.alt} + {img.caption && ( +
+ {img.caption} +
+ )} +
+ ); + })()} + + {open && ( +
+ {tool.args && Object.keys(tool.args).length > 0 && ( + + )} + {tool.result !== undefined && ( + + )} + {tool.args === undefined && tool.result === undefined && ( +
+ {running ? 'Running…' : 'No details captured.'} +
+ )} +
+ )} +
+ ); +} + +interface ExtractedImage { + src: string; + alt: string; + caption?: string; +} + +/** + * Look for an image payload in a tool result. Handles: + * { image_base64, mime_type, message? } // view_image, run_inference + * { image, mime_type } // generic fallback + * { visualization: { image_base64, mime_type } } // nested helpers + * Returns null when no image is present. + */ +function extractImageFromToolResult(result: unknown): ExtractedImage | null { + if (!result || typeof result !== 'object') return null; + const r = result as Record; + + // Recurse into common nested containers so we don't miss images that + // live under `visualization` / `image` / `data`. + for (const nested of ['visualization', 'image', 'data', 'output']) { + const v = r[nested]; + if (v && typeof v === 'object') { + const found = extractImageFromToolResult(v); + if (found) return found; + } + } + + const b64 = + (typeof r.image_base64 === 'string' && r.image_base64) || + (typeof r.image === 'string' && r.image) || + (typeof r.base64 === 'string' && r.base64); + if (!b64) return null; + + const mime = + (typeof r.mime_type === 'string' && r.mime_type) || + (typeof r.mimetype === 'string' && r.mimetype) || + 'image/png'; + const src = b64.startsWith('data:') ? b64 : `data:${mime};base64,${b64}`; + const alt = + (typeof r.description === 'string' && r.description) || + (typeof r.message === 'string' && r.message) || + 'Tool output image'; + const caption = + typeof r.message === 'string' && r.message !== alt ? r.message : undefined; + + return { src, alt, caption }; +} + +/** + * Return a copy of a tool result with any base64 image payloads replaced + * with a short marker, so the raw JSON view stays readable. We only strip + * the top-level and one level of nesting; this matches extractImageFromToolResult. + */ +function stripImageBase64(result: unknown): unknown { + if (!result || typeof result !== 'object') return result; + const src = result as Record; + const out: Record = {}; + for (const [k, v] of Object.entries(src)) { + if ( + (k === 'image_base64' || k === 'image' || k === 'base64') && + typeof v === 'string' && + v.length > 200 + ) { + out[k] = `[${v.length} chars of base64 — rendered above]`; + continue; + } + if (v && typeof v === 'object' && !Array.isArray(v)) { + out[k] = stripImageBase64(v); + } else { + out[k] = v; + } + } + return out; +} + +function DetailRow({ label, value }: { label: string; value: string }) { + return ( +
+
+ {label} +
+
+        {value}
+      
+
+ ); +} + +function formatJson(v: unknown): string { + if (v == null) return ''; + if (typeof v === 'string') return v; + try { + return JSON.stringify(v, null, 2); + } catch { + return String(v); + } +} + +// ------------- tool icon + label mapping ------------- + +interface Visual { + color: string; + icon: typeof Sparkles; +} + +function toolVisual(name: string): Visual { + const n = name.toLowerCase(); + if (n.includes('list') && n.includes('model')) + return { color: '#5B8DEF', icon: Boxes }; + if (n.includes('analyze') && n.includes('model')) + return { color: '#A855F7', icon: ScanSearch }; + if (n.includes('metadata')) return { color: '#14B8A6', icon: FileText }; + if (n.includes('input')) return { color: '#06B6D4', icon: Workflow }; + if (n.includes('output') || n.includes('interpret')) + return { color: '#F59E0B', icon: Compass }; + if (n.includes('integration') || n.includes('frontend')) + return { color: '#EC4899', icon: Hammer }; + if (n.includes('recommend') || n.includes('next')) + return { color: '#10B981', icon: Sparkles }; + if (n.includes('predict') || n.includes('infer')) + return { color: '#6366F1', icon: Activity }; + return { color: '#6B7280', icon: Wrench }; +} + +function prettyName(raw: string): string { + const overrides: Record = { + list_available_models: 'Listing available models', + get_model_metadata: 'Fetching model metadata', + analyze_model_type: 'Analysing model type', + get_model_input_requirements: 'Checking input requirements', + get_model_output_interpretation: 'Interpreting model output', + get_frontend_integration_guide: 'Writing integration snippet', + recommend_next_steps: 'Recommending next steps', + get_server_status: 'Checking server status', + }; + if (overrides[raw]) return overrides[raw]; + return raw.replace(/_/g, ' ').replace(/^\w/, (c) => c.toUpperCase()); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/chat/UserMessage.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/chat/UserMessage.tsx new file mode 100644 index 00000000..2ba6c34c --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/chat/UserMessage.tsx @@ -0,0 +1,42 @@ +import type { ChatMessage } from '../../lib/types'; +import { UserAvatar } from '../ui/Avatar'; +import { AttachmentChip } from './AttachmentPreview'; + +export function UserMessage({ + message, + onOpenImage, +}: { + message: ChatMessage; + onOpenImage: (src: string, alt?: string) => void; +}) { + const atts = message.attachments ?? []; + return ( +
+ +
+ {atts.length > 0 && ( +
+ {atts.map((a) => ( + { + if (a.kind === 'image' && a.previewUrl) { + onOpenImage(a.previewUrl, a.name); + } + }} + /> + ))} +
+ )} + {message.content && ( +
+
+ {message.content} +
+
+ )} +
+
+ ); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/chat/WelcomeScreen.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/chat/WelcomeScreen.tsx new file mode 100644 index 00000000..42d9d18a --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/chat/WelcomeScreen.tsx @@ -0,0 +1,78 @@ +import { Boxes, Compass, Hammer, Workflow } from 'lucide-react'; + +const SUGGESTIONS: Array<{ + text: string; + icon: typeof Boxes; + tone: 'green' | 'blue' | 'purple' | 'red'; +}> = [ + { text: 'List the available models', icon: Boxes, tone: 'blue' }, + { text: 'What inputs does the first model need?', icon: Workflow, tone: 'green' }, + { text: 'How do I read its output?', icon: Compass, tone: 'purple' }, + { text: 'Give me a frontend integration snippet', icon: Hammer, tone: 'red' }, +]; + +const TONE_BORDER: Record = { + green: 'rgba(16, 185, 129, 0.3)', + blue: 'rgba(59, 130, 246, 0.3)', + purple: 'rgba(168, 85, 247, 0.3)', + red: 'rgba(239, 68, 68, 0.3)', +}; +const TONE_COLOR: Record = { + green: '#10B981', + blue: '#3B82F6', + purple: '#A855F7', + red: '#EF4444', +}; + +export function WelcomeScreen({ onPick }: { onPick: (text: string) => void }) { + return ( +
+
+
+ +
+

+ Explore on-device models +

+
+

+ Ask about available models, inputs and outputs, or how to wire them into + your app. The agent will call tools and stream its answer back. +

+ +
+ {SUGGESTIONS.map((s) => { + const Icon = s.icon; + return ( + + ); + })} +
+
+ ); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/chat/ZThrobber.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/chat/ZThrobber.tsx new file mode 100644 index 00000000..7fe2279f --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/chat/ZThrobber.tsx @@ -0,0 +1,38 @@ +/** + * Thinking indicator: a throbbing ZEDEDA "Z" in brand cyan, paired + * with a softly pulsing label. Replaces the generic three-dot typing + * indicator. + */ +export function ZThrobber({ label = 'Thinking' }: { label?: string }) { + return ( +
+ + + + + + + {label} + + +
+ ); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/layout/Header.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/layout/Header.tsx new file mode 100644 index 00000000..a542a8e8 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/layout/Header.tsx @@ -0,0 +1,53 @@ +import { StatusDot } from '../ui/StatusDot'; +import type { AgentStatusResponse } from '../../lib/api'; + +/** + * Top bar — the brand/logo lives in the sidebar now. The only thing + * in the header is the LLM status pill, which doubles as a shortcut + * to open Settings. + */ +export function Header({ + status, + onOpenSettings, +}: { + status: AgentStatusResponse | null; + onOpenSettings: () => void; +}) { + const active = status?.llm_router?.active_provider; + const enabled = status?.enabled ?? false; + + const dotState: 'active' | 'warning' | 'offline' = enabled + ? 'active' + : status?.llm_router?.providers && status.llm_router.providers > 0 + ? 'warning' + : 'offline'; + + const label = enabled + ? `${active ?? '?'} · ${status?.model ?? 'no model'}` + : (status?.message ?? 'No LLM configured'); + + return ( +
+ +
+ ); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/layout/Sidebar.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/layout/Sidebar.tsx new file mode 100644 index 00000000..3e16e319 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/layout/Sidebar.tsx @@ -0,0 +1,579 @@ +import { useMemo, useRef, useState } from 'react'; +import clsx from 'clsx'; +import { + ChevronDown, + ChevronLeft, + ChevronRight, + Download, + MessageSquare, + Moon, + Pencil, + Plus, + PlusCircle, + Search, + Settings, + Sun, + Trash2, + Upload, +} from 'lucide-react'; +import type { Thread } from '../../lib/types'; +import { useThreads } from '../../hooks/useThreads'; +import { useToast } from '../ui/Toast'; + +const PAGE_KEY = 'ondevice-eval.sidebarPageSize'; +const HISTORY_OPEN_KEY = 'ondevice-eval.sidebarHistoryOpen'; + +const SUPPORTED_PAGE_SIZES = [10, 20, 50] as const; +const DEFAULT_PAGE_SIZE = 10; + +function normalizePageSize(value: unknown): number { + const n = Number(value); + return (SUPPORTED_PAGE_SIZES as readonly number[]).includes(n) + ? n + : DEFAULT_PAGE_SIZE; +} + +interface Props { + collapsed: boolean; + onToggleCollapsed: () => void; + onOpenSettings: () => void; +} + +export function Sidebar({ collapsed, onToggleCollapsed, onOpenSettings }: Props) { + const { + threads, + activeId, + setActive, + createAndActivate, + remove, + rename, + exportAll, + importAll, + } = useThreads(); + const toast = useToast(); + + const [query, setQuery] = useState(''); + const [page, setPage] = useState(1); + const [pageSize, setPageSize] = useState( + () => normalizePageSize(localStorage.getItem(PAGE_KEY)), + ); + const [historyOpen, setHistoryOpen] = useState( + () => localStorage.getItem(HISTORY_OPEN_KEY) !== 'false', + ); + const [editingId, setEditingId] = useState(null); + const fileRef = useRef(null); + + const filtered = useMemo(() => { + const q = query.trim().toLowerCase(); + if (!q) return threads; + return threads.filter((t) => { + if (t.title.toLowerCase().includes(q)) return true; + return t.messages.some((m) => m.content.toLowerCase().includes(q)); + }); + }, [query, threads]); + + const totalPages = Math.max(1, Math.ceil(filtered.length / pageSize)); + const safePage = Math.min(page, totalPages); + const pageItems = filtered.slice((safePage - 1) * pageSize, safePage * pageSize); + + const toggleHistory = () => { + setHistoryOpen((v) => { + const n = !v; + localStorage.setItem(HISTORY_OPEN_KEY, String(n)); + return n; + }); + }; + + const handleDelete = (t: Thread) => { + if ( + !window.confirm( + `Delete "${t.title}"? ${t.messages.length} message${t.messages.length === 1 ? '' : 's'} will be lost.`, + ) + ) + return; + remove(t.id); + toast.info(`Deleted "${t.title}"`); + }; + + const handleExport = () => { + const blob = new Blob([exportAll()], { type: 'application/json' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `chats-${new Date().toISOString().slice(0, 10)}.json`; + a.click(); + URL.revokeObjectURL(url); + toast.success('Exported chats'); + }; + + const handleImportFile = async (f: File) => { + const text = await f.text(); + const { imported, skipped } = importAll(text); + if (imported > 0) + toast.success(`Imported ${imported} thread${imported === 1 ? '' : 's'}`); + if (skipped > 0 && imported === 0) + toast.warning(`Skipped ${skipped} (duplicate ids)`); + }; + + // --------- collapsed rail --------- + // Logo slot swaps Z ↔ expand-chevron based on hover over the whole + // sidebar, not just the button itself. Below the logo sit minimal + // action icons (+, settings, theme) so they're always reachable. + if (collapsed) { + return ( + + ); + } + + // --------- expanded sidebar --------- + return ( + + ); +} + +// ------------------ building blocks ------------------ + +function SidebarRow({ + onClick, + icon, + label, + active, + prominent, + dense, +}: { + onClick: () => void; + icon: React.ReactNode; + label: string; + active?: boolean; + prominent?: boolean; + dense?: boolean; +}) { + return ( + + ); +} + +function RailIcon({ + onClick, + icon, + ...rest +}: { + onClick: () => void; + icon: React.ReactNode; + 'aria-label': string; +}) { + return ( + + ); +} + +function RailThemeToggle() { + const [theme, setTheme] = useState<'light' | 'dark'>( + () => + (document.documentElement.dataset.theme as 'light' | 'dark') || 'light', + ); + const toggle = () => { + const next = theme === 'light' ? 'dark' : 'light'; + document.documentElement.dataset.theme = next; + localStorage.setItem('theme', next); + setTheme(next); + }; + return ( + + ) : ( + + ) + } + /> + ); +} + +function ThemeRow() { + const [theme, setTheme] = useState<'light' | 'dark'>( + () => (document.documentElement.dataset.theme as 'light' | 'dark') || 'light', + ); + const toggle = () => { + const next = theme === 'light' ? 'dark' : 'light'; + document.documentElement.dataset.theme = next; + localStorage.setItem('theme', next); + setTheme(next); + }; + return ( + + ) : ( + + ) + } + label={theme === 'light' ? 'Dark mode' : 'Light mode'} + /> + ); +} + +function ThreadItem({ + thread, + active, + editing, + onSelect, + onStartEdit, + onFinishEdit, + onDelete, +}: { + thread: Thread; + active: boolean; + editing: boolean; + onSelect: () => void; + onStartEdit: () => void; + onFinishEdit: (next: string | null) => void; + onDelete: () => void; +}) { + const [draft, setDraft] = useState(thread.title); + + return ( +
+ + {editing ? ( + setDraft(e.target.value)} + onBlur={() => onFinishEdit(draft)} + onKeyDown={(e) => { + if (e.key === 'Enter') onFinishEdit(draft); + if (e.key === 'Escape') onFinishEdit(null); + }} + className="flex-1 rounded border bg-transparent px-1 py-0.5 text-sm outline-none" + style={{ borderColor: 'var(--zededa-cyan-border)' }} + /> + ) : ( + + )} + {!editing && ( +
+ + +
+ )} +
+ ); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/settings/SettingsModal.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/settings/SettingsModal.tsx new file mode 100644 index 00000000..06f50470 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/settings/SettingsModal.tsx @@ -0,0 +1,633 @@ +import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import { + Check, + CheckCircle2, + Download, + Loader2, + Play, + Plus, + RefreshCw, + ShieldCheck, + Trash2, + Upload, +} from 'lucide-react'; +import { llmApi, type Credential, type RouterStatus } from '../../lib/api'; +import { Modal } from '../ui/Modal'; +import { StatusDot } from '../ui/StatusDot'; +import { useToast } from '../ui/Toast'; + +interface Props { + open: boolean; + onClose: () => void; + /** Called after any mutation so the header badge refetches. */ + onChange?: () => void; +} + +export function SettingsModal({ open, onClose, onChange }: Props) { + const toast = useToast(); + const [creds, setCreds] = useState([]); + const [router, setRouter] = useState(null); + const [loading, setLoading] = useState(false); + const [showAdd, setShowAdd] = useState(false); + const importRef = useRef(null); + + const reload = useCallback(async () => { + setLoading(true); + try { + const [c, r] = await Promise.all([ + llmApi.listCredentials(), + llmApi.routerStatus(), + ]); + setCreds(c.credentials ?? []); + setRouter(r); + } catch (e) { + toast.error(`Failed to load: ${(e as Error).message}`); + } finally { + setLoading(false); + } + }, [toast]); + + useEffect(() => { + if (open) void reload(); + }, [open, reload]); + + // /llm/status returns `active_provider` as the full provider dict, not a + // string. Extract the name for display + comparisons. Defensive in case + // a future payload changes shape. + const activeName = + typeof router?.active_provider === 'string' + ? (router.active_provider as string) + : (router?.active_provider?.name ?? null); + + // When the deployment injects EIP_ACCESS_TOKEN, the router + // auto-registers an "edgeai-builtin" openai-compatible provider tagged + // `metadata.builtin = true`. When this provider is present, the agent + // works out of the box with no user API key required — we surface a + // banner and treat custom credentials as optional fallbacks. + const builtinProvider = useMemo( + () => + router?.providers?.find( + (p) => p.metadata?.builtin === true, + ) ?? null, + [router], + ); + const isBuiltinActive = + builtinProvider != null && builtinProvider.name === activeName; + + const activate = async (name: string) => { + try { + await llmApi.activateCredential(name); + toast.success(`Activated ${name}`); + onChange?.(); + await reload(); + } catch (e) { + toast.error(`Activate failed: ${(e as Error).message}`); + } + }; + + const remove = async (name: string) => { + if (!window.confirm(`Delete credential "${name}"?`)) return; + try { + await llmApi.deleteCredential(name); + toast.info(`Deleted ${name}`); + onChange?.(); + await reload(); + } catch (e) { + toast.error(`Delete failed: ${(e as Error).message}`); + } + }; + + const exportCreds = async () => { + try { + const resp = await llmApi.exportCredentials(); + // Save only the portable `bundle` — the rest is response metadata. + const blob = new Blob([JSON.stringify(resp.bundle, null, 2)], { + type: 'application/json', + }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `credentials-${new Date().toISOString().slice(0, 10)}.json`; + a.click(); + URL.revokeObjectURL(url); + if (resp.warning) toast.warning(resp.warning); + else toast.success(`Exported ${resp.credential_count} credential(s)`); + } catch (e) { + toast.error(`Export failed: ${(e as Error).message}`); + } + }; + + const importCreds = async (file: File) => { + try { + const parsed = JSON.parse(await file.text()); + const res = await llmApi.importCredentials(parsed); + if (res.imported_count > 0) await llmApi.activateAll(); + const bits: string[] = []; + if (res.imported_count) + bits.push(`imported ${res.imported_count}`); + if (res.skipped_count) bits.push(`skipped ${res.skipped_count}`); + if (res.error_count) bits.push(`${res.error_count} error(s)`); + toast.success(bits.join(' · ') || 'Nothing to import'); + onChange?.(); + await reload(); + } catch (e) { + toast.error(`Import failed: ${(e as Error).message}`); + } + }; + + return ( + + + + + } + > + {loading && ( +
+ Loading… +
+ )} + + {builtinProvider && ( +
+
+ +
+
+ EdgeAI built-in LLM + {isBuiltinActive ? ' · active' : ' · available'} +
+

+ This deployment ships with a managed OpenAI-compatible + endpoint{builtinProvider.model ? ` (${builtinProvider.model})` : ''} + {' '}authenticated by the platform — no API key needed. You + can still register your own provider below to use as a + fallback. +

+
+
+
+ )} + + {/* Router status */} +
+

+ Router +

+
+ + + {activeName ? ( + <> + Active: {activeName} + + ) : ( + 'No active provider' + )} + + + {router?.providers?.length ?? 0} registered + {router?.routing_strategy + ? ` · strategy ${router.routing_strategy}` + : ''} + +
+
+ + {/* Credentials list */} +
+
+

+ Credentials +

+
+ + + + { + const f = e.target.files?.[0]; + if (f) void importCreds(f); + e.target.value = ''; + }} + /> +
+
+ + {creds.length === 0 ? ( +

+ {builtinProvider + ? 'No additional credentials configured. The built-in EdgeAI provider above is active — adding one here is optional.' + : 'No credentials yet. Add one to enable the agent.'} +

+ ) : ( +
    + {creds.map((c) => ( +
  • +
    +
    + {c.name} + {c.name === activeName && ( + + Active + + )} +
    +
    + {c.provider_type} · {c.model ?? 'no model'}{' '} + {c.url ? `· ${c.url}` : ''} +
    +
    + {c.name !== activeName && ( + + )} + +
  • + ))} +
+ )} +
+ + {showAdd && ( + setShowAdd(false)} + onSaved={async () => { + setShowAdd(false); + onChange?.(); + await reload(); + }} + /> + )} +
+ ); +} + +function AddCredentialForm({ + onCancel, + onSaved, +}: { + onCancel: () => void; + onSaved: () => void; +}) { + const toast = useToast(); + const [name, setName] = useState(''); + const [url, setUrl] = useState(''); + const [apiKey, setApiKey] = useState(''); + const [model, setModel] = useState(''); + const [models, setModels] = useState(null); + const [fetching, setFetching] = useState(false); + const [saving, setSaving] = useState(false); + const [supportsTools, setSupportsTools] = useState(true); + const [providerType, setProviderType] = useState('auto'); + + const detectedType = useMemo(() => { + if (providerType !== 'auto') return providerType; + const u = url.toLowerCase(); + if (!u) return 'anthropic'; + if (u.includes('anthropic.com')) return 'anthropic'; + if (u.includes('openai.com')) return 'openai'; + if (u.includes('googleapis.com')) return 'google'; + if (u.includes('groq.com')) return 'groq'; + if (u.includes('11434') || u.includes('ollama')) return 'ollama'; + return 'openai-compatible'; + }, [url, providerType]); + + const fetchModels = async () => { + setFetching(true); + try { + const { models: list } = await llmApi.fetchModels({ + provider_type: detectedType, + url: url || undefined, + api_key: apiKey || undefined, + }); + setModels(list); + if (list.length === 0) toast.warning('No models returned'); + } catch (e) { + toast.error(`Fetch models failed: ${(e as Error).message}`); + } finally { + setFetching(false); + } + }; + + const save = async () => { + if (!name.trim()) { + toast.warning('Name is required'); + return; + } + setSaving(true); + try { + await llmApi.saveCredential({ + name: name.trim(), + provider_type: providerType === 'auto' ? undefined : providerType, + url: url.trim() || undefined, + api_key: apiKey.trim() || undefined, + model: model.trim() || undefined, + supports_tools: supportsTools, + enabled: true, + }); + await llmApi.activateCredential(name.trim()); + toast.success(`Saved & activated ${name.trim()}`); + onSaved(); + } catch (e) { + toast.error(`Save failed: ${(e as Error).message}`); + } finally { + setSaving(false); + } + }; + + return ( +
+

+ New credential +

+
+ + setName(e.target.value)} + placeholder="e.g. padraig-key" + className="form-input" + /> + + + + + + setUrl(e.target.value)} + placeholder="https://api.anthropic.com" + className="form-input" + /> + + + setApiKey(e.target.value)} + type="password" + placeholder="sk-…" + className="form-input" + /> + + +
+ {models && models.length > 0 ? ( + + ) : ( + setModel(e.target.value)} + placeholder="claude-sonnet-4-6" + className="form-input flex-1" + /> + )} + +
+
+
+ +
+ + +
+
+ ); +} + +function Field({ + label, + hint, + className, + children, +}: { + label: string; + hint?: string; + className?: string; + children: React.ReactNode; +}) { + return ( + + ); +} diff --git a/edgeai/ondevice-eval-agent/frontend/src/components/ui/AutoResizeTextarea.tsx b/edgeai/ondevice-eval-agent/frontend/src/components/ui/AutoResizeTextarea.tsx new file mode 100644 index 00000000..73aeb9c0 --- /dev/null +++ b/edgeai/ondevice-eval-agent/frontend/src/components/ui/AutoResizeTextarea.tsx @@ -0,0 +1,22 @@ +import { forwardRef, useEffect, useImperativeHandle, useRef } from 'react'; +import type { TextareaHTMLAttributes } from 'react'; + +interface Props extends TextareaHTMLAttributes { + maxHeight?: number; +} + +export const AutoResizeTextarea = forwardRef( + function AutoResizeTextarea({ maxHeight = 200, value, ...rest }, ref) { + const innerRef = useRef(null); + useImperativeHandle(ref, () => innerRef.current!, []); + + useEffect(() => { + const el = innerRef.current; + if (!el) return; + el.style.height = 'auto'; + el.style.height = `${Math.min(el.scrollHeight, maxHeight)}px`; + }, [value, maxHeight]); + + return