Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UI functionality to reset inputs #778

Merged
merged 5 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from trame.widgets import vuetify

from ...Input.trameFunctions import TrameFunctions
from ...trame_setup import setup_server
from ..generalFunctions import generalFunctions

Expand Down Expand Up @@ -39,6 +40,9 @@ def card():
with vuetify.VCard(v_show="csr", style="width: 170px;"):
with vuetify.VCardTitle("CSR"):
vuetify.VSpacer()
TrameFunctions.create_refresh_button(
lambda: generalFunctions.reset_inputs("csr")
)
vuetify.VIcon(
"mdi-information",
classes="ml-2",
Expand Down
14 changes: 11 additions & 3 deletions src/python/impactx/dashboard/Input/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ class DashboardDefaults:
# Inputs by section
# -------------------------------------------------------------------------

INPUT_PARAMETERS = {
SELECTION = {
"space_charge": False,
"csr": False,
}

INPUT_PARAMETERS = {
"charge_qe": -1,
"mass_MeV": 0.51099895,
"npart": 1000,
Expand All @@ -33,8 +36,12 @@ class DashboardDefaults:
"poisson_solver": "fft",
"particle_shape": 2,
"max_level": 0,
"n_cell": 32,
"blocking_factor": 16,
"n_cell_x": 32,
"n_cell_y": 32,
"n_cell_z": 32,
"blocking_factor_x": 16,
"blocking_factor_y": 16,
"blocking_factor_z": 16,
"prob_relative_first_value_fft": 1.1,
"prob_relative_first_value_multigrid": 3.1,
"mlmg_relative_tolerance": 1.0e-7,
Expand All @@ -61,6 +68,7 @@ class DashboardDefaults:
# -------------------------------------------------------------------------

DEFAULT_VALUES = {
**SELECTION,
**INPUT_PARAMETERS,
**DISTRIBUTION,
**LATTICE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def convert_distribution_parameters_to_valid_type():
param["parameter_name"]: float(param["parameter_default_value"])
if param_is_valid
else 0.0
for param in state.selectedDistributionParameters
for param in state.selected_distribution_parameters
if (param_is_valid := param["parameter_error_message"] == [])
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from impactx import distribution

from ...Input.trameFunctions import TrameFunctions
from ...trame_setup import setup_server
from ..generalFunctions import generalFunctions
from .distributionFunctions import DistributionFunctions
Expand All @@ -34,30 +35,30 @@
# Defaults
# -----------------------------------------------------------------------------

state.selectedDistribution = generalFunctions.get_default(
state.selected_distribution = generalFunctions.get_default(
"selected_distribution", "default_values"
)
state.selectedDistributionType = generalFunctions.get_default(
state.selected_distribution_type = generalFunctions.get_default(
"selected_distribution_type", "default_values"
)
state.selectedDistributionParameters = []
state.selected_distribution_parameters = []
state.distributionTypeDisabled = False

# -----------------------------------------------------------------------------
# Main Functions
# -----------------------------------------------------------------------------


def populate_distribution_parameters(selectedDistribution):
def populate_distribution_parameters(selected_distribution):
Fixed Show fixed Hide fixed
"""
Populates distribution parameters based on the selected distribution.
:param selectedDistribution (str): The name of the selected distribution
:param selected_distribution (str): The name of the selected distribution
whose parameters need to be populated.
"""

if state.selectedDistributionType == "Twiss":
if state.selected_distribution_type == "Twiss":
sig = inspect.signature(twiss)
state.selectedDistributionParameters = [
state.selected_distribution_parameters = [
{
"parameter_name": param.name,
"parameter_default_value": param.default
Expand All @@ -76,13 +77,13 @@ def populate_distribution_parameters(selectedDistribution):
]

else: # when type == 'Quadratic Form'
selectedDistributionParameters = (
selected_distribution_parameters = (
state.listOfDistributionsAndParametersAndDefault.get(
selectedDistribution, []
selected_distribution, []
)
)

state.selectedDistributionParameters = [
state.selected_distribution_parameters = [
{
"parameter_name": parameter[0],
"parameter_default_value": parameter[1],
Expand All @@ -95,11 +96,11 @@ def populate_distribution_parameters(selectedDistribution):
else "",
"parameter_step": generalFunctions.get_default(parameter[0], "steps"),
}
for parameter in selectedDistributionParameters
for parameter in selected_distribution_parameters
]

generalFunctions.update_simulation_validation_status()
return selectedDistributionParameters
generalFunctions.update_simulation_validation_status()
return state.selected_distribution_parameters


def update_distribution_parameters(
Expand All @@ -113,13 +114,13 @@ def update_distribution_parameters(
:param parameterErrorMessage: The error message related to the parameter's value.
"""

for param in state.selectedDistributionParameters:
for param in state.selected_distribution_parameters:
if param["parameter_name"] == parameterName:
param["parameter_default_value"] = parameterValue
param["parameter_error_message"] = parameterErrorMessage

generalFunctions.update_simulation_validation_status()
state.dirty("selectedDistributionParameters")
state.dirty("selected_distribution_parameters")


# -----------------------------------------------------------------------------
Expand All @@ -133,10 +134,10 @@ def distribution_parameters():
initialized with the appropriate parameters provided by the user.
"""

distribution_name = state.selectedDistribution
distribution_name = state.selected_distribution
parameters = DistributionFunctions.convert_distribution_parameters_to_valid_type()

if state.selectedDistributionType == "Twiss":
if state.selected_distribution_type == "Twiss":
twiss_params = twiss(**parameters)
distr = getattr(distribution, distribution_name)(**twiss_params)
else:
Expand All @@ -150,20 +151,20 @@ def distribution_parameters():
# -----------------------------------------------------------------------------


@state.change("selectedDistribution")
def on_distribution_name_change(selectedDistribution, **kwargs):
if selectedDistribution == "Thermal":
state.selectedDistributionType = "Quadratic Form"
@state.change("selected_distribution")
def on_distribution_name_change(selected_distribution, **kwargs):
if selected_distribution == "Thermal":
state.selected_distribution_type = "Quadratic Form"
state.distributionTypeDisabled = True
state.dirty("selectedDistributionType")
state.dirty("selected_distribution_type")
else:
state.distributionTypeDisabled = False
populate_distribution_parameters(selectedDistribution)
populate_distribution_parameters(selected_distribution)


@state.change("selectedDistributionType")
@state.change("selected_distribution_type")
def on_distribution_type_change(**kwargs):
populate_distribution_parameters(state.selectedDistribution)
populate_distribution_parameters(state.selected_distribution)


@ctrl.add("updateDistributionParameters")
Expand Down Expand Up @@ -193,6 +194,9 @@ def card():
with vuetify.VCard(style="width: 340px; height: 300px"):
with vuetify.VCardTitle("Distribution Parameters"):
vuetify.VSpacer()
TrameFunctions.create_refresh_button(
lambda: generalFunctions.reset_inputs("distribution")
)
vuetify.VIcon(
"mdi-information",
style="color: #00313C;",
Expand All @@ -204,13 +208,13 @@ def card():
with vuetify.VCol(cols=6):
vuetify.VCombobox(
label="Select Distribution",
v_model=("selectedDistribution",),
v_model=("selected_distribution",),
items=("listOfDistributions",),
dense=True,
)
with vuetify.VCol(cols=6):
vuetify.VSelect(
v_model=("selectedDistributionType",),
v_model=("selected_distribution_type",),
label="Type",
items=(
generalFunctions.get_default(
Expand All @@ -224,7 +228,7 @@ def card():
for i in range(3):
with vuetify.VCol(cols=4, classes="py-0"):
with vuetify.VRow(
v_for="(parameter, index) in selectedDistributionParameters"
v_for="(parameter, index) in selected_distribution_parameters"
):
with vuetify.VCol(
v_if=f"index % 3 == {i}", classes="py-1"
Expand Down
35 changes: 32 additions & 3 deletions src/python/impactx/dashboard/Input/generalFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,14 @@ def update_simulation_validation_status():
error_details = []

# Check for errors in distribution parameters
for param in state.selectedDistributionParameters:
for param in state.selected_distribution_parameters:
if param["parameter_error_message"]:
error_details.append(
f"{param['parameter_name']}: {param['parameter_error_message']}"
)

# Check for errors in lattice parameters
for lattice in state.selectedLatticeList:
for lattice in state.selected_lattice_list:
for param in lattice["parameters"]:
if param["parameter_error_message"]:
error_details.append(
Expand All @@ -165,7 +165,7 @@ def update_simulation_validation_status():
if state.mass_MeV_validation:
error_details.append(f"Ref. Particle Mass: {state.mass_MeV}")

if state.selectedLatticeList == []:
if state.selected_lattice_list == []:
error_details.append("LatticeListIsEmpty")

# Check for errors in CSR parameters
Expand Down Expand Up @@ -326,3 +326,32 @@ def convert_to_correct_type(value, desired_type):
return str(value)
else:
raise ValueError("Unknown type")

@staticmethod
def reset_inputs(input_section):
"""
Resets dashboard inputs to default values.

:param input_section: The input section to reset.
"""

possible_section_names = []
for name in vars(DashboardDefaults):
if name != "DEFAULT_VALUES" and name.isupper():
possible_section_names.append(name)

if input_section.upper() in possible_section_names:
state.update(getattr(DashboardDefaults, input_section.upper()))

if input_section == "distribution":
state.dirty("selected_distribution_type")
elif input_section == "lattice":
state.selected_lattice_list = []
elif input_section == "space_charge":
state.dirty("max_level")

elif input_section == "all":
state.update(DashboardDefaults.DEFAULT_VALUES)
state.dirty("selected_distribution_type")
state.selected_lattice_list = []
state.dirty("max_level")
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from trame.widgets import vuetify

from ...Input.trameFunctions import TrameFunctions
from ...trame_setup import setup_server
from ..generalFunctions import generalFunctions
from .inputFunctions import InputFunctions
Expand Down Expand Up @@ -97,6 +98,9 @@ def card(self):
with vuetify.VCard(style="width: 340px; height: 350px"):
with vuetify.VCardTitle("Input Parameters"):
vuetify.VSpacer()
TrameFunctions.create_refresh_button(
lambda: generalFunctions.reset_inputs("input_parameters")
)
vuetify.VIcon(
"mdi-information",
style="color: #00313C;",
Expand Down
Loading
Loading