Skip to content

Commit

Permalink
Fix missing attribute '_from_state' in 'EmpiricalCovariance' (#1471)
Browse files Browse the repository at this point in the history
* ADD: from_state allows int as covariance

* FIX: emp _from_state

* FIX: broken statefulness of update_many

* FIX: remove unexpected kwarg

* FIX: if statement nesting + redundant docstring entry

* FIX: wrong intendation + sort keys for shuffled samples
  • Loading branch information
MarekWadinger authored Feb 24, 2024
1 parent af6d7c5 commit 87941ba
Showing 1 changed file with 54 additions and 5 deletions.
59 changes: 54 additions & 5 deletions river/covariance/emp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class SymmetricMatrix(abc.ABC):

@property
@abc.abstractmethod
def matrix(self):
def matrix(self) -> dict:
...

def __getitem__(self, key):
Expand Down Expand Up @@ -183,16 +183,38 @@ def update_many(self, X: pd.DataFrame):
)
}

self._update_from_state(n=n, mean=mean, cov=cov)

def _update_from_state(self, n: int, mean: dict, cov: float | dict):
"""Update from state information.
Parameters
----------
n
The number of data points.
mean
A dictionary of variable means.
cov
A dictionary of covariance or variance values.
Raises
----------
KeyError: If an element in `mean` or `cov` is missing.
"""
for i, j in itertools.combinations(sorted(mean.keys()), r=2):
try:
self[i, j]
except KeyError:
self._cov[i, j] = stats.Cov(self.ddof)
if isinstance(cov, dict):
cov_ = cov.get((i, j), cov.get((j, i)))
else:
cov_ = cov
self._cov[i, j] += stats.Cov._from_state(
n=n,
mean_x=mean[i],
mean_y=mean[j],
cov=cov.get((i, j), cov.get((j, i))),
cov=cov_,
ddof=self.ddof,
)

Expand All @@ -201,9 +223,36 @@ def update_many(self, X: pd.DataFrame):
self[i, i]
except KeyError:
self._cov[i, i] = stats.Var(self.ddof)
self._cov[i, i] += stats.Var._from_state(
n=n, m=mean[i], sig=cov[i, i], ddof=self.ddof
)
if isinstance(cov, dict):
cov_ = cov[i, i]
else:
cov_ = cov
self._cov[i, i] += stats.Var._from_state(n=n, m=mean[i], sig=cov_, ddof=self.ddof)

@classmethod
def _from_state(cls, n: int, mean: dict, cov: float | dict, *, ddof=1):
"""Create a new instance from state information.
Parameters
----------
cls
The class type.
n
The number of data points.
mean
A dictionary of variable means.
cov
A dictionary of covariance or variance values.
ddof
Degrees of freedom for covariance calculation. Defaults to 1.
Returns
----------
cls: A new instance of the class with updated covariance matrix.
"""
new = cls(ddof=ddof)
new._update_from_state(n=n, mean=mean, cov=cov)
return new


class EmpiricalPrecision(SymmetricMatrix):
Expand Down

0 comments on commit 87941ba

Please sign in to comment.