Skip to content

Commit

Permalink
updated nmi calculation to remove nmi inversion and added digamme cor…
Browse files Browse the repository at this point in the history
…rection for entropy calculation as an option
  • Loading branch information
NDoering99 committed Jan 15, 2025
1 parent 21c16df commit 9535a2d
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 95 deletions.
7 changes: 4 additions & 3 deletions mdpath/mdpath_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from mdpath.src.visualization import MDPathVisualize



def edit_3D_visualization_json():
"""Edit the 3D visualization JSONS to your visualization needs from the command line.
Expand Down Expand Up @@ -433,6 +432,8 @@ def gpcr_2D_vis():
"Topology and cluster pathways are required for creating a 2D visualization of GPCR paths."
)
exit(1)


def spline():
"""
Create a 3D Visualization of Paths through a protein using accurate spline representations.
Expand All @@ -445,7 +446,7 @@ def spline():
-json (str): Json file of the MDPath analysis -> "quick_precomputed_clusters_paths"
Example usage:
$ mdpath_spline -json <path_to_json>
$ mdpath_spline -json <path_to_json>
"""
parser = argparse.ArgumentParser(
prog="mdpath_spline",
Expand All @@ -458,7 +459,7 @@ def spline():
help="quick_precomputed_clusters_paths file of your MDPath analysis",
required=True,
)

args = parser.parse_args()
json_file = args.json
MDPathVisualize.create_splines(json_file)
147 changes: 100 additions & 47 deletions mdpath/src/mutual_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sklearn.metrics import mutual_info_score
from sklearn.mixture import GaussianMixture
from scipy.stats import entropy
from scipy.special import digamma


class NMICalculator:
Expand All @@ -29,30 +30,40 @@ class NMICalculator:
GMM (optional): Option to switch between histogram method and Gaussian Mixture Model for binning before NMI calculation. Default is False.
mi_diff_df (pd.DataFrame): DataFrame containing the mutual information differences. Is calculated using either GMM or histogram method.
nmi_df (pd.DataFrame): DataFrame containing the mutual information differences. Is calculated using either GMM or histogram method.
entropy_df (pd.DataFrame): Pandas dataframe with residue and entropy values. Is calculated using either GMM or histogram method.
"""

def __init__(
self, df_all_residues: pd.DataFrame, num_bins: int = 35, GMM=False
self,
df_all_residues: pd.DataFrame,
digamma_correction=False,
num_bins: int = 35,
GMM=False,
) -> None:
self.df_all_residues = df_all_residues
self.num_bins = num_bins
self.digamma_correction = digamma_correction
self.GMM = GMM
if GMM:
self.mi_diff_df, self.entropy_df = self.NMI_calcs_with_GMM()
self.nmi_df, self.entropy_df = self.NMI_calcs_with_GMM()
else:
self.mi_diff_df, self.entropy_df = self.NMI_calcs()

def NMI_calcs(self) -> pd.DataFrame:
"""Nornmalized Mutual Information and Entropy calculation for all residue pairs.
Returns:
mi_diff_df (pd.DataFrame): Pandas dataframe with residue pair and mutual information difference.
entropy_df (pd.DataFrame): Pandas dataframe with residue and entropy values.
"""
self.nmi_df, self.entropy_df = self.NMI_calcs()

def calculate_corrected_entropy(self, hist, total_points, num_bins):
"""Calculate corrected entropy for a histogram."""
probabilities = hist / total_points
non_zero_probs = probabilities[probabilities > 0]
base_entropy = -np.sum(non_zero_probs * np.log(non_zero_probs))
correction = (num_bins - 1) / (2 * total_points)
digamma_correction = 1 / total_points * np.sum(
hist * digamma(hist + 1)
) - digamma(total_points)
return base_entropy + correction + digamma_correction

def NMI_calcs(self):
"""Extended Normalized Mutual Information and Entropy calculation."""
entropys = {}
normalized_mutual_info = {}
total_iterations = len(self.df_all_residues.columns) ** 2
Expand All @@ -63,36 +74,78 @@ def NMI_calcs(self) -> pd.DataFrame:
for col1 in self.df_all_residues.columns:
for col2 in self.df_all_residues.columns:
if col1 != col2:
hist_col1, _ = np.histogram(
self.df_all_residues[col1], bins=self.num_bins
)
hist_col2, _ = np.histogram(
self.df_all_residues[col2], bins=self.num_bins
)
hist_joint, _, _ = np.histogram2d(
self.df_all_residues[col1],
self.df_all_residues[col2],
bins=self.num_bins,
)
mi = mutual_info_score(
hist_col1, hist_col2, contingency=hist_joint
)
entropy_col1 = entropy(hist_col1)
entropy_col2 = entropy(hist_col2)
entropys[col1] = entropy_col1
entropys[col2] = entropy_col2
nmi = mi / np.sqrt(entropy_col1 * entropy_col2)
normalized_mutual_info[(col1, col2)] = nmi
progress_bar.update(1)
if self.digamma_correction:
# Adaptive binning
data_col1 = self.df_all_residues[col1].values
data_col2 = self.df_all_residues[col2].values
hist_col1, bin_edges1 = np.histogram(
data_col1, bins=self.num_bins
)
hist_col2, bin_edges2 = np.histogram(
data_col2, bins=self.num_bins
)
hist_joint, _, _ = np.histogram2d(
data_col1,
data_col2,
bins=(self.num_bins, self.num_bins),
)

# Total data points
total_points = len(data_col1)

# Corrected entropy estimates
entropy_col1 = self.calculate_corrected_entropy(
hist_col1, total_points, self.num_bins
)
entropy_col2 = self.calculate_corrected_entropy(
hist_col2, total_points, self.num_bins
)
joint_entropy = self.calculate_corrected_entropy(
hist_joint.flatten(), total_points, self.num_bins**2
)

# Mutual Information
mi = entropy_col1 + entropy_col2 - joint_entropy
entropys[col1] = entropy_col1
entropys[col2] = entropy_col2

# Normalized MI
nmi = mi / np.sqrt(entropy_col1 * entropy_col2)
normalized_mutual_info[(col1, col2)] = nmi

progress_bar.update(1)
else:
hist_col1, _ = np.histogram(
self.df_all_residues[col1], bins=self.num_bins
)
hist_col2, _ = np.histogram(
self.df_all_residues[col2], bins=self.num_bins
)
hist_joint, _, _ = np.histogram2d(
self.df_all_residues[col1],
self.df_all_residues[col2],
bins=self.num_bins,
)
mi = mutual_info_score(
hist_col1, hist_col2, contingency=hist_joint
)
entropy_col1 = entropy(hist_col1)
entropy_col2 = entropy(hist_col2)
entropys[col1] = entropy_col1
entropys[col2] = entropy_col2
nmi = mi / np.sqrt(entropy_col1 * entropy_col2)
normalized_mutual_info[(col1, col2)] = nmi
progress_bar.update(1)

entropy_df = pd.DataFrame(entropys.items(), columns=["Residue", "Entropy"])
mi_diff_df = pd.DataFrame(
nmi_df = pd.DataFrame(
normalized_mutual_info.items(), columns=["Residue Pair", "MI Difference"]
)
max_mi_diff = mi_diff_df["MI Difference"].max()
mi_diff_df["MI Difference"] = (
max_mi_diff - mi_diff_df["MI Difference"]
) # Calculate the the weights
return mi_diff_df, entropy_df
# max_mi_diff = mi_diff_df["MI Difference"].max()
# mi_diff_df["MI Difference"] = (
# max_mi_diff - mi_diff_df["MI Difference"]
# ) # Calculate the the weights
return nmi_df, entropy_df

def select_n_components(data: pd.DataFrame, max_components: int = 10) -> int:
"""Select the optimal number of GMM components using BIC
Expand Down Expand Up @@ -123,7 +176,7 @@ def NMI_calcs_with_GMM(self) -> pd.DataFrame:
"""Nornmalized Mutual Information and Entropy calculation for all residue pairs using Gaussian Mixture Models (GMM) for binning.
Returns:
mi_diff_df (pd.DataFrame): Pandas dataframe with residue pair and mutual information difference.
nmi_df (pd.DataFrame): Pandas dataframe with residue pair and mutual information difference.
entropy_df (pd.DataFrame): Pandas dataframe with residue and entropy values.
"""
Expand Down Expand Up @@ -168,13 +221,13 @@ def NMI_calcs_with_GMM(self) -> pd.DataFrame:
progress_bar.update(1)

entropy_df = pd.DataFrame(entropys.items(), columns=["Residue", "Entropy"])
mi_diff_df = pd.DataFrame(
nmi_df = pd.DataFrame(
normalized_mutual_info.items(), columns=["Residue Pair", "MI Difference"]
)

max_mi_diff = mi_diff_df["MI Difference"].max()
mi_diff_df["MI Difference"] = (
max_mi_diff - mi_diff_df["MI Difference"]
) # Calculate the weights
# max_mi_diff = mi_diff_df["MI Difference"].max()
# mi_diff_df["MI Difference"] = (
# max_mi_diff - mi_diff_df["MI Difference"]
# ) # Calculate the weights

return mi_diff_df, entropy_df
return nmi_df, entropy_df
25 changes: 16 additions & 9 deletions mdpath/src/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from itertools import combinations
import logging


class StructureCalculations:
"""Calculate residue surroundings and distances between residues in a PDB structure.
Expand Down Expand Up @@ -173,13 +174,15 @@ def calc_dihedral_angle_movement(self, res_id: int) -> tuple:
dihedral_angle_movement = np.diff(dihedrals, axis=0)
return res_id, dihedral_angle_movement
except (TypeError, AttributeError, IndexError) as e:
logging.debug(f"Failed to calculate dihedral for residue {res_id}: {str(e)}")
logging.debug(
f"Failed to calculate dihedral for residue {res_id}: {str(e)}"
)
return None

def calculate_dihedral_movement_parallel(
self,
num_parallel_processes: int,
) -> pd.DataFrame:
) -> pd.DataFrame:
"""Parallel calculation of dihedral angle movement for all residues in the trajectory.
Args:
Expand All @@ -189,7 +192,7 @@ def calculate_dihedral_movement_parallel(
pd.DataFrame: DataFrame with all residue dihedral angle movements.
"""
df_all_residues = pd.DataFrame()

try:
with Pool(processes=num_parallel_processes) as pool:
with tqdm(
Expand All @@ -201,24 +204,28 @@ def calculate_dihedral_movement_parallel(
self.calc_dihedral_angle_movement,
range(self.first_res_num, self.last_res_num + 1),
)

for result in results:
if result is None:
pbar.update(1)
continue

res_id, dihedral_data = result
try:
df_residue = pd.DataFrame(dihedral_data, columns=[f"Res {res_id}"])
df_residue = pd.DataFrame(
dihedral_data, columns=[f"Res {res_id}"]
)
df_all_residues = pd.concat(
[df_all_residues, df_residue], axis=1
)
except Exception as e:
logging.error(f"\033[1mError processing residue {res_id}: {e}\033[0m")
logging.error(
f"\033[1mError processing residue {res_id}: {e}\033[0m"
)
finally:
pbar.update(1)

except Exception as e:
logging.error(f"Parallel processing failed: {str(e)}")

return df_all_residues
Loading

0 comments on commit 9535a2d

Please sign in to comment.