Skip to content

Commit

Permalink
Remove usage of kwargs in jax.numpy.clip
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Jul 2, 2024
1 parent 97da89c commit 7f5aae2
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/jaxsim/integrators/variable_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,9 @@ def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:

# Clip the estimated initial step size to the given bounds, if necessary.
self.params["dt0"] = jnp.clip(
a=self.params["dt0"],
a_min=jnp.minimum(self.dt_min, self.params["dt0"]),
a_max=jnp.minimum(self.dt_max, self.params["dt0"]),
self.params["dt0"],
jnp.minimum(self.dt_min, self.params["dt0"]),
jnp.minimum(self.dt_max, self.params["dt0"]),
)

# =========================================================
Expand Down Expand Up @@ -371,7 +371,7 @@ def while_loop_body(carry: Carry) -> Carry:

# Shrink the Δt every time by the safety factor (even when accepted).
# The β parameters define the bounds of the timestep update factor.
safety = jnp.clip(self.safety, a_min=0.0, a_max=1.0)
safety = jnp.clip(self.safety, 0.0, 1.0)
β_min = jnp.maximum(0.0, self.beta_min)
β_max = jnp.maximum(β_min, self.beta_max)

Expand All @@ -383,9 +383,9 @@ def while_loop_body(carry: Carry) -> Carry:
# In case of acceptance, Δt_next could either be larger than Δt0,
# or slightly smaller than Δt0 depending on the safety factor.
Δt_next = Δt0 * jnp.clip(
a=safety * jnp.power(1 / local_error, 1 / (q + 1)),
a_min=β_min,
a_max=β_max,
safety * jnp.power(1 / local_error, 1 / (q + 1)),
β_min,
β_max,
)

def accept_step():
Expand Down

0 comments on commit 7f5aae2

Please sign in to comment.