Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Generalize the code
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 27, 2024
1 parent 7fd4a42 commit e74b058
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 86 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ Aqua = "0.8.7"
ArgCheck = "2.3.0"
CUDA = "5.3.2"
ChainRulesCore = "1.23"
Documenter = "1.5.0"
ExplicitImports = "1.6.0"
GPUArraysCore = "0.1.6"
LinearAlgebra = "1.10"
PartialFunctions = "1.2"
Random = "1.10"
Expand Down
56 changes: 4 additions & 52 deletions ext/WeightInitializersCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,59 +6,11 @@ using WeightInitializers: WeightInitializers, NUM_TO_FPOINT, __partial_apply

const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG}

for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros)
name = Symbol(fname, T)
TP = NUM_TO_FPOINT[Symbol(T)]
@eval begin
function WeightInitializers.$(name)(rng::AbstractCuRNG, dims::Integer...; kwargs...)
return CUDA.$(fname)($TP, dims...; kwargs...)
end
end

@eval function WeightInitializers.$(name)(rng::AbstractCuRNG; kwargs...)
return __partial_apply($name, (rng, (; kwargs...)))
end
end

function WeightInitializers.sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...;
sparsity::Number, std::Number=T(0.01)) where {T <: Number}
if length(dims) != 2
throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization."))
end

rows, cols = dims
prop_zero = min(1.0, sparsity)
num_zeros = ceil(Integer, prop_zero * rows)
sparse_array = randn(rng, T, dims...) .* T(std)
sparse_array[1:num_zeros, :] .= CUDA.zero(T)

return CUDA.@allowscalar mapslices(shuffle, sparse_array, dims=1)
function WeightInitializers.__zeros(::AbstractCuRNG, T::Type, dims::Integer...)
return CUDA.zeros(T, dims...)
end

function WeightInitializers.identity_init(::AbstractCuRNG, ::Type{T}, dims::Integer...;
gain::Number=1, shift::Integer=0) where {T <: Number}
if length(dims) == 1
# Bias initialization
return CUDA.zeros(T, dims...)
elseif length(dims) == 2
# Matrix multiplication
rows, cols = dims
mat = CUDA.zeros(T, rows, cols)
diag_indices = 1:min(rows, cols)
CUDA.fill!(view(mat, diag_indices, diag_indices), T(gain))
return CUDA.circshift(mat, shift)
else
# Convolution or more dimensions
nin, nout = dims[end - 1], dims[end]
centers = map(d -> cld(d, 2), dims[1:(end - 2)])
weights = CUDA.zeros(T, dims...)
#we should really find a better way to do this
CUDA.@allowscalar for i in 1:min(nin, nout)
index = (centers..., i, i)
weights[index...] = T(gain)
end
return CUDA.circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift))
end
function WeightInitializers.__ones(::AbstractCuRNG, T::Type, dims::Integer...)
return CUDA.ones(T, dims...)
end

end
2 changes: 1 addition & 1 deletion src/WeightInitializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module WeightInitializers

#! format: off
using ChainRulesCore: ChainRulesCore
using GPUArraysCore: GPUArraysCore
using GPUArraysCore: @allowscalar
using LinearAlgebra: LinearAlgebra, Diagonal, qr
using PartialFunctions: :$
using Random: Random, AbstractRNG, Xoshiro, shuffle
Expand Down
62 changes: 29 additions & 33 deletions src/initializers.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
__zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T} = zeros(T, dims...)
__ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T} = ones(T, dims...)

for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros, :rand, :randn)
name = Symbol(fname, T)
docstring = __generic_docstring(string(name))
TP = NUM_TO_FPOINT[Symbol(T)]
if fname in (:ones, :zeros)
@eval begin
@doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...)
return $(fname)($TP, dims...; kwargs...)
end
end
else
@eval begin
@doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...)
return $(fname)(rng, $TP, dims...; kwargs...)
end
__fname = fname in (:ones, :zeros) ? Symbol("__", fname) : fname

@eval begin
@doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...)
return $__fname(rng, $TP, dims...; kwargs...)
end
end
end
Expand Down Expand Up @@ -222,9 +219,11 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
rows, cols = dims
prop_zero = min(1.0, sparsity)
num_zeros = ceil(Integer, prop_zero * rows)

sparse_array = randn(rng, T, dims...) .* T(std)
sparse_array[1:num_zeros, :] .= zero(T)
return mapslices(shuffle, sparse_array; dims=1)
fill!(view(sparse_array, 1:num_zeros, :), zero(T))

return @allowscalar mapslices(shuffle, sparse_array; dims=1)
end

"""
Expand Down Expand Up @@ -284,30 +283,27 @@ identity_tensor = identity_init(MersenneTwister(123), Float32, # Bias ini
5; gain=1.5, shift=(1, 0))
```
"""
function identity_init(::AbstractRNG, ::Type{T}, dims::Integer...;
function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=1, shift::Integer=0) where {T <: Number}
if length(dims) == 1
# Bias initialization
return zeros(T, dims...)
elseif length(dims) == 2
# Matrix multiplication
length(dims) == 1 && return __zeros(rng, T, dims...) # Bias initialization

if length(dims) == 2
rows, cols = dims
mat = zeros(T, rows, cols)
for i in 1:min(rows, cols)
mat[i, i] = T(gain)
end
mat = __zeros(rng, T, rows, cols)
diag_indices = 1:min(rows, cols)
fill!(view(mat, diag_indices, diag_indices), T(gain))
return circshift(mat, shift)
else
# Convolution or more dimensions
nin, nout = dims[end - 1], dims[end]
centers = map(d -> cld(d, 2), dims[1:(end - 2)])
weights = zeros(T, dims...)
for i in 1:min(nin, nout)
index = (centers..., i, i)
weights[index...] = T(gain)
end
return circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift))
end

# Convolution or more dimensions
nin, nout = dims[end - 1], dims[end]
centers = map(d -> cld(d, 2), dims[1:(end - 2)])
weights = __zeros(rng, T, dims...)
@allowscalar for i in 1:min(nin, nout)
index = (centers..., i, i)
weights[index...] = T(gain)
end
return circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift))
end

# Default Fallbacks for all functions
Expand Down

0 comments on commit e74b058

Please sign in to comment.