Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/gen_worker/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,14 @@ def _label(c: _SelectedFunction) -> str:
if c.cls is not None and c.cls.__name__ == cls_name
]
if method_name:
# fn_name is the canonical slug; accept the python attr name or
# either slug spelling.
from gen_worker.discovery.names import slugify_name

wanted_slug = slugify_name(method_name)
matches = [
c for c in matches
if c.attr_name == method_name or c.fn_name == method_name
if c.attr_name == method_name or c.fn_name == wanted_slug
]

if not matches:
Expand All @@ -254,8 +259,10 @@ def _label(c: _SelectedFunction) -> str:
)
if len(matches) > 1:
if not cls_name and not method_name and default_name:
wanted = default_name.replace("-", "_").lower()
defaults = [m for m in matches if m.fn_name.lower() == wanted]
from gen_worker.discovery.names import slugify_name

wanted = slugify_name(default_name)
defaults = [m for m in matches if m.fn_name == wanted]
if len(defaults) == 1:
return defaults[0]
listing = "\n - " + "\n - ".join(_label(c) for c in matches)
Expand Down
12 changes: 10 additions & 2 deletions src/gen_worker/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import msgspec

from ..api.errors import CanceledError
from ..discovery.names import slugify_name
from ..models import memory
from ..models.residency import Residency, Tier
from . import run as run_mod
Expand Down Expand Up @@ -197,9 +198,14 @@ def _filter_candidates_by_function(

# Accept both repeated flags (--function a --function b) and a
# comma-separated list (--function a,b,c), so either spelling works.
# Function names are canonical slugs; slugify so marco_polo == marco-polo.
names: List[str] = []
for raw in wanted:
names.extend(part.strip() for part in str(raw or "").split(",") if part.strip())
names.extend(
slugify_name(part.strip())
for part in str(raw or "").split(",")
if part.strip()
)
if not names:
return candidates

Expand Down Expand Up @@ -468,7 +474,9 @@ def dispatch(
``on_event`` may raise (e.g. on client disconnect) to abort the handler.
"""
self.last_activity = time.time() # reset the --idle-timeout clock
served = self.functions.get(function_name)
served = self.functions.get(function_name) or self.functions.get(
slugify_name(function_name)
)
if served is None:
return _error_envelope(
"not_found",
Expand Down
13 changes: 11 additions & 2 deletions src/gen_worker/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from .api.binding import Binding
from .api.decorators import ATTR, EndpointDecl, Resources
from .discovery.names import slugify_name
from .discovery.walk import find_endpoints

_ITER_ORIGINS = (
Expand All @@ -40,7 +41,7 @@ class EndpointSpec:
calls ``method`` directly with no instance/setup.
"""

name: str # routable function name (pre-slug)
name: str # routable function name (canonical slug)
method: Callable[..., Any] # unbound function object
kind: str # inference | training | dataset | conversion
payload_type: type # msgspec.Struct
Expand Down Expand Up @@ -127,8 +128,16 @@ def _spec_for_handler(
raise ValueError(f"{owner}: missing return type annotation")
output_mode, output_type, delta_type = _inspect_return(owner, ret)

# The wire/dispatch name is the SLUG (matches the discovery manifest and
# tensorhub's canonical function names — `_` -> `-`): the orchestrator's
# RunJob.function_name and the worker's advertised available_functions
# must agree, and the platform normalizes to slugs everywhere.
slug = slugify_name(fn_name)
if not slug:
raise ValueError(f"{owner}: function name {fn_name!r} cannot be normalized")

return EndpointSpec(
name=fn_name,
name=slug,
method=method,
kind=decl.kind,
payload_type=payload_type,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_cli_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _fake_send(sock_path, request, timeout=0.0, on_frame=None):

assert cli.main(["run", "--attach", "--module", "_test_marco", "--payload", json.dumps({"text": "marco"})]) == 0
# Routed through the warm server with the resolved function NAME + payload.
assert captured["request"]["function"] == "marco_polo"
assert captured["request"]["function"] == "marco-polo"
assert captured["request"]["payload"] == {"text": "marco"}
assert _last_event(capsys)["value"]["response"] == "polo"

Expand Down Expand Up @@ -223,9 +223,9 @@ def test_run_list_emits_description_document(capsys, monkeypatch) -> None:
doc = json.loads(capsys.readouterr().out)
assert doc["protocol_version"] >= 1
fns = {f["name"]: f for f in doc["functions"]}
assert "marco_polo" in fns
assert fns["marco_polo"]["class"] == "MarcoPolo"
assert "properties" in fns["marco_polo"]["input_schema"]
assert "marco-polo" in fns
assert fns["marco-polo"]["class"] == "MarcoPolo"
assert "properties" in fns["marco-polo"]["input_schema"]


def test_pyproject_tool_gen_worker_main(tmp_path) -> None:
Expand Down
14 changes: 7 additions & 7 deletions tests/test_cli_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_lazy_setup_on_first_invoke_only_invoked_class_and_eager_at_boot() -> No

ep = serve_mod._Endpoint(offline=False, allow_publish=False)
ep.boot(candidates)
assert sorted(ep.function_names()) == ["do_alpha", "do_beta"]
assert sorted(ep.function_names()) == ["do-alpha", "do-beta"]
assert setups == {} # LAZY: boot indexes but does NOT setup

# First invoke sets up ONLY Alpha; Beta (never called) stays cold. Warm after.
Expand Down Expand Up @@ -186,8 +186,8 @@ def test_serve_drives_residency_demote_promote_evict(monkeypatch) -> None:
# Tiny VRAM budget so the two models over-subscribe (forces eviction).
ep._residency = Residency(vram_budget_bytes=5 * 1024 ** 3)
res = ep._residency
big_id = ep._model_id_by_inst[id(ep.functions["gen_big"].instance)]
small_id = ep._model_id_by_inst[id(ep.functions["gen_small"].instance)]
big_id = ep._model_id_by_inst[id(ep.functions["gen-big"].instance)]
small_id = ep._model_id_by_inst[id(ep.functions["gen-small"].instance)]

# 1) big: cold setup + registered VRAM-resident.
assert ep.dispatch("gen_big", {"text": "x"})["ok"] is True
Expand All @@ -199,7 +199,7 @@ def test_serve_drives_residency_demote_promote_evict(monkeypatch) -> None:
assert ep.dispatch("gen_small", {"text": "x"})["ok"] is True
assert res.tier(small_id) is Tier.VRAM
assert res.tier(big_id) is Tier.RAM # demoted, NOT re-setup
big_pipe = ep._pipeline_by_inst[id(ep.functions["gen_big"].instance)]
big_pipe = ep._pipeline_by_inst[id(ep.functions["gen-big"].instance)]
assert big_pipe.moves[-1] == "cpu" # real .to("cpu") demote
assert setups == {"big": 1, "small": 1}

Expand Down Expand Up @@ -250,11 +250,11 @@ def beta_one(self, ctx: RequestContext, data: _In) -> _Out:

# --function alpha_one keeps the WHOLE Alpha class (both fns), drops Beta.
filtered = serve_mod._filter_candidates_by_function(candidates, ["alpha_one"])
assert sorted(c.fn_name for c in filtered) == ["alpha_one", "alpha_two"]
assert sorted(c.fn_name for c in filtered) == ["alpha-one", "alpha-two"]
assert all(c.cls.__name__ == "Alpha" for c in filtered)
ep = serve_mod._Endpoint(offline=False, allow_publish=False)
ep.boot(filtered)
assert ep.function_names() == ["alpha_one", "alpha_two"]
assert ep.function_names() == ["alpha-one", "alpha-two"]
ep.dispatch("alpha_one", {"text": "x"})
assert setups == {"Alpha": 1} # only Alpha; Beta never set up
ep.shutdown()
Expand All @@ -268,7 +268,7 @@ def beta_one(self, ctx: RequestContext, data: _In) -> _Out:
rc = cli.main(["serve", "--list-functions", "--module", name])
assert rc == 0
out = capsys.readouterr().out
assert "alpha_one" in out and "beta_one" in out and "Alpha" in out and "Beta" in out
assert "alpha-one" in out and "beta-one" in out and "Alpha" in out and "Beta" in out
assert setups == {}


Expand Down
2 changes: 1 addition & 1 deletion tests/test_cli_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_serve_sidecar_written_and_removed(tmp_path) -> None:
assert doc["pid"] == proc.pid
assert "protocol_version" in doc and "functions" in doc
assert doc["listen"] == str(sock)
assert "marco_polo" in doc["functions"]
assert "marco-polo" in doc["functions"]
finally:
proc.send_signal(signal.SIGTERM)
try:
Expand Down
37 changes: 37 additions & 0 deletions tests/test_registry_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Registry spec names are canonical slugs — the ONE wire/dispatch vocabulary.

The orchestrator's canonical function name is the slug (`_` -> `-`, tensorhub
builder.NormalizeSlug); the discovery manifest publishes the same slug. The
worker must advertise and match RunJob.function_name on it, so EndpointSpec
names slugify at extraction time.
"""

import msgspec

from gen_worker import RequestContext, endpoint
from gen_worker.registry import extract_specs


class _In(msgspec.Struct):
text: str = ""


class _Out(msgspec.Struct):
response: str


@endpoint
class SnakeCase:
def marco_polo(self, ctx: RequestContext, data: _In) -> _Out:
return _Out(response="polo")

def marco_polo_slow(self, ctx: RequestContext, data: _In) -> _Out:
return _Out(response="polo")


def test_spec_names_are_slugs() -> None:
names = sorted(s.name for s in extract_specs(SnakeCase))
assert names == ["marco-polo", "marco-polo-slow"]
# python attr names survive separately for manifest python_name.
attrs = sorted(s.attr_name for s in extract_specs(SnakeCase))
assert attrs == ["marco_polo", "marco_polo_slow"]
18 changes: 9 additions & 9 deletions tests/test_worker_grpc_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,17 +405,17 @@ def test_marco_polo_example_serves_under_the_new_core() -> None:
conn.wait_for(
lambda m: m.WhichOneof("msg") == "state_delta"
and m.state_delta.phase == pb.WORKER_PHASE_READY
and "marco_polo" in m.state_delta.available_functions
and "marco-polo" in m.state_delta.available_functions
)
conn.send(run_job=pb.RunJob(
request_id="mp-1", attempt=1, function_name="marco_polo",
request_id="mp-1", attempt=1, function_name="marco-polo",
input_payload=_msgpack("marco")))
res = conn.wait_for(_is_result_for("mp-1")).job_result
assert res.status == pb.JOB_STATUS_OK
assert _decode_out(res.inline).response == "polo"

conn.send(run_job=pb.RunJob(
request_id="mp-2", attempt=1, function_name="marco_polo_stream",
request_id="mp-2", attempt=1, function_name="marco-polo-stream",
input_payload=_msgpack("marco")))
res = conn.wait_for(_is_result_for("mp-2")).job_result
assert res.status == pb.JOB_STATUS_OK
Expand Down Expand Up @@ -562,8 +562,8 @@ def log_message(self, *a):
and m.state_delta.phase == pb.WORKER_PHASE_READY
)
# Gated until its model loads: present in loading, absent from available.
assert "model_echo" not in ready.state_delta.available_functions
assert "model_echo" in ready.state_delta.loading_functions
assert "model-echo" not in ready.state_delta.available_functions
assert "model-echo" in ready.state_delta.loading_functions

# DOWNLOAD -> DOWNLOADING then ON_DISK.
conn.send(model_op=pb.ModelOp(
Expand All @@ -578,12 +578,12 @@ def log_message(self, *a):
conn.wait_for(_is_model_event("e2e/tiny", pb.MODEL_STATE_IN_RAM))
conn.wait_for(
lambda m: m.WhichOneof("msg") == "state_delta"
and "model_echo" in m.state_delta.available_functions
and "model-echo" in m.state_delta.available_functions
)

# The handler sees the materialized snapshot content.
conn.send(run_job=pb.RunJob(
request_id="r-model", attempt=1, function_name="model_echo",
request_id="r-model", attempt=1, function_name="model-echo",
input_payload=_msgpack("marco")))
res = conn.wait_for(_is_result_for("r-model")).job_result
assert res.status == pb.JOB_STATUS_OK
Expand All @@ -602,8 +602,8 @@ def log_message(self, *a):
time.sleep(0.02)
conn.wait_for(
lambda m: m.WhichOneof("msg") == "state_delta"
and "model_echo" not in m.state_delta.available_functions
and "model_echo" in m.state_delta.loading_functions
and "model-echo" not in m.state_delta.available_functions
and "model-echo" in m.state_delta.loading_functions
)
finally:
harness.stop()
Expand Down
Loading