Skip to content

Commit f39d615

Browse files
DreamsorcererSam Bullpaul-nechifor
authored
Config options (#1543)
Co-authored-by: Sam Bull <Sam.B@snowfalltravel.com> Co-authored-by: Paul Nechifor <paul@nechifor.net>
1 parent 57c8cc6 commit f39d615

19 files changed

Lines changed: 415 additions & 104 deletions

dimos/constants.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,27 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
from pathlib import Path
1617

18+
try:
19+
# Not a dependency, just the best way to get config path if available.
20+
from gi.repository import GLib # type: ignore[import-untyped,import-not-found]
21+
except ImportError:
22+
CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"))
23+
STATE_DIR = Path(os.environ.get("XDG_STATE_HOME", Path.home() / ".local" / "state")) / "dimos"
24+
else:
25+
CONFIG_DIR = Path(GLib.get_user_config_dir())
26+
STATE_DIR = Path(GLib.get_user_state_dir()) / "dimos"
27+
1728
DIMOS_PROJECT_ROOT = Path(__file__).parent.parent
1829

19-
DIMOS_LOG_DIR = DIMOS_PROJECT_ROOT / "logs"
30+
if (DIMOS_PROJECT_ROOT / ".git").exists():
31+
# Running from Git repository
32+
LOG_DIR = DIMOS_PROJECT_ROOT / "logs"
33+
else:
34+
# Running from an installed package - use XDG_STATE_HOME
35+
LOG_DIR = STATE_DIR / "logs"
2036

2137
"""
2238
Constants for shared memory

dimos/core/coordination/blueprints.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121
from types import MappingProxyType
2222
from typing import TYPE_CHECKING, Any, Literal, Union, get_args, get_origin, get_type_hints
2323

24+
from pydantic import create_model
25+
2426
if TYPE_CHECKING:
2527
from dimos.protocol.service.system_configurator.base import SystemConfigurator
2628

29+
from dimos.core.global_config import GlobalConfig
2730
from dimos.core.module import ModuleBase, is_module_type
2831
from dimos.core.stream import In, Out
2932
from dimos.core.transport import PubSubTransport
@@ -77,7 +80,7 @@ class ModuleRef:
7780

7881

7982
@dataclass(frozen=True)
80-
class _BlueprintAtom:
83+
class BlueprintAtom:
8184
kwargs: dict[str, Any]
8285
module: type[ModuleBase]
8386
streams: tuple[StreamRef, ...]
@@ -140,7 +143,7 @@ def create(cls, module: type[ModuleBase], kwargs: dict[str, Any]) -> Self:
140143

141144
@dataclass(frozen=True)
142145
class Blueprint:
143-
blueprints: tuple[_BlueprintAtom, ...]
146+
blueprints: tuple[BlueprintAtom, ...]
144147
disabled_modules_tuple: tuple[type[ModuleBase], ...] = field(default_factory=tuple)
145148
transport_map: Mapping[tuple[str, type], PubSubTransport[Any]] = field(
146149
default_factory=lambda: MappingProxyType({})
@@ -154,12 +157,20 @@ class Blueprint:
154157

155158
@classmethod
156159
def create(cls, module: type[ModuleBase], **kwargs: Any) -> "Blueprint":
157-
blueprint = _BlueprintAtom.create(module, kwargs)
160+
blueprint = BlueprintAtom.create(module, kwargs)
158161
return cls(blueprints=(blueprint,))
159162

160163
def disabled_modules(self, *modules: type[ModuleBase]) -> "Blueprint":
161164
return replace(self, disabled_modules_tuple=self.disabled_modules_tuple + modules)
162165

166+
def config(self) -> type:
167+
configs = {
168+
b.module.name: (get_type_hints(b.module)["config"] | None, None)
169+
for b in self.blueprints
170+
}
171+
configs["g"] = (GlobalConfig | None, None)
172+
return create_model("BlueprintConfig", __config__={"extra": "forbid"}, **configs) # type: ignore[call-overload,no-any-return]
173+
163174
def transports(self, transports: dict[tuple[str, type], Any]) -> "Blueprint":
164175
return replace(self, transport_map=MappingProxyType({**self.transport_map, **transports}))
165176

@@ -185,7 +196,7 @@ def configurators(self, *checks: "SystemConfigurator") -> "Blueprint":
185196
return replace(self, configurator_checks=self.configurator_checks + tuple(checks))
186197

187198
@cached_property
188-
def active_blueprints(self) -> tuple[_BlueprintAtom, ...]:
199+
def active_blueprints(self) -> tuple[BlueprintAtom, ...]:
189200
if not self.disabled_modules_tuple:
190201
return self.blueprints
191202
disabled = set(self.disabled_modules_tuple)
@@ -219,7 +230,7 @@ def autoconnect(*blueprints: Blueprint) -> Blueprint:
219230
)
220231

221232

222-
def _eliminate_duplicates(blueprints: list[_BlueprintAtom]) -> list[_BlueprintAtom]:
233+
def _eliminate_duplicates(blueprints: list[BlueprintAtom]) -> list[BlueprintAtom]:
223234
# The duplicates are eliminated in reverse so that newer blueprints override older ones.
224235
seen = set()
225236
unique_blueprints = []

dimos/core/coordination/module_coordinator.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
from collections import defaultdict
18-
from collections.abc import Mapping
18+
from collections.abc import Mapping, MutableMapping
1919
import importlib
2020
import shutil
2121
import sys
@@ -35,7 +35,7 @@
3535
from dimos.utils.safe_thread_map import safe_thread_map
3636

3737
if TYPE_CHECKING:
38-
from dimos.core.coordination.blueprints import Blueprint, _BlueprintAtom
38+
from dimos.core.coordination.blueprints import Blueprint, BlueprintAtom
3939
from dimos.core.rpc_client import ModuleProxy, ModuleProxyProtocol
4040

4141
logger = setup_logger()
@@ -56,7 +56,7 @@ def __init__(
5656
cls.deployment_identifier: cls(g=g) for cls in manager_types
5757
}
5858
self._deployed_modules = {}
59-
self._deployed_atoms: dict[type[ModuleBase], _BlueprintAtom] = {}
59+
self._deployed_atoms: dict[type[ModuleBase], BlueprintAtom] = {}
6060
self._resolved_module_refs: dict[tuple[type[ModuleBase], str], type[ModuleBase]] = {}
6161
self._transport_registry: dict[tuple[str, type], PubSubTransport[Any]] = {}
6262
self._class_aliases: dict[type[ModuleBase], type[ModuleBase]] = {}
@@ -114,7 +114,9 @@ def deploy(
114114
self._deployed_modules[module_class] = deployed_module
115115
return deployed_module # type: ignore[return-value]
116116

117-
def deploy_parallel(self, module_specs: list[ModuleSpec]) -> list[ModuleProxy]:
117+
def deploy_parallel(
118+
self, module_specs: list[ModuleSpec], blueprint_args: Mapping[str, Mapping[str, Any]]
119+
) -> list[ModuleProxy]:
118120
if not self._managers:
119121
raise ValueError("Not started")
120122

@@ -130,7 +132,7 @@ def deploy_parallel(self, module_specs: list[ModuleSpec]) -> list[ModuleProxy]:
130132
results: list[Any] = [None] * len(module_specs)
131133

132134
def _deploy_group(dep: str) -> None:
133-
deployed = self._managers[dep].deploy_parallel(specs_by_deployment[dep])
135+
deployed = self._managers[dep].deploy_parallel(specs_by_deployment[dep], blueprint_args)
134136
for index, module in zip(indices_by_deployment[dep], deployed, strict=True):
135137
results[index] = module
136138

@@ -221,12 +223,13 @@ def _connect_streams(self, blueprint: Blueprint) -> None:
221223
def build(
222224
cls,
223225
blueprint: Blueprint,
224-
cli_config_overrides: Mapping[str, Any] | None = None,
226+
blueprint_args: MutableMapping[str, Any] | None = None,
225227
) -> ModuleCoordinator:
226228
logger.info("Building the blueprint")
227229
global_config.update(**dict(blueprint.global_config_overrides))
228-
if cli_config_overrides:
229-
global_config.update(**dict(cli_config_overrides))
230+
blueprint_args = blueprint_args or {}
231+
if "g" in blueprint_args:
232+
global_config.update(**blueprint_args.pop("g"))
230233

231234
_run_configurators(blueprint)
232235
_check_requirements(blueprint)
@@ -236,7 +239,7 @@ def build(
236239
coordinator = cls(g=global_config)
237240
coordinator.start()
238241

239-
_deploy_all_modules(blueprint, coordinator, global_config)
242+
_deploy_all_modules(blueprint, coordinator, global_config, blueprint_args)
240243
coordinator._connect_streams(blueprint)
241244
_connect_module_refs(blueprint, coordinator)
242245

@@ -250,7 +253,7 @@ def build(
250253
def load_blueprint(
251254
self,
252255
blueprint: Blueprint,
253-
cli_config_overrides: Mapping[str, Any] | None = None,
256+
blueprint_args: MutableMapping[str, Mapping[str, Any]] | None = None,
254257
) -> None:
255258
"""Load a blueprint into an already-running coordinator.
256259
@@ -263,8 +266,9 @@ def load_blueprint(
263266

264267
# Apply config overrides.
265268
self._global_config.update(**dict(blueprint.global_config_overrides))
266-
if cli_config_overrides:
267-
self._global_config.update(**dict(cli_config_overrides))
269+
blueprint_args = blueprint_args or {}
270+
if "g" in blueprint_args:
271+
self._global_config.update(**blueprint_args.pop("g"))
268272

269273
# Scale worker pool.
270274
n_extra = int(blueprint.global_config_overrides.get("n_workers", 0))
@@ -288,7 +292,7 @@ def load_blueprint(
288292

289293
before = set(self._deployed_modules)
290294

291-
_deploy_all_modules(blueprint, self, self._global_config)
295+
_deploy_all_modules(blueprint, self, self._global_config, blueprint_args)
292296
self._connect_streams(blueprint)
293297
_connect_module_refs(blueprint, self, existing_modules=before)
294298

@@ -300,8 +304,12 @@ def load_blueprint(
300304

301305
self._send_on_system_modules()
302306

303-
def load_module(self, module_class: type[ModuleBase], **kwargs: Any) -> None:
304-
self.load_blueprint(module_class.blueprint(**kwargs))
307+
def load_module(
308+
self,
309+
module_class: type[ModuleBase],
310+
blueprint_args: MutableMapping[str, Mapping[str, Any]] | None = None,
311+
) -> None:
312+
self.load_blueprint(module_class.blueprint(**blueprint_args or {}))
305313

306314
def unload_module(self, module_class: type[ModuleBase]) -> None:
307315
"""Stop and tear down a single deployed module.
@@ -576,13 +584,16 @@ def _check_requirements(blueprint: Blueprint) -> None:
576584

577585

578586
def _deploy_all_modules(
579-
blueprint: Blueprint, module_coordinator: ModuleCoordinator, gc: GlobalConfig
587+
blueprint: Blueprint,
588+
module_coordinator: ModuleCoordinator,
589+
gc: GlobalConfig,
590+
blueprint_args: Mapping[str, Mapping[str, Any]],
580591
) -> None:
581592
module_specs: list[ModuleSpec] = []
582593
for bp in blueprint.active_blueprints:
583-
module_specs.append((bp.module, gc, bp.kwargs))
594+
module_specs.append((bp.module, gc, bp.kwargs.copy()))
584595

585-
module_coordinator.deploy_parallel(module_specs)
596+
module_coordinator.deploy_parallel(module_specs, blueprint_args)
586597

587598
for bp in blueprint.active_blueprints:
588599
module_coordinator._deployed_atoms[bp.module] = bp

dimos/core/coordination/test_blueprints.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414

1515

1616
import pickle
17-
from typing import Protocol
17+
from typing import Protocol, get_type_hints
18+
19+
from pydantic import ValidationError
20+
import pytest
1821

1922
from dimos.core._test_future_annotations_helper import (
2023
FutureData,
@@ -23,10 +26,10 @@
2326
)
2427
from dimos.core.coordination.blueprints import (
2528
Blueprint,
29+
BlueprintAtom,
2630
DisabledModuleProxy,
2731
ModuleRef,
2832
StreamRef,
29-
_BlueprintAtom,
3033
autoconnect,
3134
)
3235
from dimos.core.core import rpc
@@ -83,7 +86,7 @@ def what_is_as_name(self) -> str:
8386

8487

8588
def test_get_connection_set() -> None:
86-
assert _BlueprintAtom.create(CatModule, kwargs={"k": "v"}) == _BlueprintAtom(
89+
assert BlueprintAtom.create(CatModule, kwargs={"k": "v"}) == BlueprintAtom(
8790
module=CatModule,
8891
streams=(
8992
StreamRef(name="pet_cat", type=Petting, direction="in"),
@@ -99,7 +102,7 @@ def test_autoconnect() -> None:
99102

100103
assert blueprint_set == Blueprint(
101104
blueprints=(
102-
_BlueprintAtom(
105+
BlueprintAtom(
103106
module=ModuleA,
104107
streams=(
105108
StreamRef(name="data1", type=Data1, direction="out"),
@@ -108,7 +111,7 @@ def test_autoconnect() -> None:
108111
module_refs=(),
109112
kwargs={},
110113
),
111-
_BlueprintAtom(
114+
BlueprintAtom(
112115
module=ModuleB,
113116
streams=(
114117
StreamRef(name="data1", type=Data1, direction="in"),
@@ -122,6 +125,17 @@ def test_autoconnect() -> None:
122125
)
123126

124127

128+
def test_config() -> None:
129+
blueprint = autoconnect(ModuleA.blueprint(), ModuleB.blueprint())
130+
config = blueprint.config()
131+
assert config.model_fields.keys() == {"modulea", "moduleb", "g"}
132+
assert config.model_fields["modulea"].annotation == get_type_hints(ModuleA)["config"] | None
133+
assert config.model_fields["moduleb"].annotation == get_type_hints(ModuleB)["config"] | None
134+
135+
with pytest.raises(ValidationError, match="invalid_key"):
136+
config(module_a={"invalid_key": 5})
137+
138+
125139
def test_transports() -> None:
126140
custom_transport = LCMTransport("/custom_topic", Data1)
127141
blueprint_set = autoconnect(ModuleA.blueprint(), ModuleB.blueprint()).transports(
@@ -147,16 +161,16 @@ def test_future_annotations_support() -> None:
147161
"""Test that modules using `from __future__ import annotations` work correctly.
148162
149163
PEP 563 (future annotations) stores annotations as strings instead of actual types.
150-
This test verifies that _BlueprintAtom.create properly resolves string annotations
164+
This test verifies that BlueprintAtom.create properly resolves string annotations
151165
to the actual In/Out types.
152166
"""
153167

154168
# Test that streams are properly extracted from modules with future annotations
155-
out_blueprint = _BlueprintAtom.create(FutureModuleOut, kwargs={})
169+
out_blueprint = BlueprintAtom.create(FutureModuleOut, kwargs={})
156170
assert len(out_blueprint.streams) == 1
157171
assert out_blueprint.streams[0] == StreamRef(name="data", type=FutureData, direction="out")
158172

159-
in_blueprint = _BlueprintAtom.create(FutureModuleIn, kwargs={})
173+
in_blueprint = BlueprintAtom.create(FutureModuleIn, kwargs={})
160174
assert len(in_blueprint.streams) == 1
161175
assert in_blueprint.streams[0] == StreamRef(name="data", type=FutureData, direction="in")
162176

@@ -186,7 +200,7 @@ class ModuleWithOptionalRef(Module):
186200

187201

188202
def test_optional_module_ref_detected() -> None:
189-
atom = _BlueprintAtom.create(ModuleWithOptionalRef, kwargs={})
203+
atom = BlueprintAtom.create(ModuleWithOptionalRef, kwargs={})
190204
assert len(atom.module_refs) == 1
191205
ref = atom.module_refs[0]
192206
assert ref.name == "calc"

0 commit comments

Comments
 (0)