Skip to content

Commit

Permalink
Code coverage and documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
msainsburydale committed Dec 17, 2023
1 parent 8bc02d1 commit e27e7f5
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 64 deletions.
45 changes: 14 additions & 31 deletions src/bootstrap.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
"""
interval(θ̃, θ̂ = nothing; type::String, probs = [0.05, 0.95], parameter_names)
interval(θ̃::Matrix; probs = [0.05, 0.95], parameter_names = nothing)
interval(estimator::IntervalEstimator, Z; parameter_names = nothing, use_gpu = true)
Compute a confidence interval using the p × B matrix of bootstrap samples, `θ̃`,
where p is the number of parameters in the model.
Compute a confidence interval based on a p × B matrix of bootstrap estimates, `θ̃`,
where p is the number of parameters in the model, or from an `IntervalEstimator`
and data `Z`.
If `type = "quantile"`, the interval is constructed by simply taking the quantiles of
`θ̃`, and if `type = "reverse-quantile"`, the so-called
[reverse-quantile](https://en.wikipedia.org/wiki/Bootstrapping_(statistics)#Methods_for_bootstrap_confidence_intervals)
method is used. In both cases, the quantile levels are controlled by the argument `probs`.
The rows can be named with a vector of strings `parameter_names`.
The bootstrap-based interval is constructed by taking the quantiles of `θ̃`,
where the quantile levels are controlled by the keyword argument `probs`.
The return type is a p × 2 matrix, whose first and second columns respectively
contain the lower and upper bounds of the interval.
contain the lower and upper bounds of the interval. The rows of this matrix can
be named by passing a vector of strings to the keyword argument `parameter_names`.
# Examples
```
Expand All @@ -22,25 +21,14 @@ B = 50
θ̃ = rand(p, B)
θ̂ = rand(p)
interval(θ̃)
interval(θ̃, θ̂, type = "basic")
```
"""
function interval(θ̃, θ̂ = nothing; type::String = "percentile", probs = [0.05, 0.95], parameter_names = ["θ$i" for i 1:size(θ̃, 1)])
function interval(θ̃; probs = [0.05, 0.95], parameter_names = ["θ$i" for i 1:size(θ̃, 1)])

#TODO add assertions and add type on θ̃
p, B = size(θ̃)
type = lowercase(type)

if type ["percentile", "quantile"]
ci = mapslices(x -> quantile(x, probs), θ̃, dims = 2)
elseif type ["basic", "reverse-percentile", "reverse-quantile"]
isnothing(θ̂) && error("`θ̂` must be provided if `type` is 'basic', 'reverse-percentile', or 'reverse-quantile'")
q = mapslices(x -> quantile(x, probs), θ̃, dims = 2)
ci = [[2θ̂[i] - q[i, 2], 2θ̂[i] - q[i, 1]] for i 1:p]
ci = hcat(ci...)'
else
error("argument `type` not matched: it should be one of 'percentile', 'basic', 'studentised', or 'bca'.")
end

# Compute the quantiles
ci = mapslices(x -> quantile(x, probs), θ̃, dims = 2)

# Add labels to the confidence intervals
l = ci[:, 1]
Expand All @@ -49,7 +37,6 @@ function interval(θ̃, θ̂ = nothing; type::String = "percentile", probs = [0.
end


#TODO need to document this method
function interval(estimator::Union{IntervalEstimator, IntervalEstimatorCompactPrior, PointIntervalEstimator}, Z; parameter_names = nothing, use_gpu::Bool = true)

ci = estimateinbatches(estimator, Z, use_gpu = use_gpu)
Expand Down Expand Up @@ -94,11 +81,7 @@ function labelinterval(ci::M, parameter_names = ["θ$i" for i ∈ (size(ci, 1)
p = size(ci, 1) ÷ 2
K = size(ci, 2)

map(1:K) do k
lₖ = ci[1:p, k]
uₖ = ci[(p+1):end, k]
labelinterval(lₖ, uₖ, parameter_names)
end
[labelinterval(ci[:, k], parameter_names) for k 1:K]
end

# ---- Parameteric bootstrap ----
Expand Down
70 changes: 37 additions & 33 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -618,9 +618,6 @@ m = 10 # default sample size
# interval
θ̃ = bootstrap(θ̂, pars, simulator, m; use_gpu = use_gpu)
@test size(interval(θ̃)) == (p, 2)
# @test size(interval(θ̃, θ̂(Z), type = "basic")) == (p, 2) #FIXME broken on the GPU
@test_throws Exception interval(θ̃, type = "basic")
@test_throws Exception interval(θ̃, type = "zxcvbnm")
end
end
end
Expand Down Expand Up @@ -674,6 +671,22 @@ end

# ---- Estimators ----

@testset "initialise_estimator" begin
p = 2
initialise_estimator(p, architecture = "DNN")
initialise_estimator(p, architecture = "GNN")
initialise_estimator(p, architecture = "CNN", kernel_size = [(10, 10), (5, 5), (3, 3)])

@test typeof(initialise_estimator(p, architecture = "DNN", estimator_type = "interval")) <: IntervalEstimator
@test typeof(initialise_estimator(p, architecture = "GNN", estimator_type = "interval")) <: IntervalEstimator
@test typeof(initialise_estimator(p, architecture = "CNN", kernel_size = [(10, 10), (5, 5), (3, 3)], estimator_type = "interval")) <: IntervalEstimator

@test_throws Exception initialise_estimator(0, architecture = "DNN")
@test_throws Exception initialise_estimator(p, d = 0, architecture = "DNN")
@test_throws Exception initialise_estimator(p, architecture = "CNN")
@test_throws Exception initialise_estimator(p, architecture = "CNN", kernel_size = [(10, 10), (5, 5)])
end

@testset "PiecewiseEstimator" begin
@test_throws Exception PiecewiseEstimator((MLE, MLE), (30, 50))
@test_throws Exception PiecewiseEstimator((MLE, MLE, MLE), (50, 30))
Expand All @@ -686,44 +699,35 @@ end
end


@testset "IntervalEstimator" begin
# Generate some toy data
n = 2 # bivariate data
m = 256 # number of independent replicates
Z = rand(n, m)

# Create an architecture
@testset "IntervalEstimators" begin
# Generate some toy data and a basic architecture
d = 2 # bivariate data
m = 64 # number of independent replicates
Z = rand(d, m)
parameter_names = ["ρ", "σ", "τ"]
p = length(parameter_names)
w = 8 # width of each layer
ψ = Chain(Dense(n, w, relu), Dense(w, w, relu));
ϕ = Chain(Dense(w, w, relu), Dense(w, p));
architecture = DeepSet(ψ, ϕ)

# Initialise the interval estimator
estimator = IntervalEstimator(architecture)
arch = initialise_estimator(p, architecture = "DNN", d = d, width = 8)

# IntervalEstimator
estimator = IntervalEstimator(arch)
θ̂ = estimator(Z)
@test size(θ̂) == (2p, 1)
@test all(θ̂[1:p] .< θ̂[(p+1):end])
ci = interval(estimator, Z)
ci = interval(estimator, Z, parameter_names = parameter_names)
@test size(ci[1]) == (p, 2)

# Apply the interval estimator
estimator(Z)
# PointIntervalEstimator
estimator = PointIntervalEstimator(arch)
θ̂ = estimator(Z)
@test size(θ̂) == (3p, 1)
@test all(θ̂[(p+1):2p] .< θ̂[1:p] .< θ̂[(2p+1):end])
ci = interval(estimator, Z)
ci = interval(estimator, Z, parameter_names = parameter_names)
@test size(ci[1]) == (p, 2)
end

@testset "initialise_estimator" begin
p = 2
initialise_estimator(p, architecture = "DNN")
initialise_estimator(p, architecture = "GNN")
initialise_estimator(p, architecture = "CNN", kernel_size = [(10, 10), (5, 5), (3, 3)])

@test typeof(initialise_estimator(p, architecture = "DNN", estimator_type = "interval")) <: IntervalEstimator
@test typeof(initialise_estimator(p, architecture = "GNN", estimator_type = "interval")) <: IntervalEstimator
@test typeof(initialise_estimator(p, architecture = "CNN", kernel_size = [(10, 10), (5, 5), (3, 3)], estimator_type = "interval")) <: IntervalEstimator

@test_throws Exception initialise_estimator(0, architecture = "DNN")
@test_throws Exception initialise_estimator(p, d = 0, architecture = "DNN")
@test_throws Exception initialise_estimator(p, architecture = "CNN")
@test_throws Exception initialise_estimator(p, architecture = "CNN", kernel_size = [(10, 10), (5, 5)])
end

@testset "NeuralEM" begin

Expand Down

0 comments on commit e27e7f5

Please sign in to comment.