diff --git a/backend/app/core/storage_utils.py b/backend/app/core/storage_utils.py index 155627c13..dcf4fe02d 100644 --- a/backend/app/core/storage_utils.py +++ b/backend/app/core/storage_utils.py @@ -11,12 +11,13 @@ from datetime import datetime from io import BytesIO from pathlib import Path +from typing import Literal from urllib.parse import unquote, urlparse +from uuid import UUID from starlette.datastructures import Headers, UploadFile from app.core.cloud.storage import CloudStorage, CloudStorageError -from typing import Literal logger = logging.getLogger(__name__) @@ -207,6 +208,46 @@ def load_json_from_object_store(storage: CloudStorage, url: str) -> list | dict return None +_MIME_TO_EXT: dict[str, str] = { + "audio/mpeg": "mp3", + "audio/mp3": "mp3", + "audio/ogg": "ogg", + "audio/wav": "wav", + "audio/wave": "wav", + "audio/x-wav": "wav", + "audio/webm": "webm", + "audio/mp4": "mp4", + "audio/aac": "aac", + "audio/flac": "flac", +} + + +def upload_audio_bytes_to_s3( + storage: CloudStorage, + audio_bytes: bytes, + call_id: UUID, + mime_type: str | None, + prefix: str, +) -> str | None: + """Upload decoded audio bytes to S3 and return the s3:// URI. + + Args: + storage: CloudStorage instance + audio_bytes: Raw audio bytes + call_id: LLM call UUID used as the filename stem + mime_type: MIME type of the audio (determines file extension) + prefix: S3 subdirectory, e.g. "llm/tts/audio" or "llm/stt/audio" + + Returns: + s3:// URI if successful, None on failure + """ + ext = _MIME_TO_EXT.get(mime_type or "", "wav") + filename = f"{call_id}.{ext}" + return upload_to_object_store( + storage, audio_bytes, filename, prefix, mime_type or "audio/wav" + ) + + def generate_timestamped_filename(base_name: str, extension: str = "csv") -> str: """ Generate a filename with timestamp. diff --git a/backend/app/crud/llm.py b/backend/app/crud/llm.py index c7f5b1aee..b56f25e20 100644 --- a/backend/app/crud/llm.py +++ b/backend/app/crud/llm.py @@ -31,6 +31,15 @@ def serialize_input(query_input: QueryInput | str) -> str: elif isinstance(query_input, TextInput): return query_input.content.value elif isinstance(query_input, AudioInput): + if query_input.content.format == "url": + return json.dumps( + { + "type": "audio", + "format": "url", + "mime_type": query_input.content.mime_type, + "url": query_input.content.value, + } + ) return json.dumps( { "type": "audio", @@ -187,11 +196,12 @@ def update_llm_call_response( db_llm_call.provider_response_id = provider_response_id if content is not None: - # For audio outputs (AudioOutput model): calculate size metadata from base64 content - # AudioOutput serializes as: {"type": "audio", "content": {"format": "base64", "value": "...", "mime_type": "..."}} + # For audio outputs: calculate size only when content is still base64 (not a URI) if content.get("type") == "audio": - audio_value = content.get("content", {}).get("value") - if audio_value: + audio_content = content.get("content", {}) + audio_format = audio_content.get("format") + audio_value = audio_content.get("value") + if audio_value and audio_format == "base64": try: audio_data = base64.b64decode(audio_value) content["audio_size_bytes"] = len(audio_data) @@ -218,6 +228,27 @@ def update_llm_call_response( return db_llm_call +def update_llm_call_input( + session: Session, + llm_call_id: UUID, + s3_uri: str, +) -> None: + """Overwrite llm_call.input with an S3 URI after uploading STT audio.""" + db_llm_call = session.get(LlmCall, llm_call_id) + if not db_llm_call: + logger.warning( + f"[update_llm_call_input] LLM call not found | llm_call_id={llm_call_id}" + ) + return + db_llm_call.input = s3_uri + db_llm_call.updated_at = now() + session.add(db_llm_call) + session.commit() + logger.info( + f"[update_llm_call_input] Updated input URI | llm_call_id={llm_call_id}" + ) + + def get_llm_call_by_id( session: Session, llm_call_id: UUID, diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index 80aaa3008..ca6ae30c0 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -102,13 +102,19 @@ class TextContent(SQLModel): class AudioContent(SQLModel): - format: Literal["base64"] = "base64" - value: str = Field(..., description="Base64 encoded audio") + format: Literal["base64", "url"] = "base64" + value: str = Field( + ..., description="Base64 encoded audio or public URL to download from" + ) # keeping the mime_type liberal here, since does not affect base64 encoding mime_type: str | None = Field( None, description="MIME type of the audio (e.g., audio/wav, audio/mp3, audio/ogg)", ) + uri: str | None = Field( + None, + description="Presigned URL to the audio file in object storage (when available)", + ) class ImageContent(SQLModel): diff --git a/backend/app/services/llm/chain/executor.py b/backend/app/services/llm/chain/executor.py index a7bde2799..3dcc30453 100644 --- a/backend/app/services/llm/chain/executor.py +++ b/backend/app/services/llm/chain/executor.py @@ -2,6 +2,7 @@ from sqlmodel import Session +from app.core.cloud.storage import get_cloud_storage from app.core.db import engine from app.crud.jobs import JobCrud from app.crud.llm_chain import update_llm_chain_block_completed, update_llm_chain_status @@ -10,7 +11,11 @@ ChainStatus, LLMChainRequest, ) -from app.models.llm.response import IntermediateChainResponse, LLMChainResponse +from app.models.llm.response import ( + AudioOutput, + IntermediateChainResponse, + LLMChainResponse, +) from app.services.llm.chain.chain import ChainContext, LLMChain from app.services.llm.chain.types import BlockResult from app.utils import APIResponse, get_webhook_secret, send_callback @@ -65,10 +70,33 @@ def _setup(self) -> None: self._context.project_id, self._context.organization_id ) + def _resolve_presigned_url(self, output) -> None: + """Swap the s3:// URI in content.uri for a presigned URL in-place. + + Non-fatal: clears uri on failure so clients don't receive a raw s3:// address. + """ + if isinstance(output, AudioOutput) and output.content.uri: + try: + with Session(engine) as session: + storage = get_cloud_storage(session, self._context.project_id) + output.content.uri = storage.get_signed_url( + output.content.uri, expires_in=3600 + ) + except Exception as e: + logger.warning( + f"[_resolve_presigned_url] Failed to generate presigned URL: {e} | " + f"job_id={self._context.job_id}", + exc_info=True, + ) + output.content.uri = None + def _teardown(self, result: BlockResult) -> dict: """Finalize chain record, send callback, and update job status.""" if result.success: + if result.response: + self._resolve_presigned_url(result.response.response.output) + final = LLMChainResponse( response=result.response.response, usage=result.usage, @@ -159,6 +187,9 @@ def _send_intermediate_callback( ) -> None: """Send intermediate callback for a completed block.""" try: + if result.response: + self._resolve_presigned_url(result.response.response.output) + intermediate = IntermediateChainResponse( block_index=block_index + 1, total_blocks=self._context.total_blocks, diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 27fe4e28b..2bee5f02f 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -1,3 +1,5 @@ +import base64 +import json import logging import time from contextlib import contextmanager @@ -26,7 +28,12 @@ from app.crud.config import ConfigVersionCrud from app.crud.credentials import get_provider_credential from app.crud.jobs import JobCrud -from app.crud.llm import create_llm_call, serialize_input, update_llm_call_response +from app.crud.llm import ( + create_llm_call, + serialize_input, + update_llm_call_input, + update_llm_call_response, +) from app.crud.llm_chain import create_llm_chain, update_llm_chain_status from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, LLMChainRequest from app.models.llm.request import ( @@ -41,7 +48,15 @@ TextContent, TextInput, ) -from app.models.llm.response import LLMCallResponse, LLMResponse, TextOutput, Usage +from app.core.cloud.storage import get_cloud_storage +from app.core.storage_utils import upload_audio_bytes_to_s3 +from app.models.llm.response import ( + AudioOutput, + LLMCallResponse, + LLMResponse, + TextOutput, + Usage, +) from app.services.llm.chain.types import BlockResult from app.services.llm.guardrails import ( list_validators_config, @@ -52,6 +67,7 @@ from app.utils import ( APIResponse, cleanup_temp_file, + download_audio_bytes, get_webhook_secret, resolve_input, send_callback, @@ -553,6 +569,60 @@ def execute_llm_call( error=f"Failed to create LLM call record: {str(e)}" ) + # Upload STT input audio to S3 and overwrite llm_call.input with the URI. + # Failures are non-fatal: the job proceeds and the provider still gets the original input. + if ( + isinstance(query.input, AudioInput) + and query.input.content.format in ("base64", "url") + and llm_call_id + ): + try: + if query.input.content.format == "url": + stt_bytes, dl_error = download_audio_bytes( + query.input.content.value + ) + if dl_error or not stt_bytes: + raise ValueError(dl_error or "Empty audio bytes from URL") + # Rewrite to base64 in-place so the provider resolve path + # reuses these bytes instead of issuing a second HTTP download. + query.input.content.value = base64.b64encode(stt_bytes).decode() + query.input.content.format = "base64" + else: + stt_bytes = base64.b64decode(query.input.content.value) + + storage = get_cloud_storage(session, project_id) + subfolder_path = f"orgs/{organization_id}/{project_id}/audio/stt" + s3_url = upload_audio_bytes_to_s3( + storage, + stt_bytes, + llm_call_id, + query.input.content.mime_type, + subfolder_path, + ) + if s3_url: + stt_input_record = json.dumps( + { + "type": "audio", + "format": "uri", + "mime_type": query.input.content.mime_type, + "size_bytes": len(stt_bytes), + "uri": s3_url, + } + ) + update_llm_call_input(session, llm_call_id, stt_input_record) + logger.info( + f"[execute_llm_call] STT audio uploaded to S3 | llm_call_id={llm_call_id}" + ) + else: + logger.warning( + f"[execute_llm_call] STT S3 upload failed | llm_call_id={llm_call_id}" + ) + except Exception as e: + logger.warning( + f"[execute_llm_call] STT S3 upload error, continuing: {e} | llm_call_id={llm_call_id}", + exc_info=True, + ) + try: provider_instance = get_llm_provider( session=session, @@ -650,6 +720,59 @@ def execute_llm_call( ) if response: + # db_content is what gets persisted — URI-only for TTS to avoid storing + # large base64 payloads. The in-memory response keeps base64 + uri field + # so existing clients continue to receive base64 unchanged. + db_content = ( + response.response.output.model_dump() + if response.response.output + else None + ) + + tts_output = response.response.output + if ( + isinstance(tts_output, AudioOutput) + and tts_output.content.format == "base64" + and llm_call_id + ): + try: + with Session(engine) as s3_session: + storage = get_cloud_storage(s3_session, project_id) + tts_bytes = base64.b64decode(tts_output.content.value) + subfolder_path = f"orgs/{organization_id}/{project_id}/audio/tts" + s3_url = upload_audio_bytes_to_s3( + storage, + tts_bytes, + llm_call_id, + tts_output.content.mime_type, + subfolder_path, + ) + if s3_url: + # Keep base64 in the response object for backward-compatible clients. + # Set uri so execute_job can swap it for a presigned URL. + tts_output.content.uri = s3_url + # Store only the URI in the DB — not the full base64. + db_content = { + "type": "audio", + "content": { + "format": "uri", + "value": s3_url, + "mime_type": tts_output.content.mime_type, + }, + } + logger.info( + f"[execute_llm_call] TTS audio uploaded to S3 | llm_call_id={llm_call_id}" + ) + else: + logger.warning( + f"[execute_llm_call] TTS S3 upload failed, keeping base64 | llm_call_id={llm_call_id}" + ) + except Exception as e: + logger.warning( + f"[execute_llm_call] TTS S3 upload error, keeping base64: {e} | llm_call_id={llm_call_id}", + exc_info=True, + ) + with Session(engine) as session: if llm_call_id: with tracer.start_as_current_span( @@ -662,7 +785,7 @@ def execute_llm_call( session, llm_call_id=llm_call_id, provider_response_id=response.response.provider_response_id, - content=response.response.output.model_dump(), + content=db_content, usage=response.usage.model_dump(), conversation_id=response.response.conversation_id, ) @@ -802,6 +925,25 @@ def execute_job( ) if result.success: + # Swap the s3:// URI in content.uri for a short-lived presigned URL. + # content.value (base64) is untouched — existing clients keep working. + # On failure, clear uri so clients don't receive a raw s3:// address. + if result.response: + tts_out = result.response.response.output + if isinstance(tts_out, AudioOutput) and tts_out.content.uri: + try: + with Session(engine) as s3_session: + storage = get_cloud_storage(s3_session, project_id) + tts_out.content.uri = storage.get_signed_url( + tts_out.content.uri, expires_in=3600 + ) + except Exception as e: + logger.warning( + f"[execute_job] Failed to generate presigned URL: {e} | job_id={job_uuid}", + exc_info=True, + ) + tts_out.content.uri = None + callback_response = APIResponse.success_response( data=result.response, metadata=result.metadata ) diff --git a/backend/app/utils.py b/backend/app/utils.py index d7bb8f0ce..9d46e890c 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -592,6 +592,39 @@ def resolve_audio_base64(data: str, mime_type: str) -> tuple[str, str | None]: return "", f"Failed to write audio to temp file: {str(e)}" +def download_audio_bytes(url: str) -> tuple[bytes | None, str | None]: + """Download audio from a public URL. Returns (bytes, error).""" + try: + response = requests.get(url, timeout=30) + response.raise_for_status() + return response.content, None + except requests.exceptions.Timeout: + return None, f"Timed out downloading audio from URL: {url}" + except requests.exceptions.HTTPError as e: + return None, f"HTTP {e.response.status_code} downloading audio from URL: {url}" + except Exception as e: + return None, f"Failed to download audio from URL: {str(e)}" + + +def resolve_audio_url(url: str, mime_type: str) -> tuple[str, str | None]: + """Download audio from a public URL and write to temp file. Returns (file_path, error).""" + audio_bytes, error = download_audio_bytes(url) + if error: + return "", error + + ext = get_file_extension(mime_type) + try: + with tempfile.NamedTemporaryFile( + suffix=ext, delete=False, prefix="audio_" + ) as tmp: + tmp.write(audio_bytes) + temp_path = tmp.name + logger.info(f"[resolve_audio_url] Downloaded audio to temp file: {temp_path}") + return temp_path, None + except Exception as e: + return "", f"Failed to write audio to temp file: {str(e)}" + + def resolve_image_content(image_input: ImageInput) -> list[ImageContent]: contents = ( image_input.content @@ -635,6 +668,8 @@ def resolve_input( elif isinstance(query_input, AudioInput): mime_type = query_input.content.mime_type or "audio/wav" + if query_input.content.format == "url": + return resolve_audio_url(query_input.content.value, mime_type) return resolve_audio_base64(query_input.content.value, mime_type) elif isinstance(query_input, ImageInput):