From 665e9f064c5981dab83cce801aebb656c656aa73 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Jun 2024 20:09:05 -0700 Subject: [PATCH 1/4] Run formatter --- .JuliaFormatter.toml | 1 + Project.toml | 5 ++-- ext/WeightInitializersCUDAExt.jl | 4 +-- src/WeightInitializers.jl | 46 +++++--------------------------- src/initializers.jl | 33 +++++++++-------------- src/utils.jl | 9 ++++--- test/runtests.jl | 39 +++++++++++++-------------- 7 files changed, 48 insertions(+), 89 deletions(-) diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index dbc3116..547dbee 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -5,4 +5,5 @@ margin = 92 indent = 4 format_docstrings = true separate_kwargs_with_semicolon = true +join_lines_based_on_source = false always_for_in = true diff --git a/Project.toml b/Project.toml index 67384d9..6a42882 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ version = "0.1.7" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" -PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -24,7 +23,6 @@ CUDA = "5" ChainRulesCore = "1.21" LinearAlgebra = "1.9" PartialFunctions = "1.2" -PrecompileTools = "1.2" Random = "1.9" SpecialFunctions = "2" StableRNGs = "1" @@ -36,9 +34,10 @@ julia = "1.9" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Test", "StableRNGs", "Random", "Statistics", "CUDA"] +test = ["Aqua", "CUDA", "Random", "ReTestItems", "StableRNGs", "Statistics", "Test"] diff --git a/ext/WeightInitializersCUDAExt.jl b/ext/WeightInitializersCUDAExt.jl index ac07b42..105ae57 100644 --- a/ext/WeightInitializersCUDAExt.jl +++ b/ext/WeightInitializersCUDAExt.jl @@ -70,8 +70,8 @@ for initializer in (:sparse_init, :identity_init) @eval function ($initializer)(rng::AbstractCuRNG; kwargs...) return __partial_apply($initializer, (rng, (; kwargs...))) end - @eval function ($initializer)(rng::AbstractCuRNG, - ::Type{T}; kwargs...) where {T <: Number} + @eval function ($initializer)( + rng::AbstractCuRNG, ::Type{T}; kwargs...) where {T <: Number} return __partial_apply($initializer, ((rng, T), (; kwargs...))) end end diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index 6b17bd5..bac261e 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -1,50 +1,16 @@ module WeightInitializers -import PrecompileTools: @recompile_invalidations - -@recompile_invalidations begin - using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics, - LinearAlgebra -end +using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics, LinearAlgebra include("utils.jl") include("initializers.jl") # Mark the functions as non-differentiable -for f in [ - :zeros64, - :ones64, - :rand64, - :randn64, - :zeros32, - :ones32, - :rand32, - :randn32, - :zeros16, - :ones16, - :rand16, - :randn16, - :zerosC64, - :onesC64, - :randC64, - :randnC64, - :zerosC32, - :onesC32, - :randC32, - :randnC32, - :zerosC16, - :onesC16, - :randC16, - :randnC16, - :glorot_normal, - :glorot_uniform, - :kaiming_normal, - :kaiming_uniform, - :truncated_normal, - :orthogonal, - :sparse_init, - :identity_init -] +for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, + :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, + :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, + :randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal, + :kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init] @eval @non_differentiable $(f)(::Any...) end diff --git a/src/initializers.jl b/src/initializers.jl index fd31046..50deec2 100644 --- a/src/initializers.jl +++ b/src/initializers.jl @@ -4,15 +4,13 @@ for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros, :rand TP = NUM_TO_FPOINT[Symbol(T)] if fname in (:ones, :zeros) @eval begin - @doc $docstring - function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) + @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) return $(fname)($TP, dims...; kwargs...) end end else @eval begin - @doc $docstring - function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) + @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) return $(fname)(rng, $TP, dims...; kwargs...) end end @@ -34,8 +32,8 @@ Xavier initialization. feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ -function glorot_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...; - gain::Number=1) where {T <: Number} +function glorot_uniform( + rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} scale = T(gain) * sqrt(T(24) / sum(_nfan(dims...))) return (rand(rng, T, dims...) .- T(1 // 2)) .* scale end @@ -54,8 +52,8 @@ method is described in [1] and also known as Xavier initialization. feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ -function glorot_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; - gain::Number=1) where {T <: Number} +function glorot_normal( + rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1) where {T <: Number} std = T(gain) * sqrt(T(2) / sum(_nfan(dims...))) return randn(rng, T, dims...) .* std end @@ -293,14 +291,9 @@ using Random identity_matrix = identity_init(MersenneTwister(123), Float32, 5, 5) # Identity tensor for convolutional layer -identity_tensor = identity_init(MersenneTwister(123), - Float32, # Bias initialization - 3, - 3, - 5, # Matrix multiplication - 5; - gain=1.5, - shift=(1, 0)) +identity_tensor = identity_init(MersenneTwister(123), Float32, # Bias initialization + 3, 3, 5, # Matrix multiplication + 5; gain=1.5, shift=(1, 0)) ``` """ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; @@ -339,15 +332,15 @@ for initializer in (:glorot_uniform, :glorot_normal, :kaiming_uniform, :kaiming_ @eval function ($initializer)(rng::AbstractRNG, dims::Integer...; kwargs...) return $initializer(rng, Float32, dims...; kwargs...) end - @eval function ($initializer)(::Type{T}, - dims::Integer...; kwargs...) where {T <: $NType} + @eval function ($initializer)( + ::Type{T}, dims::Integer...; kwargs...) where {T <: $NType} return $initializer(_default_rng(), T, dims...; kwargs...) end @eval function ($initializer)(rng::AbstractRNG; kwargs...) return __partial_apply($initializer, (rng, (; kwargs...))) end - @eval function ($initializer)(rng::AbstractRNG, - ::Type{T}; kwargs...) where {T <: $NType} + @eval function ($initializer)( + rng::AbstractRNG, ::Type{T}; kwargs...) where {T <: $NType} return __partial_apply($initializer, ((rng, T), (; kwargs...))) end @eval ($initializer)(; kwargs...) = __partial_apply($initializer, (; kwargs...)) diff --git a/src/utils.jl b/src/utils.jl index 765890c..6a933d6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -16,12 +16,13 @@ end # This is needed if using `PartialFunctions.$` inside @eval block __partial_apply(fn, inp) = fn$inp -const NAME_TO_DIST = Dict(:zeros => "an AbstractArray of zeros", - :ones => "an AbstractArray of ones", +const NAME_TO_DIST = Dict( + :zeros => "an AbstractArray of zeros", :ones => "an AbstractArray of ones", :randn => "random numbers from a standard normal distribution", :rand => "random numbers from a uniform distribution") -const NUM_TO_FPOINT = Dict(Symbol(16) => Float16, Symbol(32) => Float32, - Symbol(64) => Float64, :C16 => ComplexF16, :C32 => ComplexF32, :C64 => ComplexF64) +const NUM_TO_FPOINT = Dict( + Symbol(16) => Float16, Symbol(32) => Float32, Symbol(64) => Float64, + :C16 => ComplexF16, :C32 => ComplexF32, :C64 => ComplexF64) @inline function __funcname(fname::String) fp = fname[(end - 2):end] diff --git a/test/runtests.jl b/test/runtests.jl index aca13c8..a620753 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,10 +32,9 @@ const GROUP = get(ENV, "GROUP", "All") end @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes - @testset "Sizes and Types: $init" for init in [zeros32, ones32, rand32, randn32, - kaiming_uniform, kaiming_normal, glorot_uniform, glorot_normal, - truncated_normal, identity_init - ] + @testset "Sizes and Types: $init" for init in [ + zeros32, ones32, rand32, randn32, kaiming_uniform, kaiming_normal, + glorot_uniform, glorot_normal, truncated_normal, identity_init] # Sizes @test size(init(3)) == (3,) @test size(init(rng, 3)) == (3,) @@ -52,15 +51,15 @@ const GROUP = get(ENV, "GROUP", "All") @test cl(3, 5) isa arrtype{Float32, 2} end - @testset "Sizes and Types: $init" for (init, fp) in [(zeros16, Float16), - (zerosC16, ComplexF16), (zeros32, Float32), (zerosC32, ComplexF32), - (zeros64, Float64), (zerosC64, ComplexF64), (ones16, Float16), - (onesC16, ComplexF16), (ones32, Float32), (onesC32, ComplexF32), - (ones64, Float64), (onesC64, ComplexF64), (rand16, Float16), - (randC16, ComplexF16), (rand32, Float32), (randC32, ComplexF32), - (rand64, Float64), (randC64, ComplexF64), (randn16, Float16), - (randnC16, ComplexF16), (randn32, Float32), (randnC32, ComplexF32), - (randn64, Float64), (randnC64, ComplexF64)] + @testset "Sizes and Types: $init" for (init, fp) in [ + (zeros16, Float16), (zerosC16, ComplexF16), (zeros32, Float32), + (zerosC32, ComplexF32), (zeros64, Float64), (zerosC64, ComplexF64), + (ones16, Float16), (onesC16, ComplexF16), (ones32, Float32), + (onesC32, ComplexF32), (ones64, Float64), (onesC64, ComplexF64), + (rand16, Float16), (randC16, ComplexF16), (rand32, Float32), + (randC32, ComplexF32), (rand64, Float64), (randC64, ComplexF64), + (randn16, Float16), (randnC16, ComplexF16), (randn32, Float32), + (randnC32, ComplexF32), (randn64, Float64), (randnC64, ComplexF64)] # Sizes @test size(init(3)) == (3,) @test size(init(rng, 3)) == (3,) @@ -77,11 +76,10 @@ const GROUP = get(ENV, "GROUP", "All") @test cl(3, 5) isa arrtype{fp, 2} end - @testset "AbstractArray Type: $init $T" for init in [kaiming_uniform, - kaiming_normal, - glorot_uniform, glorot_normal, truncated_normal, identity_init], - T in (Float16, Float32, - Float64, ComplexF16, ComplexF32, ComplexF64) + @testset "AbstractArray Type: $init $T" for init in [ + kaiming_uniform, kaiming_normal, glorot_uniform, + glorot_normal, truncated_normal, identity_init], + T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) init === truncated_normal && !(T <: Real) && continue @@ -99,8 +97,9 @@ const GROUP = get(ENV, "GROUP", "All") @test cl(3, 5) isa arrtype{T, 2} end - @testset "Closure: $init" for init in [kaiming_uniform, kaiming_normal, - glorot_uniform, glorot_normal, truncated_normal, identity_init] + @testset "Closure: $init" for init in [ + kaiming_uniform, kaiming_normal, glorot_uniform, + glorot_normal, truncated_normal, identity_init] cl = init(;) # Sizes @test size(cl(3)) == (3,) From 7fd4a422b9f4ff82d5607a56230e39723b4b5c0a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Jun 2024 21:08:08 -0700 Subject: [PATCH 2/4] Minor cleanups --- .github/workflows/CI.yml | 2 ++ .github/workflows/Downgrade.yml | 2 +- Project.toml | 29 +++++++++++++++++------------ README.md | 4 ++-- ext/WeightInitializersCUDAExt.jl | 25 +++++-------------------- src/WeightInitializers.jl | 27 +++++++++++++-------------- src/autodiff.jl | 8 ++++++++ src/initializers.jl | 20 ++++---------------- src/utils.jl | 19 +++++-------------- 9 files changed, 57 insertions(+), 79 deletions(-) create mode 100644 src/autodiff.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 2200a35..2ad20de 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -38,6 +38,8 @@ jobs: - uses: julia-actions/julia-runtest@v1 env: GROUP: "CPU" + RETESTITEMS_NWORKERS: 4 + RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 with: directories: src,ext diff --git a/.github/workflows/Downgrade.yml b/.github/workflows/Downgrade.yml index c57d5e3..269275e 100644 --- a/.github/workflows/Downgrade.yml +++ b/.github/workflows/Downgrade.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - version: ['1.9'] + version: ['1'] steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/Project.toml b/Project.toml index 6a42882..afbc7c1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,12 @@ name = "WeightInitializers" uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" authors = ["Avik Pal and contributors"] -version = "0.1.7" +version = "0.1.8" [deps] +ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -18,26 +20,29 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" WeightInitializersCUDAExt = "CUDA" [compat] -Aqua = "0.8" -CUDA = "5" -ChainRulesCore = "1.21" -LinearAlgebra = "1.9" +Aqua = "0.8.7" +ArgCheck = "2.3.0" +CUDA = "5.3.2" +ChainRulesCore = "1.23" +ExplicitImports = "1.6.0" +LinearAlgebra = "1.10" PartialFunctions = "1.2" -Random = "1.9" +Random = "1.10" +ReTestItems = "1.24.0" SpecialFunctions = "2" StableRNGs = "1" -Statistics = "1.9" -Test = "1.9" -julia = "1.9" +Statistics = "1.10" +Test = "1.10" +julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "CUDA", "Random", "ReTestItems", "StableRNGs", "Statistics", "Test"] +test = ["Aqua", "CUDA", "Documenter", "ExplicitImports", "ReTestItems", "StableRNGs", "Test"] diff --git a/README.md b/README.md index a730522..edede1c 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # WeightInitializers [![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning) -[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/) -[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/) +[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Building_Blocks/WeightInitializers) +[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Building_Blocks/WeightInitializers) [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![Build status](https://badge.buildkite.com/ffa2c8c3629cd58322446cddd3e8dcc4f121c28a574ee3e626.svg?branch=main)](https://buildkite.com/julialang/weightinitializers-dot-jl) diff --git a/ext/WeightInitializersCUDAExt.jl b/ext/WeightInitializersCUDAExt.jl index 105ae57..ad1bd50 100644 --- a/ext/WeightInitializersCUDAExt.jl +++ b/ext/WeightInitializersCUDAExt.jl @@ -1,9 +1,8 @@ module WeightInitializersCUDAExt -using WeightInitializers, CUDA -using Random -import WeightInitializers: __partial_apply, NUM_TO_FPOINT, identity_init, sparse_init, - orthogonal +using CUDA: CUDA, CURAND +using Random: Random, shuffle +using WeightInitializers: WeightInitializers, NUM_TO_FPOINT, __partial_apply const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} @@ -21,7 +20,7 @@ for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros) end end -function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; +function WeightInitializers.sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; sparsity::Number, std::Number=T(0.01)) where {T <: Number} if length(dims) != 2 throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) @@ -36,7 +35,7 @@ function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; return CUDA.@allowscalar mapslices(shuffle, sparse_array, dims=1) end -function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; +function WeightInitializers.identity_init(::AbstractCuRNG, ::Type{T}, dims::Integer...; gain::Number=1, shift::Integer=0) where {T <: Number} if length(dims) == 1 # Bias initialization @@ -62,18 +61,4 @@ function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; end end -for initializer in (:sparse_init, :identity_init) - @eval function ($initializer)(rng::AbstractCuRNG, dims::Integer...; kwargs...) - return $initializer(rng, Float32, dims...; kwargs...) - end - - @eval function ($initializer)(rng::AbstractCuRNG; kwargs...) - return __partial_apply($initializer, (rng, (; kwargs...))) - end - @eval function ($initializer)( - rng::AbstractCuRNG, ::Type{T}; kwargs...) where {T <: Number} - return __partial_apply($initializer, ((rng, T), (; kwargs...))) - end -end - end diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index bac261e..6b485a8 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -1,18 +1,20 @@ module WeightInitializers -using ChainRulesCore, PartialFunctions, Random, SpecialFunctions, Statistics, LinearAlgebra +#! format: off +using ChainRulesCore: ChainRulesCore +using GPUArraysCore: GPUArraysCore +using LinearAlgebra: LinearAlgebra, Diagonal, qr +using PartialFunctions: :$ +using Random: Random, AbstractRNG, Xoshiro, shuffle +using SpecialFunctions: SpecialFunctions, erf, erfinv +using Statistics: Statistics, std +#! format: on + +const CRC = ChainRulesCore include("utils.jl") include("initializers.jl") - -# Mark the functions as non-differentiable -for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, - :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, - :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, - :randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal, - :kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init] - @eval @non_differentiable $(f)(::Any...) -end +include("autodiff.jl") export zeros64, ones64, rand64, randn64, zeros32, ones32, rand32, randn32, zeros16, ones16, rand16, randn16 @@ -20,9 +22,6 @@ export zerosC64, onesC64, randC64, randnC64, zerosC32, onesC32, randC32, randnC3 onesC16, randC16, randnC16 export glorot_normal, glorot_uniform export kaiming_normal, kaiming_uniform -export truncated_normal -export orthogonal -export sparse_init -export identity_init +export truncated_normal, orthogonal, sparse_init, identity_init end diff --git a/src/autodiff.jl b/src/autodiff.jl new file mode 100644 index 0000000..cd9e7d6 --- /dev/null +++ b/src/autodiff.jl @@ -0,0 +1,8 @@ +# Mark the functions as non-differentiable +for f in [:zeros64, :ones64, :rand64, :randn64, :zeros32, :ones32, :rand32, :randn32, + :zeros16, :ones16, :rand16, :randn16, :zerosC64, :onesC64, :randC64, + :randnC64, :zerosC32, :onesC32, :randC32, :randnC32, :zerosC16, :onesC16, + :randC16, :randnC16, :glorot_normal, :glorot_uniform, :kaiming_normal, + :kaiming_uniform, :truncated_normal, :orthogonal, :sparse_init, :identity_init] + @eval CRC.@non_differentiable $(f)(::Any...) +end diff --git a/src/initializers.jl b/src/initializers.jl index 50deec2..65071f3 100644 --- a/src/initializers.jl +++ b/src/initializers.jl @@ -152,26 +152,14 @@ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=T(1.0)) where {T <: Number} @assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed" - if length(dims) == 2 - rows, cols = dims - else - rows = prod(dims[1:(end - 1)]) - cols = dims[end] - end - - if rows < cols - return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain))) - end + rows, cols = length(dims) == 2 ? dims : (prod(dims[1:(end - 1)]), dims[end]) + rows < cols && return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain))) mat = randn(rng, T, rows, cols) Q, R = qr(mat) mat .= Q * sign.(Diagonal(R)) .* T(gain) - if length(dims) > 2 - return reshape(mat, dims) - else - return mat - end + return length(dims) > 2 ? reshape(mat, dims) : mat end """ @@ -296,7 +284,7 @@ identity_tensor = identity_init(MersenneTwister(123), Float32, # Bias ini 5; gain=1.5, shift=(1, 0)) ``` """ -function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; +function identity_init(::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1, shift::Integer=0) where {T <: Number} if length(dims) == 1 # Bias initialization diff --git a/src/utils.jl b/src/utils.jl index 6a933d6..6dbc6b7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,18 +3,12 @@ @inline _nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices @inline _nfan(dims::Tuple) = _nfan(dims...) @inline _nfan(dims...) = prod(dims[1:(end - 2)]) .* (dims[end - 1], dims[end]) # In case of convolution kernels -_norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / √2)) +@inline _norm_cdf(x::T) where {T} = T(0.5) * (1 + erf(x / √2)) -function _default_rng() - @static if VERSION >= v"1.7" - return Xoshiro(1234) - else - return MersenneTwister(1234) - end -end +@inline _default_rng() = Xoshiro(1234) # This is needed if using `PartialFunctions.$` inside @eval block -__partial_apply(fn, inp) = fn$inp +@inline __partial_apply(fn, inp) = fn$inp const NAME_TO_DIST = Dict( :zeros => "an AbstractArray of zeros", :ones => "an AbstractArray of ones", @@ -26,11 +20,8 @@ const NUM_TO_FPOINT = Dict( @inline function __funcname(fname::String) fp = fname[(end - 2):end] - if Symbol(fp) in keys(NUM_TO_FPOINT) - return fname[1:(end - 3)], fp - else - return fname[1:(end - 2)], fname[(end - 1):end] - end + Symbol(fp) in keys(NUM_TO_FPOINT) && return fname[1:(end - 3)], fp + return fname[1:(end - 2)], fname[(end - 1):end] end @inline function __generic_docstring(fname::String) From e74b0586c9bcb264a5f14175e190d84142e37591 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Jun 2024 21:30:58 -0700 Subject: [PATCH 3/4] Generalize the code --- Project.toml | 2 ++ ext/WeightInitializersCUDAExt.jl | 56 +++-------------------------- src/WeightInitializers.jl | 2 +- src/initializers.jl | 62 +++++++++++++++----------------- 4 files changed, 36 insertions(+), 86 deletions(-) diff --git a/Project.toml b/Project.toml index afbc7c1..be3e84a 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,9 @@ Aqua = "0.8.7" ArgCheck = "2.3.0" CUDA = "5.3.2" ChainRulesCore = "1.23" +Documenter = "1.5.0" ExplicitImports = "1.6.0" +GPUArraysCore = "0.1.6" LinearAlgebra = "1.10" PartialFunctions = "1.2" Random = "1.10" diff --git a/ext/WeightInitializersCUDAExt.jl b/ext/WeightInitializersCUDAExt.jl index ad1bd50..e97f268 100644 --- a/ext/WeightInitializersCUDAExt.jl +++ b/ext/WeightInitializersCUDAExt.jl @@ -6,59 +6,11 @@ using WeightInitializers: WeightInitializers, NUM_TO_FPOINT, __partial_apply const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} -for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros) - name = Symbol(fname, T) - TP = NUM_TO_FPOINT[Symbol(T)] - @eval begin - function WeightInitializers.$(name)(rng::AbstractCuRNG, dims::Integer...; kwargs...) - return CUDA.$(fname)($TP, dims...; kwargs...) - end - end - - @eval function WeightInitializers.$(name)(rng::AbstractCuRNG; kwargs...) - return __partial_apply($name, (rng, (; kwargs...))) - end -end - -function WeightInitializers.sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...; - sparsity::Number, std::Number=T(0.01)) where {T <: Number} - if length(dims) != 2 - throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) - end - - rows, cols = dims - prop_zero = min(1.0, sparsity) - num_zeros = ceil(Integer, prop_zero * rows) - sparse_array = randn(rng, T, dims...) .* T(std) - sparse_array[1:num_zeros, :] .= CUDA.zero(T) - - return CUDA.@allowscalar mapslices(shuffle, sparse_array, dims=1) +function WeightInitializers.__zeros(::AbstractCuRNG, T::Type, dims::Integer...) + return CUDA.zeros(T, dims...) end - -function WeightInitializers.identity_init(::AbstractCuRNG, ::Type{T}, dims::Integer...; - gain::Number=1, shift::Integer=0) where {T <: Number} - if length(dims) == 1 - # Bias initialization - return CUDA.zeros(T, dims...) - elseif length(dims) == 2 - # Matrix multiplication - rows, cols = dims - mat = CUDA.zeros(T, rows, cols) - diag_indices = 1:min(rows, cols) - CUDA.fill!(view(mat, diag_indices, diag_indices), T(gain)) - return CUDA.circshift(mat, shift) - else - # Convolution or more dimensions - nin, nout = dims[end - 1], dims[end] - centers = map(d -> cld(d, 2), dims[1:(end - 2)]) - weights = CUDA.zeros(T, dims...) - #we should really find a better way to do this - CUDA.@allowscalar for i in 1:min(nin, nout) - index = (centers..., i, i) - weights[index...] = T(gain) - end - return CUDA.circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) - end +function WeightInitializers.__ones(::AbstractCuRNG, T::Type, dims::Integer...) + return CUDA.ones(T, dims...) end end diff --git a/src/WeightInitializers.jl b/src/WeightInitializers.jl index 6b485a8..8838112 100644 --- a/src/WeightInitializers.jl +++ b/src/WeightInitializers.jl @@ -2,7 +2,7 @@ module WeightInitializers #! format: off using ChainRulesCore: ChainRulesCore -using GPUArraysCore: GPUArraysCore +using GPUArraysCore: @allowscalar using LinearAlgebra: LinearAlgebra, Diagonal, qr using PartialFunctions: :$ using Random: Random, AbstractRNG, Xoshiro, shuffle diff --git a/src/initializers.jl b/src/initializers.jl index 65071f3..7877d2b 100644 --- a/src/initializers.jl +++ b/src/initializers.jl @@ -1,18 +1,15 @@ +__zeros(::AbstractRNG, ::Type{T}, dims::Integer...) where {T} = zeros(T, dims...) +__ones(::AbstractRNG, ::Type{T}, dims::Integer...) where {T} = ones(T, dims...) + for T in ("16", "32", "64", "C16", "C32", "C64"), fname in (:ones, :zeros, :rand, :randn) name = Symbol(fname, T) docstring = __generic_docstring(string(name)) TP = NUM_TO_FPOINT[Symbol(T)] - if fname in (:ones, :zeros) - @eval begin - @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) - return $(fname)($TP, dims...; kwargs...) - end - end - else - @eval begin - @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) - return $(fname)(rng, $TP, dims...; kwargs...) - end + __fname = fname in (:ones, :zeros) ? Symbol("__", fname) : fname + + @eval begin + @doc $docstring function $(name)(rng::AbstractRNG, dims::Integer...; kwargs...) + return $__fname(rng, $TP, dims...; kwargs...) end end end @@ -222,9 +219,11 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; rows, cols = dims prop_zero = min(1.0, sparsity) num_zeros = ceil(Integer, prop_zero * rows) + sparse_array = randn(rng, T, dims...) .* T(std) - sparse_array[1:num_zeros, :] .= zero(T) - return mapslices(shuffle, sparse_array; dims=1) + fill!(view(sparse_array, 1:num_zeros, :), zero(T)) + + return @allowscalar mapslices(shuffle, sparse_array; dims=1) end """ @@ -284,30 +283,27 @@ identity_tensor = identity_init(MersenneTwister(123), Float32, # Bias ini 5; gain=1.5, shift=(1, 0)) ``` """ -function identity_init(::AbstractRNG, ::Type{T}, dims::Integer...; +function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=1, shift::Integer=0) where {T <: Number} - if length(dims) == 1 - # Bias initialization - return zeros(T, dims...) - elseif length(dims) == 2 - # Matrix multiplication + length(dims) == 1 && return __zeros(rng, T, dims...) # Bias initialization + + if length(dims) == 2 rows, cols = dims - mat = zeros(T, rows, cols) - for i in 1:min(rows, cols) - mat[i, i] = T(gain) - end + mat = __zeros(rng, T, rows, cols) + diag_indices = 1:min(rows, cols) + fill!(view(mat, diag_indices, diag_indices), T(gain)) return circshift(mat, shift) - else - # Convolution or more dimensions - nin, nout = dims[end - 1], dims[end] - centers = map(d -> cld(d, 2), dims[1:(end - 2)]) - weights = zeros(T, dims...) - for i in 1:min(nin, nout) - index = (centers..., i, i) - weights[index...] = T(gain) - end - return circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) end + + # Convolution or more dimensions + nin, nout = dims[end - 1], dims[end] + centers = map(d -> cld(d, 2), dims[1:(end - 2)]) + weights = __zeros(rng, T, dims...) + @allowscalar for i in 1:min(nin, nout) + index = (centers..., i, i) + weights[index...] = T(gain) + end + return circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift)) end # Default Fallbacks for all functions From 4f5b4eaf188b6d238a477e3f0f42caff7d860aa0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 26 Jun 2024 23:11:37 -0700 Subject: [PATCH 4/4] Finish rewriting the tests --- .buildkite/pipeline.yml | 6 +- .github/workflows/CI.yml | 2 +- .github/workflows/Downgrade.yml | 2 +- .github/workflows/Downstream.yml | 2 +- .github/workflows/FormatCheck.yml | 40 ---- .github/workflows/QualityCheck.yml | 19 ++ .typos.toml | 2 + Project.toml | 2 - README.md | 1 - ext/WeightInitializersCUDAExt.jl | 3 +- src/initializers.jl | 96 +++++----- test/initializers_tests.jl | 267 +++++++++++++++++++++++++++ test/qa_tests.jl | 23 +++ test/runtests.jl | 287 +---------------------------- test/shared_testsetup.jl | 20 ++ test/utils_tests.jl | 9 + 16 files changed, 397 insertions(+), 384 deletions(-) delete mode 100644 .github/workflows/FormatCheck.yml create mode 100644 .github/workflows/QualityCheck.yml create mode 100644 .typos.toml create mode 100644 test/initializers_tests.jl create mode 100644 test/qa_tests.jl create mode 100644 test/shared_testsetup.jl create mode 100644 test/utils_tests.jl diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index a625b0f..565e58f 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -16,7 +16,7 @@ steps: queue: "juliagpu" cuda: "*" env: - GROUP: "CUDA" + BACKEND_GROUP: "CUDA" if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 240 matrix: @@ -61,7 +61,7 @@ steps: queue: "juliagpu" cuda: "*" env: - GROUP: "CUDA" + BACKEND_GROUP: "CUDA" DOWNSTREAM_TEST_REPO: "{{matrix.repo}}" if: build.message !~ /\[skip tests\]/ || build.message !~ /\[skip downstream\]/ timeout_in_minutes: 240 @@ -111,7 +111,7 @@ steps: rocm: "*" rocmgpu: "*" env: - GROUP: "AMDGPU" + BACKEND_GROUP: "AMDGPU" JULIA_AMDGPU_CORE_MUST_LOAD: "1" JULIA_AMDGPU_HIP_MUST_LOAD: "1" JULIA_AMDGPU_DISABLE_ARTIFACTS: "1" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 2ad20de..6596d9d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -37,7 +37,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: - GROUP: "CPU" + BACKEND_GROUP: "CPU" RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 diff --git a/.github/workflows/Downgrade.yml b/.github/workflows/Downgrade.yml index 269275e..5a5bcb1 100644 --- a/.github/workflows/Downgrade.yml +++ b/.github/workflows/Downgrade.yml @@ -27,7 +27,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: - GROUP: "CPU" + BACKEND_GROUP: "CPU" RETESTITEMS_NWORKERS: 4 RETESTITEMS_NWORKER_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index b215b2b..bf579cb 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -16,7 +16,7 @@ jobs: name: ${{ matrix.package.repo }}/${{ matrix.package.group }} runs-on: ${{ matrix.os }} env: - GROUP: ${{ matrix.package.group }} + BACKEND_GROUP: ${{ matrix.package.group }} strategy: fail-fast: false matrix: diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml deleted file mode 100644 index ac75c52..0000000 --- a/.github/workflows/FormatCheck.yml +++ /dev/null @@ -1,40 +0,0 @@ -name: FormatCheck - -on: - push: - branches: - - 'main' - - 'release-' - tags: ['*'] - pull_request: - -jobs: - build: - runs-on: ${{ matrix.os }} - strategy: - matrix: - julia-version: ["1"] - julia-arch: [x86] - os: [ubuntu-latest] - steps: - - uses: julia-actions/setup-julia@latest - with: - version: ${{ matrix.julia-version }} - - - uses: actions/checkout@v4 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".", verbose=true)' - - name: Format check - run: | - julia -e ' - out = Cmd(`git diff --name-only`) |> read |> String - if out == "" - exit(0) - else - @error "Some files have not been formatted !!!" - write(stdout, out) - exit(1) - end' - \ No newline at end of file diff --git a/.github/workflows/QualityCheck.yml b/.github/workflows/QualityCheck.yml new file mode 100644 index 0000000..3bfa611 --- /dev/null +++ b/.github/workflows/QualityCheck.yml @@ -0,0 +1,19 @@ +name: Code Quality Check + +on: [pull_request] + +jobs: + code-style: + name: Format Suggestions + runs-on: ubuntu-latest + steps: + - uses: julia-actions/julia-format@v3 + + typos-check: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - name: Checkout Actions Repository + uses: actions/checkout@v4 + - name: Check spelling + uses: crate-ci/typos@v1.22.9 diff --git a/.typos.toml b/.typos.toml new file mode 100644 index 0000000..4b87229 --- /dev/null +++ b/.typos.toml @@ -0,0 +1,2 @@ +[default.extend-words] +nin = "nin" diff --git a/Project.toml b/Project.toml index be3e84a..6981002 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ authors = ["Avik Pal and contributors"] version = "0.1.8" [deps] -ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -21,7 +20,6 @@ WeightInitializersCUDAExt = "CUDA" [compat] Aqua = "0.8.7" -ArgCheck = "2.3.0" CUDA = "5.3.2" ChainRulesCore = "1.23" Documenter = "1.5.0" diff --git a/README.md b/README.md index edede1c..4dc182c 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,6 @@ [![Build status](https://badge.buildkite.com/ffa2c8c3629cd58322446cddd3e8dcc4f121c28a574ee3e626.svg?branch=main)](https://buildkite.com/julialang/weightinitializers-dot-jl) [![CI](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/WeightInitializers.jl/actions/workflows/CI.yml) [![codecov](https://codecov.io/gh/LuxDL/WeightInitializers.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/WeightInitializers.jl) -[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/WeightInitializers)](https://pkgs.genieframework.com?packages=WeightInitializers) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) [![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle) diff --git a/ext/WeightInitializersCUDAExt.jl b/ext/WeightInitializersCUDAExt.jl index e97f268..ac2d391 100644 --- a/ext/WeightInitializersCUDAExt.jl +++ b/ext/WeightInitializersCUDAExt.jl @@ -1,8 +1,7 @@ module WeightInitializersCUDAExt using CUDA: CUDA, CURAND -using Random: Random, shuffle -using WeightInitializers: WeightInitializers, NUM_TO_FPOINT, __partial_apply +using WeightInitializers: WeightInitializers const AbstractCuRNG = Union{CUDA.RNG, CURAND.RNG} diff --git a/src/initializers.jl b/src/initializers.jl index 7877d2b..2a5e4c8 100644 --- a/src/initializers.jl +++ b/src/initializers.jl @@ -104,7 +104,8 @@ truncated normal distribution. The numbers are distributed like function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T(0), std=T(1), lo=-T(2), hi=T(2)) where {T <: Real} if (mean < lo - 2 * std) || (mean > hi + 2 * std) - @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." + @warn "Mean is more than 2 std outside the limits in truncated_normal, so the \ + distribution of values may be inaccurate." end l = _norm_cdf((T(lo) - T(mean)) / T(std)) u = _norm_cdf((T(hi) - T(mean)) / T(std)) @@ -122,13 +123,12 @@ end gain = 1) -> AbstractArray{T, length(dims)} Return an `AbstractArray{T}` of the given dimensions (`dims`) which is a -(semi) orthogonal matrix, as described in [^Saxe14] +(semi) orthogonal matrix, as described in [1]. The function constructs an orthogonal or semi-orthogonal matrix depending on the specified -dimensions. For two dimensions, it returns a matrix where `dims = (rows, cols)`. -For more than two dimensions, it computes an orthogonal matrix of -size `prod(dims[1:(end - 1)])` by `dims[end]` before reshaping it to -the original dimensions. +dimensions. For two dimensions, it returns a matrix where `dims = (rows, cols)`. For more +than two dimensions, it computes an orthogonal matrix of size `prod(dims[1:(end - 1)])` by +`dims[end]` before reshaping it to the original dimensions. Cannot construct a vector, i.e., `length(dims) == 1` is forbidden. @@ -141,9 +141,8 @@ Cannot construct a vector, i.e., `length(dims) == 1` is forbidden. # References -[^Saxe14] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of -learning in deep linear neural networks", -ICLR 2014, https://arxiv.org/abs/1312.6120 +[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in +deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120 """ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...; gain::Number=T(1.0)) where {T <: Number} @@ -164,12 +163,15 @@ end sparsity::Number, std::Number=0.01) -> AbstractArray{T} Creates a sparsely initialized weight matrix with a specified proportion of zeroed elements, -using random numbers drawn from a normal distribution for the non-zero elements. -This method is introduced in [^Martens2010]. -Note: The sparsity parameter controls the proportion of the matrix that will be zeroed. -For example, a sparsity of 0.3 means that approximately 30% of the elements will be -set to zero. The non-zero elements are distributed according to a normal distribution, -scaled by the std parameter. +using random numbers drawn from a normal distribution for the non-zero elements. This method +was introduced in [1]. + +!!! note + + The sparsity parameter controls the proportion of the matrix that will be zeroed. For + example, a sparsity of 0.3 means that approximately 30% of the elements will be set to + zero. The non-zero elements are distributed according to a normal distribution, scaled + by the std parameter. # Arguments @@ -177,43 +179,36 @@ scaled by the std parameter. - `T::Type{<:Number}`: The numeric type of the elements in the returned array. - `dims::Integer...`: The dimensions of the weight matrix to be generated. - `sparsity::Number`: The proportion of elements to be zeroed. Must be between 0 and 1. - - `std::Number=0.01`: The standard deviation of the normal distribution - before applying `gain`. + - `std::Number=0.01`: The standard deviation of the normal distribution before applying + `gain`. # Returns - - `AbstractArray{T}`: A sparsely initialized weight matrix of dimensions `dims` - and type `T`. + - `AbstractArray{T}`: A sparsely initialized weight matrix of dimensions `dims` and type + `T`. # Examples -```julia -using Random +```jldoctest +julia> y = sparse_init(Xoshiro(123), Float32, 5, 5; sparsity=0.3, std=0.01); -# Initialize a 5x5 sparsely initialized matrix with 30% sparsity -rng = MersenneTwister(123) -matrix = sparse_init(rng, Float32, 5, 5; sparsity=0.3, std=0.01) -``` +julia> y isa Matrix{Float32} +true -``` -5×5 Matrix{Float64}: - 0.0 0.00273815 0.00592403 0.0 0.0 - 0.00459416 -0.000754831 -0.00888936 -0.0077507 0.0 - 0.0 -0.00194229 0.0 0.0 -0.00468489 - 0.0114265 0.0 0.0 -0.00734886 0.00277726 - -0.00396679 0.0 0.00327215 -0.0071741 -0.00880897 +julia> size(y) == (5, 5) +true ``` # References -[^Martens2010] Martens, J, "Deep learning via Hessian-free optimization" -_Proceedings of the 27th International Conference on International Conference -on Machine Learning_. 2010. +[1] Martens, J, "Deep learning via Hessian-free optimization" Proceedings of the 27th +International Conference on International Conference on Machine Learning. 2010. """ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; sparsity::Number, std::Number=T(0.01)) where {T <: Number} if length(dims) != 2 - throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) + throw(ArgumentError("Only 2-dimensional outputs are supported for sparse \ + initialization.")) end rows, cols = dims @@ -250,8 +245,8 @@ most layers of a neural network. The identity mapping is scaled by the `gain` pa - Layers must have `input_size == output_size` for a perfect identity mapping. In cases where this condition is not met, the function pads extra dimensions with zeros. - For convolutional layers to achieve an identity mapping, kernel sizes must be odd, - and appropriate padding must be applied to ensure the output - feature maps are the same size as the input feature maps. + and appropriate padding must be applied to ensure the output feature maps are the same + size as the input feature maps. # Arguments @@ -271,16 +266,21 @@ most layers of a neural network. The identity mapping is scaled by the `gain` pa # Examples -```julia -using Random - -# Identity matrix for fully connected layer -identity_matrix = identity_init(MersenneTwister(123), Float32, 5, 5) - -# Identity tensor for convolutional layer -identity_tensor = identity_init(MersenneTwister(123), Float32, # Bias initialization - 3, 3, 5, # Matrix multiplication - 5; gain=1.5, shift=(1, 0)) +```jldoctest +julia> identity_init(Xoshiro(123), Float32, 5, 5) +5×5 Matrix{Float32}: + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 1.0 + +julia> identity_init(Xoshiro(123), Float32, 3, 3, 1, 1; gain=1.5) +3×3×1×1 Array{Float32, 4}: +[:, :, 1, 1] = + 0.0 0.0 0.0 + 0.0 1.5 0.0 + 0.0 0.0 0.0 ``` """ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...; diff --git a/test/initializers_tests.jl b/test/initializers_tests.jl new file mode 100644 index 0000000..202e10d --- /dev/null +++ b/test/initializers_tests.jl @@ -0,0 +1,267 @@ +@testitem "Warning: truncated_normal" begin + @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so \ + the distribution of values may be inaccurate." truncated_normal(2; mean=-5.0f0) +end + +@testitem "Identity Initialization" begin + @testset "Non-identity sizes" begin + @test identity_init(2, 3)[:, end] == zeros(Float32, 2) + @test identity_init(3, 2; shift=1)[1, :] == zeros(Float32, 2) + @test identity_init(1, 1, 3, 4)[:, :, :, end] == zeros(Float32, 1, 1, 3) + @test identity_init(2, 1, 3, 3)[end, :, :, :] == zeros(Float32, 1, 3, 3) + @test identity_init(1, 2, 3, 3)[:, end, :, :] == zeros(Float32, 1, 3, 3) + end +end + +@testitem "Orthogonal Initialization" setup=[SharedTestSetup] begin + using GPUArraysCore, LinearAlgebra + + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in RNGS_ARRTYPES + # A matrix of dim = (m,n) with m > n should produce a QR decomposition. + # In the other case, the transpose should be taken to compute the QR decomposition. + for (rows, cols) in [(5, 3), (3, 5)] + v = orthogonal(rng, rows, cols) + GPUArraysCore.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : + (@test v' * v ≈ I(cols)) + end + + for mat in [(3, 4, 5), (2, 2, 5)] + v = orthogonal(rng, mat...) + cols = mat[end] + rows = div(prod(mat), cols) + v = reshape(v, (rows, cols)) + GPUArraysCore.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : + (@test v' * v ≈ I(cols)) + end + + @testset "Orthogonal Types $T" for T in (Float32, Float64) + @test eltype(orthogonal(rng, T, 3, 4; gain=1.5)) == T + @test eltype(orthogonal(rng, T, 3, 4, 5; gain=1.5)) == T + end + + @testset "Orthogonal AbstractArray Type $T" for T in (Float32, Float64) + @test orthogonal(rng, T, 3, 5) isa AbstractArray{T, 2} + @test orthogonal(rng, T, 3, 5) isa arrtype{T, 2} + + cl = orthogonal(rng) + @test cl(T, 3, 5) isa arrtype{T, 2} + + cl = orthogonal(rng, T) + @test cl(3, 5) isa arrtype{T, 2} + end + + @testset "Orthogonal Closure" begin + cl = orthogonal(;) + + # Sizes + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + @test size(cl(3, 4, 5)) == (3, 4, 5) + @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) + + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end + end +end + +@testitem "Sparse Initialization" setup=[SharedTestSetup] begin + using Statistics + + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in RNGS_ARRTYPES + # sparse_init should yield an error for non 2-d dimensions + # sparse_init should yield no zero elements if sparsity < 0 + # sparse_init should yield all zero elements if sparsity > 1 + # sparse_init should yield exactly ceil(n_in * sparsity) elements in each column for + # other sparsity values + # sparse_init should yield a kernel in its non-zero elements consistent with the std + # parameter + + @test_throws ArgumentError sparse_init(3, 4, 5, sparsity=0.1) + @test_throws ArgumentError sparse_init(3, sparsity=0.1) + v = sparse_init(100, 100; sparsity=-0.1) + @test sum(v .== 0) == 0 + v = sparse_init(100, 100; sparsity=1.1) + @test sum(v .== 0) == length(v) + + for (n_in, n_out, sparsity, σ) in [(100, 100, 0.25, 0.1), (100, 400, 0.75, 0.01)] + expected_zeros = ceil(Integer, n_in * sparsity) + v = sparse_init(n_in, n_out; sparsity=sparsity, std=σ) + @test all([sum(v[:, col] .== 0) == expected_zeros for col in 1:n_out]) + @test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ + end + + @testset "sparse_init Types $T" for T in (Float16, Float32, Float64) + @test eltype(sparse_init(rng, T, 3, 4; sparsity=0.5)) == T + end + + @testset "sparse_init AbstractArray Type $T" for T in (Float16, Float32, Float64) + @test sparse_init(T, 3, 5; sparsity=0.5) isa AbstractArray{T, 2} + @test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T, 2} + + cl = sparse_init(rng; sparsity=0.5) + @test cl(T, 3, 5) isa arrtype{T, 2} + + cl = sparse_init(rng, T; sparsity=0.5) + @test cl(3, 5) isa arrtype{T, 2} + end + + @testset "sparse_init Closure" begin + cl = sparse_init(; sparsity=0.5) + # Sizes + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end + end +end + +@testitem "Basic Initializations" setup=[SharedTestSetup] begin + using LinearAlgebra, Statistics + + @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in RNGS_ARRTYPES + @testset "Sizes and Types: $init" for init in [ + zeros32, ones32, rand32, randn32, kaiming_uniform, kaiming_normal, + glorot_uniform, glorot_normal, truncated_normal, identity_init] + # Sizes + @test size(init(3)) == (3,) + @test size(init(rng, 3)) == (3,) + @test size(init(3, 4)) == (3, 4) + @test size(init(rng, 3, 4)) == (3, 4) + @test size(init(3, 4, 5)) == (3, 4, 5) + @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(init(rng, 4, 2)) == Float32 + @test eltype(init(4, 2)) == Float32 + # RNG Closure + cl = init(rng) + @test cl(3) isa arrtype{Float32, 1} + @test cl(3, 5) isa arrtype{Float32, 2} + end + + @testset "Sizes and Types: $init" for (init, fp) in [ + (zeros16, Float16), (zerosC16, ComplexF16), (zeros32, Float32), + (zerosC32, ComplexF32), (zeros64, Float64), (zerosC64, ComplexF64), + (ones16, Float16), (onesC16, ComplexF16), (ones32, Float32), + (onesC32, ComplexF32), (ones64, Float64), (onesC64, ComplexF64), + (rand16, Float16), (randC16, ComplexF16), (rand32, Float32), + (randC32, ComplexF32), (rand64, Float64), (randC64, ComplexF64), + (randn16, Float16), (randnC16, ComplexF16), (randn32, Float32), + (randnC32, ComplexF32), (randn64, Float64), (randnC64, ComplexF64)] + # Sizes + @test size(init(3)) == (3,) + @test size(init(rng, 3)) == (3,) + @test size(init(3, 4)) == (3, 4) + @test size(init(rng, 3, 4)) == (3, 4) + @test size(init(3, 4, 5)) == (3, 4, 5) + @test size(init(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(init(rng, 4, 2)) == fp + @test eltype(init(4, 2)) == fp + # RNG Closure + cl = init(rng) + @test cl(3) isa arrtype{fp, 1} + @test cl(3, 5) isa arrtype{fp, 2} + end + + @testset "AbstractArray Type: $init $T" for init in [ + kaiming_uniform, kaiming_normal, glorot_uniform, + glorot_normal, truncated_normal, identity_init], + T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) + + init === truncated_normal && !(T <: Real) && continue + + @test init(T, 3) isa AbstractArray{T, 1} + @test init(rng, T, 3) isa arrtype{T, 1} + @test init(T, 3, 5) isa AbstractArray{T, 2} + @test init(rng, T, 3, 5) isa arrtype{T, 2} + + cl = init(rng) + @test cl(T, 3) isa arrtype{T, 1} + @test cl(T, 3, 5) isa arrtype{T, 2} + + cl = init(rng, T) + @test cl(3) isa arrtype{T, 1} + @test cl(3, 5) isa arrtype{T, 2} + end + + @testset "Closure: $init" for init in [ + kaiming_uniform, kaiming_normal, glorot_uniform, + glorot_normal, truncated_normal, identity_init] + cl = init(;) + # Sizes + @test size(cl(3)) == (3,) + @test size(cl(rng, 3)) == (3,) + @test size(cl(3, 4)) == (3, 4) + @test size(cl(rng, 3, 4)) == (3, 4) + @test size(cl(3, 4, 5)) == (3, 4, 5) + @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) + # Type + @test eltype(cl(4, 2)) == Float32 + @test eltype(cl(rng, 4, 2)) == Float32 + end + + @testset "Kwargs types" for T in ( + Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) + if (T <: Real) + @test eltype(truncated_normal(T, 2, 5; mean=0, std=1, lo=-2, hi=2)) == T + @test eltype(orthogonal(T, 2, 5; gain=1.0)) == T + end + @test eltype(glorot_uniform(T, 2, 5; gain=1.0)) == T + @test eltype(glorot_normal(T, 2, 5; gain=1.0)) == T + @test eltype(kaiming_uniform(T, 2, 5; gain=sqrt(2))) == T + @test eltype(kaiming_normal(T, 2, 5; gain=sqrt(2))) == T + @test eltype(identity_init(T, 2, 5; gain=1.0)) == T + @test eltype(sparse_init(T, 2, 5; sparsity=0.5, std=0.01)) == T + end + + @testset "kaiming" begin + # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] + # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) + for (n_in, n_out) in [(100, 100), (100, 400)] + v = kaiming_uniform(rng, n_in, n_out) + σ2 = sqrt(6 / n_out) + @test -1σ2 < minimum(v) < -0.9σ2 + @test 0.9σ2 < maximum(v) < 1σ2 + + v = kaiming_normal(rng, n_in, n_out) + σ2 = sqrt(2 / n_out) + @test 0.9σ2 < std(v) < 1.1σ2 + end + # Type + @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5f0)) == Float32 + @test eltype(kaiming_normal(rng, 3, 4; gain=1.5f0)) == Float32 + end + + @testset "glorot: $init" for init in [glorot_uniform, glorot_normal] + # glorot_uniform and glorot_normal should both yield a kernel with + # variance ≈ 2/(fan_in + fan_out) + for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] + v = init(dims...) + fan_in, fan_out = WeightInitializers._nfan(dims...) + σ2 = 2 / (fan_in + fan_out) + @test 0.9σ2 < var(v) < 1.1σ2 + end + @test eltype(init(3, 4; gain=1.5)) == Float32 + end + + @testset "orthogonal" begin + # A matrix of dim = (m,n) with m > n should produce a QR decomposition. In the other case, the transpose should be taken to compute the QR decomposition. + for (rows, cols) in [(5, 3), (3, 5)] + v = orthogonal(rows, cols) + rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) + end + for mat in [(3, 4, 5), (2, 2, 5)] + v = orthogonal(mat...) + cols = mat[end] + rows = div(prod(mat), cols) + v = reshape(v, (rows, cols)) + rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) + end + @test eltype(orthogonal(3, 4; gain=1.5)) == Float32 + end + end +end diff --git a/test/qa_tests.jl b/test/qa_tests.jl new file mode 100644 index 0000000..c5c93c2 --- /dev/null +++ b/test/qa_tests.jl @@ -0,0 +1,23 @@ +@testitem "Aqua: Quality Assurance" begin + using Aqua + + Aqua.test_all(WeightInitializers; ambiguities=false) + Aqua.test_ambiguities(WeightInitializers; recursive=false) +end + +@testitem "Explicit Imports: Quality Assurance" setup=[SharedTestSetup] begin + using CUDA, ExplicitImports + + @test check_no_implicit_imports(WeightInitializers) === nothing + @test check_no_stale_explicit_imports(WeightInitializers) === nothing + @test check_no_self_qualified_accesses(WeightInitializers) === nothing +end + +@testitem "doctests: Quality Assurance" begin + using Documenter + + doctestexpr = :(using Random, WeightInitializers) + + DocMeta.setdocmeta!(WeightInitializers, :DocTestSetup, doctestexpr; recursive=true) + doctest(WeightInitializers; manual=false) +end diff --git a/test/runtests.jl b/test/runtests.jl index a620753..8ba7978 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,286 +1,3 @@ -using Aqua -using WeightInitializers, Test, Statistics -using StableRNGs, Random, CUDA, LinearAlgebra +using ReTestItems -CUDA.allowscalar(false) - -const GROUP = get(ENV, "GROUP", "All") - -@testset "WeightInitializers.jl Tests" begin - rngs_arrtypes = [] - - if GROUP == "All" || GROUP == "CPU" - append!(rngs_arrtypes, - [(StableRNG(12345), AbstractArray), (Random.default_rng(), AbstractArray)]) - end - - if GROUP == "All" || GROUP == "CUDA" - append!(rngs_arrtypes, [(CUDA.default_rng(), CuArray)]) - end - - @testset "_nfan" begin - # Fallback - @test WeightInitializers._nfan() == (1, 1) - # Vector - @test WeightInitializers._nfan(4) == (1, 4) - # Matrix - @test WeightInitializers._nfan(4, 5) == (5, 4) - # Tuple - @test WeightInitializers._nfan((4, 5, 6)) == WeightInitializers._nfan(4, 5, 6) - # Convolution - @test WeightInitializers._nfan(4, 5, 6) == 4 .* (5, 6) - end - - @testset "rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes - @testset "Sizes and Types: $init" for init in [ - zeros32, ones32, rand32, randn32, kaiming_uniform, kaiming_normal, - glorot_uniform, glorot_normal, truncated_normal, identity_init] - # Sizes - @test size(init(3)) == (3,) - @test size(init(rng, 3)) == (3,) - @test size(init(3, 4)) == (3, 4) - @test size(init(rng, 3, 4)) == (3, 4) - @test size(init(3, 4, 5)) == (3, 4, 5) - @test size(init(rng, 3, 4, 5)) == (3, 4, 5) - # Type - @test eltype(init(rng, 4, 2)) == Float32 - @test eltype(init(4, 2)) == Float32 - # RNG Closure - cl = init(rng) - @test cl(3) isa arrtype{Float32, 1} - @test cl(3, 5) isa arrtype{Float32, 2} - end - - @testset "Sizes and Types: $init" for (init, fp) in [ - (zeros16, Float16), (zerosC16, ComplexF16), (zeros32, Float32), - (zerosC32, ComplexF32), (zeros64, Float64), (zerosC64, ComplexF64), - (ones16, Float16), (onesC16, ComplexF16), (ones32, Float32), - (onesC32, ComplexF32), (ones64, Float64), (onesC64, ComplexF64), - (rand16, Float16), (randC16, ComplexF16), (rand32, Float32), - (randC32, ComplexF32), (rand64, Float64), (randC64, ComplexF64), - (randn16, Float16), (randnC16, ComplexF16), (randn32, Float32), - (randnC32, ComplexF32), (randn64, Float64), (randnC64, ComplexF64)] - # Sizes - @test size(init(3)) == (3,) - @test size(init(rng, 3)) == (3,) - @test size(init(3, 4)) == (3, 4) - @test size(init(rng, 3, 4)) == (3, 4) - @test size(init(3, 4, 5)) == (3, 4, 5) - @test size(init(rng, 3, 4, 5)) == (3, 4, 5) - # Type - @test eltype(init(rng, 4, 2)) == fp - @test eltype(init(4, 2)) == fp - # RNG Closure - cl = init(rng) - @test cl(3) isa arrtype{fp, 1} - @test cl(3, 5) isa arrtype{fp, 2} - end - - @testset "AbstractArray Type: $init $T" for init in [ - kaiming_uniform, kaiming_normal, glorot_uniform, - glorot_normal, truncated_normal, identity_init], - T in (Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) - - init === truncated_normal && !(T <: Real) && continue - - @test init(T, 3) isa AbstractArray{T, 1} - @test init(rng, T, 3) isa arrtype{T, 1} - @test init(T, 3, 5) isa AbstractArray{T, 2} - @test init(rng, T, 3, 5) isa arrtype{T, 2} - - cl = init(rng) - @test cl(T, 3) isa arrtype{T, 1} - @test cl(T, 3, 5) isa arrtype{T, 2} - - cl = init(rng, T) - @test cl(3) isa arrtype{T, 1} - @test cl(3, 5) isa arrtype{T, 2} - end - - @testset "Closure: $init" for init in [ - kaiming_uniform, kaiming_normal, glorot_uniform, - glorot_normal, truncated_normal, identity_init] - cl = init(;) - # Sizes - @test size(cl(3)) == (3,) - @test size(cl(rng, 3)) == (3,) - @test size(cl(3, 4)) == (3, 4) - @test size(cl(rng, 3, 4)) == (3, 4) - @test size(cl(3, 4, 5)) == (3, 4, 5) - @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) - # Type - @test eltype(cl(4, 2)) == Float32 - @test eltype(cl(rng, 4, 2)) == Float32 - end - - @testset "Kwargs types" for T in ( - Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64) - if (T <: Real) - @test eltype(truncated_normal(T, 2, 5; mean=0, std=1, lo=-2, hi=2)) == T - @test eltype(orthogonal(T, 2, 5; gain=1.0)) == T - end - @test eltype(glorot_uniform(T, 2, 5; gain=1.0)) == T - @test eltype(glorot_normal(T, 2, 5; gain=1.0)) == T - @test eltype(kaiming_uniform(T, 2, 5; gain=sqrt(2))) == T - @test eltype(kaiming_normal(T, 2, 5; gain=sqrt(2))) == T - @test eltype(identity_init(T, 2, 5; gain=1.0)) == T - @test eltype(sparse_init(T, 2, 5; sparsity=0.5, std=0.01)) == T - end - - @testset "kaiming" begin - # kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)] - # and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out) - for (n_in, n_out) in [(100, 100), (100, 400)] - v = kaiming_uniform(rng, n_in, n_out) - σ2 = sqrt(6 / n_out) - @test -1σ2 < minimum(v) < -0.9σ2 - @test 0.9σ2 < maximum(v) < 1σ2 - - v = kaiming_normal(rng, n_in, n_out) - σ2 = sqrt(2 / n_out) - @test 0.9σ2 < std(v) < 1.1σ2 - end - # Type - @test eltype(kaiming_uniform(rng, 3, 4; gain=1.5f0)) == Float32 - @test eltype(kaiming_normal(rng, 3, 4; gain=1.5f0)) == Float32 - end - - @testset "glorot: $init" for init in [glorot_uniform, glorot_normal] - # glorot_uniform and glorot_normal should both yield a kernel with - # variance ≈ 2/(fan_in + fan_out) - for dims in [(1000,), (100, 100), (100, 400), (2, 3, 32, 64), (2, 3, 4, 32, 64)] - v = init(dims...) - fan_in, fan_out = WeightInitializers._nfan(dims...) - σ2 = 2 / (fan_in + fan_out) - @test 0.9σ2 < var(v) < 1.1σ2 - end - @test eltype(init(3, 4; gain=1.5)) == Float32 - end - - @testset "orthogonal" begin - # A matrix of dim = (m,n) with m > n should produce a QR decomposition. In the other case, the transpose should be taken to compute the QR decomposition. - for (rows, cols) in [(5, 3), (3, 5)] - v = orthogonal(rows, cols) - rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) - end - for mat in [(3, 4, 5), (2, 2, 5)] - v = orthogonal(mat...) - cols = mat[end] - rows = div(prod(mat), cols) - v = reshape(v, (rows, cols)) - rows < cols ? (@test v * v' ≈ I(rows)) : (@test v' * v ≈ I(cols)) - end - @test eltype(orthogonal(3, 4; gain=1.5)) == Float32 - end - end - - @testset "Orthogonal rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes - # A matrix of dim = (m,n) with m > n should produce a QR decomposition. - # In the other case, the transpose should be taken to compute the QR decomposition. - for (rows, cols) in [(5, 3), (3, 5)] - v = orthogonal(rng, rows, cols) - CUDA.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : - (@test v' * v ≈ I(cols)) - end - for mat in [(3, 4, 5), (2, 2, 5)] - v = orthogonal(rng, mat...) - cols = mat[end] - rows = div(prod(mat), cols) - v = reshape(v, (rows, cols)) - CUDA.@allowscalar rows < cols ? (@test v * v' ≈ I(rows)) : - (@test v' * v ≈ I(cols)) - end - # Type - @testset "Orthogonal Types $T" for T in (Float32, Float64)#(Float16, Float32, Float64) - @test eltype(orthogonal(rng, T, 3, 4; gain=1.5)) == T - @test eltype(orthogonal(rng, T, 3, 4, 5; gain=1.5)) == T - end - @testset "Orthogonal AbstractArray Type $T" for T in (Float32, Float64)#(Float16, Float32, Float64) - @test orthogonal(T, 3, 5) isa AbstractArray{T, 2} - @test orthogonal(rng, T, 3, 5) isa arrtype{T, 2} - - cl = orthogonal(rng) - @test cl(T, 3, 5) isa arrtype{T, 2} - - cl = orthogonal(rng, T) - @test cl(3, 5) isa arrtype{T, 2} - end - @testset "Orthogonal Closure" begin - cl = orthogonal(;) - # Sizes - @test size(cl(3, 4)) == (3, 4) - @test size(cl(rng, 3, 4)) == (3, 4) - @test size(cl(3, 4, 5)) == (3, 4, 5) - @test size(cl(rng, 3, 4, 5)) == (3, 4, 5) - # Type - @test eltype(cl(4, 2)) == Float32 - @test eltype(cl(rng, 4, 2)) == Float32 - end - end - - @testset "sparse_init rng = $(typeof(rng)) & arrtype = $arrtype" for (rng, arrtype) in rngs_arrtypes - # sparse_init should yield an error for non 2-d dimensions - # sparse_init should yield no zero elements if sparsity < 0 - # sparse_init should yield all zero elements if sparsity > 1 - # sparse_init should yield exactly ceil(n_in * sparsity) elements in each column for other sparsity values - # sparse_init should yield a kernel in its non-zero elements consistent with the std parameter - - @test_throws ArgumentError sparse_init(3, 4, 5, sparsity=0.1) - @test_throws ArgumentError sparse_init(3, sparsity=0.1) - v = sparse_init(100, 100; sparsity=-0.1) - @test sum(v .== 0) == 0 - v = sparse_init(100, 100; sparsity=1.1) - @test sum(v .== 0) == length(v) - - for (n_in, n_out, sparsity, σ) in [(100, 100, 0.25, 0.1), (100, 400, 0.75, 0.01)] - expected_zeros = ceil(Integer, n_in * sparsity) - v = sparse_init(n_in, n_out; sparsity=sparsity, std=σ) - @test all([sum(v[:, col] .== 0) == expected_zeros for col in 1:n_out]) - @test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ - end - - # Type - @testset "sparse_init Types $T" for T in (Float16, Float32, Float64) - @test eltype(sparse_init(rng, T, 3, 4; sparsity=0.5)) == T - end - @testset "sparse_init AbstractArray Type $T" for T in (Float16, Float32, Float64) - @test sparse_init(T, 3, 5; sparsity=0.5) isa AbstractArray{T, 2} - @test sparse_init(rng, T, 3, 5; sparsity=0.5) isa arrtype{T, 2} - - cl = sparse_init(rng; sparsity=0.5) - @test cl(T, 3, 5) isa arrtype{T, 2} - - cl = sparse_init(rng, T; sparsity=0.5) - @test cl(3, 5) isa arrtype{T, 2} - end - @testset "sparse_init Closure" begin - cl = sparse_init(; sparsity=0.5) - # Sizes - @test size(cl(3, 4)) == (3, 4) - @test size(cl(rng, 3, 4)) == (3, 4) - # Type - @test eltype(cl(4, 2)) == Float32 - @test eltype(cl(rng, 4, 2)) == Float32 - end - end - - @testset "identity_init" begin - @testset "Non-identity sizes" begin - @test identity_init(2, 3)[:, end] == zeros(Float32, 2) - @test identity_init(3, 2; shift=1)[1, :] == zeros(Float32, 2) - @test identity_init(1, 1, 3, 4)[:, :, :, end] == zeros(Float32, 1, 1, 3) - @test identity_init(2, 1, 3, 3)[end, :, :, :] == zeros(Float32, 1, 3, 3) - @test identity_init(1, 2, 3, 3)[:, end, :, :] == zeros(Float32, 1, 3, 3) - end - end - - @testset "Warning: truncated_normal" begin - @test_warn "Mean is more than 2 std outside the limits in truncated_normal, so \ - the distribution of values may be inaccurate." truncated_normal(2; mean=-5.0f0) - end - - @testset "Aqua: Quality Assurance" begin - Aqua.test_all(WeightInitializers; ambiguities=false) - Aqua.test_ambiguities(WeightInitializers; recursive=false) - end -end +ReTestItems.runtests(@__DIR__) diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl new file mode 100644 index 0000000..5b18e59 --- /dev/null +++ b/test/shared_testsetup.jl @@ -0,0 +1,20 @@ +@testsetup module SharedTestSetup + +using CUDA, Random, StableRNGs + +CUDA.allowscalar(false) + +const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All")) + +RNGS_ARRTYPES = [] +if BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" + append!(RNGS_ARRTYPES, + [(StableRNG(12345), AbstractArray), (Random.GLOBAL_RNG, AbstractArray)]) +end +if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda" + push!(RNGS_ARRTYPES, (CUDA.default_rng(), CuArray)) +end + +export StableRNG, RNGS_ARRTYPES + +end diff --git a/test/utils_tests.jl b/test/utils_tests.jl new file mode 100644 index 0000000..c6c2b62 --- /dev/null +++ b/test/utils_tests.jl @@ -0,0 +1,9 @@ +@testitem "_nfan" begin + using WeightInitializers: _nfan + + @test _nfan() == (1, 1) # Fallback + @test _nfan(4) == (1, 4) # Vector + @test _nfan(4, 5) == (5, 4) # Matrix + @test _nfan((4, 5, 6)) == _nfan(4, 5, 6) # Tuple + @test _nfan(4, 5, 6) == 4 .* (5, 6) # Convolution +end