From 709ceedf42dc28d2ca3f3082749ea4b635141d86 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Wed, 15 Jan 2025 17:48:26 +0100 Subject: [PATCH] Update `test_benchmark.py` to use jax.block_until_ready for more accurate measurement of computation time For more info see https://jax.readthedocs.io/en/latest/async_dispatch.html#async-dispatch --- tests/test_benchmark.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 641e0f243..cf6898f18 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -27,7 +27,10 @@ def benchmark_test_function( # Warm-up call to avoid including compilation time jax.vmap(func, in_axes=(None, 0))(model, data) - benchmark(jax.vmap(func, in_axes=(None, 0)), model, data) + + # Benchmark the function call + # Note: jax.block_until_ready is used to ensure that the benchmark is not measuring only the asynchronous dispatch + benchmark(jax.block_until_ready(jax.vmap(func, in_axes=(None, 0))), model, data) @pytest.mark.benchmark