diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 089c78b4e..641e0f243 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -9,13 +9,14 @@ def vectorize_data(model: js.model.JaxSimModel, batch_size: int): key = jax.random.PRNGKey(seed=0) + keys = jax.random.split(key, num=batch_size) return jax.vmap( lambda key: js.data.random_model_data( model=model, key=key, ) - )(jax.numpy.repeat(key[None, :], repeats=batch_size, axis=0)) + )(keys) def benchmark_test_function(