Skip to content

Commit

Permalink
Merge pull request #579 from hschilling/I523-aviary-optimization-hist…
Browse files Browse the repository at this point in the history
…ory-plot

New optimization history tab in dashboard
  • Loading branch information
jkirk5 authored Nov 19, 2024
2 parents 8fc550d + f201563 commit f66d8bd
Showing 1 changed file with 212 additions and 57 deletions.
269 changes: 212 additions & 57 deletions aviary/visualization/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@

import numpy as np

import bokeh.palettes as bp
from bokeh.models import Legend, CheckboxGroup, CustomJS
import pandas as pd

from bokeh.models import Legend, LegendItem, CheckboxGroup, CustomJS, TextInput, ColumnDataSource, CustomJS, Div, Range1d, LinearAxis, PrintfTickFormatter
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource
from bokeh.layouts import column
from bokeh.palettes import Category10, Category20, d3

import hvplot.pandas # noqa # need this ! Otherwise hvplot using DataFrames does not work
import pandas as pd
import panel as pn
from panel.theme import DefaultTheme

import openmdao.api as om
from openmdao.utils.general_utils import env_truthy
Expand Down Expand Up @@ -474,7 +473,7 @@ def create_aviary_variables_table_data_nested(script_name, recorder_file):
return table_data_nested


def convert_case_recorder_file_to_df(recorder_file_name):
def convert_driver_case_recorder_file_to_df(recorder_file_name):
"""
Convert a case recorder file into a Pandas data frame.
Expand Down Expand Up @@ -588,7 +587,202 @@ def _get_interactive_plot_sources(data_by_varname_and_phase, x_varname, y_varnam
return [], []


def create_optimization_history_plot(case_recorder, df):

# Create a ColumnDataSource
source = ColumnDataSource(df)

# Create a Bokeh figure
plotting_figure = figure(title='Optimization History',
width=1000,
height=600,
)
plotting_figure.title.align = 'center'
plotting_figure.yaxis.visible = False
plotting_figure.xaxis.axis_label = 'Iterations'
plotting_figure.yaxis.formatter = PrintfTickFormatter(format="%5.2e")
plotting_figure.title.text_font_size = "25px"

# Choose a palette
palette = Category20[20]

# Plot each time series and keep references to the renderers
renderers = {}
variable_names = list(df.columns)[1:]
for i, variable_name in enumerate(variable_names):
color = palette[i % 20]

renderers[variable_name] = plotting_figure.line(
x='iter_count',
y=variable_name,
source=source,
y_range_name=f"extra_y_{variable_name}",
color=color,
line_width=2,
visible=False, # hide them all initially. clicking checkboxes makes them visible
)

# create axes both to the right and left of the plot.
# hide them initially
# as the user selects/deselects variables to be plotted, they get turned on/off
extra_y_axis = LinearAxis(y_range_name=f"extra_y_{variable_name}",
axis_label=f"{variable_name}",
axis_label_text_color=color)
plotting_figure.add_layout(extra_y_axis, 'right')
plotting_figure.right[i].visible = False

extra_y_axis = LinearAxis(y_range_name=f"extra_y_{variable_name}",
axis_label=f"{variable_name}",
axis_label_text_color=color)
plotting_figure.add_layout(extra_y_axis, 'left')
plotting_figure.left[i + 1].visible = False

# set the range
y_min = df[variable_name].min()
y_max = df[variable_name].max()
# if the range is zero, the axis will not be displayed. Plus need some range to make it
# look good. Some other code seems to do +- 1 for the range in this case.
if y_min == y_max:
y_min = y_min - 1
y_max = y_max + 1
plotting_figure.extra_y_ranges[f"extra_y_{variable_name}"] = Range1d(
y_min, y_max)

# Make a Legend with no items in it. those will be added in JavaScript
# as users select variables to be plotted
legend = Legend(items=[], location=(-50, -5), border_line_width=0)

# make the legend items in Python. Pass them to JavaScript where they can be added to the Legend
legend_items = []
for variable_name in variable_names:
units = case_recorder.problem_metadata['variables'][variable_name]['units']
legend_item = LegendItem(label=f"{variable_name} ({units})", renderers=[
renderers[variable_name]])
legend_items.append(legend_item)

plotting_figure.add_layout(legend, 'below')

# make the list of variables with checkboxes
data_source = ColumnDataSource(
data=dict(options=variable_names, checked=[False]*len(variable_names)))
# Create a Div to act as a scrollable container
variable_scroll_box = Div(
styles={
'overflow-y': 'scroll',
'height': '500px',
'border': '1px solid #ddd',
'padding': '10px'
}
)

# make the text box used to filter variables
filter_variables_text_box = TextInput(placeholder='Variable name filter')

# CustomJS callback for checkbox changes
variable_checkbox_callback = CustomJS(args=dict(data_source=data_source,
plotting_figure=plotting_figure,
renderers=renderers,
legend=legend,
legend_items=legend_items),
code="""
// Three things happen in this code.
// 1. turn on/off the plot lines
// 2. show the legend items for the items being plotted
// 3. show the y axes for each of the lines being plotted
// The incoming Legend is empty. The items are passed in separately
// 1. Plots
// turn off or on the line plot for the clicked on variable
const checkedIndex = cb_obj.index;
const isChecked = cb_obj.checked;
data_source.data['checked'][checkedIndex] = isChecked;
renderers[data_source.data['options'][checkedIndex]].visible = isChecked;
// 2. Legend
// empty the Legend items and then add in the ones for the variables that are checked
legend.items = [];
for (let i =0; i < legend_items.length; i++){
if ( data_source.data['checked'][i] ) {
legend.items.push(legend_items[i]);
}
}
// 3. Y axes
// first hide all of them
for (let i =0; i < legend_items.length; i++){
var extra_y_axis = plotting_figure.left[i + 1];
extra_y_axis.visible = false ;
var extra_y_axis = plotting_figure.right[i];
extra_y_axis.visible = false ;
}
// alternate between making visible the axes on the left and the right to make it more even.
// this variable keeps track of which side to add the axes to.
let put_on_left_side = true;
for (let i =0; i < legend_items.length; i++){
if (data_source.data['checked'][i]){
if (put_on_left_side){
plotting_figure.left[i + 1].visible = true;
} else {
plotting_figure.right[i].visible = true;
}
put_on_left_side = ! put_on_left_side ;
}
}
data_source.change.emit();
""")

# CustomJS callback for the variable filtering
filter_variables_callback = CustomJS(args=dict(data_source=data_source,
variable_scroll_box=variable_scroll_box,
variable_checkbox_callback=variable_checkbox_callback),
code="""
const filter_text = cb_obj.value.toLowerCase();
const all_options = data_source.data['options'];
const checked_states = data_source.data['checked'];
// Filter options
const filtered_options = all_options.filter(option =>
option.toLowerCase().includes(filter_text)
);
// Update the scroll box content
let checkboxes_html = '';
filtered_options.forEach((label) => {
const index = all_options.indexOf(label);
checkboxes_html += `
<label style="display:block; margin-bottom:5px;">
<input type="checkbox" value="${label}" ${checked_states[index] ? 'checked' : ''}
onchange="Bokeh.documents[0].get_model_by_id('${variable_checkbox_callback.id}').execute({index: ${index}, checked: this.checked})">
${label}
</label>
`;
});
variable_scroll_box.text = checkboxes_html;
""")

filter_variables_text_box.js_on_change('value', filter_variables_callback)

# Initial population of the scroll box
initial_html = ''.join(f"""
<label style="display:block; margin-bottom:5px;">
<input type="checkbox" value="{variable_name}"
onchange="Bokeh.documents[0].get_model_by_id('{variable_checkbox_callback.id}').execute({{index: {i}, checked: this.checked}})">
{variable_name}
</label>
""" for i, variable_name in enumerate(variable_names))
variable_scroll_box.text = initial_html

# Arrange the layout using Panel
layout = pn.Row(pn.Column(filter_variables_text_box,
variable_scroll_box), plotting_figure)

return layout

# The main script that generates all the tabs in the dashboard


def dashboard(script_name, problem_recorder, driver_recorder, port, run_in_background=False):
"""
Generate the dashboard app display.
Expand Down Expand Up @@ -669,9 +863,6 @@ def dashboard(script_name, problem_recorder, driver_recorder, port, run_in_backg
)
model_tabs_list.append(("Trajectory Linkage", traj_linkage_report_pane))

####### Optimization Tab #######
optimization_tabs_list = []

# Driver scaling
driver_scaling_report_pane = create_report_frame(
"html", f"{reports_dir}/driver_scaling_report.html", '''
Expand All @@ -683,53 +874,16 @@ def dashboard(script_name, problem_recorder, driver_recorder, port, run_in_backg
)
model_tabs_list.append(("Driver Scaling", driver_scaling_report_pane))

# Desvars, cons, opt interactive plot
####### Optimization Tab #######
optimization_tabs_list = []

# Optimization History Plot
if driver_recorder:
if os.path.isfile(driver_recorder):
df = convert_case_recorder_file_to_df(f"{driver_recorder}")
if df is not None:
variables = pn.widgets.CheckBoxGroup(
name="Variables",
options=list(df.columns),
# just so all of them aren't plotted from the beginning. Skip the iter count
value=list(df.columns)[1:2],
)
ipipeline = df.interactive()
ihvplot = ipipeline.hvplot(
y=variables,
responsive=True,
min_height=400,
color=list(bp.Category10[10]),
yformatter="%.0f",
title="Model Optimization using OpenMDAO",
)
optimization_plot_pane = pn.Column(
pn.Row(
pn.Column(
variables,
pn.VSpacer(height=30),
pn.VSpacer(height=30),
width=300,
),
ihvplot.panel(),
)
)
else:
optimization_plot_pane = pn.pane.Markdown(
f"# Recorder file '{driver_recorder}' does not have Driver case recordings."
)
else:
optimization_plot_pane = pn.pane.Markdown(
f"# Recorder file containing optimization history,'{driver_recorder}', not found.")

optimization_plot_pane_with_doc = pn.Column(
pn.pane.HTML(f"<p>Plot of design variables, constraints, and objectives.</p>",
styles={'text-align': documentation_text_align}),
optimization_plot_pane
)
optimization_tabs_list.append(
("History", optimization_plot_pane_with_doc)
)
df = convert_driver_case_recorder_file_to_df(f"{driver_recorder}")
cr = om.CaseReader(f"{driver_recorder}")
opt_history_pane = create_optimization_history_plot(cr, df)
optimization_tabs_list.append(("Optimization History", opt_history_pane))

# IPOPT report
if os.path.isfile(f"{reports_dir}/IPOPT.out"):
Expand Down Expand Up @@ -951,7 +1105,7 @@ def dashboard(script_name, problem_recorder, driver_recorder, port, run_in_backg
],
)

colors = bp.d3['Category20'][20][0::2] + bp.d3['Category20'][20][1::2]
colors = d3['Category20'][20][0::2] + d3['Category20'][20][1::2]
legend_data = []
phases = sorted(phases, key=str.casefold)
for i, phase in enumerate(phases):
Expand Down Expand Up @@ -1086,7 +1240,7 @@ def save_dashboard(event):
header_background="rgb(0, 212, 169)",
header=header,
background_color="white",
theme=DefaultTheme,
theme=pn.theme.DefaultTheme,
theme_toggle=False,
main_layout=None,
css_files=["assets/aviary_styles.css"],
Expand All @@ -1109,6 +1263,7 @@ def save_dashboard(event):
home_dir = "."
if port == 0:
port = get_free_port()

server = pn.serve(
template,
port=port,
Expand Down

0 comments on commit f66d8bd

Please sign in to comment.