diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..9e2a8d2 --- /dev/null +++ b/Makefile @@ -0,0 +1,7 @@ +.PHONY: test lint + +test: + uv run --dev pytest + +lint: + uv run --dev pre-commit run --all-files diff --git a/README.md b/README.md index a861838..63d2653 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,10 @@ headers of the SSE response yourself. A datastar response consists of 0..N datastar events. There are response classes included to make this easy in all of the supported frameworks. +Each framework also exposes a `@datastar_response` decorator that will wrap +return values (including generators) into the right response class while +preserving sync handlers as sync so frameworks can keep them in their +threadpools. The following examples will work across all supported frameworks when the response class is imported from the appropriate framework package. diff --git a/pyproject.toml b/pyproject.toml index 7c9c70d..8344cdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,12 +48,14 @@ urls.GitHub = "https://github.com/starfederation/datastar-python" dev = [ "django>=4.2.23", "fastapi>=0.116.1", + "httpx>=0.27", "litestar>=2.17", "pre-commit>=4.2", "python-fasthtml>=0.12.25; python_full_version>='3.10'", "quart>=0.20", "sanic>=25.3", "starlette>=0.47.3", + "uvicorn>=0.30", ] [tool.ruff] @@ -88,5 +90,6 @@ lint.ignore = [ "E501", ] lint.per-file-ignores."examples/**/*.py" = [ "ANN", "DTZ005", "PLC0415" ] +lint.per-file-ignores."tests/**/*.py" = [ "ANN", "PLC0415", "PLR2004" ] lint.fixable = [ "ALL" ] lint.pylint.allow-magic-value-types = [ "int", "str" ] diff --git a/src/datastar_py/django.py b/src/datastar_py/django.py index 64e39e9..955a343 100644 --- a/src/datastar_py/django.py +++ b/src/datastar_py/django.py @@ -2,6 +2,7 @@ from collections.abc import Awaitable, Callable, Mapping from functools import wraps +from inspect import isasyncgenfunction, isawaitable, iscoroutinefunction from typing import Any, ParamSpec from django.http import HttpRequest @@ -45,20 +46,30 @@ def __init__( def datastar_response( func: Callable[P, Awaitable[DatastarEvents] | DatastarEvents], -) -> Callable[P, Awaitable[DatastarResponse]]: +) -> Callable[P, Awaitable[DatastarResponse] | DatastarResponse]: """A decorator which wraps a function result in DatastarResponse. Can be used on a sync or async function or generator function. + Preserves the sync/async nature of the decorated function. """ + if iscoroutinefunction(func) or isasyncgenfunction(func): + + @wraps(func) + async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: + result = func(*args, **kwargs) + if isawaitable(result): + result = await result + return DatastarResponse(result) + + async_wrapper.__annotations__["return"] = DatastarResponse + return async_wrapper @wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: - r = func(*args, **kwargs) - if isinstance(r, Awaitable): - return DatastarResponse(await r) - return DatastarResponse(r) + def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: + return DatastarResponse(func(*args, **kwargs)) - return wrapper + sync_wrapper.__annotations__["return"] = DatastarResponse + return sync_wrapper def read_signals(request: HttpRequest) -> dict[str, Any] | None: diff --git a/src/datastar_py/litestar.py b/src/datastar_py/litestar.py index 6e3590f..a692536 100644 --- a/src/datastar_py/litestar.py +++ b/src/datastar_py/litestar.py @@ -2,6 +2,7 @@ from collections.abc import Awaitable, Callable, Mapping from functools import wraps +from inspect import isasyncgenfunction, isawaitable, iscoroutinefunction from typing import ( TYPE_CHECKING, Any, @@ -64,21 +65,30 @@ def __init__( def datastar_response( func: Callable[P, Awaitable[DatastarEvents] | DatastarEvents], -) -> Callable[P, Awaitable[DatastarResponse]]: +) -> Callable[P, Awaitable[DatastarResponse] | DatastarResponse]: """A decorator which wraps a function result in DatastarResponse. Can be used on a sync or async function or generator function. + Preserves the sync/async nature of the decorated function. """ + if iscoroutinefunction(func) or isasyncgenfunction(func): + + @wraps(func) + async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: + result = func(*args, **kwargs) + if isawaitable(result): + result = await result + return DatastarResponse(result) + + async_wrapper.__annotations__["return"] = DatastarResponse + return async_wrapper @wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: - r = func(*args, **kwargs) - if isinstance(r, Awaitable): - return DatastarResponse(await r) - return DatastarResponse(r) - - wrapper.__annotations__["return"] = DatastarResponse - return wrapper + def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: + return DatastarResponse(func(*args, **kwargs)) + + sync_wrapper.__annotations__["return"] = DatastarResponse + return sync_wrapper async def read_signals(request: Request) -> dict[str, Any] | None: diff --git a/src/datastar_py/quart.py b/src/datastar_py/quart.py index 1866523..6395529 100644 --- a/src/datastar_py/quart.py +++ b/src/datastar_py/quart.py @@ -2,7 +2,7 @@ from collections.abc import Awaitable, Callable, Mapping from functools import wraps -from inspect import isasyncgen, isasyncgenfunction, isgenerator +from inspect import isasyncgen, isasyncgenfunction, iscoroutinefunction, isgenerator from typing import Any, ParamSpec from quart import Response, copy_current_request_context, request, stream_with_context @@ -43,20 +43,37 @@ def __init__( def datastar_response( func: Callable[P, Awaitable[DatastarEvents] | DatastarEvents], -) -> Callable[P, Awaitable[DatastarResponse]]: +) -> Callable[P, Awaitable[DatastarResponse] | DatastarResponse]: """A decorator which wraps a function result in DatastarResponse. Can be used on a sync or async function or generator function. + Preserves the sync/async nature of the decorated function. """ + # Async generators require stream_with_context wrapping at decoration time + if isasyncgenfunction(func): - @wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: - if isasyncgenfunction(func): + @wraps(func) + async def async_gen_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: return DatastarResponse(stream_with_context(func)(*args, **kwargs)) - return DatastarResponse(await copy_current_request_context(func)(*args, **kwargs)) - wrapper.__annotations__["return"] = DatastarResponse - return wrapper + async_gen_wrapper.__annotations__["return"] = DatastarResponse + return async_gen_wrapper + + if iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: + return DatastarResponse(await copy_current_request_context(func)(*args, **kwargs)) + + async_wrapper.__annotations__["return"] = DatastarResponse + return async_wrapper + + @wraps(func) + def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: + return DatastarResponse(func(*args, **kwargs)) + + sync_wrapper.__annotations__["return"] = DatastarResponse + return sync_wrapper async def read_signals() -> dict[str, Any] | None: diff --git a/src/datastar_py/sanic.py b/src/datastar_py/sanic.py index 878ac8e..4786443 100644 --- a/src/datastar_py/sanic.py +++ b/src/datastar_py/sanic.py @@ -3,7 +3,7 @@ from collections.abc import Awaitable, Callable, Collection, Mapping from contextlib import aclosing, closing from functools import wraps -from inspect import isasyncgen, isgenerator +from inspect import isasyncgen, isawaitable, isgenerator from typing import Any, ParamSpec from sanic import HTTPResponse, Request @@ -70,7 +70,7 @@ def datastar_response( @wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse | None: r = func(*args, **kwargs) - if isinstance(r, Awaitable): + if isawaitable(r): return DatastarResponse(await r) if isasyncgen(r): request = args[0] diff --git a/src/datastar_py/starlette.py b/src/datastar_py/starlette.py index 60b4aac..abe1c13 100644 --- a/src/datastar_py/starlette.py +++ b/src/datastar_py/starlette.py @@ -2,6 +2,7 @@ from collections.abc import Awaitable, Callable, Mapping from functools import wraps +from inspect import isasyncgenfunction, isawaitable, iscoroutinefunction from typing import ( TYPE_CHECKING, Any, @@ -53,21 +54,30 @@ def __init__( def datastar_response( func: Callable[P, Awaitable[DatastarEvents] | DatastarEvents], -) -> Callable[P, Awaitable[DatastarResponse]]: +) -> Callable[P, Awaitable[DatastarResponse] | DatastarResponse]: """A decorator which wraps a function result in DatastarResponse. Can be used on a sync or async function or generator function. + Preserves the sync/async nature of the decorated function. """ + if iscoroutinefunction(func) or isasyncgenfunction(func): + + @wraps(func) + async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: + result = func(*args, **kwargs) + if isawaitable(result): + result = await result + return DatastarResponse(result) + + async_wrapper.__annotations__["return"] = DatastarResponse + return async_wrapper @wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: - r = func(*args, **kwargs) - if isinstance(r, Awaitable): - return DatastarResponse(await r) - return DatastarResponse(r) - - wrapper.__annotations__["return"] = DatastarResponse - return wrapper + def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> DatastarResponse: + return DatastarResponse(func(*args, **kwargs)) + + sync_wrapper.__annotations__["return"] = DatastarResponse + return sync_wrapper async def read_signals(request: Request) -> dict[str, Any] | None: diff --git a/tests/test_datastar_decorator_runtime.py b/tests/test_datastar_decorator_runtime.py new file mode 100644 index 0000000..f7f2038 --- /dev/null +++ b/tests/test_datastar_decorator_runtime.py @@ -0,0 +1,82 @@ +"""Runtime regression test for datastar_response: sync handlers must not stall the event loop.""" + +from __future__ import annotations + +import threading +import time +from typing import Any + +import anyio +import httpx +import pytest +import uvicorn +from starlette.applications import Starlette +from starlette.responses import PlainTextResponse +from starlette.routing import Route + +from datastar_py.sse import ServerSentEventGenerator as SSE + + +@pytest.fixture +def anyio_backend() -> str: + """Limit anyio plugin to asyncio backend for these tests.""" + return "asyncio" + + +async def _fetch( + client: httpx.AsyncClient, path: str, timings: dict[str, float], key: str +) -> None: + start = time.perf_counter() + resp = await client.get(path, timeout=5.0) + timings[key] = time.perf_counter() - start + resp.raise_for_status() + + +@pytest.mark.anyio("asyncio") +async def test_sync_handler_runs_off_event_loop() -> None: + """Sync routes should stay in the threadpool; otherwise they block the event loop.""" + entered = threading.Event() + + from datastar_py.starlette import datastar_response + + @datastar_response + def slow(request) -> Any: + entered.set() + time.sleep(1.0) # if run on the event loop, this blocks other requests + return SSE.patch_signals({"slow": True}) + + async def ping(request) -> PlainTextResponse: + return PlainTextResponse("pong") + + app = Starlette(routes=[Route("/slow", slow), Route("/ping", ping)]) + + config = uvicorn.Config(app, host="127.0.0.1", port=0, log_level="warning", lifespan="off") + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + try: + # Wait for server to start and expose sockets + for _ in range(50): + if server.started and getattr(server, "servers", None): + break + await anyio.sleep(0.05) + else: + pytest.fail("Server did not start") + + sock = server.servers[0].sockets[0] + host, port = sock.getsockname()[:2] + base_url = f"http://{host}:{port}" + + async with httpx.AsyncClient(base_url=base_url) as client: + timings: dict[str, float] = {} + async with anyio.create_task_group() as tg: + tg.start_soon(_fetch, client, "/slow", timings, "slow") + await anyio.to_thread.run_sync(entered.wait, 1.0) + tg.start_soon(_fetch, client, "/ping", timings, "ping") + + assert timings["slow"] >= 0.9 + assert timings["ping"] < 0.3, "Ping should not be blocked by slow sync handler" + finally: + server.should_exit = True + thread.join(timeout=2) diff --git a/tests/test_decorator_matrix.py b/tests/test_decorator_matrix.py new file mode 100644 index 0000000..00fd50f --- /dev/null +++ b/tests/test_decorator_matrix.py @@ -0,0 +1,110 @@ +"""Matrix tests for datastar_response across frameworks and callable types.""" + +from __future__ import annotations + +import importlib +import inspect +from collections.abc import Iterable +from typing import Any + +import pytest + +from datastar_py.sse import ServerSentEventGenerator as SSE + +FRAMEWORKS = [ + # name, module path, iterator attribute on response (None means use response directly) + ("starlette", "datastar_py.starlette", "body_iterator"), + ("fasthtml", "datastar_py.fasthtml", "body_iterator"), + ("fastapi", "datastar_py.fastapi", "body_iterator"), + ("litestar", "datastar_py.litestar", "iterator"), + ("django", "datastar_py.django", None), + # Quart and Sanic need full request contexts; covered elsewhere + ("quart", "datastar_py.quart", None), + ("sanic", "datastar_py.sanic", None), +] + + +@pytest.fixture +def anyio_backend() -> str: + """Limit anyio plugin to asyncio backend for these tests.""" + return "asyncio" + + +def _require_module(module_path: str) -> Any: + if not importlib.util.find_spec(module_path): + pytest.skip(f"{module_path} not installed") + return importlib.import_module(module_path) + + +async def _collect_events(resp: Any, iterator_attr: str | None) -> list[Any]: + """Gather events from response regardless of iterator style.""" + iterator = getattr(resp, iterator_attr) if iterator_attr else resp + events: list[Any] = [] + + if hasattr(iterator, "__aiter__"): + async for event in iterator: # type: ignore[has-type] + events.append(event) + elif isinstance(iterator, Iterable): + for event in iterator: + events.append(event) + else: + raise TypeError(f"Cannot iterate response events for {type(resp)}") + + return events + + +@pytest.mark.anyio +@pytest.mark.parametrize("framework_name,module_path,iterator_attr", FRAMEWORKS) +@pytest.mark.parametrize( + "variant", + ["sync_value", "sync_generator", "async_value", "async_generator"], +) +async def test_datastar_response_matrix( + framework_name: str, module_path: str, iterator_attr: str | None, variant: str +) -> None: + """Ensure decorator works for sync/async and generator/non-generator functions.""" + if framework_name in {"quart", "sanic"}: + pytest.skip(f"{framework_name} decorator requires full request context to exercise") + if framework_name == "django": + from django.conf import settings + + if not settings.configured: + settings.configure(DEFAULT_CHARSET="utf-8") + + mod = _require_module(module_path) + datastar_response = mod.datastar_response + DatastarResponse = mod.DatastarResponse + + if variant == "sync_value": + + @datastar_response + def handler() -> Any: + return SSE.patch_signals({"ok": True}) + elif variant == "sync_generator": + + @datastar_response + def handler() -> Any: + yield SSE.patch_signals({"ok": True}) + elif variant == "async_value": + + @datastar_response + async def handler() -> Any: + return SSE.patch_signals({"ok": True}) + else: + + @datastar_response + async def handler() -> Any: + yield SSE.patch_signals({"ok": True}) + + result = handler() + try: + if inspect.isawaitable(result): + result = await result + + assert isinstance(result, DatastarResponse) + events = await _collect_events(result, iterator_attr) + assert events, "Expected at least one event from response iterator" + finally: + # Avoid "coroutine was never awaited" warnings when assertions fail + if inspect.iscoroutine(result): + result.close()