Skip to content

Commit

Permalink
Require dihedral indices when looking up by SMILES
Browse files Browse the repository at this point in the history
Because a torsion drive can drive a given mapped SMILES at different
atom quartets, the mapped SMILES alone cannot uniquely identify a
"molecule ID" (although in this model, "molecule ID" really serves
as an identifier of an individual torsion drive, not a molecule.)
  • Loading branch information
mattwthompson committed Dec 16, 2024
1 parent 91a27ec commit b7df922
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 16 deletions.
38 changes: 26 additions & 12 deletions yammbs/torsion/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,36 @@ def get_molecule_ids(self) -> list[int]:
with self._get_session() as db:
return [molecule_id for (molecule_id,) in db.db.query(DBTorsionRecord.id).distinct()]

# TODO: Allow by multiple selectors (smiles: list[str])
def get_molecule_id_by_smiles(self, smiles: str) -> int:
# TODO: Allow by multiple selectors (how to do with multiple args? 1-arg case is smiles: list[str])
def get_molecule_id_by_smiles_and_dihedral_indices(
self,
smiles: str,
dihedral_indices: tuple[int, int, int, int],
) -> int:
with self._get_session() as db:
return next(id for (id,) in db.db.query(DBTorsionRecord.id).filter_by(mapped_smiles=smiles).all())
return next(
id
for (id,) in db.db.query(DBTorsionRecord.id)
.filter_by(
mapped_smiles=smiles,
dihedral_indices=dihedral_indices,
)
.all()
)

# TODO: Allow by multiple selectors (id: list[int])
def get_smiles_by_molecule_id(self, id: int) -> str:
with self._get_session() as db:
return next(smiles for (smiles,) in db.db.query(DBTorsionRecord.mapped_smiles).filter_by(id=id).all())

# TODO: Allow by multiple selectors (id: list[int])
def get_dihedral_indices_by_molecule_id(self, id: int) -> tuple[int, int, int, int]:
with self._get_session() as db:
return next(
dihedral_indices
for (dihedral_indices,) in db.db.query(DBTorsionRecord.dihedral_indices).filter_by(id=id).all()
)

def get_force_fields(
self,
) -> list[str]:
Expand All @@ -131,13 +151,6 @@ def get_force_fields(
).distinct()
]

def get_dihedral_indices_by_molecule_id(self, id: int) -> tuple[int, int, int, int]:
with self._get_session() as db:
return next(
dihedral_indices
for (dihedral_indices,) in db.db.query(DBTorsionRecord.dihedral_indices).filter_by(id=id).all()
)

def get_qm_points_by_molecule_id(self, id: int) -> dict[float, NDArray]:
with self._get_session() as db:
return {
Expand Down Expand Up @@ -221,8 +234,9 @@ def from_torsion_dataset(

for angle in qm_torsion.coordinates:
qm_point_record = QMTorsionPointRecord(
molecule_id=store.get_molecule_id_by_smiles(
torsion_record.mapped_smiles,
molecule_id=store.get_molecule_id_by_smiles_and_dihedral_indices(
smiles=torsion_record.mapped_smiles,
dihedral_indices=torsion_record.dihedral_indices,
),
grid_id=angle, # TODO: This needs to be a tuple later
coordinates=qm_torsion.coordinates[angle],
Expand Down
3 changes: 1 addition & 2 deletions yammbs/torsion/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ def _normalize(qm: dict[float, float], mm: dict[float, float]) -> tuple[dict[flo
"""Normalize, after sorting, a pair of QM and MM profiles to the values at the QM minimum."""
if len(mm) == 0:
LOGGER.warning(
"no mm data, returning empty dicts; "
f"length of qm dict is {len(qm)=}",
f"no mm data, returning empty dicts; length of qm dict is {len(qm)=}",
)
return dict(), dict()

Expand Down
2 changes: 1 addition & 1 deletion yammbs/torsion/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class TorsionDataset(ImmutableModel):

class TorsionProfile(ImmutableModel):
mapped_smiles: str
dihedral_indices: list[int] = Field(
dihedral_indices: tuple[int, int, int, int] = Field(
...,
description="The indices, 0-indexed, of the atoms which define the driven dihedral angle",
)
Expand Down
2 changes: 1 addition & 1 deletion yammbs/torsion/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TorsionRecord(MoleculeRecord):
coordinates of the molecule in different conformers, and partial charges / WBOs
computed for those conformers."""

dihedral_indices: list[int] = Field(
dihedral_indices: tuple[int, int, int, int] = Field(
...,
description="The indices of the atoms which define the driven dihedral angle",
)
Expand Down

0 comments on commit b7df922

Please sign in to comment.