Source code for jaxsim.math.rotation

import jax.numpy as jnp
import jaxlie

import jaxsim.typing as jtp

from .skew import Skew
from .utils import safe_norm


[docs] class Rotation: """ A utility class for rotation matrix operations. """
[docs] @staticmethod def x(theta: jtp.Float) -> jtp.Matrix: """ Generate a 3D rotation matrix around the X-axis. Args: theta: Rotation angle in radians. Returns: The 3D rotation matrix. """ return jaxlie.SO3.from_x_radians(theta=theta).as_matrix()
[docs] @staticmethod def y(theta: jtp.Float) -> jtp.Matrix: """ Generate a 3D rotation matrix around the Y-axis. Args: theta: Rotation angle in radians. Returns: The 3D rotation matrix. """ return jaxlie.SO3.from_y_radians(theta=theta).as_matrix()
[docs] @staticmethod def z(theta: jtp.Float) -> jtp.Matrix: """ Generate a 3D rotation matrix around the Z-axis. Args: theta: Rotation angle in radians. Returns: The 3D rotation matrix. """ return jaxlie.SO3.from_z_radians(theta=theta).as_matrix()
[docs] @staticmethod def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix: """ Generate a 3D rotation matrix from an axis-angle representation. Args: vector: Axis-angle representation or the rotation as a 3D vector. Returns: The SO(3) rotation matrix. """ vector = vector.squeeze() theta = safe_norm(vector) s = jnp.sin(theta) c = jnp.cos(theta) c1 = 2 * jnp.sin(theta / 2.0) ** 2 safe_theta = jnp.where(theta == 0, 1.0, theta) u = vector / safe_theta u = jnp.vstack(u.squeeze()) R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T return R.transpose()