Skip to content

Commit

Permalink
Merge pull request #13 from mattwthompson/get_records
Browse files Browse the repository at this point in the history
Add record helpers
  • Loading branch information
mattwthompson authored Jan 18, 2024
2 parents bf3dcb4 + 060e312 commit 0d3bbe8
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Run tests
run: |
pytest -v --cov=ibstore/ --cov-report=xml ibstore/
pytest -v -nauto --cov=ibstore/ --cov-report=xml ibstore/
- name: CodeCov
uses: codecov/codecov-action@v3
Expand Down
56 changes: 56 additions & 0 deletions ibstore/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,18 @@ def get_qm_conformers_by_molecule_id(self, id: int) -> list:
.all()
]

def get_force_fields(
self,
) -> list[str]:
"""Return a list of all force fields with some conformers stored."""
with self._get_session() as db:
return [
force_field
for (force_field,) in db.db.query(
DBMMConformerRecord.force_field,
).distinct()
]

def get_mm_conformers_by_molecule_id(
self,
id: int,
Expand Down Expand Up @@ -309,6 +321,50 @@ def get_mm_energies_by_molecule_id(
.all()
]

def get_qm_conformer_records_by_molecule_id(
self,
molecule_id: int,
) -> list[QMConformerRecord]:
with self._get_session() as db:
contents = [
QMConformerRecord(
molecule_id=molecule_id,
qcarchive_id=x.qcarchive_id,
mapped_smiles=x.mapped_smiles,
coordinates=x.coordinates,
energy=x.energy,
)
for x in db.db.query(DBQMConformerRecord)
.filter_by(parent_id=molecule_id)
.order_by(DBQMConformerRecord.qcarchive_id)
.all()
]

return contents

def get_mm_conformer_records_by_molecule_id(
self,
molecule_id: int,
force_field: str,
) -> list[MMConformerRecord]:
with self._get_session() as db:
contents = [
MMConformerRecord(
molecule_id=molecule_id,
qcarchive_id=x.qcarchive_id,
force_field=x.force_field,
mapped_smiles=x.mapped_smiles,
coordinates=x.coordinates,
energy=x.energy,
)
for x in db.db.query(DBMMConformerRecord)
.filter_by(parent_id=molecule_id)
.filter_by(force_field=force_field)
.order_by(DBMMConformerRecord.qcarchive_id)
.all()
]
return contents

@classmethod
def from_qcsubmit_collection(
cls,
Expand Down
73 changes: 63 additions & 10 deletions ibstore/_tests/unit_tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,29 @@
import numpy
import pytest
from openff.qcsubmit.results import OptimizationResultCollection
from openff.toolkit import Molecule
from openff.utilities import get_data_file_path, temporary_cd

from ibstore._store import MoleculeStore
from ibstore.exceptions import DatabaseExistsError
from ibstore.models import MMConformerRecord, QMConformerRecord


@pytest.fixture()
def basic_ch_store():
# This file manually generated from data/01-processed-qm-ch.json
return MoleculeStore(
get_data_file_path(
"_tests/data/ch.sqlite",
package_name="ibstore",
),
)


@pytest.fixture()
def diphenylvinylbenzene():
"""Return 1,2-diphenylvinylbenzene"""
return Molecule.from_smiles("c1ccc(cc1)C=C(c2ccccc2)c3ccccc3")


def test_from_qcsubmit(small_collection):
Expand All @@ -32,16 +51,8 @@ def test_do_not_overwrite(small_collection):
)


def test_load_existing_databse():
# This file manually generated from data/01-processed-qm-ch.json
store = MoleculeStore(
get_data_file_path(
"_tests/data/01-processed-qm-ch.sqlite",
package_name="ibstore",
),
)

assert len(store) == 40
def test_load_existing_database(basic_ch_store):
assert len(basic_ch_store) == 40


def test_get_molecule_ids(small_store):
Expand Down Expand Up @@ -104,3 +115,45 @@ def test_get_conformers(small_store):
force_field=force_field,
)[-1],
)


def test_get_force_fields(basic_ch_store):
force_fields = basic_ch_store.get_force_fields()

assert len(force_fields) == 9

assert "openff-2.1.0" in force_fields
assert "gaff-2.11" in force_fields
assert "openff-3.0.0" not in force_fields


def test_get_mm_conformer_records_by_molecule_id(basic_ch_store, diphenylvinylbenzene):
records = basic_ch_store.get_mm_conformer_records_by_molecule_id(
1,
force_field="openff-2.1.0",
)

for record in records:
assert isinstance(record, MMConformerRecord)
assert record.molecule_id == 1
assert record.force_field == "openff-2.1.0"
assert record.coordinates.shape == (36, 3)
assert record.energy is not None

assert Molecule.from_mapped_smiles(record.mapped_smiles).is_isomorphic_with(
diphenylvinylbenzene,
)


def test_get_qm_conformer_records_by_molecule_id(basic_ch_store, diphenylvinylbenzene):
records = basic_ch_store.get_qm_conformer_records_by_molecule_id(1)

for record in records:
assert isinstance(record, QMConformerRecord)
assert record.molecule_id == 1
assert record.coordinates.shape == (36, 3)
assert record.energy is not None

assert Molecule.from_mapped_smiles(record.mapped_smiles).is_isomorphic_with(
diphenylvinylbenzene,
)

0 comments on commit 0d3bbe8

Please sign in to comment.