diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 4f78b08c4..1ef55158b 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -389,7 +389,7 @@ def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix ).astype(float) @jax.jit - def base_transform(self) -> jtp.MatrixJax: + def base_transform(self) -> jtp.Matrix: """ Get the base transform. diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index c8fb6ae17..d7506d712 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1610,7 +1610,7 @@ def other_representation_to_inertial( # not remove gravity during the propagation. # Initialize the loop. - Carry = tuple[jtp.MatrixJax, jtp.MatrixJax] + Carry = tuple[jtp.Matrix, jtp.Matrix] carry0: Carry = (L_v_WL, L_v̇_WL) def propagate_accelerations(carry: Carry, i: jtp.Int) -> tuple[Carry, None]: diff --git a/src/jaxsim/api/ode_data.py b/src/jaxsim/api/ode_data.py index 950310399..766f7f926 100644 --- a/src/jaxsim/api/ode_data.py +++ b/src/jaxsim/api/ode_data.py @@ -31,8 +31,8 @@ class ODEInput(JaxsimDataclass): @staticmethod def build_from_jaxsim_model( model: js.model.JaxSimModel | None = None, - joint_forces: jtp.VectorJax | None = None, - link_forces: jtp.MatrixJax | None = None, + joint_forces: jtp.VectorLike | None = None, + link_forces: jtp.MatrixLike | None = None, ) -> ODEInput: """ Build an `ODEInput` from a `JaxSimModel`. @@ -501,14 +501,14 @@ class PhysicsModelInput(JaxsimDataclass): f_ext: The matrix of external forces applied to the links. """ - tau: jtp.VectorJax - f_ext: jtp.MatrixJax + tau: jtp.Vector + f_ext: jtp.Matrix @staticmethod def build_from_jaxsim_model( model: js.model.JaxSimModel | None = None, - joint_forces: jtp.VectorJax | None = None, - link_forces: jtp.MatrixJax | None = None, + joint_forces: jtp.VectorLike | None = None, + link_forces: jtp.MatrixLike | None = None, ) -> PhysicsModelInput: """ Build a `PhysicsModelInput` from a `JaxSimModel`. @@ -535,8 +535,8 @@ def build_from_jaxsim_model( @staticmethod def build( - joint_forces: jtp.VectorJax | None = None, - link_forces: jtp.MatrixJax | None = None, + joint_forces: jtp.VectorLike | None = None, + link_forces: jtp.MatrixLike | None = None, number_of_dofs: jtp.Int | None = None, number_of_links: jtp.Int | None = None, ) -> PhysicsModelInput: diff --git a/src/jaxsim/integrators/common.py b/src/jaxsim/integrators/common.py index 03dde9262..8ace25c7d 100644 --- a/src/jaxsim/integrators/common.py +++ b/src/jaxsim/integrators/common.py @@ -28,8 +28,8 @@ # Generic types # ============= -Time = jtp.Float -TimeStep = jtp.Float +Time = jtp.FloatLike +TimeStep = jtp.FloatLike State = NextState = TypeVar("State") StateDerivative = TypeVar("StateDerivative") PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree) diff --git a/src/jaxsim/rbda/aba.py b/src/jaxsim/rbda/aba.py index b98685f32..1e8a9c510 100644 --- a/src/jaxsim/rbda/aba.py +++ b/src/jaxsim/rbda/aba.py @@ -121,10 +121,7 @@ def aba( # Pass 1 # ====== - Pass1Carry = tuple[ - jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax - ] - + Pass1Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix] pass_1_carry: Pass1Carry = (v, c, MA, pA, i_X_0) # Propagate kinematics and initialize AB inertia and AB bias forces. @@ -178,10 +175,7 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]: d = jnp.zeros(shape=(model.number_of_links(), 1)) u = jnp.zeros(shape=(model.number_of_links(), 1)) - Pass2Carry = tuple[ - jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax - ] - + Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix] pass_2_carry: Pass2Carry = (U, d, u, MA, pA) def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: @@ -204,8 +198,8 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]: # Propagate them to the parent, handling the base link. def propagate( - MA_pA: tuple[jtp.MatrixJax, jtp.MatrixJax] - ) -> tuple[jtp.MatrixJax, jtp.MatrixJax]: + MA_pA: tuple[jtp.Matrix, jtp.Matrix] + ) -> tuple[jtp.Matrix, jtp.Matrix]: MA, pA = MA_pA @@ -248,7 +242,7 @@ def propagate( s̈ = jnp.zeros_like(s) a = jnp.zeros_like(v).at[0].set(a0) - Pass3Carry = tuple[jtp.MatrixJax, jtp.VectorJax] + Pass3Carry = tuple[jtp.Matrix, jtp.Vector] pass_3_carry = (a, s̈) def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]: diff --git a/src/jaxsim/rbda/collidable_points.py b/src/jaxsim/rbda/collidable_points.py index 20f8e8701..4e6800782 100644 --- a/src/jaxsim/rbda/collidable_points.py +++ b/src/jaxsim/rbda/collidable_points.py @@ -80,7 +80,7 @@ def collidable_points_pos_vel( # Propagate kinematics # ==================== - PropagateTransformsCarry = tuple[jtp.MatrixJax, jtp.Matrix] + PropagateTransformsCarry = tuple[jtp.Matrix, jtp.Matrix] propagate_transforms_carry: PropagateTransformsCarry = (W_X_i, W_v_Wi) def propagate_kinematics( @@ -118,8 +118,9 @@ def propagate_kinematics( # ================================================== def process_point_kinematics( - Li_p_C: jtp.VectorJax, parent_body: jtp.Int - ) -> tuple[jtp.VectorJax, jtp.VectorJax]: + Li_p_C: jtp.Vector, parent_body: jtp.Int + ) -> tuple[jtp.Vector, jtp.Vector]: + # Compute the position of the collidable point. W_p_Ci = ( Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1]) diff --git a/src/jaxsim/rbda/crba.py b/src/jaxsim/rbda/crba.py index 27ee83042..904048832 100644 --- a/src/jaxsim/rbda/crba.py +++ b/src/jaxsim/rbda/crba.py @@ -45,7 +45,7 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat # Propagate kinematics # ==================== - ForwardPassCarry = tuple[jtp.MatrixJax] + ForwardPassCarry = tuple[jtp.Matrix] forward_pass_carry: ForwardPassCarry = (i_X_0,) def propagate_kinematics( @@ -71,7 +71,7 @@ def propagate_kinematics( M = jnp.zeros(shape=(6 + model.dofs(), 6 + model.dofs())) - BackwardPassCarry = tuple[jtp.MatrixJax, jtp.MatrixJax] + BackwardPassCarry = tuple[jtp.Matrix, jtp.Matrix] backward_pass_carry: BackwardPassCarry = (Mc, M) def backward_pass( @@ -90,7 +90,7 @@ def backward_pass( j = i - CarryInnerFn = tuple[jtp.Int, jtp.MatrixJax, jtp.MatrixJax] + CarryInnerFn = tuple[jtp.Int, jtp.Matrix, jtp.Matrix] carry_inner_fn = (j, Fi, M) def while_loop_body(carry: CarryInnerFn) -> CarryInnerFn: diff --git a/src/jaxsim/rbda/forward_kinematics.py b/src/jaxsim/rbda/forward_kinematics.py index 8bcab038a..cdfbc35a3 100644 --- a/src/jaxsim/rbda/forward_kinematics.py +++ b/src/jaxsim/rbda/forward_kinematics.py @@ -61,7 +61,7 @@ def forward_kinematics_model( # Propagate the kinematics # ======================== - PropagateKinematicsCarry = tuple[jtp.MatrixJax] + PropagateKinematicsCarry = tuple[jtp.Matrix] propagate_kinematics_carry: PropagateKinematicsCarry = (W_X_i,) def propagate_kinematics( diff --git a/src/jaxsim/rbda/jacobian.py b/src/jaxsim/rbda/jacobian.py index a595c7360..197a45ee2 100644 --- a/src/jaxsim/rbda/jacobian.py +++ b/src/jaxsim/rbda/jacobian.py @@ -50,7 +50,7 @@ def jacobian( # Propagate kinematics # ==================== - PropagateKinematicsCarry = tuple[jtp.MatrixJax] + PropagateKinematicsCarry = tuple[jtp.Matrix] propagate_kinematics_carry: PropagateKinematicsCarry = (i_X_0,) def propagate_kinematics( @@ -86,9 +86,9 @@ def propagate_kinematics( # Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True. κ_bool = model.kin_dyn_parameters.support_body_array_bool[link_index] - def compute_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> tuple[jtp.MatrixJax, None]: + def compute_jacobian(J: jtp.Matrix, i: jtp.Int) -> tuple[jtp.Matrix, None]: - def update_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> jtp.MatrixJax: + def update_jacobian(J: jtp.Matrix, i: jtp.Int) -> jtp.Matrix: ii = i - 1 @@ -164,7 +164,7 @@ def jacobian_full_doubly_left( J = jnp.zeros(shape=(6, 6 + model.dofs())) J = J.at[0:6, 0:6].set(jnp.eye(6)) - ComputeFullJacobianCarry = tuple[jtp.MatrixJax, jtp.MatrixJax] + ComputeFullJacobianCarry = tuple[jtp.Matrix, jtp.Matrix] compute_full_jacobian_carry: ComputeFullJacobianCarry = (B_X_i, J) def compute_full_jacobian( @@ -261,7 +261,7 @@ def A_Ẋ_B(A_X_B: jtp.Matrix, B_v_AB: jtp.Vector) -> jtp.Matrix: J̇ = jnp.zeros(shape=(6, 6 + model.dofs())) ComputeFullJacobianDerivativeCarry = tuple[ - jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax + jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix ] compute_full_jacobian_derivative_carry: ComputeFullJacobianDerivativeCarry = ( diff --git a/src/jaxsim/rbda/rnea.py b/src/jaxsim/rbda/rnea.py index b5f927d1e..625f8fede 100644 --- a/src/jaxsim/rbda/rnea.py +++ b/src/jaxsim/rbda/rnea.py @@ -132,7 +132,7 @@ def rnea( # Pass 1 # ====== - ForwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax] + ForwardPassCarry = Tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix] forward_pass_carry: ForwardPassCarry = (v, a, i_X_0, f) def forward_pass( @@ -186,7 +186,7 @@ def forward_pass( τ = jnp.zeros_like(s) - BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax] + BackwardPassCarry = Tuple[jtp.Vector, jtp.Matrix] backward_pass_carry: BackwardPassCarry = (τ, f) def backward_pass( @@ -201,7 +201,7 @@ def backward_pass( τ = τ.at[ii].set(τ_i.squeeze()) # Propagate the force to the parent link. - def update_f(f: jtp.MatrixJax) -> jtp.MatrixJax: + def update_f(f: jtp.Matrix) -> jtp.Matrix: f_λi = f[λ[i]] + i_X_λi[i].T @ f[i] f = f.at[λ[i]].set(f_λi) diff --git a/src/jaxsim/rbda/utils.py b/src/jaxsim/rbda/utils.py index 19e4e87f3..0f209db78 100644 --- a/src/jaxsim/rbda/utils.py +++ b/src/jaxsim/rbda/utils.py @@ -19,7 +19,7 @@ def process_inputs( joint_accelerations: jtp.VectorLike | None = None, joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, - standard_gravity: jtp.VectorLike | None = None, + standard_gravity: jtp.ScalarLike | None = None, ) -> tuple[ jtp.Vector, jtp.Vector, diff --git a/src/jaxsim/typing.py b/src/jaxsim/typing.py index 5d56467c0..9b392836d 100644 --- a/src/jaxsim/typing.py +++ b/src/jaxsim/typing.py @@ -7,14 +7,14 @@ # JAX types # ========= -ScalarJax = jax.Array -IntJax = ScalarJax -BoolJax = ScalarJax -FloatJax = ScalarJax +Array = jax.Array +Scalar = Array +Vector = Array +Matrix = Array -ArrayJax = jax.Array -VectorJax = ArrayJax -MatrixJax = ArrayJax +Int = Scalar +Bool = Scalar +Float = Scalar PyTree = ( dict[Hashable, "PyTree"] | list["PyTree"] | tuple["PyTree"] | None | jax.Array | Any @@ -24,19 +24,11 @@ # Mixed JAX / NumPy types # ======================= -Array = jax.typing.ArrayLike -Scalar = Array -Vector = Array -Matrix = Array +ArrayLike = jax.typing.ArrayLike | tuple +ScalarLike = int | float | Scalar | ArrayLike +VectorLike = Vector | ArrayLike | tuple +MatrixLike = Matrix | ArrayLike -Int = int | IntJax -Bool = bool | ArrayJax -Float = float | FloatJax - -ScalarLike = Scalar | int | float -ArrayLike = Array -VectorLike = Vector -MatrixLike = Matrix -IntLike = Int -BoolLike = Bool -FloatLike = Float +IntLike = int | Int | jax.typing.ArrayLike +BoolLike = bool | Bool | jax.typing.ArrayLike +FloatLike = float | Float | jax.typing.ArrayLike diff --git a/src/jaxsim/utils/jaxsim_dataclass.py b/src/jaxsim/utils/jaxsim_dataclass.py index 0c80cda2c..0b3801b1e 100644 --- a/src/jaxsim/utils/jaxsim_dataclass.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -326,7 +326,7 @@ def replace(self: Self, validate: bool = True, **kwargs) -> Self: return obj - def flatten(self) -> jtp.VectorJax: + def flatten(self) -> jtp.Vector: """ Flatten the object into a 1D vector. @@ -337,7 +337,7 @@ def flatten(self) -> jtp.VectorJax: return self.flatten_fn()(self) @classmethod - def flatten_fn(cls: Type[Self]) -> Callable[[Self], jtp.VectorJax]: + def flatten_fn(cls: Type[Self]) -> Callable[[Self], jtp.Vector]: """ Return a function to flatten the object into a 1D vector. @@ -347,7 +347,7 @@ def flatten_fn(cls: Type[Self]) -> Callable[[Self], jtp.VectorJax]: return lambda pytree: jax.flatten_util.ravel_pytree(pytree)[0] - def unflatten_fn(self: Self) -> Callable[[jtp.VectorJax], Self]: + def unflatten_fn(self: Self) -> Callable[[jtp.Vector], Self]: """ Return a function to unflatten a 1D vector into the object.