Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up loading a MoleculeStore from a QCArchiveDataset #81

Merged
merged 5 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 44 additions & 17 deletions yammbs/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas
from numpy.typing import NDArray
from openff.qcsubmit.results import OptimizationResultCollection
from openff.toolkit import Molecule, Quantity
from openff.toolkit import Molecule
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from typing_extensions import Self
Expand Down Expand Up @@ -533,7 +533,7 @@ def from_qcarchive_dataset(
"""
Create a new MoleculeStore databset from YAMMBS's QCArchiveDataset model.

Largely adopted from `from_qcsubmit_collection`.
Largely adopted from `from_cached_result_collection`.
"""
from tqdm import tqdm

Expand All @@ -542,24 +542,51 @@ def from_qcarchive_dataset(

store = cls(database_name)

for qm_molecule in tqdm(dataset.qm_molecules, desc="Storing molecules"):
molecule = Molecule.from_mapped_smiles(qm_molecule.mapped_smiles, allow_undefined_stereo=True)
molecule.add_conformer(Quantity(qm_molecule.coordinates, "angstrom"))
mattwthompson marked this conversation as resolved.
Show resolved Hide resolved
# adapted from MoleculeRecord.from_molecule, MoleculeStore.store, and
# DBSessionManager.store_molecule_record
with store._get_session() as db:
# instead of DBSessionManager._smiles_already_exists
seen = set(db.db.query(DBMoleculeRecord.mapped_smiles))
for qm_molecule in tqdm(dataset.qm_molecules, desc="Storing molecules"):
if qm_molecule.mapped_smiles in seen:
continue
seen.add(qm_molecule.mapped_smiles)
molecule = Molecule.from_mapped_smiles(qm_molecule.mapped_smiles, allow_undefined_stereo=True)
db_record = DBMoleculeRecord(
mapped_smiles=qm_molecule.mapped_smiles,
inchi_key=molecule.to_inchi(fixed_hydrogens=True),
)
db.db.add(db_record)

molecule_record = MoleculeRecord.from_molecule(molecule)
store.store(molecule_record)
# close the session here and re-open to make sure all of the molecule
# IDs have been flushed to the db

store.store_qcarchive(
QMConformerRecord(
molecule_id=store.get_molecule_id_by_smiles(
molecule_record.mapped_smiles,
# adapted from MoleculeStore.store_qcarchive,
# QMConformerRecord.from_qcarchive_record, and
# DBSessionManager.store_qm_conformer_record
with store._get_session() as db:
# reversed so the first record encountered wins out. this matches
# the behavior of the version that queries the db each time
smiles_to_id = {
smi: id
for id, smi in reversed(
db.db.query(
DBMoleculeRecord.id,
DBMoleculeRecord.mapped_smiles,
).all(),
)
}
for record in tqdm(dataset.qm_molecules, desc="Storing Records"):
mol_id = smiles_to_id[record.mapped_smiles]
db.db.add(
DBQMConformerRecord(
parent_id=mol_id,
qcarchive_id=record.qcarchive_id,
mapped_smiles=record.mapped_smiles,
coordinates=record.coordinates,
energy=record.final_energy,
),
qcarchive_id=qm_molecule.qcarchive_id,
mapped_smiles=qm_molecule.mapped_smiles,
coordinates=qm_molecule.coordinates,
energy=qm_molecule.final_energy,
),
)
)

return store

Expand Down
2 changes: 2 additions & 0 deletions yammbs/_tests/unit_tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def test_from_qcarchive_dataset(small_qcsubmit_collection):
# Ensure a new object can be created from the same database
assert len(MoleculeStore(db)) == len(store)

assert len(store.get_smiles()) == small_qcsubmit_collection.n_molecules


def test_from_qcarchive_dataset_undefined_stereo():
"""Test loading from YAMMBS's QCArchive model with undefined stereochemistry"""
Expand Down