Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

feat: XLADevice via Reactant.jl #78

Merged
merged 9 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .buildkite/testing.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
steps:
- group: ":julia: CUDA GPU"
steps:
- label: ":julia: Julia {{matrix.julia}} + CUDA GPU"
- label: ":julia: Julia {{matrix.julia}} + CUDA GPU (Backend Group: {{matrix.group}})"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.julia}}"
Expand All @@ -16,13 +16,16 @@ steps:
queue: "juliagpu"
cuda: "*"
env:
BACKEND_GROUP: "CUDA"
BACKEND_GROUP: "{{matrix.group}}"
if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip ci\]/
timeout_in_minutes: 60
matrix:
setup:
julia:
- "1"
group:
- CUDA
- XLA

- group: ":telescope: Downstream CUDA"
steps:
Expand Down
15 changes: 11 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ concurrency:

jobs:
ci:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ github.event_name }}
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.group }} - ${{ github.event_name }}
if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }}
runs-on: ${{ matrix.os }}
strategy:
Expand All @@ -33,6 +33,12 @@ jobs:
- ubuntu-latest
- macos-latest
- windows-latest
group:
- CPU
- XLA
exclude:
- os: windows-latest
group: XLA
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand All @@ -50,6 +56,8 @@ jobs:
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
BACKEND_GROUP: ${{ matrix.group }}
- uses: julia-actions/julia-processcoverage@v1
with:
directories: src,ext
Expand Down Expand Up @@ -133,6 +141,8 @@ jobs:
- uses: julia-actions/julia-downgrade-compat@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
BACKEND_GROUP: CPU
- uses: julia-actions/julia-processcoverage@v1
with:
directories: src,ext
Expand Down Expand Up @@ -171,6 +181,3 @@ jobs:
- name: Check if the PR does increase number of invalidations
if: steps.invs_pr.outputs.total > steps.invs_default.outputs.total
run: exit 1

env:
BACKEND_GROUP: "CPU"
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.1.1"
version = "1.2.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -17,6 +17,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -33,6 +34,7 @@ MLDataDevicesFillArraysExt = "FillArrays"
MLDataDevicesGPUArraysExt = "GPUArrays"
MLDataDevicesMLUtilsExt = "MLUtils"
MLDataDevicesMetalExt = ["GPUArrays", "Metal"]
MLDataDevicesReactantExt = "Reactant"
MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools"
MLDataDevicesReverseDiffExt = "ReverseDiff"
MLDataDevicesSparseArraysExt = "SparseArrays"
Expand All @@ -53,6 +55,7 @@ MLUtils = "0.4.4"
Metal = "1"
Preferences = "1.4"
Random = "1.10"
Reactant = "0.2"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Expand Down
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ devices. It is used in deep learning frameworks such as [Lux.jl](https://lux.csa

Currently we provide support for the following backends:

1. `CUDA.jl` for NVIDIA GPUs.
2. `AMDGPU.jl` for AMD ROCM GPUs.
3. `Metal.jl` for Apple Metal GPUs. **(Experimental)**
4. `oneAPI.jl` for Intel GPUs. **(Experimental)**
1. `CPUDevice`: for CPUs -- no additional packages required.
2. `CUDADevice`: `CUDA.jl` for NVIDIA GPUs.
3. `AMDGPUDevice`: `AMDGPU.jl` for AMD ROCM GPUs.
4. `MetalDevice`: `Metal.jl` for Apple Metal GPUs. **(Experimental)**
5. `oneAPIDevice`: `oneAPI.jl` for Intel GPUs. **(Experimental)**
6. `XLADevice`: `Reactant.jl` for XLA Support. **(Experimental)**

## Updating to v1.0

Expand Down
6 changes: 3 additions & 3 deletions ext/MLDataDevicesMLUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module MLDataDevicesMLUtilsExt

using MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, DeviceIterator
MetalDevice, oneAPIDevice, XLADevice, DeviceIterator
using MLUtils: MLUtils, DataLoader

for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice)
for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, XLADevice)
@eval function (D::$(dev))(dataloader::DataLoader)
if dataloader.parallel
if dataloader.buffer
Expand All @@ -22,7 +22,7 @@ for dev in (CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice)
data
end

return DeviceIterator(D, eachobsparallel(D, data))
return DeviceIterator(identity, eachobsparallel(D, data))
end
return DeviceIterator(D, dataloader)
end
Expand Down
26 changes: 26 additions & 0 deletions ext/MLDataDevicesReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module MLDataDevicesReactantExt

using Adapt: Adapt
using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice
using Reactant: Reactant, RArray

MLDataDevices.loaded(::Union{XLADevice, Type{<:XLADevice}}) = true
MLDataDevices.functional(::Union{XLADevice, Type{<:XLADevice}}) = true

# Default RNG: Forward to CPU, we will compile it
function MLDataDevices.default_device_rng(::XLADevice)
return MLDataDevices.default_device_rng(CPUDevice())
end

# Query Device from Array
Internal.get_device(::RArray) = XLADevice()

Internal.get_device_type(::RArray) = XLADevice

# unsafe_free!
Internal.unsafe_free_internal!(::Type{XLADevice}, x::AbstractArray) = nothing

# Device Transfer
Adapt.adapt_storage(::XLADevice, x::AbstractArray) = Reactant.to_rarray(x)

end
10 changes: 7 additions & 3 deletions src/MLDataDevices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@ using Preferences: @delete_preferences!, @load_preference, @set_preferences!
using Random: AbstractRNG, Random

abstract type AbstractDevice <: Function end
abstract type AbstractGPUDevice <: AbstractDevice end
abstract type AbstractCPUDevice <: AbstractDevice end
abstract type AbstractAcceleratorDevice <: AbstractDevice end
abstract type AbstractGPUDevice <: AbstractAcceleratorDevice end

include("public.jl")
include("iterator.jl")
include("internal.jl")

export gpu_backend!, supported_gpu_backends, reset_gpu_device!
export default_device_rng
export gpu_device, cpu_device
export gpu_device, cpu_device, xla_device

export CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice
export CPUDevice
export CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice
export XLADevice
export get_device, get_device_type

export DeviceIterator
Expand Down
23 changes: 14 additions & 9 deletions src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ using Preferences: load_preference
using Random: AbstractRNG

using ..MLDataDevices: MLDataDevices, AbstractDevice, CPUDevice, CUDADevice, AMDGPUDevice,
MetalDevice, oneAPIDevice, supported_gpu_backends, GPU_DEVICES,
loaded, functional
MetalDevice, oneAPIDevice, XLADevice, supported_gpu_backends,
GPU_DEVICES, loaded, functional

for dev in (CPUDevice, MetalDevice, oneAPIDevice)
msg = "`device_id` is not applicable for `$dev`."
Expand All @@ -27,18 +27,23 @@ for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
get_triggerpkg_name(::Union{$ldev, Type{<:$ldev}}) = $(tpkg)
end
end
get_device_name(::XLADevice) = "XLA"
get_triggerpkg_name(::XLADevice) = "Reactant"

for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing}, MetalDevice, oneAPIDevice)
for T in (CPUDevice, CUDADevice{Nothing}, AMDGPUDevice{Nothing},
MetalDevice, oneAPIDevice, XLADevice)
@eval get_device_id(::$(T)) = nothing
end

struct DeviceSelectionException <: Exception end
struct DeviceSelectionException <: Exception
dev::String
end

function Base.showerror(io::IO, ::DeviceSelectionException)
return print(io, "DeviceSelectionException(No functional GPU device found!!)")
function Base.showerror(io::IO, d::DeviceSelectionException)
return print(io, "DeviceSelectionException: No functional $(d.dev) device found!")
end

function get_gpu_device(; force_gpu_usage::Bool)
function get_gpu_device(; force::Bool)
backend = load_preference(MLDataDevices, "gpu_backend", nothing)

# If backend set with preferences, use it
Expand Down Expand Up @@ -85,7 +90,7 @@ function get_gpu_device(; force_gpu_usage::Bool)
end
end

force_gpu_usage && throw(DeviceSelectionException())
force && throw(DeviceSelectionException("GPU"))
@warn """No functional GPU backend found! Defaulting to CPU.

1. If no GPU is available, nothing needs to be done.
Expand Down Expand Up @@ -144,7 +149,7 @@ for op in (:get_device, :get_device_type)
end
end

for T in (Number, AbstractRNG, Val, Symbol, String, Nothing)
for T in (Number, AbstractRNG, Val, Symbol, String, Nothing, AbstractRange)
@eval $(op)(::$(T)) = $(op == :get_device ? nothing : Nothing)
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ julia> for (i, x) in enumerate(CUDADevice()(dataloader))
(i, summary(x)) = (3, "3×7 CuArray{Float32, 2, CUDA.DeviceMemory}")
```
"""
struct DeviceIterator{D <: AbstractDevice, I}
struct DeviceIterator{D <: Function, I}
dev::D
iterator::I
end
Expand Down
Loading
Loading