Skip to content

Commit

Permalink
Merge pull request #1108 from olamarre/bugfix/ep-pickling
Browse files Browse the repository at this point in the history
fix: pickle and deep copy classification models with EP
  • Loading branch information
MartinBubel authored Jan 15, 2025
2 parents 9a31886 + b6efa8e commit acdd03d
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Changelog

## Unreleased
* fix pickle and deep copy for classification models inheriting from EP #1108 [olamarre]

* update prior `__new__` methods #1098 [MartinBubel]

* fix invalid escape sequence #1011 [janmayer]
Expand Down
27 changes: 10 additions & 17 deletions GPy/inference/latent_function_inference/expectation_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,24 +229,17 @@ def _stop_criteria(self, ga_approx):
v_diff = np.mean(np.square(ga_approx.v-self.ga_approx_old.v))
return ((tau_diff < self.epsilon) and (v_diff < self.epsilon))

def __setstate__(self, state):
super(EPBase, self).__setstate__(state[0])
self.epsilon, self.eta, self.delta = state[1]
self.reset()

def __getstate__(self):
return [super(EPBase, self).__getstate__() , [self.epsilon, self.eta, self.delta]]

def _save_to_input_dict(self):
input_dict = super(EPBase, self)._save_to_input_dict()
input_dict["epsilon"]=self.epsilon
input_dict["eta"]=self.eta
input_dict["delta"]=self.delta
input_dict["always_reset"]=self.always_reset
input_dict["max_iters"]=self.max_iters
input_dict["ep_mode"]=self.ep_mode
input_dict["parallel_updates"]=self.parallel_updates
input_dict["loading"]=True
input_dict = {
"epsilon": self.epsilon,
"eta": self.eta,
"delta": self.delta,
"always_reset": self.always_reset,
"max_iters": self.max_iters,
"ep_mode": self.ep_mode,
"parallel_updates": self.parallel_updates,
"loading": True
}
return input_dict

class EP(EPBase, ExactGaussianInference):
Expand Down
28 changes: 28 additions & 0 deletions GPy/testing/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
The test cases for various inference algorithms
"""

import copy
import pickle
import numpy as np
import GPy

Expand Down Expand Up @@ -146,6 +148,32 @@ def test_inference_EP(self):
< 1e6
)

def test_pickle_copy_EP(self):
"""Pickling and deep-copying a classification model employing EP"""

# Dummy binary classification dataset
X = np.array([0, 1, 2, 3]).reshape(-1, 1)
Y = np.array([0, 0, 1, 1]).reshape(-1, 1)

# Some classification model
inf = GPy.inference.latent_function_inference.expectation_propagation.EP(
max_iters=30, delta=0.5
)
m = GPy.core.GP(
X=X,
Y=Y,
kernel=GPy.kern.RBF(input_dim=1, variance=1.0, lengthscale=1.0),
inference_method = inf,
likelihood=GPy.likelihoods.Bernoulli(),
mean_function=None
)
m.optimize()

m_pickled = pickle.dumps(m)
assert pickle.loads(m_pickled) is not None

assert copy.deepcopy(m) is not None

# NOTE: adding a test like above for parameterized likelihood- the above test is
# only for probit likelihood which does not have any tunable hyperparameter which is why
# the term in dictionary of gradients: dL_dthetaL will always be zero. So here we repeat tests for
Expand Down

0 comments on commit acdd03d

Please sign in to comment.