diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..961af050 --- /dev/null +++ b/.flake8 @@ -0,0 +1,11 @@ +[flake8] +max-line-length = 300 +exclude = + .git + docs/build + docs/source/gallery + development + +count = true +per-file-ignores = + __init__.py: F403, F401 diff --git a/SuPyMode/directories.py b/SuPyMode/directories.py index ca9bcffe..6b3be0f6 100644 --- a/SuPyMode/directories.py +++ b/SuPyMode/directories.py @@ -9,7 +9,6 @@ 'root_path', 'project_path', 'test_path', - 'instance_directory', 'version_path', 'validation_data_path', 'doc_path', diff --git a/SuPyMode/helper.py b/SuPyMode/helper.py new file mode 100644 index 00000000..58983199 --- /dev/null +++ b/SuPyMode/helper.py @@ -0,0 +1,78 @@ +from typing import Callable +from MPSPlots.styles import mps +import matplotlib.pyplot as plt +from SuPyMode.utils import interpret_mode_of_interest + + +def singular_plot_helper(function: Callable) -> Callable: + def wrapper(self, ax: plt.Axes = None, show: bool = True, mode_of_interest: str = 'all', **kwargs) -> plt.Figure: + if ax is None: + with plt.style.context(mps): + figure, ax = plt.subplots(1, 1) + + mode_of_interest = interpret_mode_of_interest(superset=self, mode_of_interest=mode_of_interest) + + function(self, ax=ax, mode_of_interest=mode_of_interest, **kwargs) + + _, labels = ax.get_legend_handles_labels() + + # Only add a legend if there are labels + if labels: + ax.legend() + + if show: + plt.show() + + return figure + + return wrapper + + +def combination_plot_helper(function: Callable) -> Callable: + def wrapper(self, ax: plt.Axes = None, show: bool = True, mode_of_interest: str = 'all', combination: str = 'pairs', **kwargs) -> plt.Figure: + if ax is None: + with plt.style.context(mps): + figure, ax = plt.subplots(1, 1) + + mode_of_interest = interpret_mode_of_interest(superset=self, mode_of_interest=mode_of_interest) + + combination = self.interpret_combination(mode_of_interest=mode_of_interest, combination=combination) + + function(self, ax=ax, mode_of_interest=mode_of_interest, combination=combination, **kwargs) + + _, labels = ax.get_legend_handles_labels() + + # Only add a legend if there are labels + if labels: + ax.legend() + + if show: + plt.show() + + return figure + + return wrapper + + +def parse_mode_of_interest(plot_function: Callable) -> Callable: + def wrapper(self, *args, mode_of_interest='all', **kwargs): + mode_of_interest = interpret_mode_of_interest( + superset=self, + mode_of_interest=mode_of_interest + ) + + return plot_function(self, *args, mode_of_interest=mode_of_interest, **kwargs) + + return wrapper + + +def parse_combination(plot_function: Callable) -> Callable: + def wrapper(self, *args, mode_of_interest='all', combination: str = 'pairs', **kwargs): + combination = self.interpret_combination( + mode_of_interest=mode_of_interest, + combination=combination + ) + + return plot_function(self, *args, mode_of_interest=mode_of_interest, combination=combination, **kwargs) + + return wrapper diff --git a/SuPyMode/mode_label.py b/SuPyMode/mode_label.py index c7c8b6f3..1f5e6001 100644 --- a/SuPyMode/mode_label.py +++ b/SuPyMode/mode_label.py @@ -2,58 +2,92 @@ # -*- coding: utf-8 -*- from PyFinitDiff.finite_difference_2D import Boundaries +from enum import Enum -mode_dict = [ - {'mode': (0, 1, ""), 'x': 'symmetric', 'y': 'symmetric'}, - {'mode': (1, 1, r"_a"), 'x': 'symmetric', 'y': 'anti-symmetric'}, - {'mode': (1, 1, "_b"), 'x': 'anti-symmetric', 'y': 'symmetric'}, - {'mode': (2, 1, "_a"), 'x': 'symmetric', 'y': 'symmetric'}, - {'mode': (2, 1, "_b"), 'x': 'anti-symmetric', 'y': 'anti-symmetric'}, - {'mode': (0, 2, ""), 'x': 'symmetric', 'y': 'symmetric'}, - {'mode': (3, 1, "_a"), 'x': 'symmetric', 'y': 'anti-symmetric'}, - {'mode': (3, 1, "_b"), 'x': 'anti-symmetric', 'y': 'symmetric'}, - {'mode': (1, 2, "_a"), 'x': 'symmetric', 'y': 'anti-symmetric'}, - {'mode': (1, 2, "_b"), 'x': 'anti-symmetric', 'y': 'symmetric'}, - {'mode': (4, 1, "_a"), 'x': 'symmetric', 'y': 'symmetric'}, - {'mode': (4, 1, "_b"), 'x': 'anti-symmetric', 'y': 'anti-symmetric'}, - {'mode': (2, 2, "_a"), 'x': 'symmetric', 'y': 'symmetric'}, - {'mode': (2, 2, "_b"), 'x': 'anti-symmetric', 'y': 'anti-symmetric'}, - {'mode': (0, 3, ""), 'x': 'symmetric', 'y': 'symmetric'}, - {'mode': (5, 1, "_a"), 'x': 'symmetric', 'y': 'anti-symmetric'}, - {'mode': (5, 1, "_b"), 'x': 'anti-symmetric', 'y': 'symmetric'}, - {'mode': (3, 2, "_a"), 'x': 'symmetric', 'y': 'anti-symmetric'}, - {'mode': (3, 2, "_b"), 'x': 'anti-symmetric', 'y': 'symmetric'}, - {'mode': (1, 3, "_a"), 'x': 'symmetric', 'y': 'anti-symmetric'}, - {'mode': (1, 3, "_b"), 'x': 'anti-symmetric', 'y': 'symmetric'}, - {'mode': (6, 1, "_a"), 'x': 'symmetric', 'y': 'symmetric'}, - {'mode': (6, 1, "_b"), 'x': 'anti-symmetric', 'y': 'anti-symmetric'}, - {'mode': (4, 2, "_a"), 'x': 'symmetric', 'y': 'symmetric'}, - {'mode': (4, 2, "_b"), 'x': 'anti-symmetric', 'y': 'anti-symmetric'}, -] +class Parity(Enum): + SYMMETRIC = 'symmetric' + ANTI_SYMMETRIC = 'anti-symmetric' + ZERO = 'zero' + + +mode_dict = [] + + +for azimuthal, radial in [(0, 1), (1, 1), (2, 1), (0, 2), (3, 1), (1, 2), (4, 1), (2, 2), (0, 3), (5, 1), (3, 2), (1, 3), (6, 1), ]: + + if azimuthal == 0: + parities = [[Parity.SYMMETRIC.value, Parity.SYMMETRIC.value]] + sublabels = [''] + + elif azimuthal % 2 == 1: + parities = [ + [Parity.ANTI_SYMMETRIC.value, Parity.SYMMETRIC.value], + [Parity.SYMMETRIC.value, Parity.ANTI_SYMMETRIC.value], + ] + sublabels = ['_a', '_b'] + + elif azimuthal % 2 == 0: + parities = [ + [Parity.SYMMETRIC.value, Parity.SYMMETRIC.value], + [Parity.ANTI_SYMMETRIC.value, Parity.ANTI_SYMMETRIC.value], + ] + sublabels = ['_a', '_b'] + + for (x_parity, y_parity), sublabel in zip(parities, sublabels): + + mode = dict(mode=(azimuthal, radial, sublabel), x=x_parity, y=y_parity) + + mode_dict.append(mode) class ModeLabel: """ - A class to represent the LP mode label of an optical fiber based on boundary conditions and mode number. - - Attributes: - boundaries (Boundaries): The boundary conditions for the mode. - mode_number (int): The mode number. - x_parity (str): The parity in the x direction (symmetric, anti-symmetric, or zero). - y_parity (str): The parity in the y direction (symmetric, anti-symmetric, or zero). - azimuthal (int): The azimuthal mode number. - radial (int): The radial mode number. - sub_label (str): The sub-label for the mode. + Represents the LP mode label of an optical fiber based on boundary conditions and mode number. + + The `ModeLabel` class encapsulates information about the optical mode based on the boundary conditions + of an optical fiber. It calculates and assigns labels such as `LP01`, `LP11_a`, etc., depending on the + given boundary conditions and mode number. + + Parameters + ---------- + boundaries : Boundaries + The boundary conditions for the mode, indicating symmetries in different directions. + mode_number : int + The mode number to label, corresponding to the specific optical mode. + + Attributes + ---------- + boundaries : Boundaries + The boundary conditions for the mode. + mode_number : int + The mode number to label. + x_parity : str + The parity in the x direction (symmetric, anti-symmetric, or zero). + y_parity : str + The parity in the y direction (symmetric, anti-symmetric, or zero). + azimuthal : int or None + The azimuthal mode number, extracted from the filtered mode list. + radial : int or None + The radial mode number, extracted from the filtered mode list. + sub_label : str or None + The sub-label for the mode, such as `_a` or `_b`. + raw_label : str + The base LP label without the sub-label. + label : str + The full LP label for the mode, including the sub-label. """ def __init__(self, boundaries: Boundaries, mode_number: int): """ - Initializes the ModeLabel with given boundary conditions and mode number. - - Args: - boundaries (Boundaries): The boundary conditions. - mode_number (int): The mode number. + Initializes the `ModeLabel` instance with given boundary conditions and mode number. + + Parameters + ---------- + boundaries : Boundaries + The boundary conditions indicating symmetries in the left, right, top, and bottom directions. + mode_number : int + The mode number, used to identify the optical mode. """ self.boundaries = boundaries self.mode_number = mode_number @@ -61,67 +95,77 @@ def __init__(self, boundaries: Boundaries, mode_number: int): self.initialize() def initialize(self) -> None: - self.x_parity = self._get_x_parity() - self.y_parity = self._get_y_parity() + """ + Initialize and calculate the mode label based on boundary conditions and mode number. + + This method sets the `x_parity` and `y_parity` based on the boundary conditions, filters the + mode list, and assigns appropriate labels for azimuthal and radial modes. + """ + self.x_parity = self._get_parity(self.boundaries.left, self.boundaries.right) + self.y_parity = self._get_parity(self.boundaries.top, self.boundaries.bottom) filtered_modes = self.get_filtered_mode_list() if self.mode_number >= len(filtered_modes): - self.azimuthal, self.radial, self.sub_label = None, None, None - self.raw_label = f"Mode{self.mode_number}" - self.label = self.raw_label + raise ValueError(f"Mode number {self.mode_number} exceeds available modes. Max allowed: {len(filtered_modes) - 1}") - else: - self.azimuthal, self.radial, self.sub_label = filtered_modes[self.mode_number] - self.raw_label = f"LP{self.azimuthal}{self.radial}" - self.label = f"{self.raw_label}{self.sub_label}" + self.azimuthal, self.radial, self.sub_label = filtered_modes[self.mode_number] + self.raw_label = f"LP{self.azimuthal}{self.radial}" + self.label = f"{self.raw_label}{self.sub_label}" def get_filtered_mode_list(self) -> list[tuple]: - return [m['mode'] for m in mode_dict if self.x_parity in [m['x'], 'zero'] and self.y_parity in [m['y'], 'zero']] - - def _get_x_parity(self) -> str: """ - Determines the parity in the x direction based on boundary conditions. + Filters the list of available modes based on the x and y parity conditions. - Returns: - str: The x parity (symmetric, anti-symmetric, or zero). + Returns + ------- + list of tuple + A list of filtered modes that match the specified x and y parity conditions. """ - if self.boundaries.left == 'symmetric' or self.boundaries.right == 'symmetric': - return 'symmetric' - elif self.boundaries.left == 'anti-symmetric' or self.boundaries.right == 'anti-symmetric': - return 'anti-symmetric' - else: - return 'zero' + return [m['mode'] for m in mode_dict if self.x_parity in [m['x'], Parity.ZERO.value] and self.y_parity in [m['y'], Parity.ZERO.value]] - def _get_y_parity(self) -> str: + def _get_parity(self, boundary_1: str, boundary_2: str) -> str: """ - Determines the parity in the y direction based on boundary conditions. - - Returns: - str: The y parity (symmetric, anti-symmetric, or zero). + Determines the parity in a direction based on boundary conditions. + + Parameters + ---------- + boundary_1 : str + The boundary condition for one side (e.g., left or top). + boundary_2 : str + The boundary condition for the opposite side (e.g., right or bottom). + + Returns + ------- + str + The parity ('symmetric', 'anti-symmetric', or 'zero'). """ - if self.boundaries.top == 'symmetric' or self.boundaries.bottom == 'symmetric': - return 'symmetric' - elif self.boundaries.top == 'anti-symmetric' or self.boundaries.bottom == 'anti-symmetric': - return 'anti-symmetric' + if Parity.SYMMETRIC.value in (boundary_1, boundary_2): + return Parity.SYMMETRIC.value + elif Parity.ANTI_SYMMETRIC.value in (boundary_1, boundary_2): + return Parity.ANTI_SYMMETRIC.value else: - return 'zero' + return Parity.ZERO.value def __repr__(self) -> str: """ - Returns the string representation of the ModeLabel. + Returns the string representation of the `ModeLabel`. - Returns: - str: The LP mode label. + Returns + ------- + str + The LP mode label, including the azimuthal and radial information. """ return self.label def __str__(self) -> str: """ - Returns the string representation of the ModeLabel. + Returns the string representation of the `ModeLabel`. - Returns: - str: The LP mode label. + Returns + ------- + str + The LP mode label. """ return self.__repr__() diff --git a/SuPyMode/profiles.py b/SuPyMode/profiles.py index 1b318b5a..0a240343 100644 --- a/SuPyMode/profiles.py +++ b/SuPyMode/profiles.py @@ -19,7 +19,6 @@ strict=True, arbitrary_types_allowed=True, kw_only=True, - frozen=True ) diff --git a/SuPyMode/representation/adiabatic.py b/SuPyMode/representation/adiabatic.py index 84458467..5b0e3ab3 100644 --- a/SuPyMode/representation/adiabatic.py +++ b/SuPyMode/representation/adiabatic.py @@ -7,7 +7,6 @@ from SuPyMode.supermode import SuperMode import numpy -from typing import NoReturn from SuPyMode.representation.base import InheritFromSuperMode, BaseMultiModePlot import matplotlib.pyplot as plt @@ -16,42 +15,40 @@ class Adiabatic(InheritFromSuperMode, BaseMultiModePlot): """ Represents the adiabatic criterion between modes of different supermodes in optical fiber simulations. - This class extends from `InheritFromSuperMode` for accessing supermode-related data and `BaseMultiModePlot` + This class extends from `InheritFromSuperMode` for accessing supermode-related data and from `BaseMultiModePlot` for plotting functionalities tailored to visualize adiabatic transition measurements. - Class Attributes: - plot_style (dict): A dictionary defining the default style settings for plots generated by this class. - """ - - plot_style = dict( - show_legend=True, - x_label='Inverse taper ratio', - y_label=r'Adiabatic criterion [$\mu$m$^{-1}$]', - y_scale='log', - y_scale_factor=1e-6, - y_limits=[1e-5, 1], - line_width=2 - ) + Attributes + ---------- + parent_supermode : SuperMode + The parent supermode object that provides the base mode data. + """ def __init__(self, parent_supermode: SuperMode): """ - Initializes an Adiabatic object with a reference to a parent supermode. + Initialize an Adiabatic object with a reference to a parent supermode. - Args: - parent_supermode (SuperMode): The parent supermode object that provides the base mode data. + Parameters + ---------- + parent_supermode : SuperMode + The parent supermode object that provides the base mode data. """ self.parent_supermode = parent_supermode def get_values(self, other_supermode: SuperMode) -> numpy.ndarray: """ - Calculates the adiabatic transition measure between the parent supermode and another specified supermode. - - Args: - other_supermode (SuperMode): The supermode with which to compare the parent supermode. - - Returns: - numpy.ndarray: An array of adiabatic transition measures calculated between the two supermodes, - possibly adjusted by compatibility considerations. + Calculate the adiabatic transition measure between the parent supermode and another specified supermode. + + Parameters + ---------- + other_supermode : SuperMode + The supermode with which to compare the parent supermode. + + Returns + ------- + numpy.ndarray + An array of adiabatic transition measures calculated between the two supermodes, adjusted as needed for + computational compatibility. """ output = self.parent_supermode.binding.get_adiabatic_with_mode(other_supermode.binding) @@ -60,8 +57,19 @@ def get_values(self, other_supermode: SuperMode) -> numpy.ndarray: return abs(output) - def _dress_ax(self, ax: plt.Axes) -> NoReturn: - ax.set_xlabel('Inverse taper ratio') - ax.set_ylabel(r'Adiabatic criterion [$\mu$m$^{-1}$]') + def _dress_ax(self, ax: plt.Axes) -> None: + """ + Set axis labels for the adiabatic criterion plot. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axis object on which to set the labels. + """ + ax.set( + xlabel='Inverse taper ratio', + ylabel=r'Adiabatic criterion [$\mu$m$^{-1}$]', + ylim=[1e-5, 1], + ) # - diff --git a/SuPyMode/representation/base.py b/SuPyMode/representation/base.py index b3303875..bb63febe 100644 --- a/SuPyMode/representation/base.py +++ b/SuPyMode/representation/base.py @@ -1,28 +1,49 @@ -# #!/usr/bin/env python -# # -*- coding: utf-8 -*- - from __future__ import annotations from typing import TYPE_CHECKING if TYPE_CHECKING: from SuPyMode.supermode import SuperMode -from typing import NoReturn import matplotlib.pyplot as plt +from MPSPlots.styles import mps class BaseMultiModePlot(): - def render_on_ax(self, ax: plt.Axes, other_supermode: SuperMode) -> None: + def plot(self, other_supermode: SuperMode, ax: plt.Axes = None, show: bool = True) -> plt.Figure: + """ + Plot the coupling between the parent supermode and another supermode. + + This method generates a plot of specific parameter couplings between the parent supermode and the specified + `other_supermode`, using a single-axis matplotlib plot. The plot shows the normalized coupling as a function of + the inverse taper ratio (ITR), formatted according to the predefined plot style. + + Parameters + ---------- + other_supermode : SuperMode + The supermode to compare against. + ax : matplotlib.axes.Axes, optional + The axis on which to plot. If `None`, a new axis is created (default is `None`). + show : bool, optional + Whether to display the plot immediately (default is `True`). + + Returns + ------- + matplotlib.figure.Figure + The figure object containing the generated plot. + + Examples + -------- + >>> fig, ax = plt.subplots() + >>> base_multi_mode_plot.plot(other_supermode=mode2, ax=ax, show=True) + >>> plt.show() """ - Renders normalized mode coupling data as a line plot on the provided Axes object, comparing the parent supermode - with another supermode. + if ax is None: + with plt.style.context(mps): + figure, ax = plt.subplots(1, 1) + else: + figure = ax.figure - Args: - ax (plt.Axes): The Axes object on which to plot the normalized mode coupling. - other_supermode (SuperMode): The other supermode to compare against. + self._dress_ax(ax) - Note: - This method is conditioned on computational compatibility between the supermodes. - """ if not self.parent_supermode.is_computation_compatible(other_supermode): return @@ -32,120 +53,251 @@ def render_on_ax(self, ax: plt.Axes, other_supermode: SuperMode) -> None: ax.plot(self.itr_list, abs(y), label=label, linewidth=2) - def plot(self, other_supermode: SuperMode) -> NoReturn: - """ - Generates a plot of specific parameter coupling between the parent supermode and another specified supermode using a SceneList. - - This method creates a single-axis plot showing the comparative mode couplings as a function of the inverse taper ratio, - formatted according to the predefined plot style. - - Args: - other_supermode (SuperMode): The supermode to compare against. - - Returns: - SceneList: A scene list containing the plot of normalized mode couplings. - """ - figure, ax = plt.subplots(1, 1) - - self._dress_ax(ax) - - self.render_on_ax(ax=ax, other_supermode=other_supermode) - ax.legend() - figure.tight_layout() - plt.show() + if show: + plt.show() + + return figure class BaseSingleModePlot(): def __getitem__(self, idx: int): return self._data[idx] - def render_on_ax(self, ax: plt.Axis) -> None: - """ - Renders the eigenvalues as a line plot on the provided Axis object. - - Args: - ax (Axis): The Axis object on which the eigenvalues will be plotted. - - Note: - This method utilizes the plotting configuration set on the Axis to define the appearance of the plot. + def plot(self, ax: plt.Axes = None, show: bool = True) -> plt.Figure: """ - ax.plot(self.itr_list, self.data, label=f'{self.stylized_label}', linewidth=2) - - def plot(self) -> NoReturn: - """ - Generates a plot of the using matplotlib. + Plot the propagation constant for a single mode. - This method creates a single-axis plot showing the propagation constants as a function of the inverse taper ratio, + This method generates a plot of the propagation constants as a function of the inverse taper ratio (ITR), formatted according to the predefined plot style. + Parameters + ---------- + ax : matplotlib.axes.Axes, optional + The axis on which to plot. If `None`, a new axis is created (default is `None`). + show : bool, optional + Whether to display the plot immediately (default is `True`). + + Returns + ------- + matplotlib.figure.Figure + The figure object containing the generated plot. + + Examples + -------- + >>> fig, ax = plt.subplots() + >>> base_single_mode_plot.plot(ax=ax, show=True) + >>> plt.show() """ - figure, ax = plt.subplots(1, 1) + if ax is None: + with plt.style.context(mps): + figure, ax = plt.subplots(1, 1) + else: + figure = ax.figure self._dress_ax(ax) - self.render_on_ax(ax=ax) + ax.plot(self.itr_list, self.data, label=f'{self.stylized_label}', linewidth=2) ax.legend() - figure.tight_layout() - plt.show() + if show: + plt.show() + + return figure class InheritFromSuperMode(): def _set_axis_(self, ax: plt.Axes): + """ + Set the axis properties according to the predefined plot style. + + This method applies various properties (e.g., labels, limits) to the given axis based on the `plot_style` dictionary. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axis to which the plot style will be applied. + """ for element, value in self.plot_style.items(): setattr(ax, element, value) def __getitem__(self, idx): + """ + Get the data value at the specified index. + + Parameters + ---------- + idx : int + The index of the data value to retrieve. + + Returns + ------- + Any + The data value at the specified index. + """ return self.data[idx] @property def mode_number(self) -> int: + """ + Get the mode number of the parent supermode. + + Returns + ------- + int + The mode number of the parent supermode. + """ return self.parent_supermode.mode_number @property def solver_number(self) -> int: + """ + Get the solver number of the parent supermode. + + Returns + ------- + int + The solver number of the parent supermode. + """ return self.parent_supermode.solver_number @property def axes(self): + """ + Get the axes of the parent supermode. + + Returns + ------- + Any + The axes associated with the parent supermode. + """ return self.parent_supermode.axes @property def boundaries(self): + """ + Get the boundary conditions of the parent supermode. + + Returns + ------- + Any + The boundaries of the parent supermode. + """ return self.parent_supermode.boundaries @property def itr_list(self): + """ + Get the list of inverse taper ratio (ITR) values. + + Returns + ------- + list + The list of ITR values associated with the parent supermode. + """ return self.parent_supermode.itr_list @property def ID(self): + """ + Get the identifier (ID) of the parent supermode. + + Returns + ------- + Any + The identifier of the parent supermode. + """ return self.parent_supermode.ID @property def label(self): + """ + Get the label of the parent supermode. + + Returns + ------- + str + The label of the parent supermode. + """ return self.parent_supermode.label @property def stylized_label(self): + """ + Get the stylized label of the parent supermode. + + Returns + ------- + str + The stylized label of the parent supermode. + """ return self.parent_supermode.stylized_label def slice_to_itr(self, slice: list = []): + """ + Convert slice indices to inverse taper ratio (ITR) values. + + Parameters + ---------- + slice : list of int, optional + A list of slice indices to convert (default is an empty list). + + Returns + ------- + list + A list of ITR values corresponding to the provided slice indices. + """ return self.parent_supermode.parent_set.slice_to_itr(slice) def itr_to_slice(self, itr: list = []): + """ + Convert inverse taper ratio (ITR) values to slice indices. + + Parameters + ---------- + itr : list of float, optional + A list of ITR values to convert (default is an empty list). + + Returns + ------- + list + A list of slice indices corresponding to the provided ITR values. + """ return self.parent_supermode.parent_set.itr_to_slice(itr) def _get_symmetrize_vector(self, *args, **kwargs): + """ + Get the symmetrization vector from the parent supermode. + + Returns + ------- + Any + The symmetrization vector computed by the parent supermode. + """ return self.parent_supermode._get_symmetrize_vector(*args, **kwargs) def _get_axis_vector(self, *args, **kwargs): + """ + Get the axis vector from the parent supermode. + + Returns + ------- + Any + The axis vector computed by the parent supermode. + """ return self.parent_supermode._get_axis_vector(*args, **kwargs) def get_axis(self, *args, **kwargs): + """ + Get the axis information from the parent supermode. + + Returns + ------- + Any + The axis information retrieved from the parent supermode. + """ return self.parent_supermode.get_axis(*args, **kwargs) # - diff --git a/SuPyMode/representation/beating_length.py b/SuPyMode/representation/beating_length.py index 47cd5fe3..5718a75e 100644 --- a/SuPyMode/representation/beating_length.py +++ b/SuPyMode/representation/beating_length.py @@ -7,7 +7,6 @@ from SuPyMode.supermode import SuperMode import numpy -from typing import NoReturn from SuPyMode.representation.base import InheritFromSuperMode, BaseMultiModePlot import matplotlib.pyplot as plt @@ -19,40 +18,47 @@ class BeatingLength(InheritFromSuperMode, BaseMultiModePlot): This class extends from `InheritFromSuperMode` to utilize supermode-related data and from `BaseMultiModePlot` for advanced plotting functionalities tailored to visualize beating length comparisons. - Class Attributes: - plot_style (dict): Default style settings for plots generated by this class. + Attributes + ---------- + parent_supermode : SuperMode + The parent supermode object that provides the base mode data. """ - - plot_style = dict( - show_legend=True, - x_label='Inverse taper ratio', - y_label='Beating length [m]', - y_scale="log", - line_width=2 - ) - def __init__(self, parent_supermode: SuperMode): """ - Initializes a BeatingLength object with a reference to a parent supermode. + Initialize a BeatingLength object with a reference to a parent supermode. - Args: - parent_supermode (SuperMode): The parent supermode object that provides the base mode data. + Parameters + ---------- + parent_supermode : SuperMode + The parent supermode object that provides the base mode data. """ self.parent_supermode = parent_supermode def get_values(self, other_supermode: SuperMode) -> numpy.ndarray: """ - Calculates the beating length between the parent supermode and another specified supermode. + Calculate the beating length between the parent supermode and another specified supermode. - Args: - other_supermode (SuperMode): The supermode with which to compare the parent supermode. + Parameters + ---------- + other_supermode : SuperMode + The supermode with which to compare the parent supermode. - Returns: - numpy.ndarray: An array of beating lengths calculated between the two supermodes. + Returns + ------- + numpy.ndarray + An array of beating lengths calculated between the two supermodes. """ return self.parent_supermode.binding.get_beating_length_with_mode(other_supermode.binding) - def _dress_ax(self, ax: plt.Axes) -> NoReturn: + def _dress_ax(self, ax: plt.Axes) -> None: + """ + Set axis labels for the beating length plot. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axis object on which to set the labels. + """ ax.set_xlabel('Inverse taper ratio') ax.set_ylabel('Beating length [m]') diff --git a/SuPyMode/representation/beta.py b/SuPyMode/representation/beta.py index 3218d652..5375c420 100644 --- a/SuPyMode/representation/beta.py +++ b/SuPyMode/representation/beta.py @@ -6,7 +6,6 @@ if TYPE_CHECKING: from SuPyMode.supermode import SuperMode -from typing import NoReturn from SuPyMode.representation.base import InheritFromSuperMode, BaseSingleModePlot import matplotlib.pyplot as plt @@ -18,26 +17,35 @@ class Beta(InheritFromSuperMode, BaseSingleModePlot): This class utilizes inheritance from `InheritFromSuperMode` for accessing supermode-related data and `BaseSingleModePlot` for plotting functionalities tailored to propagation constant visualization. - Class Attributes: - plot_style (dict): A dictionary defining the default style settings for plots generated by this class. - - Attributes: - parent_supermode (InheritFromSuperMode): A reference to the parent supermode object from which beta data is sourced. + Attributes + ---------- + parent_supermode : SuperMode + A reference to the parent supermode object from which beta data is sourced. + data : numpy.ndarray + The propagation constant (beta) data retrieved from the parent supermode binding. """ - def __init__(self, parent_supermode: SuperMode): """ - Initializes a Beta object with a reference to a parent supermode. + Initialize a Beta object with a reference to a parent supermode. - Args: - parent_supermode (InheritFromSuperMode): The parent supermode object. + Parameters + ---------- + parent_supermode : SuperMode + The parent supermode object that provides the base beta data. """ self.parent_supermode = parent_supermode self.data = self.parent_supermode.binding.get_betas() - def _dress_ax(self, ax: plt.Axes) -> NoReturn: + def _dress_ax(self, ax: plt.Axes) -> None: + """ + Set axis labels for the propagation constant plot. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axis object on which to set the labels. + """ ax.set_xlabel('Inverse taper ratio') ax.set_ylabel('Propagation constant [rad/M]') - # - diff --git a/SuPyMode/representation/eigen_value.py b/SuPyMode/representation/eigen_value.py index 19c492c4..361403ea 100644 --- a/SuPyMode/representation/eigen_value.py +++ b/SuPyMode/representation/eigen_value.py @@ -6,7 +6,6 @@ if TYPE_CHECKING: from SuPyMode.supermode import SuperMode -from typing import NoReturn from SuPyMode.representation.base import InheritFromSuperMode, BaseSingleModePlot import matplotlib.pyplot as plt @@ -18,23 +17,35 @@ class EigenValue(InheritFromSuperMode, BaseSingleModePlot): This class extends from `InheritFromSuperMode` to access supermode-related data and from `BaseSingleModePlot` to provide plotting capabilities tailored to eigenvalue visualization. - Attributes: - parent_supermode (InheritFromSuperMode): The parent supermode object from which eigenvalue data is derived. + Attributes + ---------- + parent_supermode : SuperMode + The parent supermode object from which eigenvalue data is derived. + data : numpy.ndarray + The eigenvalue data retrieved from the parent supermode binding. """ - def __init__(self, parent_supermode: SuperMode): """ - Initializes an EigenValue object with a parent supermode reference. + Initialize an EigenValue object with a parent supermode reference. - Args: - parent_supermode (InheritFromSuperMode): A reference to the parent supermode object. + Parameters + ---------- + parent_supermode : SuperMode + A reference to the parent supermode object that provides the base eigenvalue data. """ self.parent_supermode = parent_supermode self.data = self.parent_supermode.binding.get_eigen_value() - def _dress_ax(self, ax: plt.Axis) -> NoReturn: - ax.set_xlabel('Inverse taper ratio') - ax.set_ylabel('Mode eigen values') + def _dress_ax(self, ax: plt.Axes) -> None: + """ + Set axis labels for the eigenvalue plot. + Parameters + ---------- + ax : matplotlib.axes.Axes + The axis object on which to set the labels. + """ + ax.set_xlabel('Inverse taper ratio') + ax.set_ylabel('Mode eigenvalues') # - diff --git a/SuPyMode/representation/field.py b/SuPyMode/representation/field.py index 511be35c..4b59dd5e 100644 --- a/SuPyMode/representation/field.py +++ b/SuPyMode/representation/field.py @@ -6,7 +6,6 @@ if TYPE_CHECKING: from SuPyMode.supermode import SuperMode -from typing import NoReturn import numpy from MPSPlots import colormaps import matplotlib.pyplot as plt @@ -24,16 +23,21 @@ class Field(InheritFromSuperMode): This class extends functionality from a parent supermode class to manage field data operations, including retrieving and processing field data for visualization and analysis. - Attributes: - parent_supermode (InheritFromSuperMode): Reference to the parent supermode object that provides source data. + Attributes + ---------- + parent_supermode : SuperMode + Reference to the parent supermode object that provides source data. + data : numpy.ndarray + The field data retrieved from the parent supermode binding. """ - def __init__(self, parent_supermode: SuperMode): """ Initialize the Field object with a reference to a parent supermode. - Args: - parent_supermode (InheritFromSuperMode): The parent supermode from which this field is derived. + Parameters + ---------- + parent_supermode : SuperMode + The parent supermode from which this field is derived. """ self.parent_supermode = parent_supermode self.data = self.parent_supermode.binding.get_fields() @@ -42,11 +46,15 @@ def get_norm(self, slice_number: int) -> float: """ Calculate the norm of the field for a specific slice. - Args: - slice_number (int): The slice number for which to calculate the norm. + Parameters + ---------- + slice_number : int + The slice number for which to calculate the norm. - Returns: - float: The norm of the field. + Returns + ------- + float + The norm of the field for the specified slice. """ return self.parent_supermode.binding.get_norm(slice_number) @@ -55,8 +63,10 @@ def itr_list(self) -> numpy.ndarray: """ Provides a list of iteration indices available for the fields. - Returns: - numpy.ndarray: An array of iteration indices. + Returns + ------- + numpy.ndarray + An array of iteration indices. """ return self.parent_supermode.binding.model_parameters.itr_list @@ -65,8 +75,10 @@ def parent_superset(self) -> object: """ Access the parent set of the supermode. - Returns: - object: The parent set object. + Returns + ------- + object + The parent set object. """ return self.parent_supermode.parent_set @@ -74,16 +86,24 @@ def get_field(self, slice_number: int = None, itr: float = None, add_symmetries: """ Retrieve a specific field adjusted for boundary conditions and optionally add symmetries. - Args: - slice_number (int): The slice number to retrieve. - itr (float): The iteration to use for retrieving the field. - add_symmetries (bool): Whether to add boundary symmetries to the field. - - Returns: - numpy.ndarray: The requested field as a numpy array. - - Raises: - AssertionError: If neither or both of slice_number and itr are defined. + Parameters + ---------- + slice_number : int, optional + The slice number to retrieve. + itr : float, optional + The iteration to use for retrieving the field. + add_symmetries : bool, optional, default=True + Whether to add boundary symmetries to the field. + + Returns + ------- + numpy.ndarray + The requested field as a numpy array. + + Raises + ------ + AssertionError + If neither or both of `slice_number` and `itr` are defined. """ slice_list, itr_list = interpret_slice_number_and_itr( itr_baseline=self.itr_list, @@ -103,15 +123,19 @@ def normalize_field(self, field: numpy.ndarray, itr: float, norm_type: str = 'L2 """ Normalize a field array based on a specified normalization method. - Currently, this method is deprecated. - - Args: - field (numpy.ndarray): The field to normalize. - itr (float): The iteration value for normalization scaling. - norm_type (str): The type of normalization ('max', 'center', 'L2', 'cmt'). - - Returns: - numpy.ndarray: The normalized field. + Parameters + ---------- + field : numpy.ndarray + The field to normalize. + itr : float + The iteration value for normalization scaling. + norm_type : str, optional, default='L2' + The type of normalization ('max', 'center', 'L2', 'cmt'). + + Returns + ------- + numpy.ndarray + The normalized field. """ match norm_type.lower(): case 'max': @@ -136,14 +160,20 @@ def _get_symmetrized_field_and_axis(self, field: numpy.ndarray) -> tuple: """ Generate a symmetrical version of the input field mesh according to defined boundary conditions. - Args: - field (numpy.ndarray): The 2D field mesh to be symmetrized. + Parameters + ---------- + field : numpy.ndarray + The 2D field mesh to be symmetrized. - Returns: - numpy.ndarray: The symmetrized field mesh. + Returns + ------- + tuple + A tuple containing the x-axis, y-axis, and the symmetrized field. - Raises: - AssertionError: If the input is not a 2D array. + Raises + ------ + AssertionError + If the input is not a 2D array. """ x_axis, y_axis = self._get_axis_vector(add_symmetries=True) @@ -157,11 +187,20 @@ def _get_symmetrized_field(self, field: numpy.ndarray) -> numpy.ndarray: This method generates symmetrized versions of the field and its corresponding axis vectors. - Args: - field (numpy.ndarray): The field data array to be symmetrized. + Parameters + ---------- + field : numpy.ndarray + The field data array to be symmetrized. - Returns: - tuple: A tuple containing the x-axis, y-axis, and the symmetrized field data. + Returns + ------- + numpy.ndarray + The symmetrized field data. + + Raises + ------ + AssertionError + If the input field is not a 2D array. """ field = field.squeeze() assert field.ndim == 2, f"Expected a 2-dimensional array, but got {field.ndim}-dimensional." @@ -206,18 +245,27 @@ def plot( show_mode_label: bool = True, show_itr: bool = True, show_colorbar: bool = False, - show_slice: bool = True) -> NoReturn: + show_slice: bool = True, + show: bool = True) -> None: """ Plot the field for specified iterations or slice numbers. - Args: - itr_list (list[float]): List of iterations to evaluate the field. - slice_list (list[int]): List of slices to evaluate the field. - add_symmetries (bool): Whether to include boundary symmetries in the plot. - show_mode_label (bool): Whether to show the mode label. - show_itr (bool): Whether to show the iteration value. - show_colorbar (bool): Whether to show the colorbar. - show_slice (bool): Whether to show the slice number. + Parameters + ---------- + itr_list : list of float, optional + List of iterations to evaluate the field. + slice_list : list of int, optional, default=[0, -1] + List of slices to evaluate the field. + add_symmetries : bool, optional, default=True + Whether to include boundary symmetries in the plot. + show_mode_label : bool, optional, default=True + Whether to show the mode label. + show_itr : bool, optional, default=True + Whether to show the iteration value. + show_colorbar : bool, optional, default=False + Whether to show the colorbar. + show_slice : bool, optional, default=True + Whether to show the slice number. """ slice_list, itr_list = interpret_slice_number_and_itr( itr_baseline=self.itr_list, @@ -237,7 +285,9 @@ def plot( ) plt.tight_layout() - plt.show() + + if show: + plt.show() def render_on_ax( self, @@ -247,7 +297,7 @@ def render_on_ax( show_itr: bool = True, show_slice: bool = True, show_colorbar: bool = False, - add_symmetries: bool = True) -> NoReturn: + add_symmetries: bool = True) -> None: """ Render the mode field at the given slice number into the input axis. diff --git a/SuPyMode/representation/index.py b/SuPyMode/representation/index.py index d0160800..303e885a 100644 --- a/SuPyMode/representation/index.py +++ b/SuPyMode/representation/index.py @@ -6,7 +6,6 @@ if TYPE_CHECKING: from SuPyMode.supermode import SuperMode -from typing import NoReturn from SuPyMode.representation.base import InheritFromSuperMode, BaseSingleModePlot import matplotlib.pyplot as plt @@ -18,23 +17,38 @@ class Index(InheritFromSuperMode, BaseSingleModePlot): This class extends from `InheritFromSuperMode` for accessing supermode-related data and `BaseSingleModePlot` for plotting functionalities tailored to visualize the effective refractive index. - Class Attributes: - plot_style (dict): A dictionary defining the default style settings for plots generated by this class. + Attributes + ---------- + parent_supermode : SuperMode + The supermode instance to which this Index object is linked. + data : numpy.ndarray + The effective refractive index data for the mode derived from the parent supermode. """ - def __init__(self, parent_supermode: SuperMode): """ - Initializes an Index object with a reference to a parent supermode. + Initialize an Index object with a reference to a parent supermode. - Args: - parent_supermode (SuperMode): The parent supermode object that provides the base mode data. + Parameters + ---------- + parent_supermode : SuperMode + The parent supermode object that provides the base mode data. """ self.parent_supermode = parent_supermode self.data = self.parent_supermode.binding.get_index() - def _dress_ax(self, ax: plt.Axes) -> NoReturn: - ax.set_xlabel('Inverse taper ratio') - ax.set_ylabel('Effective refraction index') - ax.set_ylim([1.44, 1.455]) + def _dress_ax(self, ax: plt.Axes) -> None: + """ + Set axis labels and limits for the effective refractive index plot. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axis object on which to set the labels and limits. + """ + ax.set( + xlabel='Inverse taper ratio', + ylabel='Effective refractive index', + ylim=[1.44, 1.455] + ) # - diff --git a/SuPyMode/representation/normalized_coupling.py b/SuPyMode/representation/normalized_coupling.py index 2709f884..93a2b0a1 100644 --- a/SuPyMode/representation/normalized_coupling.py +++ b/SuPyMode/representation/normalized_coupling.py @@ -7,7 +7,6 @@ from SuPyMode.supermode import SuperMode import numpy -from typing import NoReturn from SuPyMode.representation.base import InheritFromSuperMode, BaseMultiModePlot import matplotlib.pyplot as plt @@ -19,28 +18,36 @@ class NormalizedCoupling(InheritFromSuperMode, BaseMultiModePlot): This class extends from `InheritFromSuperMode` for accessing supermode-related data and `BaseMultiModePlot` for plotting functionalities tailored to visualize mode coupling comparisons. - Class Attributes: - plot_style (dict): A dictionary defining the default style settings for plots generated by this class. - """ + Attributes + ---------- + parent_supermode : SuperMode + The supermode instance to which this NormalizedCoupling object is linked. + """ def __init__(self, parent_supermode: SuperMode): """ - Initializes a NormalizedCoupling object with a reference to a parent supermode. + Initialize a NormalizedCoupling object with a reference to a parent supermode. - Args: - parent_supermode (SuperMode): The parent supermode object that provides the base mode data. + Parameters + ---------- + parent_supermode : SuperMode + The parent supermode object that provides the base mode data. """ self.parent_supermode = parent_supermode def get_values(self, other_supermode: SuperMode) -> numpy.ndarray: """ - Calculates the normalized mode coupling between the parent supermode and another specified supermode. + Calculate the normalized mode coupling between the parent supermode and another specified supermode. - Args: - other_supermode (SuperMode): The supermode with which to compare the parent supermode. + Parameters + ---------- + other_supermode : SuperMode + The supermode with which to compare the parent supermode. - Returns: - numpy.ndarray: An array of normalized mode coupling values, adjusted for computational compatibility. + Returns + ------- + numpy.ndarray + An array of normalized mode coupling values, adjusted for computational compatibility. """ output = self.parent_supermode.binding.get_normalized_coupling_with_mode(other_supermode.binding) @@ -49,7 +56,15 @@ def get_values(self, other_supermode: SuperMode) -> numpy.ndarray: return output - def _dress_ax(self, ax: plt.Axes) -> NoReturn: + def _dress_ax(self, ax: plt.Axes) -> None: + """ + Set axis labels for the normalized coupling plot. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axis object on which to set the labels. + """ ax.set_xlabel('Inverse taper ratio') ax.set_ylabel('Mode coupling') diff --git a/SuPyMode/solver.py b/SuPyMode/solver.py index 9ade39d9..6e1aa011 100644 --- a/SuPyMode/solver.py +++ b/SuPyMode/solver.py @@ -21,17 +21,28 @@ @dataclass() class SuPySolver(object): """ - Solver class integrating a C++ eigensolver to compute eigenvalues for optical fiber geometries. - This class manages the eigenvalue problems and returns collections of computed SuperModes. - - Attributes: - geometry (Geometry | np.ndarray): The refractive index geometry of the optical structure. - tolerance (float): Absolute tolerance for the propagation constant computation. - max_iter (int): Maximum iterations for the C++ eigensolver. - accuracy (int): Accuracy level of the finite difference method. - extrapolation_order (int): Order of Taylor series used to extrapolate eigenvalues. - debug_mode (int): Debug output level from the C++ binding (0, 1, 2). - coordinate_system (Optional[CoordinateSystem]): The coordinate system linked with the geometry. + A solver for computing eigenvalues and supermodes of optical fiber geometries using a C++ eigensolver. + + This class manages the eigenvalue problem for optical structures and returns computed supermodes. The solver utilizes + a C++ backend for efficient eigenvalue computation and integrates with finite difference methods to solve for various + boundary conditions. + + Parameters + ---------- + geometry : Geometry or numpy.ndarray + The refractive index geometry of the optical structure. + tolerance : float, optional + Absolute tolerance for the propagation constant computation (default is 1e-8). + max_iter : int, optional + Maximum iterations for the C++ eigensolver (default is 10,000). + accuracy : int, optional + Accuracy level of the finite difference method (default is 2). + extrapolation_order : int, optional + Order of Taylor series used to extrapolate eigenvalues (default is 2). + debug_mode : int, optional + Debug output level from the C++ binding, where 0 is no output and higher values provide more detail (default is 1). + coordinate_system : CoordinateSystem, optional + The coordinate system linked with the geometry. Must be provided if geometry is given as an array. """ geometry: Geometry | numpy.ndarray = field(repr=False) tolerance: float = 1e-8 @@ -46,8 +57,8 @@ def __post_init__(self): assert self.coordinate_system is not None, "Geometry provided without its coordinate system" self.mesh = self.geometry else: - self.geometry.generate_coordinate_system() - self.mesh = self.geometry.generate_mesh() + self.geometry.generate_coordinate_mesh() + self.mesh = self.geometry.mesh self.coordinate_system = self.geometry.coordinate_system self.mode_number = 0 @@ -55,15 +66,21 @@ def __post_init__(self): def initialize_binding(self, n_sorted_mode: int, boundaries: Boundaries, n_added_mode: int) -> CppSolver: """ - Initializes and configures the C++ solver binding for eigenvalue computations. - - Args: - n_sorted_mode (int): Number of modes to sort and retrieve from the solver. - boundaries (Boundaries): Boundary conditions for the finite difference system. - n_added_mode (int): Number of extra modes calculated for accuracy and reliability. - - Returns: - CppSolver: Configured C++ solver instance. + Initialize and configure the C++ solver binding for eigenvalue computations. + + Parameters + ---------- + n_sorted_mode : int + Number of modes to sort and retrieve from the solver. + boundaries : Boundaries + Boundary conditions for the finite difference system. + n_added_mode : int + Number of extra modes calculated for accuracy and reliability. + + Returns + ------- + CppSolver + Configured C++ solver instance. """ self.FD = FiniteDifference( n_x=self.coordinate_system.nx, @@ -102,13 +119,18 @@ def initialize_binding(self, n_sorted_mode: int, boundaries: Boundaries, n_added def init_superset(self, wavelength: float, n_step: int = 300, itr_initial: float = 1.0, itr_final: float = 0.1) -> None: """ - Initializes a SuperSet instance containing computed supermodes over a range of inverse taper ratios (ITR). - - Args: - wavelength (float): Wavelength for the mode computation. - n_step (int): Number of steps for the ITR interpolation. - itr_initial (float): Initial ITR value. - itr_final (float): Final ITR value. + Initialize a SuperSet instance containing computed supermodes over a range of inverse taper ratios (ITR). + + Parameters + ---------- + wavelength : float + Wavelength for the mode computation. + n_step : int, optional + Number of steps for the ITR interpolation (default is 300). + itr_initial : float, optional + Initial ITR value (default is 1.0). + itr_final : float, optional + Final ITR value (default is 0.1). """ self.wavelength = wavelength self.wavenumber = 2 * numpy.pi / wavelength @@ -129,39 +151,53 @@ def init_superset(self, wavelength: float, n_step: int = 300, itr_initial: float def index_to_eigen_value(self, index: float) -> float: """ - Converts a refractive index to the corresponding eigenvalue for the solver. + Convert a refractive index to the corresponding eigenvalue for the solver. - Args: - index (float): Refractive index to convert. + Parameters + ---------- + index : float + Refractive index to convert. - Returns: - float: Calculated eigenvalue based on the given index and the wavenumber. + Returns + ------- + float + Calculated eigenvalue based on the given index and the wavenumber. """ return -(index * self.wavenumber)**2 def eigen_value_to_index(self, eigen_value: float) -> float: """ - Converts an eigenvalue from the solver to the corresponding refractive index. + Convert an eigenvalue from the solver to the corresponding refractive index. - Args: - eigen_value (float): Eigenvalue to convert. + Parameters + ---------- + eigen_value : float + Eigenvalue to convert. - Returns: - float: Equivalent refractive index calculated from the eigenvalue and the wavenumber. + Returns + ------- + float + Equivalent refractive index calculated from the eigenvalue and the wavenumber. """ return numpy.sqrt(eigen_value) / self.wavenumber def get_supermode_labels(self, n_modes: int, boundaries: Boundaries, auto_label: bool) -> list: """ - Generates labels for supermodes based on boundary conditions and whether auto-labeling is enabled. - - Args: - n_modes (int): Number of modes for which labels are needed. - boundaries (Boundaries): Boundary conditions that affect mode symmetries. - auto_label (bool): If True, automatically generates labels based on mode symmetries; otherwise, generates generic labels. - - Returns: - list: List of labels for the supermodes. + Generate labels for supermodes based on boundary conditions and whether auto-labeling is enabled. + + Parameters + ---------- + n_modes : int + Number of modes for which labels are needed. + boundaries : Boundaries + Boundary conditions that affect mode symmetries. + auto_label : bool + If True, automatically generates labels based on mode symmetries; otherwise, generates generic labels. + + Returns + ------- + list + List of labels for the supermodes. """ if auto_label: return [ModeLabel(boundaries=boundaries, mode_number=n).label for n in range(n_modes)] @@ -170,17 +206,25 @@ def get_supermode_labels(self, n_modes: int, boundaries: Boundaries, auto_label: def add_modes(self, n_sorted_mode: int, boundaries: Boundaries, n_added_mode: int = 4, index_guess: float = 0., auto_label: bool = True) -> None: """ - Computes and adds a specified number of supermodes to the solver's collection, using given boundary conditions and mode sorting criteria. - - Args: - n_sorted_mode (int): Number of modes to output and sort from the solver. - boundaries (Boundaries): Boundary conditions for the finite difference calculations. - n_added_mode (int): Additional modes computed to ensure mode matching accuracy. - index_guess (float): Starting guess for the refractive index used in calculations (if 0, auto evaluated). - auto_label (bool): If True, enables automatic labeling of modes based on symmetry. - - Returns: - None: This method updates the solver's internal state but does not return any value. + Compute and add a specified number of supermodes to the solver's collection. + + Parameters + ---------- + n_sorted_mode : int + Number of modes to output and sort from the solver. + boundaries : Boundaries + Boundary conditions for the finite difference calculations. + n_added_mode : int, optional + Additional modes computed to ensure mode matching accuracy (default is 4). + index_guess : float, optional + Starting guess for the refractive index used in calculations (default is 0., auto-evaluated if set to 0). + auto_label : bool, optional + If True, enables automatic labeling of modes based on symmetry (default is True). + + Returns + ------- + None + This method updates the solver's internal state but does not return any value. """ alpha = self.index_to_eigen_value(index_guess) @@ -220,5 +264,3 @@ def add_modes(self, n_sorted_mode: int, boundaries: Boundaries, n_added_mode: in self.mode_number += 1 self.solver_number += 1 - -# --- diff --git a/SuPyMode/supermode.py b/SuPyMode/supermode.py index d99f9bed..36931a7b 100644 --- a/SuPyMode/supermode.py +++ b/SuPyMode/supermode.py @@ -1,6 +1,3 @@ -# #!/usr/bin/env python -# # -*- coding: utf-8 -*- - # Built-in imports import numpy from dataclasses import dataclass, field as field_arg @@ -20,13 +17,39 @@ class SuperMode(): the SuPySolver. Instances of this class belong to a SuperSet, and each supermode is uniquely identified within its symmetry set by a mode number. - Attributes: - parent_set (None): The SuperSet instance associated with this supermode. - binding (None): The corresponding C++ bound supermode object. - solver_number (int): Identifier linking this supermode to a specific Python solver. - mode_number (int): Unique identifier for this mode within a symmetry set. - boundaries (dict): Specifications of the boundary conditions for the supermode. - label (str, optional): An arbitrary descriptive label for the supermode. + Parameters + ---------- + parent_set : object + The SuperSet instance associated with this supermode. + binding : object + The corresponding C++ bound supermode object. + solver_number : int + Identifier linking this supermode to a specific Python solver. + mode_number : int + Unique identifier for this mode within a symmetry set. + boundaries : dict + Specifications of the boundary conditions for the supermode. + label : str, optional + An arbitrary descriptive label for the supermode. + + Attributes + ---------- + ID : list + Unique identifier for the solver and binding number. + field : Field + Field representation associated with the supermode. + index : Index + Index representation associated with the supermode. + beta : Beta + Beta representation associated with the supermode. + normalized_coupling : NormalizedCoupling + Normalized coupling representation of the supermode. + adiabatic : Adiabatic + Adiabatic representation of the supermode. + eigen_value : EigenValue + Eigenvalue representation of the supermode. + beating_length : BeatingLength + Beating length representation of the supermode. """ parent_set: object binding: object @@ -56,17 +79,28 @@ def __post_init__(self): def __hash__(self): """ - Returns a hash based on the binded supermode object, allowing this class - instance to be used in hash-based collections like sets and dictionaries. + Returns a hash value based on the bound supermode object. + + This allows instances of this class to be used in hash-based collections + such as sets and dictionaries. - Returns: - int: The hash value of the binded supermode object. + Returns + ------- + int + The hash value of the bound supermode object. """ return hash(self.binding) @property def binding_number(self) -> int: - """Retrieves the binding number specific to the linked C++ solver.""" + """ + Retrieves the binding number specific to the linked C++ solver. + + Returns + ------- + int + The binding number from the associated C++ solver. + """ return self.binding.binding_number @property @@ -74,8 +108,10 @@ def geometry(self) -> object: """ Provides access to the geometric configuration associated with the supermode. - Returns: - object: The geometry of the parent SuperSet. + Returns + ------- + object + The geometry of the parent SuperSet. """ return self.parent_set.geometry @@ -84,14 +120,23 @@ def coordinate_system(self) -> object: """ Accesses the coordinate system used by the supermode. - Returns: - object: The coordinate system of the parent SuperSet. + Returns + ------- + object + The coordinate system of the parent SuperSet. """ return self.parent_set.coordinate_system @property def itr_list(self) -> numpy.ndarray: - """Provides a list of iteration parameters from the model.""" + """ + Provides a list of iteration parameters from the model. + + Returns + ------- + numpy.ndarray + Array of iteration parameters from the model. + """ return self.binding.model_parameters.itr_list @property @@ -99,14 +144,23 @@ def model_parameters(self) -> ModelParameters: """ Retrieves parameters defining the model's computational aspects. - Returns: - ModelParameters: Computational parameters from the binded supermode. + Returns + ------- + ModelParameters + Computational parameters from the bound supermode. """ return self.binding.model_parameters @property def mesh_gradient(self) -> numpy.ndarray: - """Accesses the gradient mesh associated with the supermode.""" + """ + Accesses the gradient mesh associated with the supermode. + + Returns + ------- + numpy.ndarray + The gradient mesh of the supermode. + """ return self.binding.mesh_gradient @property @@ -115,8 +169,10 @@ def amplitudes(self) -> numpy.ndarray: Computes the amplitude array for this supermode, setting its own mode number to 1 and all others to 0. - Returns: - numpy.ndarray: Array of complex numbers representing amplitudes. + Returns + ------- + numpy.ndarray + Array of complex numbers representing amplitudes. """ n_mode = len(self.parent_set.supermodes) amplitudes = numpy.zeros(n_mode, dtype=complex) @@ -129,8 +185,10 @@ def stylized_label(self) -> str: Provides a stylized label for the supermode. If no custom label is provided, it defaults to a generic label with the mode ID. - Returns: - str: The stylized or default label. + Returns + ------- + str + The stylized or default label. """ if self.label is None: return f"Mode: {self.ID}" @@ -142,40 +200,55 @@ def is_computation_compatible(self, other: 'SuperMode') -> bool: Determines if another supermode is compatible for computation, based on unique identifiers and boundary conditions. - Parameters: - other (SuperMode): The other supermode to compare. + Parameters + ---------- + other : SuperMode + The other supermode to compare. - Returns: - bool: True if the supermodes are compatible for computation, False otherwise. + Returns + ------- + bool + True if the supermodes are compatible for computation, False otherwise. """ return self.binding.is_computation_compatible(other.binding) def is_symmetry_compatible(self, other: 'SuperMode') -> bool: """ - Determines whether the specified other supermode has same symmetry. + Determines whether the specified other supermode has the same symmetry. - :param other: The other supermode - :type other: SuperMode + Parameters + ---------- + other : SuperMode + The other supermode to compare. - :returns: True if the specified other is symmetry compatible, False otherwise. - :rtype: bool + Returns + ------- + bool + True if the specified other is symmetry compatible, False otherwise. """ return self.boundaries == other.boundaries def get_field_interpolation(self, itr: float = None, slice_number: int = None) -> RectBivariateSpline: """ - Computes the field interpolation for a given iteration or slice number. Requires - exactly one of the parameters to be specified. - - Parameters: - itr (float, optional): The iteration number for which to compute the interpolation. - slice_number (int, optional): The slice number for which to compute the interpolation. - - Returns: - RectBivariateSpline: Interpolated field values over a grid. - - Raises: - ValueError: If neither or both parameters are specified. + Computes the field interpolation for a given iteration or slice number. + Requires exactly one of the parameters to be specified. + + Parameters + ---------- + itr : float, optional + The iteration number for which to compute the interpolation. + slice_number : int, optional + The slice number for which to compute the interpolation. + + Returns + ------- + RectBivariateSpline + Interpolated field values over a grid. + + Raises + ------ + ValueError + If neither or both parameters are specified. """ if (itr is None) == (slice_number is None): raise ValueError("Exactly one of itr or slice_number must be provided.") @@ -189,7 +262,7 @@ def get_field_interpolation(self, itr: float = None, slice_number: int = None) - slice_list=slice_number ) - field = self.field.get_field(slice_number=slice_number, add_symmetries=True) + field = self.field.get_field(slice_number=slice_number, add_symmetries=True).T x_axis, y_axis = self.get_axis(slice_number=slice_number, add_symmetries=True) @@ -202,6 +275,19 @@ def get_field_interpolation(self, itr: float = None, slice_number: int = None) - return field_interpolation def _get_axis_vector(self, add_symmetries: bool = True) -> tuple: + """ + Computes the full axis vectors, optionally including symmetries. + + Parameters + ---------- + add_symmetries : bool, optional, default=True + Whether to include symmetries when computing the axis vectors. + + Returns + ------- + tuple + A tuple containing the full x-axis and y-axis vectors. + """ full_x_axis = self.coordinate_system.x_vector full_y_axis = self.coordinate_system.y_vector @@ -227,29 +313,57 @@ def _get_axis_vector(self, add_symmetries: bool = True) -> tuple: return full_x_axis, full_y_axis def get_axis(self, slice_number: int, add_symmetries: bool = True) -> tuple: + """ + Computes the scaled axis vectors for a specific slice, optionally including symmetries. + + Parameters + ---------- + slice_number : int + The slice index for which to compute the axis vectors. + add_symmetries : bool, optional, default=True + Whether to include symmetries in the computed axis vectors. + + Returns + ------- + tuple + A tuple containing the scaled x-axis and y-axis vectors for the given slice. + """ itr = self.model_parameters.itr_list[slice_number] - x_axis, y_axis = self._get_axis_vector(add_symmetries=add_symmetries) - return (x_axis * itr, y_axis * itr) def __repr__(self) -> str: + """ + Provides a string representation of the supermode. + + Returns + ------- + str + The label of the supermode. + """ return self.label def plot(self, plot_type: str, **kwargs): """ - Plots various properties of the supermode based on specified type. - - Parameters: - plot_type (str): The type of plot to generate (e.g., 'field', 'beta'). - *args: Additional positional arguments for the plot function. - **kwargs: Additional keyword arguments for the plot function. - - Returns: + Plots various properties of the supermode based on the specified type. + + Parameters + ---------- + plot_type : str + The type of plot to generate. Options include 'field', 'beta', 'index', 'eigen-value', + 'beating-length', 'adiabatic', and 'normalized-coupling'. + **kwargs : dict + Additional keyword arguments to pass to the plotting function. + + Returns + ------- + object The result of the plotting function, typically a plot object. - Raises: - ValueError: If an invalid plot type is specified. + Raises + ------ + ValueError + If an invalid plot type is specified. """ match plot_type.lower(): case 'field': @@ -269,5 +383,4 @@ def plot(self, plot_type: str, **kwargs): case _: raise ValueError(f'Invalid plot type: {plot_type}. Options are: index, beta, eigen-value, field, beating-length, adiabatic, normalized-coupling') - # - diff --git a/SuPyMode/superset.py b/SuPyMode/superset.py index 3a3fc7b1..f908537b 100644 --- a/SuPyMode/superset.py +++ b/SuPyMode/superset.py @@ -11,7 +11,6 @@ from typing import Optional, List from FiberFusing.geometry import Geometry - # Third-party imports from scipy.interpolate import interp1d from scipy.integrate import solve_ivp @@ -22,18 +21,26 @@ from SuPyMode.profiles import AlphaProfile from SuPyMode.propagation import Propagation from SuPyMode.superset_plots import SuperSetPlots -from SuPyMode.utils import test_valid_input, parse_mode_of_interest, parse_combination, parse_filename +from SuPyMode.utils import test_valid_input, parse_filename +from SuPyMode.helper import parse_mode_of_interest, parse_combination + @dataclass class SuperSet(SuperSetPlots): """ - A class representing a set of supermodes calculated for a specific optical fiber configuration. - It facilitates operations on supermodes like sorting, plotting, and computations related to fiber optics simulations. - - Attributes: - model_parameters (ModelParameters): - wavelength (float): The wavelength used in the solver, in meters. - geometry (object): + A set of supermodes calculated for a specific optical fiber configuration. + + This class manages operations on supermodes, including sorting, plotting, and calculations + related to fiber optics simulations. + + Parameters + ---------- + model_parameters : ModelParameters + Parameters defining the model for simulation. + wavelength : float + The wavelength used in the solver, in meters. + geometry : Geometry + The geometry of the optical structure. """ model_parameters: ModelParameters wavelength: float @@ -53,40 +60,48 @@ def __setitem__(self, idx: int, value: SuperMode) -> None: @property def coordinate_system(self): """ - Return axes object of the geometry + Returns the coordinate system associated with the geometry. + + Returns + ------- + CoordinateSystem + The coordinate system of the geometry. """ return self.geometry.coordinate_system @property def fundamental_supermodes(self) -> list[SuperMode]: """ - Identifies and returns fundamental supermodes based on the highest beta values and minimal spatial overlap. + Returns the fundamental supermodes based on the highest beta values and minimal spatial overlap. - Args: - tolerance (float): The spatial overlap tolerance for mode distinction. - - Returns: - list[SuperMode]: A list of fundamental supermodes. + Returns + ------- + list of SuperMode + A list of fundamental supermodes. """ return self.get_fundamental_supermodes(tolerance=1e-2) @property def non_fundamental_supermodes(self) -> list[SuperMode]: """ - Identifies and returns non-fundamental supermodes based on the specified spatial overlap tolerance. - - Args: - tolerance (float): The spatial overlap tolerance for distinguishing between fundamental and other modes. + Returns the non-fundamental supermodes based on the specified spatial overlap tolerance. - Returns: - list[SuperMode]: A list of non-fundamental supermodes. + Returns + ------- + list of SuperMode + A list of non-fundamental supermodes. """ return self.get_non_fundamental_supermodes(tolerance=1e-2) @property def transmission_matrix(self) -> numpy.ndarray: """ - Return the supermode transmission matrix. + Returns the supermode transmission matrix. + + Returns + ------- + numpy.ndarray + The transmission matrix for the supermodes. """ if self._transmission_matrix is None: self.compute_transmission_matrix() @@ -97,11 +112,15 @@ def itr_to_slice(self, itr_list: list[float]) -> list[int]: """ Convert ITR values to corresponding slice numbers. - Args: - itr_list (list[float]): Inverse taper ratio values. + Parameters + ---------- + itr_list : list of float + Inverse taper ratio values. - Returns: - list[int]: List of slice numbers corresponding to the ITR values. + Returns + ------- + list of int + List of slice numbers corresponding to the ITR values. """ itr_list = numpy.asarray(itr_list) @@ -111,11 +130,15 @@ def get_fundamental_supermodes(self, *, tolerance: float = 0.1) -> list[SuperMod """ Returns a list of fundamental supermodes with the highest propagation constant values and minimal spatial overlap. - Args: - tolerance (float): Tolerance for spatial overlap. + Parameters + ---------- + tolerance : float, optional + Tolerance for spatial overlap (default is 0.1). - Returns: - list[SuperMode]: List of fundamental supermodes. + Returns + ------- + list of SuperMode + List of fundamental supermodes. """ self.sort_modes_by_beta() @@ -152,11 +175,15 @@ def get_non_fundamental_supermodes(self, *, tolerance: float = 0.1) -> list[Supe """ Returns a list of non-fundamental supermodes that do not overlap with the fundamental modes. - Args: - tolerance (float): Tolerance for spatial overlap. + Parameters + ---------- + tolerance : float, optional + Tolerance for spatial overlap (default is 0.1). - Returns: - list[SuperMode]: List of non-fundamental supermodes. + Returns + ------- + list of SuperMode + List of non-fundamental supermodes. """ non_fundamental_supermodes = self.supermodes @@ -169,8 +196,10 @@ def get_mode_solver_classification(self) -> list[list[SuperMode]]: """ Returns a list of modes classified by solver number. - Returns: - list[list[SuperMode]]: List of lists containing modes classified by solver number. + Returns + ------- + list of list of SuperMode + List of lists containing modes classified by solver number. """ solver_numbers = [mode.solver_number for mode in self] @@ -187,10 +216,12 @@ def get_mode_solver_classification(self) -> list[list[SuperMode]]: def label_supermodes(self, *label_list) -> None: """ - Assigns labels to the supermodes. + Assign labels to the supermodes. - Args: - label_list (tuple): Labels to assign to the supermodes. + Parameters + ---------- + label_list : tuple + Labels to assign to the supermodes. """ for n, label in enumerate(label_list): self[n].label = label @@ -199,14 +230,14 @@ def label_supermodes(self, *label_list) -> None: def reset_labels(self) -> None: """ - Resets labels for all supermodes to default values. + Reset labels for all supermodes to default values. """ for n, super_mode in enumerate(self.supermodes): super_mode.label = f'mode_{n}' def compute_transmission_matrix(self) -> None: """ - Calculates the transmission matrix with only the propagation constant included. + Calculate the transmission matrix including only the propagation constant. """ shape = [ len(self.supermodes), @@ -221,14 +252,19 @@ def compute_transmission_matrix(self) -> None: def add_coupling_to_t_matrix(self, *, t_matrix: numpy.ndarray, adiabatic_factor: numpy.ndarray) -> numpy.ndarray: """ - Add the coupling coefficients to the transmission matrix. + Add coupling coefficients to the transmission matrix. - Args: - t_matrix (np.ndarray): Transmission matrix to which coupling values are added. - adiabatic_factor (np.ndarray): Adiabatic factor, set to one if None (normalized coupling). + Parameters + ---------- + t_matrix : numpy.ndarray + Transmission matrix to which coupling values are added. + adiabatic_factor : numpy.ndarray + Adiabatic factor, set to one if None (normalized coupling). - Returns: - np.ndarray: Updated transmission matrix with coupling values. + Returns + ------- + numpy.ndarray + Updated transmission matrix with coupling values. """ size = t_matrix.shape[-1] @@ -252,18 +288,18 @@ def add_coupling_to_t_matrix(self, *, t_matrix: numpy.ndarray, adiabatic_factor: def compute_coupling_factor(self, *, coupler_length: float) -> numpy.ndarray: """ - Compute the coupling factor defined as: - - .. math:: - f_c = \frac{1}{\rho} \frac{d \rho}{d z} + Compute the coupling factor defined by the derivative of the inverse taper ratio. - Args: - coupler_length (float): Length of the coupler. + Parameters + ---------- + coupler_length : float + Length of the coupler. - Returns: - np.ndarray: Coupling factor as a function of distance in the coupler. + Returns + ------- + numpy.ndarray + Coupling factor as a function of distance in the coupler. """ - dx = coupler_length / (self.model_parameters.n_slice) ditr = numpy.gradient(numpy.log(self.model_parameters.itr_list), axis=0) @@ -272,14 +308,19 @@ def compute_coupling_factor(self, *, coupler_length: float) -> numpy.ndarray: def get_transmision_matrix_from_profile(self, *, profile: AlphaProfile, add_coupling: bool = True) -> tuple: """ - Get the transmission matrix from the profile. + Get the transmission matrix from the given profile. - Args: - profile (AlphaProfile): Z-profile of the coupler. - add_coupling (bool): Add coupling to the transmission matrix. Defaults to True. + Parameters + ---------- + profile : AlphaProfile + Z-profile of the coupler. + add_coupling : bool, optional + Whether to add coupling to the transmission matrix (default is True). - Returns: - tuple: Distance, ITR vector, and transmission matrix. + Returns + ------- + tuple + Tuple containing distance, ITR vector, and transmission matrix. """ profile.initialize() @@ -309,21 +350,29 @@ def propagate( method: str = 'RK45', **kwargs: dict) -> Propagation: """ - Propagates the amplitudes of the supermodes in a coupler based on a given profile. - - Args: - profile (AlphaProfile): The z-profile of the coupler. - initial_amplitude (list): The initial amplitude as a list. - max_step (float, optional): The maximum step size used by the solver. Defaults to None. - n_step (int, optional): Number of steps used by the solver (not currently used in this method). - add_coupling (bool): Flag to add coupling to the transmission matrix. Defaults to True. - method (str): Integration method to be used by the solver. Defaults to 'RK45'. - **kwargs (Dict[str, Any]): Additional keyword arguments to be passed to the solver. - - Returns: - Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: A tuple containing the times of the solution, - the solution array of amplitudes, and the interpolated - index of refraction at those times. + Propagate the amplitudes of the supermodes in a coupler based on the given profile. + + Parameters + ---------- + profile : AlphaProfile + The z-profile of the coupler. + initial_amplitude : list + The initial amplitude as a list of complex numbers. + max_step : float, optional + The maximum step size used by the solver (default is None). + n_step : int, optional + Number of steps used by the solver (default is None). + add_coupling : bool, optional + Whether to add coupling to the transmission matrix (default is True). + method : str, optional + Integration method to be used by the solver (default is 'RK45'). + **kwargs : dict + Additional keyword arguments to be passed to the solver. + + Returns + ------- + Propagation + A Propagation object containing the results of the propagation. """ initial_amplitude = numpy.asarray(initial_amplitude, dtype=complex) @@ -367,16 +416,22 @@ def model(z, y): def interpret_initial_input(self, initial_amplitude: list | SuperMode) -> numpy.ndarray: """ - Interprets the initial amplitude input, ensuring compatibility with the expected number of supermodes. + Interpret the initial amplitude input to ensure compatibility with the expected number of supermodes. - Args: - initial_amplitude (list | SuperMode): The initial amplitude as either a list of complex numbers or a SuperMode object. + Parameters + ---------- + initial_amplitude : list or SuperMode + The initial amplitude as either a list of complex numbers or a SuperMode object. - Returns: - np.ndarray: The initial amplitudes as a NumPy array of complex numbers. + Returns + ------- + numpy.ndarray + The initial amplitudes as a NumPy array of complex numbers. - Raises: - ValueError: If the length of the initial amplitude list does not match the number of supermodes. + Raises + ------ + ValueError + If the length of the initial amplitude list does not match the number of supermodes. """ if isinstance(initial_amplitude, SuperMode): amplitudes = initial_amplitude.amplitudes @@ -393,13 +448,17 @@ def interpret_initial_input(self, initial_amplitude: list | SuperMode) -> numpy. def _sort_modes(self, ordering_keys) -> List[SuperMode]: """ - Sorts supermodes using specified keys provided as tuples in ordering_keys. + Sort supermodes using specified keys provided as tuples in `ordering_keys`. - Args: - ordering_keys (tuple): Tuple containing keys to sort by. + Parameters + ---------- + ordering_keys : tuple + Tuple containing keys to sort by. - Returns: - List[SuperMode]: Sorted list of supermodes. + Returns + ------- + list of SuperMode + Sorted list of supermodes. """ order = numpy.lexsort(ordering_keys) sorted_supermodes = [self.supermodes[idx] for idx in order] @@ -409,7 +468,7 @@ def _sort_modes(self, ordering_keys) -> List[SuperMode]: def sort_modes_by_beta(self) -> None: """ - Sorts supermodes in descending order of their propagation constants (beta). + Sort supermodes in descending order of their propagation constants (beta). """ lexort_index = ([-mode.beta.data[-1] for mode in self.supermodes], ) @@ -417,14 +476,19 @@ def sort_modes_by_beta(self) -> None: def sort_modes(self, sorting_method: str = "beta", keep_only: Optional[int] = None) -> None: """ - Sorts supermodes according to the specified method, optionally limiting the number of modes retained. + Sort supermodes according to the specified method, optionally limiting the number of modes retained. - Args: - sorting_method (str): Sorting method to use, either "beta" or "symmetry+beta". - keep_only (int, optional): Number of supermodes to retain after sorting. + Parameters + ---------- + sorting_method : str, optional + Sorting method to use, either "beta" or "symmetry+beta" (default is "beta"). + keep_only : int, optional + Number of supermodes to retain after sorting (default is None). - Raises: - ValueError: If an unrecognized sorting method is provided. + Raises + ------ + ValueError + If an unrecognized sorting method is provided. """ match sorting_method.lower(): case 'beta': @@ -438,7 +502,7 @@ def sort_modes(self, sorting_method: str = "beta", keep_only: Optional[int] = No def sort_modes_by_solver_and_beta(self) -> None: """ - Sorts supermodes primarily by solver number and secondarily by descending propagation constant (beta). + Sort supermodes primarily by solver number and secondarily by descending propagation constant (beta). """ lexort_index = ( [mode.solver_number for mode in self.supermodes], @@ -449,26 +513,34 @@ def sort_modes_by_solver_and_beta(self) -> None: def is_compute_compatible(self, pair_of_mode: tuple) -> bool: """ - Determines whether the specified pair of mode is compatible for computation. + Determine whether the specified pair of modes is compatible for computation. - Args: - pair_of_mode (tuple): The pair of modes. + Parameters + ---------- + pair_of_mode : tuple + The pair of modes to be checked. - Returns: - bool: True if the pair of modes is compute compatible, False otherwise. + Returns + ------- + bool + True if the pair of modes is compatible for computation, False otherwise. """ mode_0, mode_1 = pair_of_mode return mode_0.is_computation_compatible(mode_1) def remove_duplicate_combination(self, supermodes_list: list) -> list[SuperMode]: """ - Removes duplicate combinations in the mode combination list irrespective of the order. + Remove duplicate combinations from the mode combination list irrespective of the order. - Args: - supermodes_list (list): List of mode combinations. + Parameters + ---------- + supermodes_list : list of tuple + List of mode combinations. - Returns: - list: Reduced list of unique supermode combinations. + Returns + ------- + list of tuple + Reduced list of unique supermode combinations. """ output_list = [] @@ -482,12 +554,17 @@ def interpret_combination(self, mode_of_interest: list, combination: str) -> set """ Interpret user input for mode selection and return the combination of modes to consider. - Args: - mode_of_interest (list): List of modes of interest. - mode_selection (str): Mode selection method. + Parameters + ---------- + mode_of_interest : list + List of modes of interest. + combination : str + Mode selection method ('pairs' or 'specific'). - Returns: - set: Set of mode combinations. + Returns + ------- + set of tuple + Set of mode combinations. """ test_valid_input( variable_name='combination', @@ -510,14 +587,17 @@ def interpret_combination(self, mode_of_interest: list, combination: str) -> set @parse_filename def save_instance(self, filename: str) -> Path: """ - Saves the SuperSet instance as a serialized pickle file. + Save the SuperSet instance as a serialized pickle file. - Args: - filename (str): Filename for the serialized instance. - directory (str): Directory to save the file, 'auto' means the instance_directory. + Parameters + ---------- + filename : str + Filename for the serialized instance. - Returns: - Path: The path to the saved instance file. + Returns + ------- + Path + The path to the saved instance file. """ with open(filename.with_suffix('.pickle'), 'wb') as output_file: pickle.dump(self, output_file, pickle.HIGHEST_PROTOCOL) @@ -540,19 +620,31 @@ def export_data(self, """ Export the SuperSet data as CSV files, saving specific attributes of the modes or combinations of modes. - Args: - filename (str): The directory where the files will be saved. - mode_of_interest (list): List of modes to be exported. Defaults to 'all'. - combination (list): List of mode combinations to be exported. Defaults to None. - export_index (bool): Whether to export the 'index' attribute. Defaults to True. - export_beta (bool): Whether to export the 'beta' attribute. Defaults to True. - export_eigen_value (bool): Whether to export the 'eigen_value' attribute. Defaults to False. - export_adiabatic (bool): Whether to export the 'adiabatic' attribute for combinations. Defaults to True. - export_beating_length (bool): Whether to export the 'beating_length' attribute for combinations. Defaults to True. - export_normalized_coupling (bool): Whether to export the 'normalized_coupling' attribute for combinations. Defaults to True. - - Returns: - Path: The path to the directory where the files were saved. + Parameters + ---------- + filename : str + The directory where the files will be saved. + mode_of_interest : list, optional + List of modes to be exported (default is 'all'). + combination : list, optional + List of mode combinations to be exported (default is None). + export_index : bool, optional + Whether to export the 'index' attribute (default is True). + export_beta : bool, optional + Whether to export the 'beta' attribute (default is True). + export_eigen_value : bool, optional + Whether to export the 'eigen_value' attribute (default is False). + export_adiabatic : bool, optional + Whether to export the 'adiabatic' attribute for combinations (default is True). + export_beating_length : bool, optional + Whether to export the 'beating_length' attribute for combinations (default is True). + export_normalized_coupling : bool, optional + Whether to export the 'normalized_coupling' attribute for combinations (default is True). + + Returns + ------- + Path + The path to the directory where the files were saved. """ from pathlib import Path import numpy as np @@ -592,7 +684,3 @@ def _export_combination_data(attribute_name: str): _export_combination_data('beating_length') if export_normalized_coupling: _export_combination_data('normalized_coupling') - - - -# - diff --git a/SuPyMode/superset_plots.py b/SuPyMode/superset_plots.py index 5c0311e4..6d577584 100644 --- a/SuPyMode/superset_plots.py +++ b/SuPyMode/superset_plots.py @@ -3,275 +3,254 @@ # Built-in imports import numpy -from typing import NoReturn -from functools import wraps -# Local imports from SuPyMode.supermode import SuperMode from SuPyMode.utils import get_intersection, interpret_mode_of_interest, interpret_slice_number_and_itr, parse_filename from SuPyMode.profiles import AlphaProfile import matplotlib.pyplot as plt from matplotlib.backends.backend_pdf import PdfPages -from SuPyMode.utils import parse_mode_of_interest, parse_combination from MPSPlots.styles import gg_plot as plot_style +from SuPyMode.helper import singular_plot_helper, combination_plot_helper, parse_mode_of_interest -class SuperSetPlots(object): - # EFFECTIVE INDEX ------------------------------------------------------- - @parse_mode_of_interest - def _logic_index(self, - ax: plt.Axes, - mode_of_interest: list[SuperMode], - show_crossings: bool = False) -> NoReturn: - """ - Plot effective index for each mode as a function of itr. - Args: - show_crossings (bool): Whether to show crossings in the plot. - mode_of_interest (str | list[SuperMode]): The mode of interest. +class SuperSetPlots(object): - Returns: - NoReturn + def _plot_attribute(self, ax: plt.Axes, mode_of_interest: list[SuperMode], attribute: str, show_crossings: bool = False) -> None: + """ + Generalized function to plot a given attribute for each mode. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axes object on which to plot. + mode_of_interest : list of SuperMode + List of modes to be plotted. + attribute : str + The attribute to plot ('index', 'beta', 'eigen_value'). + show_crossings : bool, optional + Whether to show crossings in the plot (default is False). + + Returns + ------- + None """ for mode in mode_of_interest: - mode.index.render_on_ax(ax=ax) - mode.index._dress_ax(ax=ax) + getattr(mode, attribute).plot(ax=ax, show=False) if show_crossings: - self.add_crossings_to_ax(ax=ax, mode_of_interest=mode_of_interest, data_type='index') - - plt.legend() - - @wraps(_logic_index) - def get_figure_index(self, *args, **kwargs): - figure, ax = plt.subplots(1, 1) - self._logic_index(ax=ax, *args, **kwargs) - return figure + self.add_crossings_to_ax(ax=ax, mode_of_interest=mode_of_interest, data_type=attribute) - @wraps(_logic_index) - def plot_index(self, *args, **kwargs): - with plt.style.context(plot_style): - figure, ax = plt.subplots(1, 1) - self._logic_index(ax=ax, *args, **kwargs) - plt.show() + ax.legend() - # PROPAGATION CONSTANT ------------------------------------------------------- - @parse_mode_of_interest - def _logic_beta(self, - ax: plt.Axes, - mode_of_interest: list[SuperMode], - show_crossings: bool = False) -> NoReturn: + @singular_plot_helper + def plot_index(self, ax: plt.Axes, mode_of_interest: list[SuperMode], show_crossings: bool = False) -> plt.Figure: """ - Plot propagation constant for each mode as a function of itr. - - Args: - show_crossings (bool): Whether to show crossings in the plot. - mode_of_interest (str | list[SuperMode]): The mode of interest. - - Returns: - NoReturn + Plot the effective index for each mode as a function of inverse taper ratio (ITR). + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axes object on which to plot. + mode_of_interest : list of SuperMode + List of modes to be plotted. + show_crossings : bool, optional + Whether to show crossings in the plot (default is False). + + Examples + -------- + >>> fig, ax = plt.subplots() + >>> superset_plots.plot_index(ax=ax, mode_of_interest=[mode1, mode2], show_crossings=True) + >>> plt.show() + + Returns + ------- + plt.Figure + The figure object containing the generated plots. This can be customized further or saved to a file. """ - for mode in mode_of_interest: - mode.beta.render_on_ax(ax=ax) - mode.beta._dress_ax(ax=ax) - - if show_crossings: - self.add_crossings_to_ax(ax=ax, mode_of_interest=mode_of_interest, data_type='beta') - - plt.legend() + self._plot_attribute(ax, mode_of_interest, 'index', show_crossings) - @wraps(_logic_beta) - def get_figure_beta(self, *args, **kwargs): - figure, ax = plt.subplots(1, 1) - self._logic_beta(ax=ax, *args, **kwargs) - return figure - - @wraps(_logic_beta) - def plot_beta(self, *args, **kwargs): - with plt.style.context(plot_style): - figure, ax = plt.subplots(1, 1) - self._logic_beta(ax=ax, *args, **kwargs) - plt.show() - - # EIGEN-VALUE ------------------------------------------------------- - @parse_mode_of_interest - def _logic_eigen_value(self, - ax: plt.Axes, - mode_of_interest: list[SuperMode], - show_crossings: bool = False) -> NoReturn: + @singular_plot_helper + def plot_beta(self, ax: plt.Axes, mode_of_interest: list[SuperMode], show_crossings: bool = False) -> plt.Figure: """ - Plot propagation constant for each mode as a function of itr. - - Args: - mode_of_interest (str | list[SuperMode]): The mode of interest. - show_crossings (bool): Whether to show crossings in the plot. - - Returns: - None + Plot the effective propagation constant for each mode as a function of inverse taper ratio (ITR). + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axes object on which to plot. + mode_of_interest : list of SuperMode + List of modes to be plotted. + show_crossings : bool, optional + Whether to show crossings in the plot (default is False). + + Examples + -------- + >>> fig, ax = plt.subplots() + >>> superset_plots.plot_index(ax=ax, mode_of_interest=[mode1, mode2], show_crossings=True) + >>> plt.show() + + Returns + ------- + plt.Figure + The figure object containing the generated plots. This can be customized further or saved to a file. """ - for mode in mode_of_interest: - mode.eigen_value.render_on_ax(ax=ax) - mode.eigen_value._dress_ax(ax=ax) + self._plot_attribute(ax, mode_of_interest, 'beta', show_crossings) - if show_crossings: - self.add_crossings_to_ax(ax=ax, mode_of_interest=mode_of_interest, data_type='eigen_value') - - plt.legend() - - @wraps(_logic_eigen_value) - def get_figure_eigen_value(self, *args, **kwargs): - figure, ax = plt.subplots(1, 1) - self._logic_eigen_value(ax=ax, *args, **kwargs) - return figure - - @wraps(_logic_eigen_value) - def plot_eigen_value(self, *args, **kwargs): - with plt.style.context(plot_style): - figure, ax = plt.subplots(1, 1) - self._logic_eigen_value(ax=ax, *args, **kwargs) - plt.show() + @singular_plot_helper + def plot_eigen_value(self, ax: plt.Axes, mode_of_interest: list[SuperMode], show_crossings: bool = False) -> plt.Figure: + """ + Plot the computed eigen-values for each mode as a function of inverse taper ratio (ITR). + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axes object on which to plot. + mode_of_interest : list of SuperMode + List of modes to be plotted. + show_crossings : bool, optional + Whether to show crossings in the plot (default is False). + + Examples + -------- + >>> fig, ax = plt.subplots() + >>> superset_plots.plot_index(ax=ax, mode_of_interest=[mode1, mode2], show_crossings=True) + >>> plt.show() + + Returns + ------- + plt.Figure + The figure object containing the generated plots. This can be customized further or saved to a file. + """ + self._plot_attribute(ax, mode_of_interest, 'eigen_value', show_crossings) - # BEATING LENGTH------------------------------------------------------- - @parse_mode_of_interest - @parse_combination - def _logic_beating_length(self, + @combination_plot_helper + def plot_beating_length( + self, ax: plt.Axes, mode_of_interest: list[SuperMode], combination: list, - add_profile: list[AlphaProfile] = []) -> NoReturn: + add_profile: list[AlphaProfile] = []) -> plt.Figure: """ - Render a figure representing beating_length for each mode as a function of itr. - - Args: - mode_of_interest (list[SuperMode]): The mode of interest. - combination (list): The mode combinations. - add_profile (list[AlphaProfile]): List of profiles to add to the plot. - - Returns: - None + Render a figure representing the beating length for each mode combination as a function of inverse taper ratio (ITR). + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axes object on which to plot. + mode_of_interest : list of SuperMode + List of modes to be plotted. + combination : list + List of mode combinations. + add_profile : list of AlphaProfile, optional + List of profiles to add to the plot (default is an empty list). + + Returns + ------- + plt.Figure + The figure object containing the generated plots. This can be customized further or saved to a file. """ for mode_0, mode_1 in combination: - mode_0.beating_length.render_on_ax(ax=ax, other_supermode=mode_1) - mode_0.beating_length._dress_ax(ax=ax) - - plt.legend() - - @wraps(_logic_beating_length) - def get_figure_normalized_coupling(self, *args, **kwargs): - figure, ax = plt.subplots(1, 1) - self._logic_beating_length(ax=ax, *args, **kwargs) - return figure - - @wraps(_logic_beating_length) - def plot_beating_length(self, *args, **kwargs): - with plt.style.context(plot_style): - figure, ax = plt.subplots(1, 1) - self._logic_beating_length(ax=ax, *args, **kwargs) - plt.show() + mode_0.beating_length.plot(ax=ax, other_supermode=mode_1) + ax.legend() - # NORMALIZED COUPLING------------------------------------------------------- - @parse_mode_of_interest - @parse_combination - def _logic_normalized_coupling(self, + @combination_plot_helper + def plot_normalized_coupling( + self, ax: plt.Axes, mode_of_interest: list[SuperMode], combination: list, - add_profile: list[AlphaProfile] = []) -> NoReturn: + add_profile: list[AlphaProfile] = []) -> plt.Figure: """ - Render a figure representing normalized coupling for each mode as a function of itr. - - Args: - mode_of_interest (list[SuperMode]): The mode of interest. - combination (list): The mode combinations. - add_profile (list[AlphaProfile]): List of profiles to add to the plot. - - Returns: - None + Render a figure representing the normalized coupling for each mode combination as a function of inverse taper ratio (ITR). + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axes object on which to plot. + mode_of_interest : list of SuperMode + List of modes to be plotted. + combination : list + List of mode combinations. + add_profile : list of AlphaProfile, optional + List of profiles to add to the plot (default is an empty list). + + Returns + ------- + plt.Figure + The figure object containing the generated plots. This can be customized further or saved to a file. """ for mode_0, mode_1 in combination: - mode_0.normalized_coupling.render_on_ax(ax=ax, other_supermode=mode_1) - mode_0.normalized_coupling._dress_ax(ax=ax) - - plt.legend() - - @wraps(_logic_normalized_coupling) - def get_figure_normalized_coupling(self, *args, **kwargs): - figure, ax = plt.subplots(1, 1) - self._logic_normalized_coupling(ax=ax, *args, **kwargs) - return figure - - @wraps(_logic_normalized_coupling) - def plot_normalized_coupling(self, *args, **kwargs): - with plt.style.context(plot_style): - figure, ax = plt.subplots(1, 1) - self._logic_normalized_coupling(ax=ax, *args, **kwargs) - plt.show() + mode_0.normalized_coupling.plot(ax=ax, other_supermode=mode_1, show=False) + ax.legend() - # ADIABATIC CRITERION------------------------------------------------------- - @parse_mode_of_interest - @parse_combination - def _logic_adiabatic(self, + @combination_plot_helper + def plot_adiabatic( + self, ax: plt.Axes, mode_of_interest: list[SuperMode], combination: list, - add_profile: list[AlphaProfile] = []) -> NoReturn: + add_profile: list[AlphaProfile] = []) -> plt.Figure: """ - Render a figure representing adiabatic criterion for each mode as a function of itr. - - Args: - mode_of_interest (list[SuperMode]): The mode of interest. - combination (list): The mode combinations. - add_profile (list[AlphaProfile]): List of profiles to add to the plot. - - Returns: - None + Render a figure representing the adiabatic criterion for each mode combination as a function of inverse taper ratio (ITR). + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axes object on which to plot. + mode_of_interest : list of SuperMode + List of modes to be plotted. + combination : list + List of mode combinations. + add_profile : list of AlphaProfile, optional + List of profiles to add to the plot (default is an empty list). + + Returns + ------- + plt.Figure + The figure object containing the generated plots. This can be customized further or saved to a file. """ for mode_0, mode_1 in combination: - mode_0.adiabatic.render_on_ax(ax=ax, other_supermode=mode_1) - mode_0.adiabatic._dress_ax(ax=ax) + mode_0.adiabatic.plot(ax=ax, other_supermode=mode_1, show=False) for profile in numpy.atleast_1d(add_profile): profile.render_adiabatic_factor_vs_itr_on_ax(ax=ax, line_style='--') - plt.legend() - - @wraps(_logic_adiabatic) - def get_figure_adiabatic(self, *args, **kwargs): - figure, ax = plt.subplots(1, 1) - self._logic_adiabatic(ax=ax, *args, **kwargs) - return figure - - @wraps(_logic_adiabatic) - def plot_adiabatic(self, *args, **kwargs): - with plt.style.context(plot_style): - figure, ax = plt.subplots(1, 1) - self._logic_adiabatic(ax=ax, *args, **kwargs) - plt.show() + ax.legend() - # FIELD ------------------------------------------------------- @parse_mode_of_interest - def _logic_field( + def plot_field( self, mode_of_interest: list = 'all', itr_list: list[float] = None, slice_list: list[int] = None, show_mode_label: bool = True, show_itr: bool = True, - show_slice: bool = True) -> plt.Figure: + show_slice: bool = True, + show: bool = True) -> plt.Figure: """ Render the mode field for different ITR values or slice numbers. - Args: - mode_of_interest (list): List of modes to be plotted. Default is 'all'. - itr_list (list): List of ITR values for plotting. Default is None. - slice_list (list): List of slice numbers for plotting. Default is None. - show_mode_label (bool): Flag to display mode labels. Default is True. - show_itr (bool): Flag to display ITR values. Default is True. - show_slice (bool): Flag to display slice numbers. Default is True. - - Returns: - plt.Figure: The figure object containing the generated plots. + Parameters + ---------- + mode_of_interest : list, optional + List of modes to be plotted (default is 'all'). + itr_list : list of float, optional + List of ITR values for plotting (default is None). + slice_list : list of int, optional + List of slice numbers for plotting (default is None). + show_mode_label : bool, optional + Flag to display mode labels (default is True). + show_itr : bool, optional + Flag to display ITR values (default is True). + show_slice : bool, optional + Flag to display slice numbers (default is True). + + Returns + ------- + plt.Figure + The figure object containing the generated plots. This can be customized further or saved to a file. """ # Interpret input lists slice_list, itr_list = interpret_slice_number_and_itr( @@ -285,9 +264,12 @@ def _logic_field( mode_of_interest=mode_of_interest ) - # Determine the grid size for subplots - grid_size = numpy.array([len(slice_list), len(mode_of_interest)]) - figure, axes = plt.subplots(*grid_size, figsize=3 * numpy.flip(grid_size), squeeze=False) + n_mode = len(mode_of_interest) + n_slice = len(slice_list) + grid_size = numpy.array([n_slice, n_mode]) + + with plt.style.context(plot_style): + figure, axes = plt.subplots(*grid_size, figsize=3 * numpy.flip(grid_size), squeeze=False, sharex='row', sharey='row') # Plot each mode field on the grid for m, mode in enumerate(mode_of_interest): @@ -301,29 +283,35 @@ def _logic_field( ) figure.tight_layout() - return figure - @wraps(_logic_field) - def get_figure_field(self, *args, **kwargs): - figure = self._logic_field(*args, **kwargs) - return figure - - @wraps(_logic_field) - def plot_field(self, *args, **kwargs): - with plt.style.context(plot_style): - figure = self._logic_field(*args, **kwargs) + if show: plt.show() - def plot(self, plot_type: str, **kwargs) -> NoReturn: + return figure + + def plot(self, plot_type: str, **kwargs) -> None: """ General plotting function to handle different types of supermode plots. - Args: - plot_type (str): The type of plot to generate. Options include 'index', 'beta', 'eigen-value', etc. - **kwargs: Additional keyword arguments for specific plot configurations. - - Raises: - ValueError: If an unrecognized plot type is specified. + Parameters + ---------- + plot_type : str + The type of plot to generate. Options include 'index', 'beta', 'eigen-value', etc. + **kwargs : dict + Additional keyword arguments for specific plot configurations. + + Raises + ------ + ValueError + If an unrecognized plot type is specified. + + Examples + -------- + >>> superset_plots.plot(plot_type='index', ax=ax) + Generates an effective index plot. + + >>> superset_plots.plot(plot_type='invalid') + ValueError: Invalid plot type: invalid. Options are: index, beta, eigen-value, adiabatic, normalized-adiabatic, normalized-coupling, field, beating-length. """ match plot_type.lower(): case 'index': @@ -358,28 +346,39 @@ def generate_pdf_report( mode_of_interest: list = 'all', combination: str = 'specific') -> None: """ - Generate a full report of the coupler properties as a .pdf file. - - Args: - filename (str): Name of the report file to be output. - directory (str): Directory to save the report. - itr_list (List[float]): List of ITR values to evaluate the mode field. - slice_list (List[int]): List of slice values to evaluate the mode field. - dpi (int): Pixel density for the images included in the report. - mode_of_interest (List): List of modes to consider in the adiabatic criterion plotting. - combination (str): Method for selecting mode combinations. - - Returns: - None + Generate a full report of the coupler properties as a PDF file. + + Parameters + ---------- + filename : str, optional, default="auto" + Name of the report file to be output. If "auto", a default name will be generated based on the current timestamp. + directory : str, optional, default='.' + Directory to save the report. + itr_list : list of float, optional, default=None + List of ITR values to evaluate the mode field. If None, all available ITR values will be used. + slice_list : list of int, optional, default=None + List of slice values to evaluate the mode field. If None, all available slices will be used. + dpi : int, optional, default=200 + Pixel density for the images included in the report. + mode_of_interest : list, optional, default='all' + List of modes to consider in the report. If 'all', all available modes will be included. + combination : str, optional, default='specific' + Method for selecting mode combinations ('specific' or 'pairs'). + + Examples + -------- + >>> superset_plots.generate_pdf_report(filename="coupler_report", itr_list=[0.1, 0.2, 0.5], mode_of_interest=['LP01', 'LP11']) + This will generate a PDF report named 'coupler_report.pdf' containing plots for the specified modes at given ITR values. """ + kwargs = dict(show=False, mode_of_interest=mode_of_interest) figure_list = [ - self.geometry.render_plot(), - self.get_figure_field(itr_list=itr_list, slice_list=slice_list, mode_of_interest=mode_of_interest), - self.get_figure_index(mode_of_interest=mode_of_interest), - self.get_figure_beta(mode_of_interest=mode_of_interest), - self.get_figure_normalized_coupling(mode_of_interest=mode_of_interest, combination=combination), - self.get_figure_adiabatic(mode_of_interest=mode_of_interest, combination=combination) + # self.geometry.plot(show=False), + self.plot_field(itr_list=itr_list, slice_list=slice_list, **kwargs), + self.plot_index(**kwargs), + self.plot_beta(**kwargs), + self.plot_normalized_coupling(**kwargs, combination=combination), + self.plot_adiabatic(show=False, combination=combination) ] pp = PdfPages(filename.with_suffix('.pdf')) @@ -392,6 +391,19 @@ def generate_pdf_report( plt.close() def add_crossings_to_ax(self, ax: plt.Axes, mode_of_interest: list, data_type: str) -> None: + """ + Add mode crossings to the given axis. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axes object on which to plot. + mode_of_interest : list of SuperMode + List of modes of interest. + data_type : str + The type of data for which to find crossings (e.g., 'index', 'beta', etc.). + + """ combination = self.interpret_combination( mode_of_interest=mode_of_interest, combination='pairs' diff --git a/SuPyMode/utils.py b/SuPyMode/utils.py index aa46046a..dcf35a31 100644 --- a/SuPyMode/utils.py +++ b/SuPyMode/utils.py @@ -17,29 +17,6 @@ import logging -def parse_mode_of_interest(plot_function: Callable) -> Callable: - def wrapper(self, *args, mode_of_interest='all', **kwargs): - mode_of_interest = interpret_mode_of_interest( - superset=self, - mode_of_interest=mode_of_interest - ) - - return plot_function(self, *args, mode_of_interest=mode_of_interest, **kwargs) - - return wrapper - - -def parse_combination(plot_function: Callable) -> Callable: - def wrapper(self, *args, mode_of_interest='all', combination: str = 'pairs', **kwargs): - combination = self.interpret_combination( - mode_of_interest=mode_of_interest, - combination=combination - ) - - return plot_function(self, *args, mode_of_interest=mode_of_interest, combination=combination, **kwargs) - - return wrapper - def get_auto_generated_filename(superset) -> str: """ Generates a filename based on the simulation parameters. @@ -56,6 +33,7 @@ def get_auto_generated_filename(superset) -> str: ) return filename.replace('.', '_') + def parse_filename(save_function: Callable) -> Callable: @wraps(save_function) def wrapper(superset, filename: str = 'auto', directory: str = 'auto', **kwargs): diff --git a/SuPyMode/workflow.py b/SuPyMode/workflow.py index ad9bb142..1c27cfb3 100644 --- a/SuPyMode/workflow.py +++ b/SuPyMode/workflow.py @@ -1,18 +1,16 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +from FiberFusing.fiber import catalogue as fiber_catalogue # noqa: +from SuPyMode.profiles import AlphaProfile # noqa: F401 +from FiberFusing import configuration # noqa: F401 -from typing import List, Union, Optional, Tuple, Callable +from typing import List, Union, Optional, Tuple from pathlib import Path - from FiberFusing import Geometry, BackGround from FiberFusing.fiber.generic_fiber import GenericFiber from SuPyMode.solver import SuPySolver -from FiberFusing.fiber import catalogue as fiber_catalogue -from SuPyMode.profiles import AlphaProfile # noqa: F401 -from FiberFusing import configuration # noqa: F401 - +from PyOptik import MaterialBank from PyFinitDiff.finite_difference_2D import Boundaries -from pathvalidate import sanitize_filepath from pydantic.dataclasses import dataclass from pydantic import ConfigDict @@ -21,7 +19,6 @@ strict=True, arbitrary_types_allowed=True, kw_only=True, - frozen=True ) @@ -74,7 +71,7 @@ def prepare_simulation_geometry( def get_clad_index(clad_index: float, wavelength: float): """Retrieve the cladding index based on the input type.""" if isinstance(clad_index, str) and clad_index.lower() == 'silica': - return fiber_catalogue.get_silica_index(wavelength=wavelength) + return MaterialBank.fused_silica.compute_refractive_index(wavelength) elif isinstance(clad_index, (float, int)): return clad_index else: diff --git a/docs/examples/basic/plot_workflow_02.py b/docs/examples/basic/plot_workflow_02.py index 87132cb3..9d8c582b 100644 --- a/docs/examples/basic/plot_workflow_02.py +++ b/docs/examples/basic/plot_workflow_02.py @@ -37,13 +37,13 @@ clad_structure=clad_structure, # Cladding structure, if None provided then no cladding is set. fusion_degree=0.9, # Degree of fusion of the structure if applicable. wavelength=wavelength, # Wavelength used for the mode computation. - resolution=50, # Number of point in the x and y axis [is divided by half if symmetric or anti-symmetric boundaries]. + resolution=20, # Number of point in the x and y axis [is divided by half if symmetric or anti-symmetric boundaries]. x_bounds="left", # Mesh x-boundary structure. y_bounds="bottom", # Mesh y-boundary structure. boundaries=boundaries, # Set of symmetries to be evaluated, each symmetry add a round of simulation - n_sorted_mode=4, # Total computed and sorted mode. + n_sorted_mode=2, # Total computed and sorted mode. n_added_mode=2, # Additional computed mode that are not considered later except for field comparison [the higher the better but the slower]. - plot_geometry=True, # Plot the geometry mesh before computation. + # plot_geometry=True, # Plot the geometry mesh before computation. debug_mode=0, # Print the iteration step for the solver plus some other important steps. auto_label=True, # Auto labeling the mode. Label are not always correct and should be verified afterwards. itr_final=0.1, # Final value of inverse taper ratio to simulate @@ -53,25 +53,29 @@ superset = workflow.get_superset() +print(superset.supermodes) + +# superset[0].adiabatic.plot(superset[1]) + # %% # Field computation: :math:`E_{i,j}` # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -_ = superset.plot(plot_type='field', itr_list=[1.0, 0.1]) +# _ = superset.plot(plot_type='field', itr_list=[1.0, 0.1]) # %% # Effective index: :math:`n^{eff}_{i,j}` # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -_ = superset.plot(plot_type='index') +# _ = superset.plot(plot_type='beta') -# %% -# Modal normalized coupling: :math:`C_{i,j}` -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -_ = superset.plot(plot_type='normalized-coupling') +# # %% +# # Modal normalized coupling: :math:`C_{i,j}` +# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# _ = superset.plot(plot_type='normalized-coupling') # %% # Adiabatic criterion: :math:`\tilde{C}_{i,j}` # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -_ = superset.plot(plot_type='adiabatic') +# _ = superset.plot(plot_type='adiabatic') # - diff --git a/tests/test_api.py b/tests/test_api.py index 228163be..e33999ba 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -5,28 +5,17 @@ from SuPyMode.workflow import configuration, Workflow, fiber_catalogue, Boundaries -def test_superset_plot(): +@pytest.fixture(scope="module") +def precomputed_workflow(): """ - Test the creation and manipulation of a mode superset in a SuPyMode workflow. - - This function performs the following: - - Loads fibers from the fiber catalogue. - - Configures a fused structure for the fiber cladding. - - Initializes a workflow with specific parameters. - - Accesses and manipulates various solver and mode properties. - - Tests the labeling and resetting of supermodes. + Fixture to initialize the SuPyMode workflow for reuse across multiple tests. """ - - # Load two identical fibers from the catalogue fibers = [ fiber_catalogue.load_fiber('SMF28', wavelength=1550e-9), fiber_catalogue.load_fiber('SMF28', wavelength=1550e-9), ] - - # Define the cladding structure fused_structure = configuration.ring.FusedProfile_02x02 - # Initialize the workflow workflow = Workflow( fiber_list=fibers, clad_structure=fused_structure, @@ -40,40 +29,65 @@ def test_superset_plot(): n_sorted_mode=2, n_added_mode=2, ) + return workflow + + +def test_load_fibers(precomputed_workflow): + """ + Test loading fibers from the fiber catalogue. + """ + fibers = precomputed_workflow.fiber_list + assert fibers is not None, "Failed to load fibers from the catalogue." + assert len(fibers) == 2, "Incorrect number of fibers loaded." - # Access the solver - solver = workflow.solver - # Example operation: convert eigenvalue to index - _ = solver.eigen_value_to_index(3e6) +def test_initialize_workflow(precomputed_workflow): + """ + Test initializing a SuPyMode workflow with specific parameters. + """ + assert precomputed_workflow is not None, "Workflow initialization failed." - # Access the solver's coordinate system - _ = solver.coordinate_system - # Access the first mode in the superset - mode = workflow.superset[0] +def test_solver_properties(precomputed_workflow): + """ + Test accessing solver properties and performing eigenvalue conversion. + """ + solver = precomputed_workflow.solver + assert solver.eigen_value_to_index(3e6) is not None, "Eigenvalue conversion failed." + assert solver.coordinate_system is not None, "Solver coordinate system not accessible." - # Access various mode properties - _ = mode.geometry - _ = mode.coordinate_system - _ = mode.itr_list - _ = mode.model_parameters - _ = mode.binding_number - # Perform field interpolation for specific iterations and slices - _ = mode.get_field_interpolation(itr=1.0) - _ = mode.get_field_interpolation(slice_number=3) +def test_mode_properties(precomputed_workflow): + """ + Test accessing and validating properties of the first mode in the superset. + """ + mode = precomputed_workflow.superset[0] + assert mode.geometry is not None, "Mode geometry not accessible." + assert mode.coordinate_system is not None, "Mode coordinate system not accessible." + assert mode.itr_list is not None, "Mode ITR list not accessible." + assert mode.model_parameters is not None, "Mode model parameters not accessible." + assert mode.binding_number is not None, "Mode binding number not accessible." - # Label and reset labels for the supermodes - workflow.superset.label_supermodes('a', 'b') - workflow.superset.reset_labels() - workflow.superset.sort_modes('beta') +def test_field_interpolation(precomputed_workflow): + """ + Test field interpolation for a mode using ITR and slice number. + """ + mode = precomputed_workflow.superset[0] + assert mode.get_field_interpolation(itr=1.0) is not None, "Field interpolation by ITR failed." + assert mode.get_field_interpolation(slice_number=3) is not None, "Field interpolation by slice number failed." - workflow.superset.sort_modes('symmetry+beta') - workflow.superset.export_data(filename='test_data') +def test_superset_operations(precomputed_workflow): + """ + Test labeling, resetting labels, sorting, and exporting supermodes. + """ + precomputed_workflow.superset.label_supermodes('a', 'b') + precomputed_workflow.superset.reset_labels() + precomputed_workflow.superset.sort_modes('beta') + precomputed_workflow.superset.sort_modes('symmetry+beta') + precomputed_workflow.superset.export_data(filename='test_data') -if __name__ == '__main__': - pytest.main([__file__]) +if __name__ == "__main__": + pytest.main(["-W error", __file__]) diff --git a/tests/test_gradient.py b/tests/test_gradient.py index 823333b3..8b1511fe 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -50,4 +50,5 @@ def test_null_gradient(x_vector, y_vector, rho_mesh): assert numpy.all(condition), f"Error in constant (expected value: 0.0) gradient computation. Mean gradient value: {condition.mean()}" -# - +if __name__ == "__main__": + pytest.main(["-W error", __file__]) diff --git a/tests/test_mode_label.py b/tests/test_mode_label.py index 69b25457..71a9eace 100644 --- a/tests/test_mode_label.py +++ b/tests/test_mode_label.py @@ -23,4 +23,6 @@ def test_configurations(configuration: dict): assert mode_name == mode_label.raw_label, f"Mismatch between expected mode_label for auto-labeler: {mode_name} vs {mode_label.label}" -# - + +if __name__ == "__main__": + pytest.main(["-W error", __file__]) diff --git a/tests/test_plots.py b/tests/test_plots.py index 6cf493d4..06dfcff0 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -7,9 +7,25 @@ from SuPyMode.workflow import configuration, Workflow, fiber_catalogue, Boundaries -@pytest.fixture +PARAMETER_LIST = [ + 'index', 'beta', 'eigen-value', 'normalized-coupling', 'adiabatic', 'field' +] + + +@pytest.fixture(scope="module") def setup_workflow(): - """ Fixture to set up the workflow with common settings for tests. """ + """ + Set up a common workflow instance for testing purposes. + + This fixture initializes a Workflow instance using standard fibers and a specific fused structure. + The workflow is set up with common parameters such as resolution, wavelength, boundaries, etc. + It is used across multiple tests to avoid recomputation and ensure consistent test conditions. + + Returns + ------- + Workflow + An instance of the Workflow class configured with predefined parameters for testing. + """ fibers = [fiber_catalogue.load_fiber('SMF28', wavelength=1550e-9) for _ in range(2)] fused_structure = configuration.ring.FusedProfile_02x02 @@ -27,40 +43,49 @@ def setup_workflow(): ) -parameter_list = [ - 'index', 'beta', 'eigen-value', 'normalized-coupling', 'adiabatic', 'field' -] - - -@pytest.mark.parametrize("plot_type", parameter_list) +@pytest.mark.parametrize("plot_type", PARAMETER_LIST) @patch("matplotlib.pyplot.show") def test_superset_plot(mock_show, setup_workflow, plot_type): """ - Tests plotting functionalities for various superset properties using mocked display to verify plots are invoked without display. - - Args: - mock_show (MagicMock): Mock for matplotlib.pyplot.show to prevent GUI display during testing. - setup_workflow (Workflow): A Workflow instance set up via a fixture to standardize test setup. - plot_type (str): Type of plot to generate and test. + Test plotting functionalities for various superset properties. + + Uses a mocked display to verify that plots for different properties of the superset can be + invoked without actually displaying the GUI, ensuring that the plot functions are called correctly. + + Parameters + ---------- + mock_show : MagicMock + Mock for `matplotlib.pyplot.show` to prevent GUI display during testing. + setup_workflow : Workflow + A Workflow instance set up via a fixture to standardize test setup. + plot_type : str + Type of plot to generate and test (e.g., 'index', 'beta'). """ superset = setup_workflow.superset superset.plot(plot_type=plot_type, mode_of_interest='fundamental') mock_show.assert_called_once() mock_show.reset_mock() + plt.close() -@pytest.mark.parametrize("plot_type", parameter_list) +@pytest.mark.parametrize("plot_type", PARAMETER_LIST) @patch("matplotlib.pyplot.show") def test_representation_plot(mock_show, setup_workflow, plot_type): """ - Tests individual mode plotting functionalities within the superset to ensure each plot type can be generated. - - This function verifies that plots for each mode's specific attribute can be called and displayed (mocked). - - Args: - mock_show (MagicMock): Mock for matplotlib.pyplot.show to prevent GUI display during testing. - setup_workflow (Workflow): A Workflow instance set up via a fixture. - plot_type (str): Type of mode-specific plot to generate and test. + Test plotting functionalities for individual modes within the superset. + + This function ensures that each mode-specific attribute plot can be called and displayed. + For properties like 'normalized-coupling' and 'adiabatic', which require comparing two modes, + the test uses a second mode for completeness. + + Parameters + ---------- + mock_show : MagicMock + Mock for `matplotlib.pyplot.show` to prevent GUI display during testing. + setup_workflow : Workflow + A Workflow instance set up via a fixture. + plot_type : str + Type of mode-specific plot to generate and test (e.g., 'index', 'beta', 'normalized-coupling'). """ mode = setup_workflow.superset[0] # Use the first mode from the superset @@ -77,5 +102,4 @@ def test_representation_plot(mock_show, setup_workflow, plot_type): if __name__ == "__main__": - pytest.main([__file__]) -# - + pytest.main(["-W error", __file__]) diff --git a/tests/test_profile.py b/tests/test_profile.py index 448cd932..0e53357e 100644 --- a/tests/test_profile.py +++ b/tests/test_profile.py @@ -3,17 +3,46 @@ from SuPyMode.profiles import AlphaProfile -@pytest.fixture +@pytest.fixture(scope="module") def alpha_profile(): - """Fixture to create an AlphaProfile instance.""" - profile = AlphaProfile(initial_radius=1) - return profile + """ + Fixture to create an AlphaProfile instance with initial parameters. + + This fixture is shared across multiple tests to avoid redundant reinitialization, + improving test performance. + + Returns + ------- + AlphaProfile + An initialized AlphaProfile instance. + """ + return AlphaProfile(initial_radius=1) + + +@pytest.fixture(scope="module") +def asymmetric_alpha_profile(): + """ + Fixture to create an asymmetric AlphaProfile instance. + + Returns + ------- + AlphaProfile + An initialized asymmetric AlphaProfile instance. + """ + return AlphaProfile(initial_radius=1, symmetric=False) @patch("matplotlib.pyplot.show") def test_build_single_segment_profile(mock_show, alpha_profile): """ - Test building a profile with a single taper segment. + Test creating and plotting a profile with a single taper segment. + + Parameters + ---------- + mock_show : MagicMock + Mock for `matplotlib.pyplot.show` to prevent actual plot display. + alpha_profile : AlphaProfile + The profile fixture to use in the test. """ alpha_profile.add_taper_segment( alpha=0, @@ -29,7 +58,14 @@ def test_build_single_segment_profile(mock_show, alpha_profile): @patch("matplotlib.pyplot.show") def test_build_two_segment_profile(mock_show, alpha_profile): """ - Test building a profile with two taper segments. + Test creating and plotting a profile with two taper segments. + + Parameters + ---------- + mock_show : MagicMock + Mock for `matplotlib.pyplot.show` to prevent actual plot display. + alpha_profile : AlphaProfile + The profile fixture to use in the test. """ alpha_profile.add_taper_segment( alpha=0, @@ -48,25 +84,36 @@ def test_build_two_segment_profile(mock_show, alpha_profile): @patch("matplotlib.pyplot.show") -def test_build_asymmetric_profile(mock_show): +def test_build_asymmetric_profile(mock_show, asymmetric_alpha_profile): """ - Test building an asymmetric profile with a single taper segment. + Test creating and plotting an asymmetric profile with a single taper segment. + + Parameters + ---------- + mock_show : MagicMock + Mock for `matplotlib.pyplot.show` to prevent actual plot display. + asymmetric_alpha_profile : AlphaProfile + The asymmetric profile fixture to use in the test. """ - asymmetric_profile = AlphaProfile(initial_radius=1, symmetric=False) - asymmetric_profile.add_taper_segment( + asymmetric_alpha_profile.add_taper_segment( alpha=0, initial_heating_length=10e-3, stretching_length=0.2e-3 * 200 ) - asymmetric_profile.initialize() - asymmetric_profile.plot() + asymmetric_alpha_profile.initialize() + asymmetric_alpha_profile.plot() mock_show.assert_called_once() def test_generate_propagation_gif(alpha_profile): """ - Test generating a GIF from the profile data. + Test generating a propagation GIF from the profile data. + + Parameters + ---------- + alpha_profile : AlphaProfile + The profile fixture to use in the test. """ alpha_profile.add_taper_segment( alpha=0, @@ -76,4 +123,8 @@ def test_generate_propagation_gif(alpha_profile): alpha_profile.initialize() alpha_profile.generate_propagation_gif(number_of_frames=10) - # Assertions can be added here if generate_propagation_gif outputs testable results + # Add assertions if generate_propagation_gif outputs any testable results + + +if __name__ == "__main__": + pytest.main(["-W error", __file__]) diff --git a/tests/test_validation_normalized_coupling.py b/tests/test_validation_normalized_coupling.py index 2634701a..73917115 100644 --- a/tests/test_validation_normalized_coupling.py +++ b/tests/test_validation_normalized_coupling.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- - +import pytest import numpy from SuPyMode.workflow import Workflow, fiber_catalogue, Boundaries from PyFiberModes.__future__ import get_normalized_LP_coupling @@ -78,4 +78,6 @@ def test_normalized_coupling( if mean_relative_error > 0.1: raise AssertionError(f"Discrepancy between computed and analytical normalized coupling: [Mean Error: {error.mean()}, Mean Relative Error: {mean_relative_error}]") -# - + +if __name__ == "__main__": + pytest.main(["-W error", __file__]) diff --git a/tests/test_validation_propagation_constant.py b/tests/test_validation_propagation_constant.py index 709e6ba7..b7c29a6c 100644 --- a/tests/test_validation_propagation_constant.py +++ b/tests/test_validation_propagation_constant.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- - +import pytest import numpy from SuPyMode.workflow import Workflow, fiber_catalogue, Boundaries import PyFiberModes @@ -69,4 +69,7 @@ def test_propagation_constant( error = numpy.abs(analytical - simulation) relative_error = error / numpy.abs(analytical) raise AssertionError(f"Discrepancy between computed and analytical propagation constants. Mean Error: {error.mean()}, Mean Relative Error: {relative_error.mean()}") -# - + + +if __name__ == "__main__": + pytest.main(["-W error", __file__]) diff --git a/tests/test_validation_symmetry.py b/tests/test_validation_symmetry.py index cf669b1d..f1169fa5 100644 --- a/tests/test_validation_symmetry.py +++ b/tests/test_validation_symmetry.py @@ -1,11 +1,19 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- - +import pytest import numpy from SuPyMode.workflow import Workflow, configuration, fiber_catalogue, Boundaries +BOUNDARIES_LIST = [ + dict(x_bounds='left', boundaries=[Boundaries(right='symmetric')]), + dict(x_bounds='right', boundaries=[Boundaries(left='symmetric')]), + dict(y_bounds='top', boundaries=[Boundaries(bottom='symmetric')]), + dict(y_bounds='bottom', boundaries=[Boundaries(top='symmetric')]), +] + -def test_symmetry(fiber_name: str = 'DCF1300S_33', wavelength: float = 1.55e-6, resolution: int = 10): +@pytest.mark.parametrize('boundaries', BOUNDARIES_LIST) +def test_symmetry(boundaries, fiber_name: str = 'DCF1300S_33', wavelength: float = 1.55e-6, resolution: int = 21): """ Tests the effect of symmetric and asymmetric boundary conditions on the computed indices in a fiber modeling workflow. @@ -26,42 +34,40 @@ def test_symmetry(fiber_name: str = 'DCF1300S_33', wavelength: float = 1.55e-6, n_added_mode=2, ) + fiber_list = [fiber_catalogue.load_fiber(fiber_name, wavelength=wavelength)] + clad_structure = configuration.ring.FusedProfile_01x01 + # Setup for asymmetric boundary conditions reference_workflow = Workflow( - fiber_list=[fiber_catalogue.load_fiber(fiber_name, wavelength=wavelength)], - clad_structure=configuration.ring.FusedProfile_01x01, + fiber_list=fiber_list, + clad_structure=clad_structure, **kwargs, - x_bounds="centering", # Centered, implying no special treatment to symmetry + x_bounds="centering", y_bounds="centering", - boundaries=[Boundaries()], # No special symmetry boundaries + boundaries=[Boundaries()], ) - boundaries_dict_list = [ - dict(x_bounds='left', boundaries=[Boundaries(right='symmetric')]), - dict(x_bounds='right', boundaries=[Boundaries(left='symmetric')]), - dict(y_bounds='top', boundaries=[Boundaries(bottom='symmetric')]), - dict(y_bounds='bottom', boundaries=[Boundaries(top='symmetric')]), - ] + # Setup for symmetric boundary conditions + left_workflow = Workflow( + fiber_list=fiber_list, + clad_structure=clad_structure, + **boundaries, + **kwargs + ) - for boundaries_dict in boundaries_dict_list: + # Compare the effective index data between the two configurations + discrepancy = numpy.isclose( + reference_workflow.superset[0].index.data, + left_workflow.superset[0].index.data, + atol=1e-10, + rtol=1e-10 + ) - # Setup for symmetric boundary conditions - left_workflow = Workflow( - fiber_list=[fiber_catalogue.load_fiber(fiber_name, wavelength=wavelength)], - clad_structure=configuration.ring.FusedProfile_01x01, - **boundaries_dict, - **kwargs - ) + difference = abs(reference_workflow.superset[0].index.data - left_workflow.superset[0].index.data) - # Compare the effective index data between the two configurations - discrepancy = numpy.isclose( - reference_workflow.superset[0].index.data, - left_workflow.superset[0].index.data, - atol=1e-10, - rtol=1e-10 - ) + if numpy.mean(discrepancy) <= 0.9: + raise ValueError(f"Mismatch [{numpy.mean(difference):.5e}] between: non-symmetric and symmetric symmetry-based formulation of the numerical problem.") - if numpy.mean(discrepancy) <= 0.9: - raise ValueError("Mismatch between: non-symmetric and symmetric symmetry-based formulation of the numerical problem.") -# - +if __name__ == "__main__": + pytest.main(["-W error", __file__]) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index ccb7c96f..f9f1256a 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -4,13 +4,14 @@ import pytest from SuPyMode.workflow import configuration, Workflow, fiber_catalogue, Boundaries - +import matplotlib.pyplot as plt # Define a list of fused fiber structures from the configuration module fused_structure_list = [ configuration.ring.FusedProfile_01x01, configuration.ring.FusedProfile_02x02, ] + @pytest.mark.parametrize('fused_structure', fused_structure_list, ids=lambda x: f"n_fiber: {x.number_of_fibers}") def test_workflow(fused_structure): """ @@ -33,7 +34,7 @@ def test_workflow(fused_structure): fiber_list=fibers, clad_structure=fused_structure, wavelength=1550e-9, - resolution=30, + resolution=10, x_bounds="left", y_bounds="centering", boundaries=[Boundaries(right='symmetric')], @@ -45,11 +46,13 @@ def test_workflow(fused_structure): workflow.generate_pdf_report(filename='test_0') - workflow.save_superset_instance(filename='test_0') + plt.close() + + # workflow.save_superset_instance(filename='test_0') - # Assert that the workflow instance has been successfully created (basic check) - assert workflow is not None, "Workflow should be successfully instantiated with the given configurations." + # # Assert that the workflow instance has been successfully created (basic check) + # assert workflow is not None, "Workflow should be successfully instantiated with the given configurations." -if __name__ == '__main__': - pytest.main([__file__]) \ No newline at end of file +if __name__ == "__main__": + pytest.main(["-W error", __file__])