Skip to content

Commit

Permalink
Add GitHub action workflow for Bazel CUDA continuous tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705238265
  • Loading branch information
nitins17 authored and Google-ML-Automation committed Jan 18, 2025
1 parent 9fb2976 commit dc7a9b3
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 3 deletions.
81 changes: 81 additions & 0 deletions .github/workflows/bazel_cuda_non_rbe.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# CI - Bazel CUDA tests (Non-RBE)
#
# This workflow runs the CUDA tests with Bazel. It can only be triggered by other workflows via
# `workflow_call`. It is used by the `CI - Wheel Tests` workflows to run the Bazel CUDA tests.
#
# It consists of the following job:
# run-tests:
# - Downloads the jaxlib and CUDA artifacts from a GCS bucket.
# - Executes the `run_bazel_test_cuda_non_rbe.sh` script, which performs the following actions:
# - Installs the downloaded wheel artifacts.
# - Runs the CUDA tests with Bazel.
name: CI - Bazel CUDA tests (Non-RBE)

on:
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
required: true
default: "linux-x86-n2-16"
python:
description: "Which python version to test?"
type: string
required: true
default: "3.12"
enable-x64:
description: "Should x64 mode be enabled?"
type: string
required: true
default: "0"
gcs_download_uri:
description: "GCS location URI from where the artifacts should be downloaded"
required: true
default: 'gs://general-ml-ci-transient/jax-github-actions/jax-fork/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: boolean
required: false
default: false

jobs:
run-tests:
runs-on: ${{ inputs.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"

env:
JAXCI_HERMETIC_PYTHON_VERSION: ${{ inputs.python }}
JAXCI_ENABLE_X64: ${{ inputs.enable-x64 }}

name: "Bazel single accelerator and multi-accelerator CUDA tests (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"

steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set env vars for use in artifact download URL
run: |
os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)
# Get the major and minor version of Python.
# E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310
python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.')
echo "OS=${os}" >> $GITHUB_ENV
echo "ARCH=${arch}" >> $GITHUB_ENV
echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV
- name: Download the wheel artifacts from GCS
run: >-
mkdir -p $(pwd)/dist &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ &&
gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Bazel CUDA tests (Non-RBE)
timeout-minutes: 60
run: ./ci/run_bazel_test_cuda_non_rbe.sh
2 changes: 1 addition & 1 deletion .github/workflows/pytest_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# run-tests:
# - Downloads the jaxlib and CUDA artifacts from a GCS bucket.
# - Executes the `run_pytest_cuda.sh` script, which performs the following actions:
# - Installs the downloaded jaxlib wheel.
# - Installs the downloaded wheel artifacts.
# - Runs the CUDA tests with Pytest.
name: CI - Pytest CUDA

Expand Down
27 changes: 26 additions & 1 deletion .github/workflows/wheel_tests_continuous.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,19 @@
# uploads them to a GCS bucket.
# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow to download the jaxlib and CUDA artifacts
# that were built in the previous steps and runs the CUDA tests.
# 5. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow to download the jaxlib and
# CUDA artifacts that were built in the previous steps and runs the CUDA
# tests using Bazel.

name: CI - Wheel Tests (Continuous)

on:
schedule:
- cron: "0 */2 * * *" # Run once every 2 hours
# DO NOT SUBMIT without removing this
pull_request:
branches:
- main

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
Expand Down Expand Up @@ -44,7 +52,7 @@ jobs:
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the GPU tests job below
# Python values need to match the matrix stategy in the CUDA tests job below
runner: ["linux-x86-n2-16"]
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"]
python: ["3.10",]
Expand Down Expand Up @@ -98,4 +106,21 @@ jobs:
cuda: ${{ matrix.cuda }}
enable-x64: ${{ matrix.enable-x64 }}
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}

run-bazel-test-cuda:
needs: [build-jaxlib-artifact, build-cuda-artifacts]
uses: ./.github/workflows/bazel_cuda_non_rbe.yml
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
# Python values need to match the matrix stategy in the build artifacts job above
runner: ["linux-x86-g2-48-l4-4gpu",]
python: ["3.10",]
enable-x64: [1, 0]
with:
runner: ${{ matrix.runner }}
python: ${{ matrix.python }}
enable-x64: ${{ matrix.enable-x64 }}
# GCS upload URI is the same for both artifact build jobs
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# -x: log all commands
# -o history: record shell history
# -o allexport: export all functions and variables to be available to subscripts
set -exu -o history -o allexport
set -xu -o history -o allexport

# Source default JAXCI environment variables.
source ci/envs/default.env
Expand Down

0 comments on commit dc7a9b3

Please sign in to comment.