diff --git a/causalpy/pymc_experiments.py b/causalpy/pymc_experiments.py index c07bb1a6..c32b62d9 100644 --- a/causalpy/pymc_experiments.py +++ b/causalpy/pymc_experiments.py @@ -463,18 +463,6 @@ class DifferenceInDifferences(ExperimentalDesign): ... } ... ) ... ) - >>> result.summary() # doctest: +NUMBER - ===========================Difference in Differences============================ - Formula: y ~ 1 + group*post_treatment - - Results: - Causal impact = 0.5, $CI_{94%}$[0.4, 0.6] - Model coefficients: - Intercept 1.0, 94% HDI [1.0, 1.1] - post_treatment[T.True] 0.9, 94% HDI [0.9, 1.0] - group 0.1, 94% HDI [0.0, 0.2] - group:post_treatment[T.True] 0.5, 94% HDI [0.4, 0.6] - sigma 0.0, 94% HDI [0.0, 0.1] """ def __init__( @@ -726,7 +714,7 @@ def _plot_causal_impact_arrow(self, ax): def _causal_impact_summary_stat(self) -> str: """Computes the mean and 94% credible interval bounds for the causal impact.""" percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values - ci = "$CI_{94%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]" + ci = "$CI_{94\\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]" causal_impact = f"{self.causal_impact.mean():.2f}, " return f"Causal impact = {causal_impact + ci}" diff --git a/causalpy/tests/test_pymc_experiments.py b/causalpy/tests/test_pymc_experiments.py new file mode 100644 index 00000000..51a6b5f7 --- /dev/null +++ b/causalpy/tests/test_pymc_experiments.py @@ -0,0 +1,21 @@ +""" +Unit tests for pymc_experiments.py +""" + +import causalpy as cp + +sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2} + + +def test_did_summary(): + """Test that the summary stat function returns a string.""" + df = cp.load_data("did") + result = cp.pymc_experiments.DifferenceInDifferences( + df, + formula="y ~ 1 + group*post_treatment", + time_variable_name="t", + group_variable_name="group", + model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), + ) + print(type(result._causal_impact_summary_stat())) + assert isinstance(result._causal_impact_summary_stat(), str) diff --git a/docs/source/_static/interrogate_badge.svg b/docs/source/_static/interrogate_badge.svg index 3cd04abb..915a4635 100644 --- a/docs/source/_static/interrogate_badge.svg +++ b/docs/source/_static/interrogate_badge.svg @@ -1,5 +1,5 @@ - interrogate: 97.2% + interrogate: 97.3% @@ -12,8 +12,8 @@ interrogate interrogate - 97.2% - 97.2% + 97.3% + 97.3%