Skip to content

Commit

Permalink
Update test_benchmark.py to use different data for batch processing
Browse files Browse the repository at this point in the history
  • Loading branch information
xela-95 committed Jan 15, 2025
1 parent a4f6889 commit df81dfa
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@

def vectorize_data(model: js.model.JaxSimModel, batch_size: int):
key = jax.random.PRNGKey(seed=0)
keys = jax.random.split(key, num=batch_size)

return jax.vmap(
lambda key: js.data.random_model_data(
model=model,
key=key,
)
)(jax.numpy.repeat(key[None, :], repeats=batch_size, axis=0))
)(keys)


def benchmark_test_function(
Expand Down

0 comments on commit df81dfa

Please sign in to comment.