-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Github action workflow for building JAX artifacts
PiperOrigin-RevId: 702497163
- Loading branch information
1 parent
ceeed90
commit 4c3721e
Showing
3 changed files
with
79 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
name: CI - Build JAX Artifacts | ||
|
||
on: | ||
pull_request: | ||
branches: | ||
- main | ||
workflow_dispatch: | ||
inputs: | ||
halt-for-connection: | ||
description: 'Should this workflow run wait for a remote connection?' | ||
type: choice | ||
required: true | ||
default: 'no' | ||
options: | ||
- 'yes' | ||
- 'no' | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
build-artifacts: | ||
if: github.event.repository.fork == false | ||
|
||
defaults: | ||
run: | ||
# Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. | ||
shell: bash | ||
|
||
strategy: | ||
fail-fast: false # don't cancel all jobs on failure | ||
matrix: | ||
runner: ["windows-x86-n2-16", "linux-x86-n2-16", "linux-arm64-t2a-16"] | ||
artifact: ["jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"] | ||
python: ["3.10", "3.11", "3.12", "3.13"] | ||
exclude: | ||
# Don't build jax-cuda-pjrt and jax-cuda-plugin on windows-x86-n2-16 | ||
- runner: "windows-x86-n2-64" | ||
artifact: "jax-cuda-pjrt" | ||
- runner: "windows-x86-n2-64" | ||
artifact: "jax-cuda-plugin" | ||
# Don't build jax-cuda-pjrt for each python version | ||
- artifact: "jax-cuda-pjrt" | ||
python: 3.10 | ||
- artifact: "jax-cuda-pjrt" | ||
python: 3.11 | ||
- artifact: "jax-cuda-pjrt" | ||
python: 3.12 | ||
|
||
runs-on: ${{ matrix.runner }} | ||
|
||
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || | ||
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || | ||
(contains(matrix.runner, 'windows-x86') && null) }} | ||
|
||
env: | ||
JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" | ||
|
||
name: Build ${{ matrix.artifact }} on ${{ matrix.runner }} with Python ${{ matrix.python }} | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Enable RBE on platforms where its supported | ||
run: | | ||
os=$(uname -s | awk '{print tolower($0)}') | ||
arch=$(uname -m) | ||
# Enable RBE if building on Linux x86 or Windows x86 | ||
if [[ ($os == "linux" || $os =~ "msys_nt" ) && $arch == "x86_64" ]]; then | ||
echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV | ||
fi | ||
# Halt for testing | ||
- name: Wait For Connection | ||
uses: google-ml-infra/actions/ci_connection@main | ||
with: | ||
halt-dispatch-input: ${{ inputs.halt-for-connection }} | ||
- name: Build ${{ matrix.artifact }} | ||
run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" |
Empty file.
Empty file.