Skip to content

Commit eb04e60

Browse files
feat(workers): use types for worker messages (#1767)
1 parent bcea7a6 commit eb04e60

2 files changed

Lines changed: 154 additions & 72 deletions

File tree

dimos/core/coordination/python_worker.py

Lines changed: 75 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@
2222
import traceback
2323
from typing import TYPE_CHECKING, Any
2424

25+
from dimos.core.coordination.worker_messages import (
26+
CallMethodRequest,
27+
DeployModuleRequest,
28+
GetAttrRequest,
29+
SetRefRequest,
30+
ShutdownRequest,
31+
SuppressConsoleRequest,
32+
UndeployModuleRequest,
33+
WorkerRequest,
34+
WorkerResponse,
35+
)
2536
from dimos.core.global_config import GlobalConfig, global_config
2637
from dimos.core.library_config import apply_library_config
2738
from dimos.utils.logging_config import setup_logger
@@ -63,7 +74,12 @@ def __getattr__(self, name: str) -> Any:
6374

6475
def _call(*args: Any, **kwargs: Any) -> ActorFuture:
6576
result = self._actor._send_request_to_worker(
66-
{"type": "call_method", "name": name, "args": args, "kwargs": kwargs}
77+
CallMethodRequest(
78+
module_id=self._actor._module_id,
79+
name=name,
80+
args=args,
81+
kwargs=kwargs,
82+
)
6783
)
6884
return ActorFuture(result)
6985

@@ -91,34 +107,33 @@ def __reduce__(self) -> tuple[type, tuple[None, type, int, int, None]]:
91107
"""Exclude the connection and lock when pickling."""
92108
return (Actor, (None, self._cls, self._worker_id, self._module_id, None))
93109

94-
def _send_request_to_worker(self, request: dict[str, Any]) -> Any:
110+
def _send_request_to_worker(self, request: WorkerRequest) -> Any:
95111
if self._conn is None:
96112
raise RuntimeError("Actor connection not available - cannot send requests")
97-
request["module_id"] = self._module_id
98113
if self._lock is not None:
99114
with self._lock:
100115
self._conn.send(request)
101-
response = self._conn.recv()
116+
response: WorkerResponse = self._conn.recv()
102117
else:
103118
self._conn.send(request)
104119
response = self._conn.recv()
105-
if response.get("error"):
106-
if "AttributeError" in response["error"]: # TODO: better error handling
107-
raise AttributeError(response["error"])
108-
raise RuntimeError(f"Worker error: {response['error']}")
109-
return response.get("result")
120+
if response.error:
121+
if "AttributeError" in response.error: # TODO: better error handling
122+
raise AttributeError(response.error)
123+
raise RuntimeError(f"Worker error: {response.error}")
124+
return response.result
110125

111126
def set_ref(self, ref: Any) -> ActorFuture:
112127
"""Set the actor reference on the remote module."""
113-
result = self._send_request_to_worker({"type": "set_ref", "ref": ref})
128+
result = self._send_request_to_worker(SetRefRequest(module_id=self._module_id, ref=ref))
114129
return ActorFuture(result)
115130

116131
def __getattr__(self, name: str) -> Any:
117132
"""Proxy attribute access to the worker process."""
118133
if name.startswith("_"):
119134
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
120135

121-
return self._send_request_to_worker({"type": "getattr", "name": name})
136+
return self._send_request_to_worker(GetAttrRequest(module_id=self._module_id, name=name))
122137

123138

124139
# Global forkserver context. Using `forkserver` instead of `fork` because it
@@ -208,20 +223,14 @@ def deploy_module(
208223
kwargs["g"] = global_config
209224
module_id = _module_ids.next()
210225

211-
# Send deploy_module request to the worker process
212-
request = {
213-
"type": "deploy_module",
214-
"module_id": module_id,
215-
"module_class": module_class,
216-
"kwargs": kwargs,
217-
}
226+
request = DeployModuleRequest(module_id=module_id, module_class=module_class, kwargs=kwargs)
218227
try:
219228
with self._lock:
220229
self._conn.send(request)
221-
response = self._conn.recv()
230+
response: WorkerResponse = self._conn.recv()
222231

223-
if response.get("error"):
224-
raise RuntimeError(f"Failed to deploy module: {response['error']}")
232+
if response.error:
233+
raise RuntimeError(f"Failed to deploy module: {response.error}")
225234

226235
actor = Actor(self._conn, module_class, self._worker_id, module_id, self._lock)
227236
actor.set_ref(actor).result()
@@ -243,11 +252,11 @@ def undeploy_module(self, module_id: int) -> None:
243252
raise RuntimeError("Worker process not started")
244253

245254
with self._lock:
246-
self._conn.send({"type": "undeploy_module", "module_id": module_id})
247-
response = self._conn.recv()
255+
self._conn.send(UndeployModuleRequest(module_id=module_id))
256+
response: WorkerResponse = self._conn.recv()
248257

249-
if response.get("error"):
250-
raise RuntimeError(f"Failed to undeploy module: {response['error']}")
258+
if response.error:
259+
raise RuntimeError(f"Failed to undeploy module: {response.error}")
251260

252261
self._modules.pop(module_id, None)
253262

@@ -256,7 +265,7 @@ def suppress_console(self) -> None:
256265
return
257266
try:
258267
with self._lock:
259-
self._conn.send({"type": "suppress_console"})
268+
self._conn.send(SuppressConsoleRequest())
260269
self._conn.recv()
261270
except (BrokenPipeError, EOFError, ConnectionResetError):
262271
pass
@@ -265,7 +274,7 @@ def shutdown(self) -> None:
265274
if self._conn is not None:
266275
try:
267276
with self._lock:
268-
self._conn.send({"type": "shutdown"})
277+
self._conn.send(ShutdownRequest())
269278
if self._conn.poll(timeout=5):
270279
self._conn.recv()
271280
else:
@@ -353,54 +362,48 @@ def _worker_loop(conn: Connection, instances: dict[int, Any], worker_id: int) ->
353362
except (EOFError, KeyboardInterrupt):
354363
break
355364

356-
response: dict[str, Any] = {}
365+
response: WorkerResponse
357366
try:
358-
req_type = request.get("type")
359-
360-
if req_type == "deploy_module":
361-
module_class = request["module_class"]
362-
kwargs = request["kwargs"]
363-
module_id = request["module_id"]
364-
instance = module_class(**kwargs)
365-
instances[module_id] = instance
366-
response["result"] = module_id
367-
368-
elif req_type == "set_ref":
369-
module_id = request["module_id"]
370-
instances[module_id].ref = request.get("ref")
371-
response["result"] = worker_id
372-
373-
elif req_type == "getattr":
374-
module_id = request["module_id"]
375-
response["result"] = getattr(instances[module_id], request["name"])
376-
377-
elif req_type == "call_method":
378-
module_id = request["module_id"]
379-
method = getattr(instances[module_id], request["name"])
380-
result = method(*request.get("args", ()), **request.get("kwargs", {}))
381-
response["result"] = result
382-
383-
elif req_type == "undeploy_module":
384-
module_id = request["module_id"]
385-
instance = instances.pop(module_id, None)
386-
if instance is not None:
387-
instance.stop()
388-
response["result"] = True
389-
390-
elif req_type == "suppress_console":
391-
_suppress_console_output()
392-
response["result"] = True
393-
394-
elif req_type == "shutdown":
395-
response["result"] = True
396-
conn.send(response)
397-
break
398-
399-
else:
400-
response["error"] = f"Unknown request type: {req_type}"
367+
match request:
368+
case DeployModuleRequest(
369+
module_id=module_id, module_class=module_class, kwargs=kwargs
370+
):
371+
instance = module_class(**kwargs)
372+
instances[module_id] = instance
373+
response = WorkerResponse(result=module_id)
374+
375+
case SetRefRequest(module_id=module_id, ref=ref):
376+
instances[module_id].ref = ref
377+
response = WorkerResponse(result=worker_id)
378+
379+
case GetAttrRequest(module_id=module_id, name=name):
380+
response = WorkerResponse(result=getattr(instances[module_id], name))
381+
382+
case CallMethodRequest(module_id=module_id, name=name, args=args, kwargs=kwargs):
383+
method = getattr(instances[module_id], name)
384+
response = WorkerResponse(result=method(*args, **kwargs))
385+
386+
case UndeployModuleRequest(module_id=module_id):
387+
instance = instances.pop(module_id, None)
388+
if instance is not None:
389+
instance.stop()
390+
response = WorkerResponse(result=True)
391+
392+
case SuppressConsoleRequest():
393+
_suppress_console_output()
394+
response = WorkerResponse(result=True)
395+
396+
case ShutdownRequest():
397+
conn.send(WorkerResponse(result=True))
398+
break
399+
400+
case _:
401+
response = WorkerResponse(error=f"Unknown request type: {type(request)}")
401402

402403
except Exception as e:
403-
response["error"] = f"{e.__class__.__name__}: {e}\n{traceback.format_exc()}"
404+
response = WorkerResponse(
405+
error=f"{e.__class__.__name__}: {e}\n{traceback.format_exc()}"
406+
)
404407

405408
try:
406409
conn.send(response)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2026 Dimensional Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from dataclasses import dataclass
17+
from typing import TYPE_CHECKING, Any
18+
19+
if TYPE_CHECKING:
20+
from dimos.core.module import ModuleBase
21+
22+
23+
@dataclass(frozen=True)
24+
class DeployModuleRequest:
25+
module_id: int
26+
module_class: type[ModuleBase]
27+
kwargs: dict[str, Any]
28+
29+
30+
@dataclass(frozen=True)
31+
class SetRefRequest:
32+
module_id: int
33+
ref: Any
34+
35+
36+
@dataclass(frozen=True)
37+
class GetAttrRequest:
38+
module_id: int
39+
name: str
40+
41+
42+
@dataclass(frozen=True)
43+
class CallMethodRequest:
44+
module_id: int
45+
name: str
46+
args: tuple[Any, ...]
47+
kwargs: dict[str, Any]
48+
49+
50+
@dataclass(frozen=True)
51+
class UndeployModuleRequest:
52+
module_id: int
53+
54+
55+
@dataclass(frozen=True)
56+
class SuppressConsoleRequest:
57+
pass
58+
59+
60+
@dataclass(frozen=True)
61+
class ShutdownRequest:
62+
pass
63+
64+
65+
WorkerRequest = (
66+
DeployModuleRequest
67+
| SetRefRequest
68+
| GetAttrRequest
69+
| CallMethodRequest
70+
| UndeployModuleRequest
71+
| SuppressConsoleRequest
72+
| ShutdownRequest
73+
)
74+
75+
76+
@dataclass(frozen=True)
77+
class WorkerResponse:
78+
result: Any = None
79+
error: str | None = None

0 commit comments

Comments
 (0)