Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1098 cant instantiate halft prior #1104

Merged
merged 3 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Changelog

## Unreleased
* update prior `__new__` methods #1098 [MartinBubel]

* fix invalid escape sequence #1011 [janmayer]

## v1.13.2 (2024-07-21)
Expand Down
13 changes: 11 additions & 2 deletions GPy/core/parameterization/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,11 @@ class DGPLVM(Prior):
domain = _REAL

def __new__(cls, sigma2, lbl, x_shape):
return super(Prior, cls).__new__(cls, sigma2, lbl, x_shape)
newfunc = super(Prior, cls).__new__
if newfunc is object.__new__:
return newfunc(cls)
else:
return newfunc(cls, sigma2, lbl, x_shape)

def __init__(self, sigma2, lbl, x_shape):
self.sigma2 = sigma2
Expand Down Expand Up @@ -1275,7 +1279,12 @@ def __new__(cls, A, nu): # Singleton:
for instance in cls._instances:
if instance().A == A and instance().nu == nu:
return instance()
o = super(Prior, cls).__new__(cls, A, nu)

newfunc = super(Prior, cls).__new__
if newfunc is object.__new__:
o = newfunc(cls)
else:
o = newfunc(cls, A, nu)
cls._instances.append(weakref.ref(o))
return cls._instances[-1]()

Expand Down
102 changes: 102 additions & 0 deletions GPy/testing/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,21 @@
import pytest
import numpy as np
import GPy
from GPy.core.parameterization.priors import (
Gaussian,
Uniform,
LogGaussian,
MultivariateGaussian,
Gamma,
InverseGamma,
DGPLVM,
DGPLVM_KFDA,
DGPLVM_Lamda,
DGPLVM_T,
HalfT,
Exponential,
StudentT,
)


class TestPrior:
Expand Down Expand Up @@ -178,3 +193,90 @@ def test_fixed_domain_check1(self):
# should raise an assertionerror.
with pytest.raises(AssertionError):
m.rbf.set_prior(gaussian)


def initialize_gaussian_prior() -> None:
return Gaussian(0, 1)


def initialize_uniform_prior() -> None:
return Uniform(0, 1)


def initialize_log_gaussian_prior() -> None:
return LogGaussian(0, 1)


def initialize_multivariate_gaussian_prior() -> None:
return MultivariateGaussian(np.zeros(2), np.eye(2))


def initialize_gamma_prior() -> None:
return Gamma(1, 1)


def initialize_inverse_gamma_prior() -> None:
return InverseGamma(1, 1)


def initialize_dgplvm_prior() -> None:
# return DGPLVM(...)
raise NotImplementedError("No idea how to initialize this prior")


def initialize_dgplvm_kfda_prior() -> None:
# return DGPLVM_KFDA(...)
raise NotImplementedError("No idea how to initialize this prior")


def initialize_dgplvm_lamda_prior() -> None:
# return DGPLVM_Lamda(...)
raise NotImplementedError("No idea how to initialize this prior")


def initialize_dgplvm_t_prior() -> None:
# return DGPLVM_T(1, 1, (1, 1))
raise NotImplementedError("No idea how to initialize this prior")


def initialize_half_t_prior() -> None:
return HalfT(1, 1)


def initialize_exponential_prior() -> None:
return Exponential(1)


def initialize_student_t_prior() -> None:
return StudentT(1, 1, 1)


PRIORS = {
"Gaussian": initialize_gaussian_prior,
"Uniform": initialize_uniform_prior,
"LogGaussian": initialize_log_gaussian_prior,
"MultivariateGaussian": initialize_multivariate_gaussian_prior,
"Gamma": initialize_gamma_prior,
"InverseGamma": initialize_inverse_gamma_prior,
# "DGPLVM": initialize_dgplvm_prior,
# "DGPLVM_KFDA": initialize_dgplvm_kfda_prior,
# "DGPLVM_Lamda": initialize_dgplvm_lamda_prior,
# "DGPLVM_T": initialize_dgplvm_t_prior,
"HalfT": initialize_half_t_prior,
"Exponential": initialize_exponential_prior,
"StudentT": initialize_student_t_prior,
}


def check_prior(prior_getter: str) -> None:
prior_getter()


def test_priors() -> None:
for prior_name, prior_getter in PRIORS.items():
try:
check_prior(prior_getter)
except Exception as e:
raise RuntimeError(
f"Failed to initialize {prior_name} prior"
) from e # noqa E501
Loading