From 79a7597c9623b16248e517769ef2ba754c4aa874 Mon Sep 17 00:00:00 2001 From: Tobias Juelg Date: Sat, 2 May 2026 10:08:38 +0200 Subject: [PATCH 01/13] Add dynamic joint replay state --- python/rcs/_core/sim.pyi | 15 + python/rcs/envs/sim.py | 215 +++++++++++++ python/rcs/sim/sim.py | 34 ++ python/rcs/sim_state_replay.py | 264 ++++++++++++++++ python/tests/test_sim_state_record_replay.py | 310 +++++++++++++++++++ src/pybind/rcs.cpp | 16 + src/sim/sim.cpp | 166 +++++++++- src/sim/sim.h | 33 ++ 8 files changed, 1052 insertions(+), 1 deletion(-) create mode 100644 python/rcs/sim_state_replay.py create mode 100644 python/tests/test_sim_state_record_replay.py diff --git a/python/rcs/_core/sim.pyi b/python/rcs/_core/sim.pyi index fa5a8899..d419aeee 100644 --- a/python/rcs/_core/sim.pyi +++ b/python/rcs/_core/sim.pyi @@ -89,14 +89,29 @@ class GuiClient: def set_model_and_data(self, arg0: int, arg1: int) -> None: ... def sync(self) -> None: ... +class DynamicJointSchema: + joint_names: list[str] + joint_types: list[int] + qpos_sizes: list[int] + qvel_sizes: list[int] + def __init__(self) -> None: ... + +class DynamicJointState: + qpos: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] + qvel: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] + def __init__(self) -> None: ... + class Sim: def __init__(self, mjmdl: int, mjdata: int) -> None: ... def _start_gui_server(self, id: str) -> None: ... def _stop_gui_server(self) -> None: ... def get_config(self) -> SimConfig: ... + def get_dynamic_joint_schema(self) -> DynamicJointSchema: ... + def get_dynamic_joint_state(self) -> DynamicJointState: ... def is_converged(self) -> bool: ... def reset(self) -> None: ... def set_config(self, cfg: SimConfig) -> bool: ... + def set_dynamic_joint_state(self, schema: DynamicJointSchema, state: DynamicJointState) -> None: ... def step(self, k: int) -> None: ... def step_until_convergence(self) -> None: ... def sync_gui(self) -> None: ... diff --git a/python/rcs/envs/sim.py b/python/rcs/envs/sim.py index 77354065..b5348455 100644 --- a/python/rcs/envs/sim.py +++ b/python/rcs/envs/sim.py @@ -2,9 +2,12 @@ from typing import Any, cast import gymnasium as gym +import numpy as np from rcs._core.common import RobotPlatform +from rcs.envs.base import GripperWrapper from rcs.envs.space_utils import ActObsInfoWrapper +import rcs from rcs import sim logger = logging.getLogger(__name__) @@ -40,6 +43,37 @@ def reset( return super().reset(seed=seed, options=options) +class SimStateObservationWrapper(ActObsInfoWrapper): + DYNAMIC_JOINT_SCHEMA_KEY = "dynamic_joint_schema" + DYNAMIC_JOINT_QPOS_KEY = "dynamic_joint_qpos" + DYNAMIC_JOINT_QVEL_KEY = "dynamic_joint_qvel" + + def __init__(self, env): + super().__init__(env) + assert self.env.get_wrapper_attr("PLATFORM") == RobotPlatform.SIMULATION, "Base environment must be simulation." + self.sim = cast(sim.Sim, self.get_wrapper_attr("sim")) + self._dynamic_joint_schema = self.sim.get_dynamic_joint_schema() + self._include_schema_in_next_step = True + + def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + observation = dict(observation) + dynamic_joint_state = self.sim.get_dynamic_joint_state() + observation[self.DYNAMIC_JOINT_QPOS_KEY] = dynamic_joint_state["qpos"] + observation[self.DYNAMIC_JOINT_QVEL_KEY] = dynamic_joint_state["qvel"] + if self._include_schema_in_next_step: + observation[self.DYNAMIC_JOINT_SCHEMA_KEY] = self._dynamic_joint_schema + self._include_schema_in_next_step = False + return observation, info + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, Any]]: + obs, info = super().reset(seed=seed, options=options) + # Re-emit the schema on the first recorded step after each reset. + self._include_schema_in_next_step = True + return obs, info + + class GripperWrapperSim(ActObsInfoWrapper): def __init__(self, env): super().__init__(env) @@ -94,6 +128,187 @@ def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tupl return observation, info +class RandomObjectPos(gym.Wrapper): + """ + Wrapper to randomly re-place an object in the lab environments. + Given the object's joint name and initial pose, its x, y coordinates are randomized, while z remains fixed. + If include_rotation is true, the object's z-axis rotation (yaw) is also randomized. + + Args: + env (gym.Env): The environment to wrap. + simulation (sim.Sim): The simulation instance. + joint_name (str): The name of the free joint attached to the object to manipulate. + init_object_pose (rcs.common.Pose): The initial pose of the object. + include_rotation (bool): Whether to include rotation in the randomization. + """ + + def __init__( + self, + env: gym.Env, + joint_name: str, + init_object_pose: rcs.common.Pose, + include_position: bool = True, + include_rotation: bool = False, + x_scale: float = 0.2, + y_scale: float = 0.2, + x_offset: float = 0.1, + y_offset: float = 0.1, + ): + super().__init__(env) + self.joint_name = joint_name + self.init_object_pose = init_object_pose + self.include_position = include_position + self.include_rotation = include_rotation + self.x_scale = x_scale + self.y_scale = y_scale + self.x_offset = x_offset + self.y_offset = y_offset + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, Any]]: + if options is not None and "RandomObjectPos.init_object_pose" in options: + assert isinstance( + options["RandomObjectPos.init_object_pose"], rcs.common.Pose + ), "RandomObjectPos.init_object_pose must be a rcs.common.Pose" + + self.init_object_pose = options["RandomObjectPos.init_object_pose"] + print("Got random object pos!\n", self.init_object_pose) + del options["RandomObjectPos.init_object_pose"] + obs, info = super().reset(seed=seed, options=options) + + pos_z = self.init_object_pose.translation()[2] + if self.include_position: + pos_x = self.init_object_pose.translation()[0] + np.random.random() * self.x_scale + self.x_offset + pos_y = self.init_object_pose.translation()[1] + np.random.random() * self.y_scale + self.y_offset + else: + pos_x = self.init_object_pose.translation()[0] + pos_y = self.init_object_pose.translation()[1] + + quat = self.init_object_pose.rotation_q() # xyzw format + if self.include_rotation: + random_z_rotation = (np.random.random() - 0.5) * (0.7071068 * 2) + self.get_wrapper_attr("sim").data.joint(self.joint_name).qpos = [ + pos_x, + pos_y, + pos_z, + quat[3] + random_z_rotation, + quat[0], + quat[1], + quat[2] + random_z_rotation, + ] + else: + self.get_wrapper_attr("sim").data.joint(self.joint_name).qpos = [ + pos_x, + pos_y, + pos_z, + quat[3], + quat[0], + quat[1], + quat[2], + ] + + return obs, info + + +class RandomCubePos(gym.Wrapper): + """Wrapper to randomly place cube in the lab environments. + + Works only for single robot + """ + + def __init__(self, env: gym.Env, include_rotation: bool = False, cube_joint_name="box_joint"): + super().__init__(env) + self.include_rotation = include_rotation + self.cube_joint_name = cube_joint_name + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, Any]]: + obs, info = super().reset(seed=seed, options=options) + + iso_cube = np.array([0.498, 0.0, 0.226]) + iso_cube_pose = rcs.common.Pose(translation=np.array(iso_cube), rpy_vector=np.array([0, 0, 0])) # type: ignore + iso_cube = self.get_wrapper_attr("robot").to_pose_in_world_coordinates(iso_cube_pose).translation() + pos_z = 0.0288 + pos_x = iso_cube[0] + np.random.random() * 0.2 - 0.1 + pos_y = iso_cube[1] + np.random.random() * 0.2 - 0.1 + + if self.include_rotation: + self.get_wrapper_attr("sim").data.joint(self.cube_joint_name).qpos = [ + pos_x, + pos_y, + pos_z, + 2 * np.random.random() - 1, + 0, + 0, + 1, + ] + else: + self.get_wrapper_attr("sim").data.joint(self.cube_joint_name).qpos = [pos_x, pos_y, pos_z, 0, 0, 0, 1] + + return obs, info + + +class PickCubeSuccessWrapper(gym.Wrapper): + """ + Wrapper to check if the cube is successfully picked up in the FR3SimplePickUpSim environment. + Cube must be lifted 10 cm above the robot base. + Computes a reward between 0 and 1 based on: + - TCP to object distance + - cube z height + - whether the arm is standing still once the task is solved. + """ + + def __init__(self, env, cube_geom_name="box_geom"): + super().__init__(env) + assert isinstance(self.get_wrapper_attr("robot"), sim.SimRobot), "Robot must be a sim.SimRobot instance." + self._robot = cast(sim.SimRobot, self.get_wrapper_attr("robot")) + self.sim = self.env.get_wrapper_attr("sim") + self.cube_geom_name = cube_geom_name + self.home_pose = self._robot.get_cartesian_position() + self._gripper_closing = 0 + self._gripper = self.get_wrapper_attr("gripper") + + def step(self, action: dict[str, Any]): # type: ignore + obs, reward, _, truncated, info = super().step(action) + if ( + self._gripper.get_normalized_width() > 0.01 + and self._gripper.get_normalized_width() < 0.99 + and obs["gripper"] == GripperWrapper.BINARY_GRIPPER_CLOSED + ): + self._gripper_closing += 1 + else: + self._gripper_closing = 0 + cube_pose = rcs.common.Pose(translation=self.sim.data.geom(self.cube_geom_name).xpos) + cube_pose = self._robot.to_pose_in_robot_coordinates(cube_pose) + tcp_to_obj_dist = np.linalg.norm(cube_pose.translation() - self._robot.get_cartesian_position().translation()) + obj_to_goal_dist = 0.10 - min(cube_pose.translation()[-1], 0.10) + obj_to_goal_dist = np.linalg.norm(cube_pose.translation() - self.home_pose.translation()) + # NOTE: 4 depends on the time passing between each step. + is_grasped = ( + self._gripper_closing >= 4 # gripper is closing since more than 4 steps + and obs["gripper"] == GripperWrapper.BINARY_GRIPPER_CLOSED # command is still close + and tcp_to_obj_dist <= 0.01 # tcp to cube center is max 1cm + ) + success = obj_to_goal_dist <= 0.022 and info["is_grasped"] + movement = np.linalg.norm(self.sim.data.qvel) + + reaching_reward = 1 - np.tanh(5 * tcp_to_obj_dist) + place_reward = 1 - np.tanh(5 * obj_to_goal_dist) + static_reward = 1 - np.tanh(5 * movement) + info["is_grasped"] = is_grasped + info["success"] = success + reward = reaching_reward + place_reward * is_grasped + static_reward * success + reward /= 3 # type: ignore + return obs, reward, success, truncated, info + + def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): + obs, info = super().reset() + self.home_pose = self._robot.get_cartesian_position() + return obs, info + + class DigitalTwin(gym.Wrapper): def __init__(self, env, twin_env): super().__init__(env) diff --git a/python/rcs/sim/sim.py b/python/rcs/sim/sim.py index 07963e4d..80e4b13a 100644 --- a/python/rcs/sim/sim.py +++ b/python/rcs/sim/sim.py @@ -12,6 +12,8 @@ import mujoco as mj import mujoco.viewer import numpy as np +from rcs._core.sim import DynamicJointSchema as _DynamicJointSchema +from rcs._core.sim import DynamicJointState as _DynamicJointState from rcs._core.sim import GuiClient as _GuiClient from rcs._core.sim import Sim as _Sim from rcs.sim import SimConfig, egl_bootstrap @@ -97,6 +99,38 @@ def set_state(self, state: np.ndarray, spec: int | None = None): mj.mj_setState(self.model, self.data, state_array, state_spec) mj.mj_forward(self.model, self.data) + def get_dynamic_joint_schema(self) -> dict[str, list[str] | list[int]]: + schema = super().get_dynamic_joint_schema() + return { + "joint_names": list(schema.joint_names), + "joint_types": list(schema.joint_types), + "qpos_sizes": list(schema.qpos_sizes), + "qvel_sizes": list(schema.qvel_sizes), + } + + def get_dynamic_joint_state(self) -> dict[str, np.ndarray]: + state = super().get_dynamic_joint_state() + return { + "qpos": np.asarray(state.qpos, dtype=np.float64), + "qvel": np.asarray(state.qvel, dtype=np.float64), + } + + def set_dynamic_joint_state( + self, + schema: dict[str, list[str] | list[int]], + state: dict[str, np.ndarray], + ): + dynamic_joint_schema = _DynamicJointSchema() + dynamic_joint_schema.joint_names = list(schema["joint_names"]) + dynamic_joint_schema.joint_types = [int(value) for value in schema["joint_types"]] + dynamic_joint_schema.qpos_sizes = [int(value) for value in schema["qpos_sizes"]] + dynamic_joint_schema.qvel_sizes = [int(value) for value in schema["qvel_sizes"]] + + dynamic_joint_state = _DynamicJointState() + dynamic_joint_state.qpos = np.asarray(state["qpos"], dtype=np.float64) + dynamic_joint_state.qvel = np.asarray(state["qvel"], dtype=np.float64) + super().set_dynamic_joint_state(dynamic_joint_schema, dynamic_joint_state) + def close_gui(self): if self._stop_event is not None: self._stop_event.set() diff --git a/python/rcs/sim_state_replay.py b/python/rcs/sim_state_replay.py new file mode 100644 index 00000000..4a0c7c22 --- /dev/null +++ b/python/rcs/sim_state_replay.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Annotated, Any + +import gymnasium as gym +import numpy as np +import pyarrow.compute as pc +import pyarrow.dataset as ds +import typer +from PIL import Image +from rcs.envs.base import ControlMode +from rcs.envs.sim import SimStateObservationWrapper + +import rcs # noqa: F401 + +app = typer.Typer(help="Replay recorded MuJoCo trajectories from a parquet dataset.") + +DATASET_ARGUMENT = typer.Argument(..., exists=True, file_okay=False, dir_okay=True) +ENV_ID_OPTION = typer.Option("rcs/FR3SimplePickUpSim-v0", help="Gymnasium env id used for replay.") +TRAJECTORY_UUID_OPTION = typer.Option(None, help="UUID of the recorded trajectory to replay.") +CAMERA_OPTION = typer.Option([], "--camera", help="Camera names to enable on the replay env.") +RESOLUTION_OPTION = typer.Option((256, 256), help="Replay camera resolution as WIDTH HEIGHT.") +FRAME_RATE_OPTION = typer.Option(0, help="Replay camera frame rate.") +RENDER_MODE_OPTION = typer.Option("human", help="Gym render mode for the replay env.") +CONTROL_MODE_OPTION = typer.Option(ControlMode.CARTESIAN_TRPY.name, help="Control mode name for env creation.") +SLEEP_OPTION = typer.Option(0.0, help="Optional delay between restored states.") +OUTPUT_DIR_OPTION = typer.Option(None, help="Optional directory for re-rendered RGB frames.") +PREFER_DUCKDB_OPTION = typer.Option(True, help="Use duckdb for parquet loading when it is available.") + + +@dataclass(frozen=True) +class RecordedSimStep: + step: int + uuid: str + timestamp: float | None + observation: dict[str, Any] + + @property + def dynamic_joint_schema(self) -> dict[str, Any] | None: + schema = self.observation.get(SimStateObservationWrapper.DYNAMIC_JOINT_SCHEMA_KEY) + return dict(schema) if schema is not None else None + + @property + def dynamic_joint_state(self) -> dict[str, np.ndarray] | None: + if ( + SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY not in self.observation + or SimStateObservationWrapper.DYNAMIC_JOINT_QVEL_KEY not in self.observation + ): + return None + return { + "qpos": np.asarray(self.observation[SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY], dtype=np.float64), + "qvel": np.asarray(self.observation[SimStateObservationWrapper.DYNAMIC_JOINT_QVEL_KEY], dtype=np.float64), + } + + +class DuckDBUnavailableError(RuntimeError): + pass + + +def _get_duckdb_module(): + try: + import duckdb + except ModuleNotFoundError as exc: + msg = ( + "duckdb is required for the preferred parquet read path but is not installed. " + "Install the 'duckdb' Python package or rely on the pyarrow fallback in library calls." + ) + raise DuckDBUnavailableError(msg) from exc + return duckdb + + +def _load_distinct_uuids_with_duckdb(dataset_path: Path) -> list[str]: + duckdb = _get_duckdb_module() + connection = duckdb.connect() + try: + rows = connection.execute( + "SELECT DISTINCT uuid FROM read_parquet(?) ORDER BY uuid", + [str(dataset_path)], + ).fetchall() + finally: + connection.close() + return [row[0] for row in rows] + + +def _load_distinct_uuids_with_pyarrow(dataset_path: Path) -> list[str]: + dataset = ds.dataset(str(dataset_path), format="parquet") + uuids = dataset.to_table(columns=["uuid"])["uuid"] + return sorted(str(uuid) for uuid in pc.unique(uuids).to_pylist()) + + +def list_trajectory_ids(dataset_path: Path, prefer_duckdb: bool = True) -> list[str]: + if prefer_duckdb: + try: + return _load_distinct_uuids_with_duckdb(dataset_path) + except DuckDBUnavailableError: + pass + return _load_distinct_uuids_with_pyarrow(dataset_path) + + +def _load_trajectory_with_duckdb(dataset_path: Path, trajectory_uuid: str) -> list[RecordedSimStep]: + duckdb = _get_duckdb_module() + connection = duckdb.connect() + try: + table = connection.execute( + "SELECT uuid, step, timestamp, obs FROM read_parquet(?) WHERE uuid = ? ORDER BY step", + [str(dataset_path), trajectory_uuid], + ).to_arrow_table() + finally: + connection.close() + return [ + RecordedSimStep( + step=int(row["step"]), + uuid=str(row["uuid"]), + timestamp=float(row["timestamp"]) if row["timestamp"] is not None else None, + observation=row["obs"], + ) + for row in table.to_pylist() + ] + + +def _load_trajectory_with_pyarrow(dataset_path: Path, trajectory_uuid: str) -> list[RecordedSimStep]: + dataset = ds.dataset(str(dataset_path), format="parquet") + table = dataset.to_table(filter=pc.field("uuid") == trajectory_uuid, columns=["uuid", "step", "timestamp", "obs"]) + rows = table.sort_by([("step", "ascending")]).to_pylist() + return [ + RecordedSimStep( + step=int(row["step"]), + uuid=str(row["uuid"]), + timestamp=float(row["timestamp"]) if row["timestamp"] is not None else None, + observation=row["obs"], + ) + for row in rows + ] + + +def load_trajectory(dataset_path: Path, trajectory_uuid: str, prefer_duckdb: bool = True) -> list[RecordedSimStep]: + if prefer_duckdb: + try: + return _load_trajectory_with_duckdb(dataset_path, trajectory_uuid) + except DuckDBUnavailableError: + pass + return _load_trajectory_with_pyarrow(dataset_path, trajectory_uuid) + + +def resolve_trajectory_uuid(dataset_path: Path, trajectory_uuid: str | None, prefer_duckdb: bool = True) -> str: + if trajectory_uuid is not None: + return trajectory_uuid + available_uuids = list_trajectory_ids(dataset_path, prefer_duckdb=prefer_duckdb) + if len(available_uuids) == 1: + return available_uuids[0] + msg = ( + f"Dataset {dataset_path} contains {len(available_uuids)} trajectories. " + f"Pass --trajectory-uuid and choose one of: {available_uuids}" + ) + raise ValueError(msg) + + +def restore_sim_step( + env: gym.Env, + recorded_step: RecordedSimStep, + dynamic_joint_schema: dict[str, Any] | None = None, +): + sim = env.get_wrapper_attr("sim") + dynamic_joint_state = recorded_step.dynamic_joint_state + if dynamic_joint_state is None: + msg = "Recorded step is missing dynamic joint state data." + raise ValueError(msg) + + resolved_schema = dynamic_joint_schema or recorded_step.dynamic_joint_schema + if resolved_schema is None: + msg = "Recorded dynamic joint state is missing its schema." + raise ValueError(msg) + sim.set_dynamic_joint_state(resolved_schema, dynamic_joint_state) + + +def collect_rgb_frames(env: gym.Env) -> dict[str, np.ndarray]: + try: + camera_set = env.get_wrapper_attr("camera_set") + except AttributeError: + return {} + + frameset = camera_set.get_latest_frames() + if frameset is None: + return {} + + rgb_frames: dict[str, np.ndarray] = {} + for camera_name, frame in frameset.frames.items(): + lower_name = camera_name.lower() + if "digit" in lower_name or "tactile" in lower_name: + continue + rgb_frames[camera_name] = np.asarray(frame.camera.color.data) + return rgb_frames + + +def save_rgb_frames(output_dir: Path, recorded_step: RecordedSimStep, rgb_frames: dict[str, np.ndarray]): + output_dir.mkdir(parents=True, exist_ok=True) + for camera_name, rgb_frame in rgb_frames.items(): + Image.fromarray(rgb_frame).save(output_dir / f"step-{recorded_step.step:06d}-{camera_name}.png") + + +def replay_trajectory( + env: gym.Env, + recorded_steps: list[RecordedSimStep], + *, + sleep_s: float = 0.0, + output_dir: Path | None = None, +): + if not recorded_steps: + msg = "No recorded dynamic joint states found in the requested trajectory." + raise ValueError(msg) + + dynamic_joint_schema = next( + (recorded_step.dynamic_joint_schema for recorded_step in recorded_steps if recorded_step.dynamic_joint_schema), + None, + ) + + env.reset() + for recorded_step in recorded_steps: + restore_sim_step(env, recorded_step, dynamic_joint_schema=dynamic_joint_schema) + if output_dir is not None: + save_rgb_frames(output_dir, recorded_step, collect_rgb_frames(env)) + if sleep_s > 0: + time.sleep(sleep_s) + + +@app.command() +def replay( + dataset: Annotated[Path, DATASET_ARGUMENT], + env_id: Annotated[str, ENV_ID_OPTION], + trajectory_uuid: Annotated[str | None, TRAJECTORY_UUID_OPTION], + camera: Annotated[list[str], CAMERA_OPTION], + resolution: Annotated[tuple[int, int], RESOLUTION_OPTION], + frame_rate: Annotated[int, FRAME_RATE_OPTION], + render_mode: Annotated[str, RENDER_MODE_OPTION], + control_mode: Annotated[str, CONTROL_MODE_OPTION], + sleep_s: Annotated[float, SLEEP_OPTION], + output_dir: Annotated[Path | None, OUTPUT_DIR_OPTION], + prefer_duckdb: Annotated[bool, PREFER_DUCKDB_OPTION], +): + resolved_uuid = resolve_trajectory_uuid(dataset, trajectory_uuid, prefer_duckdb=prefer_duckdb) + env = gym.make( + env_id, + render_mode=render_mode, + control_mode=ControlMode[control_mode], + resolution=resolution, + frame_rate=frame_rate, + cam_list=camera, + ) + try: + recorded_steps = load_trajectory(dataset, resolved_uuid, prefer_duckdb=prefer_duckdb) + replay_trajectory(env, recorded_steps, sleep_s=sleep_s, output_dir=output_dir) + finally: + env.close() + + typer.echo(f"Replayed {len(recorded_steps)} steps from trajectory {resolved_uuid}.") + if output_dir is not None: + typer.echo(f"Saved re-rendered RGB frames to {output_dir}.") + + +if __name__ == "__main__": + app() diff --git a/python/tests/test_sim_state_record_replay.py b/python/tests/test_sim_state_record_replay.py new file mode 100644 index 00000000..2befdc91 --- /dev/null +++ b/python/tests/test_sim_state_record_replay.py @@ -0,0 +1,310 @@ +from __future__ import annotations + +import importlib.util +import sys +import xml.etree.ElementTree as ET +from dataclasses import dataclass +from pathlib import Path + +import gymnasium as gym +import mujoco as mj +import numpy as np +import pyarrow.dataset as ds +from rcs._core.common import RobotPlatform +from rcs._core.sim import SimConfig +from rcs.camera.interface import CameraFrame, DataFrame, Frame, FrameSet +from rcs.envs.base import ControlMode, JointsDictType +from rcs.envs.creators import SimMultiEnvCreator +from rcs.envs.storage_wrapper import StorageWrapper +from rcs.envs.utils import default_sim_gripper_cfg, default_sim_robot_cfg + +import rcs + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def _load_local_module(module_name: str, relative_path: str): + module_path = REPO_ROOT / relative_path + spec = importlib.util.spec_from_file_location(module_name, module_path) + if spec is None or spec.loader is None: + msg = f"Could not create an import spec for {module_name} from {module_path}." + raise ImportError(msg) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + parent_name, _, child_name = module_name.rpartition(".") + if parent_name: + parent_module = sys.modules[parent_name] + setattr(parent_module, child_name, module) + spec.loader.exec_module(module) + return module + + +local_sim_module = _load_local_module("rcs.sim.sim", "python/rcs/sim/sim.py") +rcs.sim.__dict__["Sim"] = local_sim_module.Sim +_load_local_module("rcs.envs.sim", "python/rcs/envs/sim.py") +_load_local_module("rcs.sim_state_replay", "python/rcs/sim_state_replay.py") + +from rcs.envs.sim import SimStateObservationWrapper # noqa: E402 +from rcs.sim.sim import Sim # noqa: E402 +from rcs.sim_state_replay import ( # noqa: E402 + load_trajectory, + replay_trajectory, + restore_sim_step, +) + +XML = """ + + + + + + + + + +""" + + +@dataclass +class DummyCameraSet: + sim: Sim + + def get_latest_frames(self) -> FrameSet: + color_value = int(np.clip(round((self.sim.data.qpos[0] + 1.0) * 80.0), 0, 255)) + rgb = np.full((8, 8, 3), color_value, dtype=np.uint8) + return FrameSet( + frames={ + "main": Frame( + camera=CameraFrame( + color=DataFrame(data=rgb), + depth=None, + ), + ) + }, + avg_timestamp=None, + ) + + +class DummySimEnv(gym.Env): + PLATFORM = RobotPlatform.SIMULATION + + def __init__(self, sim: Sim, camera_set: DummyCameraSet | None = None): + super().__init__() + self.sim = sim + self.camera_set = camera_set + self.action_space = gym.spaces.Dict( + { + "delta": gym.spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float64), + } + ) + self.observation_space = gym.spaces.Dict( + { + "qpos": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.sim.model.nq,), dtype=np.float64), + "qvel": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.sim.model.nv,), dtype=np.float64), + } + ) + + def _obs(self) -> dict[str, np.ndarray]: + return { + "qpos": self.sim.data.qpos.copy(), + "qvel": self.sim.data.qvel.copy(), + } + + def get_wrapper_attr(self, name: str): + return getattr(self, name) + + def reset(self, *, seed: int | None = None, options: dict | None = None): + super().reset(seed=seed) + mj.mj_resetData(self.sim.model, self.sim.data) + mj.mj_forward(self.sim.model, self.sim.data) + return self._obs(), {} + + def step(self, action: dict[str, np.ndarray]): + self.sim.data.qpos[0] += float(action["delta"][0]) + self.sim.data.qvel[:] = 0.0 + mj.mj_forward(self.sim.model, self.sim.data) + return self._obs(), 0.0, False, False, {} + + def close(self): + return None + + +def test_record_and_replay_sim_state(tmp_path: Path): + model_path = tmp_path / "dummy.xml" + model_path.write_text(XML) + + dataset_path = tmp_path / "dataset" + record_env: gym.Env = DummySimEnv(Sim(model_path)) + record_env = SimStateObservationWrapper(record_env) + record_env = StorageWrapper(record_env, str(dataset_path), "test sim replay", batch_size=1, always_record=True) + + obs, _ = record_env.reset() + assert SimStateObservationWrapper.DYNAMIC_JOINT_SCHEMA_KEY in obs + + record_env.step({"delta": np.array([0.125], dtype=np.float64)}) + record_env.close() + + table = ds.dataset(str(dataset_path), format="parquet").to_table().sort_by([("step", "ascending")]) + rows = table.to_pylist() + assert len(rows) == 1 + + recorded_obs = rows[0]["obs"] + assert SimStateObservationWrapper.DYNAMIC_JOINT_SCHEMA_KEY in recorded_obs + assert SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY in recorded_obs + assert SimStateObservationWrapper.DYNAMIC_JOINT_QVEL_KEY in recorded_obs + + recorded_steps = load_trajectory(dataset_path, rows[0]["uuid"], prefer_duckdb=True) + assert len(recorded_steps) == 1 + assert recorded_steps[0].dynamic_joint_schema is not None + assert np.allclose( + recorded_steps[0].dynamic_joint_state["qpos"], # type: ignore[index] + np.asarray(recorded_obs[SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY]), + ) + + replay_sim = Sim(model_path) + replay_env: gym.Env = DummySimEnv(replay_sim, camera_set=DummyCameraSet(replay_sim)) + replay_env = SimStateObservationWrapper(replay_env) + render_dir = tmp_path / "rendered" + + replay_env.reset() + restore_sim_step(replay_env, recorded_steps[0], dynamic_joint_schema=recorded_steps[0].dynamic_joint_schema) + assert np.allclose( + replay_env.get_wrapper_attr("sim").data.qpos, np.asarray(recorded_obs["qpos"]), atol=1e-9, rtol=0 + ) + assert np.allclose( + replay_env.get_wrapper_attr("sim").data.qvel, np.asarray(recorded_obs["qvel"]), atol=1e-9, rtol=0 + ) + + replay_trajectory(replay_env, recorded_steps, output_dir=render_dir) + + rendered_files = sorted(path.name for path in render_dir.glob("*.png")) + assert rendered_files == ["step-000000-main.png"] + + +def _write_scene_with_extra_fixed_body_and_camera(src: Path, dst: Path): + tree = ET.parse(src) + root = tree.getroot() + for include in root.findall("include"): + include_file = include.get("file") + if include_file is not None and not Path(include_file).is_absolute(): + include.set("file", str((src.parent / include_file).resolve())) + + worldbody = root.find("worldbody") + assert worldbody is not None + + worldbody.append( + ET.Element( + "camera", + { + "name": "replay_extra_cam", + "pos": "1.4 0.0 0.9", + "xyaxes": "0 1 0 -0.3 0 1", + }, + ) + ) + body = ET.SubElement(worldbody, "body", {"name": "replay_extra_bg", "pos": "3 3 3"}) + ET.SubElement(body, "geom", {"name": "replay_extra_bg_geom", "type": "box", "size": "0.1 0.1 0.1"}) + tree.write(dst) + + +def _record_dummy_trajectory(dataset_path: Path, model_path: Path) -> tuple[list, dict[str, object]]: + record_env: gym.Env = DummySimEnv(Sim(model_path)) + record_env = SimStateObservationWrapper(record_env) + record_env = StorageWrapper(record_env, str(dataset_path), "test sim replay", batch_size=1, always_record=True) + record_env.reset() + record_env.step({"delta": np.array([0.125], dtype=np.float64)}) + record_env.close() + + table = ds.dataset(str(dataset_path), format="parquet").to_table().sort_by([("step", "ascending")]) + rows = table.to_pylist() + recorded_steps = load_trajectory(dataset_path, rows[0]["uuid"], prefer_duckdb=True) + return recorded_steps, rows[0]["obs"] + + +def test_dynamic_joint_replay_tolerates_added_and_removed_fixed_scene_elements(tmp_path: Path): + base_model_path = tmp_path / "base.xml" + base_model_path.write_text(XML) + modified_model_path = tmp_path / "modified.xml" + _write_scene_with_extra_fixed_body_and_camera(base_model_path, modified_model_path) + + for record_model_path, replay_model_path in ( + (base_model_path, modified_model_path), + (modified_model_path, base_model_path), + ): + dataset_path = tmp_path / f"dataset-{record_model_path.stem}-to-{replay_model_path.stem}" + recorded_steps, recorded_obs = _record_dummy_trajectory(dataset_path, record_model_path) + + replay_sim = Sim(replay_model_path) + replay_env: gym.Env = DummySimEnv(replay_sim) + replay_env = SimStateObservationWrapper(replay_env) + replay_env.reset() + dynamic_joint_schema = next( + step.dynamic_joint_schema for step in recorded_steps if step.dynamic_joint_schema is not None + ) + restore_sim_step(replay_env, recorded_steps[0], dynamic_joint_schema=dynamic_joint_schema) + + assert np.allclose( + replay_env.get_wrapper_attr("sim").data.qpos, np.asarray(recorded_obs["qpos"]), atol=1e-9, rtol=0 + ) + assert np.allclose( + replay_env.get_wrapper_attr("sim").data.qvel, np.asarray(recorded_obs["qvel"]), atol=1e-9, rtol=0 + ) + + +DUAL_ARM_ROBOT2ID = {"left": "0", "right": "1"} + + +def _create_dual_arm_env(scene_name: str): + robot_cfg = default_sim_robot_cfg(scene_name, idx="") + sim_cfg = SimConfig() + sim_cfg.async_control = False + return SimMultiEnvCreator()( + name2id=DUAL_ARM_ROBOT2ID, + robot_cfg=robot_cfg, + control_mode=ControlMode.JOINTS, + gripper_cfg=default_sim_gripper_cfg(idx=""), + sim_cfg=sim_cfg, + max_relative_movement=None, + ) + + +def test_dynamic_joint_state_roundtrip_on_fr3_dual_arm_scene(tmp_path: Path): + source_scene_path = REPO_ROOT / "assets/scenes/fr3_dual_arm/scene.xml" + source_robot_path = REPO_ROOT / "assets/scenes/fr3_empty_world/robot.xml" + source_urdf_path = REPO_ROOT / "assets/scenes/fr3_empty_world/robot.urdf" + modified_scene_path = source_scene_path.parent / "scene_dynamic_joint_test.xml" + _write_scene_with_extra_fixed_body_and_camera(source_scene_path, modified_scene_path) + + base_scene_name = "fr3_dual_arm_dynamic_joint_base_test" + test_scene_name = "fr3_dual_arm_dynamic_joint_test" + scene_kwargs = { + "mjcf_robot": str(source_robot_path), + "urdf": str(source_urdf_path), + "robot_type": rcs.scenes["fr3_dual_arm"].robot_type, + "mjb": None, + } + rcs.scenes[base_scene_name] = rcs.Scene(mjcf_scene=str(source_scene_path), **scene_kwargs) + rcs.scenes[test_scene_name] = rcs.Scene(mjcf_scene=str(modified_scene_path), **scene_kwargs) + + base_env = _create_dual_arm_env(base_scene_name) + modified_env = _create_dual_arm_env(test_scene_name) + try: + base_env.reset() + base_sim = base_env.get_wrapper_attr("sim") + dynamic_joint_schema = base_sim.get_dynamic_joint_schema() + dynamic_joint_state = base_sim.get_dynamic_joint_state() + + modified_env.reset() + modified_sim = modified_env.get_wrapper_attr("sim") + modified_sim.set_dynamic_joint_state(dynamic_joint_schema, dynamic_joint_state) + restored_dynamic_joint_state = modified_sim.get_dynamic_joint_state() + + assert dynamic_joint_schema == modified_sim.get_dynamic_joint_schema() + assert np.allclose(restored_dynamic_joint_state["qpos"], dynamic_joint_state["qpos"], atol=1e-9, rtol=0) + assert np.allclose(restored_dynamic_joint_state["qvel"], dynamic_joint_state["qvel"], atol=1e-9, rtol=0) + finally: + base_env.close() + modified_env.close() + del rcs.scenes[test_scene_name] + del rcs.scenes[base_scene_name] + modified_scene_path.unlink(missing_ok=True) diff --git a/src/pybind/rcs.cpp b/src/pybind/rcs.cpp index ab7772c0..9be15edf 100644 --- a/src/pybind/rcs.cpp +++ b/src/pybind/rcs.cpp @@ -730,6 +730,18 @@ PYBIND11_MODULE(_core, m) { return rcs::sim::SimConfig(self); }); + py::class_(sim, "DynamicJointSchema") + .def(py::init<>()) + .def_readwrite("joint_names", &rcs::sim::DynamicJointSchema::joint_names) + .def_readwrite("joint_types", &rcs::sim::DynamicJointSchema::joint_types) + .def_readwrite("qpos_sizes", &rcs::sim::DynamicJointSchema::qpos_sizes) + .def_readwrite("qvel_sizes", &rcs::sim::DynamicJointSchema::qvel_sizes); + + py::class_(sim, "DynamicJointState") + .def(py::init<>()) + .def_readwrite("qpos", &rcs::sim::DynamicJointState::qpos) + .def_readwrite("qvel", &rcs::sim::DynamicJointState::qvel); + py::class_>(sim, "Sim") .def(py::init([](long m, long d) { return std::make_shared((mjModel*)m, (mjData*)d); @@ -743,6 +755,10 @@ PYBIND11_MODULE(_core, m) { .def("step", &rcs::sim::Sim::step, py::arg("k")) .def("reset", &rcs::sim::Sim::reset) .def("sync_gui", &rcs::sim::Sim::sync_gui) + .def("get_dynamic_joint_schema", &rcs::sim::Sim::get_dynamic_joint_schema) + .def("get_dynamic_joint_state", &rcs::sim::Sim::get_dynamic_joint_state) + .def("set_dynamic_joint_state", &rcs::sim::Sim::set_dynamic_joint_state, + py::arg("schema"), py::arg("state")) .def("_start_gui_server", &rcs::sim::Sim::start_gui_server, py::arg("id")) .def("_stop_gui_server", &rcs::sim::Sim::stop_gui_server); diff --git a/src/sim/sim.cpp b/src/sim/sim.cpp index 8facd93d..54804cf6 100644 --- a/src/sim/sim.cpp +++ b/src/sim/sim.cpp @@ -5,6 +5,9 @@ #include #include #include +#include +#include +#include #include namespace rcs { @@ -25,7 +28,65 @@ bool get_last_return_value(ConditionCallback cb) { return cb.last_return_value; } -Sim::Sim(mjModel* m, mjData* d) : m(m), d(d), renderer(m) {}; +int Sim::get_joint_qpos_size(int joint_type) { + switch (joint_type) { + case mjJNT_FREE: + return 7; + case mjJNT_BALL: + return 4; + case mjJNT_SLIDE: + case mjJNT_HINGE: + return 1; + default: + throw std::runtime_error("Unsupported MuJoCo joint type for qpos size."); + } +} + +int Sim::get_joint_qvel_size(int joint_type) { + switch (joint_type) { + case mjJNT_FREE: + return 6; + case mjJNT_BALL: + return 3; + case mjJNT_SLIDE: + case mjJNT_HINGE: + return 1; + default: + throw std::runtime_error("Unsupported MuJoCo joint type for qvel size."); + } +} + +void Sim::init_dynamic_joint_specs() { + this->dynamic_joint_specs.clear(); + this->dynamic_joint_name_to_index.clear(); + + for (int joint_id = 0; joint_id < this->m->njnt; ++joint_id) { + const char* joint_name = mj_id2name(this->m, mjOBJ_JOINT, joint_id); + if (joint_name == nullptr || joint_name[0] == '\0') { + std::ostringstream msg; + msg << "Dynamic joint state requires all joints to be named. Joint id " + << joint_id << " is unnamed."; + throw std::runtime_error(msg.str()); + } + + DynamicJointSpec spec{ + .name = joint_name, + .type = this->m->jnt_type[joint_id], + .qpos_adr = this->m->jnt_qposadr[joint_id], + .qvel_adr = this->m->jnt_dofadr[joint_id], + .qpos_size = get_joint_qpos_size(this->m->jnt_type[joint_id]), + .qvel_size = get_joint_qvel_size(this->m->jnt_type[joint_id]), + }; + + this->dynamic_joint_name_to_index[spec.name] = + this->dynamic_joint_specs.size(); + this->dynamic_joint_specs.push_back(spec); + } +} + +Sim::Sim(mjModel* m, mjData* d) : m(m), d(d), renderer(m) { + this->init_dynamic_joint_specs(); +}; bool Sim::set_config(const SimConfig& cfg) { this->cfg = cfg; @@ -118,6 +179,109 @@ void Sim::reset() { this->reset_callbacks(); } +DynamicJointSchema Sim::get_dynamic_joint_schema() const { + DynamicJointSchema schema; + schema.joint_names.reserve(this->dynamic_joint_specs.size()); + schema.joint_types.reserve(this->dynamic_joint_specs.size()); + schema.qpos_sizes.reserve(this->dynamic_joint_specs.size()); + schema.qvel_sizes.reserve(this->dynamic_joint_specs.size()); + + for (const DynamicJointSpec& spec : this->dynamic_joint_specs) { + schema.joint_names.push_back(spec.name); + schema.joint_types.push_back(spec.type); + schema.qpos_sizes.push_back(spec.qpos_size); + schema.qvel_sizes.push_back(spec.qvel_size); + } + return schema; +} + +DynamicJointState Sim::get_dynamic_joint_state() const { + DynamicJointState state; + int total_qpos = 0; + int total_qvel = 0; + for (const DynamicJointSpec& spec : this->dynamic_joint_specs) { + total_qpos += spec.qpos_size; + total_qvel += spec.qvel_size; + } + + state.qpos = rcs::common::VectorXd(total_qpos); + state.qvel = rcs::common::VectorXd(total_qvel); + + int qpos_offset = 0; + int qvel_offset = 0; + for (const DynamicJointSpec& spec : this->dynamic_joint_specs) { + for (int i = 0; i < spec.qpos_size; ++i) { + state.qpos[qpos_offset + i] = this->d->qpos[spec.qpos_adr + i]; + } + for (int i = 0; i < spec.qvel_size; ++i) { + state.qvel[qvel_offset + i] = this->d->qvel[spec.qvel_adr + i]; + } + qpos_offset += spec.qpos_size; + qvel_offset += spec.qvel_size; + } + + return state; +} + +void Sim::set_dynamic_joint_state(const DynamicJointSchema& schema, + const DynamicJointState& state) { + size_t joint_count = schema.joint_names.size(); + if (schema.joint_types.size() != joint_count || + schema.qpos_sizes.size() != joint_count || + schema.qvel_sizes.size() != joint_count) { + throw std::invalid_argument( + "Dynamic joint schema fields must all have the same length."); + } + + int expected_qpos_size = std::accumulate(schema.qpos_sizes.begin(), + schema.qpos_sizes.end(), 0); + int expected_qvel_size = std::accumulate(schema.qvel_sizes.begin(), + schema.qvel_sizes.end(), 0); + if (state.qpos.size() != expected_qpos_size) { + std::ostringstream msg; + msg << "Dynamic joint qpos size mismatch. Expected " + << expected_qpos_size << ", got " << state.qpos.size() << "."; + throw std::invalid_argument(msg.str()); + } + if (state.qvel.size() != expected_qvel_size) { + std::ostringstream msg; + msg << "Dynamic joint qvel size mismatch. Expected " + << expected_qvel_size << ", got " << state.qvel.size() << "."; + throw std::invalid_argument(msg.str()); + } + + int qpos_offset = 0; + int qvel_offset = 0; + for (size_t i = 0; i < joint_count; ++i) { + auto spec_iter = + this->dynamic_joint_name_to_index.find(schema.joint_names[i]); + if (spec_iter != this->dynamic_joint_name_to_index.end()) { + const DynamicJointSpec& target_spec = + this->dynamic_joint_specs[spec_iter->second]; + if (target_spec.type != schema.joint_types[i] || + target_spec.qpos_size != schema.qpos_sizes[i] || + target_spec.qvel_size != schema.qvel_sizes[i]) { + std::ostringstream msg; + msg << "Dynamic joint schema mismatch for joint '" + << schema.joint_names[i] << "'."; + throw std::invalid_argument(msg.str()); + } + + for (int j = 0; j < target_spec.qpos_size; ++j) { + this->d->qpos[target_spec.qpos_adr + j] = state.qpos[qpos_offset + j]; + } + for (int j = 0; j < target_spec.qvel_size; ++j) { + this->d->qvel[target_spec.qvel_adr + j] = state.qvel[qvel_offset + j]; + } + } + + qpos_offset += schema.qpos_sizes[i]; + qvel_offset += schema.qvel_sizes[i]; + } + + mj_forward(this->m, this->d); +} + void Sim::reset_callbacks() { for (size_t i = 0; i < std::size(this->callbacks); ++i) { this->callbacks[i].last_call_timestamp = 0; diff --git a/src/sim/sim.h b/src/sim/sim.h index 4ed35e60..6cae264c 100644 --- a/src/sim/sim.h +++ b/src/sim/sim.h @@ -5,10 +5,13 @@ #include #include #include +#include +#include #include "boost/interprocess/managed_shared_memory.hpp" #include "gui.h" #include "mujoco/mujoco.h" +#include "rcs/utils.h" namespace rcs { namespace sim { @@ -55,16 +58,42 @@ struct RenderingCallback { mjtNum last_call_timestamp; // in seconds }; +struct DynamicJointSchema { + std::vector joint_names; + std::vector joint_types; + std::vector qpos_sizes; + std::vector qvel_sizes; +}; + +struct DynamicJointState { + rcs::common::VectorXd qpos; + rcs::common::VectorXd qvel; +}; + class Sim { private: + struct DynamicJointSpec { + std::string name; + int type; + int qpos_adr; + int qvel_adr; + int qpos_size; + int qvel_size; + }; + SimConfig cfg; std::vector callbacks; std::vector any_callbacks; std::vector all_callbacks; std::vector rendering_callbacks; + std::vector dynamic_joint_specs; + std::unordered_map dynamic_joint_name_to_index; void invoke_callbacks(); bool invoke_condition_callbacks(); void invoke_rendering_callbacks(); + void init_dynamic_joint_specs(); + static int get_joint_qpos_size(int joint_type); + static int get_joint_qvel_size(int joint_type); size_t convergence_steps = 0; bool converged = true; std::optional gui; @@ -83,6 +112,10 @@ class Sim { void step(size_t k); void reset_callbacks(); void reset(); + DynamicJointSchema get_dynamic_joint_schema() const; + DynamicJointState get_dynamic_joint_state() const; + void set_dynamic_joint_state(const DynamicJointSchema& schema, + const DynamicJointState& state); /* NOTE: IMPORTANT, the callback is not necessarily called at exactly the * the requested interval. We invoke a callback if the elapsed simulation time * since the last call of the callback is greater than the requested time. From ad4c4ebb278a35279508c6ad5d482d2b105b22c1 Mon Sep 17 00:00:00 2001 From: Tobias Juelg Date: Sat, 2 May 2026 15:39:27 +0200 Subject: [PATCH 02/13] Remove legacy replay fallback --- src/sim/sim.cpp | 60 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 42 insertions(+), 18 deletions(-) diff --git a/src/sim/sim.cpp b/src/sim/sim.cpp index 54804cf6..e8d4a697 100644 --- a/src/sim/sim.cpp +++ b/src/sim/sim.cpp @@ -250,35 +250,59 @@ void Sim::set_dynamic_joint_state(const DynamicJointSchema& schema, throw std::invalid_argument(msg.str()); } + std::vector matched_target_joints(this->dynamic_joint_specs.size(), + false); int qpos_offset = 0; int qvel_offset = 0; for (size_t i = 0; i < joint_count; ++i) { auto spec_iter = this->dynamic_joint_name_to_index.find(schema.joint_names[i]); - if (spec_iter != this->dynamic_joint_name_to_index.end()) { - const DynamicJointSpec& target_spec = - this->dynamic_joint_specs[spec_iter->second]; - if (target_spec.type != schema.joint_types[i] || - target_spec.qpos_size != schema.qpos_sizes[i] || - target_spec.qvel_size != schema.qvel_sizes[i]) { - std::ostringstream msg; - msg << "Dynamic joint schema mismatch for joint '" - << schema.joint_names[i] << "'."; - throw std::invalid_argument(msg.str()); - } - - for (int j = 0; j < target_spec.qpos_size; ++j) { - this->d->qpos[target_spec.qpos_adr + j] = state.qpos[qpos_offset + j]; - } - for (int j = 0; j < target_spec.qvel_size; ++j) { - this->d->qvel[target_spec.qvel_adr + j] = state.qvel[qvel_offset + j]; - } + if (spec_iter == this->dynamic_joint_name_to_index.end()) { + std::cerr << "WARNING: Recorded dynamic joint '" << schema.joint_names[i] + << "' is missing in the replay model. Skipping it." + << std::endl; + qpos_offset += schema.qpos_sizes[i]; + qvel_offset += schema.qvel_sizes[i]; + continue; + } + + const DynamicJointSpec& target_spec = + this->dynamic_joint_specs[spec_iter->second]; + matched_target_joints[spec_iter->second] = true; + if (target_spec.type != schema.joint_types[i] || + target_spec.qpos_size != schema.qpos_sizes[i] || + target_spec.qvel_size != schema.qvel_sizes[i]) { + std::ostringstream msg; + msg << "Dynamic joint schema mismatch for joint '" + << schema.joint_names[i] << "': expected type=" << target_spec.type + << ", qpos_size=" << target_spec.qpos_size + << ", qvel_size=" << target_spec.qvel_size << " but got type=" + << schema.joint_types[i] << ", qpos_size=" << schema.qpos_sizes[i] + << ", qvel_size=" << schema.qvel_sizes[i] << "."; + throw std::invalid_argument(msg.str()); + } + + for (int j = 0; j < target_spec.qpos_size; ++j) { + this->d->qpos[target_spec.qpos_adr + j] = state.qpos[qpos_offset + j]; + } + for (int j = 0; j < target_spec.qvel_size; ++j) { + this->d->qvel[target_spec.qvel_adr + j] = state.qvel[qvel_offset + j]; } qpos_offset += schema.qpos_sizes[i]; qvel_offset += schema.qvel_sizes[i]; } + for (size_t i = 0; i < this->dynamic_joint_specs.size(); ++i) { + if (!matched_target_joints[i]) { + std::cerr << "WARNING: Replay model dynamic joint '" + << this->dynamic_joint_specs[i].name + << "' is missing in the recorded schema. Leaving it at its " + "current value." + << std::endl; + } + } + mj_forward(this->m, this->d); } From e718d15d6d98122de576e6c6a34425e96524ffd3 Mon Sep 17 00:00:00 2001 From: Tobias Juelg Date: Sun, 3 May 2026 10:36:18 +0200 Subject: [PATCH 03/13] Refactor replay onto existing sim_state flow --- python/rcs/envs/sim.py | 25 ++-- python/rcs/sim/sim.py | 47 +++--- python/rcs/sim_state_replay.py | 43 ++---- python/tests/test_sim_state_record_replay.py | 147 +++++++++---------- 4 files changed, 126 insertions(+), 136 deletions(-) diff --git a/python/rcs/envs/sim.py b/python/rcs/envs/sim.py index b5348455..287e5b90 100644 --- a/python/rcs/envs/sim.py +++ b/python/rcs/envs/sim.py @@ -44,33 +44,32 @@ def reset( class SimStateObservationWrapper(ActObsInfoWrapper): - DYNAMIC_JOINT_SCHEMA_KEY = "dynamic_joint_schema" - DYNAMIC_JOINT_QPOS_KEY = "dynamic_joint_qpos" - DYNAMIC_JOINT_QVEL_KEY = "dynamic_joint_qvel" + STATE_KEY = "sim_state" + STATE_SPEC_KEY = "sim_state_spec" + STATE_SIZE_KEY = "sim_state_size" def __init__(self, env): super().__init__(env) assert self.env.get_wrapper_attr("PLATFORM") == RobotPlatform.SIMULATION, "Base environment must be simulation." self.sim = cast(sim.Sim, self.get_wrapper_attr("sim")) - self._dynamic_joint_schema = self.sim.get_dynamic_joint_schema() - self._include_schema_in_next_step = True + self._state_spec = self.sim.get_state_spec() + self._include_state_spec_in_next_step = True def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: observation = dict(observation) - dynamic_joint_state = self.sim.get_dynamic_joint_state() - observation[self.DYNAMIC_JOINT_QPOS_KEY] = dynamic_joint_state["qpos"] - observation[self.DYNAMIC_JOINT_QVEL_KEY] = dynamic_joint_state["qvel"] - if self._include_schema_in_next_step: - observation[self.DYNAMIC_JOINT_SCHEMA_KEY] = self._dynamic_joint_schema - self._include_schema_in_next_step = False + sim_state = self.sim.get_state() + observation[self.STATE_KEY] = sim_state + observation[self.STATE_SIZE_KEY] = sim_state.shape[0] + if self._include_state_spec_in_next_step: + observation[self.STATE_SPEC_KEY] = self._state_spec + self._include_state_spec_in_next_step = False return observation, info def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[dict[str, Any], dict[str, Any]]: obs, info = super().reset(seed=seed, options=options) - # Re-emit the schema on the first recorded step after each reset. - self._include_schema_in_next_step = True + self._include_state_spec_in_next_step = True return obs, info diff --git a/python/rcs/sim/sim.py b/python/rcs/sim/sim.py index 80e4b13a..ae24969b 100644 --- a/python/rcs/sim/sim.py +++ b/python/rcs/sim/sim.py @@ -47,8 +47,6 @@ def gui_loop(gui_uuid: str, close_event): class Sim(_Sim): - STATE_SPEC = mj.mjtState.mjSTATE_INTEGRATION - def __init__(self, mjmdl: str | PathLike | ModelComposer, cfg: SimConfig | None = None): if isinstance(mjmdl, ModelComposer): self.model = mjmdl.get_model() @@ -73,31 +71,38 @@ def __init__(self, mjmdl: str | PathLike | ModelComposer, cfg: SimConfig | None if cfg is not None: self.set_config(cfg) - def get_state_spec(self) -> int: - return int(self.STATE_SPEC) + def get_state_spec(self) -> dict[str, list[str] | list[int]]: + return self.get_dynamic_joint_schema() - def get_state_size(self, spec: int | None = None) -> int: - state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec) - return mj.mj_stateSize(self.model, state_spec) + def get_state_size(self, spec: dict[str, list[str] | list[int]] | None = None) -> int: + state_spec = self.get_state_spec() if spec is None else spec + qpos_size = sum(int(value) for value in state_spec["qpos_sizes"]) + qvel_size = sum(int(value) for value in state_spec["qvel_sizes"]) + return qpos_size + qvel_size - def get_state(self, spec: int | None = None) -> np.ndarray: - state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec) - state = np.empty(self.get_state_size(int(state_spec)), dtype=np.float64) - mj.mj_getState(self.model, self.data, state, state_spec) - return state + def get_state(self, spec: dict[str, list[str] | list[int]] | None = None) -> np.ndarray: + del spec + dynamic_joint_state = self.get_dynamic_joint_state() + return np.concatenate((dynamic_joint_state["qpos"], dynamic_joint_state["qvel"])) - def set_state(self, state: np.ndarray, spec: int | None = None): - state_spec = self.STATE_SPEC if spec is None else mj.mjtState(spec) + def set_state( + self, + state: np.ndarray, + spec: dict[str, list[str] | list[int]] | None = None, + ): + state_spec = self.get_state_spec() if spec is None else spec state_array = np.asarray(state, dtype=np.float64) - expected_size = self.get_state_size(int(state_spec)) + expected_size = self.get_state_size(state_spec) if state_array.shape != (expected_size,): - msg = ( - f"Expected MuJoCo state with shape ({expected_size},), " - f"got {state_array.shape} for spec {int(state_spec)}." - ) + msg = f"Expected state with shape ({expected_size},), got {state_array.shape}." raise ValueError(msg) - mj.mj_setState(self.model, self.data, state_array, state_spec) - mj.mj_forward(self.model, self.data) + + qpos_size = sum(int(value) for value in state_spec["qpos_sizes"]) + dynamic_joint_state = { + "qpos": state_array[:qpos_size], + "qvel": state_array[qpos_size:], + } + self.set_dynamic_joint_state(state_spec, dynamic_joint_state) def get_dynamic_joint_schema(self) -> dict[str, list[str] | list[int]]: schema = super().get_dynamic_joint_schema() diff --git a/python/rcs/sim_state_replay.py b/python/rcs/sim_state_replay.py index 4a0c7c22..0f2a29b5 100644 --- a/python/rcs/sim_state_replay.py +++ b/python/rcs/sim_state_replay.py @@ -39,21 +39,13 @@ class RecordedSimStep: observation: dict[str, Any] @property - def dynamic_joint_schema(self) -> dict[str, Any] | None: - schema = self.observation.get(SimStateObservationWrapper.DYNAMIC_JOINT_SCHEMA_KEY) - return dict(schema) if schema is not None else None + def sim_state(self) -> np.ndarray: + return np.asarray(self.observation[SimStateObservationWrapper.STATE_KEY], dtype=np.float64) @property - def dynamic_joint_state(self) -> dict[str, np.ndarray] | None: - if ( - SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY not in self.observation - or SimStateObservationWrapper.DYNAMIC_JOINT_QVEL_KEY not in self.observation - ): - return None - return { - "qpos": np.asarray(self.observation[SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY], dtype=np.float64), - "qvel": np.asarray(self.observation[SimStateObservationWrapper.DYNAMIC_JOINT_QVEL_KEY], dtype=np.float64), - } + def sim_state_spec(self) -> dict[str, Any] | None: + schema = self.observation.get(SimStateObservationWrapper.STATE_SPEC_KEY) + return dict(schema) if schema is not None else None class DuckDBUnavailableError(RuntimeError): @@ -161,19 +153,15 @@ def resolve_trajectory_uuid(dataset_path: Path, trajectory_uuid: str | None, pre def restore_sim_step( env: gym.Env, recorded_step: RecordedSimStep, - dynamic_joint_schema: dict[str, Any] | None = None, + sim_state_spec: dict[str, Any] | None = None, ): - sim = env.get_wrapper_attr("sim") - dynamic_joint_state = recorded_step.dynamic_joint_state - if dynamic_joint_state is None: - msg = "Recorded step is missing dynamic joint state data." + resolved_spec = sim_state_spec or recorded_step.sim_state_spec + if resolved_spec is None: + msg = "Recorded sim state is missing its schema." raise ValueError(msg) - resolved_schema = dynamic_joint_schema or recorded_step.dynamic_joint_schema - if resolved_schema is None: - msg = "Recorded dynamic joint state is missing its schema." - raise ValueError(msg) - sim.set_dynamic_joint_state(resolved_schema, dynamic_joint_state) + sim = env.get_wrapper_attr("sim") + sim.set_state(recorded_step.sim_state, spec=resolved_spec) def collect_rgb_frames(env: gym.Env) -> dict[str, np.ndarray]: @@ -209,17 +197,16 @@ def replay_trajectory( output_dir: Path | None = None, ): if not recorded_steps: - msg = "No recorded dynamic joint states found in the requested trajectory." + msg = "No recorded sim states found in the requested trajectory." raise ValueError(msg) - dynamic_joint_schema = next( - (recorded_step.dynamic_joint_schema for recorded_step in recorded_steps if recorded_step.dynamic_joint_schema), - None, + sim_state_spec = next( + (recorded_step.sim_state_spec for recorded_step in recorded_steps if recorded_step.sim_state_spec), None ) env.reset() for recorded_step in recorded_steps: - restore_sim_step(env, recorded_step, dynamic_joint_schema=dynamic_joint_schema) + restore_sim_step(env, recorded_step, sim_state_spec=sim_state_spec) if output_dir is not None: save_rgb_frames(output_dir, recorded_step, collect_rgb_frames(env)) if sleep_s > 0: diff --git a/python/tests/test_sim_state_record_replay.py b/python/tests/test_sim_state_record_replay.py index 2befdc91..1d7b00b7 100644 --- a/python/tests/test_sim_state_record_replay.py +++ b/python/tests/test_sim_state_record_replay.py @@ -11,12 +11,8 @@ import numpy as np import pyarrow.dataset as ds from rcs._core.common import RobotPlatform -from rcs._core.sim import SimConfig from rcs.camera.interface import CameraFrame, DataFrame, Frame, FrameSet -from rcs.envs.base import ControlMode, JointsDictType -from rcs.envs.creators import SimMultiEnvCreator from rcs.envs.storage_wrapper import StorageWrapper -from rcs.envs.utils import default_sim_gripper_cfg, default_sim_robot_cfg import rcs @@ -117,13 +113,13 @@ def reset(self, *, seed: int | None = None, options: dict | None = None): super().reset(seed=seed) mj.mj_resetData(self.sim.model, self.sim.data) mj.mj_forward(self.sim.model, self.sim.data) - return self._obs(), {} + return self._obs(), {"collision": False} def step(self, action: dict[str, np.ndarray]): self.sim.data.qpos[0] += float(action["delta"][0]) self.sim.data.qvel[:] = 0.0 mj.mj_forward(self.sim.model, self.sim.data) - return self._obs(), 0.0, False, False, {} + return self._obs(), 0.0, False, False, {"collision": False} def close(self): return None @@ -139,7 +135,7 @@ def test_record_and_replay_sim_state(tmp_path: Path): record_env = StorageWrapper(record_env, str(dataset_path), "test sim replay", batch_size=1, always_record=True) obs, _ = record_env.reset() - assert SimStateObservationWrapper.DYNAMIC_JOINT_SCHEMA_KEY in obs + assert SimStateObservationWrapper.STATE_KEY in obs record_env.step({"delta": np.array([0.125], dtype=np.float64)}) record_env.close() @@ -149,17 +145,18 @@ def test_record_and_replay_sim_state(tmp_path: Path): assert len(rows) == 1 recorded_obs = rows[0]["obs"] - assert SimStateObservationWrapper.DYNAMIC_JOINT_SCHEMA_KEY in recorded_obs - assert SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY in recorded_obs - assert SimStateObservationWrapper.DYNAMIC_JOINT_QVEL_KEY in recorded_obs + assert SimStateObservationWrapper.STATE_KEY in recorded_obs + assert SimStateObservationWrapper.STATE_SPEC_KEY in recorded_obs + assert SimStateObservationWrapper.STATE_SIZE_KEY in recorded_obs + assert ( + len(recorded_obs[SimStateObservationWrapper.STATE_KEY]) + == recorded_obs[SimStateObservationWrapper.STATE_SIZE_KEY] + ) recorded_steps = load_trajectory(dataset_path, rows[0]["uuid"], prefer_duckdb=True) assert len(recorded_steps) == 1 - assert recorded_steps[0].dynamic_joint_schema is not None - assert np.allclose( - recorded_steps[0].dynamic_joint_state["qpos"], # type: ignore[index] - np.asarray(recorded_obs[SimStateObservationWrapper.DYNAMIC_JOINT_QPOS_KEY]), - ) + assert recorded_steps[0].sim_state_spec is not None + assert np.allclose(recorded_steps[0].sim_state, np.asarray(recorded_obs[SimStateObservationWrapper.STATE_KEY])) replay_sim = Sim(model_path) replay_env: gym.Env = DummySimEnv(replay_sim, camera_set=DummyCameraSet(replay_sim)) @@ -167,7 +164,7 @@ def test_record_and_replay_sim_state(tmp_path: Path): render_dir = tmp_path / "rendered" replay_env.reset() - restore_sim_step(replay_env, recorded_steps[0], dynamic_joint_schema=recorded_steps[0].dynamic_joint_schema) + restore_sim_step(replay_env, recorded_steps[0], sim_state_spec=recorded_steps[0].sim_state_spec) assert np.allclose( replay_env.get_wrapper_attr("sim").data.qpos, np.asarray(recorded_obs["qpos"]), atol=1e-9, rtol=0 ) @@ -221,7 +218,7 @@ def _record_dummy_trajectory(dataset_path: Path, model_path: Path) -> tuple[list return recorded_steps, rows[0]["obs"] -def test_dynamic_joint_replay_tolerates_added_and_removed_fixed_scene_elements(tmp_path: Path): +def test_sim_state_replay_tolerates_added_and_removed_fixed_scene_elements(tmp_path: Path): base_model_path = tmp_path / "base.xml" base_model_path.write_text(XML) modified_model_path = tmp_path / "modified.xml" @@ -238,10 +235,8 @@ def test_dynamic_joint_replay_tolerates_added_and_removed_fixed_scene_elements(t replay_env: gym.Env = DummySimEnv(replay_sim) replay_env = SimStateObservationWrapper(replay_env) replay_env.reset() - dynamic_joint_schema = next( - step.dynamic_joint_schema for step in recorded_steps if step.dynamic_joint_schema is not None - ) - restore_sim_step(replay_env, recorded_steps[0], dynamic_joint_schema=dynamic_joint_schema) + sim_state_spec = next(step.sim_state_spec for step in recorded_steps if step.sim_state_spec is not None) + restore_sim_step(replay_env, recorded_steps[0], sim_state_spec=sim_state_spec) assert np.allclose( replay_env.get_wrapper_attr("sim").data.qpos, np.asarray(recorded_obs["qpos"]), atol=1e-9, rtol=0 @@ -251,60 +246,64 @@ def test_dynamic_joint_replay_tolerates_added_and_removed_fixed_scene_elements(t ) -DUAL_ARM_ROBOT2ID = {"left": "0", "right": "1"} - +def _write_repo_scene_with_dynamic_body(src: Path, dst: Path, *, add_extra_fixed_scene_elements: bool = False): + tree = ET.parse(src) + root = tree.getroot() + worldbody = root.find("worldbody") + assert worldbody is not None -def _create_dual_arm_env(scene_name: str): - robot_cfg = default_sim_robot_cfg(scene_name, idx="") - sim_cfg = SimConfig() - sim_cfg.async_control = False - return SimMultiEnvCreator()( - name2id=DUAL_ARM_ROBOT2ID, - robot_cfg=robot_cfg, - control_mode=ControlMode.JOINTS, - gripper_cfg=default_sim_gripper_cfg(idx=""), - sim_cfg=sim_cfg, - max_relative_movement=None, + dynamic_body = ET.SubElement(worldbody, "body", {"name": "replay_dynamic_box", "pos": "0 0 0.1"}) + ET.SubElement(dynamic_body, "freejoint", {"name": "replay_dynamic_box_free"}) + ET.SubElement( + dynamic_body, + "geom", + { + "name": "replay_dynamic_box_geom", + "type": "box", + "size": "0.05 0.05 0.05", + "rgba": "0.2 0.6 0.9 1", + }, ) + if add_extra_fixed_scene_elements: + worldbody.append( + ET.Element( + "camera", + { + "name": "replay_extra_cam", + "pos": "1.4 0.0 0.9", + "xyaxes": "0 1 0 -0.3 0 1", + }, + ) + ) + fixed_body = ET.SubElement(worldbody, "body", {"name": "replay_extra_bg", "pos": "3 3 3"}) + ET.SubElement( + fixed_body, + "geom", + {"name": "replay_extra_bg_geom", "type": "box", "size": "0.1 0.1 0.1"}, + ) + + tree.write(dst) + + +def test_sim_state_roundtrip_on_repo_scene_layout(tmp_path: Path): + source_scene_path = REPO_ROOT / "assets/scenes/empty_world/scene.xml" + base_scene_path = tmp_path / "empty_world_dynamic.xml" + modified_scene_path = tmp_path / "empty_world_dynamic_modified.xml" + _write_repo_scene_with_dynamic_body(source_scene_path, base_scene_path) + _write_repo_scene_with_dynamic_body(source_scene_path, modified_scene_path, add_extra_fixed_scene_elements=True) + + base_sim = Sim(base_scene_path) + sim_state_spec = base_sim.get_state_spec() + sim_state = base_sim.get_state().copy() + num_seed_values = min(8, sim_state.shape[0]) + sim_state[:num_seed_values] = np.linspace(0.01, 0.01 * num_seed_values, num_seed_values) + base_sim.set_state(sim_state, sim_state_spec) + seeded_sim_state = base_sim.get_state() + + modified_sim = Sim(modified_scene_path) + modified_sim.set_state(seeded_sim_state, sim_state_spec) + restored_sim_state = modified_sim.get_state() -def test_dynamic_joint_state_roundtrip_on_fr3_dual_arm_scene(tmp_path: Path): - source_scene_path = REPO_ROOT / "assets/scenes/fr3_dual_arm/scene.xml" - source_robot_path = REPO_ROOT / "assets/scenes/fr3_empty_world/robot.xml" - source_urdf_path = REPO_ROOT / "assets/scenes/fr3_empty_world/robot.urdf" - modified_scene_path = source_scene_path.parent / "scene_dynamic_joint_test.xml" - _write_scene_with_extra_fixed_body_and_camera(source_scene_path, modified_scene_path) - - base_scene_name = "fr3_dual_arm_dynamic_joint_base_test" - test_scene_name = "fr3_dual_arm_dynamic_joint_test" - scene_kwargs = { - "mjcf_robot": str(source_robot_path), - "urdf": str(source_urdf_path), - "robot_type": rcs.scenes["fr3_dual_arm"].robot_type, - "mjb": None, - } - rcs.scenes[base_scene_name] = rcs.Scene(mjcf_scene=str(source_scene_path), **scene_kwargs) - rcs.scenes[test_scene_name] = rcs.Scene(mjcf_scene=str(modified_scene_path), **scene_kwargs) - - base_env = _create_dual_arm_env(base_scene_name) - modified_env = _create_dual_arm_env(test_scene_name) - try: - base_env.reset() - base_sim = base_env.get_wrapper_attr("sim") - dynamic_joint_schema = base_sim.get_dynamic_joint_schema() - dynamic_joint_state = base_sim.get_dynamic_joint_state() - - modified_env.reset() - modified_sim = modified_env.get_wrapper_attr("sim") - modified_sim.set_dynamic_joint_state(dynamic_joint_schema, dynamic_joint_state) - restored_dynamic_joint_state = modified_sim.get_dynamic_joint_state() - - assert dynamic_joint_schema == modified_sim.get_dynamic_joint_schema() - assert np.allclose(restored_dynamic_joint_state["qpos"], dynamic_joint_state["qpos"], atol=1e-9, rtol=0) - assert np.allclose(restored_dynamic_joint_state["qvel"], dynamic_joint_state["qvel"], atol=1e-9, rtol=0) - finally: - base_env.close() - modified_env.close() - del rcs.scenes[test_scene_name] - del rcs.scenes[base_scene_name] - modified_scene_path.unlink(missing_ok=True) + assert sim_state_spec == modified_sim.get_state_spec() + assert np.allclose(restored_sim_state, seeded_sim_state, atol=1e-9, rtol=0) From e34a207744f1f66efa658cc27a60738decb178a2 Mon Sep 17 00:00:00 2001 From: Tobias Juelg Date: Sun, 3 May 2026 11:46:29 +0200 Subject: [PATCH 04/13] Refactor replay onto existing replayer flow --- python/rcs/envs/base.py | 5 +- python/rcs/envs/sim.py | 214 ------------- python/rcs/sim/replayer.py | 17 +- python/rcs/sim_state_replay.py | 251 --------------- python/tests/test_replayer.py | 119 ++++++- python/tests/test_sim_state_record_replay.py | 309 ------------------- 6 files changed, 133 insertions(+), 782 deletions(-) delete mode 100644 python/rcs/sim_state_replay.py delete mode 100644 python/tests/test_sim_state_record_replay.py diff --git a/python/rcs/envs/base.py b/python/rcs/envs/base.py index 8dc97d54..a0c2c5f6 100644 --- a/python/rcs/envs/base.py +++ b/python/rcs/envs/base.py @@ -170,6 +170,7 @@ class ArmObsType(TQuatDictType, JointsDictType, TRPYDictType): ... CartOrJointContType: TypeAlias = TQuatDictType | JointsDictType | TRPYDictType LimitedCartOrJointContType: TypeAlias = LimitedTQuatRelDictType | LimitedJointsRelDictType | LimitedTRPYRelDictType +SimStateSpec: TypeAlias = dict[str, list[str] | list[int]] class ArmWithGripper(TQuatDictType, GripperDictType): ... @@ -212,9 +213,9 @@ def __init__(self, sim: simulation.Sim, return_state=True) -> None: self.frame_rate = SimpleFrameRate(cfg.frequency, "MoJoCo Simulation Loop") self.main_greenlet: greenlet | None = None self.return_state = return_state - self._replay_state: tuple[np.ndarray, int | None] | None = None + self._replay_state: tuple[np.ndarray, SimStateSpec | None] | None = None - def set_replay_state(self, state: np.ndarray, spec: int | None = None): + def set_replay_state(self, state: np.ndarray, spec: SimStateSpec | None = None): self._replay_state = (state, spec) def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], float, bool, bool, dict]: diff --git a/python/rcs/envs/sim.py b/python/rcs/envs/sim.py index 287e5b90..77354065 100644 --- a/python/rcs/envs/sim.py +++ b/python/rcs/envs/sim.py @@ -2,12 +2,9 @@ from typing import Any, cast import gymnasium as gym -import numpy as np from rcs._core.common import RobotPlatform -from rcs.envs.base import GripperWrapper from rcs.envs.space_utils import ActObsInfoWrapper -import rcs from rcs import sim logger = logging.getLogger(__name__) @@ -43,36 +40,6 @@ def reset( return super().reset(seed=seed, options=options) -class SimStateObservationWrapper(ActObsInfoWrapper): - STATE_KEY = "sim_state" - STATE_SPEC_KEY = "sim_state_spec" - STATE_SIZE_KEY = "sim_state_size" - - def __init__(self, env): - super().__init__(env) - assert self.env.get_wrapper_attr("PLATFORM") == RobotPlatform.SIMULATION, "Base environment must be simulation." - self.sim = cast(sim.Sim, self.get_wrapper_attr("sim")) - self._state_spec = self.sim.get_state_spec() - self._include_state_spec_in_next_step = True - - def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: - observation = dict(observation) - sim_state = self.sim.get_state() - observation[self.STATE_KEY] = sim_state - observation[self.STATE_SIZE_KEY] = sim_state.shape[0] - if self._include_state_spec_in_next_step: - observation[self.STATE_SPEC_KEY] = self._state_spec - self._include_state_spec_in_next_step = False - return observation, info - - def reset( - self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[dict[str, Any], dict[str, Any]]: - obs, info = super().reset(seed=seed, options=options) - self._include_state_spec_in_next_step = True - return obs, info - - class GripperWrapperSim(ActObsInfoWrapper): def __init__(self, env): super().__init__(env) @@ -127,187 +94,6 @@ def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tupl return observation, info -class RandomObjectPos(gym.Wrapper): - """ - Wrapper to randomly re-place an object in the lab environments. - Given the object's joint name and initial pose, its x, y coordinates are randomized, while z remains fixed. - If include_rotation is true, the object's z-axis rotation (yaw) is also randomized. - - Args: - env (gym.Env): The environment to wrap. - simulation (sim.Sim): The simulation instance. - joint_name (str): The name of the free joint attached to the object to manipulate. - init_object_pose (rcs.common.Pose): The initial pose of the object. - include_rotation (bool): Whether to include rotation in the randomization. - """ - - def __init__( - self, - env: gym.Env, - joint_name: str, - init_object_pose: rcs.common.Pose, - include_position: bool = True, - include_rotation: bool = False, - x_scale: float = 0.2, - y_scale: float = 0.2, - x_offset: float = 0.1, - y_offset: float = 0.1, - ): - super().__init__(env) - self.joint_name = joint_name - self.init_object_pose = init_object_pose - self.include_position = include_position - self.include_rotation = include_rotation - self.x_scale = x_scale - self.y_scale = y_scale - self.x_offset = x_offset - self.y_offset = y_offset - - def reset( - self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[dict[str, Any], dict[str, Any]]: - if options is not None and "RandomObjectPos.init_object_pose" in options: - assert isinstance( - options["RandomObjectPos.init_object_pose"], rcs.common.Pose - ), "RandomObjectPos.init_object_pose must be a rcs.common.Pose" - - self.init_object_pose = options["RandomObjectPos.init_object_pose"] - print("Got random object pos!\n", self.init_object_pose) - del options["RandomObjectPos.init_object_pose"] - obs, info = super().reset(seed=seed, options=options) - - pos_z = self.init_object_pose.translation()[2] - if self.include_position: - pos_x = self.init_object_pose.translation()[0] + np.random.random() * self.x_scale + self.x_offset - pos_y = self.init_object_pose.translation()[1] + np.random.random() * self.y_scale + self.y_offset - else: - pos_x = self.init_object_pose.translation()[0] - pos_y = self.init_object_pose.translation()[1] - - quat = self.init_object_pose.rotation_q() # xyzw format - if self.include_rotation: - random_z_rotation = (np.random.random() - 0.5) * (0.7071068 * 2) - self.get_wrapper_attr("sim").data.joint(self.joint_name).qpos = [ - pos_x, - pos_y, - pos_z, - quat[3] + random_z_rotation, - quat[0], - quat[1], - quat[2] + random_z_rotation, - ] - else: - self.get_wrapper_attr("sim").data.joint(self.joint_name).qpos = [ - pos_x, - pos_y, - pos_z, - quat[3], - quat[0], - quat[1], - quat[2], - ] - - return obs, info - - -class RandomCubePos(gym.Wrapper): - """Wrapper to randomly place cube in the lab environments. - - Works only for single robot - """ - - def __init__(self, env: gym.Env, include_rotation: bool = False, cube_joint_name="box_joint"): - super().__init__(env) - self.include_rotation = include_rotation - self.cube_joint_name = cube_joint_name - - def reset( - self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[dict[str, Any], dict[str, Any]]: - obs, info = super().reset(seed=seed, options=options) - - iso_cube = np.array([0.498, 0.0, 0.226]) - iso_cube_pose = rcs.common.Pose(translation=np.array(iso_cube), rpy_vector=np.array([0, 0, 0])) # type: ignore - iso_cube = self.get_wrapper_attr("robot").to_pose_in_world_coordinates(iso_cube_pose).translation() - pos_z = 0.0288 - pos_x = iso_cube[0] + np.random.random() * 0.2 - 0.1 - pos_y = iso_cube[1] + np.random.random() * 0.2 - 0.1 - - if self.include_rotation: - self.get_wrapper_attr("sim").data.joint(self.cube_joint_name).qpos = [ - pos_x, - pos_y, - pos_z, - 2 * np.random.random() - 1, - 0, - 0, - 1, - ] - else: - self.get_wrapper_attr("sim").data.joint(self.cube_joint_name).qpos = [pos_x, pos_y, pos_z, 0, 0, 0, 1] - - return obs, info - - -class PickCubeSuccessWrapper(gym.Wrapper): - """ - Wrapper to check if the cube is successfully picked up in the FR3SimplePickUpSim environment. - Cube must be lifted 10 cm above the robot base. - Computes a reward between 0 and 1 based on: - - TCP to object distance - - cube z height - - whether the arm is standing still once the task is solved. - """ - - def __init__(self, env, cube_geom_name="box_geom"): - super().__init__(env) - assert isinstance(self.get_wrapper_attr("robot"), sim.SimRobot), "Robot must be a sim.SimRobot instance." - self._robot = cast(sim.SimRobot, self.get_wrapper_attr("robot")) - self.sim = self.env.get_wrapper_attr("sim") - self.cube_geom_name = cube_geom_name - self.home_pose = self._robot.get_cartesian_position() - self._gripper_closing = 0 - self._gripper = self.get_wrapper_attr("gripper") - - def step(self, action: dict[str, Any]): # type: ignore - obs, reward, _, truncated, info = super().step(action) - if ( - self._gripper.get_normalized_width() > 0.01 - and self._gripper.get_normalized_width() < 0.99 - and obs["gripper"] == GripperWrapper.BINARY_GRIPPER_CLOSED - ): - self._gripper_closing += 1 - else: - self._gripper_closing = 0 - cube_pose = rcs.common.Pose(translation=self.sim.data.geom(self.cube_geom_name).xpos) - cube_pose = self._robot.to_pose_in_robot_coordinates(cube_pose) - tcp_to_obj_dist = np.linalg.norm(cube_pose.translation() - self._robot.get_cartesian_position().translation()) - obj_to_goal_dist = 0.10 - min(cube_pose.translation()[-1], 0.10) - obj_to_goal_dist = np.linalg.norm(cube_pose.translation() - self.home_pose.translation()) - # NOTE: 4 depends on the time passing between each step. - is_grasped = ( - self._gripper_closing >= 4 # gripper is closing since more than 4 steps - and obs["gripper"] == GripperWrapper.BINARY_GRIPPER_CLOSED # command is still close - and tcp_to_obj_dist <= 0.01 # tcp to cube center is max 1cm - ) - success = obj_to_goal_dist <= 0.022 and info["is_grasped"] - movement = np.linalg.norm(self.sim.data.qvel) - - reaching_reward = 1 - np.tanh(5 * tcp_to_obj_dist) - place_reward = 1 - np.tanh(5 * obj_to_goal_dist) - static_reward = 1 - np.tanh(5 * movement) - info["is_grasped"] = is_grasped - info["success"] = success - reward = reaching_reward + place_reward * is_grasped + static_reward * success - reward /= 3 # type: ignore - return obs, reward, success, truncated, info - - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): - obs, info = super().reset() - self.home_pose = self._robot.get_cartesian_position() - return obs, info - - class DigitalTwin(gym.Wrapper): def __init__(self, env, twin_env): super().__init__(env) diff --git a/python/rcs/sim/replayer.py b/python/rcs/sim/replayer.py index f8ec6832..0ca28371 100644 --- a/python/rcs/sim/replayer.py +++ b/python/rcs/sim/replayer.py @@ -9,11 +9,20 @@ import rcs.envs.configs as env_configs import rcs.envs.tasks as env_tasks from rcs._core.sim import SimConfig -from rcs.envs.base import RelativeTo, SimEnv +from rcs.envs.base import RelativeTo, SimEnv, SimStateSpec from rcs.envs.scenes import SimEnvCreator from rcs.envs.storage_wrapper import StorageWrapper +def _normalize_sim_state_spec(value: Any) -> SimStateSpec: + return { + "joint_names": [str(item) for item in value["joint_names"]], + "joint_types": [int(item) for item in value["joint_types"]], + "qpos_sizes": [int(item) for item in value["qpos_sizes"]], + "qvel_sizes": [int(item) for item in value["qvel_sizes"]], + } + + @dataclass(frozen=True) class RecordedSimStep: step: int @@ -38,13 +47,13 @@ def sim_state(self) -> np.ndarray: raise KeyError(msg) @property - def sim_state_spec(self) -> int | None: + def sim_state_spec(self) -> SimStateSpec | None: if SimEnv.STATE_SPEC_KEY in self.info: - return int(self.info[SimEnv.STATE_SPEC_KEY]) + return _normalize_sim_state_spec(self.info[SimEnv.STATE_SPEC_KEY]) for value in self.info.values(): if isinstance(value, dict) and SimEnv.STATE_SPEC_KEY in value: - return int(value[SimEnv.STATE_SPEC_KEY]) + return _normalize_sim_state_spec(value[SimEnv.STATE_SPEC_KEY]) return None diff --git a/python/rcs/sim_state_replay.py b/python/rcs/sim_state_replay.py deleted file mode 100644 index 0f2a29b5..00000000 --- a/python/rcs/sim_state_replay.py +++ /dev/null @@ -1,251 +0,0 @@ -from __future__ import annotations - -import time -from dataclasses import dataclass -from pathlib import Path -from typing import Annotated, Any - -import gymnasium as gym -import numpy as np -import pyarrow.compute as pc -import pyarrow.dataset as ds -import typer -from PIL import Image -from rcs.envs.base import ControlMode -from rcs.envs.sim import SimStateObservationWrapper - -import rcs # noqa: F401 - -app = typer.Typer(help="Replay recorded MuJoCo trajectories from a parquet dataset.") - -DATASET_ARGUMENT = typer.Argument(..., exists=True, file_okay=False, dir_okay=True) -ENV_ID_OPTION = typer.Option("rcs/FR3SimplePickUpSim-v0", help="Gymnasium env id used for replay.") -TRAJECTORY_UUID_OPTION = typer.Option(None, help="UUID of the recorded trajectory to replay.") -CAMERA_OPTION = typer.Option([], "--camera", help="Camera names to enable on the replay env.") -RESOLUTION_OPTION = typer.Option((256, 256), help="Replay camera resolution as WIDTH HEIGHT.") -FRAME_RATE_OPTION = typer.Option(0, help="Replay camera frame rate.") -RENDER_MODE_OPTION = typer.Option("human", help="Gym render mode for the replay env.") -CONTROL_MODE_OPTION = typer.Option(ControlMode.CARTESIAN_TRPY.name, help="Control mode name for env creation.") -SLEEP_OPTION = typer.Option(0.0, help="Optional delay between restored states.") -OUTPUT_DIR_OPTION = typer.Option(None, help="Optional directory for re-rendered RGB frames.") -PREFER_DUCKDB_OPTION = typer.Option(True, help="Use duckdb for parquet loading when it is available.") - - -@dataclass(frozen=True) -class RecordedSimStep: - step: int - uuid: str - timestamp: float | None - observation: dict[str, Any] - - @property - def sim_state(self) -> np.ndarray: - return np.asarray(self.observation[SimStateObservationWrapper.STATE_KEY], dtype=np.float64) - - @property - def sim_state_spec(self) -> dict[str, Any] | None: - schema = self.observation.get(SimStateObservationWrapper.STATE_SPEC_KEY) - return dict(schema) if schema is not None else None - - -class DuckDBUnavailableError(RuntimeError): - pass - - -def _get_duckdb_module(): - try: - import duckdb - except ModuleNotFoundError as exc: - msg = ( - "duckdb is required for the preferred parquet read path but is not installed. " - "Install the 'duckdb' Python package or rely on the pyarrow fallback in library calls." - ) - raise DuckDBUnavailableError(msg) from exc - return duckdb - - -def _load_distinct_uuids_with_duckdb(dataset_path: Path) -> list[str]: - duckdb = _get_duckdb_module() - connection = duckdb.connect() - try: - rows = connection.execute( - "SELECT DISTINCT uuid FROM read_parquet(?) ORDER BY uuid", - [str(dataset_path)], - ).fetchall() - finally: - connection.close() - return [row[0] for row in rows] - - -def _load_distinct_uuids_with_pyarrow(dataset_path: Path) -> list[str]: - dataset = ds.dataset(str(dataset_path), format="parquet") - uuids = dataset.to_table(columns=["uuid"])["uuid"] - return sorted(str(uuid) for uuid in pc.unique(uuids).to_pylist()) - - -def list_trajectory_ids(dataset_path: Path, prefer_duckdb: bool = True) -> list[str]: - if prefer_duckdb: - try: - return _load_distinct_uuids_with_duckdb(dataset_path) - except DuckDBUnavailableError: - pass - return _load_distinct_uuids_with_pyarrow(dataset_path) - - -def _load_trajectory_with_duckdb(dataset_path: Path, trajectory_uuid: str) -> list[RecordedSimStep]: - duckdb = _get_duckdb_module() - connection = duckdb.connect() - try: - table = connection.execute( - "SELECT uuid, step, timestamp, obs FROM read_parquet(?) WHERE uuid = ? ORDER BY step", - [str(dataset_path), trajectory_uuid], - ).to_arrow_table() - finally: - connection.close() - return [ - RecordedSimStep( - step=int(row["step"]), - uuid=str(row["uuid"]), - timestamp=float(row["timestamp"]) if row["timestamp"] is not None else None, - observation=row["obs"], - ) - for row in table.to_pylist() - ] - - -def _load_trajectory_with_pyarrow(dataset_path: Path, trajectory_uuid: str) -> list[RecordedSimStep]: - dataset = ds.dataset(str(dataset_path), format="parquet") - table = dataset.to_table(filter=pc.field("uuid") == trajectory_uuid, columns=["uuid", "step", "timestamp", "obs"]) - rows = table.sort_by([("step", "ascending")]).to_pylist() - return [ - RecordedSimStep( - step=int(row["step"]), - uuid=str(row["uuid"]), - timestamp=float(row["timestamp"]) if row["timestamp"] is not None else None, - observation=row["obs"], - ) - for row in rows - ] - - -def load_trajectory(dataset_path: Path, trajectory_uuid: str, prefer_duckdb: bool = True) -> list[RecordedSimStep]: - if prefer_duckdb: - try: - return _load_trajectory_with_duckdb(dataset_path, trajectory_uuid) - except DuckDBUnavailableError: - pass - return _load_trajectory_with_pyarrow(dataset_path, trajectory_uuid) - - -def resolve_trajectory_uuid(dataset_path: Path, trajectory_uuid: str | None, prefer_duckdb: bool = True) -> str: - if trajectory_uuid is not None: - return trajectory_uuid - available_uuids = list_trajectory_ids(dataset_path, prefer_duckdb=prefer_duckdb) - if len(available_uuids) == 1: - return available_uuids[0] - msg = ( - f"Dataset {dataset_path} contains {len(available_uuids)} trajectories. " - f"Pass --trajectory-uuid and choose one of: {available_uuids}" - ) - raise ValueError(msg) - - -def restore_sim_step( - env: gym.Env, - recorded_step: RecordedSimStep, - sim_state_spec: dict[str, Any] | None = None, -): - resolved_spec = sim_state_spec or recorded_step.sim_state_spec - if resolved_spec is None: - msg = "Recorded sim state is missing its schema." - raise ValueError(msg) - - sim = env.get_wrapper_attr("sim") - sim.set_state(recorded_step.sim_state, spec=resolved_spec) - - -def collect_rgb_frames(env: gym.Env) -> dict[str, np.ndarray]: - try: - camera_set = env.get_wrapper_attr("camera_set") - except AttributeError: - return {} - - frameset = camera_set.get_latest_frames() - if frameset is None: - return {} - - rgb_frames: dict[str, np.ndarray] = {} - for camera_name, frame in frameset.frames.items(): - lower_name = camera_name.lower() - if "digit" in lower_name or "tactile" in lower_name: - continue - rgb_frames[camera_name] = np.asarray(frame.camera.color.data) - return rgb_frames - - -def save_rgb_frames(output_dir: Path, recorded_step: RecordedSimStep, rgb_frames: dict[str, np.ndarray]): - output_dir.mkdir(parents=True, exist_ok=True) - for camera_name, rgb_frame in rgb_frames.items(): - Image.fromarray(rgb_frame).save(output_dir / f"step-{recorded_step.step:06d}-{camera_name}.png") - - -def replay_trajectory( - env: gym.Env, - recorded_steps: list[RecordedSimStep], - *, - sleep_s: float = 0.0, - output_dir: Path | None = None, -): - if not recorded_steps: - msg = "No recorded sim states found in the requested trajectory." - raise ValueError(msg) - - sim_state_spec = next( - (recorded_step.sim_state_spec for recorded_step in recorded_steps if recorded_step.sim_state_spec), None - ) - - env.reset() - for recorded_step in recorded_steps: - restore_sim_step(env, recorded_step, sim_state_spec=sim_state_spec) - if output_dir is not None: - save_rgb_frames(output_dir, recorded_step, collect_rgb_frames(env)) - if sleep_s > 0: - time.sleep(sleep_s) - - -@app.command() -def replay( - dataset: Annotated[Path, DATASET_ARGUMENT], - env_id: Annotated[str, ENV_ID_OPTION], - trajectory_uuid: Annotated[str | None, TRAJECTORY_UUID_OPTION], - camera: Annotated[list[str], CAMERA_OPTION], - resolution: Annotated[tuple[int, int], RESOLUTION_OPTION], - frame_rate: Annotated[int, FRAME_RATE_OPTION], - render_mode: Annotated[str, RENDER_MODE_OPTION], - control_mode: Annotated[str, CONTROL_MODE_OPTION], - sleep_s: Annotated[float, SLEEP_OPTION], - output_dir: Annotated[Path | None, OUTPUT_DIR_OPTION], - prefer_duckdb: Annotated[bool, PREFER_DUCKDB_OPTION], -): - resolved_uuid = resolve_trajectory_uuid(dataset, trajectory_uuid, prefer_duckdb=prefer_duckdb) - env = gym.make( - env_id, - render_mode=render_mode, - control_mode=ControlMode[control_mode], - resolution=resolution, - frame_rate=frame_rate, - cam_list=camera, - ) - try: - recorded_steps = load_trajectory(dataset, resolved_uuid, prefer_duckdb=prefer_duckdb) - replay_trajectory(env, recorded_steps, sleep_s=sleep_s, output_dir=output_dir) - finally: - env.close() - - typer.echo(f"Replayed {len(recorded_steps)} steps from trajectory {resolved_uuid}.") - if output_dir is not None: - typer.echo(f"Saved re-rendered RGB frames to {output_dir}.") - - -if __name__ == "__main__": - app() diff --git a/python/tests/test_replayer.py b/python/tests/test_replayer.py index c819c6b2..f17fc3e5 100644 --- a/python/tests/test_replayer.py +++ b/python/tests/test_replayer.py @@ -1,14 +1,23 @@ +import xml.etree.ElementTree as ET from pathlib import Path from typing import Any import duckdb +import gymnasium as gym +import mujoco as mj import numpy as np from rcs._core.sim import SimConfig -from rcs.envs.base import RelativeTo +from rcs.envs.base import RelativeTo, SimEnv from rcs.envs.configs import EmptyWorldFR3Duo from rcs.envs.storage_wrapper import StorageWrapper from rcs.envs.tasks import PickTaskConfig -from rcs.sim.replayer import load_distinct_uuids, load_trajectory, replay_trajectory +from rcs.sim.replayer import ( + RecordedSimStep, + load_distinct_uuids, + load_trajectory, + replay_trajectory, +) +from rcs.sim.sim import Sim def _build_env(output_dir: Path, *, with_cameras: bool, instruction: str = "") -> StorageWrapper: @@ -106,6 +115,94 @@ def _replay_prefix(output_dir: Path, *, with_cameras: bool, limit: int) -> None: env.close() +MINIMAL_XML = """ + + + + + + + + + +""" + + +class DummyReplayEnv(gym.Env): + def __init__(self, sim: Sim): + super().__init__() + self.sim = sim + self._replay_state = None + + def get_wrapper_attr(self, name: str): + return getattr(self, name) + + def set_replay_state(self, state: np.ndarray, spec=None): + self._replay_state = (np.asarray(state, dtype=np.float64), spec) + + def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): + super().reset(seed=seed) + mj.mj_resetData(self.sim.model, self.sim.data) + mj.mj_forward(self.sim.model, self.sim.data) + return {}, {} + + def step(self, action: dict[str, np.ndarray]): + if self._replay_state is not None: + state, spec = self._replay_state + self.sim.set_state(state, spec) + self._replay_state = None + self.sim.data.qpos[0] += float(action["delta"][0]) + self.sim.data.qvel[:] = 0.0 + mj.mj_forward(self.sim.model, self.sim.data) + return {}, 0.0, False, False, {} + + +def _write_scene_with_extra_fixed_body_and_camera(src: Path, dst: Path): + tree = ET.parse(src) + root = tree.getroot() + for include in root.findall("include"): + include_file = include.get("file") + if include_file is not None and not Path(include_file).is_absolute(): + include.set("file", str((src.parent / include_file).resolve())) + + worldbody = root.find("worldbody") + assert worldbody is not None + + worldbody.append( + ET.Element( + "camera", + { + "name": "replay_extra_cam", + "pos": "1.4 0.0 0.9", + "xyaxes": "0 1 0 -0.3 0 1", + }, + ) + ) + body = ET.SubElement(worldbody, "body", {"name": "replay_extra_bg", "pos": "3 3 3"}) + ET.SubElement(body, "geom", {"name": "replay_extra_bg_geom", "type": "box", "size": "0.1 0.1 0.1"}) + tree.write(dst) + + +def _recorded_dummy_step(model_path: Path) -> RecordedSimStep: + sim = Sim(model_path) + state = sim.get_state().copy() + state[0] = 0.125 + sim.set_state(state, sim.get_state_spec()) + return RecordedSimStep( + step=0, + uuid="dummy-trajectory", + timestamp=None, + observation={}, + info={ + SimEnv.STATE_KEY: sim.get_state(), + SimEnv.STATE_SPEC_KEY: sim.get_state_spec(), + }, + action={"delta": np.array([0.0], dtype=np.float64)}, + instruction="", + success=False, + ) + + def _assert_nested_close(actual: Any, expected: Any, *, atol: float = 1e-6): if isinstance(expected, dict): assert isinstance(actual, dict) @@ -198,6 +295,24 @@ def test_replayer_reproduces_existing_parquet_prefix_without_cameras(tmp_path: P _assert_nested_close(replay_instruction, source_instruction) +def test_replayer_restores_sim_state_across_fixed_scene_changes(tmp_path: Path): + base_model_path = tmp_path / "base.xml" + base_model_path.write_text(MINIMAL_XML) + modified_model_path = tmp_path / "modified.xml" + _write_scene_with_extra_fixed_body_and_camera(base_model_path, modified_model_path) + + for record_model_path, replay_model_path in ( + (base_model_path, modified_model_path), + (modified_model_path, base_model_path), + ): + recorded_step = _recorded_dummy_step(record_model_path) + replay_env = DummyReplayEnv(Sim(replay_model_path)) + + replay_trajectory(replay_env, [recorded_step], True) + + assert np.allclose(replay_env.sim.get_state(), recorded_step.sim_state, atol=1e-9, rtol=0) + + def test_replayer_adds_cameras_to_existing_episode_without_cameras(tmp_path: Path): source_dir = tmp_path / "source" replay_dir = tmp_path / "replayed_with_cameras" diff --git a/python/tests/test_sim_state_record_replay.py b/python/tests/test_sim_state_record_replay.py deleted file mode 100644 index 1d7b00b7..00000000 --- a/python/tests/test_sim_state_record_replay.py +++ /dev/null @@ -1,309 +0,0 @@ -from __future__ import annotations - -import importlib.util -import sys -import xml.etree.ElementTree as ET -from dataclasses import dataclass -from pathlib import Path - -import gymnasium as gym -import mujoco as mj -import numpy as np -import pyarrow.dataset as ds -from rcs._core.common import RobotPlatform -from rcs.camera.interface import CameraFrame, DataFrame, Frame, FrameSet -from rcs.envs.storage_wrapper import StorageWrapper - -import rcs - -REPO_ROOT = Path(__file__).resolve().parents[2] - - -def _load_local_module(module_name: str, relative_path: str): - module_path = REPO_ROOT / relative_path - spec = importlib.util.spec_from_file_location(module_name, module_path) - if spec is None or spec.loader is None: - msg = f"Could not create an import spec for {module_name} from {module_path}." - raise ImportError(msg) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - parent_name, _, child_name = module_name.rpartition(".") - if parent_name: - parent_module = sys.modules[parent_name] - setattr(parent_module, child_name, module) - spec.loader.exec_module(module) - return module - - -local_sim_module = _load_local_module("rcs.sim.sim", "python/rcs/sim/sim.py") -rcs.sim.__dict__["Sim"] = local_sim_module.Sim -_load_local_module("rcs.envs.sim", "python/rcs/envs/sim.py") -_load_local_module("rcs.sim_state_replay", "python/rcs/sim_state_replay.py") - -from rcs.envs.sim import SimStateObservationWrapper # noqa: E402 -from rcs.sim.sim import Sim # noqa: E402 -from rcs.sim_state_replay import ( # noqa: E402 - load_trajectory, - replay_trajectory, - restore_sim_step, -) - -XML = """ - - - - - - - - - -""" - - -@dataclass -class DummyCameraSet: - sim: Sim - - def get_latest_frames(self) -> FrameSet: - color_value = int(np.clip(round((self.sim.data.qpos[0] + 1.0) * 80.0), 0, 255)) - rgb = np.full((8, 8, 3), color_value, dtype=np.uint8) - return FrameSet( - frames={ - "main": Frame( - camera=CameraFrame( - color=DataFrame(data=rgb), - depth=None, - ), - ) - }, - avg_timestamp=None, - ) - - -class DummySimEnv(gym.Env): - PLATFORM = RobotPlatform.SIMULATION - - def __init__(self, sim: Sim, camera_set: DummyCameraSet | None = None): - super().__init__() - self.sim = sim - self.camera_set = camera_set - self.action_space = gym.spaces.Dict( - { - "delta": gym.spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float64), - } - ) - self.observation_space = gym.spaces.Dict( - { - "qpos": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.sim.model.nq,), dtype=np.float64), - "qvel": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.sim.model.nv,), dtype=np.float64), - } - ) - - def _obs(self) -> dict[str, np.ndarray]: - return { - "qpos": self.sim.data.qpos.copy(), - "qvel": self.sim.data.qvel.copy(), - } - - def get_wrapper_attr(self, name: str): - return getattr(self, name) - - def reset(self, *, seed: int | None = None, options: dict | None = None): - super().reset(seed=seed) - mj.mj_resetData(self.sim.model, self.sim.data) - mj.mj_forward(self.sim.model, self.sim.data) - return self._obs(), {"collision": False} - - def step(self, action: dict[str, np.ndarray]): - self.sim.data.qpos[0] += float(action["delta"][0]) - self.sim.data.qvel[:] = 0.0 - mj.mj_forward(self.sim.model, self.sim.data) - return self._obs(), 0.0, False, False, {"collision": False} - - def close(self): - return None - - -def test_record_and_replay_sim_state(tmp_path: Path): - model_path = tmp_path / "dummy.xml" - model_path.write_text(XML) - - dataset_path = tmp_path / "dataset" - record_env: gym.Env = DummySimEnv(Sim(model_path)) - record_env = SimStateObservationWrapper(record_env) - record_env = StorageWrapper(record_env, str(dataset_path), "test sim replay", batch_size=1, always_record=True) - - obs, _ = record_env.reset() - assert SimStateObservationWrapper.STATE_KEY in obs - - record_env.step({"delta": np.array([0.125], dtype=np.float64)}) - record_env.close() - - table = ds.dataset(str(dataset_path), format="parquet").to_table().sort_by([("step", "ascending")]) - rows = table.to_pylist() - assert len(rows) == 1 - - recorded_obs = rows[0]["obs"] - assert SimStateObservationWrapper.STATE_KEY in recorded_obs - assert SimStateObservationWrapper.STATE_SPEC_KEY in recorded_obs - assert SimStateObservationWrapper.STATE_SIZE_KEY in recorded_obs - assert ( - len(recorded_obs[SimStateObservationWrapper.STATE_KEY]) - == recorded_obs[SimStateObservationWrapper.STATE_SIZE_KEY] - ) - - recorded_steps = load_trajectory(dataset_path, rows[0]["uuid"], prefer_duckdb=True) - assert len(recorded_steps) == 1 - assert recorded_steps[0].sim_state_spec is not None - assert np.allclose(recorded_steps[0].sim_state, np.asarray(recorded_obs[SimStateObservationWrapper.STATE_KEY])) - - replay_sim = Sim(model_path) - replay_env: gym.Env = DummySimEnv(replay_sim, camera_set=DummyCameraSet(replay_sim)) - replay_env = SimStateObservationWrapper(replay_env) - render_dir = tmp_path / "rendered" - - replay_env.reset() - restore_sim_step(replay_env, recorded_steps[0], sim_state_spec=recorded_steps[0].sim_state_spec) - assert np.allclose( - replay_env.get_wrapper_attr("sim").data.qpos, np.asarray(recorded_obs["qpos"]), atol=1e-9, rtol=0 - ) - assert np.allclose( - replay_env.get_wrapper_attr("sim").data.qvel, np.asarray(recorded_obs["qvel"]), atol=1e-9, rtol=0 - ) - - replay_trajectory(replay_env, recorded_steps, output_dir=render_dir) - - rendered_files = sorted(path.name for path in render_dir.glob("*.png")) - assert rendered_files == ["step-000000-main.png"] - - -def _write_scene_with_extra_fixed_body_and_camera(src: Path, dst: Path): - tree = ET.parse(src) - root = tree.getroot() - for include in root.findall("include"): - include_file = include.get("file") - if include_file is not None and not Path(include_file).is_absolute(): - include.set("file", str((src.parent / include_file).resolve())) - - worldbody = root.find("worldbody") - assert worldbody is not None - - worldbody.append( - ET.Element( - "camera", - { - "name": "replay_extra_cam", - "pos": "1.4 0.0 0.9", - "xyaxes": "0 1 0 -0.3 0 1", - }, - ) - ) - body = ET.SubElement(worldbody, "body", {"name": "replay_extra_bg", "pos": "3 3 3"}) - ET.SubElement(body, "geom", {"name": "replay_extra_bg_geom", "type": "box", "size": "0.1 0.1 0.1"}) - tree.write(dst) - - -def _record_dummy_trajectory(dataset_path: Path, model_path: Path) -> tuple[list, dict[str, object]]: - record_env: gym.Env = DummySimEnv(Sim(model_path)) - record_env = SimStateObservationWrapper(record_env) - record_env = StorageWrapper(record_env, str(dataset_path), "test sim replay", batch_size=1, always_record=True) - record_env.reset() - record_env.step({"delta": np.array([0.125], dtype=np.float64)}) - record_env.close() - - table = ds.dataset(str(dataset_path), format="parquet").to_table().sort_by([("step", "ascending")]) - rows = table.to_pylist() - recorded_steps = load_trajectory(dataset_path, rows[0]["uuid"], prefer_duckdb=True) - return recorded_steps, rows[0]["obs"] - - -def test_sim_state_replay_tolerates_added_and_removed_fixed_scene_elements(tmp_path: Path): - base_model_path = tmp_path / "base.xml" - base_model_path.write_text(XML) - modified_model_path = tmp_path / "modified.xml" - _write_scene_with_extra_fixed_body_and_camera(base_model_path, modified_model_path) - - for record_model_path, replay_model_path in ( - (base_model_path, modified_model_path), - (modified_model_path, base_model_path), - ): - dataset_path = tmp_path / f"dataset-{record_model_path.stem}-to-{replay_model_path.stem}" - recorded_steps, recorded_obs = _record_dummy_trajectory(dataset_path, record_model_path) - - replay_sim = Sim(replay_model_path) - replay_env: gym.Env = DummySimEnv(replay_sim) - replay_env = SimStateObservationWrapper(replay_env) - replay_env.reset() - sim_state_spec = next(step.sim_state_spec for step in recorded_steps if step.sim_state_spec is not None) - restore_sim_step(replay_env, recorded_steps[0], sim_state_spec=sim_state_spec) - - assert np.allclose( - replay_env.get_wrapper_attr("sim").data.qpos, np.asarray(recorded_obs["qpos"]), atol=1e-9, rtol=0 - ) - assert np.allclose( - replay_env.get_wrapper_attr("sim").data.qvel, np.asarray(recorded_obs["qvel"]), atol=1e-9, rtol=0 - ) - - -def _write_repo_scene_with_dynamic_body(src: Path, dst: Path, *, add_extra_fixed_scene_elements: bool = False): - tree = ET.parse(src) - root = tree.getroot() - worldbody = root.find("worldbody") - assert worldbody is not None - - dynamic_body = ET.SubElement(worldbody, "body", {"name": "replay_dynamic_box", "pos": "0 0 0.1"}) - ET.SubElement(dynamic_body, "freejoint", {"name": "replay_dynamic_box_free"}) - ET.SubElement( - dynamic_body, - "geom", - { - "name": "replay_dynamic_box_geom", - "type": "box", - "size": "0.05 0.05 0.05", - "rgba": "0.2 0.6 0.9 1", - }, - ) - - if add_extra_fixed_scene_elements: - worldbody.append( - ET.Element( - "camera", - { - "name": "replay_extra_cam", - "pos": "1.4 0.0 0.9", - "xyaxes": "0 1 0 -0.3 0 1", - }, - ) - ) - fixed_body = ET.SubElement(worldbody, "body", {"name": "replay_extra_bg", "pos": "3 3 3"}) - ET.SubElement( - fixed_body, - "geom", - {"name": "replay_extra_bg_geom", "type": "box", "size": "0.1 0.1 0.1"}, - ) - - tree.write(dst) - - -def test_sim_state_roundtrip_on_repo_scene_layout(tmp_path: Path): - source_scene_path = REPO_ROOT / "assets/scenes/empty_world/scene.xml" - base_scene_path = tmp_path / "empty_world_dynamic.xml" - modified_scene_path = tmp_path / "empty_world_dynamic_modified.xml" - _write_repo_scene_with_dynamic_body(source_scene_path, base_scene_path) - _write_repo_scene_with_dynamic_body(source_scene_path, modified_scene_path, add_extra_fixed_scene_elements=True) - - base_sim = Sim(base_scene_path) - sim_state_spec = base_sim.get_state_spec() - sim_state = base_sim.get_state().copy() - num_seed_values = min(8, sim_state.shape[0]) - sim_state[:num_seed_values] = np.linspace(0.01, 0.01 * num_seed_values, num_seed_values) - base_sim.set_state(sim_state, sim_state_spec) - seeded_sim_state = base_sim.get_state() - - modified_sim = Sim(modified_scene_path) - modified_sim.set_state(seeded_sim_state, sim_state_spec) - restored_sim_state = modified_sim.get_state() - - assert sim_state_spec == modified_sim.get_state_spec() - assert np.allclose(restored_sim_state, seeded_sim_state, atol=1e-9, rtol=0) From 64a25ff7dec7caa2756f4cd9881db24a3fe13bdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20J=C3=BClg?= Date: Sun, 3 May 2026 12:36:53 +0200 Subject: [PATCH 05/13] style: stubs with generic classes and fix formatting --- Makefile | 6 +++--- python/rcs/_core/common.pyi | 8 ++++---- python/rcs/_core/sim.pyi | 36 ++++++++++++++++++++---------------- src/sim/sim.cpp | 21 +++++++++++---------- 4 files changed, 38 insertions(+), 33 deletions(-) diff --git a/Makefile b/Makefile index e3830ac8..67d68544 100644 --- a/Makefile +++ b/Makefile @@ -32,9 +32,9 @@ stubgen: find ./python -not -path "./python/rcs/_core/*" -name '*.pyi' -delete find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/tuple\[typing\.Literal\[\([0-9]\+\)\], typing\.Literal\[1\]\]/tuple\[typing\.Literal[\1]\]/g' find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/tuple\[\([M|N]\), typing\.Literal\[1\]\]/tuple\[\1\]/g' - sed -i 's/ q_home: numpy\.ndarray\[tuple\[M\], numpy\.dtype\[numpy\.float64\]\] | None/ q_home: numpy.ndarray | None/' python/rcs/_core/common.pyi - python -c "from pathlib import Path; p=Path('python/rcs/_core/common.pyi'); t=p.read_text(); t=t.replace('numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]]', 'numpy.ndarray[tuple[typing.Literal[2], typing.Any], numpy.dtype[numpy.float64]]'); p.write_text(t)" - python -c "from pathlib import Path; p=Path('python/rcs/_core/sim.pyi'); t=p.read_text(); t=t.replace('numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]]', 'numpy.ndarray[tuple[typing.Literal[2], typing.Any], numpy.dtype[numpy.float64]]'); t=t.replace(', max_buffer_frames: int = 100', ''); p.write_text(t)" + find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/class RobotConfig/class RobotConfig(typing.Generic[M, N])/g' + find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/class SimRobotConfig(rcs._core.common.RobotConfig)/class SimRobotConfig(rcs._core.common.RobotConfig, typing.Generic[N])/g' + find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/class DynamicJointState/class DynamicJointState(typing.Generic[M])/g' python ci_scripts/generate_common_typing.py ruff check --fix python/rcs/_core python/rcs/common_typing.py isort python/rcs/_core python/rcs/common_typing.py diff --git a/python/rcs/_core/common.pyi b/python/rcs/_core/common.pyi index cf9d1d5d..e3c8b2d5 100644 --- a/python/rcs/_core/common.pyi +++ b/python/rcs/_core/common.pyi @@ -239,12 +239,12 @@ class Robot: def to_pose_in_robot_coordinates(self, pose_in_world_coordinates: Pose) -> Pose: ... def to_pose_in_world_coordinates(self, pose_in_robot_coordinates: Pose) -> Pose: ... -class RobotConfig: +class RobotConfig(typing.Generic[M, N]): attachment_site: str dof: int - joint_limits: numpy.ndarray[tuple[typing.Literal[2], typing.Any], numpy.dtype[numpy.float64]] + joint_limits: numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]] kinematic_model_path: str - q_home: numpy.ndarray | None + q_home: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] | None robot_platform: RobotPlatform robot_type: RobotType tcp_offset: Pose @@ -252,7 +252,7 @@ class RobotConfig: self, robot_type: RobotType = ..., dof: int = 7, - joint_limits: numpy.ndarray[tuple[typing.Literal[2], typing.Any], numpy.dtype[numpy.float64]] = ..., + joint_limits: numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]] = ..., robot_platform: RobotPlatform = ..., tcp_offset: Pose = ..., attachment_site: str = "attachment_site", diff --git a/python/rcs/_core/sim.pyi b/python/rcs/_core/sim.pyi index d419aeee..f76c69a3 100644 --- a/python/rcs/_core/sim.pyi +++ b/python/rcs/_core/sim.pyi @@ -11,6 +11,8 @@ import rcs._core.common __all__: list[str] = [ "CameraType", + "DynamicJointSchema", + "DynamicJointState", "FrameSet", "GuiClient", "Sim", @@ -69,6 +71,18 @@ class CameraType: @property def value(self) -> int: ... +class DynamicJointSchema: + joint_names: list[str] + joint_types: list[int] + qpos_sizes: list[int] + qvel_sizes: list[int] + def __init__(self) -> None: ... + +class DynamicJointState(typing.Generic[M]): + qpos: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] + qvel: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] + def __init__(self) -> None: ... + class FrameSet: def __init__( self, @@ -89,18 +103,6 @@ class GuiClient: def set_model_and_data(self, arg0: int, arg1: int) -> None: ... def sync(self) -> None: ... -class DynamicJointSchema: - joint_names: list[str] - joint_types: list[int] - qpos_sizes: list[int] - qvel_sizes: list[int] - def __init__(self) -> None: ... - -class DynamicJointState: - qpos: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] - qvel: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] - def __init__(self) -> None: ... - class Sim: def __init__(self, mjmdl: int, mjdata: int) -> None: ... def _start_gui_server(self, id: str) -> None: ... @@ -125,7 +127,9 @@ class SimCameraConfig(rcs._core.common.BaseCameraConfig): ) -> None: ... class SimCameraSet: - def __init__(self, sim: Sim, cameras: dict[str, SimCameraConfig], render_on_demand: bool = True) -> None: ... + def __init__( + self, sim: Sim, cameras: dict[str, SimCameraConfig], render_on_demand: bool = True, max_buffer_frames: int = 100 + ) -> None: ... def buffer_size(self) -> int: ... def clear_buffer(self) -> None: ... def get_latest_frameset(self) -> FrameSet | None: ... @@ -210,12 +214,12 @@ class SimRobot(rcs._core.common.Robot): def set_config(self, cfg: SimRobotConfig) -> bool: ... def set_joints_hard(self, q: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]]) -> None: ... -class SimRobotConfig(rcs._core.common.RobotConfig): +class SimRobotConfig(rcs._core.common.RobotConfig, typing.Generic[N]): actuators: list[str] arm_collision_geoms: list[str] base: str dof: int - joint_limits: numpy.ndarray[tuple[typing.Literal[2], typing.Any], numpy.dtype[numpy.float64]] + joint_limits: numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]] joint_rotational_tolerance: float joints: list[str] seconds_between_callbacks: float @@ -262,7 +266,7 @@ class SimRobotConfig(rcs._core.common.RobotConfig): ], base: str = "base", dof: int = 7, - joint_limits: numpy.ndarray[tuple[typing.Literal[2], typing.Any], numpy.dtype[numpy.float64]] = ..., + joint_limits: numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]] = ..., ) -> None: ... def add_prefix(self, id: str) -> None: ... diff --git a/src/sim/sim.cpp b/src/sim/sim.cpp index e8d4a697..2cbafbb3 100644 --- a/src/sim/sim.cpp +++ b/src/sim/sim.cpp @@ -233,20 +233,20 @@ void Sim::set_dynamic_joint_state(const DynamicJointSchema& schema, "Dynamic joint schema fields must all have the same length."); } - int expected_qpos_size = std::accumulate(schema.qpos_sizes.begin(), - schema.qpos_sizes.end(), 0); - int expected_qvel_size = std::accumulate(schema.qvel_sizes.begin(), - schema.qvel_sizes.end(), 0); + int expected_qpos_size = + std::accumulate(schema.qpos_sizes.begin(), schema.qpos_sizes.end(), 0); + int expected_qvel_size = + std::accumulate(schema.qvel_sizes.begin(), schema.qvel_sizes.end(), 0); if (state.qpos.size() != expected_qpos_size) { std::ostringstream msg; - msg << "Dynamic joint qpos size mismatch. Expected " - << expected_qpos_size << ", got " << state.qpos.size() << "."; + msg << "Dynamic joint qpos size mismatch. Expected " << expected_qpos_size + << ", got " << state.qpos.size() << "."; throw std::invalid_argument(msg.str()); } if (state.qvel.size() != expected_qvel_size) { std::ostringstream msg; - msg << "Dynamic joint qvel size mismatch. Expected " - << expected_qvel_size << ", got " << state.qvel.size() << "."; + msg << "Dynamic joint qvel size mismatch. Expected " << expected_qvel_size + << ", got " << state.qvel.size() << "."; throw std::invalid_argument(msg.str()); } @@ -276,8 +276,9 @@ void Sim::set_dynamic_joint_state(const DynamicJointSchema& schema, msg << "Dynamic joint schema mismatch for joint '" << schema.joint_names[i] << "': expected type=" << target_spec.type << ", qpos_size=" << target_spec.qpos_size - << ", qvel_size=" << target_spec.qvel_size << " but got type=" - << schema.joint_types[i] << ", qpos_size=" << schema.qpos_sizes[i] + << ", qvel_size=" << target_spec.qvel_size + << " but got type=" << schema.joint_types[i] + << ", qpos_size=" << schema.qpos_sizes[i] << ", qvel_size=" << schema.qvel_sizes[i] << "."; throw std::invalid_argument(msg.str()); } From e4e1e685c0d5a60efd35fa4cc93d4c56e362882b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20J=C3=BClg?= Date: Sun, 3 May 2026 12:58:13 +0200 Subject: [PATCH 06/13] chore(replayer): simplify sim state methods and rename spec to schema --- python/rcs/envs/base.py | 12 +++--- python/rcs/sim/replayer.py | 14 +++---- python/rcs/sim/sim.py | 79 +++++++++++++---------------------- python/tests/test_replayer.py | 6 +-- 4 files changed, 46 insertions(+), 65 deletions(-) diff --git a/python/rcs/envs/base.py b/python/rcs/envs/base.py index a0c2c5f6..f339b23b 100644 --- a/python/rcs/envs/base.py +++ b/python/rcs/envs/base.py @@ -170,7 +170,7 @@ class ArmObsType(TQuatDictType, JointsDictType, TRPYDictType): ... CartOrJointContType: TypeAlias = TQuatDictType | JointsDictType | TRPYDictType LimitedCartOrJointContType: TypeAlias = LimitedTQuatRelDictType | LimitedJointsRelDictType | LimitedTRPYRelDictType -SimStateSpec: TypeAlias = dict[str, list[str] | list[int]] +SimStateSchema: TypeAlias = dict[str, list[str] | list[int]] class ArmWithGripper(TQuatDictType, GripperDictType): ... @@ -205,7 +205,7 @@ class HardwareEnv(BaseEnv): class SimEnv(BaseEnv): PLATFORM = RobotPlatform.SIMULATION STATE_KEY = "sim_state" - STATE_SPEC_KEY = "sim_state_spec" + STATE_SCHEMA_KEY = "sim_state_schema" def __init__(self, sim: simulation.Sim, return_state=True) -> None: self.sim = sim @@ -213,10 +213,10 @@ def __init__(self, sim: simulation.Sim, return_state=True) -> None: self.frame_rate = SimpleFrameRate(cfg.frequency, "MoJoCo Simulation Loop") self.main_greenlet: greenlet | None = None self.return_state = return_state - self._replay_state: tuple[np.ndarray, SimStateSpec | None] | None = None + self._replay_state: tuple[np.ndarray, SimStateSchema | None] | None = None - def set_replay_state(self, state: np.ndarray, spec: SimStateSpec | None = None): - self._replay_state = (state, spec) + def set_replay_state(self, state: np.ndarray, schema: SimStateSchema | None = None): + self._replay_state = (state, schema) def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], float, bool, bool, dict]: if self.main_greenlet is not None: @@ -256,7 +256,7 @@ def reset( def observation(self, observation: dict[str, Any], info: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: sim_state = self.sim.get_state() info[self.STATE_KEY] = sim_state - info[self.STATE_SPEC_KEY] = self.sim.get_state_spec() + info[self.STATE_SCHEMA_KEY] = self.sim.get_state_schema() return observation, info diff --git a/python/rcs/sim/replayer.py b/python/rcs/sim/replayer.py index 0ca28371..419f5245 100644 --- a/python/rcs/sim/replayer.py +++ b/python/rcs/sim/replayer.py @@ -9,12 +9,12 @@ import rcs.envs.configs as env_configs import rcs.envs.tasks as env_tasks from rcs._core.sim import SimConfig -from rcs.envs.base import RelativeTo, SimEnv, SimStateSpec +from rcs.envs.base import RelativeTo, SimEnv, SimStateSchema from rcs.envs.scenes import SimEnvCreator from rcs.envs.storage_wrapper import StorageWrapper -def _normalize_sim_state_spec(value: Any) -> SimStateSpec: +def _normalize_sim_state_schema(value: Any) -> SimStateSchema: return { "joint_names": [str(item) for item in value["joint_names"]], "joint_types": [int(item) for item in value["joint_types"]], @@ -47,13 +47,13 @@ def sim_state(self) -> np.ndarray: raise KeyError(msg) @property - def sim_state_spec(self) -> SimStateSpec | None: - if SimEnv.STATE_SPEC_KEY in self.info: - return _normalize_sim_state_spec(self.info[SimEnv.STATE_SPEC_KEY]) + def sim_state_spec(self) -> SimStateSchema | None: + if SimEnv.STATE_SCHEMA_KEY in self.info: + return _normalize_sim_state_schema(self.info[SimEnv.STATE_SCHEMA_KEY]) for value in self.info.values(): - if isinstance(value, dict) and SimEnv.STATE_SPEC_KEY in value: - return _normalize_sim_state_spec(value[SimEnv.STATE_SPEC_KEY]) + if isinstance(value, dict) and SimEnv.STATE_SCHEMA_KEY in value: + return _normalize_sim_state_schema(value[SimEnv.STATE_SCHEMA_KEY]) return None diff --git a/python/rcs/sim/sim.py b/python/rcs/sim/sim.py index ae24969b..e030f7f1 100644 --- a/python/rcs/sim/sim.py +++ b/python/rcs/sim/sim.py @@ -1,6 +1,7 @@ import atexit import contextlib import multiprocessing as mp +import typing import uuid from logging import getLogger from multiprocessing.synchronize import Event as EventClass @@ -12,8 +13,8 @@ import mujoco as mj import mujoco.viewer import numpy as np -from rcs._core.sim import DynamicJointSchema as _DynamicJointSchema -from rcs._core.sim import DynamicJointState as _DynamicJointState +from rcs._core.sim import DynamicJointSchema +from rcs._core.sim import DynamicJointState from rcs._core.sim import GuiClient as _GuiClient from rcs._core.sim import Sim as _Sim from rcs.sim import SimConfig, egl_bootstrap @@ -71,69 +72,49 @@ def __init__(self, mjmdl: str | PathLike | ModelComposer, cfg: SimConfig | None if cfg is not None: self.set_config(cfg) - def get_state_spec(self) -> dict[str, list[str] | list[int]]: - return self.get_dynamic_joint_schema() + def get_state_schema(self) -> dict[str, list[str] | list[int]]: + schema = super().get_dynamic_joint_schema() + return { + "joint_names": list(schema.joint_names), + "joint_types": list(schema.joint_types), + "qpos_sizes": list(schema.qpos_sizes), + "qvel_sizes": list(schema.qvel_sizes), + } - def get_state_size(self, spec: dict[str, list[str] | list[int]] | None = None) -> int: - state_spec = self.get_state_spec() if spec is None else spec - qpos_size = sum(int(value) for value in state_spec["qpos_sizes"]) - qvel_size = sum(int(value) for value in state_spec["qvel_sizes"]) + def get_state_size(self, schema: dict[str, list[str] | list[int]] | None = None) -> int: + state_schema = self.get_state_schema() if schema is None else schema + qpos_size = sum(int(value) for value in state_schema["qpos_sizes"]) + qvel_size = sum(int(value) for value in state_schema["qvel_sizes"]) return qpos_size + qvel_size - def get_state(self, spec: dict[str, list[str] | list[int]] | None = None) -> np.ndarray: - del spec - dynamic_joint_state = self.get_dynamic_joint_state() - return np.concatenate((dynamic_joint_state["qpos"], dynamic_joint_state["qvel"])) + def get_state(self) -> np.ndarray: + state = super().get_dynamic_joint_state() + return np.concatenate((state.qpos, state.qvel)) def set_state( self, state: np.ndarray, - spec: dict[str, list[str] | list[int]] | None = None, + schema: dict[str, list[str] | list[int]] | None = None, ): - state_spec = self.get_state_spec() if spec is None else spec + state_schema = self.get_state_schema() if schema is None else schema state_array = np.asarray(state, dtype=np.float64) - expected_size = self.get_state_size(state_spec) + expected_size = self.get_state_size(state_schema) if state_array.shape != (expected_size,): msg = f"Expected state with shape ({expected_size},), got {state_array.shape}." raise ValueError(msg) - qpos_size = sum(int(value) for value in state_spec["qpos_sizes"]) - dynamic_joint_state = { - "qpos": state_array[:qpos_size], - "qvel": state_array[qpos_size:], - } - self.set_dynamic_joint_state(state_spec, dynamic_joint_state) + qpos_size = sum(int(value) for value in state_schema["qpos_sizes"]) - def get_dynamic_joint_schema(self) -> dict[str, list[str] | list[int]]: - schema = super().get_dynamic_joint_schema() - return { - "joint_names": list(schema.joint_names), - "joint_types": list(schema.joint_types), - "qpos_sizes": list(schema.qpos_sizes), - "qvel_sizes": list(schema.qvel_sizes), - } - def get_dynamic_joint_state(self) -> dict[str, np.ndarray]: - state = super().get_dynamic_joint_state() - return { - "qpos": np.asarray(state.qpos, dtype=np.float64), - "qvel": np.asarray(state.qvel, dtype=np.float64), - } + dynamic_joint_schema = DynamicJointSchema() + dynamic_joint_schema.joint_names = typing.cast(list[str], list(state_schema["joint_names"])) + dynamic_joint_schema.joint_types = [int(value) for value in state_schema["joint_types"]] + dynamic_joint_schema.qpos_sizes = [int(value) for value in state_schema["qpos_sizes"]] + dynamic_joint_schema.qvel_sizes = [int(value) for value in state_schema["qvel_sizes"]] - def set_dynamic_joint_state( - self, - schema: dict[str, list[str] | list[int]], - state: dict[str, np.ndarray], - ): - dynamic_joint_schema = _DynamicJointSchema() - dynamic_joint_schema.joint_names = list(schema["joint_names"]) - dynamic_joint_schema.joint_types = [int(value) for value in schema["joint_types"]] - dynamic_joint_schema.qpos_sizes = [int(value) for value in schema["qpos_sizes"]] - dynamic_joint_schema.qvel_sizes = [int(value) for value in schema["qvel_sizes"]] - - dynamic_joint_state = _DynamicJointState() - dynamic_joint_state.qpos = np.asarray(state["qpos"], dtype=np.float64) - dynamic_joint_state.qvel = np.asarray(state["qvel"], dtype=np.float64) + dynamic_joint_state = DynamicJointState() # type: ignore + dynamic_joint_state.qpos = state_array[:qpos_size] + dynamic_joint_state.qvel = state_array[qpos_size:] super().set_dynamic_joint_state(dynamic_joint_schema, dynamic_joint_state) def close_gui(self): diff --git a/python/tests/test_replayer.py b/python/tests/test_replayer.py index f17fc3e5..337bb7a5 100644 --- a/python/tests/test_replayer.py +++ b/python/tests/test_replayer.py @@ -138,7 +138,7 @@ def get_wrapper_attr(self, name: str): return getattr(self, name) def set_replay_state(self, state: np.ndarray, spec=None): - self._replay_state = (np.asarray(state, dtype=np.float64), spec) + self._replay_state = (state, spec) def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): super().reset(seed=seed) @@ -187,7 +187,7 @@ def _recorded_dummy_step(model_path: Path) -> RecordedSimStep: sim = Sim(model_path) state = sim.get_state().copy() state[0] = 0.125 - sim.set_state(state, sim.get_state_spec()) + sim.set_state(state, sim.get_state_schema()) return RecordedSimStep( step=0, uuid="dummy-trajectory", @@ -195,7 +195,7 @@ def _recorded_dummy_step(model_path: Path) -> RecordedSimStep: observation={}, info={ SimEnv.STATE_KEY: sim.get_state(), - SimEnv.STATE_SPEC_KEY: sim.get_state_spec(), + SimEnv.STATE_SCHEMA_KEY: sim.get_state_schema(), }, action={"delta": np.array([0.0], dtype=np.float64)}, instruction="", From 1ff1d308d5fea9b5c6ed0d436a0aa9abfb8d627f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20J=C3=BClg?= Date: Sun, 3 May 2026 13:21:15 +0200 Subject: [PATCH 07/13] style: fix generic dof typing --- Makefile | 6 ++++-- python/rcs/_core/common.pyi | 7 +++---- python/rcs/_core/sim.pyi | 7 +++---- python/rcs/envs/configs.py | 6 +++--- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/Makefile b/Makefile index 67d68544..9d5908b6 100644 --- a/Makefile +++ b/Makefile @@ -32,9 +32,11 @@ stubgen: find ./python -not -path "./python/rcs/_core/*" -name '*.pyi' -delete find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/tuple\[typing\.Literal\[\([0-9]\+\)\], typing\.Literal\[1\]\]/tuple\[typing\.Literal[\1]\]/g' find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/tuple\[\([M|N]\), typing\.Literal\[1\]\]/tuple\[\1\]/g' - find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/class RobotConfig/class RobotConfig(typing.Generic[M, N])/g' - find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/class SimRobotConfig(rcs._core.common.RobotConfig)/class SimRobotConfig(rcs._core.common.RobotConfig, typing.Generic[N])/g' + find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/class RobotConfig/class RobotConfig(typing.Generic[M])/g' + find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/class SimRobotConfig(rcs._core.common.RobotConfig)/class SimRobotConfig(rcs._core.common.RobotConfig[M])/g' find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/class DynamicJointState/class DynamicJointState(typing.Generic[M])/g' + find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/N = typing.TypeVar("N", bound=int)//g' + find ./python/rcs/_core -name '*.pyi' -print | xargs sed -i 's/, N/, M/g' python ci_scripts/generate_common_typing.py ruff check --fix python/rcs/_core python/rcs/common_typing.py isort python/rcs/_core python/rcs/common_typing.py diff --git a/python/rcs/_core/common.pyi b/python/rcs/_core/common.pyi index e3c8b2d5..b2749517 100644 --- a/python/rcs/_core/common.pyi +++ b/python/rcs/_core/common.pyi @@ -41,7 +41,6 @@ __all__: list[str] = [ "TRIPOD_GRASP", ] M = typing.TypeVar("M", bound=int) -N = typing.TypeVar("N", bound=int) class BaseCameraConfig: frame_rate: int @@ -239,10 +238,10 @@ class Robot: def to_pose_in_robot_coordinates(self, pose_in_world_coordinates: Pose) -> Pose: ... def to_pose_in_world_coordinates(self, pose_in_robot_coordinates: Pose) -> Pose: ... -class RobotConfig(typing.Generic[M, N]): +class RobotConfig(typing.Generic[M]): attachment_site: str dof: int - joint_limits: numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]] + joint_limits: numpy.ndarray[tuple[typing.Literal[2], M], numpy.dtype[numpy.float64]] kinematic_model_path: str q_home: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]] | None robot_platform: RobotPlatform @@ -252,7 +251,7 @@ class RobotConfig(typing.Generic[M, N]): self, robot_type: RobotType = ..., dof: int = 7, - joint_limits: numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]] = ..., + joint_limits: numpy.ndarray[tuple[typing.Literal[2], M], numpy.dtype[numpy.float64]] = ..., robot_platform: RobotPlatform = ..., tcp_offset: Pose = ..., attachment_site: str = "attachment_site", diff --git a/python/rcs/_core/sim.pyi b/python/rcs/_core/sim.pyi index f76c69a3..8ea8ae40 100644 --- a/python/rcs/_core/sim.pyi +++ b/python/rcs/_core/sim.pyi @@ -34,7 +34,6 @@ __all__: list[str] = [ "tracking", ] M = typing.TypeVar("M", bound=int) -N = typing.TypeVar("N", bound=int) class CameraType: """ @@ -214,12 +213,12 @@ class SimRobot(rcs._core.common.Robot): def set_config(self, cfg: SimRobotConfig) -> bool: ... def set_joints_hard(self, q: numpy.ndarray[tuple[M], numpy.dtype[numpy.float64]]) -> None: ... -class SimRobotConfig(rcs._core.common.RobotConfig, typing.Generic[N]): +class SimRobotConfig(rcs._core.common.RobotConfig[M]): actuators: list[str] arm_collision_geoms: list[str] base: str dof: int - joint_limits: numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]] + joint_limits: numpy.ndarray[tuple[typing.Literal[2], M], numpy.dtype[numpy.float64]] joint_rotational_tolerance: float joints: list[str] seconds_between_callbacks: float @@ -266,7 +265,7 @@ class SimRobotConfig(rcs._core.common.RobotConfig, typing.Generic[N]): ], base: str = "base", dof: int = 7, - joint_limits: numpy.ndarray[tuple[typing.Literal[2], N], numpy.dtype[numpy.float64]] = ..., + joint_limits: numpy.ndarray[tuple[typing.Literal[2], M], numpy.dtype[numpy.float64]] = ..., ) -> None: ... def add_prefix(self, id: str) -> None: ... diff --git a/python/rcs/envs/configs.py b/python/rcs/envs/configs.py index afbd9f95..05e63cde 100644 --- a/python/rcs/envs/configs.py +++ b/python/rcs/envs/configs.py @@ -1,6 +1,6 @@ import copy import time -from typing import ClassVar +from typing import ClassVar, Literal import numpy as np from rcs._core.common import FrankaHandTCPOffset, GripperType, RobotType @@ -37,7 +37,7 @@ class EmptyWorldFR3(SimEnvCreator): def config(self) -> SimEnvCreatorConfig: q_home = rcs.ROBOTS[RobotType.FR3].q_home q_home[-1] = np.pi / 4 - robot_cfg = SimRobotConfig( + robot_cfg = SimRobotConfig[Literal[7]]( robot_type=RobotType.FR3, tcp_offset=GRIPPER_OFFSETS[rcs.common.GripperType.FrankaHand], attachment_site=rcs.ROBOTS[RobotType.FR3].attachment_site, @@ -183,7 +183,7 @@ class EmptyWorldFR3Duo(SimEnvCreator): gripper_mesh_quaternion_offset: ClassVar[list[float]] = [0, 0, 0.7071068, 0.7071068] def config(self) -> SimEnvCreatorConfig: - robot_cfg = SimRobotConfig( + robot_cfg = SimRobotConfig[Literal[7]]( tcp_offset=GRIPPER_OFFSETS[rcs.common.GripperType("Robotiq2F85")], robot_type=RobotType.FR3, attachment_site=rcs.ROBOTS[RobotType.FR3].attachment_site, From 28262edd4ff5e3da49922588bc9efcf95a80dc93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20J=C3=BClg?= Date: Sun, 3 May 2026 13:43:49 +0200 Subject: [PATCH 08/13] test: refactor replayer tests and fix type annotation --- python/rcs/envs/configs.py | 6 +- python/rcs/sim/replayer.py | 6 +- python/rcs/sim/sim.py | 6 +- python/tests/test_replayer.py | 135 +++++++++++----------------------- 4 files changed, 52 insertions(+), 101 deletions(-) diff --git a/python/rcs/envs/configs.py b/python/rcs/envs/configs.py index 05e63cde..f29165b8 100644 --- a/python/rcs/envs/configs.py +++ b/python/rcs/envs/configs.py @@ -37,7 +37,7 @@ class EmptyWorldFR3(SimEnvCreator): def config(self) -> SimEnvCreatorConfig: q_home = rcs.ROBOTS[RobotType.FR3].q_home q_home[-1] = np.pi / 4 - robot_cfg = SimRobotConfig[Literal[7]]( + robot_cfg: SimRobotConfig[Literal[7]] = SimRobotConfig( robot_type=RobotType.FR3, tcp_offset=GRIPPER_OFFSETS[rcs.common.GripperType.FrankaHand], attachment_site=rcs.ROBOTS[RobotType.FR3].attachment_site, @@ -183,7 +183,7 @@ class EmptyWorldFR3Duo(SimEnvCreator): gripper_mesh_quaternion_offset: ClassVar[list[float]] = [0, 0, 0.7071068, 0.7071068] def config(self) -> SimEnvCreatorConfig: - robot_cfg = SimRobotConfig[Literal[7]]( + robot_cfg: SimRobotConfig[Literal[7]] = SimRobotConfig( tcp_offset=GRIPPER_OFFSETS[rcs.common.GripperType("Robotiq2F85")], robot_type=RobotType.FR3, attachment_site=rcs.ROBOTS[RobotType.FR3].attachment_site, @@ -224,7 +224,7 @@ def config(self) -> SimEnvCreatorConfig: joint_limits=rcs.ROBOTS[RobotType.FR3].joint_limits, q_home=rcs.HOME_POSITIONS["FR3_DUO_LEFT"], ) - robot_cfg_right = copy.deepcopy(robot_cfg) + robot_cfg_right: SimRobotConfig[Literal[7]] = copy.deepcopy(robot_cfg) robot_cfg_right.q_home = rcs.HOME_POSITIONS["FR3_DUO_RIGHT"] robot_cfgs: dict[str, SimRobotConfig] = {"left": robot_cfg, "right": robot_cfg_right} diff --git a/python/rcs/sim/replayer.py b/python/rcs/sim/replayer.py index 419f5245..cd7a23ca 100644 --- a/python/rcs/sim/replayer.py +++ b/python/rcs/sim/replayer.py @@ -47,7 +47,7 @@ def sim_state(self) -> np.ndarray: raise KeyError(msg) @property - def sim_state_spec(self) -> SimStateSchema | None: + def sim_state_schema(self) -> SimStateSchema | None: if SimEnv.STATE_SCHEMA_KEY in self.info: return _normalize_sim_state_schema(self.info[SimEnv.STATE_SCHEMA_KEY]) @@ -103,9 +103,9 @@ def restore_sim_step(env: gym.Env, recorded_step: RecordedSimStep): lead_env = None if lead_env is not None: - lead_env.set_replay_state(recorded_step.sim_state, spec=recorded_step.sim_state_spec) + lead_env.set_replay_state(recorded_step.sim_state, schema=recorded_step.sim_state_schema) else: - env.get_wrapper_attr("set_replay_state")(recorded_step.sim_state, spec=recorded_step.sim_state_spec) + env.get_wrapper_attr("set_replay_state")(recorded_step.sim_state, schema=recorded_step.sim_state_schema) def replay_trajectory(env: gym.Env, recorded_steps: list[RecordedSimStep], headless: bool): diff --git a/python/rcs/sim/sim.py b/python/rcs/sim/sim.py index e030f7f1..38ce779d 100644 --- a/python/rcs/sim/sim.py +++ b/python/rcs/sim/sim.py @@ -13,8 +13,7 @@ import mujoco as mj import mujoco.viewer import numpy as np -from rcs._core.sim import DynamicJointSchema -from rcs._core.sim import DynamicJointState +from rcs._core.sim import DynamicJointSchema, DynamicJointState from rcs._core.sim import GuiClient as _GuiClient from rcs._core.sim import Sim as _Sim from rcs.sim import SimConfig, egl_bootstrap @@ -105,14 +104,13 @@ def set_state( qpos_size = sum(int(value) for value in state_schema["qpos_sizes"]) - dynamic_joint_schema = DynamicJointSchema() dynamic_joint_schema.joint_names = typing.cast(list[str], list(state_schema["joint_names"])) dynamic_joint_schema.joint_types = [int(value) for value in state_schema["joint_types"]] dynamic_joint_schema.qpos_sizes = [int(value) for value in state_schema["qpos_sizes"]] dynamic_joint_schema.qvel_sizes = [int(value) for value in state_schema["qvel_sizes"]] - dynamic_joint_state = DynamicJointState() # type: ignore + dynamic_joint_state = DynamicJointState() # type: ignore dynamic_joint_state.qpos = state_array[:qpos_size] dynamic_joint_state.qvel = state_array[qpos_size:] super().set_dynamic_joint_state(dynamic_joint_schema, dynamic_joint_state) diff --git a/python/tests/test_replayer.py b/python/tests/test_replayer.py index 337bb7a5..32ba3b9b 100644 --- a/python/tests/test_replayer.py +++ b/python/tests/test_replayer.py @@ -3,24 +3,22 @@ from typing import Any import duckdb -import gymnasium as gym -import mujoco as mj import numpy as np from rcs._core.sim import SimConfig -from rcs.envs.base import RelativeTo, SimEnv +from rcs.envs.base import RelativeTo from rcs.envs.configs import EmptyWorldFR3Duo from rcs.envs.storage_wrapper import StorageWrapper from rcs.envs.tasks import PickTaskConfig -from rcs.sim.replayer import ( - RecordedSimStep, - load_distinct_uuids, - load_trajectory, - replay_trajectory, -) -from rcs.sim.sim import Sim +from rcs.sim.replayer import load_distinct_uuids, load_trajectory, replay_trajectory -def _build_env(output_dir: Path, *, with_cameras: bool, instruction: str = "") -> StorageWrapper: +def _build_env( + output_dir: Path, + *, + with_cameras: bool, + instruction: str = "", + scene_path: Path | None = None, +) -> StorageWrapper: scene = EmptyWorldFR3Duo() cfg = scene.config() cfg.sim_cfg = SimConfig(async_control=True, realtime=False, frequency=30, max_convergence_steps=500) @@ -29,6 +27,8 @@ def _build_env(output_dir: Path, *, with_cameras: bool, instruction: str = "") - if cfg.root_frame_objects is None: cfg.root_frame_objects = {} cfg.task_cfg = PickTaskConfig(robot_name="right") + if scene_path is not None: + cfg.scene = str(scene_path) if not with_cameras: cfg.camera_cfgs = {} else: @@ -50,8 +50,14 @@ def _build_env(output_dir: Path, *, with_cameras: bool, instruction: str = "") - ) -def _record_source_dataset(dataset_dir: Path, *, limit: int, instruction: str) -> None: - env = _build_env(dataset_dir, with_cameras=False, instruction=instruction) +def _record_source_dataset( + dataset_dir: Path, + *, + limit: int, + instruction: str, + scene_path: Path | None = None, +) -> None: + env = _build_env(dataset_dir, with_cameras=False, instruction=instruction, scene_path=scene_path) try: env.reset() action = { @@ -103,9 +109,9 @@ def _replay_rows(dataset_dir: Path): connection.close() -def _replay_prefix(output_dir: Path, *, with_cameras: bool, limit: int) -> None: +def _replay_prefix(output_dir: Path, *, with_cameras: bool, limit: int, scene_path: Path | None = None) -> None: source_dir = output_dir.parent / "source" - env = _build_env(output_dir, with_cameras=with_cameras) + env = _build_env(output_dir, with_cameras=with_cameras, scene_path=scene_path) try: uuid = load_distinct_uuids(source_dir)[0] recorded_steps = load_trajectory(source_dir, uuid)[:limit] @@ -115,48 +121,6 @@ def _replay_prefix(output_dir: Path, *, with_cameras: bool, limit: int) -> None: env.close() -MINIMAL_XML = """ - - - - - - - - - -""" - - -class DummyReplayEnv(gym.Env): - def __init__(self, sim: Sim): - super().__init__() - self.sim = sim - self._replay_state = None - - def get_wrapper_attr(self, name: str): - return getattr(self, name) - - def set_replay_state(self, state: np.ndarray, spec=None): - self._replay_state = (state, spec) - - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): - super().reset(seed=seed) - mj.mj_resetData(self.sim.model, self.sim.data) - mj.mj_forward(self.sim.model, self.sim.data) - return {}, {} - - def step(self, action: dict[str, np.ndarray]): - if self._replay_state is not None: - state, spec = self._replay_state - self.sim.set_state(state, spec) - self._replay_state = None - self.sim.data.qpos[0] += float(action["delta"][0]) - self.sim.data.qvel[:] = 0.0 - mj.mj_forward(self.sim.model, self.sim.data) - return {}, 0.0, False, False, {} - - def _write_scene_with_extra_fixed_body_and_camera(src: Path, dst: Path): tree = ET.parse(src) root = tree.getroot() @@ -183,26 +147,6 @@ def _write_scene_with_extra_fixed_body_and_camera(src: Path, dst: Path): tree.write(dst) -def _recorded_dummy_step(model_path: Path) -> RecordedSimStep: - sim = Sim(model_path) - state = sim.get_state().copy() - state[0] = 0.125 - sim.set_state(state, sim.get_state_schema()) - return RecordedSimStep( - step=0, - uuid="dummy-trajectory", - timestamp=None, - observation={}, - info={ - SimEnv.STATE_KEY: sim.get_state(), - SimEnv.STATE_SCHEMA_KEY: sim.get_state_schema(), - }, - action={"delta": np.array([0.0], dtype=np.float64)}, - instruction="", - success=False, - ) - - def _assert_nested_close(actual: Any, expected: Any, *, atol: float = 1e-6): if isinstance(expected, dict): assert isinstance(actual, dict) @@ -296,21 +240,30 @@ def test_replayer_reproduces_existing_parquet_prefix_without_cameras(tmp_path: P def test_replayer_restores_sim_state_across_fixed_scene_changes(tmp_path: Path): - base_model_path = tmp_path / "base.xml" - base_model_path.write_text(MINIMAL_XML) - modified_model_path = tmp_path / "modified.xml" - _write_scene_with_extra_fixed_body_and_camera(base_model_path, modified_model_path) - - for record_model_path, replay_model_path in ( - (base_model_path, modified_model_path), - (modified_model_path, base_model_path), - ): - recorded_step = _recorded_dummy_step(record_model_path) - replay_env = DummyReplayEnv(Sim(replay_model_path)) + source_scene_path = Path(EmptyWorldFR3Duo().config().scene) + modified_scene_path = tmp_path / "modified_scene.xml" + _write_scene_with_extra_fixed_body_and_camera(source_scene_path, modified_scene_path) - replay_trajectory(replay_env, [recorded_step], True) - - assert np.allclose(replay_env.sim.get_state(), recorded_step.sim_state, atol=1e-9, rtol=0) + for record_scene_path, replay_scene_path in ( + (source_scene_path, modified_scene_path), + (modified_scene_path, source_scene_path), + ): + case_dir = tmp_path / f"{record_scene_path.stem}-to-{replay_scene_path.stem}" + source_dir = case_dir / "source" + replay_dir = case_dir / "replayed" + + _record_source_dataset(source_dir, limit=3, instruction="pick up cube", scene_path=record_scene_path) + _replay_prefix(replay_dir, with_cameras=False, limit=3, scene_path=replay_scene_path) + + source_uuid = load_distinct_uuids(source_dir)[0] + replay_uuid = load_distinct_uuids(replay_dir)[0] + source_steps = load_trajectory(source_dir, source_uuid) + replay_steps = load_trajectory(replay_dir, replay_uuid) + + assert len(source_steps) == len(replay_steps) == 3 + for replay_step, source_step in zip(replay_steps, source_steps, strict=True): + assert replay_step.sim_state_schema == source_step.sim_state_schema + assert np.allclose(replay_step.sim_state, source_step.sim_state, atol=1e-5, rtol=0) def test_replayer_adds_cameras_to_existing_episode_without_cameras(tmp_path: Path): From 89d9e8acd1e5d4025e57ba902e80eb3f662ae44a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20J=C3=BClg?= Date: Sun, 3 May 2026 13:44:14 +0200 Subject: [PATCH 09/13] bump(python): min python version 3.11 --- extensions/rcs_fr3/pyproject.toml | 2 +- extensions/rcs_panda/pyproject.toml | 2 +- extensions/rcs_robotics_library/pyproject.toml | 2 +- extensions/rcs_robotiq2f85/pyproject.toml | 2 +- extensions/rcs_so101/pyproject.toml | 2 +- extensions/rcs_tacto/pyproject.toml | 2 +- extensions/rcs_ur5e/pyproject.toml | 2 +- extensions/rcs_usb_cam/pyproject.toml | 2 +- extensions/rcs_xarm7/pyproject.toml | 2 +- pyproject.toml | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/extensions/rcs_fr3/pyproject.toml b/extensions/rcs_fr3/pyproject.toml index 249ca89a..1127acf9 100644 --- a/extensions/rcs_fr3/pyproject.toml +++ b/extensions/rcs_fr3/pyproject.toml @@ -19,7 +19,7 @@ dependencies = ["rcs>=0.6.3", "frankik"] readme = "README.md" maintainers = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }] authors = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.scikit-build] diff --git a/extensions/rcs_panda/pyproject.toml b/extensions/rcs_panda/pyproject.toml index 99319869..ed122a0c 100644 --- a/extensions/rcs_panda/pyproject.toml +++ b/extensions/rcs_panda/pyproject.toml @@ -18,7 +18,7 @@ dependencies = ["rcs>=0.6.3"] readme = "README.md" maintainers = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }] authors = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.scikit-build] diff --git a/extensions/rcs_robotics_library/pyproject.toml b/extensions/rcs_robotics_library/pyproject.toml index d2fe0da7..17a2c09b 100644 --- a/extensions/rcs_robotics_library/pyproject.toml +++ b/extensions/rcs_robotics_library/pyproject.toml @@ -21,7 +21,7 @@ authors = [ { name = "Tobias Jülg", email = "tobias.juelg@utn.de" }, { name = "Pierre Krack", email = "pierre.krack@utn.de" }, ] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.scikit-build] diff --git a/extensions/rcs_robotiq2f85/pyproject.toml b/extensions/rcs_robotiq2f85/pyproject.toml index 835dfe02..ddb09018 100644 --- a/extensions/rcs_robotiq2f85/pyproject.toml +++ b/extensions/rcs_robotiq2f85/pyproject.toml @@ -17,5 +17,5 @@ maintainers = [ authors = [ { name = "Tobias Jülg", email = "tobias.juelg@utn.de" }, ] -requires-python = ">=3.10" +requires-python = ">=3.11" license = { text = "AGPL-3.0-or-later" } diff --git a/extensions/rcs_so101/pyproject.toml b/extensions/rcs_so101/pyproject.toml index 06d22ba3..0dbf0130 100644 --- a/extensions/rcs_so101/pyproject.toml +++ b/extensions/rcs_so101/pyproject.toml @@ -18,7 +18,7 @@ dependencies = ["rcs>=0.6.3", "lerobot==0.3.3"] readme = "README.md" maintainers = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }] authors = [{ name = "Tobias Jülg", email = "tobias.juelg@utn.de" }] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.scikit-build] build.verbose = true diff --git a/extensions/rcs_tacto/pyproject.toml b/extensions/rcs_tacto/pyproject.toml index f69dd2cb..7c94e292 100644 --- a/extensions/rcs_tacto/pyproject.toml +++ b/extensions/rcs_tacto/pyproject.toml @@ -17,7 +17,7 @@ maintainers = [ { name = "Seongjin Bien", email = "seongjin.bien@utn.de" }, ] authors = [{ name = "Seongjin Bien", email = "seongjin.bien@utn.de" }] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.black] line-length = 120 diff --git a/extensions/rcs_ur5e/pyproject.toml b/extensions/rcs_ur5e/pyproject.toml index 1cb50fde..84ddac07 100644 --- a/extensions/rcs_ur5e/pyproject.toml +++ b/extensions/rcs_ur5e/pyproject.toml @@ -15,7 +15,7 @@ authors = [ { name = "Tobias Jülg", email = "tobias.juelg@utn.de" }, { name = "Johannes Hechtl", email = "johannes.hechtl@siemens.com" }, ] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.black] line-length = 120 diff --git a/extensions/rcs_usb_cam/pyproject.toml b/extensions/rcs_usb_cam/pyproject.toml index 8d42c626..4a39396b 100644 --- a/extensions/rcs_usb_cam/pyproject.toml +++ b/extensions/rcs_usb_cam/pyproject.toml @@ -12,7 +12,7 @@ maintainers = [ { name = "Seongjin Bien", email = "seongjin.bien@utn.de" }, ] authors = [{ name = "Seongjin Bien", email = "seongjin.bien@utn.de" }] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.black] line-length = 120 diff --git a/extensions/rcs_xarm7/pyproject.toml b/extensions/rcs_xarm7/pyproject.toml index a024385b..3432f3e3 100644 --- a/extensions/rcs_xarm7/pyproject.toml +++ b/extensions/rcs_xarm7/pyproject.toml @@ -16,7 +16,7 @@ authors = [ { name = "Tobias Jülg", email = "tobias.juelg@utn.de" }, { name = "Ken Nakahara", email = "knakahara@lasr.org" }, ] -requires-python = ">=3.10" +requires-python = ">=3.11" [tool.black] line-length = 120 diff --git a/pyproject.toml b/pyproject.toml index 6a9213d3..ed210c40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ authors = [ { name = "Pierre Krack", email = "pierre.krack@utn.de" }, { name = "Seongjin Bien", email = "seongjin.bien@utn.de" }, ] -requires-python = ">=3.10" +requires-python = ">=3.11" license = { file = "LICENSE" } [dependency-groups] From 3cb7b5202ec15dc425543428e4f1c9524442f952 Mon Sep 17 00:00:00 2001 From: Tobias Juelg Date: Sun, 3 May 2026 22:29:34 +0200 Subject: [PATCH 10/13] fix(sim): encode replay free joints relative to root frame --- python/rcs/envs/scenes.py | 25 ++++++- python/rcs/envs/tasks.py | 3 +- python/rcs/sim/composer.py | 31 +++++++-- python/rcs/sim/replayer.py | 5 +- python/rcs/sim/sim.py | 115 ++++++++++++++++++++++++++++++++- python/tests/test_replayer.py | 118 ++++++++++++++++++++++++++++++++-- 6 files changed, 278 insertions(+), 19 deletions(-) diff --git a/python/rcs/envs/scenes.py b/python/rcs/envs/scenes.py index b661204d..ea11f3b3 100644 --- a/python/rcs/envs/scenes.py +++ b/python/rcs/envs/scenes.py @@ -239,7 +239,13 @@ def create_model(self, cfg: SimEnvCreatorConfig) -> MjModel: if cfg.root_frame_objects is not None: for object_id, (object_xml, object2root_frame) in cfg.root_frame_objects.items(): object2world = cfg.root_frame_to_world * object2root_frame - self.add_object_mujoco(composer, object_id, object_xml, object2world) + self.add_object_mujoco( + composer, + object_id, + object_xml, + object2world, + register_root_relative_replay_free_joints=True, + ) # add external objects if cfg.world_frame_objects is not None: for object_id, (object_xml, object2world) in cfg.world_frame_objects.items(): @@ -324,6 +330,11 @@ def create_env_from_model(self, cfg: SimEnvCreatorConfig, mjmodel: MjModel) -> g prefixed_cfg = self.prefixed_cfg(cfg) simulation = Sim(mjmodel, prefixed_cfg.sim_cfg) + if isinstance(mjmodel, ModelComposer): + simulation.configure_state_encodings( + root_frame_to_world=cfg.root_frame_to_world, + root_relative_free_joints=mjmodel.root_relative_replay_free_joints, + ) envs: dict[str, gym.Env] = {} env: gym.Env @@ -373,14 +384,22 @@ def add_task_env( return env def add_object_mujoco( - self, composer: ModelComposer, object_id: str, object_xml: str, object2world: rcs.common.Pose + self, + composer: ModelComposer, + object_id: str, + object_xml: str, + object2world: rcs.common.Pose, + *, + register_root_relative_replay_free_joints: bool = False, ): """Add an object to the Mujoco scene.""" - composer.add_object_world_frame( + added_object = composer.add_object_world_frame( object_xml, object_prefix=object_id + "_", pose=object2world, ) + if register_root_relative_replay_free_joints: + composer.register_root_relative_replay_free_joints(added_object.prefixed_free_joint_names) def add_object_robot_frame_mujoco( self, diff --git a/python/rcs/envs/tasks.py b/python/rcs/envs/tasks.py index fde0a7d7..37773dc0 100644 --- a/python/rcs/envs/tasks.py +++ b/python/rcs/envs/tasks.py @@ -163,11 +163,12 @@ def add_task_mujoco(cfg: PickTaskConfig, composer: ModelComposer, env_cfg: SimEn """Add task-specific elements to the Mujoco scene.""" object2world = cfg.object_center_to_root_frame * env_cfg.root_frame_to_world - composer.add_object_world_frame( + added_object = composer.add_object_world_frame( cfg.object_xml, object_prefix=cfg.prefix, pose=object2world, ) + composer.register_root_relative_replay_free_joints(added_object.prefixed_free_joint_names) @staticmethod def add_task_env(cfg: PickTaskConfig, env: gym.Env, _simulation: Sim, env_cfg: SimEnvCreatorConfig) -> gym.Env: diff --git a/python/rcs/sim/composer.py b/python/rcs/sim/composer.py index 27f73d48..504894c1 100644 --- a/python/rcs/sim/composer.py +++ b/python/rcs/sim/composer.py @@ -1,10 +1,17 @@ import os +from dataclasses import dataclass from typing import Optional import mujoco from rcs._core.common import Pose +@dataclass(frozen=True) +class AddedObject: + root_body: mujoco._specs.MjsBody + prefixed_free_joint_names: list[str] + + class ModelComposer: """ Composes MuJoCo scenes using mjSpec with flexible positioning and prefixing. @@ -20,6 +27,7 @@ def __init__( self.spec.compiler.autolimits = True self.add_gravcomp = add_gravcomp self._gravcomp_prefixes: set[str] = set() + self._root_relative_replay_free_joints: set[str] = set() def _resolve_asset_paths(self, spec: mujoco.MjSpec, xml_path: str): """Resolves relative paths to absolute ones.""" @@ -67,6 +75,17 @@ def _apply_pose(self, body: mujoco._specs.MjsBody, pose: Pose): body.pos = list(pose.translation()) body.quat = list(pose.rotation_q_wxyz()) + def _prefixed_free_joint_names(self, spec: mujoco.MjSpec, prefix: str) -> list[str]: + free_joint_type = int(mujoco.mjtJoint.mjJNT_FREE) + return [f"{prefix}{joint.name}" for joint in spec.joints if joint.name and int(joint.type) == free_joint_type] + + def register_root_relative_replay_free_joints(self, joint_names: list[str]): + self._root_relative_replay_free_joints.update(joint_names) + + @property + def root_relative_replay_free_joints(self) -> set[str]: + return set(self._root_relative_replay_free_joints) + def add_camera( self, resolution: tuple[int, int], @@ -230,7 +249,7 @@ def add_object_robot_frame( object_prefix: str, attachment_site_name: str, pose: Pose | None = None, - ) -> mujoco._specs.MjsBody: + ) -> AddedObject: """Attaches an object to a robot attachment site with an optional local pose offset.""" if pose is None: pose = Pose() @@ -243,15 +262,14 @@ def add_object_robot_frame( object_spec = mujoco.MjSpec.from_file(xml_path) self._resolve_asset_paths(object_spec, xml_path) + prefixed_free_joint_names = self._prefixed_free_joint_names(object_spec, object_prefix) object_root = object_spec.worldbody.first_body() object_root = attachment_site.attach(object_root, object_prefix, "") self._apply_pose(object_root, pose) - return object_root + return AddedObject(root_body=object_root, prefixed_free_joint_names=prefixed_free_joint_names) - def add_object_world_frame( - self, xml_path: str, object_prefix: str, pose: Pose | None = None - ) -> mujoco._specs.MjsBody: + def add_object_world_frame(self, xml_path: str, object_prefix: str, pose: Pose | None = None) -> AddedObject: """ Attaches a single object MJCF at a specific pose. Assumes the XML contains only one root body in the worldbody. @@ -265,6 +283,7 @@ def add_object_world_frame( # Load the child spec child_spec = mujoco.MjSpec.from_file(xml_path) self._resolve_asset_paths(child_spec, xml_path) + prefixed_free_joint_names = self._prefixed_free_joint_names(child_spec, object_prefix) # Attach using a frame frame = self.spec.worldbody.add_frame() @@ -281,7 +300,7 @@ def add_object_world_frame( # Apply the pose self._apply_pose(obj_root, pose) - return obj_root + return AddedObject(root_body=obj_root, prefixed_free_joint_names=prefixed_free_joint_names) def save_mjcf(self, output_path: str): """Compiles and saves the MJCF.""" diff --git a/python/rcs/sim/replayer.py b/python/rcs/sim/replayer.py index cd7a23ca..05a091a1 100644 --- a/python/rcs/sim/replayer.py +++ b/python/rcs/sim/replayer.py @@ -12,14 +12,17 @@ from rcs.envs.base import RelativeTo, SimEnv, SimStateSchema from rcs.envs.scenes import SimEnvCreator from rcs.envs.storage_wrapper import StorageWrapper +from rcs.sim.sim import RAW_STATE_ENCODING def _normalize_sim_state_schema(value: Any) -> SimStateSchema: + joint_names = [str(item) for item in value["joint_names"]] return { - "joint_names": [str(item) for item in value["joint_names"]], + "joint_names": joint_names, "joint_types": [int(item) for item in value["joint_types"]], "qpos_sizes": [int(item) for item in value["qpos_sizes"]], "qvel_sizes": [int(item) for item in value["qvel_sizes"]], + "encodings": [str(item) for item in value.get("encodings", [RAW_STATE_ENCODING] * len(joint_names))], } diff --git a/python/rcs/sim/sim.py b/python/rcs/sim/sim.py index 38ce779d..af5dd72e 100644 --- a/python/rcs/sim/sim.py +++ b/python/rcs/sim/sim.py @@ -13,6 +13,7 @@ import mujoco as mj import mujoco.viewer import numpy as np +from rcs._core import common from rcs._core.sim import DynamicJointSchema, DynamicJointState from rcs._core.sim import GuiClient as _GuiClient from rcs._core.sim import Sim as _Sim @@ -26,6 +27,8 @@ # Target frames per second FPS = 60 +RAW_STATE_ENCODING = "raw" +ROOT_RELATIVE_FREE_STATE_ENCODING = "root_relative_free" def gui_loop(gui_uuid: str, close_event): @@ -68,9 +71,20 @@ def __init__(self, mjmdl: str | PathLike | ModelComposer, cfg: SimConfig | None self._gui_process: Optional[mp.context.SpawnProcess] = None self._stop_event: Optional[EventClass] = None self._gui_atexit_registered = False + self._root_frame_to_world = common.Pose() + self._root_relative_replay_free_joints: set[str] = set() if cfg is not None: self.set_config(cfg) + def configure_state_encodings( + self, + *, + root_frame_to_world: common.Pose, + root_relative_free_joints: typing.Iterable[str] = (), + ): + self._root_frame_to_world = common.Pose(root_frame_to_world) + self._root_relative_replay_free_joints = set(root_relative_free_joints) + def get_state_schema(self) -> dict[str, list[str] | list[int]]: schema = super().get_dynamic_joint_schema() return { @@ -78,6 +92,14 @@ def get_state_schema(self) -> dict[str, list[str] | list[int]]: "joint_types": list(schema.joint_types), "qpos_sizes": list(schema.qpos_sizes), "qvel_sizes": list(schema.qvel_sizes), + "encodings": [ + ( + ROOT_RELATIVE_FREE_STATE_ENCODING + if joint_name in self._root_relative_replay_free_joints + else RAW_STATE_ENCODING + ) + for joint_name in schema.joint_names + ], } def get_state_size(self, schema: dict[str, list[str] | list[int]] | None = None) -> int: @@ -88,7 +110,10 @@ def get_state_size(self, schema: dict[str, list[str] | list[int]] | None = None) def get_state(self) -> np.ndarray: state = super().get_dynamic_joint_state() - return np.concatenate((state.qpos, state.qvel)) + qpos = np.array(state.qpos, copy=True) + qvel = np.array(state.qvel, copy=True) + self._encode_state_in_place(qpos, qvel, self.get_state_schema()) + return np.concatenate((qpos, qvel)) def set_state( self, @@ -103,6 +128,9 @@ def set_state( raise ValueError(msg) qpos_size = sum(int(value) for value in state_schema["qpos_sizes"]) + qpos = np.array(state_array[:qpos_size], copy=True) + qvel = np.array(state_array[qpos_size:], copy=True) + self._decode_state_in_place(qpos, qvel, state_schema) dynamic_joint_schema = DynamicJointSchema() dynamic_joint_schema.joint_names = typing.cast(list[str], list(state_schema["joint_names"])) @@ -111,10 +139,91 @@ def set_state( dynamic_joint_schema.qvel_sizes = [int(value) for value in state_schema["qvel_sizes"]] dynamic_joint_state = DynamicJointState() # type: ignore - dynamic_joint_state.qpos = state_array[:qpos_size] - dynamic_joint_state.qvel = state_array[qpos_size:] + dynamic_joint_state.qpos = qpos + dynamic_joint_state.qvel = qvel super().set_dynamic_joint_state(dynamic_joint_schema, dynamic_joint_state) + def _encode_state_in_place( + self, + qpos: np.ndarray, + qvel: np.ndarray, + schema: dict[str, list[str] | list[int]], + ): + self._transform_state_in_place(qpos, qvel, schema, encode=True) + + def _decode_state_in_place( + self, + qpos: np.ndarray, + qvel: np.ndarray, + schema: dict[str, list[str] | list[int]], + ): + self._transform_state_in_place(qpos, qvel, schema, encode=False) + + def _transform_state_in_place( + self, + qpos: np.ndarray, + qvel: np.ndarray, + schema: dict[str, list[str] | list[int]], + *, + encode: bool, + ): + joint_names = typing.cast(list[str], list(schema["joint_names"])) + joint_types = [int(value) for value in schema["joint_types"]] + qpos_sizes = [int(value) for value in schema["qpos_sizes"]] + qvel_sizes = [int(value) for value in schema["qvel_sizes"]] + encodings = typing.cast(list[str], list(schema.get("encodings", [RAW_STATE_ENCODING] * len(joint_names)))) + + qpos_offset = 0 + qvel_offset = 0 + free_joint_type = int(mj.mjtJoint.mjJNT_FREE) + for joint_name, joint_type, joint_qpos_size, joint_qvel_size, encoding in zip( + joint_names, joint_types, qpos_sizes, qvel_sizes, encodings, strict=True + ): + if encoding == RAW_STATE_ENCODING: + pass + elif encoding == ROOT_RELATIVE_FREE_STATE_ENCODING: + if joint_type != free_joint_type or joint_qpos_size != 7 or joint_qvel_size != 6: + msg = ( + f"Joint '{joint_name}' uses encoding '{ROOT_RELATIVE_FREE_STATE_ENCODING}' " + "but is not a free joint." + ) + raise ValueError(msg) + joint_qpos = qpos[qpos_offset : qpos_offset + joint_qpos_size] + joint_qvel = qvel[qvel_offset : qvel_offset + joint_qvel_size] + transformed_qpos, transformed_qvel = self._transform_root_relative_free_joint( + joint_qpos, joint_qvel, encode=encode + ) + qpos[qpos_offset : qpos_offset + joint_qpos_size] = transformed_qpos + qvel[qvel_offset : qvel_offset + joint_qvel_size] = transformed_qvel + else: + msg = f"Unsupported sim state encoding '{encoding}' for joint '{joint_name}'." + raise ValueError(msg) + + qpos_offset += joint_qpos_size + qvel_offset += joint_qvel_size + + def _transform_root_relative_free_joint( + self, + joint_qpos: np.ndarray, + joint_qvel: np.ndarray, + *, + encode: bool, + ) -> tuple[np.ndarray, np.ndarray]: + joint_pose_world = common.Pose( + translation=np.asarray(joint_qpos[:3], dtype=np.float64), + quaternion=np.asarray([joint_qpos[4], joint_qpos[5], joint_qpos[6], joint_qpos[3]], dtype=np.float64), + ) + root_inverse = self._root_frame_to_world.inverse() + joint_pose = root_inverse * joint_pose_world if encode else self._root_frame_to_world * joint_pose_world + + rotation = root_inverse.rotation_m() if encode else self._root_frame_to_world.rotation_m() + transformed_qvel = np.concatenate((rotation @ joint_qvel[:3], rotation @ joint_qvel[3:6])) + + return ( + np.concatenate((joint_pose.translation(), joint_pose.rotation_q_wxyz())), + transformed_qvel, + ) + def close_gui(self): if self._stop_event is not None: self._stop_event.set() diff --git a/python/tests/test_replayer.py b/python/tests/test_replayer.py index 32ba3b9b..1c299f25 100644 --- a/python/tests/test_replayer.py +++ b/python/tests/test_replayer.py @@ -1,6 +1,6 @@ import xml.etree.ElementTree as ET from pathlib import Path -from typing import Any +from typing import Any, cast import duckdb import numpy as np @@ -9,7 +9,15 @@ from rcs.envs.configs import EmptyWorldFR3Duo from rcs.envs.storage_wrapper import StorageWrapper from rcs.envs.tasks import PickTaskConfig -from rcs.sim.replayer import load_distinct_uuids, load_trajectory, replay_trajectory +from rcs.sim.replayer import ( + load_distinct_uuids, + load_trajectory, + replay_trajectory, + restore_sim_step, +) +from rcs.sim.sim import ROOT_RELATIVE_FREE_STATE_ENCODING + +import rcs def _build_env( @@ -18,6 +26,7 @@ def _build_env( with_cameras: bool, instruction: str = "", scene_path: Path | None = None, + root_frame_to_world: rcs.common.Pose | None = None, ) -> StorageWrapper: scene = EmptyWorldFR3Duo() cfg = scene.config() @@ -29,6 +38,8 @@ def _build_env( cfg.task_cfg = PickTaskConfig(robot_name="right") if scene_path is not None: cfg.scene = str(scene_path) + if root_frame_to_world is not None: + cfg.root_frame_to_world = root_frame_to_world if not with_cameras: cfg.camera_cfgs = {} else: @@ -56,8 +67,15 @@ def _record_source_dataset( limit: int, instruction: str, scene_path: Path | None = None, + root_frame_to_world: rcs.common.Pose | None = None, ) -> None: - env = _build_env(dataset_dir, with_cameras=False, instruction=instruction, scene_path=scene_path) + env = _build_env( + dataset_dir, + with_cameras=False, + instruction=instruction, + scene_path=scene_path, + root_frame_to_world=root_frame_to_world, + ) try: env.reset() action = { @@ -109,9 +127,21 @@ def _replay_rows(dataset_dir: Path): connection.close() -def _replay_prefix(output_dir: Path, *, with_cameras: bool, limit: int, scene_path: Path | None = None) -> None: +def _replay_prefix( + output_dir: Path, + *, + with_cameras: bool, + limit: int, + scene_path: Path | None = None, + root_frame_to_world: rcs.common.Pose | None = None, +) -> None: source_dir = output_dir.parent / "source" - env = _build_env(output_dir, with_cameras=with_cameras, scene_path=scene_path) + env = _build_env( + output_dir, + with_cameras=with_cameras, + scene_path=scene_path, + root_frame_to_world=root_frame_to_world, + ) try: uuid = load_distinct_uuids(source_dir)[0] recorded_steps = load_trajectory(source_dir, uuid)[:limit] @@ -192,6 +222,31 @@ def _strip_frames(obs: dict[str, Any]) -> dict[str, Any]: return {key: value for key, value in obs.items() if key != "frames"} +def _tilted_root_frame_to_world() -> rcs.common.Pose: + return rcs.common.Pose( + translation=np.array([0.35, -0.2, 0.15]), + quaternion=np.array([0.0, 0.0, 0.38268343, 0.92387953]), + ) + + +def _joint_qpos_from_state(state: np.ndarray, schema: dict[str, list[str] | list[int]], joint_name: str) -> np.ndarray: + joint_names = cast(list[str], schema["joint_names"]) + joint_index = joint_names.index(joint_name) + qpos_offset = sum(int(size) for size in schema["qpos_sizes"][:joint_index]) + qpos_size = int(schema["qpos_sizes"][joint_index]) + return np.asarray(state[qpos_offset : qpos_offset + qpos_size], dtype=np.float64) + + +def _joint_qpos_in_root_frame(env: StorageWrapper, joint_name: str, root_frame_to_world: rcs.common.Pose) -> np.ndarray: + joint_qpos_world = np.asarray(env.get_wrapper_attr("sim").data.joint(joint_name).qpos, dtype=np.float64) + joint_pose_world = rcs.common.Pose( + translation=joint_qpos_world[:3], + quaternion=np.array([joint_qpos_world[4], joint_qpos_world[5], joint_qpos_world[6], joint_qpos_world[3]]), + ) + joint_pose_root = root_frame_to_world.inverse() * joint_pose_world + return np.concatenate((joint_pose_root.translation(), joint_pose_root.rotation_q_wxyz())) + + def test_replayer_reproduces_existing_parquet_prefix_without_cameras(tmp_path: Path): source_dir = tmp_path / "source" replay_dir = tmp_path / "replayed" @@ -314,3 +369,56 @@ def test_replayer_adds_cameras_to_existing_episode_without_cameras(tmp_path: Pat _assert_nested_close(replay_action, source_action, atol=1e-8) _assert_nested_close(replay_env_action, source_env_action, atol=1e-8) _assert_nested_close(replay_instruction, source_instruction) + + +def test_replayer_restores_root_relative_free_joint_state_across_root_frame_changes(tmp_path: Path): + source_dir = tmp_path / "source" + replay_dir = tmp_path / "replayed" + default_root = rcs.common.Pose() + shifted_root = _tilted_root_frame_to_world() + object_joint_name = "PickTask_box_joint" + + _record_source_dataset( + source_dir, + limit=3, + instruction="pick up cube", + root_frame_to_world=default_root, + ) + _replay_prefix( + replay_dir, + with_cameras=False, + limit=3, + root_frame_to_world=shifted_root, + ) + + source_uuid = load_distinct_uuids(source_dir)[0] + replay_uuid = load_distinct_uuids(replay_dir)[0] + source_steps = load_trajectory(source_dir, source_uuid) + replay_steps = load_trajectory(replay_dir, replay_uuid) + + assert len(source_steps) == len(replay_steps) == 3 + for replay_step, source_step in zip(replay_steps, source_steps, strict=True): + assert replay_step.sim_state_schema == source_step.sim_state_schema + assert replay_step.sim_state_schema is not None + schema = replay_step.sim_state_schema + joint_names = cast(list[str], schema["joint_names"]) + encodings = cast(list[str], schema["encodings"]) + object_joint_index = joint_names.index(object_joint_name) + assert encodings[object_joint_index] == ROOT_RELATIVE_FREE_STATE_ENCODING + assert np.allclose(replay_step.sim_state, source_step.sim_state, atol=1e-5, rtol=0) + + replay_env = _build_env(replay_dir / "inspection", with_cameras=False, root_frame_to_world=shifted_root) + try: + replay_env.reset() + lead_env = replay_env.get_wrapper_attr("lead_env") + for source_step in source_steps: + restore_sim_step(replay_env, source_step) + lead_env.step_sim() + assert source_step.sim_state_schema is not None + expected_joint_qpos = _joint_qpos_from_state( + source_step.sim_state, source_step.sim_state_schema, object_joint_name + ) + actual_joint_qpos = _joint_qpos_in_root_frame(replay_env, object_joint_name, shifted_root) + assert np.allclose(actual_joint_qpos, expected_joint_qpos, atol=1e-5, rtol=0) + finally: + replay_env.close() From ca534ad3f28d4f1f5272eef1c6e29c83d27d7cbd Mon Sep 17 00:00:00 2001 From: Tobias Juelg Date: Mon, 4 May 2026 07:11:33 +0200 Subject: [PATCH 11/13] refactor(sim): simplify replay free-joint plumbing --- python/rcs/envs/scenes.py | 5 ++- python/rcs/envs/tasks.py | 4 +-- python/rcs/sim/composer.py | 30 +++++++++------- python/rcs/sim/sim.py | 73 +++++++++++++++++--------------------- 4 files changed, 53 insertions(+), 59 deletions(-) diff --git a/python/rcs/envs/scenes.py b/python/rcs/envs/scenes.py index ea11f3b3..8b554f09 100644 --- a/python/rcs/envs/scenes.py +++ b/python/rcs/envs/scenes.py @@ -393,13 +393,12 @@ def add_object_mujoco( register_root_relative_replay_free_joints: bool = False, ): """Add an object to the Mujoco scene.""" - added_object = composer.add_object_world_frame( + composer.add_object_world_frame( object_xml, object_prefix=object_id + "_", pose=object2world, + register_root_relative_replay_free_joints=register_root_relative_replay_free_joints, ) - if register_root_relative_replay_free_joints: - composer.register_root_relative_replay_free_joints(added_object.prefixed_free_joint_names) def add_object_robot_frame_mujoco( self, diff --git a/python/rcs/envs/tasks.py b/python/rcs/envs/tasks.py index 37773dc0..5444d461 100644 --- a/python/rcs/envs/tasks.py +++ b/python/rcs/envs/tasks.py @@ -163,12 +163,12 @@ def add_task_mujoco(cfg: PickTaskConfig, composer: ModelComposer, env_cfg: SimEn """Add task-specific elements to the Mujoco scene.""" object2world = cfg.object_center_to_root_frame * env_cfg.root_frame_to_world - added_object = composer.add_object_world_frame( + composer.add_object_world_frame( cfg.object_xml, object_prefix=cfg.prefix, pose=object2world, + register_root_relative_replay_free_joints=True, ) - composer.register_root_relative_replay_free_joints(added_object.prefixed_free_joint_names) @staticmethod def add_task_env(cfg: PickTaskConfig, env: gym.Env, _simulation: Sim, env_cfg: SimEnvCreatorConfig) -> gym.Env: diff --git a/python/rcs/sim/composer.py b/python/rcs/sim/composer.py index 504894c1..2a6ebe6e 100644 --- a/python/rcs/sim/composer.py +++ b/python/rcs/sim/composer.py @@ -1,17 +1,10 @@ import os -from dataclasses import dataclass from typing import Optional import mujoco from rcs._core.common import Pose -@dataclass(frozen=True) -class AddedObject: - root_body: mujoco._specs.MjsBody - prefixed_free_joint_names: list[str] - - class ModelComposer: """ Composes MuJoCo scenes using mjSpec with flexible positioning and prefixing. @@ -249,7 +242,9 @@ def add_object_robot_frame( object_prefix: str, attachment_site_name: str, pose: Pose | None = None, - ) -> AddedObject: + *, + register_root_relative_replay_free_joints: bool = False, + ) -> mujoco._specs.MjsBody: """Attaches an object to a robot attachment site with an optional local pose offset.""" if pose is None: pose = Pose() @@ -262,14 +257,22 @@ def add_object_robot_frame( object_spec = mujoco.MjSpec.from_file(xml_path) self._resolve_asset_paths(object_spec, xml_path) - prefixed_free_joint_names = self._prefixed_free_joint_names(object_spec, object_prefix) + if register_root_relative_replay_free_joints: + self.register_root_relative_replay_free_joints(self._prefixed_free_joint_names(object_spec, object_prefix)) object_root = object_spec.worldbody.first_body() object_root = attachment_site.attach(object_root, object_prefix, "") self._apply_pose(object_root, pose) - return AddedObject(root_body=object_root, prefixed_free_joint_names=prefixed_free_joint_names) + return object_root - def add_object_world_frame(self, xml_path: str, object_prefix: str, pose: Pose | None = None) -> AddedObject: + def add_object_world_frame( + self, + xml_path: str, + object_prefix: str, + pose: Pose | None = None, + *, + register_root_relative_replay_free_joints: bool = False, + ) -> mujoco._specs.MjsBody: """ Attaches a single object MJCF at a specific pose. Assumes the XML contains only one root body in the worldbody. @@ -283,7 +286,8 @@ def add_object_world_frame(self, xml_path: str, object_prefix: str, pose: Pose | # Load the child spec child_spec = mujoco.MjSpec.from_file(xml_path) self._resolve_asset_paths(child_spec, xml_path) - prefixed_free_joint_names = self._prefixed_free_joint_names(child_spec, object_prefix) + if register_root_relative_replay_free_joints: + self.register_root_relative_replay_free_joints(self._prefixed_free_joint_names(child_spec, object_prefix)) # Attach using a frame frame = self.spec.worldbody.add_frame() @@ -300,7 +304,7 @@ def add_object_world_frame(self, xml_path: str, object_prefix: str, pose: Pose | # Apply the pose self._apply_pose(obj_root, pose) - return AddedObject(root_body=obj_root, prefixed_free_joint_names=prefixed_free_joint_names) + return obj_root def save_mjcf(self, output_path: str): """Compiles and saves the MJCF.""" diff --git a/python/rcs/sim/sim.py b/python/rcs/sim/sim.py index af5dd72e..d37fdcf0 100644 --- a/python/rcs/sim/sim.py +++ b/python/rcs/sim/sim.py @@ -110,9 +110,12 @@ def get_state_size(self, schema: dict[str, list[str] | list[int]] | None = None) def get_state(self) -> np.ndarray: state = super().get_dynamic_joint_state() - qpos = np.array(state.qpos, copy=True) - qvel = np.array(state.qvel, copy=True) - self._encode_state_in_place(qpos, qvel, self.get_state_schema()) + qpos, qvel = self._transform_state( + np.array(state.qpos, copy=True), + np.array(state.qvel, copy=True), + self.get_state_schema(), + encode=True, + ) return np.concatenate((qpos, qvel)) def set_state( @@ -128,9 +131,12 @@ def set_state( raise ValueError(msg) qpos_size = sum(int(value) for value in state_schema["qpos_sizes"]) - qpos = np.array(state_array[:qpos_size], copy=True) - qvel = np.array(state_array[qpos_size:], copy=True) - self._decode_state_in_place(qpos, qvel, state_schema) + qpos, qvel = self._transform_state( + np.array(state_array[:qpos_size], copy=True), + np.array(state_array[qpos_size:], copy=True), + state_schema, + encode=False, + ) dynamic_joint_schema = DynamicJointSchema() dynamic_joint_schema.joint_names = typing.cast(list[str], list(state_schema["joint_names"])) @@ -143,39 +149,25 @@ def set_state( dynamic_joint_state.qvel = qvel super().set_dynamic_joint_state(dynamic_joint_schema, dynamic_joint_state) - def _encode_state_in_place( - self, - qpos: np.ndarray, - qvel: np.ndarray, - schema: dict[str, list[str] | list[int]], - ): - self._transform_state_in_place(qpos, qvel, schema, encode=True) - - def _decode_state_in_place( - self, - qpos: np.ndarray, - qvel: np.ndarray, - schema: dict[str, list[str] | list[int]], - ): - self._transform_state_in_place(qpos, qvel, schema, encode=False) - - def _transform_state_in_place( + def _transform_state( self, qpos: np.ndarray, qvel: np.ndarray, schema: dict[str, list[str] | list[int]], *, encode: bool, - ): + ) -> tuple[np.ndarray, np.ndarray]: joint_names = typing.cast(list[str], list(schema["joint_names"])) joint_types = [int(value) for value in schema["joint_types"]] qpos_sizes = [int(value) for value in schema["qpos_sizes"]] qvel_sizes = [int(value) for value in schema["qvel_sizes"]] encodings = typing.cast(list[str], list(schema.get("encodings", [RAW_STATE_ENCODING] * len(joint_names)))) + root_transform = self._root_frame_to_world.inverse() if encode else self._root_frame_to_world + root_rotation = root_transform.rotation_m() + free_joint_type = int(mj.mjtJoint.mjJNT_FREE) qpos_offset = 0 qvel_offset = 0 - free_joint_type = int(mj.mjtJoint.mjJNT_FREE) for joint_name, joint_type, joint_qpos_size, joint_qvel_size, encoding in zip( joint_names, joint_types, qpos_sizes, qvel_sizes, encodings, strict=True ): @@ -188,13 +180,14 @@ def _transform_state_in_place( "but is not a free joint." ) raise ValueError(msg) - joint_qpos = qpos[qpos_offset : qpos_offset + joint_qpos_size] - joint_qvel = qvel[qvel_offset : qvel_offset + joint_qvel_size] - transformed_qpos, transformed_qvel = self._transform_root_relative_free_joint( - joint_qpos, joint_qvel, encode=encode + qpos[qpos_offset : qpos_offset + joint_qpos_size], qvel[qvel_offset : qvel_offset + joint_qvel_size] = ( + self._transform_root_relative_free_joint( + qpos[qpos_offset : qpos_offset + joint_qpos_size], + qvel[qvel_offset : qvel_offset + joint_qvel_size], + root_transform=root_transform, + root_rotation=root_rotation, + ) ) - qpos[qpos_offset : qpos_offset + joint_qpos_size] = transformed_qpos - qvel[qvel_offset : qvel_offset + joint_qvel_size] = transformed_qvel else: msg = f"Unsupported sim state encoding '{encoding}' for joint '{joint_name}'." raise ValueError(msg) @@ -202,26 +195,24 @@ def _transform_state_in_place( qpos_offset += joint_qpos_size qvel_offset += joint_qvel_size + return qpos, qvel + def _transform_root_relative_free_joint( self, joint_qpos: np.ndarray, joint_qvel: np.ndarray, *, - encode: bool, + root_transform: common.Pose, + root_rotation: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: - joint_pose_world = common.Pose( + joint_pose = common.Pose( translation=np.asarray(joint_qpos[:3], dtype=np.float64), quaternion=np.asarray([joint_qpos[4], joint_qpos[5], joint_qpos[6], joint_qpos[3]], dtype=np.float64), ) - root_inverse = self._root_frame_to_world.inverse() - joint_pose = root_inverse * joint_pose_world if encode else self._root_frame_to_world * joint_pose_world - - rotation = root_inverse.rotation_m() if encode else self._root_frame_to_world.rotation_m() - transformed_qvel = np.concatenate((rotation @ joint_qvel[:3], rotation @ joint_qvel[3:6])) - + transformed_pose = root_transform * joint_pose return ( - np.concatenate((joint_pose.translation(), joint_pose.rotation_q_wxyz())), - transformed_qvel, + np.concatenate((transformed_pose.translation(), transformed_pose.rotation_q_wxyz())), + np.concatenate((root_rotation @ joint_qvel[:3], root_rotation @ joint_qvel[3:6])), ) def close_gui(self): From 7339a6d6f69734b9b87485680dbc7026b9b72b54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20J=C3=BClg?= Date: Mon, 4 May 2026 07:46:49 +0200 Subject: [PATCH 12/13] perf(sim): cache record schema computation --- python/rcs/sim/sim.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/rcs/sim/sim.py b/python/rcs/sim/sim.py index d37fdcf0..43a49a20 100644 --- a/python/rcs/sim/sim.py +++ b/python/rcs/sim/sim.py @@ -75,6 +75,7 @@ def __init__(self, mjmdl: str | PathLike | ModelComposer, cfg: SimConfig | None self._root_relative_replay_free_joints: set[str] = set() if cfg is not None: self.set_config(cfg) + self._state_schema = self._compute_state_schema() def configure_state_encodings( self, @@ -84,8 +85,12 @@ def configure_state_encodings( ): self._root_frame_to_world = common.Pose(root_frame_to_world) self._root_relative_replay_free_joints = set(root_relative_free_joints) + self._state_schema = self._compute_state_schema() def get_state_schema(self) -> dict[str, list[str] | list[int]]: + return self._state_schema + + def _compute_state_schema(self) -> dict[str, list[str] | list[int]]: schema = super().get_dynamic_joint_schema() return { "joint_names": list(schema.joint_names), From 532729d6aaafd38cf2daeb579994e11ca91cd86e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20J=C3=BClg?= Date: Mon, 4 May 2026 07:53:25 +0200 Subject: [PATCH 13/13] feat(sim): added camera depth option to cfg --- python/rcs/envs/scenes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/rcs/envs/scenes.py b/python/rcs/envs/scenes.py index 8b554f09..05e04060 100644 --- a/python/rcs/envs/scenes.py +++ b/python/rcs/envs/scenes.py @@ -49,6 +49,7 @@ def __call__(self, **kwargs) -> gym.Env: class WrapperConfig: binary_gripper: bool = True home_on_reset: bool = True + include_depth: bool = False #### SIM SPECIFIC #### @@ -364,7 +365,7 @@ def create_env_from_model(self, cfg: SimEnvCreatorConfig, mjmodel: MjModel) -> g BaseCameraSet, SimCameraSet(simulation, prefixed_cfg.camera_cfgs, physical_units=True, render_on_demand=True), ) - env = CameraSetWrapper(env, camera_set, include_depth=True) + env = CameraSetWrapper(env, camera_set, include_depth=cfg.wrapper_cfg.include_depth) env = self.add_task_env(prefixed_cfg.task_cfg, env, simulation, cfg) if not prefixed_cfg.headless: env.get_wrapper_attr("sim").open_gui()