Skip to content

Commit

Permalink
Documentation; removed IntervalEstimatorCompactPrior, QuantileEstimat…
Browse files Browse the repository at this point in the history
…or, and PointIntervalEstimator
  • Loading branch information
msainsburydale committed Dec 17, 2023
1 parent e27e7f5 commit 385a85b
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 234 deletions.
6 changes: 0 additions & 6 deletions docs/src/API/core.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ PointEstimator
IntervalEstimator
IntervalEstimatorCompactPrior
PointIntervalEstimator
QuantileEstimator
PiecewiseEstimator
```

Expand Down
218 changes: 38 additions & 180 deletions src/Estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,124 +12,67 @@ abstract type NeuralEstimator end
"""
PointEstimator(arch)
A simple point estimator, that is, a mapping from the sample space to the
parameter space, defined by the given architecture `arch`.
A neural point estimator, that is, a mapping from the sample space to the
parameter space, defined by the given neural-network architecture `arch`.
"""
struct PointEstimator{F} <: NeuralEstimator
arch::F
# PointEstimator(arch) = isa(arch, PointEstimator) ? error("Please do not construct PointEstimator objects with another PointEstimator") : new(arch)
end
@functor PointEstimator (arch,)
(est::PointEstimator)(Z) = est.arch(Z)


# ---- IntervalEstimator: credible intervals ----
# ---- IntervalEstimator for amortised credible intervals ----

"""
IntervalEstimator(arch_lower, arch_upper)
IntervalEstimator(arch)
A neural interval estimator that jointly estimates credible intervals constructed as,
IntervalEstimator(u)
IntervalEstimator(u, v)
IntervalEstimator(u, v, g::Compress)
IntervalEstimator(u, v, min_supp::Vector, max_supp::Vector)
```math
[l(Z), l(Z) + \\mathrm{exp}(u(Z))],
```
where ``l(⋅)`` and ``u(⋅)`` are the neural networks `arch_lower` and
`arch_upper`, both of which should transform data into ``p``-dimensional vectors,
where ``p`` is the number of parameters in the statistical model. If only a
single neural network architecture `arch` is provided, it will be used for both
`arch_lower` and `arch_upper`.
The returned value is a matrix with ``2p`` rows, where the first and second ``p``
rows correspond to estimates of the lower and upper bound, respectively.
See also [`IntervalEstimatorCompactPrior`](@ref).
# Examples
```
using NeuralEstimators
using Flux
# Generate some toy data
n = 2 # bivariate data
m = 100 # number of independent replicates
Z = rand(n, m)
# Create an architecture
p = 3 # number of parameters in the statistical model
w = 8 # width of each layer
ψ = Chain(Dense(n, w, relu), Dense(w, w, relu));
ϕ = Chain(Dense(w, w, relu), Dense(w, p));
architecture = DeepSet(ψ, ϕ)
# Initialise the interval estimator
estimator = IntervalEstimator(architecture)
# Apply the interval estimator
estimator(Z)
interval(estimator, Z)
```
"""
struct IntervalEstimator{F, G} <: NeuralEstimator
l::F
u::G
end
IntervalEstimator(l) = IntervalEstimator(l, deepcopy(l))
@functor IntervalEstimator
function (est::IntervalEstimator)(Z)
l = est.l(Z)
vcat(l, l .+ exp.(est.u(Z)))
end
# Ensure that IntervalEstimator objects are not constructed with PointEstimator:
#TODO find a neater way to do this; don't want to write so many methods, especially for PointIntervalEstimator
IntervalEstimator(l::PointEstimator, u::PointEstimator) = IntervalEstimator(l.arch, u.arch)
IntervalEstimator(l, u::PointEstimator) = IntervalEstimator(l, u.arch)
IntervalEstimator(l::PointEstimator, u) = IntervalEstimator(l.arch, u)

"""
IntervalEstimatorCompactPrior(u, v, min_supp::Vector, max_supp::Vector)
IntervalEstimatorCompactPrior(u, v, compress::Compress)
Uses the neural networks `u` and `v` to jointly estimate credible intervals
that are guaranteed to be within the support of the prior distributon. This
support is defined by the ``p``-dimensional vectors `min_supp` and `max_supp`
(or a single ``p``-dimensional object of type `Compress`), where ``p`` is the
number of parameters in the statistical model.
Given data ``Z``, the intervals are constructed as
A neural interval estimator which, given data ``Z``, jointly estimates credible
intervals in the form,
```math
[g(u(Z)), g(u(Z)) + \\mathrm{exp}(v(Z)))],
```
where
- ``u(⋅)`` and ``v(⋅)`` are neural networks, both of which should transform data into ``p``-dimensional vectors;
- ``u(⋅)`` and ``v(⋅)`` are neural networks, both of which should transform data into ``p``-dimensional vectors (with ``p`` the number of parameters in the statistical model);
- ``g(⋅)`` is a logistic function that maps its input to the prior support.
The prior support is defined either by the ``p``-dimensional vectors `min_supp`
and `max_supp`, or a single ``p``-dimensional object of type [`Compress`](@ref).
If these objects are not given, the range of the intervals will be unrestricted.
Note that, in addition to ensuring that the interval remains in the prior support,
this construction also ensures that the intervals are valid (i.e., it prevents
quantile crossing, in the sense that the upper bound is always greater than the
lower bound).
If only a single neural-network architecture is provided, it will be used
for both `u` and `v`.
The returned value is a matrix with ``2p`` rows, where the first and second ``p``
rows correspond to estimates of the lower and upper bound, respectively.
See also [`IntervalEstimator`](@ref) and [`Compress`](@ref).
# Examples
```
using NeuralEstimators
using Flux
# prior support
min_supp = [25, 0.5, -pi/2]
max_supp = [500, 2.5, 0]
p = length(min_supp) # number of parameters in the statistical model
# Generate some toy data
n = 2 # bivariate data
m = 100 # number of independent replicates
Z = rand(n, m)
# prior support
min_supp = [25, 0.5, -pi/2]
max_supp = [500, 2.5, 0]
p = length(min_supp) # number of parameters in the statistical model
# Create an architecture
w = 8 # width of each layer
ψ = Chain(Dense(n, w, relu), Dense(w, w, relu));
Expand All @@ -138,116 +81,31 @@ u = DeepSet(ψ, ϕ)
v = deepcopy(u) # use the same architecture for both u and v
# Initialise the interval estimator
estimator = IntervalEstimatorCompactPrior(u, v, min_supp, max_supp)
estimator = IntervalEstimator(u, v, min_supp, max_supp)
# Apply the interval estimator
estimator(Z)
interval(estimator, Z)
```
"""
struct IntervalEstimatorCompactPrior{F, G} <: NeuralEstimator
struct IntervalEstimator{F, G} <: NeuralEstimator
u::F
v::G
c::Compress
end
IntervalEstimatorCompactPrior(u, v, min_supp, max_supp) = IntervalEstimatorCompactPrior(u, v, Compress(min_supp, max_supp))
@functor IntervalEstimatorCompactPrior
Flux.trainable(est::IntervalEstimatorCompactPrior) = (est.u, est.v)

function (est::IntervalEstimatorCompactPrior)(Z)
x = est.u(Z)
y = x .+ exp.(est.v(Z))
c = est.c
vcat(c(x), c(y))
end

"""
PointIntervalEstimator(arch_point, arch_lower, arch_upper)
PointIntervalEstimator(arch_point, arch_bound)
PointIntervalEstimator(arch)
A neural estimator that jointly produces point estimates, ``θ̂(Z)``, where ``θ̂(⋅)`` is a
neural point estimator with architecture `arch_point`, and credible intervals constructed as,
```math
[θ̂(Z) - \\mathrm{exp}(l(Z)), θ̂(Z) + \\mathrm{exp}(u(Z))],
```
where ``l(⋅)`` and ``u(⋅)`` are the neural networks `arch_lower` and
`arch_upper`, both of which should transform data into ``p``-dimensional vectors,
where ``p`` is the number of parameters in the statistical model.
If only a single neural network architecture `arch` is provided, it will be used
for all architectures; similarly, if two architectures are provided, the second
will be used for both `arch_lower` and `arch_upper`.
Internally, the point estimates, lower-bound estimates, and upper-bound estimates are concatenated, so
that `PointIntervalEstimator` objects transform data into matrices with ``3p`` rows.
# Examples
```
using NeuralEstimators
using Flux
# Generate some toy data
n = 2 # bivariate data
m = 100 # number of independent replicates
Z = rand(n, m)
# Create an architecture
p = 3 # number of parameters in the statistical model
w = 8 # width of each layer
ψ = Chain(Dense(n, w, relu), Dense(w, w, relu));
ϕ = Chain(Dense(w, w, relu), Dense(w, p));
architecture = DeepSet(ψ, ϕ)
# Initialise the estimator
estimator = PointIntervalEstimator(architecture)
# Apply the estimator
estimator(Z)
interval(estimator, Z)
```
"""
struct PointIntervalEstimator{H, F, G} <: NeuralEstimator
θ̂::H
l::F
u::G
g::Union{Function,Compress}
# IntervalEstimator(u, v, g) = any(isa.([u, v], PointEstimator)) ? error("Please do not construct IntervalEstimator objects with PointEstimators") : new(u, v, g)
end
PointIntervalEstimator(θ̂) = PointIntervalEstimator(θ̂, deepcopy(θ̂), deepcopy(θ̂))
PointIntervalEstimator(θ̂, l) = PointIntervalEstimator(θ̂, deepcopy(l), deepcopy(l))
@functor PointIntervalEstimator
function (est::PointIntervalEstimator)(Z)
θ̂ = est.θ̂(Z)
vcat(θ̂, θ̂ .- exp.(est.l(Z)), θ̂ .+ exp.(est.u(Z)))
end
# Ensure that IntervalEstimator objects are not constructed as a wrapper of PointEstimator:
PointIntervalEstimator(θ̂::PointEstimator, l::PointEstimator, u::PointEstimator) = PointIntervalEstimator(θ̂.arch, l.arch, u.arch)
PointIntervalEstimator(θ̂::PointEstimator, l, u) = PointIntervalEstimator(θ̂.arch, l, u)
PointIntervalEstimator(θ̂, l::PointEstimator, u::PointEstimator) = PointIntervalEstimator(θ̂, l.arch, u.arch)


# ---- QuantileEstimator: estimating arbitrary quantiles of the posterior distribution ----

# Should Follow up with this point from Gnieting's paper:
# 9.2 Quantile Estimation
# Koenker and Bassett (1978) proposed quantile regression using an optimum
# score estimator based on the proper scoring rule (41).


#TODO this is a topic of ongoing research with Jordan
"""
QuantileEstimator()
Coming soon: this structure will allow for the simultaneous estimation of an
arbitrary number of marginal quantiles of the posterior distribution.
"""
struct QuantileEstimator{F, G} <: NeuralEstimator
l::F
u::G
IntervalEstimator(u) = IntervalEstimator(u, deepcopy(u), identity)
IntervalEstimator(u, v) = IntervalEstimator(u, v, identity)
IntervalEstimator(u, g::Compress) = IntervalEstimator(u, deepcopy(u), g)
IntervalEstimator(u, min_supp, max_supp) = IntervalEstimator(u, Compress(min_supp, max_supp))
IntervalEstimator(u, v, min_supp, max_supp) = IntervalEstimator(u, v, Compress(min_supp, max_supp))
@functor IntervalEstimator
Flux.trainable(est::IntervalEstimator) = (est.u, est.v)
function (est::IntervalEstimator)(Z)
bₗ = est.u(Z) # lower bound
bᵤ = bₗ .+ exp.(est.v(Z)) # upper bound
vcat(est.g(bₗ), est.g(bᵤ))
end
# @functor QuantileEstimator
# (c::QuantileEstimator)(Z) = vcat(c.l(Z), c.l(Z) .+ exp.(c.u(Z)))



# ---- PiecewiseEstimator ----
Expand Down
46 changes: 29 additions & 17 deletions src/Graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ hidden-feature graphs; the `readout` module aggregates these feature graphs into
a single hidden feature vector of fixed length; the function `a`(⋅) is a
permutation-invariant aggregation function, and `ϕ` is a neural network.
The data should be stored as a `GNNGraph` or `AbstractVector{GNNGraph}`, where
The data should be stored as a `GNNGraph` or `Vector{GNNGraph}`, where
each graph is associated with a single parameter vector. The graphs may contain
sub-graphs corresponding to independent replicates from the model.
Expand Down Expand Up @@ -586,22 +586,34 @@ function (est::GNN)(g::GNNGraph, m::AbstractVector{I}) where {I <: Integer}
return est.deepset(h̃)
end

# Methods needed to accomodate above method of GNN. They are exactly the same as
# the standard methods defined in Estimators.jl, but also pass through m.
#TODO unit testing for these methods
#TODO Not ideal that there's so much code repetition... we're just replacing f(Z) with f(Z, m). Tried with the g(x...) = sum(x) approach; it almost worked, might be worth trying again.
# Also need a custom method for _updatebatch!()
# NB Surely there is a more generic way to dispatch here (e.g., any structure that contains GNN)
function _updatebatch!(θ̂::Union{GNN, PointEstimator{<:GNN}, IntervalEstimator{<:GNN}}, Z, θ, device, loss, γ, optimiser)

m = numberreplicates(Z)
Z = Flux.batch(Z)
Z, θ = Z |> device, θ |> device

# Compute gradients in such a way that the training loss is also saved.
# This is equivalent to: gradients = gradient(() -> loss(θ̂(Z), θ), γ)
ls, back = Zygote.pullback(() -> loss(θ̂(Z, m), θ), γ) # NB here we also pass m to θ̂, since Flux.batch() cannot be differentiated
gradients = back(one(ls))
update!(optimiser, γ, gradients)

# Assuming that loss returns an average, convert it to a sum.
ls = ls * size(θ)[end]
return ls
end


# Higher level methods needed to accomodate above methods for GNN. They are
# exactly the same as the standard methods defined in Estimators.jl, but we
# also pass through m.
#NB Not ideal that there's so much code repetition... we're just replacing
# f(Z) with f(Z, m). Tried with the g(x...) = sum(x) approach; it almost worked, might be worth trying again.
(est::PointEstimator{<:GNN})(Z::GNNGraph, m::AbstractVector{I}) where {I <: Integer} = est.arch(Z, m)
function (est::IntervalEstimator{<:GNN})(Z::GNNGraph, m::AbstractVector{I}) where {I <: Integer}
l = est.l(Z, m)
vcat(l, l .+ exp.(est.u(Z, m)))
end
function (est::IntervalEstimatorCompactPrior)(Z::GNNGraph, m::AbstractVector{I})
x = est.u(Z, m)
y = x .+ exp.(est.v(Z, m))
c = est.c
vcat(c(x), c(y))
end
function (est::PointIntervalEstimator{<:GNN})(Z::GNNGraph, m::AbstractVector{I}) where {I <: Integer}
θ̂ = est.θ̂(Z, m)
vcat(θ̂, θ̂ .- exp.(est.l(Z, m)), θ̂ .+ exp.(est.u(Z, m)))
bₗ = est.u(Z, m) # lower bound
bᵤ = bₗ .+ exp.(est.v(Z, m)) # upper bound
vcat(est.g(bₗ), est.g(bᵤ))
end
2 changes: 1 addition & 1 deletion src/NeuralEstimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export CholeskyCovariance, CovarianceMatrix, CorrelationMatrix
export vectotril, vectotriu
include("Architectures.jl")

export NeuralEstimator, PointEstimator, IntervalEstimator, IntervalEstimatorCompactPrior, PointIntervalEstimator, QuantileEstimator, PiecewiseEstimator, initialise_estimator
export NeuralEstimator, PointEstimator, IntervalEstimator, PiecewiseEstimator, initialise_estimator
include("Estimators.jl")

export GNN, UniversalPool, adjacencymatrix, WeightedGraphConv, maternclusterprocess
Expand Down
8 changes: 2 additions & 6 deletions src/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,14 @@ function interval(θ̃; probs = [0.05, 0.95], parameter_names = ["θ$i" for i
end


function interval(estimator::Union{IntervalEstimator, IntervalEstimatorCompactPrior, PointIntervalEstimator}, Z; parameter_names = nothing, use_gpu::Bool = true)
function interval(estimator::IntervalEstimator, Z; parameter_names = nothing, use_gpu::Bool = true)

ci = estimateinbatches(estimator, Z, use_gpu = use_gpu)
ci = cpu(ci)

if typeof(estimator) <: IntervalEstimator || typeof(estimator) <: IntervalEstimatorCompactPrior
if typeof(estimator) <: IntervalEstimator
@assert size(ci, 1) % 2 == 0
p = size(ci, 1) ÷ 2
elseif typeof(estimator) <: PointIntervalEstimator
@assert size(ci, 1) % 3 == 0
p = size(ci, 1) ÷ 3
ci = ci[p+1:end, :]
end

if isnothing(parameter_names)
Expand Down
Loading

0 comments on commit 385a85b

Please sign in to comment.