From c3efbc1a82f8aac35b6efaab1e581bee16ab94ae Mon Sep 17 00:00:00 2001 From: voetberg Date: Tue, 21 May 2024 14:14:14 -0500 Subject: [PATCH] Corner plots version --- src/metrics/local_two_sample.py | 4 +- src/plots/local_two_sample.py | 157 ++++++++++++++++++++------------ 2 files changed, 104 insertions(+), 57 deletions(-) diff --git a/src/metrics/local_two_sample.py b/src/metrics/local_two_sample.py index e078670..0616755 100644 --- a/src/metrics/local_two_sample.py +++ b/src/metrics/local_two_sample.py @@ -81,7 +81,8 @@ def _cross_eval_score(self, p, q, x_p, x_q, classifier, classifier_kwargs, n_cro # train classifiers over cv-folds probabilities = [] self.evaluation_data = np.zeros((n_cross_folds, len(next(cv_splits)[1]), self.evaluation_context.shape[-1])) - + self.prior_evaluation = np.zeros_like(p) + kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42) cv_splits = kf.split(p) for cross_trial, (train_index, val_index) in enumerate(cv_splits): @@ -94,6 +95,7 @@ def _cross_eval_score(self, p, q, x_p, x_q, classifier, classifier_kwargs, n_cro self.evaluation_data[cross_trial][index] = self.data.simulator.simulate( p_validation, self.evaluation_context[val_index][index] ) + self.prior_evaluation[index] = p_validation probabilities.append(self._eval_model(p_evaluate, self.evaluation_data[cross_trial], trained_nth_classifier)) return probabilities diff --git a/src/plots/local_two_sample.py b/src/plots/local_two_sample.py index 0923763..da67e2c 100644 --- a/src/plots/local_two_sample.py +++ b/src/plots/local_two_sample.py @@ -1,5 +1,6 @@ from typing import Optional, Sequence, Union import matplotlib.pyplot as plt +from matplotlib import cm import numpy as np from matplotlib.colors import Normalize from matplotlib.patches import Rectangle @@ -79,20 +80,22 @@ def lc2st_pairplot(self, subplot, confidence_region_alpha=0.2): pairplot_values = self._make_pairplot_values(prob) subplot.plot(self.cdf_alphas, pairplot_values, label=label, color=color) - def probability_intensity(self, subplot, plot_dims, features, n_bins=20): + def probability_intensity(self, subplot, features, n_bins=20): evaluation_data = self.l2st.evaluation_data - + norm = Normalize(vmin=0, vmax=1) if len(evaluation_data.shape) >=3: # Used the kfold option evaluation_data = evaluation_data.reshape(( evaluation_data.shape[0]*evaluation_data.shape[1], evaluation_data.shape[-1])) self.probability = self.probability.ravel() - if plot_dims==1: + try: + # If there is only one feature + int(features) _, bins, patches = subplot.hist( evaluation_data[:,features], n_bins, weights=self.probability, density=True, color=self.param_colors[features]) - + eval_bins = np.select( [evaluation_data[:,features] <= i for i in bins[1:]], list(range(n_bins)) ) @@ -102,16 +105,18 @@ def probability_intensity(self, subplot, plot_dims, features, n_bins=20): colors = plt.get_cmap(self.colorway) for w, p in zip(weights, patches): - p.set_facecolor(colors(w)) # color is mean predicted proba - - else: + p.set_facecolor(colors(norm(w))) # color is mean predicted proba + - _, x_edges, y_edges, patches = subplot.hist2d( + except TypeError: + _, x_edges, y_edges, image = subplot.hist2d( evaluation_data[:,features[0]], evaluation_data[:,features[1]], n_bins, - density=True, color=self.param_colors[features[0]]) + density=True, color="white") + image.remove() + eval_bins_dim_1 = np.select( [evaluation_data[:,features[0]] <= i for i in x_edges[1:]], list(range(n_bins)) ) @@ -121,21 +126,23 @@ def probability_intensity(self, subplot, plot_dims, features, n_bins=20): colors = plt.get_cmap(self.colorway) - weights = np.empty((n_bins, n_bins)) + weights = np.empty((n_bins, n_bins)) * np.nan + print(weights) for i in range(n_bins): for j in range(n_bins): - try: + local_and = np.logical_and(eval_bins_dim_1==i, eval_bins_dim_2==j) + if local_and.any(): weights[i, j] = self.probability[np.logical_and(eval_bins_dim_1==i, eval_bins_dim_2==j)].mean() - except KeyError: - pass + for i in range(len(x_edges) - 1): for j in range(len(y_edges) - 1): weight = weights[i,j] - facecolor = colors(weight) + facecolor = colors(norm(weight)) # if no sample in bin, set color to white if weight == np.nan: facecolor = "white" + rect = Rectangle( (x_edges[i], y_edges[j]), x_edges[i + 1] - x_edges[i], @@ -144,42 +151,29 @@ def probability_intensity(self, subplot, plot_dims, features, n_bins=20): edgecolor="none", ) subplot.add_patch(rect) - + def _plot(self, use_intensity_plot:bool=True, n_alpha_samples:int=100, confidence_region_alpha:float=0.2, n_intensity_bins:int=20, - intensity_dimension:int=2, - intensity_feature_index:Union[int, Sequence[int]]=[0,1], linear_classifier:Union[str, list[str]]='MLP', cross_evaluate:bool=True, n_null_hypothesis_trials=100, classifier_kwargs:Union[dict, list[dict]]=None, - y_label="Empirical CDF", - x_label="", - title="Local Classifier 2-Sample Test" + pairplot_y_label="Empirical CDF", + pairplot_x_label="", + pairplot_title="Local Classifier PP-Plot", + intensity_plot_ylabel="", + intensity_plot_xlabel="", + intensity_plot_title="Local Classifier Intensity Distribution", ): - if use_intensity_plot: - if intensity_dimension not in (1, 2): - raise NotImplementedError("LC2ST Intensity Plot only implemented in 1D and 2D") - - if intensity_dimension == 1: - try: - int(intensity_feature_index) - except TypeError: - raise ValueError(f"Cannot use {intensity_feature_index} to plot, please supply an integer value index.") - - else: - try: - assert len(intensity_feature_index) == intensity_dimension - int(intensity_feature_index[0]) - int(intensity_feature_index[1]) - except (AssertionError, TypeError): - raise ValueError(f"Cannot use {intensity_feature_index} to plot, please supply a list of 2 integer value indices.") - + # Plots to make - + # pp_plot_lc2st: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L49 + # eval_space_with_proba_intensity: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L133 + self.l2st(**{ "linear_classifier":linear_classifier, "cross_evaluate": cross_evaluate, @@ -188,25 +182,76 @@ def _plot(self, self.probability, self.null_hypothesis_probability = self.l2st.output["lc2st_probabilities"], self.l2st.output["lc2st_null_hypothesis_probabilities"] - # Plots to make - - # pp_plot_lc2st: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L49 - # eval_space_with_proba_intensity: https://github.com/JuliaLinhart/lc2st/blob/e221cc326480cb0daadfd2ba50df4eefd374793b/lc2st/graphical_diagnostics.py#L133 - - n_plots = 1 if not use_intensity_plot else 2 - figure_size = self.figure_size if n_plots==1 else (int(self.figure_size[0]*1.8),self.figure_size[1]) - fig, subplots = plt.subplots(1, n_plots, figsize=figure_size) + fig, subplots = plt.subplots(1, 1, figsize=self.figure_size) self.cdf_alphas = np.linspace(0, 1, n_alpha_samples) - self.lc2st_pairplot(subplots[0] if n_plots == 2 else subplots, confidence_region_alpha=confidence_region_alpha) - if use_intensity_plot: - self.probability_intensity( - subplots[1], - intensity_dimension, - n_bins=n_intensity_bins, - features=intensity_feature_index - ) + self.lc2st_pairplot(subplots, confidence_region_alpha=confidence_region_alpha) fig.legend() - fig.supylabel(y_label) - fig.supxlabel(x_label) - fig.suptitle(title) \ No newline at end of file + fig.supylabel(pairplot_y_label) + fig.supxlabel(pairplot_x_label) + fig.suptitle(pairplot_title) + + self.plot_name = "local_c2st_pp_plot.png" + self._finish() + + if use_intensity_plot: + + fig, subplots = plt.subplots(len(self.param_names), len(self.param_names), figsize=(self.figure_size[0]*1.2, self.figure_size[1])) + combos_run = [] + for x_index, x_param in enumerate(self.param_names): + for y_index, y_param in enumerate(self.param_names): + + if ({x_index, y_index} not in combos_run) and (x_index>=y_index): + subplot = subplots[x_index][y_index] + + if x_index == y_index: + features = x_index + else: + features = [x_index, y_index] + + self.probability_intensity( + subplot, + features=features, + n_bins=n_intensity_bins + ) + combos_run.append({x_index, y_index}) + + if (x_index None: + try: + self._data_setup() + except NotImplementedError: + pass + try: + self._plot_settings() + except NotImplementedError: + pass + + self._plot(**plot_args)