Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid importing tf_routes #118

Merged
merged 10 commits into from
Dec 20, 2023
1 change: 0 additions & 1 deletion devtools/conda-envs/fep_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,5 @@ dependencies:
- pytest-cov
- codecov
- pip:
- git+https://github.com/wiederm/tf_routes.git
- git+https://github.com/ParmEd/ParmEd.git
- git+https://github.com/wiederm/transformato_testsystems.git
253 changes: 220 additions & 33 deletions transformato/mutate.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,49 +526,236 @@ def _match_terminal_dummy_atoms_between_common_cores(

return (lj_default_cc1, lj_default_cc2)

@staticmethod
def _calculate_order_of_LJ_mutations(
def change_route_cycles(route, cycledict, degreedict, weightdict, G):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest that we put all these functions into a file that we call e.g. rdkit_functions or something like this.

"""
preliminary mutation list is sorted using cycle and degree dictionary
currently used in _calculate_order_of_LJ_mutations_new
----
Args:
route: original mutation route
cycledict: dict of cycle participation of atoms
degreedict: dict of degree of atoms
weightdict: dict of weight of atoms
G: nx-graph of molecule
----
returns reordered array of the mutation route
"""

for i in range(len(route) - 1):
routedict = route[i]
routeweight = weightdict.get(route[i])

routecycleval = cycledict.get(route[i])
routedegreeval = degreedict.get(route[i])

for j in range(i, len(route)):
if routeweight == weightdict[route[j]]:
# if nodes have same weight (i.e. distance from root), the node participating in more cycles is removed later

if routecycleval > cycledict[route[j]] or (
routecycleval == cycledict[route[j]]
and routedegreeval > degreedict[route[j]]
):
idx1 = route.index(route[i])
idx2 = route.index(route[j])
route[idx1], route[idx2] = route[idx2], route[idx1]
continue

# if nodes have same weight (i.e. distance from root) and same cycle participation number, the node which has more neighbours already removed is removed earlier

if routecycleval == cycledict[route[j]]:
edgesi = G.edges(routedict)
edgesj = G.edges(route[j])

iedgecounter = 0
for edge in edgesi:
if edge[1] in route[0:i] or edge[0] in route[0:i]:
iedgecounter = iedgecounter + 1

jedgecounter = 0
for edge in edgesj:
if edge[1] in route[0:i] or edge[0] in route[0:i]:
jedgecounter = jedgecounter + 1

if iedgecounter < jedgecounter:
idx1 = route.index(route[i])
idx2 = route.index(route[j])
route[idx1], route[idx2] = route[idx2], route[idx1]

return route

def cycle_checks_nx(G, use_actual_weight_for_mod=False):
"""
cycle processing, can be used in _calculate_order_of_LJ_mutations_new_iter and ..._new_iter_change (default is cycle_checks_nx_v2)
--------
returns nx-graph-object with updated weights (according to cycle participation of the atom)
"""

# search cycles using networkx
cycles = nx.cycle_basis(G)

from collections import Counter

cdict = Counter(x for xs in cycles for x in set(xs))

# modify weighted graph: nodes participating in many cycles get lower weight
for i in cdict:
edg = G.edges(i)
for el in edg:
if use_actual_weight_for_mod == True:
G[el[0]][el[1]]["weight"] = G[el[0]][el[1]]["weight"] - cdict[i] * (
G[el[0]][el[1]]["weight"] ** (1 / 2)
)
else:
G[el[0]][el[1]]["weight"] = G[el[0]][el[1]]["weight"] - cdict[i] * 5

return G

def cycle_checks(G):
"""
cycle processing dictionary and degree dictionary for preferential removal (atoms which neighbours already have been removed are removed earlier), currently used in _calculate_order_of_LJ_mutations_new (via change_route_cycles)
----
returns dictionary containing number of cycle participation (cdict) and dict containing degree of atom (degreedict)
"""

# search cycles using networkx
cycles = nx.cycle_basis(G)

# alternatively, using rdkit
# ri = mol.GetRingInfo()
# cyclesrdkit = ri.AtomRings()

import collections
from collections import Counter

cdict = Counter(x for xs in cycles for x in set(xs))
# cdictrdkit = Counter(x for xs in cyclesrdkit for x in set(xs))

# add atoms with no cycle participation
for key in G.nodes:
if key not in cdict:
cdict[key] = 0

degreedict = G.degree()
degreedict = {node: val for (node, val) in degreedict}

return cdict, degreedict

def exclude_Hs_from_mutations(connected_dummy_regions: list, G: nx.Graph):
"""
hydrogens are removed from the networkx-graph-representation and the list of connected dummy regions
----
Args:
connected_dummy_regions: list of connected dummy regions
G: nx-graph of molecule
----
returns list of connected dummy regions and networkx-graph without hydrogens
"""

G_hydrogens = [x for x, y in G.nodes(data=True) if y["atom_type"] == "H"]

G.remove_nodes_from(G_hydrogens)
connected_dummy_regions_copy = connected_dummy_regions
for hydroindex in G_hydrogens:
for indexregion, region in enumerate(connected_dummy_regions):
if hydroindex in region:
connected_dummy_regions_copy[indexregion].remove(hydroindex)

return connected_dummy_regions_copy, G

def _calculate_order_of_LJ_mutations_with_bfs(
self,
connected_dummy_regions: list,
match_terminal_atoms: dict,
G: nx.Graph,
cyclecheck=True,
ordercycles=True,
exclude_Hs=True,
) -> list:
try:
from tf_routes.routes import (
_calculate_order_of_LJ_mutations_new as _calculate_order_of_LJ_mutations_with_bfs,
)
"""
bfs/djikstra-algorithm applied once for route (without iterations)
-----
cyclecheck: updates weights according to cycle participation (should always be set to True)
ordercheck: if there is no possibility to decide between two nodes - i.e. the weight would be the exactly the same - weight updating according to preferential removal decides that the node in which neighbourhood nodes already have been removed is removed next
exclude_Hs: if True, hydrogens are removed before the mutation algorithm is applied - necessary for usual Transformato workflow
"""

return _calculate_order_of_LJ_mutations_with_bfs(
connected_dummy_regions, match_terminal_atoms, G
if exclude_Hs == True:
connected_dummy_regions, G = self.exclude_Hs_from_mutations(
connected_dummy_regions, G
)

except ModuleNotFoundError:
ordered_LJ_mutations = []
for real_atom in match_terminal_atoms:
for dummy_atom in match_terminal_atoms[real_atom]:
for connected_dummy_region in connected_dummy_regions:
# stop at connected dummy region with specific dummy_atom in it
if dummy_atom not in connected_dummy_region:
continue

G_dummy = G.copy()
# delete all nodes not in dummy region
remove_nodes = [
node
for node in G.nodes()
if node not in connected_dummy_region
]
for remove_node in remove_nodes:
G_dummy.remove_node(remove_node)
ordered_LJ_mutations = []

for real_atom in match_terminal_atoms:
for dummy_atom in match_terminal_atoms[real_atom]:
for connected_dummy_region in connected_dummy_regions:
# stop at connected dummy region with specific dummy_atom in it
if dummy_atom not in connected_dummy_region:
continue

# root is the dummy atom that connects the real region with the dummy region
root = dummy_atom
G_dummy = G.copy()
# delete all nodes not in dummy region
remove_nodes = [
node for node in G.nodes() if node not in connected_dummy_region
]
for remove_node in remove_nodes:
G_dummy.remove_node(remove_node)

# root is the dummy atom that connects the real region with the dummy region
root = dummy_atom

# process cycles
if cyclecheck == True and ordercycles == False:
G_dummy = self.cycle_checks_nx(G_dummy)

edges = list(nx.dfs_edges(G_dummy, source=root))
nodes = [root] + [v for u, v in edges]
nodes.reverse() # NOTE: reverse the mutation
ordered_LJ_mutations.append(nodes)
# process cycles and correct order (according to 'preferential removal')
if cyclecheck == True and ordercycles == True:
cycledict, degreedict = self.cycle_checks(G_dummy)

return ordered_LJ_mutations
# dijkstra
ssource = nx.single_source_dijkstra(
G_dummy, source=root, weight="weight"
)
# result of dijkstra algorithm is sorted
sortedssource = {
k: v
for k, v in sorted(ssource[0].items(), key=lambda item: item[1])
}

# get keys of sorted dict
sortedssource_edges = sortedssource.keys()

sortedssource_edges_list = list(sortedssource_edges)
# sorted list contains the mutation route
nodes = sortedssource_edges_list

# order has to be reversed - the most distant atom is the first to be removed
nodes.reverse()

# sort nodes according to degree, cycle participation and removal order
if cyclecheck == True and ordercycles == True:
nodes = self.change_route_cycles(
nodes, cycledict, degreedict, sortedssource, G
)

logger.info("Final mutation route:")
logger.info(nodes)
ordered_LJ_mutations.append(nodes)

return ordered_LJ_mutations

@staticmethod
def _calculate_order_of_LJ_mutations(
self,
connected_dummy_regions: list,
match_terminal_atoms: dict,
G: nx.Graph,
) -> list:

return self._calculate_order_of_LJ_mutations_with_bfs(
connected_dummy_regions, match_terminal_atoms, G
)

def _check_for_lp(
self,
Expand Down
Loading