Skip to content

Commit

Permalink
Merge pull request #121 from choderalab/retro-tab
Browse files Browse the repository at this point in the history
Add "Retrospective" tab to dashboard
  • Loading branch information
dotsdl authored Jun 15, 2021
2 parents 7933e26 + 8c0d75f commit 81e3639
Show file tree
Hide file tree
Showing 11 changed files with 442 additions and 63 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies:
- openmmtools
- pandas >= 1.1
- perses
- plotly
- pydantic
- pymbar
- python >= 3.6
Expand Down
167 changes: 158 additions & 9 deletions fah_xchem/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,44 @@
import multiprocessing
import os
from typing import List, Optional
import networkx as nx
import numpy as np

from ..fah_utils import list_results
from ..schema import (
AnalysisConfig,
CompoundSeries,
CompoundSeriesAnalysis,
CompoundMicrostate,
FahConfig,
GenAnalysis,
PhaseAnalysis,
PointEstimate,
ProjectPair,
Transformation,
TransformationAnalysis,
WorkPair,
FragalysisConfig,
RunStatus
)
from .diffnet import combine_free_energies
from .constants import KT_KCALMOL
from .diffnet import combine_free_energies, pIC50_to_DG
from .exceptions import AnalysisError, DataValidationError
from .extract_work import extract_work_pair
from .free_energy import compute_relative_free_energy
from .free_energy import compute_relative_free_energy, InsufficientDataError
from .plots import generate_plots
from .report import generate_report, gens_are_consistent
from .structures import generate_representative_snapshots
from .website import generate_website


EXP_DDG_IJ_ERR = 0.2 # TODO check this is correct


def analyze_phase(server: FahConfig, run: int, project: int, config: AnalysisConfig):

paths = list_results(config=server, run=run, project=project)

if not paths:
raise AnalysisError(f"No data found for project {project}, RUN {run}")

Expand Down Expand Up @@ -62,14 +71,24 @@ def get_gen_analysis(gen: int, works: List[WorkPair]) -> GenAnalysis:
# TODO: round raw work output?
return GenAnalysis(gen=gen, works=filtered_works, free_energy=free_energy)

# Analyze gens, omitting incomplete gens
gens = list()
for gen, works in works_by_gen.items():
try:
gens.append( get_gen_analysis(gen, works) )
except InsufficientDataError as e:
# It's OK if we don't have sufficient data here
pass

return PhaseAnalysis(
free_energy=free_energy,
gens=[get_gen_analysis(gen, works) for gen, works in works_by_gen.items()],
gens=gens,
)


def analyze_transformation(
transformation: Transformation,
compounds: CompoundSeries,
projects: ProjectPair,
server: FahConfig,
config: AnalysisConfig,
Expand All @@ -86,11 +105,37 @@ def analyze_transformation(
complex_phase.free_energy.delta_f - solvent_phase.free_energy.delta_f
)

# get associated DDGs between compounds, if experimentally known
exp_ddg = calc_exp_ddg(transformation=transformation, compounds=compounds)
absolute_error = (
abs(binding_free_energy - exp_ddg) if (exp_ddg.point is not None) else None
)

# Check for consistency across GENS, if requested
consistent_bool = None
if filter_gen_consistency:
consistent_bool = gens_are_consistent(
complex_phase=complex_phase, solvent_phase=solvent_phase, nsigma=3
complex_phase=complex_phase, solvent_phase=solvent_phase, nsigma=1
)

return TransformationAnalysis(
transformation=transformation,
reliable_transformation=consistent_bool,
binding_free_energy=binding_free_energy,
complex_phase=complex_phase,
solvent_phase=solvent_phase,
exp_ddg=exp_ddg,
absolute_error=absolute_error,
)

else:

return TransformationAnalysis(
transformation=transformation,
binding_free_energy=binding_free_energy,
complex_phase=complex_phase,
solvent_phase=solvent_phase,
exp_ddg=exp_ddg,
)

return TransformationAnalysis(
Expand All @@ -101,45 +146,142 @@ def analyze_transformation(
solvent_phase=solvent_phase,
)


def calc_exp_ddg(transformation: TransformationAnalysis, compounds: CompoundSeries):
"""
Compute experimental free energy difference between two compounds, if available.
NOTE: This method makes the approximation that each microstate has the same affinity as the parent compound.
TODO: Instead, solve DiffNet without experimental data and use derived DDGs between compounds (not transformations).
Parameters
----------
transformation : TransformationAnalysis
The transformation of interest
compounds : CompoundSeries
Data for the compound series.
Returns
-------
ddg : PointEstimate
Point estimate of free energy difference for this transformation,
or PointEstimate(None, None) if not available.
"""
compounds_by_microstate = {
microstate.microstate_id: compound
for compound in compounds
for microstate in compound.microstates
}

initial_experimental_data = compounds_by_microstate[
transformation.initial_microstate.microstate_id
].metadata.experimental_data
final_experimental_data = compounds_by_microstate[
transformation.final_microstate.microstate_id
].metadata.experimental_data

if ("pIC50" in initial_experimental_data) and ("pIC50" in final_experimental_data):
initial_dg = PointEstimate(
point=pIC50_to_DG(initial_experimental_data["pIC50"]), stderr=EXP_DDG_IJ_ERR
)
final_dg = PointEstimate(
point=pIC50_to_DG(final_experimental_data["pIC50"]), stderr=EXP_DDG_IJ_ERR
)
error = final_dg - initial_dg
return error
else:
return PointEstimate(point=None, stderr=None)


def analyze_transformation_or_warn(
transformation: Transformation, **kwargs
) -> Optional[TransformationAnalysis]:

try:
return analyze_transformation(transformation, **kwargs)
except AnalysisError as exc:
logging.warning("Failed to analyze RUN%d: %s", transformation.run_id, exc)
return None


def analyze_compound_series(
def analyze_compound_series(
series: CompoundSeries,
config: AnalysisConfig,
server: FahConfig,
num_procs: Optional[int] = None,
) -> CompoundSeriesAnalysis:
"""
Analyze a compound series to generate JSON.
"""
from rich.progress import track

# TODO: Cache results and only update RUNs for which we have received new data

# Pre-filter based on which transformations have any work data
logging.info(f'Pre-filtering {len(series.transformations)} transformations to identify those with work data...')
available_transformations = [
transformation for transformation in series.transformations
if len(list_results(config=server, run=transformation.run_id, project=series.metadata.fah_projects.complex_phase)) > 0
and len(list_results(config=server, run=transformation.run_id, project=series.metadata.fah_projects.solvent_phase)) > 0
]
#available_transformations = series.transformations[:50]

# Process compound series in parallel
logging.info(f'Processing {len(available_transformations)} / {len(series.transformations)} available transformations in parallel...')
with multiprocessing.Pool(num_procs) as pool:
results_iter = pool.imap_unordered(
partial(
analyze_transformation_or_warn,
projects=series.metadata.fah_projects,
server=server,
config=config,
compounds=series.compounds,
),
series.transformations,
available_transformations,
)
transformations = [
result
for result in track(
results_iter,
total=len(series.transformations),
total=len(available_transformations),
description="Computing transformation free energies",
)
if result is not None
]

# Reprocess transformation experimental errors to only include most favorable transformation
# NOTE: This is a hack, and should be replaced by a more robust method for accounting for racemic mixtures
# Compile list of all microstate transformations for each compound
compound_ddgs = dict()
for transformation in transformations:
compound_id = transformation.transformation.final_microstate.compound_id
if compound_id in compound_ddgs:
compound_ddgs[compound_id].append(transformation.binding_free_energy.point)
else:
compound_ddgs[compound_id] = [transformation.binding_free_energy.point]
# Collapse to a single estimate
from scipy.special import logsumexp
for compound_id, ddgs in compound_ddgs.items():
compound_ddgs[compound_id] = -logsumexp(-np.array(ddgs)) + np.log(len(ddgs))
# Regenerate list of transformations
for index, t in enumerate(transformations):
if (t.exp_ddg is None) or (t.exp_ddg.point is None):
continue
compound_id = t.transformation.final_microstate.compound_id
absolute_error_point = abs(t.exp_ddg.point - compound_ddgs[compound_id])
transformations[index] = TransformationAnalysis(
transformation=t.transformation,
reliable_transformation=t.reliable_transformation,
binding_free_energy=t.binding_free_energy,
complex_phase=t.complex_phase,
solvent_phase=t.solvent_phase,
exp_ddg=t.exp_ddg,
absolute_error=PointEstimate(point=absolute_error_point, stderr=t.absolute_error.stderr),
)

# Sort transformations by RUN
# transformations.sort(key=lambda transformation_analysis : transformation_analysis.transformation.run_id)
# Sort transformations by free energy difference
Expand Down Expand Up @@ -190,10 +332,17 @@ def generate_artifacts(
data_dir, f"PROJ{series.metadata.fah_projects.complex_phase}"
)

# Pre-filter based on which transformations have any data
available_transformations = [
transformation for transformation in series.transformations
if transformation.binding_free_energy is not None
and transformation.binding_free_energy.point is not None
]

if snapshots:
logging.info("Generating representative snapshots")
generate_representative_snapshots(
transformations=series.transformations,
transformations=available_transformations,
project_dir=complex_project_dir,
project_data_dir=complex_data_dir,
output_dir=os.path.join(output_dir, "transformations"),
Expand Down
2 changes: 1 addition & 1 deletion fah_xchem/analysis/free_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def compute_relative_free_energy(

# TODO: Flag problematic RUN/CLONE/GEN trajectories for further analysis and debugging
works = _filter_work_values(all_works)

if len(works) < (min_num_work_values or 1):
raise InsufficientDataError(
f"Need at least {min_num_work_values} good work values for analysis, "
Expand Down
59 changes: 59 additions & 0 deletions fah_xchem/analysis/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from matplotlib.font_manager import FontProperties
import multiprocessing
import numpy as np
import networkx as nx
import pandas as pd
from pymbar import BAR
from typing import Generator, Iterable, List, Optional
Expand All @@ -19,6 +20,37 @@
TransformationAnalysis,
)
from .constants import KT_KCALMOL
from arsenic import plotting


def plot_retrospective(
transformations: List[TransformationAnalysis],
output_dir: str,
filename: str = "retrospective",
):

graph = nx.DiGraph()

# TODO this loop can be sped up
for analysis in transformations:
transformation = analysis.transformation

# Only interested if the compounds have an experimental DDG
if analysis.binding_free_energy is None or analysis.exp_ddg.point is None:
continue

graph.add_edge(
transformation.initial_microstate,
transformation.final_microstate,
exp_DDG=analysis.exp_ddg.point * KT_KCALMOL,
exp_dDDG=analysis.exp_ddg.stderr * KT_KCALMOL,
calc_DDG=analysis.binding_free_energy.point * KT_KCALMOL,
calc_dDDG=analysis.binding_free_energy.stderr * KT_KCALMOL,
)

filename_png = filename + ".png"

plotting.plot_DDGs(graph, filename=os.path.join(output_dir, filename_png))


def plot_work_distributions(
Expand Down Expand Up @@ -721,6 +753,8 @@ def generate_plots(
"""
from rich.progress import track

# TODO: Cache results and only update RUNs for which we have received new data

binding_delta_fs = [
transformation.binding_free_energy.point
for transformation in series.transformations
Expand Down Expand Up @@ -763,3 +797,28 @@ def generate_plots(
description="Generating plots",
):
pass

#
# Retrospective plots
#

# NOTE this is handled by Arsenic
# this needs to be plotted last as the figure isn't cleared by default in Arsenic
# TODO generate time stamp

# All transformations
plot_retrospective(output_dir=output_dir, transformations=series.transformations, filename='retrospective-transformations-all')

# Reliable subset of transformations
plot_retrospective(output_dir=output_dir, transformations=[transformation for transformation in series.transformations if transformation.reliable_transformation], filename='retrospective-transformations-reliable')

# Transformations not involving racemates
# TODO: Find a simpler way to filter non-racemates
nmicrostates = { compound.metadata.compound_id : len(compound.microstates) for compound in series.compounds }
def is_racemate(microstate):
return True if (nmicrostates[microstate.compound_id] > 1) else False
plot_retrospective(
output_dir=output_dir,
transformations=[transformation for transformation in series.transformations if (not is_racemate(transformation.transformation.initial_microstate) and not is_racemate(transformation.transformation.final_microstate))],
filename='retrospective-transformations-noracemates'
)
Loading

0 comments on commit 81e3639

Please sign in to comment.