Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use jax.block_until_ready in benchmarks #344

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@

def vectorize_data(model: js.model.JaxSimModel, batch_size: int):
key = jax.random.PRNGKey(seed=0)
keys = jax.random.split(key, num=batch_size)

return jax.vmap(
lambda key: js.data.random_model_data(
model=model,
key=key,
)
)(jax.numpy.repeat(key[None, :], repeats=batch_size, axis=0))
)(keys)


def benchmark_test_function(
Expand All @@ -26,7 +27,10 @@ def benchmark_test_function(

# Warm-up call to avoid including compilation time
jax.vmap(func, in_axes=(None, 0))(model, data)
benchmark(jax.vmap(func, in_axes=(None, 0)), model, data)

# Benchmark the function call
# Note: jax.block_until_ready is used to ensure that the benchmark is not measuring only the asynchronous dispatch
benchmark(jax.block_until_ready(jax.vmap(func, in_axes=(None, 0))), model, data)


@pytest.mark.benchmark
Expand Down