-
Notifications
You must be signed in to change notification settings - Fork 2.9k
146 lines (137 loc) · 5.27 KB
/
build_artifacts.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# CI - Build JAX Artifacts
# This workflow builds JAX wheels (jax, jaxlib, jax-cuda-plugin, and jax-cuda-pjrt) with a set of
# configuration options (platform, python version, whether to use latest XLA, etc). It can be
# triggered manually via workflow_dispatch or called by other workflows via workflow_call. When a
# workflow call is made, this workflow will build the artifacts and upload it to a GCS bucket so
# that other workflows (e.g. Pytest workflows) can use it.
name: CI - Build JAX Artifacts
on:
workflow_dispatch:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: choice
required: true
default: "linux-x86-n2-16"
options:
- "linux-x86-n2-16"
- "linux-arm64-c4a-64"
- "windows-x86-n2-16"
artifact:
description: "Which JAX artifact to build?"
type: choice
required: true
default: "jaxlib"
options:
- "jax"
- "jaxlib"
- "jax-cuda-plugin"
- "jax-cuda-pjrt"
python:
description: "Which python version should the artifact be built for?"
type: choice
required: false
default: "3.12"
options:
- "3.10"
- "3.11"
- "3.12"
- "3.13"
clone_main_xla:
description: "Should latest XLA be used?"
type: choice
required: false
default: "0"
options:
- "1"
- "0"
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
required: false
default: 'no'
options:
- 'yes'
- 'no'
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
required: true
default: "linux-x86-n2-16"
artifact:
description: "Which JAX artifact to build?"
type: string
required: true
default: "jaxlib"
python:
description: "Which python version should the artifact be built for?"
type: string
required: false
default: "3.12"
clone_main_xla:
description: "Should latest XLA be used?"
type: string
required: false
default: "0"
upload_artifacts_to_gcs:
description: "Should the artifacts be uploaded to a GCS bucket?"
required: true
default: true
type: boolean
gcs_upload_uri:
description: "GCS location prefix to where the artifacts should be uploaded"
required: true
default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}'
type: string
outputs:
gcs_upload_uri:
description: "GCS location prefix to where the artifacts were uploaded"
value: ${{ jobs.build-artifacts.outputs.gcs_upload_uri }}
permissions:
contents: read
jobs:
build-artifacts:
defaults:
run:
# Explicitly set the shell to bash to override Windows's default (cmd)
shell: bash
runs-on: ${{ inputs.runner }}
container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
(contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') ||
(contains(inputs.runner, 'windows-x86') && null) }}
env:
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}"
name: Build ${{ inputs.artifact }} (${{ inputs.runner }}, Python ${{ inputs.python }}, clone main XLA=${{ inputs.clone_main_xla }})
# Map the job outputs to step outputs
outputs:
gcs_upload_uri: ${{ steps.store-gcs-upload-uri.outputs.gcs_upload_uri }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Enable RBE if building on Linux x86 or Windows x86
if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86')
run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Build ${{ inputs.artifact }}
timeout-minutes: 30
run: ./ci/build_artifacts.sh "${{ inputs.artifact }}"
- name: Upload artifacts to a GCS bucket (non-Windows runs)
if: >-
${{ inputs.upload_artifacts_to_gcs && !contains(inputs.runner, 'windows-x86') }}
run: gsutil -m cp -r "$(pwd)/dist/*.whl" "${{ inputs.gcs_upload_uri }}"/
# Set shell to cmd to avoid path errors when using gcloud commands on Windows
- name: Upload artifacts to a GCS bucket (Windows runs)
if: >-
${{ inputs.upload_artifacts_to_gcs && contains(inputs.runner, 'windows-x86') }}
shell: cmd
run: gsutil -m cp -r "dist/*.whl" "${{ inputs.gcs_upload_uri }}"/
- name: Store the GCS upload URI as an output
id: store-gcs-upload-uri
if: ${{ inputs.upload_artifacts_to_gcs }}
run: echo "gcs_upload_uri=${{ inputs.gcs_upload_uri }}" >> "$GITHUB_OUTPUT"