Source code for jaxsim.api.integrators

import dataclasses
from collections.abc import Callable

import jax
import jax.numpy as jnp

import jaxsim
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.api.data import JaxSimModelData
from jaxsim.math import Skew


[docs] def semi_implicit_euler_integration( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, link_forces: jtp.Vector, joint_torques: jtp.Vector, ) -> JaxSimModelData: """Integrate the system state using the semi-implicit Euler method.""" with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): # Compute the system acceleration W_v̇_WB, , contact_state_derivative = js.ode.system_acceleration( model=model, data=data, link_forces=link_forces, joint_torques=joint_torques, ) dt = model.time_step # Compute the new generalized velocity. new_generalized_acceleration = jnp.hstack([W_v̇_WB, ]) new_generalized_velocity = ( data.generalized_velocity + dt * new_generalized_acceleration ) # Extract the new base and joint velocities. W_v_B = new_generalized_velocity[0:6] = new_generalized_velocity[6:] # Compute the new base position and orientation. W_ω_WB = new_generalized_velocity[3:6] # To obtain the derivative of the base position, we need to subtract # the skew-symmetric matrix of the base angular velocity times the base position. # See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9 W_ṗ_B = new_generalized_velocity[0:3] + Skew.wedge(W_ω_WB) @ data.base_position W_Q̇_B = jaxsim.math.Quaternion.derivative( quaternion=data.base_orientation, omega=W_ω_WB, omega_in_body_fixed=False, ).squeeze() W_p_B = data.base_position + dt * W_ṗ_B W_Q_B = data.base_orientation + dt * W_Q̇_B base_quaternion_norm = jaxsim.math.safe_norm(W_Q_B, axis=-1) W_Q_B = W_Q_B / jnp.where(base_quaternion_norm == 0, 1.0, base_quaternion_norm) s = data.joint_positions + dt * integrated_contact_state = jax.tree.map( lambda x, x_dot: x + dt * x_dot, data.contact_state, contact_state_derivative, ) # TODO: Avoid double replace, e.g. by computing cached value here data = dataclasses.replace( data, _base_quaternion=W_Q_B, _base_position=W_p_B, _joint_positions=s, _joint_velocities=, _base_linear_velocity=W_v_B[0:3], _base_angular_velocity=W_ω_WB, contact_state=integrated_contact_state, ) # Update the cached computations. data = data.replace(model=model) return data
[docs] def rk4_integration( model: js.model.JaxSimModel, data: JaxSimModelData, link_forces: jtp.Vector, joint_torques: jtp.Vector, ) -> JaxSimModelData: """Integrate the system state using the Runge-Kutta 4 method.""" dt = model.time_step def f(x) -> dict[str, jtp.Matrix]: with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): data_ti = data.replace(model=model, **x) return js.ode.system_dynamics( model=model, data=data_ti, link_forces=link_forces, joint_torques=joint_torques, ) base_quaternion_norm = jaxsim.math.safe_norm(data._base_quaternion, axis=-1) base_quaternion = data._base_quaternion / jnp.where( base_quaternion_norm == 0, 1.0, base_quaternion_norm ) x_t0 = dict( base_position=data._base_position, base_quaternion=base_quaternion, joint_positions=data._joint_positions, base_linear_velocity=data._base_linear_velocity, base_angular_velocity=data._base_angular_velocity, joint_velocities=data._joint_velocities, contact_state=data.contact_state, ) euler_mid = lambda x, dxdt: x + (0.5 * dt) * dxdt euler_fin = lambda x, dxdt: x + dt * dxdt k1 = f(x_t0) k2 = f(jax.tree.map(euler_mid, x_t0, k1)) k3 = f(jax.tree.map(euler_mid, x_t0, k2)) k4 = f(jax.tree.map(euler_fin, x_t0, k3)) # Average the slopes and compute the RK4 state derivative. average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6 dxdt = jax.tree_util.tree_map(average, k1, k2, k3, k4) # Integrate the dynamics x_tf = jax.tree_util.tree_map(euler_fin, x_t0, dxdt) data_tf = dataclasses.replace( data, _base_position=x_tf["base_position"], _base_quaternion=x_tf["base_quaternion"], _joint_positions=x_tf["joint_positions"], _base_linear_velocity=x_tf["base_linear_velocity"], _base_angular_velocity=x_tf["base_angular_velocity"], _joint_velocities=x_tf["joint_velocities"], contact_state=x_tf["contact_state"], ) return data_tf.replace(model=model)
_INTEGRATORS_MAP: dict[ js.model.IntegratorType, Callable[..., js.data.JaxSimModelData] ] = { js.model.IntegratorType.SemiImplicitEuler: semi_implicit_euler_integration, js.model.IntegratorType.RungeKutta4: rk4_integration, }