diff --git a/gufe/protocols/errors.py b/gufe/protocols/errors.py index b9256e46..3dddf9d6 100644 --- a/gufe/protocols/errors.py +++ b/gufe/protocols/errors.py @@ -26,3 +26,11 @@ class MissingUnitResultError(ProtocolDAGResultError): class ProtocolUnitFailureError(ProtocolDAGResultError): """Error when a ProtocolDAGResult has only ProtocolUnitFailure(s) for a given ProtocolUnit.""" + + +class ExecutionInterrupt(BaseException): + """Exception for unrecoverable execution failures that are unrelated to user inputs. + + Will not be caught by ``ProtocolUnit.execute()``. + + """ diff --git a/gufe/protocols/protocolunit.py b/gufe/protocols/protocolunit.py index fba47adb..56dcd750 100644 --- a/gufe/protocols/protocolunit.py +++ b/gufe/protocols/protocolunit.py @@ -21,6 +21,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union from ..tokenization import TOKENIZABLE_REGISTRY, GufeKey, GufeTokenizable +from .errors import ExecutionInterrupt @dataclass @@ -327,9 +328,8 @@ def execute( start_time=start, end_time=datetime.datetime.now(), ) - - except KeyboardInterrupt: - # if we "fail" due to a KeyboardInterrupt, we always want to raise + except (KeyboardInterrupt, ExecutionInterrupt): + # NOTE: this statement is for clarity, these Interrupts will raise regardless. raise except Exception as e: if raise_error: diff --git a/gufe/tests/test_protocolunit.py b/gufe/tests/test_protocolunit.py index 51a37a47..c3f153fe 100644 --- a/gufe/tests/test_protocolunit.py +++ b/gufe/tests/test_protocolunit.py @@ -3,6 +3,7 @@ import pytest +from gufe.protocols.errors import ExecutionInterrupt from gufe.protocols.protocolunit import Context, ProtocolUnit, ProtocolUnitFailure, ProtocolUnitResult from gufe.tests.test_tokenization import GufeTokenizableTestsMixin @@ -27,6 +28,16 @@ def _execute(ctx: Context, an_input=2, **inputs): return {"foo": "bar"} +class DummyExecutionInterruptUnit(ProtocolUnit): + @staticmethod + def _execute(ctx: Context, an_input=2, **inputs): + + if an_input != 2: + raise ExecutionInterrupt + + return {"foo": "bar"} + + @pytest.fixture def dummy_unit(): return DummyUnit(name="qux") @@ -77,6 +88,26 @@ def test_execute(self, tmpdir): with pytest.raises(ValueError, match="should always be 2"): unit.execute(context=ctx, raise_error=True, an_input=3) + def test_execute_ExecutionInterrupt(self, tmpdir): + with tmpdir.as_cwd(): + + unit = DummyExecutionInterruptUnit() + + shared = Path("shared") / str(unit.key) + shared.mkdir(parents=True) + + scratch = Path("scratch") / str(unit.key) + scratch.mkdir(parents=True) + + ctx = Context(shared=shared, scratch=scratch) + + with pytest.raises(ExecutionInterrupt): + unit.execute(context=ctx, an_input=3) + + u: ProtocolUnitResult = unit.execute(context=ctx, an_input=2) + + assert u.outputs == {"foo": "bar"} + def test_execute_KeyboardInterrupt(self, tmpdir): with tmpdir.as_cwd(): diff --git a/news/add_execution_interrupt.rst b/news/add_execution_interrupt.rst new file mode 100644 index 00000000..b3826fd5 --- /dev/null +++ b/news/add_execution_interrupt.rst @@ -0,0 +1,23 @@ +**Added:** + +* Added ExecutionInterrupt, a special exception that does not get handled as a ``ProtocolUnitFailure``. + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +*