Skip to content
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# v0.0.6

## Added
* Separation between task-space controls and MuJoCo simulation controls
* Added `task_to_sim_ctrl()` method to `Task` base class (default identity mapping)
* Updated `RolloutBackend` to accept and use `task_to_sim_ctrl` function for control mapping
* Updated `MJSimulation` to use `task_to_sim_ctrl` when setting controls
* Tasks can now override `task_to_sim_ctrl()` to map between different control spaces
Comment on lines +3 to +8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tag yourself in these updates


# v0.0.5

## Added
Expand Down
6 changes: 5 additions & 1 deletion judo/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def __init__(
self.model = self.task.model
self.model_data_pairs = make_model_data_pairs(self.model, self.optimizer_cfg.num_rollouts)

self.rollout_backend = RolloutBackend(num_threads=self.optimizer_cfg.num_rollouts, backend=rollout_backend)
self.rollout_backend = RolloutBackend(
num_threads=self.optimizer_cfg.num_rollouts,
backend=rollout_backend,
task_to_sim_ctrl=self.task.task_to_sim_ctrl,
)
self.action_normalizer = self._init_action_normalizer()

# a container for any metadata from the system that we want to pass to the task
Expand Down
4 changes: 3 additions & 1 deletion judo/simulation/mj_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def step(self) -> None:
"""Step the simulation forward by one timestep."""
if self.control is not None and not self.paused:
try:
self.task.data.ctrl[:] = self.control(self.task.data.time)
control_input = self.control(self.task.data.time)
processed_control = self.task.task_to_sim_ctrl(control_input)
self.task.data.ctrl[:] = processed_control
self.task.pre_sim_step()
mj_step(self.task.sim_model, self.task.data)
self.task.post_sim_step()
Expand Down
14 changes: 14 additions & 0 deletions judo/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ def dt(self) -> float:
"""Returns Mujoco physics timestep for default physics task."""
return self.model.opt.timestep

def task_to_sim_ctrl(self, controls: np.ndarray) -> np.ndarray:
"""Maps the controls from the optimizer to the controls used in the simulation.

This can be overridden by tasks that have different control mappings. By default, it is the identity
function.

Args:
controls: The controls from the optimizer. Shape=(num_rollouts, T, nu) or (T, nu) or (nu,).

Returns:
mapped_controls: The controls to be used in the simulation. Same shape as input.
"""
return controls

def pre_rollout(self, curr_state: np.ndarray) -> None:
"""Pre-rollout behavior for task (does nothing by default).

Expand Down
22 changes: 15 additions & 7 deletions judo/utils/mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import time
from copy import deepcopy
from typing import Literal
from typing import Callable, Literal

import numpy as np
from mujoco import MjData, MjModel
Expand All @@ -20,9 +20,16 @@ def make_model_data_pairs(model: MjModel, num_pairs: int) -> list[tuple[MjModel,
class RolloutBackend:
"""The backend for conducting multithreaded rollouts."""

def __init__(self, num_threads: int, backend: Literal["mujoco"]) -> None:
"""Initialize the backend with a number of threads."""
def __init__(self, num_threads: int, backend: Literal["mujoco"], task_to_sim_ctrl: Callable) -> None:
"""Initialize the backend with a number of threads.

Args:
num_threads: Number of threads for parallel rollout.
backend: Backend to use ("mujoco" for Python mujoco rollout).
task_to_sim_ctrl: Function to map task controls to simulation controls.
"""
self.backend = backend
self.task_to_sim_ctrl = task_to_sim_ctrl
if self.backend == "mujoco":
self.setup_mujoco_backend(num_threads)
else:
Expand Down Expand Up @@ -57,15 +64,16 @@ def rollout(
# shape = (num_rollouts, num_states + 1)
x0_batched = np.tile(x0, (len(ms), 1))
full_states = np.concatenate([time.time() * np.ones((len(ms), 1)), x0_batched], axis=-1)
processed_controls = self.task_to_sim_ctrl(controls)
assert full_states.shape[-1] == nq + nv + 1
assert full_states.ndim == 2
assert controls.ndim == 3
assert controls.shape[-1] == nu
assert controls.shape[0] == full_states.shape[0]
assert processed_controls.ndim == 3
assert processed_controls.shape[-1] == nu
assert processed_controls.shape[0] == full_states.shape[0]

# rollout
if self.backend == "mujoco":
_states, _out_sensors = self.rollout_func(ms, ds, full_states, controls)
_states, _out_sensors = self.rollout_func(ms, ds, full_states, processed_controls)
else:
raise ValueError(f"Unknown backend: {self.backend}")
out_states = np.array(_states)[..., 1:] # remove time from state
Expand Down
Loading