Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions examples/fastapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
# ///

import asyncio
from collections.abc import AsyncIterator
from datetime import datetime

import uvicorn
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.responses import HTMLResponse
from fastapi.sse import EventSourceResponse, ServerSentEvent

from datastar_py.fastapi import (
DatastarResponse,
ReadSignals,
ServerSentEventGenerator,
)
Expand Down Expand Up @@ -64,7 +65,10 @@ async def read_root():
return HTMLResponse(HTML.replace("CURRENT_TIME", f"{datetime.isoformat(datetime.now())}"))


async def time_updates():
@app.get("/updates", response_class=EventSourceResponse)
async def updates(signals: ReadSignals) -> AsyncIterator[ServerSentEvent]:
# ReadSignals is a dependency that automatically loads the signals from the request
print(signals)
while True:
yield ServerSentEventGenerator.patch_elements(
f"""<span id="currentTime">{datetime.now().isoformat()}"""
Expand All @@ -76,12 +80,5 @@ async def time_updates():
await asyncio.sleep(1)


@app.get("/updates", response_class=StreamingResponse)
async def updates(signals: ReadSignals):
# ReadSignals is a dependency that automatically loads the signals from the request
print(signals)
return DatastarResponse(time_updates())


if __name__ == "__main__":
uvicorn.run(app)
136 changes: 131 additions & 5 deletions src/datastar_py/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,144 @@
from typing import Annotated, Any
from collections.abc import Mapping
from typing import Annotated, Any, Literal, overload

from fastapi import Depends
from fastapi.sse import ServerSentEvent

from .sse import SSE_HEADERS, ServerSentEventGenerator
from .starlette import DatastarResponse, datastar_response, read_signals
from datastar_py import consts
from datastar_py.attributes import SignalValue

from .sse import SSE_HEADERS, BaseServerSentEventGenerator, _HtmlProvider
from .starlette import read_signals

__all__ = [
"SSE_HEADERS",
"DatastarResponse",
"ReadSignals",
"ServerSentEventGenerator",
"datastar_response",
"read_signals",
]


ReadSignals = Annotated[dict[str, Any] | None, Depends(read_signals)]


class ServerSentEventGenerator(BaseServerSentEventGenerator):
__slots__ = ()

@overload
@classmethod
def patch_elements(
cls,
*,
selector: str,
mode: Literal[consts.ElementPatchMode.REMOVE],
use_view_transition: bool | None = None,
event_id: str | None = None,
retry_duration: int | None = None,
) -> ServerSentEvent: ...
@overload
@classmethod
def patch_elements(
cls,
elements: str | _HtmlProvider,
selector: str | None = None,
mode: consts.ElementPatchMode | None = None,
use_view_transition: bool | None = None,
event_id: str | None = None,
retry_duration: int | None = None,
) -> ServerSentEvent: ...
@classmethod
def patch_elements( # noqa: PLR0913 too many arguments
cls,
elements: str | _HtmlProvider | None = None,
selector: str | None = None,
mode: consts.ElementPatchMode | None = None,
use_view_transition: bool | None = None,
namespace: consts.ElementPatchNamespace | None = None,
event_id: str | None = None,
retry_duration: int | None = None,
) -> ServerSentEvent:
result = cls._patch_elements(
elements=elements,
selector=selector,
mode=mode,
use_view_transition=use_view_transition,
event_id=event_id,
namespace=namespace,
retry_duration=retry_duration,
)
return ServerSentEvent(
event=result["event_type"],
id=result["event_id"],
raw_data="\n".join(result["data_lines"]),
retry=result["retry_duration"],
)

@classmethod
def remove_elements(
cls, selector: str, event_id: str | None = None, retry_duration: int | None = None
) -> ServerSentEvent:
result = cls._remove_elements(
selector=selector,
event_id=event_id,
retry_duration=retry_duration,
)
return ServerSentEvent(
event=result["event_type"],
id=result["event_id"],
raw_data="\n".join(result["data_lines"]),
retry=result["retry_duration"],
)

@classmethod
def patch_signals(
cls,
signals: dict[str, SignalValue] | str,
event_id: str | None = None,
only_if_missing: bool | None = None,
retry_duration: int | None = None,
) -> ServerSentEvent:
result = cls._patch_signals(
signals=signals,
event_id=event_id,
only_if_missing=only_if_missing,
retry_duration=retry_duration,
)
return ServerSentEvent(
event=result["event_type"],
id=result["event_id"],
raw_data="\n".join(result["data_lines"]),
retry=result["retry_duration"],
)

@classmethod
def execute_script(
cls,
script: str,
auto_remove: bool = True,
attributes: Mapping[str, str] | list[str] | None = None,
event_id: str | None = None,
retry_duration: int | None = None,
) -> ServerSentEvent:
result = cls._execute_script(
script=script,
auto_remove=auto_remove,
attributes=attributes,
event_id=event_id,
retry_duration=retry_duration,
)
return ServerSentEvent(
event=result["event_type"],
id=result["event_id"],
raw_data="\n".join(result["data_lines"]),
retry=result["retry_duration"],
)

@classmethod
def redirect(cls, location: str) -> ServerSentEvent:
result = cls._redirect(location)
return ServerSentEvent(
event=result["event_type"],
id=result["event_id"],
raw_data="\n".join(result["data_lines"]),
retry=result["retry_duration"],
)
Loading