From cb34279804f65a727ddf1c12fa86d0c11e8063e5 Mon Sep 17 00:00:00 2001 From: Ken Lauer <152229072+ken-lauer@users.noreply.github.com> Date: Tue, 10 Dec 2024 09:31:16 -0800 Subject: [PATCH] ENH: add line width scaling factor for matplotlib/bokeh --- pytao/plotting/bokeh.py | 71 ++++++++++++++++++++----- pytao/plotting/mpl.py | 113 ++++++++++++++++++++++++++++++---------- 2 files changed, 143 insertions(+), 41 deletions(-) diff --git a/pytao/plotting/bokeh.py b/pytao/plotting/bokeh.py index 6aad0fa6..3f39e057 100644 --- a/pytao/plotting/bokeh.py +++ b/pytao/plotting/bokeh.py @@ -115,6 +115,8 @@ class _Defaults: max_data_points: int = 10_000 variables_per_row: int = 2 show_sliders: bool = True + line_width_scale: float = 0.5 + floor_line_width_scale: float = 0.5 @classmethod def get_size_for_class( @@ -149,6 +151,8 @@ def set_defaults( max_data_points: Optional[int] = None, variables_per_row: Optional[int] = None, show_sliders: Optional[bool] = None, + line_width_scale: Optional[float] = None, + floor_line_width_scale: Optional[float] = None, ): """ Change defaults used for Bokeh plots. @@ -188,6 +192,10 @@ def set_defaults( Variables to list per row when in single mode (i.e., `vars=True`). show_sliders : bool, default=True Show sliders alongside the spinners in single mode. + line_width_scale : float, default=1.0 + Plot line width scaling factor applied to Tao's line width. + floor_line_width_scale : float, default=0.5 + Floor plan line width scaling factor applied to Tao's line width. """ if width is not None: @@ -224,6 +232,10 @@ def set_defaults( _Defaults.variables_per_row = int(variables_per_row) if show_sliders is not None: _Defaults.show_sliders = bool(show_sliders) + if line_width_scale is not None: + _Defaults.line_width_scale = float(line_width_scale) + if floor_line_width_scale is not None: + _Defaults.floor_line_width_scale = float(floor_line_width_scale) return { key: value for key, value in vars(_Defaults).items() @@ -369,6 +381,7 @@ def _plot_curve_line( line: PlotCurveLine, name: Optional[str] = None, source: Optional[ColumnDataSource] = None, + line_width_scale: float = 1.0, ): if source is None: source = ColumnDataSource(data={}) @@ -382,7 +395,7 @@ def _plot_curve_line( return fig.line( "x", "y", - line_width=line.linewidth, + line_width=line.linewidth * line_width_scale, source=source, color=bokeh_color(line.color), name=name, @@ -390,10 +403,21 @@ def _plot_curve_line( ) -def _plot_curve(fig: figure, curve: PlotCurve, source: CurveData) -> None: +def _plot_curve( + fig: figure, + curve: PlotCurve, + source: CurveData, + line_width_scale: float = 1.0, +) -> None: name = pgplot.mathjax_string(curve.info["name"]) if "line" in source and curve.line is not None: - _plot_curve_line(fig=fig, line=curve.line, name=name, source=source["line"]) + _plot_curve_line( + fig=fig, + line=curve.line, + name=name, + source=source["line"], + line_width_scale=line_width_scale, + ) if "symbol" in source and curve.symbol is not None: legend = None if "line" in source else name @@ -461,6 +485,7 @@ def _draw_layout_elems( fig: figure, elems: List[LatticeLayoutElement], skip_labels: bool = True, + line_width_scale: float = 1.0, ): line_data = { "xs": [], @@ -492,7 +517,7 @@ def _draw_layout_elems( line_data["name"].extend([elem.name] * len(lines)) line_data["s_start"].extend([elem.info["ele_s_start"]] * len(lines)) line_data["s_end"].extend([elem.info["ele_s_end"]] * len(lines)) - line_data["line_width"].extend([shape.line_width] * len(lines)) + line_data["line_width"].extend([shape.line_width * line_width_scale] * len(lines)) line_data["color"].extend([color] * len(lines)) if isinstance(shape, LayoutShape): @@ -500,7 +525,7 @@ def _draw_layout_elems( if isinstance(patch, PlotPatchRectangle): rectangles.append((elem, shape, patch)) else: - _plot_patch(fig, patch, line_width=shape.line_width) + _plot_patch(fig, patch, line_width=shape.line_width * line_width_scale) if rectangles: source = ColumnDataSource( @@ -509,7 +534,9 @@ def _draw_layout_elems( "ys": [[[_patch_rect_to_points(patch)[1]]] for _, _, patch in rectangles], "name": [shape.name for _, shape, _ in rectangles], "color": [bokeh_color(shape.color) for _, shape, _ in rectangles], - "line_width": [shape.line_width for _, shape, _ in rectangles], + "line_width": [ + shape.line_width * line_width_scale for _, shape, _ in rectangles + ], "s_start": [elem.info["ele_s_start"] for elem, _, _ in rectangles], "s_end": [elem.info["ele_s_end"] for elem, _, _ in rectangles], } @@ -586,6 +613,7 @@ def _draw_annotations( def _draw_floor_plan_shapes( fig: figure, elems: List[FloorPlanElement], + line_width_scale: float = 1.0, ): polygon_data = { "xs": [], @@ -620,19 +648,23 @@ def _draw_floor_plan_shapes( polygon_data["xs"].append([[vx]]) polygon_data["ys"].append([[vy]]) polygon_data["name"].append(shape.name) - polygon_data["line_width"].append(shape.line_width) + polygon_data["line_width"].append(shape.line_width * line_width_scale) polygon_data["color"].append(bokeh_color(shape.color)) else: for patch in shape.to_patches(): assert not isinstance(patch, (PlotPatchRectangle, PlotPatchPolygon)) - _plot_patch(fig, patch, line_width=shape.line_width) + _plot_patch( + fig, patch, line_width=shape.line_width * _Defaults.floor_line_width_scale + ) lines = shape.to_lines() if lines: line_data["xs"].extend([line.xs for line in lines]) line_data["ys"].extend([line.ys for line in lines]) line_data["name"].extend([shape.name] * len(lines)) - line_data["line_width"].extend([line.linewidth for line in lines]) + line_data["line_width"].extend( + [line.linewidth * line_width_scale for line in lines] + ) line_data["color"].extend([bokeh_color(line.color) for line in lines]) if line_data["xs"]: @@ -695,7 +727,6 @@ def _plot_patch( if source is None: source = ColumnDataSource() - line_width = line_width if line_width is not None else patch.linewidth if isinstance(patch, PlotPatchRectangle): cx, cy = patch.center source.data["x"] = [cx] @@ -894,7 +925,12 @@ def create_figure( fig.yaxis.ticker = [] fig.yaxis.visible = False - _draw_layout_elems(fig, self.graph.elements, skip_labels=True) + _draw_layout_elems( + fig, + self.graph.elements, + skip_labels=True, + line_width_scale=_Defaults.line_width_scale, + ) if add_named_hover_tool: hover = bokeh.models.HoverTool( @@ -1105,7 +1141,11 @@ def create_figure( if box_zoom is not None: box_zoom.match_aspect = True - _draw_floor_plan_shapes(fig, self.graph.elements) + _draw_floor_plan_shapes( + fig, + self.graph.elements, + line_width_scale=_Defaults.floor_line_width_scale, + ) if add_named_hover_tool: hover = bokeh.models.HoverTool( @@ -1121,7 +1161,12 @@ def create_figure( for line in self.graph.building_walls.lines: _plot_curve_line(fig, line) for patch in self.graph.building_walls.patches: - _plot_patch(fig, patch) + if patch.linewidth is None: + line_width = 1.0 + else: + line_width = patch.linewidth * _Defaults.floor_line_width_scale + + _plot_patch(fig, patch, line_width=line_width) orbits = self.graph.floor_orbits if orbits is not None: _plot_curve_symbols(fig, orbits.curve, name="floor_orbits") diff --git a/pytao/plotting/mpl.py b/pytao/plotting/mpl.py index 7880482e..598cef0e 100644 --- a/pytao/plotting/mpl.py +++ b/pytao/plotting/mpl.py @@ -48,21 +48,52 @@ class _Defaults: layout_height: float = 0.5 + line_width_scale: float = 0.5 + floor_line_width_scale: float = 0.5 colormap: str = "PRGn_r" def set_defaults( layout_height: Optional[float] = None, colormap: Optional[str] = None, + line_width_scale: Optional[float] = None, + floor_line_width_scale: Optional[float] = None, figsize: Optional[Tuple[float, float]] = None, - width: Optional[int] = None, - height: Optional[int] = None, + width: Optional[float] = None, + height: Optional[float] = None, dpi: Optional[int] = None, ): + """ + Set default values for Matplotlib plot settings. + + Parameters + ---------- + layout_height : float, optional + Height of the layout. Default is 0.5. + colormap : str, optional + Colormap to use for plotting. Default is "PRGn_r". + line_width_scale : float, optional + Scale factor for line widths, excluding floor plan lines. Default is 0.5. + floor_line_width_scale : float, optional + Scale factor for floor plan line widths. Default is 0.5. + figsize : tuple of float, optional + Size of the figure (width, height). Default is as-configured in matplotlib rcParams. + width : float, optional + Width of the figure in inches. Default is as-configured in matplotlib rcParams. + height : float, optional + Height of the figure in inches. Default is as-configured in matplotlib rcParams. + dpi : int, optional + Dots per inch for the figure. Default is as-configured in matplotlib rcParams. + """ + if layout_height is not None: _Defaults.layout_height = layout_height if colormap is not None: _Defaults.colormap = colormap + if line_width_scale is not None: + _Defaults.line_width_scale = line_width_scale + if floor_line_width_scale is not None: + _Defaults.floor_line_width_scale = floor_line_width_scale if figsize is not None: matplotlib.rcParams["figure.figsize"] = figsize if width and height: @@ -176,13 +207,14 @@ def plot_curve_line( curve: PlotCurveLine, ax: matplotlib.axes.Axes, label: Optional[str] = None, + line_width_scale: float = 1.0, ): return ax.plot( curve.xs, curve.ys, color=pgplot.mpl_color(curve.color or "black"), linestyle=curve.linestyle, - linewidth=curve.linewidth, + linewidth=curve.linewidth * line_width_scale, label=label, ) @@ -218,7 +250,7 @@ def plot_histogram( ) -def plot_curve(curve: PlotCurve, ax: matplotlib.axes.Axes): +def plot_curve(curve: PlotCurve, ax: matplotlib.axes.Axes, line_width_scale: float = 1.0): res = [] if curve.line is not None: res.append( @@ -239,11 +271,15 @@ def plot_curve(curve: PlotCurve, ax: matplotlib.axes.Axes): if curve.histogram is not None: res.append(plot_histogram(curve.histogram, ax)) for patch in curve.patches or []: - res.append(plot_patch(patch, ax)) + res.append(plot_patch(patch, ax, line_width_scale=line_width_scale)) return res -def patch_to_mpl(patch: PlotPatch): +def patch_to_mpl(patch: PlotPatch, line_width_scale: float = 1.0): + patch_args = patch._patch_args + if patch_args["linewidth"] is not None: + patch_args["linewidth"] *= line_width_scale + if isinstance(patch, PlotPatchRectangle): return matplotlib.patches.Rectangle( xy=patch.xy, @@ -251,7 +287,7 @@ def patch_to_mpl(patch: PlotPatch): height=patch.height, angle=patch.angle, rotation_point=patch.rotation_point, - **patch._patch_args, + **patch_args, ) if isinstance(patch, PlotPatchArc): return matplotlib.patches.Arc( @@ -261,18 +297,18 @@ def patch_to_mpl(patch: PlotPatch): angle=patch.angle, theta1=patch.theta1, theta2=patch.theta2, - **patch._patch_args, + **patch_args, ) if isinstance(patch, PlotPatchCircle): return matplotlib.patches.Circle( xy=patch.xy, radius=patch.radius, - **patch._patch_args, + **patch_args, ) if isinstance(patch, PlotPatchPolygon): return matplotlib.patches.Polygon( xy=patch.vertices, - **patch._patch_args, + **patch_args, ) if isinstance(patch, PlotPatchEllipse): @@ -281,7 +317,7 @@ def patch_to_mpl(patch: PlotPatch): width=patch.width, height=patch.height, angle=patch.angle, - **patch._patch_args, + **patch_args, ) if isinstance(patch, PlotPatchSbend): codes = [ @@ -304,26 +340,34 @@ def patch_to_mpl(patch: PlotPatch): ] return matplotlib.patches.PathPatch( matplotlib.path.Path(vertices, codes), - facecolor="green", - alpha=0.5, + # facecolor="green", + # alpha=0.5, + **patch_args, ) raise NotImplementedError(f"Unsupported patch type: {type(patch).__name__}") -def plot_patch(patch: PlotPatch, ax: matplotlib.axes.Axes): - mpl = patch_to_mpl(patch) +def plot_patch(patch: PlotPatch, ax: matplotlib.axes.Axes, line_width_scale: float = 1.0): + mpl = patch_to_mpl(patch, line_width_scale=line_width_scale) ax.add_patch(mpl) return mpl -def plot_layout_shape(shape: layout_shapes.AnyLayoutShape, ax: matplotlib.axes.Axes): +def plot_layout_shape( + shape: layout_shapes.AnyLayoutShape, + ax: matplotlib.axes.Axes, + line_width_scale: Optional[float] = None, +): + if line_width_scale is None: + line_width_scale = _Defaults.line_width_scale + if isinstance(shape, layout_shapes.LayoutWrappedShape): ax.add_collection( matplotlib.collections.LineCollection( [[(x, y) for x, y in zip(line[0], line[1])] for line in shape.lines], colors=pgplot.mpl_color(shape.color), - linewidths=shape.line_width, + linewidths=shape.line_width * line_width_scale, ) ) else: @@ -333,19 +377,26 @@ def plot_layout_shape(shape: layout_shapes.AnyLayoutShape, ax: matplotlib.axes.A matplotlib.collections.LineCollection( lines, colors=pgplot.mpl_color(shape.color), - linewidths=shape.line_width, + linewidths=shape.line_width * line_width_scale, ) ) for patch in shape.to_patches(): - plot_patch(patch, ax) + plot_patch(patch, ax, line_width_scale=line_width_scale) + +def plot_floor_plan_shape( + shape: floor_plan_shapes.Shape, + ax: matplotlib.axes.Axes, + line_width_scale: Optional[float] = None, +): + if line_width_scale is None: + line_width_scale = _Defaults.floor_line_width_scale -def plot_floor_plan_shape(shape: floor_plan_shapes.Shape, ax: matplotlib.axes.Axes): for line in shape.to_lines(): - plot_curve_line(line, ax) + plot_curve_line(line, ax, line_width_scale=line_width_scale) if not isinstance(shape, floor_plan_shapes.Box): for patch in shape.to_patches(): - plot_patch(patch, ax) + plot_patch(patch, ax, line_width_scale=line_width_scale) def plot(graph: AnyGraph, ax: Optional[matplotlib.axes.Axes] = None) -> matplotlib.axes.Axes: @@ -356,8 +407,10 @@ def plot(graph: AnyGraph, ax: Optional[matplotlib.axes.Axes] = None) -> matplotl if isinstance(graph, BasicGraph): for curve in graph.curves: - assert not curve.info["use_y2"], "TODO: y2 support" - plot_curve(curve, ax) + if curve.info["use_y2"]: + raise NotImplementedError("y2 support") + + plot_curve(curve, ax, line_width_scale=_Defaults.line_width_scale) if graph.draw_legend and any(curve.legend_label for curve in graph.curves): ax.legend() @@ -367,7 +420,7 @@ def plot(graph: AnyGraph, ax: Optional[matplotlib.axes.Axes] = None) -> matplotl for elem in graph.elements: if elem.shape is not None: - plot_layout_shape(elem.shape, ax) + plot_layout_shape(elem.shape, ax, line_width_scale=_Defaults.line_width_scale) # ax.add_collection( # matplotlib.collections.LineCollection( # elem.lines, @@ -392,14 +445,18 @@ def plot(graph: AnyGraph, ax: Optional[matplotlib.axes.Axes] = None) -> matplotl ax.set_aspect("equal") for elem in graph.elements: if elem.shape is not None: - plot_floor_plan_shape(elem.shape, ax) + plot_floor_plan_shape( + elem.shape, + ax, + line_width_scale=_Defaults.floor_line_width_scale, + ) for annotation in elem.annotations: plot_annotation(annotation, ax) for line in graph.building_walls.lines: - plot_curve_line(line, ax) + plot_curve_line(line, ax, line_width_scale=_Defaults.floor_line_width_scale) for patch in graph.building_walls.patches: - plot_patch(patch, ax) + plot_patch(patch, ax, line_width_scale=_Defaults.floor_line_width_scale) if graph.floor_orbits is not None: plot_curve_symbols(graph.floor_orbits.curve, ax) else: