diff --git a/src/jaxsim/integrators/variable_step.py b/src/jaxsim/integrators/variable_step.py index dd07fa40e..89a991440 100644 --- a/src/jaxsim/integrators/variable_step.py +++ b/src/jaxsim/integrators/variable_step.py @@ -312,9 +312,9 @@ def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState: # Clip the estimated initial step size to the given bounds, if necessary. self.params["dt0"] = jnp.clip( - a=self.params["dt0"], - a_min=jnp.minimum(self.dt_min, self.params["dt0"]), - a_max=jnp.minimum(self.dt_max, self.params["dt0"]), + self.params["dt0"], + jnp.minimum(self.dt_min, self.params["dt0"]), + jnp.minimum(self.dt_max, self.params["dt0"]), ) # ========================================================= @@ -371,7 +371,7 @@ def while_loop_body(carry: Carry) -> Carry: # Shrink the Δt every time by the safety factor (even when accepted). # The β parameters define the bounds of the timestep update factor. - safety = jnp.clip(self.safety, a_min=0.0, a_max=1.0) + safety = jnp.clip(self.safety, 0.0, 1.0) β_min = jnp.maximum(0.0, self.beta_min) β_max = jnp.maximum(β_min, self.beta_max) @@ -383,9 +383,9 @@ def while_loop_body(carry: Carry) -> Carry: # In case of acceptance, Δt_next could either be larger than Δt0, # or slightly smaller than Δt0 depending on the safety factor. Δt_next = Δt0 * jnp.clip( - a=safety * jnp.power(1 / local_error, 1 / (q + 1)), - a_min=β_min, - a_max=β_max, + safety * jnp.power(1 / local_error, 1 / (q + 1)), + β_min, + β_max, ) def accept_step():