From 385a85b5aa5d9fdf17ba83f77547364a0dd7972d Mon Sep 17 00:00:00 2001 From: Matthew Sainsbury-Dale Date: Sun, 17 Dec 2023 16:49:38 +1100 Subject: [PATCH] Documentation; removed IntervalEstimatorCompactPrior, QuantileEstimator, and PointIntervalEstimator --- docs/src/API/core.md | 6 -- src/Estimators.jl | 218 +++++++--------------------------------- src/Graphs.jl | 46 +++++---- src/NeuralEstimators.jl | 2 +- src/bootstrap.jl | 8 +- src/train.jl | 20 +--- test/runtests.jl | 16 ++- 7 files changed, 82 insertions(+), 234 deletions(-) diff --git a/docs/src/API/core.md b/docs/src/API/core.md index d0a4413..ddfc838 100644 --- a/docs/src/API/core.md +++ b/docs/src/API/core.md @@ -31,12 +31,6 @@ PointEstimator IntervalEstimator -IntervalEstimatorCompactPrior - -PointIntervalEstimator - -QuantileEstimator - PiecewiseEstimator ``` diff --git a/src/Estimators.jl b/src/Estimators.jl index c111236..f1c2434 100644 --- a/src/Estimators.jl +++ b/src/Estimators.jl @@ -12,89 +12,27 @@ abstract type NeuralEstimator end """ PointEstimator(arch) -A simple point estimator, that is, a mapping from the sample space to the -parameter space, defined by the given architecture `arch`. +A neural point estimator, that is, a mapping from the sample space to the +parameter space, defined by the given neural-network architecture `arch`. """ struct PointEstimator{F} <: NeuralEstimator arch::F + # PointEstimator(arch) = isa(arch, PointEstimator) ? error("Please do not construct PointEstimator objects with another PointEstimator") : new(arch) end @functor PointEstimator (arch,) (est::PointEstimator)(Z) = est.arch(Z) -# ---- IntervalEstimator: credible intervals ---- +# ---- IntervalEstimator for amortised credible intervals ---- """ - IntervalEstimator(arch_lower, arch_upper) - IntervalEstimator(arch) -A neural interval estimator that jointly estimates credible intervals constructed as, + IntervalEstimator(u) + IntervalEstimator(u, v) + IntervalEstimator(u, v, g::Compress) + IntervalEstimator(u, v, min_supp::Vector, max_supp::Vector) -```math -[l(Z), l(Z) + \\mathrm{exp}(u(Z))], -``` - -where ``l(⋅)`` and ``u(⋅)`` are the neural networks `arch_lower` and -`arch_upper`, both of which should transform data into ``p``-dimensional vectors, -where ``p`` is the number of parameters in the statistical model. If only a -single neural network architecture `arch` is provided, it will be used for both -`arch_lower` and `arch_upper`. - -The returned value is a matrix with ``2p`` rows, where the first and second ``p`` -rows correspond to estimates of the lower and upper bound, respectively. - -See also [`IntervalEstimatorCompactPrior`](@ref). - -# Examples -``` -using NeuralEstimators -using Flux - -# Generate some toy data -n = 2 # bivariate data -m = 100 # number of independent replicates -Z = rand(n, m) - -# Create an architecture -p = 3 # number of parameters in the statistical model -w = 8 # width of each layer -ψ = Chain(Dense(n, w, relu), Dense(w, w, relu)); -ϕ = Chain(Dense(w, w, relu), Dense(w, p)); -architecture = DeepSet(ψ, ϕ) - -# Initialise the interval estimator -estimator = IntervalEstimator(architecture) - -# Apply the interval estimator -estimator(Z) -interval(estimator, Z) -``` -""" -struct IntervalEstimator{F, G} <: NeuralEstimator - l::F - u::G -end -IntervalEstimator(l) = IntervalEstimator(l, deepcopy(l)) -@functor IntervalEstimator -function (est::IntervalEstimator)(Z) - l = est.l(Z) - vcat(l, l .+ exp.(est.u(Z))) -end -# Ensure that IntervalEstimator objects are not constructed with PointEstimator: -#TODO find a neater way to do this; don't want to write so many methods, especially for PointIntervalEstimator -IntervalEstimator(l::PointEstimator, u::PointEstimator) = IntervalEstimator(l.arch, u.arch) -IntervalEstimator(l, u::PointEstimator) = IntervalEstimator(l, u.arch) -IntervalEstimator(l::PointEstimator, u) = IntervalEstimator(l.arch, u) - -""" - IntervalEstimatorCompactPrior(u, v, min_supp::Vector, max_supp::Vector) - IntervalEstimatorCompactPrior(u, v, compress::Compress) -Uses the neural networks `u` and `v` to jointly estimate credible intervals -that are guaranteed to be within the support of the prior distributon. This -support is defined by the ``p``-dimensional vectors `min_supp` and `max_supp` -(or a single ``p``-dimensional object of type `Compress`), where ``p`` is the -number of parameters in the statistical model. - -Given data ``Z``, the intervals are constructed as +A neural interval estimator which, given data ``Z``, jointly estimates credible +intervals in the form, ```math [g(u(Z)), g(u(Z)) + \\mathrm{exp}(v(Z)))], @@ -102,34 +40,39 @@ Given data ``Z``, the intervals are constructed as where -- ``u(⋅)`` and ``v(⋅)`` are neural networks, both of which should transform data into ``p``-dimensional vectors; +- ``u(⋅)`` and ``v(⋅)`` are neural networks, both of which should transform data into ``p``-dimensional vectors (with ``p`` the number of parameters in the statistical model); - ``g(⋅)`` is a logistic function that maps its input to the prior support. +The prior support is defined either by the ``p``-dimensional vectors `min_supp` +and `max_supp`, or a single ``p``-dimensional object of type [`Compress`](@ref). +If these objects are not given, the range of the intervals will be unrestricted. + Note that, in addition to ensuring that the interval remains in the prior support, this construction also ensures that the intervals are valid (i.e., it prevents quantile crossing, in the sense that the upper bound is always greater than the lower bound). +If only a single neural-network architecture is provided, it will be used +for both `u` and `v`. + The returned value is a matrix with ``2p`` rows, where the first and second ``p`` rows correspond to estimates of the lower and upper bound, respectively. -See also [`IntervalEstimator`](@ref) and [`Compress`](@ref). - # Examples ``` using NeuralEstimators using Flux -# prior support -min_supp = [25, 0.5, -pi/2] -max_supp = [500, 2.5, 0] -p = length(min_supp) # number of parameters in the statistical model - # Generate some toy data n = 2 # bivariate data m = 100 # number of independent replicates Z = rand(n, m) +# prior support +min_supp = [25, 0.5, -pi/2] +max_supp = [500, 2.5, 0] +p = length(min_supp) # number of parameters in the statistical model + # Create an architecture w = 8 # width of each layer ψ = Chain(Dense(n, w, relu), Dense(w, w, relu)); @@ -138,116 +81,31 @@ u = DeepSet(ψ, ϕ) v = deepcopy(u) # use the same architecture for both u and v # Initialise the interval estimator -estimator = IntervalEstimatorCompactPrior(u, v, min_supp, max_supp) +estimator = IntervalEstimator(u, v, min_supp, max_supp) # Apply the interval estimator estimator(Z) interval(estimator, Z) ``` """ -struct IntervalEstimatorCompactPrior{F, G} <: NeuralEstimator +struct IntervalEstimator{F, G} <: NeuralEstimator u::F v::G - c::Compress -end -IntervalEstimatorCompactPrior(u, v, min_supp, max_supp) = IntervalEstimatorCompactPrior(u, v, Compress(min_supp, max_supp)) -@functor IntervalEstimatorCompactPrior -Flux.trainable(est::IntervalEstimatorCompactPrior) = (est.u, est.v) - -function (est::IntervalEstimatorCompactPrior)(Z) - x = est.u(Z) - y = x .+ exp.(est.v(Z)) - c = est.c - vcat(c(x), c(y)) -end - -""" - PointIntervalEstimator(arch_point, arch_lower, arch_upper) - PointIntervalEstimator(arch_point, arch_bound) - PointIntervalEstimator(arch) -A neural estimator that jointly produces point estimates, ``θ̂(Z)``, where ``θ̂(⋅)`` is a -neural point estimator with architecture `arch_point`, and credible intervals constructed as, - -```math -[θ̂(Z) - \\mathrm{exp}(l(Z)), θ̂(Z) + \\mathrm{exp}(u(Z))], -``` - -where ``l(⋅)`` and ``u(⋅)`` are the neural networks `arch_lower` and -`arch_upper`, both of which should transform data into ``p``-dimensional vectors, -where ``p`` is the number of parameters in the statistical model. - -If only a single neural network architecture `arch` is provided, it will be used -for all architectures; similarly, if two architectures are provided, the second -will be used for both `arch_lower` and `arch_upper`. - -Internally, the point estimates, lower-bound estimates, and upper-bound estimates are concatenated, so -that `PointIntervalEstimator` objects transform data into matrices with ``3p`` rows. - -# Examples -``` -using NeuralEstimators -using Flux - -# Generate some toy data -n = 2 # bivariate data -m = 100 # number of independent replicates -Z = rand(n, m) - -# Create an architecture -p = 3 # number of parameters in the statistical model -w = 8 # width of each layer -ψ = Chain(Dense(n, w, relu), Dense(w, w, relu)); -ϕ = Chain(Dense(w, w, relu), Dense(w, p)); -architecture = DeepSet(ψ, ϕ) - -# Initialise the estimator -estimator = PointIntervalEstimator(architecture) - -# Apply the estimator -estimator(Z) -interval(estimator, Z) -``` -""" -struct PointIntervalEstimator{H, F, G} <: NeuralEstimator - θ̂::H - l::F - u::G + g::Union{Function,Compress} + # IntervalEstimator(u, v, g) = any(isa.([u, v], PointEstimator)) ? error("Please do not construct IntervalEstimator objects with PointEstimators") : new(u, v, g) end -PointIntervalEstimator(θ̂) = PointIntervalEstimator(θ̂, deepcopy(θ̂), deepcopy(θ̂)) -PointIntervalEstimator(θ̂, l) = PointIntervalEstimator(θ̂, deepcopy(l), deepcopy(l)) -@functor PointIntervalEstimator -function (est::PointIntervalEstimator)(Z) - θ̂ = est.θ̂(Z) - vcat(θ̂, θ̂ .- exp.(est.l(Z)), θ̂ .+ exp.(est.u(Z))) -end -# Ensure that IntervalEstimator objects are not constructed as a wrapper of PointEstimator: -PointIntervalEstimator(θ̂::PointEstimator, l::PointEstimator, u::PointEstimator) = PointIntervalEstimator(θ̂.arch, l.arch, u.arch) -PointIntervalEstimator(θ̂::PointEstimator, l, u) = PointIntervalEstimator(θ̂.arch, l, u) -PointIntervalEstimator(θ̂, l::PointEstimator, u::PointEstimator) = PointIntervalEstimator(θ̂, l.arch, u.arch) - - -# ---- QuantileEstimator: estimating arbitrary quantiles of the posterior distribution ---- - -# Should Follow up with this point from Gnieting's paper: -# 9.2 Quantile Estimation -# Koenker and Bassett (1978) proposed quantile regression using an optimum -# score estimator based on the proper scoring rule (41). - - -#TODO this is a topic of ongoing research with Jordan -""" - QuantileEstimator() - -Coming soon: this structure will allow for the simultaneous estimation of an -arbitrary number of marginal quantiles of the posterior distribution. -""" -struct QuantileEstimator{F, G} <: NeuralEstimator - l::F - u::G +IntervalEstimator(u) = IntervalEstimator(u, deepcopy(u), identity) +IntervalEstimator(u, v) = IntervalEstimator(u, v, identity) +IntervalEstimator(u, g::Compress) = IntervalEstimator(u, deepcopy(u), g) +IntervalEstimator(u, min_supp, max_supp) = IntervalEstimator(u, Compress(min_supp, max_supp)) +IntervalEstimator(u, v, min_supp, max_supp) = IntervalEstimator(u, v, Compress(min_supp, max_supp)) +@functor IntervalEstimator +Flux.trainable(est::IntervalEstimator) = (est.u, est.v) +function (est::IntervalEstimator)(Z) + bₗ = est.u(Z) # lower bound + bᵤ = bₗ .+ exp.(est.v(Z)) # upper bound + vcat(est.g(bₗ), est.g(bᵤ)) end -# @functor QuantileEstimator -# (c::QuantileEstimator)(Z) = vcat(c.l(Z), c.l(Z) .+ exp.(c.u(Z))) - # ---- PiecewiseEstimator ---- diff --git a/src/Graphs.jl b/src/Graphs.jl index f4f68a1..254eeec 100644 --- a/src/Graphs.jl +++ b/src/Graphs.jl @@ -458,7 +458,7 @@ hidden-feature graphs; the `readout` module aggregates these feature graphs into a single hidden feature vector of fixed length; the function `a`(⋅) is a permutation-invariant aggregation function, and `ϕ` is a neural network. -The data should be stored as a `GNNGraph` or `AbstractVector{GNNGraph}`, where +The data should be stored as a `GNNGraph` or `Vector{GNNGraph}`, where each graph is associated with a single parameter vector. The graphs may contain sub-graphs corresponding to independent replicates from the model. @@ -586,22 +586,34 @@ function (est::GNN)(g::GNNGraph, m::AbstractVector{I}) where {I <: Integer} return est.deepset(h̃) end -# Methods needed to accomodate above method of GNN. They are exactly the same as -# the standard methods defined in Estimators.jl, but also pass through m. -#TODO unit testing for these methods -#TODO Not ideal that there's so much code repetition... we're just replacing f(Z) with f(Z, m). Tried with the g(x...) = sum(x) approach; it almost worked, might be worth trying again. +# Also need a custom method for _updatebatch!() +# NB Surely there is a more generic way to dispatch here (e.g., any structure that contains GNN) +function _updatebatch!(θ̂::Union{GNN, PointEstimator{<:GNN}, IntervalEstimator{<:GNN}}, Z, θ, device, loss, γ, optimiser) + + m = numberreplicates(Z) + Z = Flux.batch(Z) + Z, θ = Z |> device, θ |> device + + # Compute gradients in such a way that the training loss is also saved. + # This is equivalent to: gradients = gradient(() -> loss(θ̂(Z), θ), γ) + ls, back = Zygote.pullback(() -> loss(θ̂(Z, m), θ), γ) # NB here we also pass m to θ̂, since Flux.batch() cannot be differentiated + gradients = back(one(ls)) + update!(optimiser, γ, gradients) + + # Assuming that loss returns an average, convert it to a sum. + ls = ls * size(θ)[end] + return ls +end + + +# Higher level methods needed to accomodate above methods for GNN. They are +# exactly the same as the standard methods defined in Estimators.jl, but we +# also pass through m. +#NB Not ideal that there's so much code repetition... we're just replacing +# f(Z) with f(Z, m). Tried with the g(x...) = sum(x) approach; it almost worked, might be worth trying again. (est::PointEstimator{<:GNN})(Z::GNNGraph, m::AbstractVector{I}) where {I <: Integer} = est.arch(Z, m) function (est::IntervalEstimator{<:GNN})(Z::GNNGraph, m::AbstractVector{I}) where {I <: Integer} - l = est.l(Z, m) - vcat(l, l .+ exp.(est.u(Z, m))) -end -function (est::IntervalEstimatorCompactPrior)(Z::GNNGraph, m::AbstractVector{I}) - x = est.u(Z, m) - y = x .+ exp.(est.v(Z, m)) - c = est.c - vcat(c(x), c(y)) -end -function (est::PointIntervalEstimator{<:GNN})(Z::GNNGraph, m::AbstractVector{I}) where {I <: Integer} - θ̂ = est.θ̂(Z, m) - vcat(θ̂, θ̂ .- exp.(est.l(Z, m)), θ̂ .+ exp.(est.u(Z, m))) + bₗ = est.u(Z, m) # lower bound + bᵤ = bₗ .+ exp.(est.v(Z, m)) # upper bound + vcat(est.g(bₗ), est.g(bᵤ)) end diff --git a/src/NeuralEstimators.jl b/src/NeuralEstimators.jl index 07a8d44..2d86455 100644 --- a/src/NeuralEstimators.jl +++ b/src/NeuralEstimators.jl @@ -40,7 +40,7 @@ export CholeskyCovariance, CovarianceMatrix, CorrelationMatrix export vectotril, vectotriu include("Architectures.jl") -export NeuralEstimator, PointEstimator, IntervalEstimator, IntervalEstimatorCompactPrior, PointIntervalEstimator, QuantileEstimator, PiecewiseEstimator, initialise_estimator +export NeuralEstimator, PointEstimator, IntervalEstimator, PiecewiseEstimator, initialise_estimator include("Estimators.jl") export GNN, UniversalPool, adjacencymatrix, WeightedGraphConv, maternclusterprocess diff --git a/src/bootstrap.jl b/src/bootstrap.jl index d1601e7..3d5a009 100644 --- a/src/bootstrap.jl +++ b/src/bootstrap.jl @@ -37,18 +37,14 @@ function interval(θ̃; probs = [0.05, 0.95], parameter_names = ["θ$i" for i end -function interval(estimator::Union{IntervalEstimator, IntervalEstimatorCompactPrior, PointIntervalEstimator}, Z; parameter_names = nothing, use_gpu::Bool = true) +function interval(estimator::IntervalEstimator, Z; parameter_names = nothing, use_gpu::Bool = true) ci = estimateinbatches(estimator, Z, use_gpu = use_gpu) ci = cpu(ci) - if typeof(estimator) <: IntervalEstimator || typeof(estimator) <: IntervalEstimatorCompactPrior + if typeof(estimator) <: IntervalEstimator @assert size(ci, 1) % 2 == 0 p = size(ci, 1) ÷ 2 - elseif typeof(estimator) <: PointIntervalEstimator - @assert size(ci, 1) % 3 == 0 - p = size(ci, 1) ÷ 3 - ci = ci[p+1:end, :] end if isnothing(parameter_names) diff --git a/src/train.jl b/src/train.jl index 3767349..7160c79 100644 --- a/src/train.jl +++ b/src/train.jl @@ -91,7 +91,7 @@ function train end function train(θ̂, sampler, simulator; m, - ξ = nothing, xi = nothing, + ξ = nothing, xi = nothing, epochs_per_θ_refresh::Integer = 1, epochs_per_theta_refresh::Integer = 1, epochs_per_Z_refresh::Integer = 1, simulate_just_in_time::Bool = false, @@ -736,24 +736,6 @@ function _updatebatch!(θ̂, Z, θ, device, loss, γ, optimiser) return ls end -#TODO Surely there is a better way of dispatching here... -function _updatebatch!(θ̂::Union{GNN, PointEstimator{<:GNN}, IntervalEstimator{<:GNN}, IntervalEstimatorCompactPrior{<:GNN}, PointIntervalEstimator{<:GNN}}, Z, θ, device, loss, γ, optimiser) - - m = numberreplicates(Z) - Z = Flux.batch(Z) - Z, θ = Z |> device, θ |> device - - # Compute gradients in such a way that the training loss is also saved. - # This is equivalent to: gradients = gradient(() -> loss(θ̂(Z), θ), γ) - ls, back = Zygote.pullback(() -> loss(θ̂(Z, m), θ), γ) # NB here we also pass m to θ̂, since Flux.batch() cannot be differentiated - gradients = back(one(ls)) - update!(optimiser, γ, gradients) - - # Assuming that loss returns an average, convert it to a sum. - ls = ls * size(θ)[end] - return ls -end - # Wrapper function that returns simulated data and the true parameter values _simulate(simulator, params::P, m) where {P <: Union{AbstractMatrix, ParameterConfigurations}} = (simulator(params, m), _extractθ(params)) diff --git a/test/runtests.jl b/test/runtests.jl index 23d4a68..88f42b1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -699,7 +699,7 @@ end end -@testset "IntervalEstimators" begin +@testset "IntervalEstimator" begin # Generate some toy data and a basic architecture d = 2 # bivariate data m = 64 # number of independent replicates @@ -711,6 +711,7 @@ end # IntervalEstimator estimator = IntervalEstimator(arch) + estimator = IntervalEstimator(arch, arch) θ̂ = estimator(Z) @test size(θ̂) == (2p, 1) @test all(θ̂[1:p] .< θ̂[(p+1):end]) @@ -718,11 +719,16 @@ end ci = interval(estimator, Z, parameter_names = parameter_names) @test size(ci[1]) == (p, 2) - # PointIntervalEstimator - estimator = PointIntervalEstimator(arch) + # IntervalEstimator with a compact prior + min_supp = [25, 0.5, -pi/2] + max_supp = [500, 2.5, 0] + estimator = IntervalEstimator(arch, min_supp, max_supp) + estimator = IntervalEstimator(arch, arch, min_supp, max_supp) θ̂ = estimator(Z) - @test size(θ̂) == (3p, 1) - @test all(θ̂[(p+1):2p] .< θ̂[1:p] .< θ̂[(2p+1):end]) + @test size(θ̂) == (2p, 1) + @test all(θ̂[1:p] .< θ̂[(p+1):end]) + @test all(min_supp .< θ̂[1:p] .< max_supp) + @test all(min_supp .< θ̂[p+1:end] .< max_supp) ci = interval(estimator, Z) ci = interval(estimator, Z, parameter_names = parameter_names) @test size(ci[1]) == (p, 2)