JaxSim as a hardware-accelerated parallel physics engine#
This notebook shows how to use the key APIs to load a robot model and simulate multiple trajectories simultaneously.
Prepare the environment#
# @title Imports and setup
import sys
from IPython.display import clear_output
IS_COLAB = "google.colab" in sys.modules
# Install JAX and Gazebo
if IS_COLAB:
!{sys.executable} -m pip install --pre -qU jaxsim
!apt install -qq lsb-release wget gnupg
!wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg
!echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null
!apt -qq update
!apt install -qq --no-install-recommends libsdformat13 gz-tools2
clear_output()
# Set environment variable to avoid GPU out of memory errors
%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false
# ================
# Notebook imports
# ================
import jax
import jax.numpy as jnp
import jaxsim.api as js
from jaxsim import logging
import pathlib
import urllib.request
logging.set_logging_level(logging.LoggingLevel.WARNING)
print(f"Running on {jax.devices()}")
env: XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false
Running on [CpuDevice(id=0)]
Prepare the simulation#
JaxSim supports loading robot descriptions from both SDF and URDF files. In this example, we will load the ergoCub model urdf.
Create the model and its data#
To define a simulation we need two main objects:
model
: an object that defines the dynamics of the system.data
: an object that contains the state of the system.
The JaxSimModel
object contains the simulation time step, the integrator and the contact model.
To see the advanced usage, check the advanced example, where you will see how to pass explicitly an integrator class and state to the model
object and how to change the contact model.
Create the model#
# Create the JaxSim model.
url = "https://raw.githubusercontent.com/icub-tech-iit/ergocub-software/refs/heads/master/urdf/ergoCub/robots/ergoCubSN001/model.urdf"
# Retrieve the file
model_path, _ = urllib.request.urlretrieve(url)
model_description_path = pathlib.Path(model_path)
full_model = js.model.JaxSimModel.build_from_model_description(
model_description=model_description_path,
time_step=0.0001,
is_urdf=True
)
joints_list = tuple(('l_shoulder_pitch', 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow',
'r_shoulder_pitch', 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow',
'l_hip_pitch', 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'))
model = js.model.reduce(
model=full_model,
considered_joints=joints_list
)
Create the data object#
The data object is never changed by reference. Anytime you call a method aimed at modifying data, like reset_base_position
, a new data object will be returned with the updated attributes while the original data will not be changed.
# Create the data of a single model.
data = js.data.JaxSimModelData.build(model=model, base_position=jnp.array([0.0, 0.0, 1.0]))
Simulation#
# Create a random JAX key.
key = jax.random.PRNGKey(seed=0)
# Initialize the simulated time.
T = jnp.arange(start=0, stop=0.3, step=model.time_step)
# Simulate
for _t in T:
data = js.model.step(
model=model,
data=data,
link_forces=None,
joint_force_references=None,
)
Vectorized simulation#
We will now vectorize the simulation on batched data using jax.vmap
# first we have to vmap the function
import functools
from typing import Any
@jax.jit
def step_single(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
# Close step over static arguments.
return js.model.step(
model=model,
data=data,
link_forces=None,
joint_force_references=None,
)
@jax.jit
@functools.partial(jax.vmap, in_axes=(None, 0))
def step_parallel(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
return step_single(
model=model, data=data
)
# Then we have to create the vector of initial state
batch_size = 5
data_batch_t0 = jax.vmap(
lambda pos: js.data.JaxSimModelData.build(model=model, base_position=pos))(jnp.tile(jnp.array([0.0, 0.0, 1.0]), (batch_size, 1)))
data = data_batch_t0
for _t in T:
data = step_parallel(model, data)