From dc458fa74862f85cd90182f10fcbaf311d284bcc Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 8 Jan 2025 14:14:57 +0100 Subject: [PATCH] Cache `base_velocity`, `base_orientation` and `base_transform` --- examples/jaxsim_for_robot_controllers.ipynb | 8 +- src/jaxsim/api/com.py | 8 +- src/jaxsim/api/contact.py | 8 +- src/jaxsim/api/data.py | 168 ++++++++++---------- src/jaxsim/api/frame.py | 8 +- src/jaxsim/api/link.py | 8 +- src/jaxsim/api/model.py | 111 +++++++------ src/jaxsim/api/ode.py | 6 +- src/jaxsim/mujoco/utils.py | 8 +- src/jaxsim/parsers/rod/parser.py | 5 +- src/jaxsim/rbda/contacts/visco_elastic.py | 6 +- tests/test_api_contact.py | 6 +- tests/test_api_data.py | 16 +- tests/test_api_frame.py | 8 +- tests/test_api_link.py | 8 +- tests/test_api_model.py | 113 +++++++------ tests/test_automatic_differentiation.py | 28 ++-- tests/test_simulations.py | 6 +- tests/utils_idyntree.py | 6 +- 19 files changed, 274 insertions(+), 261 deletions(-) diff --git a/examples/jaxsim_for_robot_controllers.ipynb b/examples/jaxsim_for_robot_controllers.ipynb index 6cd884ae9..118a0e080 100644 --- a/examples/jaxsim_for_robot_controllers.ipynb +++ b/examples/jaxsim_for_robot_controllers.ipynb @@ -256,7 +256,7 @@ "\n", " # Update the MuJoCo data.\n", " mj_model_helper.set_joint_positions(\n", - " positions=data.joint_positions(), joint_names=model.joint_names()\n", + " positions=data.joint_positions, joint_names=model.joint_names()\n", " )\n", "\n", " # Record a new video frame.\n", @@ -345,8 +345,8 @@ " Mss = js.model.free_floating_mass_matrix(model=model, data=data)[6:, 6:]\n", "\n", " # Get the current joint positions and velocities.\n", - " s = data.joint_positions()\n", - " ṡ = data.joint_velocities()\n", + " s = data.joint_positions\n", + " ṡ = data.joint_velocities\n", "\n", " # Compute the actuated joint torques.\n", " s_star = -kp * (s - s_des) - kd * (ṡ - s_dot_des)\n", @@ -405,7 +405,7 @@ "\n", " # Update the MuJoCo data.\n", " mj_model_helper.set_joint_positions(\n", - " positions=data.joint_positions(), joint_names=model.joint_names()\n", + " positions=data.joint_positions, joint_names=model.joint_names()\n", " )\n", "\n", " # Record a new video frame.\n", diff --git a/src/jaxsim/api/com.py b/src/jaxsim/api/com.py index fdc3fe8ce..7277acbdf 100644 --- a/src/jaxsim/api/com.py +++ b/src/jaxsim/api/com.py @@ -27,7 +27,7 @@ def com_position( m = js.model.total_mass(model=model) W_H_L = data.kyn_dyn.forward_kinematics - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B) def B_p̃_LCoM(i) -> jtp.Vector: @@ -134,7 +134,7 @@ def centroidal_momentum_jacobian( model=model, data=data, output_vel_repr=VelRepr.Body ) - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform B_H_W = jaxsim.math.Transform.inverse(W_H_B) W_p_CoM = com_position(model=model, data=data) @@ -172,7 +172,7 @@ def locked_centroidal_spatial_inertia( with data.switch_velocity_representation(VelRepr.Body): B_Mbb_B = js.model.locked_spatial_inertia(model=model, data=data) - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform W_p_CoM = com_position(model=model, data=data) match data.velocity_representation: @@ -411,7 +411,7 @@ def bias_momentum_derivative_term( case VelRepr.Body: GB_Xf_W = jaxsim.math.Adjoint.from_transform( - transform=data.base_transform().at[0:3].set(W_p_CoM) + transform=data.kyn_dyn.base_transform.at[0:3].set(W_p_CoM) ).T GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index f47b8068c..1f7c32418 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -636,9 +636,9 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W) case VelRepr.Body: - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform W_X_B = Adjoint.from_transform(transform=W_H_B) - B_v_WB = data.base_velocity() + B_v_WB = data.kyn_dyn.base_velocity B_vx_WB = Cross.vx(B_v_WB) W_Ẋ_B = W_X_B @ B_vx_WB @@ -646,10 +646,10 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B) case VelRepr.Mixed: - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) W_X_BW = Adjoint.from_transform(transform=W_H_BW) - BW_v_WB = data.base_velocity() + BW_v_WB = data.kyn_dyn.base_velocity BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) BW_vx_W_BW = Cross.vx(BW_v_W_BW) W_Ẋ_BW = W_X_BW @ BW_vx_W_BW diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 8a484a6c1..c691c9452 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -53,28 +53,17 @@ def __convert_attribute(self, value, name): "link_body_transforms", "collidable_point_positions", "collidable_point_velocities", + "base_transform", + "base_orientation", ]: return value - W_R_B = jaxsim.math.Quaternion.to_dcm( - self._data.state.physics_model.base_quaternion - ) - W_p_B = jnp.vstack(self._data.state.physics_model.base_position) - - W_H_B = jnp.vstack( - [ - jnp.block([W_R_B, W_p_B]), - jnp.array([0, 0, 0, 1]), - ] - ) + W_H_B = self._data.kyn_dyn.base_transform match name: case "jacobian_full_doubly_left": - if ( - self._data.velocity_representation - != self._kyn_dyn.velocity_representation - ): + if self._data.velocity_representation != VelRepr.Body: value = convert_jacobian( J=value, dofs=len(self._data.state.physics_model.joint_positions), @@ -83,10 +72,7 @@ def __convert_attribute(self, value, name): ) case "jacobian_derivative_full_doubly_left": - if ( - self._data.velocity_representation - != self._kyn_dyn.velocity_representation - ): + if self._data.velocity_representation != VelRepr.Body: value = convert_jacobian_derivative( Jd=value, dofs=len(self._data.state.physics_model.joint_positions), @@ -95,10 +81,7 @@ def __convert_attribute(self, value, name): ) case "mass_matrix": - if ( - self._data.velocity_representation - != self._kyn_dyn.velocity_representation - ): + if self._data.velocity_representation != VelRepr.Body: value = convert_mass_matrix( M=value, dofs=len(self._data.state.physics_model.joint_positions), @@ -106,6 +89,15 @@ def __convert_attribute(self, value, name): velocity_representation=self._data.velocity_representation, ) + case "base_velocity": + if self._data.velocity_representation != VelRepr.Inertial: + value = JaxSimModelData.inertial_to_other_representation( + array=value, + other_representation=self._data.velocity_representation, + transform=W_H_B, + is_force=False, + ) + case _: raise AttributeError( f"'{type(self._kyn_dyn).__name__}' object has no attribute '{name}'" @@ -127,8 +119,6 @@ def __setattr__(self, name, value): if name in ["_data", "_kyn_dyn"]: return super().__setattr__(name, value) - value = self.__convert_attribute(value=value, name=name) - # Push the update to JaxSimModelData. self._data._update_kyn_dyn(name, value) @@ -153,6 +143,12 @@ class KynDynComputation(common.ModelDataWithVelocityRepresentation): mass_matrix: jtp.Matrix + base_transform: jtp.Matrix + + base_orientation: jtp.Vector + + base_velocity: jtp.Vector + @jax_dataclasses.pytree_dataclass class JaxSimModelData(common.ModelDataWithVelocityRepresentation): @@ -173,7 +169,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): ) @property - def kyn_dyn(self): + def kyn_dyn(self) -> KynDynComputation | KynDynProxy: jaxsim.exceptions.raise_runtime_error_if( self._kyn_dyn_stale, @@ -457,6 +453,9 @@ def build( forward_kinematics=W_H_LL, collidable_point_positions=W_p_Ci, collidable_point_velocities=W_ṗ_Ci, + base_transform=base_transform, + base_orientation=jaxsim.math.Quaternion.to_dcm(base_quaternion), + base_velocity=v_WB, ) return JaxSimModelData( @@ -522,9 +521,8 @@ def base_position(self) -> jtp.Vector: return self.state.physics_model.base_position - @js.common.named_scope - @functools.partial(jax.jit, static_argnames=["dcm"]) - def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix: + @property + def base_orientation(self) -> jtp.Vector | jtp.Matrix: """ Get the base orientation. @@ -536,39 +534,8 @@ def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix """ # Extract the base quaternion. - W_Q_B = self.state.physics_model.base_quaternion.squeeze() - - # Always normalize the quaternion to avoid numerical issues. - # If the active scheme does not integrate the quaternion on its manifold, - # we introduce a Baumgarte stabilization to let the quaternion converge to - # a unit quaternion. In this case, it is not guaranteed that the quaternion - # stored in the state is a unit quaternion. - norm = jaxsim.math.safe_norm(W_Q_B) - W_Q_B = W_Q_B / (norm + jnp.finfo(float).eps * (norm == 0)) - - return (W_Q_B if not dcm else jaxsim.math.Quaternion.to_dcm(W_Q_B)).astype( - float - ) - - @js.common.named_scope - @jax.jit - def base_transform(self) -> jtp.Matrix: - """ - Get the base transform. - - Returns: - The base transform as an SE(3) matrix. - """ - - W_R_B = self.base_orientation(dcm=True) - W_p_B = jnp.vstack(self.base_position) - - return jnp.vstack( - [ - jnp.block([W_R_B, W_p_B]), - jnp.array([0, 0, 0, 1]), - ] - ) + W_Q_B = self.state.physics_model.base_quaternion + return W_Q_B @js.common.named_scope @jax.jit @@ -587,7 +554,7 @@ def base_velocity(self) -> jtp.Vector: ] ) - W_H_B = self.base_transform() + W_H_B = self.kyn_dyn.base_transform return JaxSimModelData.inertial_to_other_representation( array=W_v_WB, @@ -607,7 +574,7 @@ def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]: A tuple containing the base transform and the joint positions. """ - return self.base_transform(), self.joint_positions() + return self.kyn_dyn.base_transform, self.joint_positions @js.common.named_scope @jax.jit @@ -622,7 +589,7 @@ def generalized_velocity(self) -> jtp.Vector: """ return ( - jnp.hstack([self.base_velocity(), self.joint_velocities]) + jnp.hstack([self.kyn_dyn.base_velocity, self.joint_velocities]) .squeeze() .astype(float) ) @@ -831,7 +798,7 @@ def reset_base_linear_velocity( base_velocity=jnp.hstack( [ linear_velocity.squeeze(), - self.base_velocity()[3:6], + self.kyn_dyn.base_velocity[3:6], ] ), velocity_representation=velocity_representation, @@ -862,7 +829,7 @@ def reset_base_angular_velocity( return self.reset_base_velocity( base_velocity=jnp.hstack( [ - self.base_velocity()[0:3], + self.kyn_dyn.base_velocity[0:3], angular_velocity.squeeze(), ] ), @@ -897,10 +864,25 @@ def reset_base_velocity( else self.velocity_representation ) + # Recompute the base transform since we cannot rely on the cached value. + W_H_B = jnp.vstack( + [ + jnp.block( + [ + jaxsim.math.Quaternion.to_dcm( + self.state.physics_model.base_quaternion + ), + jnp.vstack(self.state.physics_model.base_position), + ] + ), + jnp.array([0, 0, 0, 1]), + ] + ) + W_v_WB = self.other_representation_to_inertial( array=jnp.atleast_1d(base_velocity.squeeze()).astype(float), other_representation=velocity_representation, - transform=self.base_transform(), + transform=W_H_B, is_force=False, ) @@ -931,11 +913,22 @@ def update_kyn_dyn( An instance of `JaxSimModelData` with the updated `kyn_dyn` attribute. """ - base_quaternion = self.base_orientation(dcm=False) - base_transform = self.base_transform() + base_quaternion = self.state.physics_model.base_quaternion joint_positions = self.joint_positions joint_velocities = self.joint_velocities + base_transform = jnp.vstack( + [ + jnp.block( + [ + jaxsim.math.Quaternion.to_dcm(base_quaternion), + jnp.vstack(self.state.physics_model.base_position), + ] + ), + jnp.array([0, 0, 0, 1]), + ] + ) + i_X_λ, S = model.kin_dyn_parameters.joint_transforms_and_motion_subspaces( joint_positions=joint_positions, base_transform=base_transform ) @@ -968,20 +961,24 @@ def update_kyn_dyn( joint_transforms=i_X_λ, ) - with self.switch_velocity_representation(VelRepr.Inertial) as data: - base_velocity = data.base_velocity() - - W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( - model=model, - base_position=base_transform[:3, 3], - base_quaternion=base_quaternion, - joint_positions=joint_positions, - base_linear_velocity=base_velocity[0:3], - base_angular_velocity=base_velocity[3:6], - joint_velocities=joint_velocities, - joint_transforms=i_X_λ, - motion_subspaces=S, - ) + base_velocity = jnp.hstack( + [ + self.state.physics_model.base_linear_velocity, + self.state.physics_model.base_angular_velocity, + ] + ) + + W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( + model=model, + base_position=base_transform[:3, 3], + base_quaternion=base_quaternion, + joint_positions=joint_positions, + base_linear_velocity=base_velocity[0:3], + base_angular_velocity=base_velocity[3:6], + joint_velocities=joint_velocities, + joint_transforms=i_X_λ, + motion_subspaces=S, + ) data = self.replace(_kyn_dyn_stale=jnp.array(0, dtype=bool)) @@ -994,6 +991,9 @@ def update_kyn_dyn( data.kyn_dyn.forward_kinematics = W_H_LL data.kyn_dyn.collidable_point_positions = W_p_Ci data.kyn_dyn.collidable_point_velocities = W_ṗ_Ci + data.kyn_dyn.base_transform = base_transform + data.kyn_dyn.base_orientation = jaxsim.math.Quaternion.to_dcm(base_quaternion) + data.kyn_dyn.base_velocity = base_velocity return data diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index 39c7092fd..aa7c7ce60 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -401,9 +401,9 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W) case VelRepr.Body: - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform W_X_B = Adjoint.from_transform(transform=W_H_B) - B_v_WB = data.base_velocity() + B_v_WB = data.kyn_dyn.base_velocity B_vx_WB = Cross.vx(B_v_WB) W_Ẋ_B = W_X_B @ B_vx_WB @@ -411,10 +411,10 @@ def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix: Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B) case VelRepr.Mixed: - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) W_X_BW = Adjoint.from_transform(transform=W_H_BW) - BW_v_WB = data.base_velocity() + BW_v_WB = data.kyn_dyn.base_velocity BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) BW_vx_W_BW = Cross.vx(BW_v_W_BW) W_Ẋ_BW = W_X_BW @ BW_vx_W_BW diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 2719da5d1..2fc13d820 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -285,7 +285,7 @@ def jacobian( # Adjust the input representation such that `J_WL_I @ I_ν`. match data.velocity_representation: case VelRepr.Inertial: - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) B_J_WL_I = B_J_WL_W = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841 B_X_W, jnp.eye(model.dofs()) @@ -295,7 +295,7 @@ def jacobian( B_J_WL_I = B_J_WL_B case VelRepr.Mixed: - W_R_B = data.base_orientation(dcm=True) + W_R_B = data.kyn_dyn.base_transform[:3, :3] BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) B_J_WL_I = B_J_WL_BW = B_J_WL_B @ jax.scipy.linalg.block_diag( # noqa: F841 @@ -310,7 +310,7 @@ def jacobian( # Adjust the output representation such that `O_v_WL_I = O_J_WL_I @ I_ν`. match output_vel_repr: case VelRepr.Inertial: - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform W_X_B = Adjoint.from_transform(transform=W_H_B) O_J_WL_I = W_J_WL_I = W_X_B @ B_J_WL_I # noqa: F841 @@ -320,7 +320,7 @@ def jacobian( O_J_WL_I = L_J_WL_I case VelRepr.Mixed: - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform W_H_L = W_H_B @ B_H_L LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3)) LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 062512db2..17a10bfcb 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -59,6 +59,17 @@ class JaxSimModel(JaxsimDataclass): _description: Static[ModelDescription] = dataclasses.field(default=None, repr=False) + @property + def description(self) -> ModelDescription: + """ + Return the intermediate model description of the model. + + Returns: + The intermediate model description of the model. + """ + + return self._description + # ======================== # Initialization and state # ======================== @@ -563,7 +574,7 @@ def generalized_free_floating_jacobian( case VelRepr.Inertial: - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform B_X_W = Adjoint.from_transform(transform=W_H_B, inverse=True) B_J_full_WX_I = B_J_full_WX_W = ( # noqa: F841 @@ -577,7 +588,7 @@ def generalized_free_floating_jacobian( case VelRepr.Mixed: - W_R_B = data.base_orientation(dcm=True) + W_R_B = data.kyn_dyn.base_transform[:3, :3] BW_H_B = jnp.eye(4).at[0:3, 0:3].set(W_R_B) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) @@ -611,7 +622,7 @@ def generalized_free_floating_jacobian( case VelRepr.Inertial: - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform W_X_B = jaxsim.math.Adjoint.from_transform(W_H_B) O_J_WL_I = W_J_WL_I = jax.vmap( # noqa: F841 @@ -629,7 +640,7 @@ def generalized_free_floating_jacobian( case VelRepr.Mixed: - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform LW_H_L = jax.vmap( lambda B_H_L: (W_H_B @ B_H_L).at[0:3, 3].set(jnp.zeros(3)) @@ -676,19 +687,24 @@ def generalized_free_floating_jacobian_derivative( ) # Compute the derivative of the doubly-left free-floating full jacobian. - B_J̇_full_WX_B, B_H_L = ( - data.kyn_dyn.jacobian_derivative_full_doubly_left, - data.kyn_dyn.link_body_transforms, + B_J̇_full_WX_B, B_H_L = jaxsim.rbda.jacobian_derivative_full_doubly_left( + model=model, + joint_positions=data.joint_positions, + joint_velocities=data.joint_velocities, ) # The derivative of the equation to change the input and output representations # of the Jacobian derivative needs the computation of the plain link Jacobian. - B_J_full_WL_B = data.kyn_dyn.jacobian_full_doubly_left + B_J_full_WL_B, _ = jaxsim.rbda.jacobian_full_doubly_left( + model=model, + joint_positions=data.joint_positions, + ) # Compute the actual doubly-left free-floating jacobian derivative of the link # by zeroing the columns not in the path π_B(L) using the boolean κ(i). κb = model.kin_dyn_parameters.support_body_array_bool + # Compute the base transform. W_H_B = data.kyn_dyn.base_transform # We add the 5 columns of ones to the Jacobian derivative to account for the @@ -713,7 +729,7 @@ def generalized_free_floating_jacobian_derivative( B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True) - W_v_WB = data.base_velocity() + W_v_WB = data.kyn_dyn.base_velocity B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) # Compute the operator to change the representation of ν, and its @@ -739,7 +755,7 @@ def generalized_free_floating_jacobian_derivative( BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3)) B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) - BW_v_WB = data.base_velocity() + BW_v_WB = data.kyn_dyn.base_velocity BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) BW_v_BW_B = BW_v_WB - BW_v_W_BW @@ -764,7 +780,7 @@ def generalized_free_floating_jacobian_derivative( O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B) with data.switch_velocity_representation(VelRepr.Body): - B_v_WB = data.base_velocity() + B_v_WB = data.kyn_dyn.base_velocity O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841 @@ -777,7 +793,7 @@ def generalized_free_floating_jacobian_derivative( B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B) with data.switch_velocity_representation(VelRepr.Body): - B_v_WB = data.base_velocity() + B_v_WB = data.kyn_dyn.base_velocity L_v_WL = jnp.einsum( "b6j,j->b6", L_X_B @ B_J_WL_B, data.generalized_velocity() ) @@ -797,7 +813,7 @@ def generalized_free_floating_jacobian_derivative( B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B) with data.switch_velocity_representation(VelRepr.Body): - B_v_WB = data.base_velocity() + B_v_WB = data.kyn_dyn.base_velocity with data.switch_velocity_representation(VelRepr.Mixed): BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3)) @@ -931,8 +947,8 @@ def forward_dynamics_aba( # Extract the state in inertial-fixed representation. with data.switch_velocity_representation(VelRepr.Inertial): W_p_B = data.base_position - W_v_WB = data.base_velocity() - W_Q_B = data.base_orientation(dcm=False) + W_v_WB = data.kyn_dyn.base_velocity + W_Q_B = data.base_orientation s = data.joint_positions ṡ = data.joint_velocities @@ -985,14 +1001,14 @@ def to_active( case VelRepr.Body: # In this case C=B - W_H_C = W_H_B = data.base_transform() + W_H_C = W_H_B = data.kyn_dyn.base_transform W_v_WC = W_v_WB case VelRepr.Mixed: # In this case C=B[W] - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 - W_ṗ_B = data.base_velocity()[0:3] + W_ṗ_B = data.kyn_dyn.base_velocity[0:3] W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 case _: @@ -1134,7 +1150,7 @@ def free_floating_mass_matrix( case VelRepr.Inertial: B_X_W = Adjoint.from_transform( - transform=data.base_transform(), inverse=True + transform=data.kyn_dyn.base_transform, inverse=True ) invT = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) @@ -1142,7 +1158,7 @@ def free_floating_mass_matrix( case VelRepr.Mixed: - BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + BW_H_B = data.kyn_dyn.base_transform.at[0:3, 3].set(jnp.zeros(3)) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) invT = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) @@ -1225,12 +1241,12 @@ def compute_link_contribution(M, v, J, J̇) -> jtp.Array: case VelRepr.Inertial: n = model.dofs() - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform B_X_W = jaxsim.math.Adjoint.from_transform(W_H_B, inverse=True) B_T_W = jax.scipy.linalg.block_diag(B_X_W, jnp.eye(n)) with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WB = data.base_velocity() + W_v_WB = data.kyn_dyn.base_velocity B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB) B_Ṫ_W = jax.scipy.linalg.block_diag(B_Ẋ_W, jnp.zeros(shape=(n, n))) @@ -1245,12 +1261,12 @@ def compute_link_contribution(M, v, J, J̇) -> jtp.Array: case VelRepr.Mixed: n = model.dofs() - BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + BW_H_B = data.kyn_dyn.base_transform.at[0:3, 3].set(jnp.zeros(3)) B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True) B_T_BW = jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(n)) with data.switch_velocity_representation(VelRepr.Mixed): - BW_v_WB = data.base_velocity() + BW_v_WB = data.kyn_dyn.base_velocity BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3)) BW_v_BW_B = BW_v_WB - BW_v_W_BW @@ -1344,14 +1360,14 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 case VelRepr.Body: - W_H_C = W_H_B = data.base_transform() + W_H_C = W_H_B = data.kyn_dyn.base_transform with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WC = W_v_WB = data.base_velocity() + W_v_WC = W_v_WB = data.kyn_dyn.base_velocity case VelRepr.Mixed: - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) # noqa: F841 - W_ṗ_B = data.base_velocity()[0:3] + W_ṗ_B = data.kyn_dyn.base_velocity[0:3] W_v_WC = W_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) # noqa: F841 case _: @@ -1363,7 +1379,7 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): W_v̇_WB = to_inertial( C_v̇_WB=v̇_WB, W_H_C=W_H_C, - C_v_WB=data.base_velocity(), + C_v_WB=data.kyn_dyn.base_velocity, W_v_WC=W_v_WC, ) @@ -1377,15 +1393,14 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): # Extract the link and joint serializations. link_names = model.link_names() - joint_names = model.joint_names() # Extract the state in inertial-fixed representation. with data.switch_velocity_representation(VelRepr.Inertial): W_p_B = data.base_position - W_v_WB = data.base_velocity() - W_Q_B = data.base_orientation(dcm=False) - s = data.joint_positions(model=model, joint_names=joint_names) - ṡ = data.joint_velocities(model=model, joint_names=joint_names) + W_v_WB = data.kyn_dyn.base_velocity + W_Q_B = data.base_orientation + s = data.joint_positions + ṡ = data.joint_velocities # Extract the inputs in inertial-fixed representation. with references.switch_velocity_representation(VelRepr.Inertial): @@ -1420,7 +1435,7 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_v_WC): f_B = js.data.JaxSimModelData.inertial_to_other_representation( array=W_f_B, other_representation=data.velocity_representation, - transform=data.base_transform(), + transform=data.kyn_dyn.base_transform, is_force=True, ).squeeze() @@ -1625,12 +1640,12 @@ def total_momentum_jacobian( case VelRepr.Inertial: B_X_W = Adjoint.from_transform( - transform=data.base_transform(), inverse=True + transform=data.kyn_dyn.base_transform, inverse=True ) B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_W, jnp.eye(model.dofs())) case VelRepr.Mixed: - BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + BW_H_B = data.kyn_dyn.base_transform.at[0:3, 3].set(jnp.zeros(3)) B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) B_Jh = B_Jh_B @ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs())) @@ -1642,14 +1657,14 @@ def total_momentum_jacobian( return B_Jh case VelRepr.Inertial: - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform B_Xv_W = Adjoint.from_transform(transform=W_H_B, inverse=True) W_Xf_B = B_Xv_W.T W_Jh = W_Xf_B @ B_Jh return W_Jh case VelRepr.Mixed: - BW_H_B = data.base_transform().at[0:3, 3].set(jnp.zeros(3)) + BW_H_B = data.kyn_dyn.base_transform.at[0:3, 3].set(jnp.zeros(3)) B_Xv_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True) BW_Xf_B = B_Xv_BW.T BW_Jh = BW_Xf_B @ B_Jh @@ -1724,7 +1739,7 @@ def average_velocity_jacobian( GB_J = G_J W_p_B = data.base_position W_p_CoM = js.com.com_position(model=model, data=data) - B_R_W = data.base_orientation(dcm=True).transpose() + B_R_W = data.kyn_dyn.base_transform[:3, :3].transpose() B_H_GB = jnp.eye(4).at[0:3, 3].set(B_R_W @ (W_p_CoM - W_p_B)) B_X_GB = Adjoint.from_transform(transform=B_H_GB) @@ -1775,7 +1790,7 @@ def link_bias_accelerations( # ================================================ # Compute the base transform. - W_H_B = data.base_transform() + W_H_B = data.kyn_dyn.base_transform def other_representation_to_inertial( C_v̇_WB: jtp.Vector, C_v_WB: jtp.Vector, W_H_C: jtp.Matrix, W_v_WC: jtp.Vector @@ -1802,25 +1817,25 @@ def other_representation_to_inertial( W_H_C = W_H_W = jnp.eye(4) # noqa: F841 W_v_WC = W_v_WW = jnp.zeros(6) # noqa: F841 with data.switch_velocity_representation(VelRepr.Inertial): - C_v_WB = W_v_WB = data.base_velocity() + C_v_WB = W_v_WB = data.kyn_dyn.base_velocity case VelRepr.Body: W_H_C = W_H_B with data.switch_velocity_representation(VelRepr.Inertial): - W_v_WC = W_v_WB = data.base_velocity() # noqa: F841 + W_v_WC = W_v_WB = data.kyn_dyn.base_velocity # noqa: F841 with data.switch_velocity_representation(VelRepr.Body): - C_v_WB = B_v_WB = data.base_velocity() + C_v_WB = B_v_WB = data.kyn_dyn.base_velocity case VelRepr.Mixed: W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) W_H_C = W_H_BW with data.switch_velocity_representation(VelRepr.Mixed): - W_ṗ_B = data.base_velocity()[0:3] + W_ṗ_B = data.kyn_dyn.base_velocity[0:3] BW_v_W_BW = jnp.zeros(6).at[0:3].set(W_ṗ_B) W_X_BW = jaxsim.math.Adjoint.from_transform(transform=W_H_BW) W_v_WC = W_v_W_BW = W_X_BW @ BW_v_W_BW # noqa: F841 with data.switch_velocity_representation(VelRepr.Mixed): - C_v_WB = BW_v_WB = data.base_velocity() # noqa: F841 + C_v_WB = BW_v_WB = data.kyn_dyn.base_velocity # noqa: F841 case _: raise ValueError(data.velocity_representation) @@ -1853,11 +1868,11 @@ def other_representation_to_inertial( # Store the base velocity. with data.switch_velocity_representation(VelRepr.Body): - B_v_WB = data.base_velocity() + B_v_WB = data.kyn_dyn.base_velocity L_v_WL = L_v_WL.at[0].set(B_v_WB) # Get the joint velocities. - ṡ = data.joint_velocities(model=model, joint_names=model.joint_names()) + ṡ = data.joint_velocities # Allocate the buffer to store the body-fixed link accelerations, # and initialize the base acceleration. diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 4c7074b89..0b20b06ce 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -374,13 +374,13 @@ def system_position_dynamics( """ ṡ = data.joint_velocities - W_Q_B = data.base_orientation(dcm=False) + W_Q_B = data.base_orientation with data.switch_velocity_representation(velocity_representation=VelRepr.Mixed): - W_ṗ_B = data.base_velocity()[0:3] + W_ṗ_B = data.kyn_dyn.base_velocity[0:3] with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial): - W_ω_WB = data.base_velocity()[3:6] + W_ω_WB = data.kyn_dyn.base_velocity[3:6] W_Q̇_B = Quaternion.derivative( quaternion=W_Q_B, diff --git a/src/jaxsim/mujoco/utils.py b/src/jaxsim/mujoco/utils.py index 7519164cd..07280ace3 100644 --- a/src/jaxsim/mujoco/utils.py +++ b/src/jaxsim/mujoco/utils.py @@ -63,7 +63,7 @@ def mujoco_data_from_jaxsim( # Set the model orientation. model_helper.set_base_orientation( - orientation=np.array(jaxsim_data.base_orientation()) + orientation=np.array(jaxsim_data.base_orientation) ) # Set the joint positions. @@ -71,11 +71,7 @@ def mujoco_data_from_jaxsim( model_helper.set_joint_positions( joint_names=list(jaxsim_model.joint_names()), - positions=np.array( - jaxsim_data.joint_positions( - model=jaxsim_model, joint_names=jaxsim_model.joint_names() - ) - ), + positions=np.array(jaxsim_data.joint_positions), ) # Updating these joints is not necessary after the first time. diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index adf767651..865c3f41c 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -192,8 +192,9 @@ def extract_model_data( # Combine the pose of the base link (child of the found fixed joint) # with the pose of the fixed joint connecting with the world. # Note: we assume it's a fixed joint and ignore any joint angle. - links_dict[base_link_name].mutable(validate=False).pose = ( - joints_with_world_parent[0].pose @ links_dict[base_link_name].pose + links_dict[joints_with_world_parent[0].child.name].pose = ( + joints_with_world_parent[0].pose + @ links_dict[joints_with_world_parent[0].child.name].pose ) # ============ diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index 74fdee346..abf464df1 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -808,12 +808,12 @@ def integrate_data_with_average_contact_forces( s_t0 = data.joint_positions W_p_B_t0 = data.base_position - W_Q_B_t0 = data.base_orientation(dcm=False) + W_Q_B_t0 = data.base_orientation ṡ_t0 = data.joint_velocities with data.switch_velocity_representation(jaxsim.VelRepr.Mixed): - W_ṗ_B_t0 = data.base_velocity()[0:3] - W_ω_WB_t0 = data.base_velocity()[3:6] + W_ṗ_B_t0 = data.kyn_dyn.base_velocity[0:3] + W_ω_WB_t0 = data.kyn_dyn.base_velocity[3:6] with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): W_ν_t0 = data.generalized_velocity() diff --git a/tests/test_api_contact.py b/tests/test_api_contact.py index 9c819f72c..e0bb5da90 100644 --- a/tests/test_api_contact.py +++ b/tests/test_api_contact.py @@ -134,10 +134,10 @@ def test_contact_jacobian_derivative( data_with_frames = js.data.JaxSimModelData.build( model=model_with_frames, base_position=data.base_position, - base_quaternion=data.base_orientation(dcm=False), + base_quaternion=data.base_orientation, joint_positions=data.joint_positions, - base_linear_velocity=data.base_velocity()[0:3], - base_angular_velocity=data.base_velocity()[3:6], + base_linear_velocity=data.kyn_dyn.base_velocity[0:3], + base_angular_velocity=data.kyn_dyn.base_velocity[3:6], joint_velocities=data.joint_velocities, velocity_representation=velocity_representation, ) diff --git a/tests/test_api_data.py b/tests/test_api_data.py index 417ce2f7c..cdc38a833 100644 --- a/tests/test_api_data.py +++ b/tests/test_api_data.py @@ -93,25 +93,27 @@ def test_data_change_velocity_representation( model=model, data=data ) - assert data.base_velocity() == pytest.approx(kin_dyn_inertial.base_velocity()) + assert data.kyn_dyn.base_velocity == pytest.approx(kin_dyn_inertial.base_velocity()) if not model.floating_base(): return with data.switch_velocity_representation(VelRepr.Mixed): - assert data.base_velocity() == pytest.approx(kin_dyn_mixed.base_velocity()) - assert data.base_velocity()[0:3] != pytest.approx( + assert data.kyn_dyn.base_velocity == pytest.approx( + kin_dyn_mixed.base_velocity() + ) + assert data.kyn_dyn.base_velocity[0:3] != pytest.approx( data.state.physics_model.base_linear_velocity ) - assert data.base_velocity()[3:6] == pytest.approx( + assert data.kyn_dyn.base_velocity[3:6] == pytest.approx( data.state.physics_model.base_angular_velocity ) with data.switch_velocity_representation(VelRepr.Body): - assert data.base_velocity() == pytest.approx(kin_dyn_body.base_velocity()) - assert data.base_velocity()[0:3] != pytest.approx( + assert data.kyn_dyn.base_velocity == pytest.approx(kin_dyn_body.base_velocity()) + assert data.kyn_dyn.base_velocity[0:3] != pytest.approx( data.state.physics_model.base_linear_velocity ) - assert data.base_velocity()[3:6] != pytest.approx( + assert data.kyn_dyn.base_velocity[3:6] != pytest.approx( data.state.physics_model.base_angular_velocity ) diff --git a/tests/test_api_frame.py b/tests/test_api_frame.py index 4e7690913..c02eee410 100644 --- a/tests/test_api_frame.py +++ b/tests/test_api_frame.py @@ -253,7 +253,7 @@ def compute_q(data: js.data.JaxSimModelData) -> jax.Array: q = jnp.hstack( [ data.base_position, - data.base_orientation(), + data.base_orientation, data.joint_positions, ] ) @@ -262,13 +262,13 @@ def compute_q(data: js.data.JaxSimModelData) -> jax.Array: def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: with data.switch_velocity_representation(VelRepr.Body): - B_ω_WB = data.base_velocity()[3:6] + B_ω_WB = data.kyn_dyn.base_velocity[3:6] with data.switch_velocity_representation(VelRepr.Mixed): - W_ṗ_B = data.base_velocity()[0:3] + W_ṗ_B = data.kyn_dyn.base_velocity[0:3] W_Q̇_B = Quaternion.derivative( - quaternion=data.base_orientation(), + quaternion=data.base_orientation, omega=B_ω_WB, omega_in_body_fixed=True, K=0.0, diff --git a/tests/test_api_link.py b/tests/test_api_link.py index fbc027e57..09079051c 100644 --- a/tests/test_api_link.py +++ b/tests/test_api_link.py @@ -344,7 +344,7 @@ def J(q) -> jax.Array: def compute_q(data: js.data.JaxSimModelData) -> jax.Array: q = jnp.hstack( - [data.base_position, data.base_orientation(), data.joint_positions] + [data.base_position, data.base_orientation, data.joint_positions] ) return q @@ -352,13 +352,13 @@ def compute_q(data: js.data.JaxSimModelData) -> jax.Array: def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: with data.switch_velocity_representation(VelRepr.Body): - B_ω_WB = data.base_velocity()[3:6] + B_ω_WB = data.kyn_dyn.base_velocity[3:6] with data.switch_velocity_representation(VelRepr.Mixed): - W_ṗ_B = data.base_velocity()[0:3] + W_ṗ_B = data.kyn_dyn.base_velocity[0:3] W_Q̇_B = jaxsim.math.Quaternion.derivative( - quaternion=data.base_orientation(), + quaternion=data.base_orientation, omega=B_ω_WB, omega_in_body_fixed=True, K=0.0, diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 90fdc1742..9bb0f1d60 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -81,9 +81,7 @@ def test_model_creation_and_reduction( locked_joint_positions=dict( zip( model_full.joint_names(), - data_full.joint_positions( - model=model_full, joint_names=model_full.joint_names() - ).tolist(), + data_full.joint_positions, strict=True, ) ), @@ -107,19 +105,19 @@ def test_model_creation_and_reduction( # Check that the reduced model maintains the same integrator of the full model. assert model_full.integrator == model_reduced.integrator + joint_idxs_reduced = js.joint.names_to_idxs( + model=model_reduced, joint_names=model_reduced.joint_names() + ) + # Build the data of the reduced model. data_reduced = js.data.JaxSimModelData.build( model=model_reduced, base_position=data_full.base_position, - base_quaternion=data_full.base_orientation(dcm=False), - joint_positions=data_full.joint_positions( - model=model_full, joint_names=model_reduced.joint_names() - ), - base_linear_velocity=data_full.base_velocity()[0:3], - base_angular_velocity=data_full.base_velocity()[3:6], - joint_velocities=data_full.joint_velocities( - model=model_full, joint_names=model_reduced.joint_names() - ), + base_quaternion=data_full.base_orientation, + joint_positions=data_full.joint_positions[joint_idxs_reduced], + base_linear_velocity=data_full.kyn_dyn.base_velocity[0:3], + base_angular_velocity=data_full.kyn_dyn.base_velocity[3:6], + joint_velocities=data_full.joint_velocities[joint_idxs_reduced], velocity_representation=data_full.velocity_representation, ) @@ -138,12 +136,12 @@ def test_model_creation_and_reduction( ) # Check that joint serialization works. - assert data_full.joint_positions( - model=model_full, joint_names=model_reduced.joint_names() - ) == pytest.approx(data_reduced.joint_positions()) - assert data_full.joint_velocities( - model=model_full, joint_names=model_reduced.joint_names() - ) == pytest.approx(data_reduced.joint_velocities()) + assert data_full.joint_positions[joint_idxs_reduced] == pytest.approx( + data_reduced.joint_positions + ) + assert data_full.joint_velocities[joint_idxs_reduced] == pytest.approx( + data_reduced.joint_velocities + ) # Check that link transforms are preserved. for link_name in model_reduced.link_names(): @@ -285,49 +283,50 @@ def test_model_rbda( _, subkey = jax.random.split(prng_key, num=2) - data = js.data.random_model_data( - model=model, key=subkey, velocity_representation=velocity_representation - ) + with jax.disable_jit(): + data = js.data.random_model_data( + model=model, key=subkey, velocity_representation=velocity_representation + ) - kin_dyn = utils_idyntree.build_kindyncomputations_from_jaxsim_model( - model=model, data=data - ) + kin_dyn = utils_idyntree.build_kindyncomputations_from_jaxsim_model( + model=model, data=data + ) - # ===== - # Tests - # ===== + # ===== + # Tests + # ===== - # Support both fixed-base and floating-base models by slicing the first six rows. - sl = np.s_[0:] if model.floating_base() else np.s_[6:] + # Support both fixed-base and floating-base models by slicing the first six rows. + sl = np.s_[0:] if model.floating_base() else np.s_[6:] - # Mass matrix - M_idt = kin_dyn.mass_matrix() - M_js = js.model.free_floating_mass_matrix(model=model, data=data) - assert pytest.approx(M_idt[sl, sl]) == M_js[sl, sl] + # Mass matrix + M_idt = kin_dyn.mass_matrix() + M_js = js.model.free_floating_mass_matrix(model=model, data=data) + assert pytest.approx(M_idt[sl, sl]) == M_js[sl, sl] - # Gravity forces - g_idt = kin_dyn.gravity_forces() - g_js = js.model.free_floating_gravity_forces(model=model, data=data) - assert pytest.approx(g_idt[sl]) == g_js[sl] + # Gravity forces + g_idt = kin_dyn.gravity_forces() + g_js = js.model.free_floating_gravity_forces(model=model, data=data) + assert pytest.approx(g_idt[sl]) == g_js[sl] - # Bias forces - h_idt = kin_dyn.bias_forces() - h_js = js.model.free_floating_bias_forces(model=model, data=data) - assert pytest.approx(h_idt[sl]) == h_js[sl] + # Bias forces + h_idt = kin_dyn.bias_forces() + h_js = js.model.free_floating_bias_forces(model=model, data=data) + assert pytest.approx(h_idt[sl]) == h_js[sl] - # Forward kinematics - HH_js = js.model.forward_kinematics(model=model, data=data) - HH_idt = jnp.stack( - [kin_dyn.frame_transform(frame_name=name) for name in model.link_names()] - ) - assert pytest.approx(HH_idt) == HH_js + # Forward kinematics + HH_js = js.model.forward_kinematics(model=model, data=data) + HH_idt = jnp.stack( + [kin_dyn.frame_transform(frame_name=name) for name in model.link_names()] + ) + assert pytest.approx(HH_idt) == HH_js - # Bias accelerations - Jν_js = js.model.link_bias_accelerations(model=model, data=data) - Jν_idt = jnp.stack( - [kin_dyn.frame_bias_acc(frame_name=name) for name in model.link_names()] - ) - assert pytest.approx(Jν_idt) == Jν_js + # Bias accelerations + Jν_js = js.model.link_bias_accelerations(model=model, data=data) + Jν_idt = jnp.stack( + [kin_dyn.frame_bias_acc(frame_name=name) for name in model.link_names()] + ) + assert pytest.approx(Jν_idt) == Jν_js def test_model_jacobian( @@ -451,7 +450,7 @@ def M(q) -> jax.Array: def compute_q(data: js.data.JaxSimModelData) -> jax.Array: q = jnp.hstack( - [data.base_position, data.base_orientation(), data.joint_positions] + [data.base_position, data.base_orientation, data.joint_positions] ) return q @@ -459,13 +458,13 @@ def compute_q(data: js.data.JaxSimModelData) -> jax.Array: def compute_q̇(data: js.data.JaxSimModelData) -> jax.Array: with data.switch_velocity_representation(VelRepr.Body): - B_ω_WB = data.base_velocity()[3:6] + B_ω_WB = data.kyn_dyn.base_velocity[3:6] with data.switch_velocity_representation(VelRepr.Mixed): - W_ṗ_B = data.base_velocity()[0:3] + W_ṗ_B = data.kyn_dyn.base_velocity[0:3] W_Q̇_B = jaxsim.math.Quaternion.derivative( - quaternion=data.base_orientation(), + quaternion=data.base_orientation, omega=B_ω_WB, omega_in_body_fixed=True, K=0.0, diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 54811a09a..60206e203 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -77,9 +77,9 @@ def test_ad_aba( # State in VelRepr.Inertial representation. W_p_B = data.base_position - W_Q_B = data.base_orientation(dcm=False) + W_Q_B = data.base_orientation s = data.joint_positions - W_v_WB = data.base_velocity() + W_v_WB = data.kyn_dyn.base_velocity ṡ = data.joint_velocities i_X_λ = data.kyn_dyn.joint_transforms S = data.kyn_dyn.motion_subspaces @@ -135,9 +135,9 @@ def test_ad_rnea( # State in VelRepr.Inertial representation. W_p_B = data.base_position - W_Q_B = data.base_orientation(dcm=False) + W_Q_B = data.base_orientation s = data.joint_positions - W_v_WB = data.base_velocity() + W_v_WB = data.kyn_dyn.base_velocity ṡ = data.joint_velocities i_X_λ = data.kyn_dyn.joint_transforms S = data.kyn_dyn.motion_subspaces @@ -204,7 +204,7 @@ def test_ad_crba( ) # State in VelRepr.Inertial representation. - s = data.joint_positions(model=model) + s = data.joint_positions i_X_λ = data.kyn_dyn.joint_transforms S = data.kyn_dyn.motion_subspaces @@ -244,8 +244,8 @@ def test_ad_fk( # State in VelRepr.Inertial representation. W_p_B = data.base_position - W_Q_B = data.base_orientation(dcm=False) - s = data.joint_positions(model=model) + W_Q_B = data.base_orientation + s = data.joint_positions i_X_λ = data.kyn_dyn.joint_transforms # ==== @@ -284,7 +284,7 @@ def test_ad_jacobian( ) # State in VelRepr.Inertial representation. - s = data.joint_positions(model=model) + s = data.joint_positions i_X_λ = data.kyn_dyn.joint_transforms S = data.kyn_dyn.motion_subspaces @@ -379,9 +379,9 @@ def test_ad_integration( # State in VelRepr.Inertial representation. W_p_B = data.base_position - W_Q_B = data.base_orientation(dcm=False) + W_Q_B = data.base_orientation s = data.joint_positions - W_v_WB = data.base_velocity() + W_v_WB = data.kyn_dyn.base_velocity ṡ = data.joint_velocities m = data.state.extended["tangential_deformation"] @@ -434,10 +434,10 @@ def step( ) xf_W_p_B = data_xf.base_position - xf_W_Q_B = data_xf.base_orientation(dcm=False) - xf_s = data_xf.joint_positions(model=model) - xf_W_v_WB = data_xf.base_velocity() - xf_ṡ = data_xf.joint_velocities(model=model) + xf_W_Q_B = data_xf.state.physics_model.base_quaternion + xf_s = data_xf.joint_positions + xf_W_v_WB = data_xf.kyn_dyn.base_velocity + xf_ṡ = data_xf.joint_velocities xf_m = data_xf.state.extended["tangential_deformation"] return xf_W_p_B, xf_W_Q_B, xf_s, xf_W_v_WB, xf_ṡ, xf_m diff --git a/tests/test_simulations.py b/tests/test_simulations.py index b2eac6091..dc1f59b44 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -80,7 +80,7 @@ def test_box_with_external_forces( # Check that the box didn't move. assert data.base_position == pytest.approx(data0.base_position) - assert data.base_orientation() == pytest.approx(data0.base_orientation()) + assert data.base_orientation == pytest.approx(data0.base_orientation) def test_box_with_zero_gravity( @@ -461,7 +461,7 @@ def test_joint_limits( data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.005, tf=3.0) assert ( - np.min(np.array(data_tf.joint_positions()), axis=0) + tolerance + np.min(np.array(data_tf.joint_positions), axis=0) + tolerance >= position_limits_min ) @@ -474,6 +474,6 @@ def test_joint_limits( data_tf = run_simulation(model=model, data_t0=data_t0, dt=0.001, tf=3.0) assert ( - np.max(np.array(data_tf.joint_positions()), axis=0) - tolerance + np.max(np.array(data_tf.joint_positions), axis=0) - tolerance <= position_limits_max ) diff --git a/tests/utils_idyntree.py b/tests/utils_idyntree.py index c371dcdb2..5aea83ba1 100644 --- a/tests/utils_idyntree.py +++ b/tests/utils_idyntree.py @@ -64,7 +64,7 @@ def build_kindyncomputations_from_jaxsim_model( else dict( zip( model.joint_names(), - data.joint_positions(model=model, joint_names=model.joint_names()), + data.joint_positions, strict=True, ) ) @@ -109,8 +109,8 @@ def store_jaxsim_data_in_kindyncomputations( kin_dyn.set_robot_state( joint_positions=np.array(data.joint_positions), joint_velocities=np.array(data.joint_velocities), - base_transform=np.array(data.base_transform()), - base_velocity=np.array(data.base_velocity()), + base_transform=np.array(data.kyn_dyn.base_transform), + base_velocity=np.array(data.kyn_dyn.base_velocity), ) return kin_dyn