Skip to content

Commit

Permalink
Merge pull request #344 from xela-95/test/benchmark-patch
Browse files Browse the repository at this point in the history
Use `jax.block_until_ready` in benchmarks
  • Loading branch information
xela-95 authored Jan 15, 2025
2 parents d075d48 + 709ceed commit b9eadd3
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@

def vectorize_data(model: js.model.JaxSimModel, batch_size: int):
key = jax.random.PRNGKey(seed=0)
keys = jax.random.split(key, num=batch_size)

return jax.vmap(
lambda key: js.data.random_model_data(
model=model,
key=key,
)
)(jax.numpy.repeat(key[None, :], repeats=batch_size, axis=0))
)(keys)


def benchmark_test_function(
Expand All @@ -26,7 +27,10 @@ def benchmark_test_function(

# Warm-up call to avoid including compilation time
jax.vmap(func, in_axes=(None, 0))(model, data)
benchmark(jax.vmap(func, in_axes=(None, 0)), model, data)

# Benchmark the function call
# Note: jax.block_until_ready is used to ensure that the benchmark is not measuring only the asynchronous dispatch
benchmark(jax.block_until_ready(jax.vmap(func, in_axes=(None, 0))), model, data)


@pytest.mark.benchmark
Expand Down

0 comments on commit b9eadd3

Please sign in to comment.