Skip to content

Commit

Permalink
Merge pull request #112 from ken-lauer/enh_line_width_scale_factor
Browse files Browse the repository at this point in the history
ENH: add line width scaling factor for matplotlib/bokeh
  • Loading branch information
ChristopherMayes authored Jan 8, 2025
2 parents 4c36f00 + cb34279 commit 3fac35c
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 41 deletions.
71 changes: 58 additions & 13 deletions pytao/plotting/bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class _Defaults:
max_data_points: int = 10_000
variables_per_row: int = 2
show_sliders: bool = True
line_width_scale: float = 0.5
floor_line_width_scale: float = 0.5

@classmethod
def get_size_for_class(
Expand Down Expand Up @@ -149,6 +151,8 @@ def set_defaults(
max_data_points: Optional[int] = None,
variables_per_row: Optional[int] = None,
show_sliders: Optional[bool] = None,
line_width_scale: Optional[float] = None,
floor_line_width_scale: Optional[float] = None,
):
"""
Change defaults used for Bokeh plots.
Expand Down Expand Up @@ -188,6 +192,10 @@ def set_defaults(
Variables to list per row when in single mode (i.e., `vars=True`).
show_sliders : bool, default=True
Show sliders alongside the spinners in single mode.
line_width_scale : float, default=1.0
Plot line width scaling factor applied to Tao's line width.
floor_line_width_scale : float, default=0.5
Floor plan line width scaling factor applied to Tao's line width.
"""

if width is not None:
Expand Down Expand Up @@ -224,6 +232,10 @@ def set_defaults(
_Defaults.variables_per_row = int(variables_per_row)
if show_sliders is not None:
_Defaults.show_sliders = bool(show_sliders)
if line_width_scale is not None:
_Defaults.line_width_scale = float(line_width_scale)
if floor_line_width_scale is not None:
_Defaults.floor_line_width_scale = float(floor_line_width_scale)
return {
key: value
for key, value in vars(_Defaults).items()
Expand Down Expand Up @@ -369,6 +381,7 @@ def _plot_curve_line(
line: PlotCurveLine,
name: Optional[str] = None,
source: Optional[ColumnDataSource] = None,
line_width_scale: float = 1.0,
):
if source is None:
source = ColumnDataSource(data={})
Expand All @@ -382,18 +395,29 @@ def _plot_curve_line(
return fig.line(
"x",
"y",
line_width=line.linewidth,
line_width=line.linewidth * line_width_scale,
source=source,
color=bokeh_color(line.color),
name=name,
**kw,
)


def _plot_curve(fig: figure, curve: PlotCurve, source: CurveData) -> None:
def _plot_curve(
fig: figure,
curve: PlotCurve,
source: CurveData,
line_width_scale: float = 1.0,
) -> None:
name = pgplot.mathjax_string(curve.info["name"])
if "line" in source and curve.line is not None:
_plot_curve_line(fig=fig, line=curve.line, name=name, source=source["line"])
_plot_curve_line(
fig=fig,
line=curve.line,
name=name,
source=source["line"],
line_width_scale=line_width_scale,
)

if "symbol" in source and curve.symbol is not None:
legend = None if "line" in source else name
Expand Down Expand Up @@ -461,6 +485,7 @@ def _draw_layout_elems(
fig: figure,
elems: List[LatticeLayoutElement],
skip_labels: bool = True,
line_width_scale: float = 1.0,
):
line_data = {
"xs": [],
Expand Down Expand Up @@ -492,15 +517,15 @@ def _draw_layout_elems(
line_data["name"].extend([elem.name] * len(lines))
line_data["s_start"].extend([elem.info["ele_s_start"]] * len(lines))
line_data["s_end"].extend([elem.info["ele_s_end"]] * len(lines))
line_data["line_width"].extend([shape.line_width] * len(lines))
line_data["line_width"].extend([shape.line_width * line_width_scale] * len(lines))
line_data["color"].extend([color] * len(lines))

if isinstance(shape, LayoutShape):
for patch in shape.to_patches():
if isinstance(patch, PlotPatchRectangle):
rectangles.append((elem, shape, patch))
else:
_plot_patch(fig, patch, line_width=shape.line_width)
_plot_patch(fig, patch, line_width=shape.line_width * line_width_scale)

if rectangles:
source = ColumnDataSource(
Expand All @@ -509,7 +534,9 @@ def _draw_layout_elems(
"ys": [[[_patch_rect_to_points(patch)[1]]] for _, _, patch in rectangles],
"name": [shape.name for _, shape, _ in rectangles],
"color": [bokeh_color(shape.color) for _, shape, _ in rectangles],
"line_width": [shape.line_width for _, shape, _ in rectangles],
"line_width": [
shape.line_width * line_width_scale for _, shape, _ in rectangles
],
"s_start": [elem.info["ele_s_start"] for elem, _, _ in rectangles],
"s_end": [elem.info["ele_s_end"] for elem, _, _ in rectangles],
}
Expand Down Expand Up @@ -586,6 +613,7 @@ def _draw_annotations(
def _draw_floor_plan_shapes(
fig: figure,
elems: List[FloorPlanElement],
line_width_scale: float = 1.0,
):
polygon_data = {
"xs": [],
Expand Down Expand Up @@ -620,19 +648,23 @@ def _draw_floor_plan_shapes(
polygon_data["xs"].append([[vx]])
polygon_data["ys"].append([[vy]])
polygon_data["name"].append(shape.name)
polygon_data["line_width"].append(shape.line_width)
polygon_data["line_width"].append(shape.line_width * line_width_scale)
polygon_data["color"].append(bokeh_color(shape.color))
else:
for patch in shape.to_patches():
assert not isinstance(patch, (PlotPatchRectangle, PlotPatchPolygon))
_plot_patch(fig, patch, line_width=shape.line_width)
_plot_patch(
fig, patch, line_width=shape.line_width * _Defaults.floor_line_width_scale
)

lines = shape.to_lines()
if lines:
line_data["xs"].extend([line.xs for line in lines])
line_data["ys"].extend([line.ys for line in lines])
line_data["name"].extend([shape.name] * len(lines))
line_data["line_width"].extend([line.linewidth for line in lines])
line_data["line_width"].extend(
[line.linewidth * line_width_scale for line in lines]
)
line_data["color"].extend([bokeh_color(line.color) for line in lines])

if line_data["xs"]:
Expand Down Expand Up @@ -695,7 +727,6 @@ def _plot_patch(
if source is None:
source = ColumnDataSource()

line_width = line_width if line_width is not None else patch.linewidth
if isinstance(patch, PlotPatchRectangle):
cx, cy = patch.center
source.data["x"] = [cx]
Expand Down Expand Up @@ -894,7 +925,12 @@ def create_figure(
fig.yaxis.ticker = []
fig.yaxis.visible = False

_draw_layout_elems(fig, self.graph.elements, skip_labels=True)
_draw_layout_elems(
fig,
self.graph.elements,
skip_labels=True,
line_width_scale=_Defaults.line_width_scale,
)

if add_named_hover_tool:
hover = bokeh.models.HoverTool(
Expand Down Expand Up @@ -1105,7 +1141,11 @@ def create_figure(
if box_zoom is not None:
box_zoom.match_aspect = True

_draw_floor_plan_shapes(fig, self.graph.elements)
_draw_floor_plan_shapes(
fig,
self.graph.elements,
line_width_scale=_Defaults.floor_line_width_scale,
)

if add_named_hover_tool:
hover = bokeh.models.HoverTool(
Expand All @@ -1121,7 +1161,12 @@ def create_figure(
for line in self.graph.building_walls.lines:
_plot_curve_line(fig, line)
for patch in self.graph.building_walls.patches:
_plot_patch(fig, patch)
if patch.linewidth is None:
line_width = 1.0
else:
line_width = patch.linewidth * _Defaults.floor_line_width_scale

_plot_patch(fig, patch, line_width=line_width)
orbits = self.graph.floor_orbits
if orbits is not None:
_plot_curve_symbols(fig, orbits.curve, name="floor_orbits")
Expand Down
Loading

0 comments on commit 3fac35c

Please sign in to comment.