Skip to content

Commit

Permalink
Extend tests to non-zero-mean cases
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Apr 20, 2022
1 parent 17862a3 commit cbeabe7
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 58 deletions.
16 changes: 8 additions & 8 deletions tests/model/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@


def test_summation_with_itself():
p = GP(EQ())
p = GP(1, EQ())
p_many = p + p + p + p + p

x = B.linspace(0, 10, 5)
approx(p_many(x).var, 25 * p(x).var)
approx(p_many(x).mean, B.zeros(5, 1))
approx(p_many(x).mean, 5 * B.ones(5, 1))

y = B.randn(5, 1)
post = p.measure | (p(x), y)
Expand All @@ -21,8 +21,8 @@ def test_summation_with_itself():

def test_additive_model():
m = Measure()
p1 = GP(EQ(), measure=m)
p2 = GP(EQ(), measure=m)
p1 = GP(1, EQ(), measure=m)
p2 = GP(2, EQ(), measure=m)
p_sum = p1 + p2

x = B.linspace(0, 5, 10)
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_fd_derivative():


def test_reflection():
p = GP(EQ())
p = GP(1, EQ())
p2 = 5 - p

x = B.linspace(0, 5, 10)
Expand All @@ -79,7 +79,7 @@ def test_reflection():


def test_negation():
p = GP(EQ())
p = GP(1, EQ())
p2 = -p

x = B.linspace(0, 5, 10)
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_batched():
x1 = B.randn(16, 10, 1)
x2 = B.randn(16, 5, 1)

p = GP(2 * EQ().stretch(0.5))
p = GP(1, 2 * EQ().stretch(0.5))
y1, y2 = p.measure.sample(p(x1), p(x2))
logpdf = p.measure.logpdf((p(x1, 0.1), y1), (p(x2, 0.1), y2))

Expand All @@ -159,7 +159,7 @@ def test_mo_batched():
x = B.randn(16, 10, 1)

with Measure():
p = cross(GP(2 * EQ().stretch(0.5)), GP(2 * EQ().stretch(0.5)))
p = cross(GP(1, 2 * EQ().stretch(0.5)), GP(2, 2 * EQ().stretch(0.5)))
y = p(x).sample()
logpdf = p(x, 0.1).logpdf(y)

Expand Down
14 changes: 7 additions & 7 deletions tests/model/test_fdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def check(noise, dtype, n, asserted_type):


def test_fdd():
p = GP(EQ())
p = GP(1, EQ())

# Test specification without noise.
for fdd in [p(1), FDD(p, 1)]:
Expand All @@ -41,7 +41,7 @@ def test_fdd():
assert isinstance(fdd.noise, matrix.Zero)
rep = (
"<FDD:\n"
" process=GP(0, EQ()),\n"
" process=GP(1, EQ()),\n"
" input=1,\n"
" noise=<zero matrix: batch=(), shape=(1, 1), dtype=int>>"
)
Expand All @@ -60,13 +60,13 @@ def test_fdd():
assert isinstance(fdd.noise, matrix.Diagonal)
assert str(fdd) == (
"<FDD:\n"
" process=GP(0, EQ()),\n"
" process=GP(1, EQ()),\n"
" input=1.0,\n"
" noise=<diagonal matrix: batch=(), shape=(2, 2), dtype=int64>>"
)
assert repr(fdd) == (
"<FDD:\n"
" process=GP(0, EQ()),\n"
" process=GP(1, EQ()),\n"
" input=1.0,\n"
" noise=<diagonal matrix: batch=(), shape=(2, 2), dtype=int64\n"
" diag=[1 2]>>"
Expand All @@ -83,8 +83,8 @@ def test_fdd():

def test_fdd_take():
with Measure():
f1 = GP(EQ())
f2 = GP(Exp())
f1 = GP(1, EQ())
f2 = GP(2, Exp())
f = cross(f1, f2)

x = B.linspace(0, 3, 5)
Expand All @@ -108,7 +108,7 @@ def test_fdd_take():


def test_fdd_properties():
p = GP(EQ())
p = GP(1, EQ())

# Sample observations.
x = B.linspace(0, 5, 5)
Expand Down
12 changes: 6 additions & 6 deletions tests/model/test_gp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from time import time

import numpy as np
import pytest
from time import time
from lab import B
from mlkernels import (
Linear,
Expand All @@ -14,10 +15,9 @@
OneMean,
)
from plum import NotFoundLookupError

from stheno.model import Measure, GP
from stheno.random import Normal
from .util import assert_equal_gps

from ..util import approx


Expand Down Expand Up @@ -93,7 +93,7 @@ def test_construction():


def test_sum_other():
p = GP(TensorProductMean(lambda x: x**2), EQ())
p = GP(lambda x: x**2, EQ())

def five(y):
return 5 * B.ones(B.shape(y)[0], 1)
Expand Down Expand Up @@ -124,7 +124,7 @@ def five(y):


def test_mul_other():
p = GP(TensorProductMean(lambda x: x**2), EQ())
p = GP(lambda x: x**2, EQ())

def five(y):
return 5 * B.ones(B.shape(y)[0], 1)
Expand Down Expand Up @@ -175,7 +175,7 @@ def test_stationarity():


def test_marginals():
p = GP(TensorProductMean(lambda x: x**2), EQ())
p = GP(lambda x: x**2, EQ())
x = B.linspace(0, 5, 10)

# Check that `marginals` outputs the right thing.
Expand Down
61 changes: 28 additions & 33 deletions tests/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
EQ,
Delta,
Exp,
TensorProductMean,
)

from stheno.model import (
Measure,
GP,
Expand All @@ -24,6 +22,7 @@
cross,
FDD,
)

from .util import assert_equal_normals, assert_equal_measures
from ..util import approx

Expand Down Expand Up @@ -140,8 +139,8 @@ def doubly_assign():
)
def test_conditioning(generate_noise_tuple):
m = Measure()
p1 = GP(EQ(), measure=m)
p2 = GP(Exp(), measure=m)
p1 = GP(1, EQ(), measure=m)
p2 = GP(2, Exp(), measure=m)
p_sum = p1 + p2

# Sample some data to condition on.
Expand Down Expand Up @@ -188,7 +187,7 @@ def test_conditioning(generate_noise_tuple):

def test_conditioning_consistency():
m = Measure()
p = GP(EQ(), measure=m)
p = GP(1, EQ(), measure=m)
e = GP(0.1 * Delta(), measure=m)
e2 = GP(e.kernel, measure=m)

Expand All @@ -205,7 +204,7 @@ def test_conditioning_consistency():

@pytest.mark.parametrize("shape", [(0,), (0, 1), (4, 0, 1)])
def test_conditioning_empty_observations(shape):
p = GP(EQ())
p = GP(1, EQ())

x = B.randn(*shape)
y = p(x).sample()
Expand All @@ -217,7 +216,7 @@ def test_conditioning_empty_observations(shape):


def test_conditioning_shorthand():
p = GP(EQ())
p = GP(1, EQ())

# Test conditioning once.
x = B.linspace(0, 5, 10)
Expand All @@ -237,7 +236,7 @@ def test_conditioning_shorthand():


def test_conditioning_missing_data():
p = GP(EQ())
p = GP(1, EQ())
x = B.linspace(0, 5, 10)
y = p(x).sample()
y[:3] = B.nan
Expand All @@ -247,7 +246,7 @@ def test_conditioning_missing_data():


def test_conditioning_shape_check():
f = GP(EQ())
f = GP(1, EQ())
x = B.randn(2)
f | (f(x), B.randn(2, 1))
with pytest.raises(ValueError):
Expand All @@ -265,8 +264,8 @@ def test_conditioning_shape_check():
@pytest.mark.parametrize("PseudoObs", [PseudoObs, PseudoObsFITC, PseudoObsDTC])
def test_pseudoobs_and_elbo(generate_noise_tuple, PseudoObs):
m = Measure()
p1 = GP(EQ(), measure=m)
p2 = GP(Exp(), measure=m)
p1 = GP(1, EQ(), measure=m)
p2 = GP(2, Exp(), measure=m)
p_sum = p1 + p2

# Sample some data to condition on.
Expand Down Expand Up @@ -364,7 +363,7 @@ def elwise_(k: TrackingEQ, x: B.Numeric, y: B.Numeric):
x_new = B.randn(1)

# Perform a pseudo-point approximation
p = GP(TrackingEQ())
p = GP(1, TrackingEQ())
p_post = p | PseudoObs(p(x_ind), (p(x_obs, 0.1), y_obs))
mean, var = p_post(x_new).marginals()

Expand All @@ -383,8 +382,8 @@ def test_backward_compatibility():
@pytest.mark.parametrize("PseudoObs", [PseudoObs, PseudoObsFITC, PseudoObsDTC])
def test_logpdf(PseudoObs):
m = Measure()
p1 = GP(EQ(), measure=m)
p2 = GP(Exp(), measure=m)
p1 = GP(1, EQ(), measure=m)
p2 = GP(2, Exp(), measure=m)
p3 = p1 + p2

x1 = B.linspace(0, 2, 5)
Expand Down Expand Up @@ -414,8 +413,8 @@ def test_logpdf(PseudoObs):

def test_manual_new_gp():
m = Measure()
p1 = GP(EQ(), measure=m)
p2 = GP(EQ(), measure=m)
p1 = GP(1, EQ(), measure=m)
p2 = GP(2, EQ(), measure=m)
p_sum = p1 + p2

p1_equivalent = m.add_gp(
Expand All @@ -436,67 +435,64 @@ def test_manual_new_gp():

def test_stretching():
# Test construction:
p = GP(TensorProductMean(lambda x: x**2), EQ())
assert str(p.stretch(1)) == "GP(<lambda> > 1, EQ() > 1)"
p = GP(lambda x: x**2, Linear())
assert str(p.stretch(1)) == "GP(<lambda> > 1, Linear() > 1)"

# Test case:
p = GP(EQ())
p_stretched = p.stretch(5)

x = B.linspace(0, 5, 10)
y = p_stretched(x).sample()

post = p.measure | (p_stretched(x), y)
post = p.measure | (p_stretched(x, B.epsilon), y)
assert_equal_normals(post(p(x / 5)), post(p_stretched(x)))
assert_equal_normals(post(p(x)), post(p_stretched(x * 5)))


def test_shifting():
# Test construction:
p = GP(TensorProductMean(lambda x: x**2), Linear())
p = GP(lambda x: x**2, Linear())
assert str(p.shift(1)) == "GP(<lambda> shift 1, Linear() shift 1)"

# Test case:
p = GP(EQ())
p_shifted = p.shift(5)

x = B.linspace(0, 5, 10)
y = p_shifted(x).sample()

post = p.measure | (p_shifted(x), y)
post = p.measure | (p_shifted(x, B.epsilon), y)
assert_equal_normals(post(p(x - 5)), post(p_shifted(x)))
assert_equal_normals(post(p(x)), post(p_shifted(x + 5)))


def test_input_transform():
# Test construction:
p = GP(TensorProductMean(lambda x: x**2), EQ())
p = GP(lambda x: x**2, Linear())
assert (
str(p.transform(lambda x: x))
== "GP(<lambda> transform <lambda>, EQ() transform <lambda>)"
== "GP(<lambda> transform <lambda>, Linear() transform <lambda>)"
)

# Test case:
p = GP(EQ())
p_transformed = p.transform(lambda x: B.sqrt(x))

x = B.linspace(0, 5, 10)
y = p_transformed(x).sample()

post = p.measure | (p_transformed(x), y)
post = p.measure | (p_transformed(x, B.epsilon), y)
assert_equal_normals(post(p(B.sqrt(x))), post(p_transformed(x)))
assert_equal_normals(post(p(x)), post(p_transformed(x * x)))


def test_selection():
# Test construction:
p = GP(TensorProductMean(lambda x: x**2), EQ())
p = GP(lambda x: x**2, EQ())
assert str(p.select(1)) == "GP(<lambda> : [1], EQ() : [1])"
assert str(p.select(1, 2)) == "GP(<lambda> : [1, 2], EQ() : [1, 2])"

# Test case:
p = GP(EQ()) # 1D
p2 = p.select(0) # 2D
# `p` is 1D.
p2 = p.select(0) # `p2` is 2D.

x = B.linspace(0, 5, 10)
x21 = B.stack(x, B.randn(10), axis=1)
Expand All @@ -520,11 +516,10 @@ def test_selection():

def test_derivative():
# Test construction:
p = GP(TensorProductMean(lambda x: x**2), EQ())
p = GP(lambda x: x**2, EQ())
assert str(p.diff(1)) == "GP(d(1) <lambda>, d(1) EQ())"

# Test case:
p = GP(EQ())
dp = p.diff()

x = B.linspace(tf.float64, 0, 1, 100)
Expand Down Expand Up @@ -574,7 +569,7 @@ def test_multi_sample():

def test_sample_correct_measure():
m = Measure()
p1 = GP(EQ(), measure=m)
p1 = GP(1, EQ(), measure=m)

post = m | (p1(0), 1)

Expand Down
Loading

0 comments on commit cbeabe7

Please sign in to comment.