Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

refactor: cleanup the internals #35

Merged
merged 5 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions .buildkite/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ steps:
agents:
queue: "juliagpu"
cuda: "*"
env:
RETESTITEMS_NWORKERS: 2
if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test"
timeout_in_minutes: 60
matrix:
Expand Down Expand Up @@ -98,7 +96,6 @@ steps:
JULIA_AMDGPU_CORE_MUST_LOAD: "1"
JULIA_AMDGPU_HIP_MUST_LOAD: "1"
JULIA_AMDGPU_DISABLE_ARTIFACTS: "1"
RETESTITEMS_NWORKERS: 2
if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.pull_request.labels includes "run downstream test"
timeout_in_minutes: 60
matrix:
Expand Down Expand Up @@ -159,9 +156,4 @@ steps:
- "1"

env:
RETESTITEMS_NWORKERS: 8
RETESTITEMS_NWORKER_THREADS: 2
RETESTITEMS_TESTITEM_TIMEOUT: 3600
JULIA_PKG_SERVER: ""
JULIA_NUM_THREADS: 4
SECRET_CODECOV_TOKEN: "DpNKbuKYRX40vpyJCfTvQmxwls1hlCUWiZX4pnsukt9E8u4pf0WUcIroRv2UDDbGYjuk5izmZ9yAhZZhiGMhjFF/TIji3JiYe1sXWdfSrNk0N2+CNoXo+CIi3JvS7mB+YAIUTEi2Xph+L7R0d+It079PEispqVv4bdRuqgSbY7Rn3NSsoV1cB8uUaVFBJH4EewC6Hceg80QW7q+CBru+QECudKbAWnRVLoizRsgzIld+gTUqsI1PhR+vSpD+AfZzhVxmff55ttVcMUFGnL3w4L74qoLVPET52/GPLCOi3RLGSzBJjebSBqqKOwesT9xJ4yaZ21AEzyeOm86YRc2WYg==;U2FsdGVkX1/eBwyJ7Of++vKyAWDSBvSdJeiKmVmlaVKFU5CHejM+sDlSZWH/WmoBatLcqH+eUVEGXC+oWl5riw=="
2 changes: 0 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -172,5 +172,3 @@ jobs:

env:
BACKEND_GROUP: "CPU"
RETESTITEMS_NWORKERS: 4
RETESTITEMS_NWORKER_THREADS: 2
24 changes: 2 additions & 22 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "WeightInitializers"
uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.0.1"
version = "1.0.2"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down Expand Up @@ -29,36 +29,16 @@ WeightInitializersoneAPIExt = ["oneAPI", "GPUArrays"]

[compat]
AMDGPU = "0.9.6, 1"
Aqua = "0.8.7"
ArgCheck = "2.3.0"
CUDA = "5.3.2"
ChainRulesCore = "1.23"
ConcreteStructs = "0.2.3"
Documenter = "1.5.0"
ExplicitImports = "1.9.0"
GPUArrays = "10.2"
GPUArraysCore = "0.1.6"
GPUArrays = "10.2"
LinearAlgebra = "1.10"
Metal = "1.1.0"
Pkg = "1.10"
Random = "1.10"
ReTestItems = "1.24.0"
SpecialFunctions = "2.4"
StableRNGs = "1"
Statistics = "1.10"
Test = "1.10"
julia = "1.10"
oneAPI = "1.5.0"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Documenter", "ExplicitImports", "GPUArrays", "Pkg", "ReTestItems", "StableRNGs", "Test"]
14 changes: 7 additions & 7 deletions ext/WeightInitializersAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,32 @@ module WeightInitializersAMDGPUExt
using AMDGPU: AMDGPU, ROCArray
using GPUArrays: RNG
using Random: Random
using WeightInitializers: WeightInitializers
using WeightInitializers: DeviceAgnostic

@inline function WeightInitializers.__zeros(
function DeviceAgnostic.zeros(
::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number}
return AMDGPU.zeros(T, dims...)
end
@inline function WeightInitializers.__ones(
function DeviceAgnostic.ones(
::AMDGPU.rocRAND.RNG, ::Type{T}, dims::Integer...) where {T <: Number}
return AMDGPU.ones(T, dims...)
end

@inline function WeightInitializers.__zeros(
function DeviceAgnostic.zeros(
::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number}
return AMDGPU.zeros(T, dims...)
end
@inline function WeightInitializers.__ones(
function DeviceAgnostic.ones(
::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number}
return AMDGPU.ones(T, dims...)
end
@inline function WeightInitializers.__rand(
function DeviceAgnostic.rand(
rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = ROCArray{T}(undef, dims...)
Random.rand!(rng, y)
return y
end
@inline function WeightInitializers.__randn(
function DeviceAgnostic.randn(
rng::RNG, ::ROCArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = ROCArray{T}(undef, dims...)
Random.randn!(rng, y)
Expand Down
14 changes: 7 additions & 7 deletions ext/WeightInitializersCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,34 @@ module WeightInitializersCUDAExt
using CUDA: CUDA, CURAND, CuArray
using GPUArrays: RNG
using Random: Random
using WeightInitializers: WeightInitializers
using WeightInitializers: DeviceAgnostic

const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG}

@inline function WeightInitializers.__zeros(
function DeviceAgnostic.zeros(
::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number}
return CUDA.zeros(T, dims...)
end
@inline function WeightInitializers.__ones(
function DeviceAgnostic.ones(
::AbstractCuRNG, ::Type{T}, dims::Integer...) where {T <: Number}
return CUDA.ones(T, dims...)
end

@inline function WeightInitializers.__zeros(
function DeviceAgnostic.zeros(
::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number}
return CUDA.zeros(T, dims...)
end
@inline function WeightInitializers.__ones(
function DeviceAgnostic.ones(
::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number}
return CUDA.ones(T, dims...)
end
@inline function WeightInitializers.__rand(
function DeviceAgnostic.rand(
rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = CuArray{T}(undef, dims...)
Random.rand!(rng, y)
return y
end
@inline function WeightInitializers.__randn(
function DeviceAgnostic.randn(
rng::RNG, ::CuArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = CuArray{T}(undef, dims...)
Random.randn!(rng, y)
Expand Down
16 changes: 8 additions & 8 deletions ext/WeightInitializersGPUArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
module WeightInitializersGPUArraysExt

using GPUArrays: RNG
using WeightInitializers: WeightInitializers
using WeightInitializers: DeviceAgnostic

for f in (:__zeros, :__ones, :__rand, :__randn)
@eval @inline function WeightInitializers.$(f)(
for f in (:zeros, :ones, :rand, :randn)
@eval function DeviceAgnostic.$(f)(
rng::RNG, ::Type{T}, dims::Integer...) where {T <: Number}
return WeightInitializers.$(f)(rng, rng.state, T, dims...)
return DeviceAgnostic.$(f)(rng, rng.state, T, dims...)
end
end

## Certain backends don't support sampling Complex numbers, so we avoid hitting those
## dispatches
for f in (:__rand, :__randn)
@eval @inline function WeightInitializers.$(f)(
for f in (:rand, :randn)
@eval function DeviceAgnostic.$(f)(
rng::RNG, ::Type{<:Complex{T}}, args::Integer...) where {T <: Number}
real_part = WeightInitializers.$(f)(rng, rng.state, T, args...)
imag_part = WeightInitializers.$(f)(rng, rng.state, T, args...)
real_part = DeviceAgnostic.$(f)(rng, rng.state, T, args...)
imag_part = DeviceAgnostic.$(f)(rng, rng.state, T, args...)
return Complex{T}.(real_part, imag_part)
end
end
Expand Down
10 changes: 5 additions & 5 deletions ext/WeightInitializersMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@
using Metal: Metal, MtlArray
using GPUArrays: RNG
using Random: Random
using WeightInitializers: WeightInitializers
using WeightInitializers: DeviceAgnostic

@inline function WeightInitializers.__zeros(
function DeviceAgnostic.zeros(

Check warning on line 8 in ext/WeightInitializersMetalExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/WeightInitializersMetalExt.jl#L8

Added line #L8 was not covered by tests
::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number}
return Metal.zeros(T, dims...)
end
@inline function WeightInitializers.__ones(
function DeviceAgnostic.ones(

Check warning on line 12 in ext/WeightInitializersMetalExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/WeightInitializersMetalExt.jl#L12

Added line #L12 was not covered by tests
::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number}
return Metal.ones(T, dims...)
end
@inline function WeightInitializers.__rand(
function DeviceAgnostic.rand(

Check warning on line 16 in ext/WeightInitializersMetalExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/WeightInitializersMetalExt.jl#L16

Added line #L16 was not covered by tests
rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = MtlArray{T}(undef, dims...)
Random.rand!(rng, y)
return y
end
@inline function WeightInitializers.__randn(
function DeviceAgnostic.randn(

Check warning on line 22 in ext/WeightInitializersMetalExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/WeightInitializersMetalExt.jl#L22

Added line #L22 was not covered by tests
rng::RNG, ::MtlArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = MtlArray{T}(undef, dims...)
Random.randn!(rng, y)
Expand Down
10 changes: 5 additions & 5 deletions ext/WeightInitializersoneAPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@ module WeightInitializersoneAPIExt
using oneAPI: oneAPI, oneArray
using GPUArrays: RNG
using Random: Random
using WeightInitializers: WeightInitializers
using WeightInitializers: DeviceAgnostic

@inline function WeightInitializers.__zeros(
function DeviceAgnostic.zeros(
::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number}
return oneAPI.zeros(T, dims...)
end
@inline function WeightInitializers.__ones(
function DeviceAgnostic.ones(
::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number}
return oneAPI.ones(T, dims...)
end
@inline function WeightInitializers.__rand(
function DeviceAgnostic.rand(
rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = oneArray{T}(undef, dims...)
Random.rand!(rng, y)
return y
end
@inline function WeightInitializers.__randn(
function DeviceAgnostic.randn(
rng::RNG, ::oneArray, ::Type{T}, dims::Integer...) where {T <: Number}
y = oneArray{T}(undef, dims...)
Random.randn!(rng, y)
Expand Down
19 changes: 12 additions & 7 deletions src/WeightInitializers.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
module WeightInitializers

using ArgCheck: @argcheck
using ChainRulesCore: ChainRulesCore
using ConcreteStructs: @concrete
using ChainRulesCore: @non_differentiable
using GPUArraysCore: @allowscalar
using LinearAlgebra: LinearAlgebra, Diagonal, qr
using Random: Random, AbstractRNG, Xoshiro, shuffle
using SpecialFunctions: SpecialFunctions, erf, erfinv
using Random: Random, AbstractRNG, shuffle
using SpecialFunctions: SpecialFunctions, erfinv # TODO: Move to Ext in v2.0
using Statistics: Statistics, std

const CRC = ChainRulesCore

include("partial.jl")
include("utils.jl")
include("initializers.jl")
include("autodiff.jl")

# Mark the functions as non-differentiable
for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32,
:zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64,
:randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16,
:randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal,
:kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init]
@eval @non_differentiable $(f)(::Any...)
end

export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16,
rand16, randn16
Expand Down
13 changes: 0 additions & 13 deletions src/autodiff.jl

This file was deleted.

Loading
Loading