Source code for jaxsim.api.actuation_model

import jax.numpy as jnp

import jaxsim.api as js
import jaxsim.typing as jtp


[docs] def compute_resultant_torques( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, joint_force_references: jtp.Vector | None = None, ) -> jtp.Vector: """ Compute the resultant torques acting on the joints. Args: model: The model to consider. data: The data of the considered model. joint_force_references: The joint force references to apply. Returns: The resultant torques acting on the joints. """ # Build joint torques if not provided. τ_references = ( jnp.atleast_1d(joint_force_references.squeeze()) if joint_force_references is not None else jnp.zeros_like(data.joint_positions) ).astype(float) # ==================== # Enforce joint limits # ==================== τ_position_limit = jnp.zeros_like(τ_references).astype(float) if model.dofs() > 0: # Stiffness and damper parameters for the joint position limits. k_j = jnp.array( model.kin_dyn_parameters.joint_parameters.position_limit_spring ).astype(float) d_j = jnp.array( model.kin_dyn_parameters.joint_parameters.position_limit_damper ).astype(float) # Compute the joint position limit violations. lower_violation = jnp.clip( data.joint_positions - model.kin_dyn_parameters.joint_parameters.position_limits_min, max=0.0, ) upper_violation = jnp.clip( data.joint_positions - model.kin_dyn_parameters.joint_parameters.position_limits_max, min=0.0, ) # Compute the joint position limit torque. τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation) τ_position_limit -= ( jnp.positive(τ_position_limit) * jnp.diag(d_j) @ data.joint_velocities ) # ==================== # Joint friction model # ==================== τ_friction = jnp.zeros_like(τ_references).astype(float) if model.dofs() > 0: # Static and viscous joint friction parameters kc = jnp.array( model.kin_dyn_parameters.joint_parameters.friction_static ).astype(float) kv = jnp.array( model.kin_dyn_parameters.joint_parameters.friction_viscous ).astype(float) # Compute the joint friction torque. τ_friction = -( jnp.diag(kc) @ jnp.sign(data.joint_velocities) + jnp.diag(kv) @ data.joint_velocities ) # =============================== # Compute the total joint forces. # =============================== τ_total = τ_references + τ_friction + τ_position_limit return τ_total