Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define MixedMultivariateDistribution type #27

Closed
itsdfish opened this issue Jun 21, 2023 · 12 comments
Closed

Define MixedMultivariateDistribution type #27

itsdfish opened this issue Jun 21, 2023 · 12 comments

Comments

@itsdfish
Copy link
Owner

I would like to explore the possibility of defining a type for multialternative SSMs in hopes that this package plays well with Turing and other packages. The package currently works well in most cases. However, it does not work well with predict from Turing, as discussed here.

Here is one idea:

using Distributions
using Turing 
import Distributions: logpdf
import Distributions: loglikelihood
import Base: length

abstract type Mixed <: ValueSupport end 

const MixedMultivariateDistribution = Distribution{Multivariate, Mixed}

abstract type SSM1D <: ContinuousUnivariateDistribution end 

abstract type SSM2D <: MixedMultivariateDistribution end 

This defines MixedMultivariateDistribution which could potentially be used outside of SSMs. The type system is then split into 1D and 2D SSMs, which are abstract types.

The code below shows that this works for basic MCMC sampling. The question is how to get it to work with predict and friends.

# not really 2D, just for illustration
struct MyType{T<:Real} <: SSM2D
    n::Int
    x::T
end

logpdf(d::MyType,data::Int) = logpdf(Binomial(d.n, d.x), data)

loglikelihood(d::MyType,data::Int) = loglikelihood(Binomial(d.n, d.x), data)


@model function my_model(n, k)
    θ ~ Beta(1, 1)
    return k ~ MyType(n, θ)
end

chain = sample(my_model(10, 5), NUTS(), 3_000)
@kiante-fernandez
Copy link
Contributor

Putting some discussion we've have had for reference:

@itsdfish, did we end up getting predict to work in the case discussed in the discourse?

@DominiqueMakowski

This comment was marked as off-topic.

@itsdfish
Copy link
Owner Author

@DominiqueMakowski, that is an interesting idea. I would have to learn how the @formula macro works and understand what algorithm is used to fit the model. I suspect that the algorithm for fitting a linear model would not be suitable for an SSM. This package might be promising: https://github.com/TuringLang/TuringGLM.jl

In the meantime, you should be able to add covariates to your model via Turing. For example,

@model function cool_model(choices, rts, covariate)
   β0 ~ normal(0, 1)
   β1 ~ normal(0, 1)
   # drift rate
   ν = β0 + β1 * covariate
   ....  
end

@DominiqueMakowski
Copy link
Contributor

The question is how to get it to work with predict and friends.

What issues do you foresee? What would make this new distribution class not compatible with methods such as predict()

@itsdfish
Copy link
Owner Author

@DominiqueMakowski, good question. Turing extends a the function predict here. Strangely, when we run predict with a 2D SSM, it returns an empty chain rather than crashing. My goal is to fork a copy and try to figure out what parts do not work. Presumably, there are functions (or they could add functions) that we could extend for SSMs. For example, maybe AbstractMCMC.bundle_samples does not know how to handle the data from a 2D SSM, and produces a silent error. Assuming that is the case, we could create our own method for MixedMultivarateDistribution, which would allow AbstractMCMC.bundle_samples to store the predictions properly in a chain. At least, my hope.

@DominiqueMakowski
Copy link
Contributor

Don't want to sound pushy at all, just curious, but any news on this front 😁 ?

@itsdfish
Copy link
Owner Author

@DominiqueMakowski, no problem! I appreciate your interest.

I will have some time tomorrow morning to look through Turing and reach out to the developers to see whether they be willing to set up their code so that others may extend the functionality. I'll cc you on any issues I open for your situation awareness.

@itsdfish
Copy link
Owner Author

itsdfish commented Jul 1, 2023

@DominiqueMakowski, I think I have a proof of concept working. Would you be able to give the following code a try on your system to see if it seems reasonable?

using Distributions
using Turing 
using Random
import Distributions: logpdf
import Distributions: loglikelihood
import Distributions: rand
import DynamicPPL: vectorize
import Base: length

abstract type Mixed <: ValueSupport end 

const MixedMultivariateDistribution = Distribution{Multivariate, Mixed}

abstract type SSM1D <: ContinuousUnivariateDistribution end 

abstract type SSM2D <: MixedMultivariateDistribution end 

struct MyType{T<:Real} <: SSM2D
    μ::T 
    σ::T
end

Base.broadcastable(x::MyType) = Ref(x)

vectorize(d::MixedMultivariateDistribution, r::NamedTuple) = [r...]

Base.length(d::MixedMultivariateDistribution) = 2

rand(d::MixedMultivariateDistribution) = rand(Random.default_rng(), d)
rand(d::MixedMultivariateDistribution, n::Int) = rand(Random.default_rng(), d, n)

function rand(rng::AbstractRNG, d::MyType)
    choice = rand(1:2)
    rt = rand(rng, LogNormal(d.μ, d.σ))
    return (;choice, rt)
end

function rand(rng::AbstractRNG, d::MyType, N::Int)
    choice = fill(0, N)
    rt = fill(0.0, N)
    for i in 1:N
        choice[i],rt[i] = rand(rng, d)
    end
    return (choice=choice,rt=rt)
end

function logpdf(d::MyType, choice::Int, rt::Float64) 
    return logpdf(LogNormal(d.μ, d.σ), rt)
end

logpdf(d::MyType, data::NamedTuple) = logpdf(d::MyType, data.choice, data.rt)

loglikelihood(d::MyType, data::NamedTuple) = sum(logpdf.(d, data...))

@model function my_model(data)
    μ ~ Normal(0, 1)
    σ ~ truncated(Normal(0, 1), 0, Inf)
    data ~ MyType(μ, σ)
    return (;data, μ, σ)
end

data = rand(MyType(0, 1), 10)
chain = sample(my_model(data), NUTS(), 1_000)

predictions = predict(my_model(missing), chain)

@DominiqueMakowski
Copy link
Contributor

It does seem to work ☺️ here's the output!

Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
parameters        = data[1], data[2]
internals         =

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Missing

     data[1]    1.4970    0.5002    0.0164   928.3307        NaN    0.9992       missing
     data[2]    2.7176   10.2072    0.3206   898.7900   901.0663    0.9993       missing

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64

     data[1]    1.0000    1.0000    1.0000    2.0000    2.0000
     data[2]    0.1605    0.5916    1.2593    2.5420   11.0056

@itsdfish
Copy link
Owner Author

itsdfish commented Jul 1, 2023

Awesome. Evidently, the main thing I needed to define was

vectorize(d::MixedMultivariateDistribution, r::NamedTuple) = [r...]
which I believe allows Turing to save the predictions into a Chain.

I'll work on integrating the new type system either today or tomorrow and push a new release.

@itsdfish itsdfish closed this as completed Jul 1, 2023
@kiante-fernandez
Copy link
Contributor

Awesome. Evidently, the main thing I needed to define was

vectorize(d::MixedMultivariateDistribution, r::NamedTuple) = [r...] which I believe allows Turing to save the predictions into a Chain.

I'll work on integrating the new type system either today or tomorrow and push a new release.

@itsdfish Nice work on this! Was away for a summer school. Will try to finish up the Ratcliff DDM this weekend. Will look closer at what you implemented and I'll make sure Ratcliff distribution uses the new framework.

@itsdfish
Copy link
Owner Author

itsdfish commented Jul 1, 2023

@kiante-fernandez, thanks!

I think integrating with the new type system should be fairly simple. Basically, you can make the model a subtype of SSM2D. As far as I can tell, the generic methods defined in utilities will work with the new model. Once nice thing about the new system is that I was able to remove 3-4 methods per model by using generic methods.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants