|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -"""MuJoCo simulation adapter for ControlCoordinator integration. |
16 | | -
|
17 | | -Thin wrapper around SimManipInterface that plugs into the adapter registry. |
18 | | -Arm joint methods are inherited from SimManipInterface. |
| 15 | +"""Shared-memory adapter for MuJoCo-based manipulator simulation. |
| 16 | +this adapter reads from and writes to the same SHM buffers. |
19 | 17 | """ |
20 | 18 |
|
21 | 19 | from __future__ import annotations |
22 | 20 |
|
23 | | -from pathlib import Path |
| 21 | +import math |
| 22 | +import time |
24 | 23 | from typing import TYPE_CHECKING, Any |
25 | 24 |
|
26 | | -from dimos.simulation.engines.mujoco_engine import MujocoEngine |
27 | | -from dimos.simulation.manipulators.sim_manip_interface import SimManipInterface |
| 25 | +from dimos.hardware.manipulators.spec import ( |
| 26 | + ControlMode, |
| 27 | + JointLimits, |
| 28 | + ManipulatorInfo, |
| 29 | +) |
| 30 | +from dimos.simulation.engines.mujoco_shm import ( |
| 31 | + ManipShmReader, |
| 32 | + shm_key_from_path, |
| 33 | +) |
| 34 | +from dimos.utils.logging_config import setup_logger |
28 | 35 |
|
29 | 36 | if TYPE_CHECKING: |
30 | 37 | from dimos.hardware.manipulators.registry import AdapterRegistry |
31 | 38 |
|
32 | 39 |
|
33 | | -class SimMujocoAdapter(SimManipInterface): |
34 | | - """Uses ``address`` as the MJCF XML path (same field real adapters use for IP/port). |
35 | | - If the engine has more joints than ``dof``, the extra joint at index ``dof`` |
36 | | - is treated as the gripper, with ctrl range scaled automatically. |
| 40 | +logger = setup_logger() |
| 41 | + |
| 42 | +_READY_WAIT_TIMEOUT_S = 60.0 |
| 43 | +_READY_WAIT_POLL_S = 0.1 |
| 44 | +_ATTACH_RETRY_TIMEOUT_S = 30.0 |
| 45 | +_ATTACH_RETRY_POLL_S = 0.2 |
| 46 | + |
| 47 | + |
| 48 | +class ShmMujocoAdapter: |
| 49 | + """``ManipulatorAdapter`` that proxies to a ``MujocoSimModule`` via SHM. |
| 50 | +
|
| 51 | + Uses ``address`` (the MJCF XML path) as the discovery key. The sim module |
| 52 | + must be running and have signalled ready before ``connect()`` returns. |
37 | 53 | """ |
38 | 54 |
|
39 | 55 | def __init__( |
40 | 56 | self, |
41 | 57 | dof: int = 7, |
42 | 58 | address: str | None = None, |
43 | | - headless: bool = True, |
| 59 | + hardware_id: str | None = None, |
44 | 60 | **_: Any, |
45 | 61 | ) -> None: |
46 | 62 | if address is None: |
47 | 63 | raise ValueError("address (MJCF XML path) is required for sim_mujoco adapter") |
48 | | - engine = MujocoEngine(config_path=Path(address), headless=headless) |
| 64 | + self._dof = dof |
| 65 | + self._address = address |
| 66 | + self._hardware_id = hardware_id |
| 67 | + self._shm_key = shm_key_from_path(address) |
| 68 | + self._shm: ManipShmReader | None = None |
| 69 | + self._connected = False |
| 70 | + self._servos_enabled = False |
| 71 | + self._control_mode = ControlMode.POSITION |
| 72 | + self._error_code = 0 |
| 73 | + self._error_message = "" |
| 74 | + self._has_gripper = False |
| 75 | + self._effort_mode_warned = False |
| 76 | + |
| 77 | + def connect(self) -> bool: |
| 78 | + deadline = time.monotonic() + _ATTACH_RETRY_TIMEOUT_S |
| 79 | + while True: |
| 80 | + try: |
| 81 | + self._shm = ManipShmReader(self._shm_key) |
| 82 | + break |
| 83 | + except FileNotFoundError: |
| 84 | + if time.monotonic() > deadline: |
| 85 | + logger.error( |
| 86 | + "SHM buffers not found", |
| 87 | + address=self._address, |
| 88 | + shm_key=self._shm_key, |
| 89 | + timeout_s=_ATTACH_RETRY_TIMEOUT_S, |
| 90 | + ) |
| 91 | + return False |
| 92 | + time.sleep(_ATTACH_RETRY_POLL_S) |
| 93 | + |
| 94 | + # Wait for sim module to signal ready. |
| 95 | + deadline = time.monotonic() + _READY_WAIT_TIMEOUT_S |
| 96 | + while not self._shm.is_ready(): |
| 97 | + if time.monotonic() > deadline: |
| 98 | + logger.error("sim module not ready", timeout_s=_READY_WAIT_TIMEOUT_S) |
| 99 | + self._shm.cleanup() |
| 100 | + self._shm = None |
| 101 | + return False |
| 102 | + time.sleep(_READY_WAIT_POLL_S) |
| 103 | + |
| 104 | + num_joints = self._shm.num_joints() |
| 105 | + self._has_gripper = num_joints > self._dof |
| 106 | + self._connected = True |
| 107 | + self._servos_enabled = True |
| 108 | + logger.info("ShmMujocoAdapter connected", dof=self._dof, gripper=self._has_gripper) |
| 109 | + return True |
| 110 | + |
| 111 | + def disconnect(self) -> None: |
| 112 | + try: |
| 113 | + if self._shm is not None: |
| 114 | + self._shm.cleanup() |
| 115 | + finally: |
| 116 | + self._shm = None |
| 117 | + self._connected = False |
| 118 | + |
| 119 | + def is_connected(self) -> bool: |
| 120 | + return self._connected and self._shm is not None |
| 121 | + |
| 122 | + def get_info(self) -> ManipulatorInfo: |
| 123 | + return ManipulatorInfo( |
| 124 | + vendor="Simulation", |
| 125 | + model="Simulation", |
| 126 | + dof=self._dof, |
| 127 | + firmware_version=None, |
| 128 | + serial_number=None, |
| 129 | + ) |
| 130 | + |
| 131 | + def get_dof(self) -> int: |
| 132 | + return self._dof |
| 133 | + |
| 134 | + def get_limits(self) -> JointLimits: |
| 135 | + lower = [-math.pi] * self._dof |
| 136 | + upper = [math.pi] * self._dof |
| 137 | + max_vel_rad = math.radians(180.0) |
| 138 | + return JointLimits( |
| 139 | + position_lower=lower, |
| 140 | + position_upper=upper, |
| 141 | + velocity_max=[max_vel_rad] * self._dof, |
| 142 | + ) |
| 143 | + |
| 144 | + def set_control_mode(self, mode: ControlMode) -> bool: |
| 145 | + self._control_mode = mode |
| 146 | + return True |
| 147 | + |
| 148 | + def get_control_mode(self) -> ControlMode: |
| 149 | + return self._control_mode |
| 150 | + |
| 151 | + def read_joint_positions(self) -> list[float]: |
| 152 | + if self._shm is None: |
| 153 | + return [0.0] * self._dof |
| 154 | + return self._shm.read_positions(self._dof) |
| 155 | + |
| 156 | + def read_joint_velocities(self) -> list[float]: |
| 157 | + if self._shm is None: |
| 158 | + return [0.0] * self._dof |
| 159 | + return self._shm.read_velocities(self._dof) |
| 160 | + |
| 161 | + def read_joint_efforts(self) -> list[float]: |
| 162 | + if self._shm is None: |
| 163 | + return [0.0] * self._dof |
| 164 | + return self._shm.read_efforts(self._dof) |
| 165 | + |
| 166 | + def read_state(self) -> dict[str, int]: |
| 167 | + velocities = self.read_joint_velocities() |
| 168 | + is_moving = any(abs(v) > 1e-4 for v in velocities) |
| 169 | + mode_int = list(ControlMode).index(self._control_mode) |
| 170 | + return {"state": 1 if is_moving else 0, "mode": mode_int} |
| 171 | + |
| 172 | + def read_error(self) -> tuple[int, str]: |
| 173 | + return self._error_code, self._error_message |
| 174 | + |
| 175 | + def write_joint_positions(self, positions: list[float], velocity: float = 1.0) -> bool: |
| 176 | + if not self._servos_enabled or self._shm is None: |
| 177 | + return False |
| 178 | + self._control_mode = ControlMode.POSITION |
| 179 | + self._shm.write_position_command(positions[: self._dof]) |
| 180 | + return True |
| 181 | + |
| 182 | + def write_joint_velocities(self, velocities: list[float]) -> bool: |
| 183 | + if not self._servos_enabled or self._shm is None: |
| 184 | + return False |
| 185 | + self._control_mode = ControlMode.VELOCITY |
| 186 | + self._shm.write_velocity_command(velocities[: self._dof]) |
| 187 | + return True |
| 188 | + |
| 189 | + def write_joint_efforts(self, efforts: list[float]) -> bool: |
| 190 | + # Effort mode not exposed via SHM yet; caller can fall back to position. |
| 191 | + if not self._effort_mode_warned: |
| 192 | + logger.warning( |
| 193 | + "write_joint_efforts not supported by sim adapter; ignoring and returning False", |
| 194 | + dof=self._dof, |
| 195 | + ) |
| 196 | + self._effort_mode_warned = True |
| 197 | + return False |
| 198 | + |
| 199 | + def write_stop(self) -> bool: |
| 200 | + # Hold current position. |
| 201 | + if self._shm is None: |
| 202 | + return False |
| 203 | + positions = self._shm.read_positions(self._dof) |
| 204 | + self._shm.write_position_command(positions) |
| 205 | + return True |
| 206 | + |
| 207 | + def write_enable(self, enable: bool) -> bool: |
| 208 | + self._servos_enabled = enable |
| 209 | + return True |
| 210 | + |
| 211 | + def read_enabled(self) -> bool: |
| 212 | + return self._servos_enabled |
| 213 | + |
| 214 | + def write_clear_errors(self) -> bool: |
| 215 | + self._error_code = 0 |
| 216 | + self._error_message = "" |
| 217 | + return True |
| 218 | + |
| 219 | + def read_cartesian_position(self) -> dict[str, float] | None: |
| 220 | + return None |
| 221 | + |
| 222 | + def write_cartesian_position(self, pose: dict[str, float], velocity: float = 1.0) -> bool: |
| 223 | + return False |
| 224 | + |
| 225 | + def read_gripper_position(self) -> float | None: |
| 226 | + if not self._has_gripper or self._shm is None: |
| 227 | + return None |
| 228 | + return self._shm.read_gripper_position() |
49 | 229 |
|
50 | | - # Detect gripper from engine joints |
51 | | - gripper_idx = None |
52 | | - gripper_kwargs = {} |
53 | | - joint_names = list(engine.joint_names) |
54 | | - if len(joint_names) > dof: |
55 | | - gripper_idx = dof |
56 | | - ctrl_range = engine.get_actuator_ctrl_range(dof) |
57 | | - joint_range = engine.get_joint_range(dof) |
58 | | - if ctrl_range is None or joint_range is None: |
59 | | - raise ValueError(f"Gripper joint at index {dof} missing ctrl/joint range in MJCF") |
60 | | - gripper_kwargs = {"gripper_ctrl_range": ctrl_range, "gripper_joint_range": joint_range} |
| 230 | + def write_gripper_position(self, position: float) -> bool: |
| 231 | + if not self._has_gripper or self._shm is None: |
| 232 | + return False |
| 233 | + self._shm.write_gripper_command(position) |
| 234 | + return True |
61 | 235 |
|
62 | | - super().__init__(engine=engine, dof=dof, gripper_idx=gripper_idx, **gripper_kwargs) |
| 236 | + def read_force_torque(self) -> list[float] | None: |
| 237 | + return None |
63 | 238 |
|
64 | 239 |
|
65 | 240 | def register(registry: AdapterRegistry) -> None: |
66 | 241 | """Register this adapter with the registry.""" |
67 | | - registry.register("sim_mujoco", SimMujocoAdapter) |
| 242 | + registry.register("sim_mujoco", ShmMujocoAdapter) |
68 | 243 |
|
69 | 244 |
|
70 | | -__all__ = ["SimMujocoAdapter"] |
| 245 | +__all__ = ["ShmMujocoAdapter"] |
0 commit comments