diff --git a/seaborn/regression.py b/seaborn/regression.py index 5e5503a422..e936ef598a 100644 --- a/seaborn/regression.py +++ b/seaborn/regression.py @@ -76,9 +76,9 @@ class _RegressionPlotter(_LinearPlotter): def __init__(self, x, y, data=None, x_estimator=None, x_bins=None, x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None, seed=None, order=1, logistic=False, lowess=False, - robust=False, logx=False, x_partial=None, y_partial=None, - truncate=False, dropna=True, x_jitter=None, y_jitter=None, - color=None, label=None): + lowess_kws=None, robust=False, logx=False, + x_partial=None, y_partial=None, truncate=False, dropna=True, + x_jitter=None, y_jitter=None, color=None, label=None): # Set member attributes self.x_estimator = x_estimator @@ -91,6 +91,7 @@ def __init__(self, x, y, data=None, x_estimator=None, x_bins=None, self.order = order self.logistic = logistic self.lowess = lowess + self.lowess_kws = {} if lowess_kws is None else lowess_kws self.robust = robust self.logx = logx self.truncate = truncate @@ -126,13 +127,20 @@ def __init__(self, x, y, data=None, x_estimator=None, x_bins=None, self.x_discrete = self.x # Disable regression in case of singleton inputs - if len(self.x) <= 1: + if len(self.x) <= 1 or self.lowess: self.fit_reg = False # Save the range of the x variable for the grid later if self.fit_reg: self.x_range = self.x.min(), self.x.max() + # Check lowess_kws + if self.lowess: + allowed_lowess_kws = ("frac", "it", "delta") + for k in self.lowess_kws: + if k not in allowed_lowess_kws: + raise ValueError(f"Unsupported parameter '{k}' for lowess.") + @property def scatter_data(self): """Data where each observation is a point.""" @@ -306,7 +314,7 @@ def reg_func(_x, _y): def fit_lowess(self): """Fit a locally-weighted regression, which returns its own grid.""" from statsmodels.nonparametric.smoothers_lowess import lowess - grid, yhat = lowess(self.y, self.x).T + grid, yhat = lowess(self.y, self.x, **self.lowess_kws).T return grid, yhat def fit_logx(self, grid): @@ -533,6 +541,11 @@ def lineplot(self, ax, kws): model (locally weighted linear regression). Note that confidence intervals cannot currently be drawn for this kind of model.\ """), + lowess_kws=dedent("""\ + lowess_kws : dict, optional + Additional keyword arguments to pass to ``lowess()`` function \ + from ``statsmodels``. + """), robust=dedent("""\ robust : bool, optional If ``True``, use ``statsmodels`` to estimate a robust regression. This @@ -581,7 +594,7 @@ def lmplot( legend=True, legend_out=None, x_estimator=None, x_bins=None, x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None, seed=None, order=1, logistic=False, lowess=False, - robust=False, logx=False, x_partial=None, y_partial=None, + lowess_kws=None, robust=False, logx=False, x_partial=None, y_partial=None, truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None, line_kws=None, facet_kws=None, ): @@ -646,7 +659,7 @@ def update_datalim(data, x, y, ax, **kws): seed=seed, order=order, logistic=logistic, lowess=lowess, robust=robust, logx=logx, x_partial=x_partial, y_partial=y_partial, truncate=truncate, x_jitter=x_jitter, y_jitter=y_jitter, - scatter_kws=scatter_kws, line_kws=line_kws, + scatter_kws=scatter_kws, line_kws=line_kws, lowess_kws=lowess_kws ) facets.map_dataframe(regplot, x=x, y=y, **regplot_kws) facets.set_axis_labels(x, y) @@ -720,6 +733,7 @@ def update_datalim(data, x, y, ax, **kws): {order} {logistic} {lowess} + {lowess_kws} {robust} {logx} {xy_partial} @@ -753,8 +767,8 @@ def regplot( data=None, *, x=None, y=None, x_estimator=None, x_bins=None, x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None, - seed=None, order=1, logistic=False, lowess=False, robust=False, - logx=False, x_partial=None, y_partial=None, + seed=None, order=1, logistic=False, lowess=False, lowess_kws=None, + robust=False, logx=False, x_partial=None, y_partial=None, truncate=True, dropna=True, x_jitter=None, y_jitter=None, label=None, color=None, marker="o", scatter_kws=None, line_kws=None, ax=None @@ -762,8 +776,8 @@ def regplot( plotter = _RegressionPlotter(x, y, data, x_estimator, x_bins, x_ci, scatter, fit_reg, ci, n_boot, units, seed, - order, logistic, lowess, robust, logx, - x_partial, y_partial, truncate, dropna, + order, logistic, lowess, lowess_kws, robust, + logx, x_partial, y_partial, truncate, dropna, x_jitter, y_jitter, color, label) if ax is None: @@ -800,6 +814,7 @@ def regplot( {order} {logistic} {lowess} + {lowess_kws} {robust} {logx} {xy_partial} @@ -853,7 +868,7 @@ def regplot( def residplot( data=None, *, x=None, y=None, - x_partial=None, y_partial=None, lowess=False, + x_partial=None, y_partial=None, lowess=False, lowess_kws=None, order=1, robust=False, dropna=True, label=None, color=None, scatter_kws=None, line_kws=None, ax=None ): @@ -877,6 +892,8 @@ def residplot( the `x` or `y` variables before plotting. lowess : boolean, optional Fit a lowess smoother to the residual scatterplot. + lowess_kws : dict, optional + Additional keyword arguments passed to lowess() from statsmodels. order : int, optional Order of the polynomial to fit when calculating the residuals. robust : boolean, optional @@ -915,6 +932,7 @@ def residplot( plotter = _RegressionPlotter(x, y, data, ci=None, order=order, robust=robust, x_partial=x_partial, y_partial=y_partial, + lowess=lowess, lowess_kws=lowess_kws, dropna=dropna, color=color, label=label) if ax is None: @@ -924,12 +942,6 @@ def residplot( _, yhat, _ = plotter.fit_regression(grid=plotter.x) plotter.y = plotter.y - yhat - # Set the regression option on the plotter - if lowess: - plotter.lowess = True - else: - plotter.fit_reg = False - # Plot a horizontal line at 0 ax.axhline(0, ls=":", c=".2") diff --git a/tests/test_regression.py b/tests/test_regression.py index 368f6c50a6..7e9ba48a4e 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -431,6 +431,25 @@ def test_lowess_regression(self): assert len(grid) == len(yhat) assert err_bands is None + @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels") + def test_lowess_regression_with_kws(self): + lowess_kws = dict(frac=2 / 3, it=1, delta=0.0) + p = lm._RegressionPlotter("x", "y", data=self.df, lowess=True, + lowess_kws=lowess_kws) + grid, yhat, err_bands = p.fit_regression(x_range=(-3, 3)) + + assert len(grid) == len(yhat) + assert err_bands is None + + @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels") + def test_lowess_regression_with_bad_kw(self): + + lowess_kws = dict(frac=2 / 3, it=3, delta=0.0, bad_kw=-1) + with pytest.raises(ValueError, match="Unsupported parameter " + "'bad_kw' for lowess\\."): + lm._RegressionPlotter("x", "y", data=self.df, lowess=True, + lowess_kws=lowess_kws) + def test_regression_options(self): with pytest.raises(ValueError): @@ -666,6 +685,26 @@ def test_residplot_lowess(self): x, y = ax.lines[1].get_xydata().T npt.assert_array_equal(x, np.sort(self.df.x)) + @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels") + def test_residplot_lowess_with_kws(self): + + lowess_kws = dict(frac=2 / 3, it=3, delta=0.0) + ax = lm.residplot(x="x", y="y", data=self.df, lowess=True, + lowess_kws=lowess_kws) + assert len(ax.lines) == 2 + + x, y = ax.lines[1].get_xydata().T + npt.assert_array_equal(x, np.sort(self.df.x)) + + @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels") + def test_residplot_lowess_bad_kw(self): + + lowess_kws = dict(frac=2 / 3, it=3, delta=0.0, bad_kw=-1) + with pytest.raises(ValueError, match="Unsupported parameter" + " 'bad_kw' for lowess\\."): + lm.residplot(x="x", y="y", data=self.df, lowess=True, + lowess_kws=lowess_kws) + @pytest.mark.parametrize("option", ["robust", "lowess"]) @pytest.mark.skipif(not _no_statsmodels, reason="statsmodels installed") def test_residplot_statsmodels_missing_errors(self, long_df, option):