Source code for jaxsim.api.ode

import jax
import jax.numpy as jnp

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Quaternion, Skew

from .common import VelRepr

# ==================================
# Functions defining system dynamics
# ==================================


[docs] def system_acceleration( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_forces: jtp.MatrixLike | None = None, joint_torques: jtp.VectorLike | None = None, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute the system acceleration in the active representation. Args: model: The model to consider. data: The data of the considered model. link_forces: The 6D forces to apply to the links expressed in the same velocity representation of data. joint_torques: The joint torques applied to the joints. Returns: A tuple containing the base 6D acceleration in the active representation and the joint accelerations. """ # ==================== # Validate input data # ==================== # Build link forces if not provided. f_L = ( jnp.atleast_2d(link_forces.squeeze()) if link_forces is not None else jnp.zeros((model.number_of_links(), 6)) ).astype(float) # ====================== # Compute contact forces # ====================== W_f_L_terrain = jnp.zeros_like(f_L) contact_state_derivative = {} if len(model.kin_dyn_parameters.contact_parameters.body) > 0: # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact # with the terrain. W_f_L_terrain, contact_state_derivative = js.contact.link_contact_forces( model=model, data=data, link_forces=f_L, joint_torques=joint_torques, ) W_f_L_total = f_L + W_f_L_terrain # Update the contact state data. This is necessary only for the contact models # that require propagation and integration of contact state. contact_state = model.contact_model.update_contact_state(contact_state_derivative) # Store the link forces in a references object. references = js.references.JaxSimModelReferences.build( model=model, data=data, velocity_representation=data.velocity_representation, link_forces=W_f_L_total, ) # Compute forward dynamics. # # - Joint accelerations: s̈ ∈ ℝⁿ # - Base acceleration: v̇_WB ∈ ℝ⁶ # # Note that ABA returns the base acceleration in the velocity representation # stored in the `data` object. v̇_WB, = js.model.forward_dynamics_aba( model=model, data=data, joint_forces=joint_torques, link_forces=references.link_forces(model=model, data=data), ) return v̇_WB, , contact_state
[docs] @jax.jit @js.common.named_scope def system_position_dynamics( data: js.data.JaxSimModelData, baumgarte_quaternion_regularization: jtp.FloatLike = 1.0, ) -> tuple[jtp.Vector, jtp.Vector, jtp.Vector]: r""" Compute the dynamics of the system position. Args: data: The data of the considered model. baumgarte_quaternion_regularization: The Baumgarte regularization coefficient for adjusting the quaternion norm. Returns: A tuple containing the derivative of the base position, the derivative of the base quaternion, and the derivative of the joint positions. Note: In inertial-fixed representation, the linear component of the base velocity is not the derivative of the base position. In fact, the base velocity is defined as: :math:`{} ^W v_{W, B} = \begin{bmatrix} {} ^W \dot{p}_B S({} ^W \omega_{W, B}) {} ^W p _B\\ {} ^W \omega_{W, B} \end{bmatrix}`. Where :math:`S(\cdot)` is the skew-symmetric matrix operator. """ = data.joint_velocities W_Q_B = data.base_orientation W_ω_WB = data.base_velocity[3:6] W_ṗ_B = data.base_velocity[0:3] + Skew.wedge(W_ω_WB) @ data.base_position W_Q̇_B = Quaternion.derivative( quaternion=W_Q_B, omega=W_ω_WB, omega_in_body_fixed=False, K=baumgarte_quaternion_regularization, ).squeeze() return W_ṗ_B, W_Q̇_B,
[docs] @jax.jit @js.common.named_scope def system_dynamics( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, link_forces: jtp.Vector | None = None, joint_torques: jtp.Vector | None = None, baumgarte_quaternion_regularization: jtp.FloatLike = 1.0, ) -> dict[str, jtp.Vector]: """ Compute the dynamics of the system. Args: model: The model to consider. data: The data of the considered model. link_forces: The 6D forces to apply to the links expressed in the frame corresponding to the velocity representation of `data`. joint_torques: The joint torques acting on the joints. baumgarte_quaternion_regularization: The Baumgarte regularization coefficient used to adjust the norm of the quaternion (only used in integrators not operating on the SO(3) manifold). Returns: A dictionary containing the derivatives of the base position, the base quaternion, the joint positions, the base linear velocity, the base angular velocity, and the joint velocities. """ with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial): W_v̇_WB, , contact_state_derivative = system_acceleration( model=model, data=data, joint_torques=joint_torques, link_forces=link_forces, ) W_ṗ_B, W_Q̇_B, = system_position_dynamics( data=data, baumgarte_quaternion_regularization=baumgarte_quaternion_regularization, ) return dict( base_position=W_ṗ_B, base_quaternion=W_Q̇_B, joint_positions=, base_linear_velocity=W_v̇_WB[0:3], base_angular_velocity=W_v̇_WB[3:6], joint_velocities=, contact_state=contact_state_derivative, )