Skip to content

Commit

Permalink
Update jaxsim.typing module
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Jul 2, 2024
1 parent b8c1f45 commit 97da89c
Show file tree
Hide file tree
Showing 13 changed files with 51 additions and 64 deletions.
2 changes: 1 addition & 1 deletion src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix
).astype(float)

@jax.jit
def base_transform(self) -> jtp.MatrixJax:
def base_transform(self) -> jtp.Matrix:
"""
Get the base transform.
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1610,7 +1610,7 @@ def other_representation_to_inertial(
# not remove gravity during the propagation.

# Initialize the loop.
Carry = tuple[jtp.MatrixJax, jtp.MatrixJax]
Carry = tuple[jtp.Matrix, jtp.Matrix]
carry0: Carry = (L_v_WL, L_v̇_WL)

def propagate_accelerations(carry: Carry, i: jtp.Int) -> tuple[Carry, None]:
Expand Down
16 changes: 8 additions & 8 deletions src/jaxsim/api/ode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class ODEInput(JaxsimDataclass):
@staticmethod
def build_from_jaxsim_model(
model: js.model.JaxSimModel | None = None,
joint_forces: jtp.VectorJax | None = None,
link_forces: jtp.MatrixJax | None = None,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
) -> ODEInput:
"""
Build an `ODEInput` from a `JaxSimModel`.
Expand Down Expand Up @@ -501,14 +501,14 @@ class PhysicsModelInput(JaxsimDataclass):
f_ext: The matrix of external forces applied to the links.
"""

tau: jtp.VectorJax
f_ext: jtp.MatrixJax
tau: jtp.Vector
f_ext: jtp.Matrix

@staticmethod
def build_from_jaxsim_model(
model: js.model.JaxSimModel | None = None,
joint_forces: jtp.VectorJax | None = None,
link_forces: jtp.MatrixJax | None = None,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
) -> PhysicsModelInput:
"""
Build a `PhysicsModelInput` from a `JaxSimModel`.
Expand All @@ -535,8 +535,8 @@ def build_from_jaxsim_model(

@staticmethod
def build(
joint_forces: jtp.VectorJax | None = None,
link_forces: jtp.MatrixJax | None = None,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
number_of_dofs: jtp.Int | None = None,
number_of_links: jtp.Int | None = None,
) -> PhysicsModelInput:
Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
# Generic types
# =============

Time = jtp.Float
TimeStep = jtp.Float
Time = jtp.FloatLike
TimeStep = jtp.FloatLike
State = NextState = TypeVar("State")
StateDerivative = TypeVar("StateDerivative")
PyTreeType = TypeVar("PyTreeType", bound=jtp.PyTree)
Expand Down
16 changes: 5 additions & 11 deletions src/jaxsim/rbda/aba.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,7 @@ def aba(
# Pass 1
# ======

Pass1Carry = tuple[
jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
]

Pass1Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
pass_1_carry: Pass1Carry = (v, c, MA, pA, i_X_0)

# Propagate kinematics and initialize AB inertia and AB bias forces.
Expand Down Expand Up @@ -178,10 +175,7 @@ def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> tuple[Pass1Carry, None]:
d = jnp.zeros(shape=(model.number_of_links(), 1))
u = jnp.zeros(shape=(model.number_of_links(), 1))

Pass2Carry = tuple[
jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
]

Pass2Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
pass_2_carry: Pass2Carry = (U, d, u, MA, pA)

def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]:
Expand All @@ -204,8 +198,8 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]:

# Propagate them to the parent, handling the base link.
def propagate(
MA_pA: tuple[jtp.MatrixJax, jtp.MatrixJax]
) -> tuple[jtp.MatrixJax, jtp.MatrixJax]:
MA_pA: tuple[jtp.Matrix, jtp.Matrix]
) -> tuple[jtp.Matrix, jtp.Matrix]:

MA, pA = MA_pA

Expand Down Expand Up @@ -248,7 +242,7 @@ def propagate(
= jnp.zeros_like(s)
a = jnp.zeros_like(v).at[0].set(a0)

Pass3Carry = tuple[jtp.MatrixJax, jtp.VectorJax]
Pass3Carry = tuple[jtp.Matrix, jtp.Vector]
pass_3_carry = (a, )

def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]:
Expand Down
7 changes: 4 additions & 3 deletions src/jaxsim/rbda/collidable_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def collidable_points_pos_vel(
# Propagate kinematics
# ====================

PropagateTransformsCarry = tuple[jtp.MatrixJax, jtp.Matrix]
PropagateTransformsCarry = tuple[jtp.Matrix, jtp.Matrix]
propagate_transforms_carry: PropagateTransformsCarry = (W_X_i, W_v_Wi)

def propagate_kinematics(
Expand Down Expand Up @@ -118,8 +118,9 @@ def propagate_kinematics(
# ==================================================

def process_point_kinematics(
Li_p_C: jtp.VectorJax, parent_body: jtp.Int
) -> tuple[jtp.VectorJax, jtp.VectorJax]:
Li_p_C: jtp.Vector, parent_body: jtp.Int
) -> tuple[jtp.Vector, jtp.Vector]:

# Compute the position of the collidable point.
W_p_Ci = (
Adjoint.to_transform(adjoint=W_X_i[parent_body]) @ jnp.hstack([Li_p_C, 1])
Expand Down
6 changes: 3 additions & 3 deletions src/jaxsim/rbda/crba.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def crba(model: js.model.JaxSimModel, *, joint_positions: jtp.Vector) -> jtp.Mat
# Propagate kinematics
# ====================

ForwardPassCarry = tuple[jtp.MatrixJax]
ForwardPassCarry = tuple[jtp.Matrix]
forward_pass_carry: ForwardPassCarry = (i_X_0,)

def propagate_kinematics(
Expand All @@ -71,7 +71,7 @@ def propagate_kinematics(

M = jnp.zeros(shape=(6 + model.dofs(), 6 + model.dofs()))

BackwardPassCarry = tuple[jtp.MatrixJax, jtp.MatrixJax]
BackwardPassCarry = tuple[jtp.Matrix, jtp.Matrix]
backward_pass_carry: BackwardPassCarry = (Mc, M)

def backward_pass(
Expand All @@ -90,7 +90,7 @@ def backward_pass(

j = i

CarryInnerFn = tuple[jtp.Int, jtp.MatrixJax, jtp.MatrixJax]
CarryInnerFn = tuple[jtp.Int, jtp.Matrix, jtp.Matrix]
carry_inner_fn = (j, Fi, M)

def while_loop_body(carry: CarryInnerFn) -> CarryInnerFn:
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/rbda/forward_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def forward_kinematics_model(
# Propagate the kinematics
# ========================

PropagateKinematicsCarry = tuple[jtp.MatrixJax]
PropagateKinematicsCarry = tuple[jtp.Matrix]
propagate_kinematics_carry: PropagateKinematicsCarry = (W_X_i,)

def propagate_kinematics(
Expand Down
10 changes: 5 additions & 5 deletions src/jaxsim/rbda/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def jacobian(
# Propagate kinematics
# ====================

PropagateKinematicsCarry = tuple[jtp.MatrixJax]
PropagateKinematicsCarry = tuple[jtp.Matrix]
propagate_kinematics_carry: PropagateKinematicsCarry = (i_X_0,)

def propagate_kinematics(
Expand Down Expand Up @@ -86,9 +86,9 @@ def propagate_kinematics(
# Checking if j ∈ κ(i) is equivalent to: κ_bool(j) is True.
κ_bool = model.kin_dyn_parameters.support_body_array_bool[link_index]

def compute_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> tuple[jtp.MatrixJax, None]:
def compute_jacobian(J: jtp.Matrix, i: jtp.Int) -> tuple[jtp.Matrix, None]:

def update_jacobian(J: jtp.MatrixJax, i: jtp.Int) -> jtp.MatrixJax:
def update_jacobian(J: jtp.Matrix, i: jtp.Int) -> jtp.Matrix:

ii = i - 1

Expand Down Expand Up @@ -164,7 +164,7 @@ def jacobian_full_doubly_left(
J = jnp.zeros(shape=(6, 6 + model.dofs()))
J = J.at[0:6, 0:6].set(jnp.eye(6))

ComputeFullJacobianCarry = tuple[jtp.MatrixJax, jtp.MatrixJax]
ComputeFullJacobianCarry = tuple[jtp.Matrix, jtp.Matrix]
compute_full_jacobian_carry: ComputeFullJacobianCarry = (B_X_i, J)

def compute_full_jacobian(
Expand Down Expand Up @@ -261,7 +261,7 @@ def A_Ẋ_B(A_X_B: jtp.Matrix, B_v_AB: jtp.Vector) -> jtp.Matrix:
= jnp.zeros(shape=(6, 6 + model.dofs()))

ComputeFullJacobianDerivativeCarry = tuple[
jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax
jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix
]

compute_full_jacobian_derivative_carry: ComputeFullJacobianDerivativeCarry = (
Expand Down
6 changes: 3 additions & 3 deletions src/jaxsim/rbda/rnea.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def rnea(
# Pass 1
# ======

ForwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax]
ForwardPassCarry = Tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix, jtp.Matrix]
forward_pass_carry: ForwardPassCarry = (v, a, i_X_0, f)

def forward_pass(
Expand Down Expand Up @@ -186,7 +186,7 @@ def forward_pass(

τ = jnp.zeros_like(s)

BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax]
BackwardPassCarry = Tuple[jtp.Vector, jtp.Matrix]
backward_pass_carry: BackwardPassCarry = (τ, f)

def backward_pass(
Expand All @@ -201,7 +201,7 @@ def backward_pass(
τ = τ.at[ii].set(τ_i.squeeze())

# Propagate the force to the parent link.
def update_f(f: jtp.MatrixJax) -> jtp.MatrixJax:
def update_f(f: jtp.Matrix) -> jtp.Matrix:

f_λi = f[λ[i]] + i_X_λi[i].T @ f[i]
f = f.at[λ[i]].set(f_λi)
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/rbda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def process_inputs(
joint_accelerations: jtp.VectorLike | None = None,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
standard_gravity: jtp.VectorLike | None = None,
standard_gravity: jtp.ScalarLike | None = None,
) -> tuple[
jtp.Vector,
jtp.Vector,
Expand Down
36 changes: 14 additions & 22 deletions src/jaxsim/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
# JAX types
# =========

ScalarJax = jax.Array
IntJax = ScalarJax
BoolJax = ScalarJax
FloatJax = ScalarJax
Array = jax.Array
Scalar = Array
Vector = Array
Matrix = Array

ArrayJax = jax.Array
VectorJax = ArrayJax
MatrixJax = ArrayJax
Int = Scalar
Bool = Scalar
Float = Scalar

PyTree = (
dict[Hashable, "PyTree"] | list["PyTree"] | tuple["PyTree"] | None | jax.Array | Any
Expand All @@ -24,19 +24,11 @@
# Mixed JAX / NumPy types
# =======================

Array = jax.typing.ArrayLike
Scalar = Array
Vector = Array
Matrix = Array
ArrayLike = jax.typing.ArrayLike | tuple
ScalarLike = int | float | Scalar | ArrayLike
VectorLike = Vector | ArrayLike | tuple
MatrixLike = Matrix | ArrayLike

Int = int | IntJax
Bool = bool | ArrayJax
Float = float | FloatJax

ScalarLike = Scalar | int | float
ArrayLike = Array
VectorLike = Vector
MatrixLike = Matrix
IntLike = Int
BoolLike = Bool
FloatLike = Float
IntLike = int | Int | jax.typing.ArrayLike
BoolLike = bool | Bool | jax.typing.ArrayLike
FloatLike = float | Float | jax.typing.ArrayLike
6 changes: 3 additions & 3 deletions src/jaxsim/utils/jaxsim_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def replace(self: Self, validate: bool = True, **kwargs) -> Self:

return obj

def flatten(self) -> jtp.VectorJax:
def flatten(self) -> jtp.Vector:
"""
Flatten the object into a 1D vector.
Expand All @@ -337,7 +337,7 @@ def flatten(self) -> jtp.VectorJax:
return self.flatten_fn()(self)

@classmethod
def flatten_fn(cls: Type[Self]) -> Callable[[Self], jtp.VectorJax]:
def flatten_fn(cls: Type[Self]) -> Callable[[Self], jtp.Vector]:
"""
Return a function to flatten the object into a 1D vector.
Expand All @@ -347,7 +347,7 @@ def flatten_fn(cls: Type[Self]) -> Callable[[Self], jtp.VectorJax]:

return lambda pytree: jax.flatten_util.ravel_pytree(pytree)[0]

def unflatten_fn(self: Self) -> Callable[[jtp.VectorJax], Self]:
def unflatten_fn(self: Self) -> Callable[[jtp.Vector], Self]:
"""
Return a function to unflatten a 1D vector into the object.
Expand Down

0 comments on commit 97da89c

Please sign in to comment.