Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add residual plot overloading the - operator #275

Closed
wants to merge 12 commits into from
5 changes: 5 additions & 0 deletions src/plopp/backends/matplotlib/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,8 @@ def __truediv__(self, other):
from .tiled import vstack

return vstack(self, other)

def __sub__(self, other):
from .residuals import residuals

return residuals(self, other)
106 changes: 106 additions & 0 deletions src/plopp/backends/matplotlib/residuals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

import matplotlib.pyplot as plt

from ..protocols import FigureLike
from .utils import copy_figure, make_figure


class ResidualPlot:
def __init__(self, main_panel: FigureLike, res_panel: FigureLike):
self.main_panel = main_panel
self.res_panel = res_panel
self.panels = [self.main_panel, self.res_panel]

def __getitem__(self, key):
return self.panels[key]

def __repr__(self):
return f"ResidualPlot(main_panel={self.main_panel}, res_panel={self.res_panel})"

def _repr_mimebundle_(self, *args, **kwargs) -> dict:
return self.main_panel._repr_mimebundle_(*args, **kwargs)


def residuals(main_fig: FigureLike, reference: FigureLike) -> ResidualPlot:
"""
Create a residual plot from two figures, using the data from the second figure as
the reference the residuals are computed from.

Parameters
----------
main_fig:
The main figure.
reference:
The reference figure.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed this earlier, but wouldn't it make more sense to accept a data array instead of a figure? Why would you require the user to plot the reference data first?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm starting to think that the whole approach using operators is a bad idea. I also ran into problems if, for example, the user made a figure where they changed the line style or the line color. Then I had to make sure that was copied over.

I think we need to go back to some of the suggestions in #201 and use a dedicated function, and drop the operators.


Returns
-------
:
A figure with a main panel showing the data from both the main and reference
figures, and a smaller 'residuals' panel at the bottom displaying the difference
between the data from the main figure with the data from the reference figure.
"""
# If there is a colormapper, we are dealing with a 2d figure
if hasattr(main_fig._view, 'colormapper') or hasattr(
reference._view, 'colormapper'
):
raise TypeError("The residual plot only supports 1d figures.")
if len(reference.artists) != 1:
raise TypeError(
"The reference figure must contain exactly one line to "
"compute residuals."
)

fig = make_figure(figsize=(6.0, 4.0))
gs = fig.add_gridspec(
2,
1,
height_ratios=(4, 1),
hspace=0.0,
)

main_ax = fig.add_subplot(gs[0])
res_ax = fig.add_subplot(gs[1], sharex=main_ax)
main_panel = copy_figure(main_fig, ax=main_ax)
main_canvas = main_panel.canvas
if main_canvas.is_widget():
fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
ref_node = next(iter(reference.graph_nodes.values()))
data = ref_node()
if not data.name:
data.name = "reference"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modifies the input, right? Seems like you should make a copy of reference first.

ref_node.add_view(main_panel)
main_panel._view.render()
# main_view._view.artists[ref_node.id]._line.set_zorder(-10)
if main_canvas._legend:
main_ax.legend()
diff_nodes = [n - ref_node for n in main_fig.graph_nodes.values()]
res_panel = reference.__class__(reference._view.__class__, *diff_nodes, ax=res_ax)

main_ax.tick_params(
top=True, labeltop=True, bottom=False, labelbottom=False, direction='out'
)
main_ax.secondary_xaxis("bottom").tick_params(
axis="x",
direction="in",
top=False,
labeltop=False,
bottom=True,
labelbottom=False,
)
res_ax.tick_params(
top=False, labeltop=False, bottom=True, labelbottom=True, direction='out'
)
res_ax.secondary_xaxis("top").tick_params(
axis="x",
direction="in",
top=True,
labeltop=False,
bottom=False,
labelbottom=False,
)

return ResidualPlot(main_panel=main_panel, res_panel=res_panel)
6 changes: 4 additions & 2 deletions src/plopp/data/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ def scatter(npoints=500, scale=10.0, seed=1) -> sc.DataArray:
)


def random(shape, dtype='float64', unit='', dims=None, seed=None) -> sc.DataArray:
def random(
shape, dtype='float64', unit='', dims=None, seed=None, binedges=False
) -> sc.DataArray:
"""
Generate a data array containing random data values.

Expand Down Expand Up @@ -213,7 +215,7 @@ def random(shape, dtype='float64', unit='', dims=None, seed=None) -> sc.DataArra
dtype=dtype,
),
coords={
dim: sc.arange(dim, shape[i], unit='m', dtype='float64')
dim: sc.arange(dim, shape[i] + int(binedges), unit='m', dtype='float64')
for i, dim in enumerate(dims)
},
)
Expand Down
101 changes: 101 additions & 0 deletions tests/backends/matplotlib/mpl_residuals_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

import numpy as np
import pytest

import plopp as pp
from plopp.backends.matplotlib.residuals import residuals
from plopp.data.testing import data_array


def test_single_line():
ref = data_array(ndim=1)
a = ref.copy()
a.values += np.random.uniform(-0.25, 0.25, size=len(a))
fig1 = ref.plot()
fig2 = a.plot()
fig = residuals(fig2, fig1)
assert len(fig.main_panel.fig.get_axes()) == 2
assert len(fig.main_panel.artists) == 2
assert len(fig.res_panel.artists) == 1


def test_three_lines():
ref = data_array(ndim=1)
a = ref.copy()
a.values += np.random.uniform(-0.25, 0.25, size=len(a))
b = ref.copy()
b.values += np.random.uniform(-1.0, 1.0, size=len(b))
c = ref.copy()
fig1 = ref.plot()
fig2 = pp.plot({'a': a, 'b': b, 'c': c})
fig = residuals(fig2, fig1)
assert len(fig.main_panel.artists) == 4
assert len(fig.res_panel.artists) == 3


def test_with_bin_edges():
ref = data_array(ndim=1, binedges=True)
a = ref.copy()
a.values += np.random.uniform(-0.25, 0.25, size=len(a))
fig1 = ref.plot()
fig2 = a.plot()
fig = residuals(fig2, fig1)
for artist in fig.main_panel.artists.values():
assert artist._line.get_linestyle() == '-'
for artist in fig.res_panel.artists.values():
assert artist._line.get_linestyle() == '-'


def test_raises_when_given_2d_plots():
da1d = data_array(ndim=1)
da2d = data_array(ndim=2)
msg = "The residual plot only supports 1d figures."
with pytest.raises(TypeError, match=msg):
residuals(da2d.plot(), da1d.plot())
with pytest.raises(TypeError, match=msg):
residuals(da1d.plot(), da2d.plot())
with pytest.raises(TypeError, match=msg):
residuals(da2d.plot(), da2d.plot())


def test_raises_when_reference_contains_multiple_lines():
ref = data_array(ndim=1)
a = ref.copy()
a.values += np.random.uniform(-0.25, 0.25, size=len(a))
b = ref.copy()
b.values += np.random.uniform(-1.0, 1.0, size=len(b))
fig1 = pp.plot({'a': a, 'b': b})
fig2 = ref.plot()
with pytest.raises(
TypeError,
match="The reference figure must contain exactly one line",
):
residuals(fig2, fig1)


def test_single_line_operator_minus():
ref = data_array(ndim=1)
a = ref.copy()
a.values += np.random.uniform(-0.25, 0.25, size=len(a))
fig1 = ref.plot()
fig2 = a.plot()
fig = fig2 - fig1
assert len(fig.main_panel.fig.get_axes()) == 2
assert len(fig.main_panel.artists) == 2
assert len(fig.res_panel.artists) == 1


def test_three_lines_operator_minus():
ref = data_array(ndim=1)
a = ref.copy()
a.values += np.random.uniform(-0.25, 0.25, size=len(a))
b = ref.copy()
b.values += np.random.uniform(-1.0, 1.0, size=len(b))
c = ref.copy()
fig1 = ref.plot()
fig2 = pp.plot({'a': a, 'b': b, 'c': c})
fig = fig2 - fig1
assert len(fig.main_panel.artists) == 4
assert len(fig.res_panel.artists) == 3
Loading