Skip to content

Commit

Permalink
completed visualization + minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MarvinTaterra committed Jul 26, 2024
1 parent 52bf92d commit 463069d
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 21 deletions.
2 changes: 1 addition & 1 deletion mdpath/mdpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from mdpath.src.visualization import (
residue_CA_coordinates,
apply_backtracking,
cluster_prep_for_visualisaton,
cluster_prep_for_visualisation,
format_dict,
visualise_graph,
precompute_path_properties,
Expand Down
37 changes: 17 additions & 20 deletions mdpath/src/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def residue_CA_coordinates(pdb_file: str, end: int) -> dict:
return residue_coordinates_dict


def cluster_prep_for_visualisaton(
def cluster_prep_for_visualisation(
cluster: list[list[int]], pdb_file: str
) -> list[list[tuple[float]]]:
"""Prepares patway clusters for visualisation.
"""Prepares pathway clusters for visualisation.
Args:
cluster (list[list[int]]): Cluster of pathways.
Expand All @@ -57,23 +57,24 @@ def cluster_prep_for_visualisaton(
Returns:
cluster (list[list[tuple[float]]]): Cluster of pathways with CA atom coordinates.
"""
cluster = []
new_cluster = []
parser = PDB.PDBParser(QUIET=True)
structure = parser.get_structure("pdb_structure", pdb_file)

for pathway in cluster:
pathways = []
for residue in pathway:
parser = PDB.PDBParser(QUIET=True)
structure = parser.get_structure("pdb_structure", pdb_file)
res_id = ("", residue, "")
try:
res = structure[0][res_id]
atom = res["CA"]
coord = atom.get_coord()
coord = tuple(atom.get_coord())
pathways.append(coord)
except KeyError:
print(res + " not found.")
cluster.append(pathways)
return cluster

print(f"Residue {res_id} not found.")
new_cluster.append(pathways)
return new_cluster

def apply_backtracking(original_dict: dict, translation_dict: dict) -> dict:
"""Backtracks the original dictionary with a translation dictionary.
Expand Down Expand Up @@ -104,28 +105,24 @@ def format_dict(updated_dict: dict) -> dict:
Returns:
transformed_dict (dict): Reformatted dictionary.
"""

def transform_list(nested_list):
transformed = []
for item in nested_list:
if isinstance(item, np.ndarray):
transformed.append(item.tolist())
elif isinstance(item, list):
transformed.append(transform_list(item)) # Recursively transform lists
transformed.append(transform_list(item)) # Append instead of extend
else:
transformed.append(item)
return transformed

transformed_dict = {}
for key, value in updated_dict.items():
if isinstance(value, np.ndarray):
transformed_dict[key] = value.tolist()
elif isinstance(value, list):
transformed_dict[key] = transform_list(value)
else:
transformed_dict[key] = value

transformed_dict = {
key: transform_list(value) for key, value in updated_dict.items()
}
return transformed_dict


def visualise_graph(graph: nx.Graph, k=0.1, node_size=200) -> None:
"""Draws residue graph to PNG file.
Expand Down
146 changes: 146 additions & 0 deletions mdpath/tests/test_mdpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from MDAnalysis.analysis.dihedrals import Dihedral
from tqdm import tqdm
from Bio import PDB
import os


import mdpath.src.visualization

Expand Down Expand Up @@ -671,3 +673,147 @@ def test_apply_backtracking():
result = mdpath.src.visualization.apply_backtracking(original_dict, translation_dict)
assert result == expected

def test_cluster_prep_for_visualisation():
pdb_file = "mock_pdb_file.pdb"
input_cluster = [[1, 2], [3]]
mock_coordinates = {
1: (1.0, 1.0, 1.0),
2: (2.0, 2.0, 2.0),
3: (3.0, 3.0, 3.0)
}

with patch('Bio.PDB.PDBParser') as mock_parser:
mock_structure = MagicMock()

def get_structure(name, file):
return mock_structure

mock_parser.return_value.get_structure.side_effect = get_structure

mock_residues = {}
for residue_id, coord in mock_coordinates.items():
mock_residue = MagicMock()
mock_atom = MagicMock()
mock_atom.get_coord.return_value = coord
mock_residue.__getitem__.return_value = mock_atom
mock_residues[("", residue_id, "")] = mock_residue

def getitem(res_id):
if res_id in mock_residues:
return mock_residues[res_id]
else:
raise KeyError

mock_structure[0].__getitem__.side_effect = getitem

result = mdpath.src.visualization.cluster_prep_for_visualisation(input_cluster, pdb_file)

expected_result = [
[(1.0, 1.0, 1.0), (2.0, 2.0, 2.0)],
[(3.0, 3.0, 3.0)]
]

assert result == expected_result

def test_format_dict():
input_dict = {
'array': np.array([1, 2, 3]),
'nested_list': [1, 2, np.array([3, 4])]
}
expected_output = {
'array': [1, 2, 3],
'nested_list': [1, 2, [3, 4]]
}
assert mdpath.src.visualization.format_dict(input_dict) == expected_output
assert mdpath.src.visualization.format_dict({}) == {}
input_dict = {
'nested': [1, [2, 3], np.array([4, 5])]
}
expected_output = {
'nested': [1, [2, 3], [4, 5]]
}
assert mdpath.src.visualization.format_dict(input_dict) == expected_output
input_dict = {
'mixed': [1, 'string', np.array([6, 7])]
}
expected_output = {
'mixed': [1, 'string', [6, 7]]
}
assert mdpath.src.visualization.format_dict(input_dict) == expected_output


def test_visualise_graph():
G = nx.Graph()
G.add_edges_from([(1, 2), (2, 3), (3, 1)])

mdpath.src.visualization.visualise_graph(G)

assert os.path.exists("graph.png"), "graph.png file was not created."

if os.path.exists("graph.png"):
os.remove("graph.png")
def test_precompute_path_properties():
json_data = {
"cluster1": [
[[[1, 2, 3]], [[4, 5, 6]]],
[[[7, 8, 9]], [[10, 11, 12]]],
],
"cluster2": [
[[[13, 14, 15]], [[16, 17, 18]]],
]
}

expected_output = [
{
"clusterid": "cluster1",
"pathway_index": 0,
"path_segment_index": 0,
"coord1": [1, 2, 3],
"coord2": [4, 5, 6],
"color": [1, 0, 0],
"radius": 0.015,
"path_number": 1,
},
{
"clusterid": "cluster1",
"pathway_index": 1,
"path_segment_index": 0,
"coord1": [7, 8, 9],
"coord2": [10, 11, 12],
"color": [1, 0, 0],
"radius": 0.015,
"path_number": 2,
},
{
"clusterid": "cluster2",
"pathway_index": 0,
"path_segment_index": 0,
"coord1": [13, 14, 15],
"coord2": [16, 17, 18],
"color": [0, 1, 0],
"radius": 0.015,
"path_number": 1,
},
]

result = mdpath.src.visualization.precompute_path_properties(json_data)
assert result == expected_output

def test_precompute_cluster_properties_quick():
json_data = {
'cluster1': [
[[[1, 2, 3]], [[4, 5, 6]]],
[[[1, 2, 3]], [[4, 5, 6]]]
],
'cluster2': [
[[[7, 8, 9]], [[10, 11, 12]]]
]
}

expected_output = [
{'clusterid': 'cluster1', 'coord1': [1, 2, 3], 'coord2': [4, 5, 6], 'color': [1, 0, 0], 'radius': 0.015},
{'clusterid': 'cluster1', 'coord1': [1, 2, 3], 'coord2': [4, 5, 6], 'color': [1, 0, 0], 'radius': 0.03},
{'clusterid': 'cluster2', 'coord1': [7, 8, 9], 'coord2': [10, 11, 12], 'color': [0, 1, 0], 'radius': 0.015}
]
actual_output = mdpath.src.visualization.precompute_cluster_properties_quick(json_data)
assert actual_output == expected_output, f"Expected {expected_output}, but got {actual_output}"

0 comments on commit 463069d

Please sign in to comment.