Skip to content

Commit

Permalink
Re-factor build CLI to a subcommand based approach
Browse files Browse the repository at this point in the history
This commit reworks the JAX build CLI to a subcommand based approach where artifacts are now defined as CLI subcommands. The new structure offers a clear and organized CLI that enables users to execute specific build tasks without having to navigate through a monolithic script.

Each subcommand has specific options and arguments that apply to its respective build process. This allows users to execute targeted build commands with relevant options (e.g CUDA arguments only apply to CUDA subcommands, ROCM arguments only apply to ROCM subcommands, etc.). This would also reduce the complexity and the potential for errors during the build process.  Segregating functionalities into distinct subcommands also simplifies the code which should help with the maintenance and future extensions.

Usage:
* Building `jaxlib`:
```
python build/build.py jaxlib --python_version=3.10
```
* Building `jax-cuda-plugin`:
```
python build/build.py jax-cuda-plugin --cuda_version=12.3.2 --cudnn_version=9.1.1 --python_version=3.10
```
* Building `jax-rocm-pjrt`:
```
python build/build.py jax-rocm-pjrt --rocm_version=60 --rocm_path=/path/to/rocm
```
* Using a local XLA path:
```
python build/build.py jaxlib --local_xla_path=/path/to/xla
```
* Updating requirements_lock.txt files:
```
python build/build.py requirements_update --python_version=3.10
```
PiperOrigin-RevId: 691466647
  • Loading branch information
nitins17 authored and Google-ML-Automation committed Nov 7, 2024
1 parent 04a6652 commit 4c2635a
Show file tree
Hide file tree
Showing 9 changed files with 805 additions and 520 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/asan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ jobs:
run: |
source ${GITHUB_WORKSPACE}/venv/bin/activate
cd jax
python build/build.py \
--bazel_options=--color=yes \
--bazel_options=--copt=-fsanitize=address \
python build/build.py jaxlib --verbose \
--bazel_build_options='--verbose_failures=true' \
--bazel_build_options='--copt=-fsanitize=address' \
--clang_path=/usr/bin/clang-18
pip install dist/jaxlib-*.whl
pip install -e .
Expand Down
7 changes: 3 additions & 4 deletions .github/workflows/wheel_win_x64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@ jobs:
python -m pip install -r build/test-requirements.txt
python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
python.exe build\build.py `
--bazel_options=--color=yes `
--bazel_options=--config=win_clang `
--verbose
python.exe build\build.py jaxlib --verbose `
--bazel_build_options='--verbose_failures=true' `
--bazel_build_options='--config=win_clang'
- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
with:
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/windows_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ jobs:
python -m pip install -r build/test-requirements.txt
python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
python.exe build\build.py `
--bazel_options=--color=yes `
--bazel_options=--config=win_clang
python.exe build\build.py jaxlib --verbose `
--bazel_build_options='--verbose_failures=true' `
--bazel_build_options='--config=win_clang'
- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
with:
Expand Down
69 changes: 69 additions & 0 deletions .github/workflows/windows_presubmit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
name: Presubmit - Windows CPU
on:
# TODO(DO_NOT_SUBMIT): temporary check
push:
branches:
- main
pull_request:
branches:
- main

permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows

env:
DISTUTILS_USE_SDK: 1
MSSdk: 1

jobs:
presubmit-win-wheels:
if: ${{ (github.event.action != 'labeled') || (github.event.label.name == 'windows:force-run')}}
strategy:
fail-fast: true
matrix:
os: [windows-2019-32core]
arch: [AMD64]
pyver: ['3.10']
name: ${{ matrix.os }} Windows build
runs-on: ${{ matrix.os }}

steps:
- name: Cancel Previous Runs
uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/[email protected]
with:
access_token: ${{ github.token }}

- name: Install LLVM/Clang
run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade

- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.pyver }}
cache: 'pip'

- name: Build wheels
env:
BAZEL_VC: "C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Enterprise\\VC"
JAXLIB_RELEASE: true
run: |
python -m pip install -r build/test-requirements.txt
python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1
"C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH
python.exe build\build.py jaxlib --verbose `
--bazel_build_options='--verbose_failures=true' `
--bazel_build_options="--config=win_clang" `
--bazel_build_options='--color=yes'
- name: Run tests
env:
JAX_ENABLE_CHECKS: true
JAX_SKIP_SLOW_TESTS: true
PY_COLORS: 1
run: |
python -m pip install --find-links ${{ github.workspace }}\dist jaxlib
python -m pip install -e ${{ github.workspace }}
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
pytest -n auto --tb=short tests examples
Loading

0 comments on commit 4c2635a

Please sign in to comment.