JaxSim as a hardware-accelerated parallel physics engine#

This notebook shows how to use the key APIs to load a robot model and simulate multiple trajectories simultaneously.

Open In Colab

Prepare the environment#

# @title Imports and setup
import sys
from IPython.display import clear_output

IS_COLAB = "google.colab" in sys.modules

# Install JAX and Gazebo
if IS_COLAB:
    !{sys.executable} -m pip install --pre -qU jaxsim
    !apt install -qq lsb-release wget gnupg
    !wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg
    !echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null
    !apt -qq update
    !apt install -qq --no-install-recommends libsdformat13 gz-tools2

    clear_output()

# Set environment variable to avoid GPU out of memory errors
%env XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false


# ================
# Notebook imports
# ================

import jax
import jax.numpy as jnp
import jaxsim.api as js
from jaxsim import logging
import pathlib
import urllib.request

logging.set_logging_level(logging.LoggingLevel.WARNING)
print(f"Running on {jax.devices()}")
env: XLA_PYTHON_CLIENT_MEM_PREALLOCATE=false
Running on [CpuDevice(id=0)]

Prepare the simulation#

JaxSim supports loading robot descriptions from both SDF and URDF files. In this example, we will load the ergoCub model urdf.

Create the model and its data#

To define a simulation we need two main objects:

  • model: an object that defines the dynamics of the system.

  • data: an object that contains the state of the system.

The JaxSimModel object contains the simulation time step, the integrator and the contact model. To see the advanced usage, check the advanced example, where you will see how to pass explicitly an integrator class and state to the model object and how to change the contact model.

Create the model#

#  Create the JaxSim model.
url = "https://raw.githubusercontent.com/icub-tech-iit/ergocub-software/refs/heads/master/urdf/ergoCub/robots/ergoCubSN001/model.urdf"

# Retrieve the file
model_path, _ = urllib.request.urlretrieve(url)

model_description_path = pathlib.Path(model_path)
full_model = js.model.JaxSimModel.build_from_model_description(
    model_description=model_description_path,
    time_step=0.0001,
    is_urdf=True
)

joints_list = tuple(('l_shoulder_pitch', 'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow',
               'r_shoulder_pitch', 'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow',
               'l_hip_pitch', 'l_hip_roll', 'l_hip_yaw', 'l_knee', 'l_ankle_pitch', 'l_ankle_roll',
               'r_hip_pitch', 'r_hip_roll', 'r_hip_yaw', 'r_knee', 'r_ankle_pitch', 'r_ankle_roll'))

model = js.model.reduce(
    model=full_model,
    considered_joints=joints_list
)

Create the data object#

The data object is never changed by reference. Anytime you call a method aimed at modifying data, like reset_base_position, a new data object will be returned with the updated attributes while the original data will not be changed.

# Create the data of a single model.
data = js.data.JaxSimModelData.build(model=model, base_position=jnp.array([0.0, 0.0, 1.0]))

Simulation#

# Create a random JAX key.

key = jax.random.PRNGKey(seed=0)

# Initialize the simulated time.
T = jnp.arange(start=0, stop=0.3, step=model.time_step)

# Simulate
for _t in T:
    data = js.model.step(
        model=model,
        data=data,
        link_forces=None,
        joint_force_references=None,
    )

Vectorized simulation#

We will now vectorize the simulation on batched data using jax.vmap

# first we have to vmap the function

import functools
from typing import Any


@jax.jit
def step_single(
    model: js.model.JaxSimModel,
    data: js.data.JaxSimModelData,
) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:

    # Close step over static arguments.
    return js.model.step(
        model=model,
        data=data,
        link_forces=None,
        joint_force_references=None,
    )


@jax.jit
@functools.partial(jax.vmap, in_axes=(None, 0))
def step_parallel(
    model: js.model.JaxSimModel,
    data: js.data.JaxSimModelData,
) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:

    return step_single(
        model=model, data=data
    )


# Then we have to create the vector of initial state
batch_size = 5
data_batch_t0 = jax.vmap(
    lambda pos:  js.data.JaxSimModelData.build(model=model, base_position=pos))(jnp.tile(jnp.array([0.0, 0.0, 1.0]), (batch_size, 1)))

data = data_batch_t0
for _t in T:
    data = step_parallel(model, data)