JaxSim for developing closed-loop robot controllers#
Originally developed as a hardware-accelerated physics engine, JaxSim has expanded its capabilities to become a full-featured JAX-based multibody dynamics library.
In this notebook, you’ll explore how to combine these two core features. Specifically, you’ll learn how to load a robot model and design a model-based controller for closed-loop simulations.
# @title Prepare the environment
from IPython.display import clear_output
import sys
IS_COLAB = "google.colab" in sys.modules
# Install JAX, sdformat, and other notebook dependencies.
if IS_COLAB:
!{sys.executable} -m pip install --pre -qU jaxsim[viz]
!apt install -qq lsb-release wget gnupg
!wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg
!echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null
!apt -qq update
!apt install -qq --no-install-recommends libsdformat13 gz-tools2
# Install dependencies for visualization on Colab and ReadTheDocs.
!apt -qq update
!apt install libosmesa6-dev
clear_output()
# ================
# Notebook imports
# ================
import os
os.environ["MUJOCO_GL"] = "osmesa"
import jax
import jax.numpy as jnp
import jaxsim.mujoco
from jaxsim import logging
logging.set_logging_level(logging.LoggingLevel.WARNING)
print(f"Running on {jax.devices()}")
Running on [CpuDevice(id=0)]
We will use a simple cartpole model for this example. The cartpole model is a 2D model with a cart that can move horizontally and a pole that can rotate around the cart. The state of the cartpole is given by the position of the cart, the angle of the pole, the velocity of the cart, and the angular velocity of the pole. The control input is the horizontal force applied to the cart.
Prepare the simulation#
JaxSim supports loading robot models from both SDF and URDF files, utilizing the ami-iit/rod
library for processing these formats.
The rod
library library can read URDF files and validates them internally using gazebosim/sdformat
. In this example, we’ll load a cart-pole model, which will be used to create the JaxSim simulation model.
import os
os.path.abspath("")
'/home/docs/checkouts/readthedocs.org/user_builds/jaxsim/checkouts/351/docs/_collections/examples'
# @title Load the URDF model
import pathlib
import urllib
# Retrieve the file
url = "https://raw.githubusercontent.com/ami-iit/jaxsim/refs/heads/main/examples/assets/cartpole.urdf"
model_path, _ = urllib.request.urlretrieve(url)
model_urdf_string = pathlib.Path(model_path).read_text()
# @title Create the model and its data
import jaxsim.api as js
# Create the model from the model description.
model = js.model.JaxSimModel.build_from_model_description(
model_description=model_urdf_string,
time_step=0.010,
)
# Create the data storing the simulation state.
data_zero = js.data.JaxSimModelData.zero(model=model)
# @title Define simulation parameters
# Initialize the simulated time.
T = jnp.arange(start=0, stop=5.0, step=model.time_step)
Prepare the MuJoCo renderer#
For visualization purpose, we use the passive viewer of the MuJoCo simulator. It allows to either open an interactive windows when used locally or record a video when used in notebooks.
# Create the MJCF resources from the URDF.
mjcf_string, assets = jaxsim.mujoco.UrdfToMjcf.convert(
urdf=model.built_from,
# Create the camera used by the recorder.
cameras=jaxsim.mujoco.loaders.MujocoCamera.build_from_target_view(
camera_name="cartpole_camera",
lookat=js.link.com_position(
model=model,
data=data_zero,
link_index=js.link.name_to_idx(model=model, link_name="cart"),
in_link_frame=False,
),
distance=3,
azimuth=150,
elevation=-10,
),
)
# Create a helper to operate on the MuJoCo model and data.
mj_model_helper = jaxsim.mujoco.MujocoModelHelper.build_from_xml(
mjcf_description=mjcf_string, assets=assets
)
# Create the video recorder.
recorder = jaxsim.mujoco.MujocoVideoRecorder(
model=mj_model_helper.model,
data=mj_model_helper.data,
fps=int(1 / model.time_step),
width=320 * 2,
height=240 * 2,
)
jaxsim[3862] WARNING This method is deprecated. Use 'ModelToMjcf.convert' instead.
Open-loop simulation#
Now, let’s run a simulation to demonstrate the open-loop dynamics of the system.
import mediapy as media
# Create a random joint position.
# For a random full state, you can use jaxsim.api.data.random_model_data.
random_joint_positions = jax.random.uniform(
minval=-1.0,
maxval=1.0,
shape=(model.dofs(),),
key=jax.random.PRNGKey(0),
)
# Reset the state to the random joint positions.
data = js.data.JaxSimModelData.build(model=model, joint_positions=random_joint_positions)
for _ in T:
# Step the JaxSim simulation.
data = js.model.step(
model=model,
data=data,
joint_force_references=None,
link_forces=None,
)
# Update the MuJoCo data.
mj_model_helper.set_joint_positions(
positions=data.joint_positions, joint_names=model.joint_names()
)
# Record a new video frame.
recorder.record_frame(camera_name="cartpole_camera")
# Play the video.
media.show_video(recorder.frames, fps=recorder.fps)
recorder.frames = []
Closed-loop simulation#
Next, let’s design a simple computed torque controller. The equations of motion for the cart-pole system are given by:
where:
\(\mathbf{s} \in \mathbb{R}^n\) are the joint positions.
\(\dot{\mathbf{s}} \in \mathbb{R}^n\) are the joint velocities.
\(\ddot{\mathbf{s}} \in \mathbb{R}^n\) are the joint accelerations.
\(\boldsymbol{\tau} \in \mathbb{R}^n\) are the joint torques.
\(M_{ss} \in \mathbb{R}^{n \times n}\) is the mass matrix.
\(\mathbf{h}_s \in \mathbb{R}^n\) is the vector of bias forces.
JaxSim computes these quantities for floating-base systems, so we specifically focus on the joint-related portions by marking them with subscripts.
Since no external forces or joint friction are present, we can extend a PD controller with a feed-forward term that includes gravity compensation:
where \(\tilde{\mathbf{s}} = \left(\mathbf{s} - \mathbf{s}^\text{des}\right)\) is the joint position error.
With this control law, the closed-loop system dynamics simplifies to:
which converges asymptotically to zero, ensuring stability.
# @title Create the computed torque controller
# Define the PD gains
kp = 10.0
kd = 6.0
def computed_torque_controller(
data: js.data.JaxSimModelData,
s_des: jax.Array,
s_dot_des: jax.Array,
) -> jax.Array:
# Compute the gravity compensation term.
hs = js.model.free_floating_bias_forces(model=model, data=data)[6:]
# Compute the joint-related portion of the floating-base mass matrix.
Mss = js.model.free_floating_mass_matrix(model=model, data=data)[6:, 6:]
# Get the current joint positions and velocities.
s = data.joint_positions
ṡ = data.joint_velocities
# Compute the actuated joint torques.
s_star = -kp * (s - s_des) - kd * (ṡ - s_dot_des)
τ = Mss @ s_star + hs
return τ
Now, we can use the pd_controller
function to compute the torque to apply to the cartpole. Our aim is to stabilize the cartpole in the upright position, so we set the desired position q_d
to 0 and the desired velocity q_dot_d
to 0.
# @title Run the simulation
# Initialize the data.
# Set the joint positions.
data = js.data.JaxSimModelData.build(model=model, joint_positions=jnp.array([-0.25, jnp.deg2rad(160)]), joint_velocities=jnp.array([3.00, jnp.deg2rad(10) / model.time_step]))
for _ in T:
# Get the actuated torques from the computed torque controller.
τ = computed_torque_controller(
data=data,
s_des=jnp.array([0.0, 0.0]),
s_dot_des=jnp.array([0.0, 0.0]),
)
# Step the JaxSim simulation.
data = js.model.step(
model=model,
data=data,
joint_force_references=τ,
)
# Update the MuJoCo data.
mj_model_helper.set_joint_positions(
positions=data.joint_positions, joint_names=model.joint_names()
)
# Record a new video frame.
recorder.record_frame(camera_name="cartpole_camera")
media.show_video(recorder.frames, fps=recorder.fps)
recorder.frames = []
Conclusions#
In this notebook, we explored how to use JaxSim for developing a closed-loop controller for a robot model. Key takeaways include:
We performed an open-loop simulation to understand the dynamics of the system without control.
We implemented a computed torque controller with PD feedback and a feed-forward gravity compensation term, enabling the stabilization of the system by controlling joint torques.
The closed-loop simulation can leverage hardware acceleration on GPUs and TPUs, with the ability to use
jax.vmap
for parallel sampling through automatic vectorization.
JaxSim’s closed-loop support can be extended to more advanced, model-based reactive controllers and planners for trajectory optimization. To explore optimization-based methods, consider the following JAX-based projects for hardware-accelerated control and planning:
Additionally, if your controllers or planners require the derivatives of the dynamics with respect to the state or inputs, you can obtain them using automatic differentiation directly through JaxSim’s API.