Skip to content

Commit

Permalink
fixed failing tests after the new nmi calculation algo was implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
NDoering99 committed Jan 20, 2025
1 parent d99d45c commit 2b22ab1
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 12 deletions.
2 changes: 1 addition & 1 deletion mdpath/tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_pathways_cluster(setup_clustering):
clustering = setup_clustering

with patch("matplotlib.pyplot.savefig") as mock_savefig:
clusters = clustering.pathways_cluster(n_top_clust=0)
clusters = clustering.pathways_cluster(n_top_clust=3)

assert isinstance(clusters, dict)
assert all(isinstance(k, int) and isinstance(v, list) for k, v in clusters.items())
Expand Down
4 changes: 2 additions & 2 deletions mdpath/tests/test_mdpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def test_mdpath_output_files():

topology = os.path.join(script_dir, "test_topology.pdb")
trajectory = os.path.join(script_dir, "test_trajectory.dcd")
numpath = "25"
numpath = "10"
bootstrap = "1"
assert os.path.exists(topology), f"Topology file {topology} does not exist."
assert os.path.exists(trajectory), f"Trajectory file {trajectory} does not exist."

expected_files = [
os.path.join(script_dir, "first_frame.pdb"),
os.path.join(script_dir, "nmi_df.csv.csv"),
os.path.join(script_dir, "nmi_df.csv"),
os.path.join(script_dir, "output.txt"),
os.path.join(script_dir, "residue_coordinates.pkl"),
os.path.join(script_dir, "cluster_pathways_dict.pkl"),
Expand Down
2 changes: 1 addition & 1 deletion mdpath/tests/test_mdpath_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def test_multitraj_analysis():
topology,
"-multitraj",
multitraj_1,
multitraj_1,
multitraj_1
]
with pytest.raises(SystemExit) as exc_info:
multitraj_analysis()
Expand Down
59 changes: 51 additions & 8 deletions mdpath/tests/test_mutual_information.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import sys
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics import mutual_info_score
from scipy.stats import entropy
import pytest
import tempfile
from unittest.mock import MagicMock, Mock, patch, call


from mdpath.src.mutual_information import NMICalculator


def test_nmi_calculator(mocker):
def test_nmi_calculator_invert():
np.random.seed(0) # Set a fixed seed for reproducibility
data = {
"Residue1": np.random.uniform(-180, 180, size=100),
Expand All @@ -20,7 +15,7 @@ def test_nmi_calculator(mocker):
}
df_all_residues = pd.DataFrame(data)

calculator = NMICalculator(df_all_residues)
calculator = NMICalculator(df_all_residues, invert=True)
result = calculator.nmi_df

assert isinstance(result, pd.DataFrame), "Result should be a DataFrame"
Expand Down Expand Up @@ -57,3 +52,51 @@ def test_nmi_calculator(mocker):
assert np.isclose(
actual, expected, atol=1e-5
), f"For pair {pair}, expected {expected} but got {actual}"


def test_nmi_calculator():
np.random.seed(0) # Set a fixed seed for reproducibility
data = {
"Residue1": np.random.uniform(-180, 180, size=100),
"Residue2": np.random.uniform(-180, 180, size=100),
"Residue3": np.random.uniform(-180, 180, size=100),
}
df_all_residues = pd.DataFrame(data)

calculator = NMICalculator(df_all_residues)
result = calculator.nmi_df

assert isinstance(result, pd.DataFrame), "Result should be a DataFrame"

assert (
"Residue Pair" in result.columns
), "DataFrame should contain 'Residue Pair' column"
assert (
"MI Difference" in result.columns
), "DataFrame should contain 'MI Difference' column"

assert not result.empty, "DataFrame should not be empty"

expected_shape = (
len(df_all_residues.columns) * (len(df_all_residues.columns) - 1),
2,
)
assert (
result.shape == expected_shape
), f"Expected shape {expected_shape}, but got {result.shape}"

assert (
result["MI Difference"] >= 0
).all(), "MI Difference values should be non-negative"

# Add your expected values here
expected_values = {
("Residue1", "Residue2"): 0.6729976399990001,
("Residue1", "Residue3"): 0.6667438140756161,
("Residue2", "Residue3"): 0.6667102733802279,
}
for pair, expected in expected_values.items():
actual = result.loc[result["Residue Pair"] == pair, "MI Difference"].values[0]
assert np.isclose(
actual, expected, atol=1e-5
), f"For pair {pair}, expected {expected} but got {actual}"

0 comments on commit 2b22ab1

Please sign in to comment.