Skip to content

Commit

Permalink
Merge pull request #5 from mattwthompson/api-helpers
Browse files Browse the repository at this point in the history
Add lookups by molecule and QCArchive ID
  • Loading branch information
mattwthompson authored Aug 30, 2023
2 parents cd963cb + 2e9c4ba commit 9dbfb72
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 3 deletions.
5 changes: 2 additions & 3 deletions ibstore/_minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,15 @@ def _minimize_blob(
)

with Pool(processes=n_processes) as pool:
for result in tqdm(
yield from tqdm(
pool.imap(
_run_openmm,
inputs,
chunksize=chunksize,
),
desc=f"Building and minimizing systems with {force_field}",
total=len(inputs),
):
yield result
)


class MinimizationInput(ImmutableModel):
Expand Down
41 changes: 41 additions & 0 deletions ibstore/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,19 @@ def store_conformer(
else:
db.store_mm_conformer_record(record)

def get_molecule_ids(self) -> list[int]:
"""
Get the molecule IDs of all records in the store.
These are likely to be integers sequentially incrementing from 1, but that
is not guaranteed.
"""
with self._get_session() as db:
return [
molecule_id
for (molecule_id,) in db.db.query(DBMoleculeRecord.id).distinct()
]

def get_smiles(self) -> List[str]:
"""Get the (mapped) smiles of all records in the store."""
with self._get_session() as db:
Expand Down Expand Up @@ -214,6 +227,15 @@ def get_qcarchive_ids_by_molecule_id(self, id: int) -> list[str]:
.all()
]

def get_molecule_id_by_qcarchive_id(self, id: str) -> int:
with self._get_session() as db:
return [
molecule_id
for (molecule_id,) in db.db.query(DBQMConformerRecord.parent_id)
.filter_by(qcarchive_id=id)
.all()
][0]

def get_qm_conformers_by_molecule_id(self, id: int) -> list:
with self._get_session() as db:
return [
Expand All @@ -239,6 +261,25 @@ def get_mm_conformers_by_molecule_id(
.all()
]

def get_qm_conformer_by_qcarchive_id(self, id: int):
with self._get_session() as db:
return [
conformer
for (conformer,) in db.db.query(DBQMConformerRecord.coordinates)
.filter_by(qcarchive_id=id)
.all()
][0]

def get_mm_conformer_by_qcarchive_id(self, id: int, force_field: str):
with self._get_session() as db:
return [
conformer
for (conformer,) in db.db.query(DBMMConformerRecord.coordinates)
.filter_by(qcarchive_id=id)
.filter_by(force_field=force_field)
.all()
][0]

# TODO: Allow by multiple selectors (id: list[int])
def get_qm_energies_by_molecule_id(self, id: int) -> list[float]:
with self._get_session() as db:
Expand Down
12 changes: 12 additions & 0 deletions ibstore/_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,21 @@
from openff.qcsubmit.results import OptimizationResultCollection
from openff.utilities.utilities import get_data_file_path

from ibstore._store import MoleculeStore


@pytest.fixture()
def small_collection() -> OptimizationResultCollection:
return OptimizationResultCollection.parse_file(
get_data_file_path("_tests/data/01-processed-qm-ch.json", "ibstore"),
)


@pytest.fixture()
def small_store() -> MoleculeStore:
return MoleculeStore(
get_data_file_path(
"_tests/data/ch.sqlite",
package_name="ibstore",
),
)
41 changes: 41 additions & 0 deletions ibstore/_tests/unit_tests/test_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tempfile

import numpy
import pytest
from openff.utilities import get_data_file_path, temporary_cd

Expand Down Expand Up @@ -38,3 +39,43 @@ def test_load_existing_databse():
)

assert len(store) == 40


def test_get_molecule_ids(small_store):
molecule_ids = small_store.get_molecule_ids()

assert len(molecule_ids) == len({*molecule_ids}) == 40

assert min(molecule_ids) == 1
assert max(molecule_ids) == 40


def test_get_molecule_id_by_qcarchive_id(small_store):
molecule_id = 40
qcarchive_id = small_store.get_qcarchive_ids_by_molecule_id(molecule_id)[-1]

assert small_store.get_molecule_id_by_qcarchive_id(qcarchive_id) == molecule_id


def test_get_conformers(small_store):
force_field = "openff-2.0.0"
molecule_id = 40
qcarchive_id = small_store.get_qcarchive_ids_by_molecule_id(molecule_id)[-1]

numpy.testing.assert_allclose(
small_store.get_qm_conformer_by_qcarchive_id(
qcarchive_id,
),
small_store.get_qm_conformers_by_molecule_id(molecule_id)[-1],
)

numpy.testing.assert_allclose(
small_store.get_mm_conformer_by_qcarchive_id(
qcarchive_id,
force_field=force_field,
),
small_store.get_mm_conformers_by_molecule_id(
molecule_id,
force_field=force_field,
)[-1],
)

0 comments on commit 9dbfb72

Please sign in to comment.