diff --git a/examples/sd15-image/README.md b/examples/sd15-image/README.md new file mode 100644 index 0000000..3973c2b --- /dev/null +++ b/examples/sd15-image/README.md @@ -0,0 +1,20 @@ +# sd15-image + +Real GPU inference: Stable Diffusion 1.5 text-to-image. Send +`{"prompt": "a lighthouse at dusk"}`, get back a PNG. + +## What it demonstrates + +- `@endpoint(model=HF(...), resources=Resources(vram_gb=6))` — a + HuggingFace model binding (fp16 variant only, via `files=` allow + patterns) downloaded through `ensure_local` and injected into + `setup()` as a constructed `StableDiffusionPipeline` (dtype and + device placement are worker policy, not endpoint code). +- Image outputs through `gen_worker.io.write_image` — PNGs exceed the + inline threshold, so they ride the stored blob_ref path. +- GPU jobs serialize on the worker's GPU semaphore; no CUDA OOM under + concurrent submits. + +Driven end to end by the cozy e2e J4 GPU journey (`task e2e-gpu` in the +e2e repo): cold HF download -> IN_VRAM residency -> 8 real generations -> +billing captured per request. diff --git a/examples/sd15-image/pyproject.toml b/examples/sd15-image/pyproject.toml new file mode 100644 index 0000000..b5b05ca --- /dev/null +++ b/examples/sd15-image/pyproject.toml @@ -0,0 +1,27 @@ +[project] +name = "sd15_image" +version = "1.0.0" +description = "Stable Diffusion 1.5 text-to-image example (real GPU inference)" +requires-python = ">=3.11,<3.13" +dependencies = [ + "gen-worker>=0.8.3", + "diffusers>=0.31", + "transformers>=4.46", + "accelerate>=1.0", + "msgspec>=0.18", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/sd15_image"] + +[tool.gen_worker] +main = "sd15_image.main" + +# The example tracks the in-repo gen-worker (the @endpoint API it uses is +# unreleased); published wheels can't run this source. +[tool.uv.sources] +gen-worker = { path = "../..", editable = true } diff --git a/examples/sd15-image/src/sd15_image/__init__.py b/examples/sd15-image/src/sd15_image/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/sd15-image/src/sd15_image/main.py b/examples/sd15-image/src/sd15_image/main.py new file mode 100644 index 0000000..22881d6 --- /dev/null +++ b/examples/sd15-image/src/sd15_image/main.py @@ -0,0 +1,69 @@ +"""sd15-image — real GPU inference: Stable Diffusion 1.5 text-to-image. + +The platform downloads the HF snapshot through ``ensure_local`` (fp16 +variant only), constructs the pipeline from the ``setup()`` annotation +(dtype + placement are worker policy), and the handler returns the PNG +through the typed-parts output path (>64KB rides blob_ref, not inline). +""" + +from __future__ import annotations + +import msgspec + +from diffusers import StableDiffusionPipeline +from gen_worker import HF, RequestContext, Resources, ValidationError, endpoint +from gen_worker import io as gw_io +from gen_worker.api.types import Asset + + +class SD15Input(msgspec.Struct): + prompt: str = "" + steps: int = 12 + width: int = 512 + height: int = 512 + seed: int = 0 + + +class SD15Output(msgspec.Struct): + image: Asset + width: int + height: int + seed: int + + +@endpoint( + model=HF( + "stable-diffusion-v1-5/stable-diffusion-v1-5", + dtype="fp16", + files=("*.json", "*.txt", "*.fp16.safetensors"), + ), + resources=Resources(vram_gb=6), +) +class SD15Image: + def setup(self, model: StableDiffusionPipeline) -> None: + self._pipe = model + + def generate(self, ctx: RequestContext, p: SD15Input) -> SD15Output: + import torch + + if not str(p.prompt or "").strip(): + raise ValidationError("prompt required") + if p.width % 8 or p.height % 8 or not (64 <= p.width <= 1024) or not (64 <= p.height <= 1024): + raise ValidationError("width/height must be multiples of 8 in [64, 1024]") + steps = max(1, min(int(p.steps), 50)) + + ctx.raise_if_cancelled() + generator = torch.Generator(device=self._pipe.device).manual_seed(int(p.seed)) + image = self._pipe( + p.prompt, + num_inference_steps=steps, + width=p.width, + height=p.height, + generator=generator, + ).images[0] + + asset = gw_io.write_image(ctx, "image", image, format="png") + return SD15Output(image=asset, width=image.width, height=image.height, seed=p.seed) + + +__all__ = ["SD15Image", "SD15Input", "SD15Output"]