Source code for jaxsim.api.data

from __future__ import annotations

import dataclasses
import functools
from collections.abc import Sequence

try:
    from typing import Self, override
except ImportError:
    from typing_extensions import override, Self

import jax
import jax.numpy as jnp
import jax.scipy.spatial.transform
import jax_dataclasses

import jaxsim.api as js
import jaxsim.math
import jaxsim.rbda
import jaxsim.typing as jtp

from . import common
from .common import VelRepr


[docs] @jax_dataclasses.pytree_dataclass class JaxSimModelData(common.ModelDataWithVelocityRepresentation): """ Class storing the state of the physics model dynamics. Attributes: joint_positions: The vector of joint positions. joint_velocities: The vector of joint velocities. base_position: The 3D position of the base link. base_quaternion: The quaternion defining the orientation of the base link. base_linear_velocity: The linear velocity of the base link in inertial-fixed representation. base_angular_velocity: The angular velocity of the base link in inertial-fixed representation. base_transform: The base transform. joint_transforms: The joint transforms. link_transforms: The link transforms. link_velocities: The link velocities in inertial-fixed representation. """ # Joint state _joint_positions: jtp.Vector _joint_velocities: jtp.Vector # Base state _base_quaternion: jtp.Vector _base_linear_velocity: jtp.Vector _base_angular_velocity: jtp.Vector _base_position: jtp.Vector # Cached computations. _base_transform: jtp.Matrix = dataclasses.field(repr=False, default=None) _joint_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None) _link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None) _link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None) # Extended state for soft and rigid contact models. contact_state: dict[str, jtp.Array] = dataclasses.field(default=None)
[docs] @staticmethod def build( model: js.model.JaxSimModel, base_position: jtp.VectorLike | None = None, base_quaternion: jtp.VectorLike | None = None, joint_positions: jtp.VectorLike | None = None, base_linear_velocity: jtp.VectorLike | None = None, base_angular_velocity: jtp.VectorLike | None = None, joint_velocities: jtp.VectorLike | None = None, contact_state: dict[str, jtp.Array] | None = None, velocity_representation: VelRepr = VelRepr.Mixed, ) -> JaxSimModelData: """ Create a `JaxSimModelData` object with the given state. Args: model: The model for which to create the state. base_position: The base position. base_quaternion: The base orientation as a quaternion. joint_positions: The joint positions. base_linear_velocity: The base linear velocity in the selected representation. base_angular_velocity: The base angular velocity in the selected representation. joint_velocities: The joint velocities. velocity_representation: The velocity representation to use. It defaults to mixed if not provided. contact_state: The optional contact state. Returns: A `JaxSimModelData` initialized with the given state. """ base_position = jnp.array( base_position if base_position is not None else jnp.zeros(3), dtype=float, ).squeeze() base_quaternion = jnp.array( ( base_quaternion if base_quaternion is not None else jnp.array([1.0, 0, 0, 0]) ), dtype=float, ).squeeze() base_linear_velocity = jnp.array( base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3), dtype=float, ).squeeze() base_angular_velocity = jnp.array( ( base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3) ), dtype=float, ).squeeze() joint_positions = jnp.atleast_1d( jnp.array( ( joint_positions if joint_positions is not None else jnp.zeros(model.dofs()) ), dtype=float, ).squeeze() ) joint_velocities = jnp.atleast_1d( jnp.array( ( joint_velocities if joint_velocities is not None else jnp.zeros(model.dofs()) ), dtype=float, ).squeeze() ) W_H_B = jaxsim.math.Transform.from_quaternion_and_translation( translation=base_position, quaternion=base_quaternion ) W_v_WB = JaxSimModelData.other_representation_to_inertial( array=jnp.hstack([base_linear_velocity, base_angular_velocity]), other_representation=velocity_representation, transform=W_H_B, is_force=False, ).astype(float) joint_transforms = model.kin_dyn_parameters.joint_transforms( joint_positions=joint_positions, base_transform=W_H_B ) link_transforms, link_velocities_inertial = ( jaxsim.rbda.forward_kinematics_model( model=model, base_position=base_position, base_quaternion=base_quaternion, joint_positions=joint_positions, base_linear_velocity_inertial=W_v_WB[0:3], base_angular_velocity_inertial=W_v_WB[3:6], joint_velocities=joint_velocities, ) ) contact_state = contact_state or {} if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts): contact_state.setdefault( "tangential_deformation", jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point), ) model_data = JaxSimModelData( velocity_representation=velocity_representation, _base_quaternion=base_quaternion, _base_position=base_position, _joint_positions=joint_positions, _base_linear_velocity=W_v_WB[0:3], _base_angular_velocity=W_v_WB[3:6], _joint_velocities=joint_velocities, _base_transform=W_H_B, _joint_transforms=joint_transforms, _link_transforms=link_transforms, _link_velocities=link_velocities_inertial, contact_state=contact_state, ) if not model_data.valid(model=model): raise ValueError( "The built state is not compatible with the model.", model_data ) return model_data
[docs] @staticmethod def zero( model: js.model.JaxSimModel, velocity_representation: VelRepr = VelRepr.Mixed, ) -> JaxSimModelData: """ Create a `JaxSimModelData` object with zero state. Args: model: The model for which to create the state. velocity_representation: The velocity representation to use. It defaults to mixed if not provided. Returns: A `JaxSimModelData` initialized with zero state. """ return JaxSimModelData.build( model=model, velocity_representation=velocity_representation )
# ================== # Extract quantities # ================== @property def joint_positions(self) -> jtp.Vector: """ Get the joint positions. Returns: The joint positions. """ return self._joint_positions @property def joint_velocities(self) -> jtp.Vector: """ Get the joint velocities. Returns: The joint velocities. """ return self._joint_velocities @property def base_quaternion(self) -> jtp.Vector: """ Get the base quaternion. Returns: The base quaternion. """ return self._base_quaternion @property def base_position(self) -> jtp.Vector: """ Get the base position. Returns: The base position. """ return self._base_position @property def base_orientation(self) -> jtp.Matrix: """ Get the base orientation. Returns: The base orientation. """ # Extract the base quaternion. W_Q_B = self.base_quaternion # Always normalize the quaternion to avoid numerical issues. # If the active scheme does not integrate the quaternion on its manifold, # we introduce a Baumgarte stabilization to let the quaternion converge to # a unit quaternion. In this case, it is not guaranteed that the quaternion # stored in the state is a unit quaternion. norm = jaxsim.math.safe_norm(W_Q_B, axis=-1, keepdims=True) W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0)) return W_Q_B @property def base_velocity(self) -> jtp.Vector: """ Get the base 6D velocity. Returns: The base 6D velocity in the active representation. """ W_v_WB = jnp.concatenate( [self._base_linear_velocity, self._base_angular_velocity], axis=-1 ) W_H_B = self._base_transform return ( JaxSimModelData.inertial_to_other_representation( array=W_v_WB, other_representation=self.velocity_representation, transform=W_H_B, is_force=False, ) .squeeze() .astype(float) ) @property def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]: r""" Get the generalized position :math:`\mathbf{q} = ({}^W \mathbf{H}_B, \mathbf{s}) \in \text{SO}(3) \times \mathbb{R}^n`. Returns: A tuple containing the base transform and the joint positions. """ return self._base_transform, self.joint_positions @property def generalized_velocity(self) -> jtp.Vector: r""" Get the generalized velocity. :math:`\boldsymbol{\nu} = (\boldsymbol{v}_{W,B};\, \boldsymbol{\omega}_{W,B};\, \mathbf{s}) \in \mathbb{R}^{6+n}` Returns: The generalized velocity in the active representation. """ return ( jnp.hstack([self.base_velocity, self.joint_velocities]) .squeeze() .astype(float) ) @property def base_transform(self) -> jtp.Matrix: """ Get the base transform. Returns: The base transform. """ return self._base_transform # ================ # Store quantities # ================
[docs] @js.common.named_scope @jax.jit def reset_base_quaternion( self, model: js.model.JaxSimModel, base_quaternion: jtp.VectorLike ) -> Self: """ Reset the base quaternion. Args: model: The JaxSim model to use. base_quaternion: The base orientation as a quaternion. Returns: The updated `JaxSimModelData` object. """ W_Q_B = jnp.array(base_quaternion, dtype=float) norm = jaxsim.math.safe_norm(W_Q_B, axis=-1) W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0)) return self.replace(model=model, base_quaternion=W_Q_B)
[docs] @js.common.named_scope @jax.jit def reset_base_pose( self, model: js.model.JaxSimModel, base_pose: jtp.MatrixLike ) -> Self: """ Reset the base pose. Args: model: The JaxSim model to use. base_pose: The base pose as an SE(3) matrix. Returns: The updated `JaxSimModelData` object. """ base_pose = jnp.array(base_pose) W_p_B = base_pose[0:3, 3] W_Q_B = jaxsim.math.Quaternion.from_dcm(dcm=base_pose[0:3, 0:3]) return self.replace( model=model, base_position=W_p_B, base_quaternion=W_Q_B, )
[docs] @override def replace( self, model: js.model.JaxSimModel, joint_positions: jtp.Vector | None = None, joint_velocities: jtp.Vector | None = None, base_quaternion: jtp.Vector | None = None, base_linear_velocity: jtp.Vector | None = None, base_angular_velocity: jtp.Vector | None = None, base_position: jtp.Vector | None = None, *, contact_state: dict[str, jtp.Array] | None = None, validate: bool = False, ) -> Self: """ Replace the attributes of the `JaxSimModelData` object. """ # Extract the batch size. batch_size = ( self._base_transform.shape[0] if self._base_transform.ndim > 2 else 1 ) if joint_positions is None: joint_positions = self.joint_positions if joint_velocities is None: joint_velocities = self.joint_velocities if base_quaternion is None: base_quaternion = self.base_quaternion if base_position is None: base_position = self.base_position if contact_state is None: contact_state = self.contact_state if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts): contact_state.setdefault( "tangential_deformation", jnp.zeros_like(model.kin_dyn_parameters.contact_parameters.point), ) # Normalize the quaternion to avoid numerical issues. base_quaternion_norm = jaxsim.math.safe_norm( base_quaternion, axis=-1, keepdims=True ) base_quaternion = base_quaternion / jnp.where( base_quaternion_norm == 0, 1.0, base_quaternion_norm ) joint_positions = jnp.atleast_1d(joint_positions.squeeze()).astype(float) joint_velocities = jnp.atleast_1d(joint_velocities.squeeze()).astype(float) base_quaternion = jnp.atleast_1d(base_quaternion.squeeze()).astype(float) base_position = jnp.atleast_1d(base_position.squeeze()).astype(float) base_transform = jaxsim.math.Transform.from_quaternion_and_translation( translation=base_position, quaternion=base_quaternion ) joint_transforms = jax.vmap(model.kin_dyn_parameters.joint_transforms)( joint_positions=jnp.broadcast_to( joint_positions, (batch_size, model.dofs()) ), base_transform=jnp.broadcast_to(base_transform, (batch_size, 4, 4)), ) if base_linear_velocity is None and base_angular_velocity is None: base_linear_velocity_inertial = self._base_linear_velocity base_angular_velocity_inertial = self._base_angular_velocity else: if base_linear_velocity is None: base_linear_velocity = self.base_velocity[:3] if base_angular_velocity is None: base_angular_velocity = self.base_velocity[3:] base_linear_velocity = jnp.atleast_1d(base_linear_velocity.squeeze()) base_angular_velocity = jnp.atleast_1d(base_angular_velocity.squeeze()) W_v_WB = JaxSimModelData.other_representation_to_inertial( array=jnp.hstack([base_linear_velocity, base_angular_velocity]), other_representation=self.velocity_representation, transform=base_transform, is_force=False, ).astype(float) base_linear_velocity_inertial, base_angular_velocity_inertial = ( W_v_WB[..., :3], W_v_WB[..., 3:], ) link_transforms, link_velocities = jax.vmap( jaxsim.rbda.forward_kinematics_model, in_axes=(None,) )( model, base_position=jnp.broadcast_to(base_position, (batch_size, 3)), base_quaternion=jnp.broadcast_to(base_quaternion, (batch_size, 4)), joint_positions=jnp.broadcast_to( joint_positions, (batch_size, model.dofs()) ), joint_velocities=jnp.broadcast_to( joint_velocities, (batch_size, model.dofs()) ), base_linear_velocity_inertial=jnp.broadcast_to( base_linear_velocity_inertial, (batch_size, 3) ), base_angular_velocity_inertial=jnp.broadcast_to( base_angular_velocity_inertial, (batch_size, 3) ), ) # Adjust the output shapes. if batch_size == 1: link_transforms = link_transforms.reshape(self._link_transforms.shape) link_velocities = link_velocities.reshape(self._link_velocities.shape) joint_transforms = joint_transforms.reshape(self._joint_transforms.shape) return super().replace( _joint_positions=joint_positions, _joint_velocities=joint_velocities, _base_quaternion=base_quaternion, _base_linear_velocity=base_linear_velocity_inertial, _base_angular_velocity=base_angular_velocity_inertial, _base_position=base_position, _base_transform=base_transform, _joint_transforms=joint_transforms, _link_transforms=link_transforms, _link_velocities=link_velocities, validate=validate, )
[docs] def valid(self, model: js.model.JaxSimModel) -> bool: """ Check if the `JaxSimModelData` is valid for a given `JaxSimModel`. Args: model: The `JaxSimModel` to validate the `JaxSimModelData` against. Returns: `True` if the `JaxSimModelData` is valid for the given model, `False` otherwise. """ if self._joint_positions.shape != (model.dofs(),): return False if self._joint_velocities.shape != (model.dofs(),): return False if self._base_position.shape != (3,): return False if self._base_quaternion.shape != (4,): return False if self._base_linear_velocity.shape != (3,): return False if self._base_angular_velocity.shape != (3,): return False return True
[docs] @functools.partial(jax.jit, static_argnames=["velocity_representation", "base_rpy_seq"]) def random_model_data( model: js.model.JaxSimModel, *, key: jax.Array | None = None, velocity_representation: VelRepr | None = None, base_pos_bounds: tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], ] = ((-1, -1, 0.5), 1.0), base_rpy_bounds: tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], ] = (-jnp.pi, jnp.pi), base_rpy_seq: str = "XYZ", joint_pos_bounds: ( tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], ] | None ) = None, base_vel_lin_bounds: tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], ] = (-1.0, 1.0), base_vel_ang_bounds: tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], ] = (-1.0, 1.0), joint_vel_bounds: tuple[ jtp.FloatLike | Sequence[jtp.FloatLike], jtp.FloatLike | Sequence[jtp.FloatLike], ] = (-1.0, 1.0), ) -> JaxSimModelData: """ Randomly generate a `JaxSimModelData` object. Args: model: The target model for the random data. key: The random key. velocity_representation: The velocity representation to use. base_pos_bounds: The bounds for the base position. base_rpy_bounds: The bounds for the euler angles used to build the base orientation. base_rpy_seq: The sequence of axes for rotation (using `Rotation` from scipy). joint_pos_bounds: The bounds for the joint positions (reading the joint limits if None). base_vel_lin_bounds: The bounds for the base linear velocity. base_vel_ang_bounds: The bounds for the base angular velocity. joint_vel_bounds: The bounds for the joint velocities. Returns: A `JaxSimModelData` object with random data. """ key = key if key is not None else jax.random.PRNGKey(seed=0) k1, k2, k3, k4, k5, k6 = jax.random.split(key, num=6) p_min = jnp.array(base_pos_bounds[0], dtype=float) p_max = jnp.array(base_pos_bounds[1], dtype=float) rpy_min = jnp.array(base_rpy_bounds[0], dtype=float) rpy_max = jnp.array(base_rpy_bounds[1], dtype=float) v_min = jnp.array(base_vel_lin_bounds[0], dtype=float) v_max = jnp.array(base_vel_lin_bounds[1], dtype=float) ω_min = jnp.array(base_vel_ang_bounds[0], dtype=float) ω_max = jnp.array(base_vel_ang_bounds[1], dtype=float) ṡ_min, ṡ_max = joint_vel_bounds base_position = jax.random.uniform(key=k1, shape=(3,), minval=p_min, maxval=p_max) base_quaternion = jaxsim.math.Quaternion.to_wxyz( xyzw=jax.scipy.spatial.transform.Rotation.from_euler( seq=base_rpy_seq, angles=jax.random.uniform( key=k2, shape=(3,), minval=rpy_min, maxval=rpy_max ), ).as_quat() ) ( joint_positions, joint_velocities, base_linear_velocity, base_angular_velocity, ) = (None,) * 4 if model.number_of_joints() > 0: s_min, s_max = ( jnp.array(joint_pos_bounds, dtype=float) if joint_pos_bounds is not None else (None, None) ) joint_positions = ( js.joint.random_joint_positions(model=model, key=k3) if (s_min is None or s_max is None) else jax.random.uniform( key=k3, shape=(model.dofs(),), minval=s_min, maxval=s_max ) ) joint_velocities = jax.random.uniform( key=k4, shape=(model.dofs(),), minval=ṡ_min, maxval=ṡ_max ) if model.floating_base(): base_linear_velocity = jax.random.uniform( key=k5, shape=(3,), minval=v_min, maxval=v_max ) base_angular_velocity = jax.random.uniform( key=k6, shape=(3,), minval=ω_min, maxval=ω_max ) return JaxSimModelData.build( model=model, base_position=base_position, base_quaternion=base_quaternion, joint_positions=joint_positions, joint_velocities=joint_velocities, base_linear_velocity=base_linear_velocity, base_angular_velocity=base_angular_velocity, **( {"velocity_representation": velocity_representation} if velocity_representation is not None else {} ), )