Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
default_device_rng
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 12, 2024
1 parent ac28e86 commit aa4d84c
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 12 deletions.
2 changes: 1 addition & 1 deletion ext/LuxDeviceUtilsLuxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ LuxDeviceUtils.__is_loaded(::LuxAMDGPUDevice) = true
LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional()

# Default RNG
device_default_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng()
LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng()

# Device Transfer
## To GPU
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxDeviceUtilsLuxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ LuxDeviceUtils.__is_loaded(::LuxCUDADevice) = true
LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional()

# Default RNG
device_default_rng(::LuxCUDADevice) = CUDA.default_rng()
LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng()

# Device Transfer
## To GPU
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxDeviceUtilsMetalGPUArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ LuxDeviceUtils.__is_loaded(::LuxMetalDevice) = true
LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional()

# Default RNG
device_default_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray)
LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray)

# Device Transfer
## To GPU
Expand Down
1 change: 0 additions & 1 deletion ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,4 @@ function Adapt.adapt_structure(to::LuxDeviceUtils.AbstractLuxDeviceAdaptor,
return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t)
end


end
16 changes: 8 additions & 8 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays
import Adapt: adapt, adapt_storage

export gpu_backend!, supported_gpu_backends, reset_gpu_device!
export device_default_rng
export default_device_rng
export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice
export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor

Expand Down Expand Up @@ -209,20 +209,20 @@ Return a `LuxCPUDevice` object which can be used to transfer data to CPU.
@inline cpu_device() = LuxCPUDevice()

"""
device_default_rng(::AbstractLuxDevice)
default_device_rng(::AbstractLuxDevice)
Returns the default RNG for the device. This can be used to directly generate parameters
and states on the device using
[WeightInitializers.jl](https://github.com/LuxDL/WeightInitializers.jl).
"""
function device_default_rng(D::AbstractLuxDevice)
error("""`device_default_rng` not implemented for $(typeof(D)). This is either because:
function default_device_rng(D::AbstractLuxDevice)
return error("""`default_device_rng` not implemented for $(typeof(D)). This is either because:
1. The default RNG for this device is not known / officially provided.
2. The trigger package for the device is not loaded.
""")
1. The default RNG for this device is not known / officially provided.
2. The trigger package for the device is not loaded.
""")
end
device_default_rng(::LuxCPUDevice) = Random.default_rng()
default_device_rng(::LuxCPUDevice) = Random.default_rng()

# Dispatches for Different Data Structures
# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability
Expand Down

0 comments on commit aa4d84c

Please sign in to comment.