Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Type check for kwargs #22

Merged
merged 6 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "WeightInitializers"
uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.6"
version = "0.1.7"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 2 additions & 0 deletions ext/WeightInitializersCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...;
throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization."))
end

std = std isa T ? std : convert(T, std)
rows, cols = dims
prop_zero = min(1.0, sparsity)
num_zeros = ceil(Integer, prop_zero * rows)
Expand All @@ -38,6 +39,7 @@ end

function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...;
gain::Number=1, shift::Integer=0) where {T <: Number}
gain = gain isa T ? gain : convert(T, gain)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to check for this. The compiler is smart enough to get rid of these

Copy link
Member Author

@MartinuzziFrancesco MartinuzziFrancesco Mar 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also get rid of this then?

Copy link
Member Author

@MartinuzziFrancesco MartinuzziFrancesco Mar 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the T(gain) specifically

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry scratch that, I misunderstood

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant gain isa T ? gain : convert(T, gain) can be written as T(gain) and the compiler should be smart to remove it if the types match.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see, should I take that approach with all of the checks?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

if length(dims) == 1
# Bias initialization
return CUDA.zeros(T, dims...)
Expand Down
13 changes: 12 additions & 1 deletion src/initializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ artificial intelligence and statistics_. 2010.
"""
function glorot_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=1) where {T <: Number}
scale = T(gain) * sqrt(T(24) / sum(_nfan(dims...)))
gain = gain isa T ? gain : convert(T, gain)
scale = gain * sqrt(T(24) / sum(_nfan(dims...)))
return (rand(rng, T, dims...) .- T(1 // 2)) .* scale
end

Expand All @@ -56,6 +57,7 @@ artificial intelligence and statistics_. 2010.
"""
function glorot_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=1) where {T <: Number}
gain = gain isa T ? gain : convert(T, gain)
std = T(gain) * sqrt(T(2) / sum(_nfan(dims...)))
return randn(rng, T, dims...) .* std
end
Expand All @@ -75,6 +77,7 @@ vision_. 2015.
"""
function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=√T(2)) where {T <: Number}
gain = gain isa T ? gain : convert(T, gain)
bound = √T(3) * gain / sqrt(T(first(_nfan(dims...))))
return (rand(rng, T, dims...) .- T(1 // 2)) .* 2 * bound
end
Expand All @@ -94,6 +97,7 @@ vision_. 2015.
"""
function kaiming_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=√T(2)) where {T <: Number}
gain = gain isa T ? gain : convert(T, gain)
std = gain / sqrt(T(first(_nfan(dims...))))
return randn(rng, T, dims...) .* std
end
Expand All @@ -111,6 +115,10 @@ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T(
if (mean < lo - 2 * std) || (mean > hi + 2 * std)
@warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate."
end
mean = mean isa T ? mean : convert(T, mean)
std = std isa T ? std : convert(T, std)
lo = lo isa T ? lo : convert(T, lo)
hi = hi isa T ? hi : convert(T, hi)
l = _norm_cdf((lo - mean) / std)
u = _norm_cdf((hi - mean) / std)
xs = rand(rng, T, dims...)
Expand Down Expand Up @@ -153,6 +161,7 @@ ICLR 2014, https://arxiv.org/abs/1312.6120
function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=T(1.0)) where {T <: Number}
@assert length(dims)>1 "Creating vectors (length(dims) == 1) is not allowed"
gain = gain isa T ? gain : convert(T, gain)

if length(dims) == 2
rows, cols = dims
Expand Down Expand Up @@ -233,6 +242,7 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization."))
end

std = std isa T ? std : convert(T, std)
rows, cols = dims
prop_zero = min(1.0, sparsity)
num_zeros = ceil(Integer, prop_zero * rows)
Expand Down Expand Up @@ -305,6 +315,7 @@ identity_tensor = identity_init(MersenneTwister(123),
"""
function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=1, shift::Integer=0) where {T <: Number}
gain = gain isa T ? gain : convert(T, gain)
if length(dims) == 1
# Bias initialization
return zeros(T, dims...)
Expand Down
14 changes: 14 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,20 @@ const GROUP = get(ENV, "GROUP", "All")
@test eltype(cl(rng, 4, 2)) == Float32
end

@testset "Kwargs types" for T in (
Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64)
if (T <: Real)
@test eltype(truncated_normal(T, 2, 5; mean=0, std=1, lo=-2, hi=2)) == T
@test eltype(orthogonal(T, 2, 5; gain=1.0)) == T
end
@test eltype(glorot_uniform(T, 2, 5; gain=1.0)) == T
@test eltype(glorot_normal(T, 2, 5; gain=1.0)) == T
@test eltype(kaiming_uniform(T, 2, 5; gain=sqrt(2))) == T
@test eltype(kaiming_normal(T, 2, 5; gain=sqrt(2))) == T
@test eltype(identity_init(T, 2, 5; gain=1.0)) == T
@test eltype(sparse_init(T, 2, 5; sparsity=0.5, std=0.01)) == T
end

@testset "kaiming" begin
# kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)]
# and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out)
Expand Down
Loading