CI - Build JAX Artifacts #8
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CI - Build JAX Artifacts | |
# This workflow builds JAX wheels (jax, jaxlib, jax-cuda-plugin, and jax-cuda-pjrt) with a set of | |
# configuration options (platform, python version, whether to use latest XLA, etc). It can be | |
# triggered manually via workflow_dispatch or called by other workflows via workflow_call. When a | |
# workflow call is made, this workflow will build the artifacts and upload it to a GCS bucket so | |
# that other workflows (e.g. Pytest workflows) can use it. | |
name: CI - Build JAX Artifacts | |
on: | |
workflow_dispatch: | |
inputs: | |
runner: | |
description: "Which runner should the workflow run on?" | |
type: choice | |
required: true | |
default: "linux-x86-n2-16" | |
options: | |
- "linux-x86-n2-16" | |
- "linux-arm64-c4a-64" | |
- "windows-x86-n2-16" | |
artifact: | |
description: "Which JAX artifact to build?" | |
type: choice | |
required: true | |
default: "jaxlib" | |
options: | |
- "jax" | |
- "jaxlib" | |
- "jax-cuda-plugin" | |
- "jax-cuda-pjrt" | |
python: | |
description: "Which python version should the artifact be built for?" | |
type: choice | |
required: false | |
default: "3.12" | |
options: | |
- "3.10" | |
- "3.11" | |
- "3.12" | |
- "3.13" | |
clone_main_xla: | |
description: "Should latest XLA be used?" | |
type: choice | |
required: false | |
default: "0" | |
options: | |
- "1" | |
- "0" | |
halt-for-connection: | |
description: 'Should this workflow run wait for a remote connection?' | |
type: choice | |
required: false | |
default: 'no' | |
options: | |
- 'yes' | |
- 'no' | |
workflow_call: | |
inputs: | |
runner: | |
description: "Which runner should the workflow run on?" | |
type: string | |
required: true | |
default: "linux-x86-n2-16" | |
artifact: | |
description: "Which JAX artifact to build?" | |
type: string | |
required: true | |
default: "jaxlib" | |
python: | |
description: "Which python version should the artifact be built for?" | |
type: string | |
required: false | |
default: "3.12" | |
clone_main_xla: | |
description: "Should latest XLA be used?" | |
type: string | |
required: false | |
default: "0" | |
upload_artifacts_to_gcs: | |
description: "Should the artifacts be uploaded to a GCS bucket?" | |
required: true | |
default: true | |
type: boolean | |
gcs_upload_uri: | |
description: "GCS location prefix to where the artifacts should be uploaded" | |
required: true | |
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | |
type: string | |
outputs: | |
gcs_upload_uri: | |
description: "GCS location prefix to where the artifacts were uploaded" | |
value: ${{ inputs.gcs_upload_uri }} | |
permissions: | |
contents: read | |
jobs: | |
build-artifacts: | |
defaults: | |
run: | |
# Explicitly set the shell to bash to override Windows's default (cmd) | |
shell: bash | |
runs-on: ${{ inputs.runner }} | |
container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || | |
(contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || | |
(contains(inputs.runner, 'windows-x86') && null) }} | |
env: | |
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" | |
JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" | |
name: Build ${{ inputs.artifact }} (${{ inputs.runner }}, Python ${{ inputs.python }}, clone main XLA=${{ inputs.clone_main_xla }}) | |
steps: | |
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | |
- name: Enable RBE if building on Linux x86 or Windows x86 | |
if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') | |
run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV | |
# Halt for testing | |
- name: Wait For Connection | |
uses: google-ml-infra/actions/ci_connection@main | |
with: | |
halt-dispatch-input: ${{ inputs.halt-for-connection }} | |
- name: Build ${{ inputs.artifact }} | |
run: ./ci/build_artifacts.sh "${{ inputs.artifact }}" | |
- name: Upload artifacts to a GCS bucket (non-Windows runs) | |
if: >- | |
${{ inputs.upload_artifacts_to_gcs && !contains(inputs.runner, 'windows-x86') }} | |
run: gsutil -m cp -r "$(pwd)/dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ | |
# Set shell to cmd to avoid path errors when using gcloud commands on Windows | |
- name: Upload artifacts to a GCS bucket (Windows runs) | |
if: >- | |
${{ inputs.upload_artifacts_to_gcs && contains(inputs.runner, 'windows-x86') }} | |
shell: cmd | |
run: gsutil -m cp -r "dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ |