Skip to content

Commit

Permalink
Typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mattwthompson committed Nov 19, 2024
1 parent b3a8150 commit d8ecfdb
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 27 deletions.
3 changes: 1 addition & 2 deletions yammbs/_tests/unit_tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy
import qcelemental
from openff.units import unit

from yammbs.models import MoleculeRecord, QMConformerRecord

Expand Down Expand Up @@ -30,5 +29,5 @@ def test_load_from_qcsubmit(small_qcsubmit_collection):
assert qm_conformer.qcarchive_id == qc_record.id
assert numpy.allclose(
qm_conformer.coordinates,
molecule.conformers[0].m_as(unit.angstrom),
molecule.conformers[0].m_as("angstrom"),
)
3 changes: 1 addition & 2 deletions yammbs/_tests/unit_tests/test_molecule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy
from openff.toolkit import Molecule
from openff.units import unit

from yammbs._molecule import _to_geometric_molecule

Expand All @@ -15,6 +14,6 @@ def test_to_geometric_molecule():
assert molecule.n_bonds == len(geometric_molecule.Data["bonds"])

assert numpy.allclose(
molecule.conformers[0].m_as(unit.angstrom),
molecule.conformers[0].m_as("angstrom"),
geometric_molecule.Data["xyzs"][0],
)
16 changes: 6 additions & 10 deletions yammbs/analysis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import TYPE_CHECKING

import numpy
from openff.toolkit import Molecule
from openff.units import Quantity, unit
from openff.toolkit import Molecule, Quantity

from yammbs._base.array import Array
from yammbs._base.base import ImmutableModel
Expand Down Expand Up @@ -102,7 +101,6 @@ def get_rmsd(
) -> float:
"""Compute the RMSD between two sets of coordinates."""
from openeye import oechem
from openff.units import Quantity, unit

molecule1 = Molecule(molecule)
molecule2 = Molecule(molecule)
Expand All @@ -111,9 +109,9 @@ def get_rmsd(
if molecule.conformers is not None:
molecule.conformers.clear()

molecule1.add_conformer(Quantity(reference, unit.angstrom)) # type: ignore[call-overload]
molecule1.add_conformer(Quantity(reference, "angstrom"))

molecule2.add_conformer(Quantity(target, unit.angstrom)) # type: ignore[call-overload]
molecule2.add_conformer(Quantity(target, "angstrom"))

# oechem appears to not support named arguments, but it's hard to tell
# since the Python API is not documented
Expand Down Expand Up @@ -145,10 +143,10 @@ def get_internal_coordinate_rmsds(
from yammbs._molecule import _to_geometric_molecule

if isinstance(reference, Quantity):
reference = reference.m_as(unit.angstrom)
reference = reference.m_as("angstrom")

if isinstance(target, Quantity):
target = target.m_as(unit.angstrom)
target = target.m_as("angstrom")

_generator = PrimitiveInternalCoordinates(
_to_geometric_molecule(molecule=molecule, coordinates=target),
Expand Down Expand Up @@ -211,13 +209,11 @@ def _rdmol(
molecule: Molecule,
conformer: Array,
):
from openff.units import Quantity, unit

molecule = Molecule(molecule)
if molecule.conformers is not None:
molecule.conformers.clear()

molecule.add_conformer(Quantity(conformer, unit.angstrom)) # type: ignore[call-overload]
molecule.add_conformer(Quantity(conformer, "angstrom"))

return molecule.to_rdkit()

Expand Down
4 changes: 2 additions & 2 deletions yammbs/torsion/_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ class DBTorsionRecord(DBBase): # type: ignore
dihedral_indices = Column(PickleType, nullable=False)


class DBTorsionProfileRecord(DBBase):
class DBTorsionProfileRecord(DBBase): # type: ignore
__tablename__ = "torsion_points"

id = Column(Integer, primary_key=True, index=True)

# is this like a molecule ID or like a QCArchive ID?
parent_id: int = Column(Integer, ForeignKey("molecules.id"), nullable=False, index=True)
parent_id = Column(Integer, ForeignKey("molecules.id"), nullable=False, index=True)

# TODO: Store QCArchive ID

Expand Down
14 changes: 6 additions & 8 deletions yammbs/torsion/_minimize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from multiprocessing import Pool
from typing import Generator

from numpy.typing import NDArray
from pydantic import Field
Expand All @@ -17,7 +18,7 @@ class ConstrainedMinimizationInput(ImmutableModel):
...,
description="The SMILES of the molecule",
)
dihedral_indices: list[int] = Field(
dihedral_indices: tuple[int, int, int, int] = Field(
...,
description="The indices of the atoms which define the driven dihedral angle",
)
Expand All @@ -44,16 +45,13 @@ class ConstrainedMinimizationResult(ConstrainedMinimizationInput):

def _minimize_torsions(
mapped_smiles: str,
dihedral_indices: list[int],
dihedral_indices: tuple[int, int, int, int],
qm_data: tuple[float, NDArray, float], # (grid_id, coordinates, energy)
force_field: str,
n_processes: int = 2,
chunksize=32,
):
# ) -> Iterator["MinimizationResult"]:
inputs = list()

inputs = [
) -> Generator[ConstrainedMinimizationResult, None, None]:
inputs = [ # type: ignore[misc]
ConstrainedMinimizationInput(
mapped_smiles=mapped_smiles,
dihedral_indices=dihedral_indices,
Expand Down Expand Up @@ -167,7 +165,7 @@ def _minimize_constrained(
for index in range(simulation.system.getNumParticles())
},
)
logging(input.dihedral_indices, input.mapped_smiles)
logging.error(input.dihedral_indices, input.mapped_smiles)

raise ConstrainedMinimizationError("Minimization failed, see logger") from e

Expand Down
7 changes: 4 additions & 3 deletions yammbs/torsion/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def get_force_fields(
).distinct()
]

def get_dihedral_indices_by_molecule_id(self, id: int) -> list[int]:
def get_dihedral_indices_by_molecule_id(self, id: int) -> tuple[int, int, int, int]:
with self._get_session() as db:
return next(
dihedral_indices
Expand Down Expand Up @@ -249,7 +249,7 @@ def optimize_mm(
for molecule_id in self.get_molecule_ids():
with self._get_session() as db:
# TODO: Implement "seen" behavior to short-circuit already-optimized torsions
qm_data = tuple(
qm_data: tuple[float, NDArray, float] = tuple( # type: ignore[assignment]
(grid_id, coordinates, energy)
for (grid_id, coordinates, energy) in db.db.query(
DBQMTorsionPointRecord.grid_id,
Expand Down Expand Up @@ -306,7 +306,8 @@ def get_log_sse(
continue

_qm = dict(sorted(_qm.items()))
qm_minimum_index = min(_qm, key=_qm.get)
qm_minimum_index = min(_qm, key=_qm.get) # type: ignore[arg-type]

qm = {key: _qm[key] - _qm[qm_minimum_index] for key in _qm}
mm = {key: _mm[key] - _mm[qm_minimum_index] for key in _mm}

Expand Down

0 comments on commit d8ecfdb

Please sign in to comment.