From ffab954c1f273975fc04ddf73b0c425d5a9947ee Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 14 Nov 2024 17:28:26 +0100 Subject: [PATCH] Added CI job with TSAN and free-threading --- .github/workflows/tsan.yaml | 107 ++++++++++++++++++++++++++++++++++++ tests/api_test.py | 79 +++++++++++++++----------- tests/jaxpr_effects_test.py | 51 ++++++++++------- 3 files changed, 184 insertions(+), 53 deletions(-) create mode 100644 .github/workflows/tsan.yaml diff --git a/.github/workflows/tsan.yaml b/.github/workflows/tsan.yaml new file mode 100644 index 000000000000..ffe7fcca546e --- /dev/null +++ b/.github/workflows/tsan.yaml @@ -0,0 +1,107 @@ +name: CI - Free-threading and Thread Sanitizer (nightly) + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +on: + schedule: + - cron: "0 12 * * *" # Daily at 12:00 UTC + workflow_dispatch: # allows triggering the workflow run manually + pull_request: # Automatically trigger on pull requests affecting this file + # branches: + # - main + paths: + - '**/workflows/tsan.yaml' + +jobs: + tsan: + runs-on: linux-x86-n2-64 + container: + image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04 + strategy: + fail-fast: false + defaults: + run: + shell: bash -l {0} + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: jax + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: python/cpython + path: cpython + ref: v3.13.0 + - name: Install clang 18 + env: + DEBIAN_FRONTEND: noninteractive + run: | + apt update + apt install -y clang-18 libstdc++-14-dev build-essential libssl-dev \ + zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev curl git \ + libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev \ + libffi-dev liblzma-dev + - name: Build CPython with TSAN enabled + run: | + cd cpython + mkdir ${GITHUB_WORKSPACE}/cpython-tsan + CC=clang-18 CXX=clang++-18 ./configure --prefix ${GITHUB_WORKSPACE}/cpython-tsan --disable-gil --with-thread-sanitizer + make -j64 + make install + # Check whether free-threading mode is enabled + PYTHON_GIL=0 ${GITHUB_WORKSPACE}/cpython-tsan/bin/python3 -c "import sys; assert not sys._is_gil_enabled()" + ${GITHUB_WORKSPACE}/cpython-tsan/bin/python3 -m venv ${GITHUB_WORKSPACE}/venv + - name: Install JAX test requirements + run: | + source ${GITHUB_WORKSPACE}/venv/bin/activate + cd jax + python -m pip install -r build/test-requirements.txt + - name: Build and install JAX + run: | + source ${GITHUB_WORKSPACE}/venv/bin/activate + cd jax + python build/build.py \ + --bazel_options=--color=yes \ + --bazel_options=--copt=-fsanitize=thread \ + --bazel_options=--linkopt="-fsanitize=thread" \ + --bazel_options=--@rules_python//python/config_settings:py_freethreaded="yes" \ + --bazel_options=--@nanobind//:enabled_free_threading=True \ + --clang_path=/usr/bin/clang-18 + # We have to manually install nightly scipy, otherwise default scipy installation + # is failing to build it here: ../meson.build:84:0: ERROR: Unknown compiler(s) + python -m pip install -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple scipy + python -m pip install dist/jaxlib-*.whl + python -m pip install -e . + - name: Run tests + env: + JAX_NUM_GENERATED_CASES: 1 + JAX_ENABLE_X64: true + JAX_SKIP_SLOW_TESTS: true + PY_COLORS: 1 + run: | + source ${GITHUB_WORKSPACE}/venv/bin/activate + cd jax + echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" + echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" + echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" + # As we do not have yet free-threading support + # there will be the following warning: + # RuntimeWarning: The global interpreter lock (GIL) has been enabled to load module 'jaxlib.utils', + # which has not declared that it can run safely without the GIL. + # To avoid that we temporarily define PYTHON_GIL + export PYTHON_GIL=0 + + # Continue running all commands even if they failing + set +e + + python -m pytest -s -vvv tests/jaxpr_effects_test.py::EffectOrderingTest::test_different_threads_get_different_tokens + exit_code=$? + python -m pytest -s -vvv tests/api_test.py::CustomJVPTest::test_concurrent_initial_style + exit_code=$(( $exit_code | $? )) + python -m pytest -s -vvv tests/api_test.py::APITest::test_concurrent_device_get_and_put + exit_code=$(( $exit_code | $? )) + python -m pytest -s -vvv tests/api_test.py::JitTest::test_concurrent_jit + exit_code=$(( $exit_code | $? )) + + exit $exit_code \ No newline at end of file diff --git a/tests/api_test.py b/tests/api_test.py index 49cd33ee464c..0253ccb6e460 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2984,21 +2984,27 @@ def e(x): self.assertIn("stablehlo.sine", stablehlo) def test_concurrent_device_get_and_put(self): - def f(x): - for _ in range(100): - y = jax.device_put(x) - x = jax.device_get(y) - return x + # Capture ThreadSanitizer warnings and fail the test if anything reported + with jtu.capture_stderr() as get_output: + def f(x): + for _ in range(100): + y = jax.device_put(x) + x = jax.device_get(y) + return x - xs = [self.rng().randn(i) for i in range(10)] - # Make sure JAX backend is initialised on the main thread since some JAX - # backends install signal handlers. - jax.device_put(0) - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(partial(f, x)) for x in xs] - ys = [f.result() for f in futures] - for x, y in zip(xs, ys): - self.assertAllClose(x, y) + xs = [self.rng().randn(i) for i in range(10)] + # Make sure JAX backend is initialised on the main thread since some JAX + # backends install signal handlers. + jax.device_put(0) + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(partial(f, x)) for x in xs] + ys = [f.result() for f in futures] + for x, y in zip(xs, ys): + self.assertAllClose(x, y) + + captured = get_output() + if len(captured) > 0 and "ThreadSanitizer" in captured: + raise RuntimeError(f"ThreadSanitizer reported warnings:\n{captured}") def test_dtype_from_builtin_types(self): for dtype in [bool, int, float, complex]: @@ -7549,25 +7555,32 @@ def f(x, y): def test_concurrent_initial_style(self): # https://github.com/jax-ml/jax/issues/3843 - def unroll(param, sequence): - def scan_f(prev_state, inputs): - return prev_state, jax.nn.sigmoid(param * inputs) - return jnp.sum(jax.lax.scan(scan_f, None, sequence)[1]) - - def run(): - return jax.grad(unroll)(jnp.array(1.0), jnp.array([1.0])) - - expected = run() - - # we just don't want this to crash - n_workers = 2 - with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as e: - futures = [] - for _ in range(n_workers): - futures.append(e.submit(run)) - results = [f.result() for f in futures] - for ans in results: - self.assertAllClose(ans, expected) + + # Capture ThreadSanitizer warnings and fail the test if anything reported + with jtu.capture_stderr() as get_output: + def unroll(param, sequence): + def scan_f(prev_state, inputs): + return prev_state, jax.nn.sigmoid(param * inputs) + return jnp.sum(jax.lax.scan(scan_f, None, sequence)[1]) + + def run(): + return jax.grad(unroll)(jnp.array(1.0), jnp.array([1.0])) + + expected = run() + + # we just don't want this to crash + n_workers = 2 + with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as e: + futures = [] + for _ in range(n_workers): + futures.append(e.submit(run)) + results = [f.result() for f in futures] + for ans in results: + self.assertAllClose(ans, expected) + + captured = get_output() + if len(captured) > 0 and "ThreadSanitizer" in captured: + raise RuntimeError(f"ThreadSanitizer reported warnings:\n{captured}") def test_nondiff_argnums_vmap_tracer(self): # https://github.com/jax-ml/jax/issues/3964 diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 2e91792aa950..14c852f84e10 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -575,27 +575,38 @@ def g(x): def test_different_threads_get_different_tokens(self): if jax.device_count() < 2: raise unittest.SkipTest("Test requires >= 2 devices.") - tokens = [] - def _noop(_): - return () - def f(x): - # Runs in a thread. - res = jax.jit( - lambda x: callback_p.bind( - x, callback=_noop, effect=log_effect, out_avals=[]) - )(x) - tokens.append(dispatch.runtime_tokens.current_tokens[log_effect]) - return res - - t1 = threading.Thread(target=lambda: f(2.)) - t2 = threading.Thread(target=lambda: f(3.)) - t1.start() - t2.start() - t1.join() - t2.join() - token1, token2 = tokens - self.assertIsNot(token1, token2) + # Capture ThreadSanitizer warnings and fail the test if anything reported + with jtu.capture_stderr() as get_output: + tokens = [] + def _noop(_): + return () + + def f(x): + # Runs in a thread. + res = jax.jit( + lambda x: callback_p.bind( + x, callback=_noop, effect=log_effect, out_avals=[]) + )(x) + # This is necessary for free-threading mode + with threading.Lock(): + tokens.append(dispatch.runtime_tokens.current_tokens[log_effect]) + return res + + t1 = threading.Thread(target=lambda: f(2.)) + t2 = threading.Thread(target=lambda: f(3.)) + t1.start() + t2.start() + t1.join() + t2.join() + assert len(tokens) == 2, tokens + token1, token2 = tokens + self.assertIsNot(token1, token2) + + captured = get_output() + if len(captured) > 0 and "ThreadSanitizer" in captured: + raise RuntimeError(f"ThreadSanitizer reported warnings:\n{captured}") + class ParallelEffectsTest(jtu.JaxTestCase):