Source code for jaxsim.rbda.collidable_points
import jax
import jax.numpy as jnp
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Skew
[docs]
def collidable_points_pos_vel(
model: js.model.JaxSimModel,
*,
link_transforms: jtp.Matrix,
link_velocities: jtp.Matrix,
) -> tuple[jtp.Matrix, jtp.Matrix]:
"""
Compute the position and linear velocity of the enabled collidable points in the world frame.
Args:
model: The model to consider.
link_transforms: The transforms from the world frame to each link.
link_velocities: The linear and angular velocities of each link.
Returns:
A tuple containing the position and linear velocity of the enabled collidable points.
"""
# Get the indices of the enabled collidable points.
indices_of_enabled_collidable_points = (
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)
parent_link_idx_of_enabled_collidable_points = jnp.array(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)[indices_of_enabled_collidable_points]
L_p_Ci = model.kin_dyn_parameters.contact_parameters.point[
indices_of_enabled_collidable_points
]
if len(indices_of_enabled_collidable_points) == 0:
return jnp.array(0).astype(float), jnp.empty(0).astype(float)
def process_point_kinematics(
Li_p_C: jtp.Vector, parent_body: jtp.Int
) -> tuple[jtp.Vector, jtp.Vector]:
# Compute the position of the collidable point.
W_p_Ci = (link_transforms[parent_body] @ jnp.hstack([Li_p_C, 1]))[0:3]
# Compute the linear part of the mixed velocity Ci[W]_v_{W,Ci}.
CW_vl_WCi = (
jnp.block([jnp.eye(3), -Skew.wedge(vector=W_p_Ci).squeeze()])
@ link_velocities[parent_body].squeeze()
)
return W_p_Ci, CW_vl_WCi
# Process all the collidable points in parallel.
W_p_Ci, CW_vl_WC = jax.vmap(process_point_kinematics)(
L_p_Ci,
parent_link_idx_of_enabled_collidable_points,
)
return W_p_Ci, CW_vl_WC