diff --git a/examples/fastapi/app.py b/examples/fastapi/app.py index a925025..6aec3fd 100644 --- a/examples/fastapi/app.py +++ b/examples/fastapi/app.py @@ -9,14 +9,15 @@ # /// import asyncio +from collections.abc import AsyncIterator from datetime import datetime import uvicorn from fastapi import FastAPI -from fastapi.responses import HTMLResponse, StreamingResponse +from fastapi.responses import HTMLResponse +from fastapi.sse import EventSourceResponse, ServerSentEvent from datastar_py.fastapi import ( - DatastarResponse, ReadSignals, ServerSentEventGenerator, ) @@ -64,7 +65,10 @@ async def read_root(): return HTMLResponse(HTML.replace("CURRENT_TIME", f"{datetime.isoformat(datetime.now())}")) -async def time_updates(): +@app.get("/updates", response_class=EventSourceResponse) +async def updates(signals: ReadSignals) -> AsyncIterator[ServerSentEvent]: + # ReadSignals is a dependency that automatically loads the signals from the request + print(signals) while True: yield ServerSentEventGenerator.patch_elements( f"""{datetime.now().isoformat()}""" @@ -76,12 +80,5 @@ async def time_updates(): await asyncio.sleep(1) -@app.get("/updates", response_class=StreamingResponse) -async def updates(signals: ReadSignals): - # ReadSignals is a dependency that automatically loads the signals from the request - print(signals) - return DatastarResponse(time_updates()) - - if __name__ == "__main__": uvicorn.run(app) diff --git a/src/datastar_py/fastapi.py b/src/datastar_py/fastapi.py index 7625b04..4a18048 100644 --- a/src/datastar_py/fastapi.py +++ b/src/datastar_py/fastapi.py @@ -1,18 +1,144 @@ -from typing import Annotated, Any +from collections.abc import Mapping +from typing import Annotated, Any, Literal, overload from fastapi import Depends +from fastapi.sse import ServerSentEvent -from .sse import SSE_HEADERS, ServerSentEventGenerator -from .starlette import DatastarResponse, datastar_response, read_signals +from datastar_py import consts +from datastar_py.attributes import SignalValue + +from .sse import SSE_HEADERS, BaseServerSentEventGenerator, _HtmlProvider +from .starlette import read_signals __all__ = [ "SSE_HEADERS", - "DatastarResponse", "ReadSignals", "ServerSentEventGenerator", - "datastar_response", "read_signals", ] ReadSignals = Annotated[dict[str, Any] | None, Depends(read_signals)] + + +class ServerSentEventGenerator(BaseServerSentEventGenerator): + __slots__ = () + + @overload + @classmethod + def patch_elements( + cls, + *, + selector: str, + mode: Literal[consts.ElementPatchMode.REMOVE], + use_view_transition: bool | None = None, + event_id: str | None = None, + retry_duration: int | None = None, + ) -> ServerSentEvent: ... + @overload + @classmethod + def patch_elements( + cls, + elements: str | _HtmlProvider, + selector: str | None = None, + mode: consts.ElementPatchMode | None = None, + use_view_transition: bool | None = None, + event_id: str | None = None, + retry_duration: int | None = None, + ) -> ServerSentEvent: ... + @classmethod + def patch_elements( # noqa: PLR0913 too many arguments + cls, + elements: str | _HtmlProvider | None = None, + selector: str | None = None, + mode: consts.ElementPatchMode | None = None, + use_view_transition: bool | None = None, + namespace: consts.ElementPatchNamespace | None = None, + event_id: str | None = None, + retry_duration: int | None = None, + ) -> ServerSentEvent: + result = cls._patch_elements( + elements=elements, + selector=selector, + mode=mode, + use_view_transition=use_view_transition, + event_id=event_id, + namespace=namespace, + retry_duration=retry_duration, + ) + return ServerSentEvent( + event=result["event_type"], + id=result["event_id"], + raw_data="\n".join(result["data_lines"]), + retry=result["retry_duration"], + ) + + @classmethod + def remove_elements( + cls, selector: str, event_id: str | None = None, retry_duration: int | None = None + ) -> ServerSentEvent: + result = cls._remove_elements( + selector=selector, + event_id=event_id, + retry_duration=retry_duration, + ) + return ServerSentEvent( + event=result["event_type"], + id=result["event_id"], + raw_data="\n".join(result["data_lines"]), + retry=result["retry_duration"], + ) + + @classmethod + def patch_signals( + cls, + signals: dict[str, SignalValue] | str, + event_id: str | None = None, + only_if_missing: bool | None = None, + retry_duration: int | None = None, + ) -> ServerSentEvent: + result = cls._patch_signals( + signals=signals, + event_id=event_id, + only_if_missing=only_if_missing, + retry_duration=retry_duration, + ) + return ServerSentEvent( + event=result["event_type"], + id=result["event_id"], + raw_data="\n".join(result["data_lines"]), + retry=result["retry_duration"], + ) + + @classmethod + def execute_script( + cls, + script: str, + auto_remove: bool = True, + attributes: Mapping[str, str] | list[str] | None = None, + event_id: str | None = None, + retry_duration: int | None = None, + ) -> ServerSentEvent: + result = cls._execute_script( + script=script, + auto_remove=auto_remove, + attributes=attributes, + event_id=event_id, + retry_duration=retry_duration, + ) + return ServerSentEvent( + event=result["event_type"], + id=result["event_id"], + raw_data="\n".join(result["data_lines"]), + retry=result["retry_duration"], + ) + + @classmethod + def redirect(cls, location: str) -> ServerSentEvent: + result = cls._redirect(location) + return ServerSentEvent( + event=result["event_type"], + id=result["event_id"], + raw_data="\n".join(result["data_lines"]), + retry=result["retry_duration"], + ) diff --git a/src/datastar_py/sse.py b/src/datastar_py/sse.py index f1837f8..f0a50e0 100644 --- a/src/datastar_py/sse.py +++ b/src/datastar_py/sse.py @@ -3,7 +3,7 @@ import json from collections.abc import AsyncIterable, Iterable, Mapping from itertools import chain -from typing import Literal, Protocol, TypeAlias, overload, runtime_checkable +from typing import Any, Literal, Protocol, TypeAlias, overload, runtime_checkable from datastar_py import consts from datastar_py.attributes import SignalValue, _escape @@ -37,53 +37,11 @@ class DatastarEvent(str): ) -class ServerSentEventGenerator: +class BaseServerSentEventGenerator: __slots__ = () @classmethod - def _send( - cls, - event_type: consts.EventType, - data_lines: list[str], - event_id: str | None = None, - retry_duration: int | None = None, - ) -> DatastarEvent: - prefix = [f"event: {event_type}"] - - if event_id: - prefix.append(f"id: {event_id}") - - if retry_duration and retry_duration != consts.DEFAULT_SSE_RETRY_DURATION: - prefix.append(f"retry: {retry_duration}") - - data_lines = [f"data: {line}" for line in data_lines] - - return DatastarEvent("\n".join(chain(prefix, data_lines)) + "\n\n") - - @overload - @classmethod - def patch_elements( - cls, - *, - selector: str, - mode: Literal[consts.ElementPatchMode.REMOVE], - use_view_transition: bool | None = None, - event_id: str | None = None, - retry_duration: int | None = None, - ) -> DatastarEvent: ... - @overload - @classmethod - def patch_elements( - cls, - elements: str | _HtmlProvider, - selector: str | None = None, - mode: consts.ElementPatchMode | None = None, - use_view_transition: bool | None = None, - event_id: str | None = None, - retry_duration: int | None = None, - ) -> DatastarEvent: ... - @classmethod - def patch_elements( # noqa: PLR0913 too many arguments + def _patch_elements( # noqa: PLR0913 too many arguments cls, elements: str | _HtmlProvider | None = None, selector: str | None = None, @@ -92,7 +50,7 @@ def patch_elements( # noqa: PLR0913 too many arguments namespace: consts.ElementPatchNamespace | None = None, event_id: str | None = None, retry_duration: int | None = None, - ) -> DatastarEvent: + ) -> dict[str, Any]: if isinstance(elements, _HtmlProvider): elements = elements.__html__() data_lines = [] @@ -115,18 +73,18 @@ def patch_elements( # noqa: PLR0913 too many arguments f"{consts.ELEMENTS_DATALINE_LITERAL} {x}" for x in elements.splitlines() ) - return ServerSentEventGenerator._send( - consts.EventType.PATCH_ELEMENTS, - data_lines, - event_id, - retry_duration, - ) + return { + "event_type": consts.EventType.PATCH_ELEMENTS, + "data_lines": data_lines, + "event_id": event_id, + "retry_duration": retry_duration, + } @classmethod - def remove_elements( + def _remove_elements( cls, selector: str, event_id: str | None = None, retry_duration: int | None = None - ) -> DatastarEvent: - return ServerSentEventGenerator.patch_elements( + ) -> dict[str, Any]: + return cls._patch_elements( selector=selector, mode=consts.ElementPatchMode.REMOVE, event_id=event_id, @@ -134,13 +92,13 @@ def remove_elements( ) @classmethod - def patch_signals( + def _patch_signals( cls, signals: dict[str, SignalValue] | str, event_id: str | None = None, only_if_missing: bool | None = None, retry_duration: int | None = None, - ) -> DatastarEvent: + ) -> dict[str, Any]: data_lines = [] if ( only_if_missing is not None @@ -157,19 +115,22 @@ def patch_signals( f"{consts.SIGNALS_DATALINE_LITERAL} {line}" for line in signals_str.splitlines() ) - return ServerSentEventGenerator._send( - consts.EventType.PATCH_SIGNALS, data_lines, event_id, retry_duration - ) + return { + "event_type": consts.EventType.PATCH_SIGNALS, + "data_lines": data_lines, + "event_id": event_id, + "retry_duration": retry_duration, + } @classmethod - def execute_script( + def _execute_script( cls, script: str, auto_remove: bool = True, attributes: Mapping[str, str] | list[str] | None = None, event_id: str | None = None, retry_duration: int | None = None, - ) -> DatastarEvent: + ) -> dict[str]: attribute_string = "" if auto_remove: attribute_string += ' data-effect="el.remove()"' @@ -182,7 +143,7 @@ def execute_script( attribute_string += " " + " ".join(attributes) script_tag = f"{script}" - return ServerSentEventGenerator.patch_elements( + return cls._patch_elements( script_tag, mode=consts.ElementPatchMode.APPEND, selector="body", @@ -191,9 +152,131 @@ def execute_script( ) @classmethod - def redirect(cls, location: str) -> DatastarEvent: - return cls.execute_script(f"setTimeout(() => window.location = '{location}')") + def _redirect(cls, location: str) -> dict[str, Any]: + return cls._execute_script(f"setTimeout(() => window.location = '{location}')") def _js_bool(b: bool) -> str: return "true" if b else "false" + + +class ServerSentEventGenerator(BaseServerSentEventGenerator): + __slots__ = () + + @classmethod + def _send( + cls, + event_type: consts.EventType, + data_lines: list[str], + event_id: str | None = None, + retry_duration: int | None = None, + ) -> DatastarEvent: + prefix = [f"event: {event_type}"] + + if event_id: + prefix.append(f"id: {event_id}") + + if retry_duration and retry_duration != consts.DEFAULT_SSE_RETRY_DURATION: + prefix.append(f"retry: {retry_duration}") + + data_lines = [f"data: {line}" for line in data_lines] + + return DatastarEvent("\n".join(chain(prefix, data_lines)) + "\n\n") + + @overload + @classmethod + def patch_elements( + cls, + *, + selector: str, + mode: Literal[consts.ElementPatchMode.REMOVE], + use_view_transition: bool | None = None, + event_id: str | None = None, + retry_duration: int | None = None, + ) -> DatastarEvent: ... + @overload + @classmethod + def patch_elements( + cls, + elements: str | _HtmlProvider, + selector: str | None = None, + mode: consts.ElementPatchMode | None = None, + use_view_transition: bool | None = None, + event_id: str | None = None, + retry_duration: int | None = None, + ) -> DatastarEvent: ... + @classmethod + def patch_elements( # noqa: PLR0913 too many arguments + cls, + elements: str | _HtmlProvider | None = None, + selector: str | None = None, + mode: consts.ElementPatchMode | None = None, + use_view_transition: bool | None = None, + namespace: consts.ElementPatchNamespace | None = None, + event_id: str | None = None, + retry_duration: int | None = None, + ) -> DatastarEvent: + return cls._send( + **cls._patch_elements( + elements=elements, + selector=selector, + mode=mode, + use_view_transition=use_view_transition, + namespace=namespace, + event_id=event_id, + retry_duration=retry_duration, + ) + ) + + @classmethod + def remove_elements( + cls, selector: str, event_id: str | None = None, retry_duration: int | None = None + ) -> DatastarEvent: + return cls._send( + **cls._patch_elements( + selector=selector, + mode=consts.ElementPatchMode.REMOVE, + event_id=event_id, + retry_duration=retry_duration, + ) + ) + + @classmethod + def patch_signals( + cls, + signals: dict[str, SignalValue] | str, + event_id: str | None = None, + only_if_missing: bool | None = None, + retry_duration: int | None = None, + ) -> DatastarEvent: + return cls._send( + **cls._patch_signals( + signals=signals, + event_id=event_id, + only_if_missing=only_if_missing, + retry_duration=retry_duration, + ) + ) + + @classmethod + def execute_script( + cls, + script: str, + auto_remove: bool = True, + attributes: Mapping[str, str] | list[str] | None = None, + event_id: str | None = None, + retry_duration: int | None = None, + ) -> DatastarEvent: + return cls._send( + **cls._execute_script( + script=script, + auto_remove=auto_remove, + attributes=attributes, + event_id=event_id, + retry_duration=retry_duration, + ) + ) + + @classmethod + def redirect(cls, location: str) -> DatastarEvent: + return cls._send(**cls._redirect(location))