Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Type check for kwargs #22

Merged
merged 6 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "WeightInitializers"
uuid = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.6"
version = "0.1.7"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
6 changes: 3 additions & 3 deletions ext/WeightInitializersCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function sparse_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...;
rows, cols = dims
prop_zero = min(1.0, sparsity)
num_zeros = ceil(Integer, prop_zero * rows)
sparse_array = randn(rng, T, dims...) .* std
sparse_array = randn(rng, T, dims...) .* T(std)
sparse_array[1:num_zeros, :] .= CUDA.zero(T)

return CUDA.@allowscalar mapslices(shuffle, sparse_array, dims=1)
Expand All @@ -46,7 +46,7 @@ function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...;
rows, cols = dims
mat = CUDA.zeros(T, rows, cols)
diag_indices = 1:min(rows, cols)
CUDA.fill!(view(mat, diag_indices, diag_indices), gain)
CUDA.fill!(view(mat, diag_indices, diag_indices), T(gain))
return CUDA.circshift(mat, shift)
else
# Convolution or more dimensions
Expand All @@ -56,7 +56,7 @@ function identity_init(rng::AbstractCuRNG, ::Type{T}, dims::Integer...;
#we should really find a better way to do this
CUDA.@allowscalar for i in 1:min(nin, nout)
index = (centers..., i, i)
weights[index...] = gain
weights[index...] = T(gain)
end
return CUDA.circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift))
end
Expand Down
18 changes: 9 additions & 9 deletions src/initializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ vision_. 2015.
"""
function kaiming_uniform(rng::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=√T(2)) where {T <: Number}
bound = √T(3) * gain / sqrt(T(first(_nfan(dims...))))
bound = √T(3) * T(gain) / sqrt(T(first(_nfan(dims...))))
return (rand(rng, T, dims...) .- T(1 // 2)) .* 2 * bound
end

Expand All @@ -94,7 +94,7 @@ vision_. 2015.
"""
function kaiming_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...;
gain::Number=√T(2)) where {T <: Number}
std = gain / sqrt(T(first(_nfan(dims...))))
std = T(gain) / sqrt(T(first(_nfan(dims...))))
return randn(rng, T, dims...) .* std
end

Expand All @@ -111,13 +111,13 @@ function truncated_normal(rng::AbstractRNG, ::Type{T}, dims::Integer...; mean=T(
if (mean < lo - 2 * std) || (mean > hi + 2 * std)
@warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate."
end
l = _norm_cdf((lo - mean) / std)
u = _norm_cdf((hi - mean) / std)
l = _norm_cdf((T(lo) - T(mean)) / T(std))
u = _norm_cdf((T(hi) - T(mean)) / T(std))
xs = rand(rng, T, dims...)
broadcast!(xs, xs) do x
x = x * 2(u - l) + (2l - 1)
x = erfinv(x)
return clamp(x * std * √2 + mean, lo, hi)
return clamp(x * T(std) * √2 + T(mean), T(lo), T(hi))
end
return xs
end
Expand Down Expand Up @@ -162,7 +162,7 @@ function orthogonal(rng::AbstractRNG, ::Type{T}, dims::Integer...;
end

if rows < cols
return permutedims(orthogonal(rng, T, cols, rows; gain))
return permutedims(orthogonal(rng, T, cols, rows; gain=T(gain)))
end

mat = randn(rng, T, rows, cols)
Expand Down Expand Up @@ -236,7 +236,7 @@ function sparse_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
rows, cols = dims
prop_zero = min(1.0, sparsity)
num_zeros = ceil(Integer, prop_zero * rows)
sparse_array = randn(rng, T, dims...) .* std
sparse_array = randn(rng, T, dims...) .* T(std)
sparse_array[1:num_zeros, :] .= zero(T)
return mapslices(shuffle, sparse_array; dims=1)
end
Expand Down Expand Up @@ -313,7 +313,7 @@ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
rows, cols = dims
mat = zeros(T, rows, cols)
for i in 1:min(rows, cols)
mat[i, i] = gain
mat[i, i] = T(gain)
end
return circshift(mat, shift)
else
Expand All @@ -323,7 +323,7 @@ function identity_init(rng::AbstractRNG, ::Type{T}, dims::Integer...;
weights = zeros(T, dims...)
for i in 1:min(nin, nout)
index = (centers..., i, i)
weights[index...] = gain
weights[index...] = T(gain)
end
return circshift(weights, (ntuple(d -> 0, length(dims) - 2)..., shift, shift))
end
Expand Down
14 changes: 14 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,20 @@ const GROUP = get(ENV, "GROUP", "All")
@test eltype(cl(rng, 4, 2)) == Float32
end

@testset "Kwargs types" for T in (
Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64)
if (T <: Real)
@test eltype(truncated_normal(T, 2, 5; mean=0, std=1, lo=-2, hi=2)) == T
@test eltype(orthogonal(T, 2, 5; gain=1.0)) == T
end
@test eltype(glorot_uniform(T, 2, 5; gain=1.0)) == T
@test eltype(glorot_normal(T, 2, 5; gain=1.0)) == T
@test eltype(kaiming_uniform(T, 2, 5; gain=sqrt(2))) == T
@test eltype(kaiming_normal(T, 2, 5; gain=sqrt(2))) == T
@test eltype(identity_init(T, 2, 5; gain=1.0)) == T
@test eltype(sparse_init(T, 2, 5; sparsity=0.5, std=0.01)) == T
end

@testset "kaiming" begin
# kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)]
# and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out)
Expand Down
Loading