Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

text as label #2

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from collections.abc import Mapping, Sequence

import cycler
Expand Down Expand Up @@ -72,6 +73,12 @@ def branches(
-------
ax - The axes that the plot was drawn on.
""" # noqa: D205
# Setup
if not ax:
ax = plt.gca()
if (ax.name == "polar" and not polar) or (ax.name != "polar" and polar):
warnings.warn("Polar setting of axes does not match requested type. Creating new axes.", stacklevel=2)
fig, ax = plt.subplots(subplot_kw={"projection": "polar"} if polar else None)
kwargs = kwargs if kwargs else {}
if not key:
key = next(iter(tdata.obst.keys()))
Expand Down Expand Up @@ -116,18 +123,20 @@ def branches(
else:
raise ValueError("Invalid linewidth value. Must be int, float, or an str specifying an attribute of the edges.")
# Plot
if not ax:
subplot_kw = {"projection": "polar"} if polar else None
fig, ax = plt.subplots(subplot_kw=subplot_kw)
elif (ax.name == "polar") != polar:
raise ValueError("Provided axis does not match the requested 'polar' setting.")
ax.add_collection(LineCollection(zorder=1, **kwargs))
# Configure plot
lat_lim = (-0.2, depth)
lon_lim = (0, 2 * np.pi)
ax.set_xlim(lon_lim if polar else lat_lim)
ax.set_ylim(lat_lim if polar else lon_lim)
ax.axis("off")
if polar:
ax.set_ylim((-depth * 0.05, depth * 1.05))
ax.spines["polar"].set_visible(False)
else:
ax.set_ylim((-0.03 * np.pi, 2.03 * np.pi))
ax.set_xlim((-depth * 0.05, depth * 1.05))
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.tick_params(length=0)
ax.set_xticks([])
ax.set_yticks([])
ax._attrs = {
"node_coords": node_coords,
"leaves": leaves,
Expand Down Expand Up @@ -388,19 +397,23 @@ def annotation(
# Plot
if attrs["polar"]:
ax.pcolormesh(lons, lats, rgb_array.swapaxes(0, 1), zorder=2, **kwargs)
ax.set_ylim(-0.2, end_lat)
ax.set_ylim(-attrs["depth"] * 0.05, end_lat)
else:
ax.pcolormesh(lats, lons, rgb_array, zorder=2, **kwargs)
ax.set_xlim(-0.2, end_lat)
labels_lats = np.linspace(start_lat, end_lat, len(labels) + 1)
labels_lats = labels_lats + (end_lat - start_lat) / (len(labels) * 2)
for idx, label in enumerate(labels):
if is_array and len(labels) == 1:
ax.text(labels_lats[idx], -0.1, label, ha="center", va="top")
ax.set_ylim(-0.5, 2 * np.pi)
else:
ax.text(labels_lats[idx], -0.1, label, ha="center", va="top", rotation=90)
ax.set_ylim(-1, 2 * np.pi)
ax.set_xlim(-attrs["depth"] * 0.05, end_lat)
# Add labels
if labels and len(labels) > 0:
labels_lats = np.linspace(start_lat, end_lat, len(labels) + 1)
labels_lats = labels_lats + (end_lat - start_lat) / (len(labels) * 2)
existing_ticks = ax.get_xticks()
existing_labels = [label.get_text() for label in ax.get_xticklabels()]
ax.set_xticks(np.append(existing_ticks, labels_lats[:-1]))
ax.set_xticklabels(existing_labels + labels)
for label in ax.get_xticklabels()[len(existing_ticks) :]:
if is_array and len(labels) == 1:
label.set_rotation(0)
else:
label.set_rotation(90)
ax._attrs.update({"offset": end_lat})
return ax

Expand Down
21 changes: 10 additions & 11 deletions tests/test_plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def test_polar_with_clades(tdata):
fig, ax = plt.subplots(dpi=600, subplot_kw={"polar": True})
fig, ax = plt.subplots(dpi=300, subplot_kw={"polar": True})
pycea.pl.branches(tdata, key="tree", polar=True, color="clade", palette="Set1", na_color="black", ax=ax)
pycea.pl.nodes(tdata, color="clade", palette="Set1", style="clade", ax=ax)
pycea.pl.annotation(tdata, keys="clade", ax=ax)
Expand All @@ -18,19 +18,18 @@ def test_polar_with_clades(tdata):


def test_angled_numeric_annotations(tdata):
fig, ax = plt.subplots(dpi=600)
pycea.pl.branches(
tdata, key="tree", polar=False, color="length", cmap="hsv", linewidth="length", angled_branches=True, ax=ax
tdata, key="tree", polar=False, color="length", cmap="hsv", linewidth="length", angled_branches=True
)
pycea.pl.nodes(tdata, nodes="all", color="time", style="s", size=20, ax=ax)
pycea.pl.annotation(tdata, keys=["x", "y"], cmap="magma", width=0.1, gap=0.05, ax=ax)
pycea.pl.annotation(tdata, keys=["0", "1", "2", "3", "4", "5"], label="genes", ax=ax)
pycea.pl.nodes(tdata, nodes="all", color="time", style="s", size=20)
pycea.pl.annotation(tdata, keys=["x", "y"], cmap="magma", width=0.1, gap=0.05)
pycea.pl.annotation(tdata, keys=["0", "1", "2", "3", "4", "5"], label="genes")
plt.savefig(plot_path / "angled_numeric.png")
plt.close()


def test_matrix_annotation(tdata):
fig, ax = plt.subplots(dpi=600)
fig, ax = plt.subplots(dpi=300)
pycea.pl.tree(
tdata,
key="tree",
Expand All @@ -44,19 +43,19 @@ def test_matrix_annotation(tdata):
plt.close()


def test_branches_invalid_input(tdata):
def test_branches_bad_input(tdata):
fig, ax = plt.subplots()
with pytest.raises(ValueError):
pycea.pl.branches(tdata, key="tree", color=["bad"] * 5)
with pytest.raises(ValueError):
pycea.pl.branches(tdata, key="tree", linewidth=["bad"] * 5)
# Can't plot polar with non-polar axis
with pytest.raises(ValueError):
# Warns about polar
with pytest.warns(match="Polar"):
pycea.pl.branches(tdata, key="tree", polar=True, ax=ax)
plt.close()


def test_annotation_invalid_input(tdata):
def test_annotation_bad_input(tdata):
# Need to plot branches first
fig, ax = plt.subplots()
with pytest.raises(ValueError):
Expand Down
Loading