Skip to content

Commit

Permalink
Removed theta_is_not_zero function
Browse files Browse the repository at this point in the history
  • Loading branch information
ConnorTingley committed Jan 15, 2025
1 parent 5b9feb7 commit 83664f0
Showing 1 changed file with 10 additions and 19 deletions.
29 changes: 10 additions & 19 deletions src/jaxsim/math/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,18 @@ def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix:

vector = vector.squeeze()

def theta_is_not_zero(axis: jtp.Vector) -> jtp.Matrix:
theta = safe_norm(vector)

v = axis
theta = safe_norm(v)
s = jnp.sin(theta)
c = jnp.cos(theta)

s = jnp.sin(theta)
c = jnp.cos(theta)
c1 = 2 * jnp.sin(theta / 2.0) ** 2

c1 = 2 * jnp.sin(theta / 2.0) ** 2
safe_theta = jnp.where(theta == 0, 1.0, theta)
u = vector / safe_theta
u = jnp.vstack(u.squeeze())

safe_theta = jnp.where(theta == 0, 1.0, theta)
u = v / safe_theta
u = jnp.vstack(u.squeeze())
R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T


R = c * jnp.eye(3) - s * Skew.wedge(u) + c1 * u @ u.T

return R.transpose()

return jnp.where(
jnp.allclose(vector, 0.0),
# Return an identity rotation matrix when the input vector is zero.
jnp.eye(3),
theta_is_not_zero(axis=vector),
)
return R.transpose()

0 comments on commit 83664f0

Please sign in to comment.