From e74b0586c9bcb264a5f14175e190d84142e37591 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Jun 2024 21:30:58 -0700 Subject: [PATCH] Generalize the code --- Project.toml | 2 ++ ext/WeightInitializersCUDAExt.jl | 56 +++-------------------------- src/WeightInitializers.jl | 2 +- src/initializers.jl | 62 +++++++++++++++----------------- 4 files changed, 36 insertions(+), 86 deletions(-) diff --git a/Project.toml b/Project.toml index afbc7c1..be3e84a 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,9 @@ Aqua = "0.8.7" ArgCheck = "2.3.0" CUDA = "5.3.2" ChainRulesCore = "1.23" +Documenter = "1.5.0" ExplicitImports = "1.6.0" +GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" PartialFunctions = "1.2" Random = "1.10" diff --git a/ext/WeightInitializersCUDAExt.jl b/ext/WeightInitializersCUDAExt.jl index ad1bd50..e97f268 100644 --- a/ext/WeightInitializersCUDAExt.jl +++ b/ext/WeightInitializersCUDAExt.jl @@ -6,59 +6,11 @@ using WeightInitializers: WeightInitializers, NUM_TO_FPOINT, __partial_apply const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} -for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros) - name = Symbol(fname, T) - TP = NUM_TO_FPOINT[Symbol(T)] - @eval begin - function WeightInitializers.$(name)(rng::AbstractCuRNG, dims::Integer...; kwargs...) - return CUDA.$(fname)($TP, dims...; kwargs...) - end - end - - @eval function WeightInitializers.$(name)(rng::AbstractCuRNG; kwargs...) - return __partial_apply($name, (rng, (; kwargs...))) - end -end - -function WeightInitializers.sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; - sparsity::Number, std::Number=T(0.01)) where {T <: Number} - if length(dims) != 2 - throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) - end - - rows, cols = dims - prop_zero = min(1.0, sparsity) - num_zeros = ceil(Integer, prop_zero * rows) - sparse_array = randn(rng, T, dims...) .* T(std) - sparse_array[1:num_zeros, :] .= CUDA.zero(T) - - return CUDA.@allowscalar mapslices(shuffle, sparse_array, dims=1) +function WeightInitializers.__zeros(::AbstractCuRNG, T::Type, dims::Integer...) + return CUDA.zeros(T, dims...) end - -function WeightInitializers.identity_init(::AbstractCuRNG, ::Type{T}, dims::Integer...; - gain::Number=1, shift::Integer=0) where {T <: Number} - if length(dims) == 1 - # Bias initialization - return CUDA.zeros(T, dims...) - elseif length(dims) == 2 - # Matrix multiplication - rows, cols = dims - mat = CUDA.zeros(T, rows, cols) - diag_indices = 1:min(rows, cols) - CUDA.fill!(view(mat, diag_indices, diag_indices), T(gain)) - return CUDA.circshift(mat, shift) - else - # Convolution or more dimensions - nin, nout = dims[end - 1], dims[end] - centers = map(d -> cld(d, 2), dims[1:(end - 2)]) - weights = CUDA.zeros(T, dims...) - #we should really find a better way to do this - CUDA.@allowscalar for i in 1:min(nin, nout) - index = (centers..., i, i) - weights[index...] = T(gain) - end - return CUDA.circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) - end +function WeightInitializers.__ones(::AbstractCuRNG, T::Type, dims::Integer...) + return CUDA.ones(T, dims...) end end diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index 6b485a8..8838112 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -2,7 +2,7 @@ module WeightInitializers #! format: off using ChainRulesCore: ChainRulesCore -using GPUArraysCore: GPUArraysCore +using GPUArraysCore: @allowscalar using LinearAlgebra: LinearAlgebra, Diagonal, qr using PartialFunctions: :$ using Random: Random, AbstractRNG, Xoshiro, shuffle diff --git a/src/initializers.jl b/src/initializers.jl index 65071f3..7877d2b 100644 --- a/src/initializers.jl +++ b/src/initializers.jl @@ -1,18 +1,15 @@ +__zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T} = zeros(T, dims...) +__ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T} = ones(T, dims...) + for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros, :rand, :randn) name = Symbol(fname, T) docstring = __generic_docstring(string(name)) TP = NUM_TO_FPOINT[Symbol(T)] - if fname in (:ones, :zeros) - @eval begin - @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) - return $(fname)($TP, dims...; kwargs...) - end - end - else - @eval begin - @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) - return $(fname)(rng, $TP, dims...; kwargs...) - end + __fname = fname in (:ones, :zeros) ? Symbol("__", fname) : fname + + @eval begin + @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) + return $__fname(rng, $TP, dims...; kwargs...) end end end @@ -222,9 +219,11 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; rows, cols = dims prop_zero = min(1.0, sparsity) num_zeros = ceil(Integer, prop_zero * rows) + sparse_array = randn(rng, T, dims...) .* T(std) - sparse_array[1:num_zeros, :] .= zero(T) - return mapslices(shuffle, sparse_array; dims=1) + fill!(view(sparse_array, 1:num_zeros, :), zero(T)) + + return @allowscalar mapslices(shuffle, sparse_array; dims=1) end """ @@ -284,30 +283,27 @@ identity_tensor = identity_init(MersenneTwister(123), Float32, # Bias ini 5; gain=1.5, shift=(1, 0)) ``` """ -function identity_init(::AbstractRNG, ::Type{T}, dims::Integer...; +function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1, shift::Integer=0) where {T <: Number} - if length(dims) == 1 - # Bias initialization - return zeros(T, dims...) - elseif length(dims) == 2 - # Matrix multiplication + length(dims) == 1 && return __zeros(rng, T, dims...) # Bias initialization + + if length(dims) == 2 rows, cols = dims - mat = zeros(T, rows, cols) - for i in 1:min(rows, cols) - mat[i, i] = T(gain) - end + mat = __zeros(rng, T, rows, cols) + diag_indices = 1:min(rows, cols) + fill!(view(mat, diag_indices, diag_indices), T(gain)) return circshift(mat, shift) - else - # Convolution or more dimensions - nin, nout = dims[end - 1], dims[end] - centers = map(d -> cld(d, 2), dims[1:(end - 2)]) - weights = zeros(T, dims...) - for i in 1:min(nin, nout) - index = (centers..., i, i) - weights[index...] = T(gain) - end - return circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) end + + # Convolution or more dimensions + nin, nout = dims[end - 1], dims[end] + centers = map(d -> cld(d, 2), dims[1:(end - 2)]) + weights = __zeros(rng, T, dims...) + @allowscalar for i in 1:min(nin, nout) + index = (centers..., i, i) + weights[index...] = T(gain) + end + return circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) end # Default Fallbacks for all functions