Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 399 roc curve #408

Merged
merged 8 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion metplotpy/plots/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,11 @@ def _get_user_legends(self, legend_label_type: str ) -> list:
# resulting in a zero-sized series_list. In this case,
# the legend label will just be the legend_label_type.
if len(series_list) == 0 and legend_label_unspecified:
return [legend_label_type]
# check if summary_curve is present
if 'summary_curve' in self.parameters.keys() and self.parameters['summary_curve'] != 'none':
return [legend_label_type, self.parameters['summary_curve'] + ' ' + legend_label_type]
else:
return [legend_label_type]

perms = utils.create_permutations(series_list)
for idx,ll in enumerate(legends_list):
Expand All @@ -481,6 +485,8 @@ def _get_user_legends(self, legend_label_type: str ) -> list:
ll_list.append(legend_label)
else:
ll_list.append(ll)
if 'summary_curve' in self.parameters.keys() and self.parameters['summary_curve'] != 'none':
ll_list.append(self.parameters['summary_curve'] + ' ' + legend_label_type)

legends_list_ordered = self.create_list_by_series_ordering(ll_list)
return legends_list_ordered
Expand Down
3 changes: 3 additions & 0 deletions metplotpy/plots/config/roc_diagram_defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,6 @@ reverse_connection_order: False
create_html: True #optional
stat_input: ../../test/roc_diagram/plot_20200507_074426.data #required
plot_filename: ./roc_diagram_default.png #required

summary_curve: none

67 changes: 63 additions & 4 deletions metplotpy/plots/roc_diagram/roc_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from datetime import datetime
import yaml
import re
import sys
import warnings
# with warnings.catch_warnings():
# warnings.simplefilter("ignore", category="DeprecationWarning")
Expand All @@ -32,7 +31,7 @@
from metplotpy.plots.roc_diagram.roc_diagram_config import ROCDiagramConfig
from metplotpy.plots.roc_diagram.roc_diagram_series import ROCDiagramSeries
import metcalcpy.util.utils as calc_util

from metplotpy.plots.util import prepare_pct_roc, prepare_ctc_roc


class ROCDiagram(BasePlot):
Expand Down Expand Up @@ -146,11 +145,70 @@ def _create_series(self, input_data):

# use the list of series ordering values to determine how many series objects we need.
num_series = len(self.config_obj.series_ordering)
if self.config_obj.summary_curve != 'none':
num_series = num_series -1

for i, series in enumerate(range(num_series)):
# Create a ROCDiagramSeries object
series_obj = ROCDiagramSeries(self.config_obj, i, input_data)
series_list.append(series_obj)

if self.config_obj.summary_curve != 'none':
# add Summary Curve bassd on teh summary dataframes of each ROCDiagramSeries
df_sum_main = None
for idx, series in enumerate(series_list):
# create a main summary frame from series summary frames
if self.config_obj.linetype_ctc:
if df_sum_main is None:
df_sum_main = pd.DataFrame(columns=['fcst_thresh', 'fy_oy', 'fy_on', 'fn_oy', 'fn_on'])
elif self.config_obj.linetype_pct:
if df_sum_main is None:
df_sum_main = pd.DataFrame(columns=['thresh_i', 'i_value', 'on_i', 'oy_i'])

df_sum_main = pd.concat([df_sum_main, series.series_points[3]], axis=0)

if self.config_obj.linetype_ctc:
df_summary_curve = pd.DataFrame(columns=['fcst_thresh', 'fy_oy', 'fy_on', 'fn_oy', 'fn_on'])
fcst_thresh_list = df_sum_main['fcst_thresh'].unique()
for thresh in fcst_thresh_list:
if self.config_obj.summary_curve == 'median':
group_stats_fy_oy = df_sum_main['fy_oy'][df_sum_main['fcst_thresh'] == thresh].median()
group_stats_fn_oy = df_sum_main['fn_oy'][df_sum_main['fcst_thresh'] == thresh].median()
group_stats_fy_on = df_sum_main['fy_on'][df_sum_main['fcst_thresh'] == thresh].median()
group_stats_fn_on = df_sum_main['fn_on'][df_sum_main['fcst_thresh'] == thresh].median()
else:
group_stats_fy_oy = df_sum_main['fy_oy'][df_sum_main['fcst_thresh'] == thresh].mean()
group_stats_fn_oy = df_sum_main['fn_oy'][df_sum_main['fcst_thresh'] == thresh].mean()
group_stats_fy_on = df_sum_main['fy_on'][df_sum_main['fcst_thresh'] == thresh].mean()
group_stats_fn_on = df_sum_main['fn_on'][df_sum_main['fcst_thresh'] == thresh].mean()
df_summary_curve.loc[len(df_summary_curve)] = {'fcst_thresh': thresh,
'fy_oy': group_stats_fy_oy,
'fn_oy': group_stats_fn_oy,
'fy_on': group_stats_fy_on,
'fn_on': group_stats_fn_on,
}
df_summary_curve.reset_index()
pody, pofd, thresh = prepare_ctc_roc(df_summary_curve,self.config_obj.ctc_ascending)
else:
df_summary_curve = pd.DataFrame(columns=['thresh_i', 'on_i', 'oy_i'])
thresh_i_list = df_sum_main['thresh_i'].unique()
for index, thresh in enumerate(thresh_i_list):
if self.config_obj.summary_curve == 'median':
on_i_sum = df_sum_main['on_i'][df_sum_main['thresh_i'] == thresh].median()
oy_i_sum = df_sum_main['oy_i'][df_sum_main['thresh_i'] == thresh].median()
else:
on_i_sum = df_sum_main['on_i'][df_sum_main['thresh_i'] == thresh].mean()
oy_i_sum = df_sum_main['oy_i'][df_sum_main['thresh_i'] == thresh].mean()
df_summary_curve.loc[len(df_summary_curve)] = {'thresh_i': thresh, 'on_i': on_i_sum,
'oy_i': oy_i_sum, }
df_summary_curve.reset_index()
pody, pofd, thresh = prepare_pct_roc(df_summary_curve)

series_obj = ROCDiagramSeries(self.config_obj, num_series -1, None)
series_obj.series_points = (pofd, pody, thresh, None)

series_list.append(series_obj)

return series_list

def remove_file(self):
Expand Down Expand Up @@ -316,6 +374,9 @@ def _create_figure(self):

thresh_list = []




# "Dump" False Detection Rate (POFD) and PODY points to an output
# file based on the output image filename (useful in debugging)
# This output file is used by METviewer and not necessary for other uses.
Expand Down Expand Up @@ -346,8 +407,6 @@ def _create_figure(self):
)




def add_trace_copy(trace):
"""Adds separate traces for markers and a legend.
This is a fix for not printing 'Aa' in the legend
Expand Down
7 changes: 7 additions & 0 deletions metplotpy/plots/roc_diagram/roc_diagram_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def __init__(self, parameters):
self.plot_height = self.calculate_plot_dimension('plot_height', 'pixels')
self.show_legend = self._get_show_legend()

if 'summary_curve' in self.parameters.keys():
self.summary_curve = self.parameters['summary_curve']
else:
self.summary_curve = 'none'


def _get_series_inner_dict(self, index):
"""
Expand Down Expand Up @@ -304,3 +309,5 @@ def _get_markers(self):
markers_list_ordered = self.create_list_by_series_ordering(markers_list)
return markers_list_ordered



69 changes: 43 additions & 26 deletions metplotpy/plots/roc_diagram/roc_diagram_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
import warnings
import pandas as pd
import re
import metcalcpy.util.ctc_statistics as cstats
import metcalcpy.util.pstd_statistics as pstats
import metcalcpy.util.utils as utils
from ..series import Series
from ..util import prepare_pct_roc, prepare_ctc_roc


class ROCDiagramSeries(Series):
Expand All @@ -42,11 +41,12 @@ def _create_series_points(self):
Args:

Returns:
tuple of three lists:
tuple of three lists and a summary dataframe:
pody (Probability of detection) and
pofd (probability of false detection/
false alarm rate)
thresh (threshold value, used to annotate)
df_sum for the summary curve future calculation


"""
Expand All @@ -56,11 +56,15 @@ def _create_series_points(self):
# config file
input_df = self.input_data

# this is for the summary curve
if input_df is None:
return [], [], [], []

# Event equalization can sometimes create an empty data frame. Check for
# an empty data frame and return a tuple of empty lists if this is the case.
if input_df.empty:
print(f"INFO: No points to plot (most likely as a result of event equalization). ")
return [],[],[]
return [],[],[],[]

series_num = self.series_order
perm = utils.create_permutations(self.all_series_vals)
Expand All @@ -71,35 +75,48 @@ def _create_series_points(self):
# no subsetting of data required, no series_val_1 values
# were specified in the config file.
subset_df = input_df.copy()

df_sum = None
if self.config.linetype_ctc:
subset_df = self._add_ctc_columns(subset_df)
df_roc = cstats.calculate_ctc_roc(subset_df, ascending=self.config.ctc_ascending)
pody = df_roc['pody']
pody = pd.concat([pd.Series([1]), pody], ignore_index=True)
pody = pd.concat([pody, pd.Series([0])], ignore_index=True)
pofd = df_roc['pofd']
pofd = pd.concat([pd.Series([1]), pofd], ignore_index=True)
pofd = pd.concat([pofd, pd.Series([0])], ignore_index=True)
thresh = df_roc['thresh']
thresh = pd.concat([pd.Series(['']), thresh], ignore_index=True)
thresh = pd.concat([thresh, pd.Series([''])], ignore_index=True)
return pofd, pody, thresh
pody, pofd, thresh = prepare_ctc_roc(subset_df, self.config.ctc_ascending)

if self.config.summary_curve != 'none':
# calculate sum for each thresh
fcst_thresh_list = subset_df['fcst_thresh'].unique()
df_sum = pd.DataFrame(columns=['fcst_thresh', 'fy_oy', 'fy_on', 'fn_oy', 'fn_on'])
for thresh_val in fcst_thresh_list:
fy_oy_sum = subset_df['fy_oy'][subset_df['fcst_thresh'] == thresh_val].sum()
fy_on_sum = subset_df['fy_on'][subset_df['fcst_thresh'] == thresh_val].sum()
fn_oy_sum = subset_df['fn_oy'][subset_df['fcst_thresh'] == thresh_val].sum()
fn_on_sum = subset_df['fn_on'][subset_df['fcst_thresh'] == thresh_val].sum()
df_sum.loc[len(df_sum)] = {'fcst_thresh': thresh_val,
'fy_oy': fy_oy_sum, 'fy_on': fy_on_sum,
'fn_oy': fn_oy_sum, 'fn_on' : fn_on_sum}
df_sum.reset_index()
return pofd, pody, thresh, df_sum

elif self.config.linetype_pct:
roc_df = pstats._calc_pct_roc(subset_df)
pody = roc_df['pody']
pody = pd.concat([pd.Series([1]), pody], ignore_index=True)
pody = pd.concat([pody, pd.Series([0])])
pofd = roc_df['pofd']
pofd = pd.concat([pd.Series([1]), pofd], ignore_index=True)
pofd = pd.concat([pofd, pd.Series([0])], ignore_index=True)
thresh = roc_df['thresh']
thresh = pd.concat([pd.Series(['']),thresh], ignore_index=True)
thresh = pd.concat([thresh, pd.Series([''])], ignore_index=True)
return pofd, pody, thresh
pody, pofd, thresh = prepare_pct_roc(subset_df)

if self.config.summary_curve != 'none':
# calculate sum for each thresh
thresh_i_list = subset_df['thresh_i'].unique()
i_value_list = subset_df['i_value'].unique()
df_sum = pd.DataFrame(columns=['thresh_i', 'i_value', 'on_i', 'oy_i'])
if len(thresh_i_list) != len(i_value_list):
raise Exception("The size of thresh_i is not the same as the size of i_value")
for index, thresh_val in enumerate(thresh_i_list):
on_i_sum = subset_df['on_i'][(subset_df['thresh_i'] == thresh_val) & (subset_df['i_value'] == i_value_list[index])].sum()
oy_i_sum = subset_df['oy_i'][(subset_df['thresh_i'] == thresh_val) & (subset_df['i_value'] == i_value_list[index])].sum()
df_sum.loc[len(df_sum)] = {'thresh_i': thresh_val, 'i_value': i_value_list[index], 'on_i': on_i_sum, 'oy_i': oy_i_sum,}
df_sum.reset_index()
return pofd, pody, thresh, df_sum
else:
raise ValueError('error neither ctc or pct linetype ')



def _subset_data(self, df_full, permutation):
'''
Subset the input dataframe, iterating over the column and rows of interest
Expand Down
60 changes: 50 additions & 10 deletions metplotpy/plots/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import pandas as pd
from plotly.graph_objects import Figure
from metplotpy.plots.context_filter import ContextFilter as cf
import metcalcpy.util.pstd_statistics as pstats
import metcalcpy.util.ctc_statistics as cstats

COLORSCALES = {
'green_red': ['#E6FFE2', '#B3FAAD', '#74F578', '#30D244', '#00A01E', '#F6A1A2',
Expand Down Expand Up @@ -364,7 +366,6 @@ def is_thresh_column(column_name: str) -> bool:


def filter_by_fixed_vars(input_df: pd.DataFrame, settings_dict: dict) -> pd.DataFrame:

"""
Filter the input data based on values in the settings_dict dictionary.
For each key (corresponding to a column in the input_df dataframe),
Expand Down Expand Up @@ -477,9 +478,9 @@ def filter_by_fixed_vars(input_df: pd.DataFrame, settings_dict: dict) -> pd.Data


elif val_idx > 0 and not is_last_val:
# One of the middle values in the list
query_string = prev_val_string + single_quote + val + \
single_quote + list_sep
# One of the middle values in the list
query_string = prev_val_string + single_quote + val + \
single_quote + list_sep

else:
# The last value in the list
Expand Down Expand Up @@ -507,14 +508,14 @@ def filter_by_fixed_vars(input_df: pd.DataFrame, settings_dict: dict) -> pd.Data

# Only one value in the values list (both first and last element)
if val_idx == 0 and is_last_val:
prev_val_string = prev_val_string + col + in_token + list_start\
+ single_quote + val + single_quote +\
prev_val_string = prev_val_string + col + in_token + list_start \
+ single_quote + val + single_quote + \
list_terminator

elif val_idx == 0 and (not is_last_val):
# First value of a list of values
prev_val_string = prev_val_string + col + in_token +\
list_start + single_quote + val +\
prev_val_string = prev_val_string + col + in_token + \
list_start + single_quote + val + \
single_quote + list_sep

elif val_idx > 0 and not is_last_val:
Expand All @@ -526,15 +527,54 @@ def filter_by_fixed_vars(input_df: pd.DataFrame, settings_dict: dict) -> pd.Data

query_string_list.append(prev_val_string)


# Perform query for each column (key)
for cur_query in query_string_list:
working_df = filtered_df.query(cur_query)
filtered_df = working_df.copy(deep=True)


# clean up
del working_df
gc.collect()

return filtered_df


def prepare_pct_roc(subset_df):
"""
Initialize the PCT ROC plot data, appends a beginning and end point
:param subset_df: PCT data
:return: PCT ROC plot data
"""
roc_df = pstats._calc_pct_roc(subset_df)
pody = roc_df['pody']
pody = pd.concat([pd.Series([1]), pody], ignore_index=True)
pody = pd.concat([pody, pd.Series([0])])
pofd = roc_df['pofd']
pofd = pd.concat([pd.Series([1]), pofd], ignore_index=True)
pofd = pd.concat([pofd, pd.Series([0])], ignore_index=True)
thresh = roc_df['thresh']
thresh = pd.concat([pd.Series(['']), thresh], ignore_index=True)
thresh = pd.concat([thresh, pd.Series([''])], ignore_index=True)

return pody, pofd, thresh


def prepare_ctc_roc(subset_df, is_ascending):
"""
Initialize the CTC ROC plot data, appends a beginning and end point
:param subset_df: CTC data
:param is_ascending: thresh order
:return: CTC ROC plot data
"""
df_roc = cstats.calculate_ctc_roc(subset_df, ascending=is_ascending)
pody = df_roc['pody']
pody = pd.concat([pd.Series([1]), pody], ignore_index=True)
pody = pd.concat([pody, pd.Series([0])], ignore_index=True)
pofd = df_roc['pofd']
pofd = pd.concat([pd.Series([1]), pofd], ignore_index=True)
pofd = pd.concat([pofd, pd.Series([0])], ignore_index=True)
thresh = df_roc['thresh']
thresh = pd.concat([pd.Series(['']), thresh], ignore_index=True)
thresh = pd.concat([thresh, pd.Series([''])], ignore_index=True)

return pody, pofd, thresh
Loading