Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions examples/sd15-image/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# sd15-image

Real GPU inference: Stable Diffusion 1.5 text-to-image. Send
`{"prompt": "a lighthouse at dusk"}`, get back a PNG.

## What it demonstrates

- `@endpoint(model=HF(...), resources=Resources(vram_gb=6))` — a
HuggingFace model binding (fp16 variant only, via `files=` allow
patterns) downloaded through `ensure_local` and injected into
`setup()` as a constructed `StableDiffusionPipeline` (dtype and
device placement are worker policy, not endpoint code).
- Image outputs through `gen_worker.io.write_image` — PNGs exceed the
inline threshold, so they ride the stored blob_ref path.
- GPU jobs serialize on the worker's GPU semaphore; no CUDA OOM under
concurrent submits.

Driven end to end by the cozy e2e J4 GPU journey (`task e2e-gpu` in the
e2e repo): cold HF download -> IN_VRAM residency -> 8 real generations ->
billing captured per request.
27 changes: 27 additions & 0 deletions examples/sd15-image/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
[project]
name = "sd15_image"
version = "1.0.0"
description = "Stable Diffusion 1.5 text-to-image example (real GPU inference)"
requires-python = ">=3.11,<3.13"
dependencies = [
"gen-worker>=0.8.3",
"diffusers>=0.31",
"transformers>=4.46",
"accelerate>=1.0",
"msgspec>=0.18",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build.targets.wheel]
packages = ["src/sd15_image"]

[tool.gen_worker]
main = "sd15_image.main"

# The example tracks the in-repo gen-worker (the @endpoint API it uses is
# unreleased); published wheels can't run this source.
[tool.uv.sources]
gen-worker = { path = "../..", editable = true }
Empty file.
69 changes: 69 additions & 0 deletions examples/sd15-image/src/sd15_image/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""sd15-image — real GPU inference: Stable Diffusion 1.5 text-to-image.

The platform downloads the HF snapshot through ``ensure_local`` (fp16
variant only), constructs the pipeline from the ``setup()`` annotation
(dtype + placement are worker policy), and the handler returns the PNG
through the typed-parts output path (>64KB rides blob_ref, not inline).
"""

from __future__ import annotations

import msgspec

from diffusers import StableDiffusionPipeline
from gen_worker import HF, RequestContext, Resources, ValidationError, endpoint
from gen_worker import io as gw_io
from gen_worker.api.types import Asset


class SD15Input(msgspec.Struct):
prompt: str = ""
steps: int = 12
width: int = 512
height: int = 512
seed: int = 0


class SD15Output(msgspec.Struct):
image: Asset
width: int
height: int
seed: int


@endpoint(
model=HF(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
dtype="fp16",
files=("*.json", "*.txt", "*.fp16.safetensors"),
),
resources=Resources(vram_gb=6),
)
class SD15Image:
def setup(self, model: StableDiffusionPipeline) -> None:
self._pipe = model

def generate(self, ctx: RequestContext, p: SD15Input) -> SD15Output:
import torch

if not str(p.prompt or "").strip():
raise ValidationError("prompt required")
if p.width % 8 or p.height % 8 or not (64 <= p.width <= 1024) or not (64 <= p.height <= 1024):
raise ValidationError("width/height must be multiples of 8 in [64, 1024]")
steps = max(1, min(int(p.steps), 50))

ctx.raise_if_cancelled()
generator = torch.Generator(device=self._pipe.device).manual_seed(int(p.seed))
image = self._pipe(
p.prompt,
num_inference_steps=steps,
width=p.width,
height=p.height,
generator=generator,
).images[0]

asset = gw_io.write_image(ctx, "image", image, format="png")
return SD15Output(image=asset, width=image.width, height=image.height, seed=p.seed)


__all__ = ["SD15Image", "SD15Input", "SD15Output"]
Loading