Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dep warnings for dict mapping #296

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion gufe/protocols/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import abc
from typing import Optional, Iterable, Any, Union
from openff.units import Quantity
import warnings

from ..settings import Settings, SettingsBaseModel
from ..tokenization import GufeTokenizable, GufeKey
Expand Down Expand Up @@ -176,7 +177,7 @@ def create(
*,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]],
mapping: Optional[Union[ComponentMapping, list[ComponentMapping], dict[str, ComponentMapping]]],
extends: Optional[ProtocolDAGResult] = None,
name: Optional[str] = None,
transformation_key: Optional[GufeKey] = None
Expand Down Expand Up @@ -219,6 +220,12 @@ def create(
ProtocolDAG
A directed, acyclic graph that can be executed by a `Scheduler`.
"""
if isinstance(mapping, dict):
warnings.warn(("mapping input as a dict is deprecated, "
"instead use either a single Mapping or list"),
DeprecationWarning)
mapping = list(mapping.values())

return ProtocolDAG(
name=name,
protocol_units=self._create(
Expand Down
9 changes: 9 additions & 0 deletions gufe/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,15 @@ def test_create_execute_gather(self, protocol_dag):

assert protocolresult.get_estimate() == 95500.0

def test_deprecation_warning_on_dict_mapping(self, instance, vacuum_ligand, solvated_ligand):
lig = solvated_ligand.components['ligand']
mapping = gufe.LigandAtomMapping(lig, lig, componentA_to_componentB={})

with pytest.warns(DeprecationWarning,
match="mapping input as a dict is deprecated"):
instance.create(stateA=solvated_ligand, stateB=vacuum_ligand,
mapping={'ligand': mapping})

class ProtocolDAGTestsMixin(GufeTokenizableTestsMixin):

def test_protocol_units(self, instance):
Expand Down
14 changes: 14 additions & 0 deletions gufe/tests/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io
import pathlib

import gufe
from gufe.transformations import Transformation, NonTransformation
from gufe.protocols.protocoldag import execute_DAG

Expand Down Expand Up @@ -123,6 +124,19 @@ def test_dump_load_roundtrip(self, absolute_transformation):
recreated = Transformation.load(string)
assert absolute_transformation == recreated

def test_deprecation_warning_on_dict_mapping(self, solvated_ligand, solvated_complex):
lig = solvated_complex.components['ligand']
# this mapping makes no sense, but it'll trigger the dep warning we want
mapping = gufe.LigandAtomMapping(lig, lig, componentA_to_componentB={})

with pytest.warns(DeprecationWarning,
match="mapping input as a dict is deprecated"):
Transformation(
solvated_complex, solvated_ligand,
protocol=DummyProtocol(settings=DummyProtocol.default_settings()),
mapping={'ligand': mapping},
)


class TestNonTransformation(GufeTokenizableTestsMixin):

Expand Down
9 changes: 8 additions & 1 deletion gufe/transformations/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from typing import Optional, Iterable, Union
import json
import warnings

from ..tokenization import GufeTokenizable, JSON_HANDLER
from ..utils import ensure_filelike
Expand All @@ -24,7 +25,7 @@ def __init__(
stateA: ChemicalSystem,
stateB: ChemicalSystem,
protocol: Protocol,
mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None,
mapping: Optional[Union[ComponentMapping, list[ComponentMapping], dict[str, ComponentMapping]]] = None,
name: Optional[str] = None,
):
"""Two chemical states with a method for estimating free energy difference
Expand All @@ -47,6 +48,12 @@ def __init__(
name : str, optional
a human-readable tag for this transformation
"""
if isinstance(mapping, dict):
warnings.warn(("mapping input as a dict is deprecated, "
"instead use either a single Mapping or list"),
DeprecationWarning)
mapping = list(mapping.values())

self._stateA = stateA
self._stateB = stateB
self._mapping = mapping
Expand Down
Loading