Skip to content

Commit

Permalink
Support cancellation of concurrent tasks
Browse files Browse the repository at this point in the history
Currently a cancellation event sent to a worker with `concurrency` > 1
will always cancel the most recent task.

This commit refactors the `_ChildWorker` class to exclusively send the
`Cancel` event when receiving a `CancelRequest` and then handle the
cancellation in the event loop.

For the async `_aloop` we now use a weak map to track all in flight
prediction tasks. On cancel, we retrieve the task from the map and
cancel it.

For the standard `_loop` method we send the cancellation signal to the
process as before. It's not entirely clear why the child needs to send
the signal to itself vs raising the `CancellationException` directly
but this commit leaves it as-is for simplicity.
  • Loading branch information
Aron Carroll authored and aron committed Dec 18, 2024
1 parent 4161a39 commit d4f5e70
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 30 deletions.
31 changes: 24 additions & 7 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import types
import uuid
import warnings
import weakref
from concurrent.futures import Future, ThreadPoolExecutor
from enum import Enum, auto, unique
from multiprocessing.connection import Connection
Expand Down Expand Up @@ -95,6 +96,10 @@ class PredictionState:


class Worker:
@property
def uses_concurrency(self) -> bool:
return self._max_concurrency > 1

def __init__(
self, child: "_ChildWorker", events: Connection, max_concurrency: int = 1
) -> None:
Expand Down Expand Up @@ -289,7 +294,7 @@ def _consume_events_inner(self) -> None:
with self._predictions_lock:
predict_state = self._predictions_in_flight.get(ev.tag)
if predict_state and not predict_state.cancel_sent:
self._child.send_cancel()
self._child.send_cancel_signal()
self._events.send(Envelope(event=Cancel(), tag=ev.tag))
predict_state.cancel_sent = True
else:
Expand Down Expand Up @@ -415,7 +420,7 @@ def run(self) -> None:
redirector,
)

def send_cancel(self) -> None:
def send_cancel_signal(self) -> None:
if self.is_alive() and self.pid:
os.kill(self.pid, signal.SIGUSR1)

Expand Down Expand Up @@ -516,7 +521,9 @@ def _loop(
while True:
e = cast(Envelope, self._events.recv())
if isinstance(e.event, Cancel):
continue # Ignored in sync predictors.
# for sync predictors, this is handled via SIGUSR1 signals from
# the parent via send_cancel_signal
continue
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
Expand All @@ -533,17 +540,27 @@ async def _aloop(
assert isinstance(self._events, LockedConnection)
self._events = AsyncConnection(self._events.connection)

task = None

async with asyncio.TaskGroup() as tg:
tasks = weakref.WeakValueDictionary[str | None, asyncio.Task[Any]]()
while True:
e = cast(Envelope, await self._events.recv())
if isinstance(e.event, Cancel) and task and self._cancelable:
if isinstance(e.event, Cancel):
# NOTE: We don't check the _cancelable flag here, instead we rely
# on the presence of the value in the weakmap to determine if
# a prediction is actively being processed.
task = tasks.get(e.tag)
if not task:
print(
"Got cancel event for unrecognized prediction",
file=sys.stderr,
)
continue

task.cancel()
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
task = tg.create_task(
tasks[e.tag] = tg.create_task(
self._apredict(e.tag, e.event.payload, predict, redirector)
)
else:
Expand Down
114 changes: 91 additions & 23 deletions python/tests/server/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,24 @@
SLEEP_FIXTURES = [
WorkerConfig("sleep"),
WorkerConfig("sleep_async", min_python=(3, 11), is_async=True),
WorkerConfig(
"sleep_async",
min_python=(3, 11),
is_async=True,
max_concurrency=10,
),
]

SLEEP_NO_SETUP_FIXTURES = [
WorkerConfig("sleep", setup=False),
WorkerConfig("sleep_async", min_python=(3, 11), setup=False, is_async=True),
WorkerConfig(
"sleep_async",
min_python=(3, 11),
setup=False,
is_async=True,
max_concurrency=10,
),
]


Expand Down Expand Up @@ -452,30 +465,40 @@ def test_predict_logging(worker, expected_stdout, expected_stderr):


@uses_worker_configs(SLEEP_NO_SETUP_FIXTURES)
def test_cancel_is_safe(worker):
def test_cancel_is_safe(worker: Worker):
"""
Calls to cancel at any time should not result in unexpected things
happening or the cancelation of unexpected predictions.
"""

tag = None
if worker.uses_concurrency:
tag = "p1"

for _ in range(50):
worker.cancel()
worker.cancel(tag)

result = _process(worker, worker.setup)
assert not result.done.error

for _ in range(50):
worker.cancel()
worker.cancel(tag)

result1 = _process(
worker, lambda: worker.predict({"sleep": 0.5}), swallow_exceptions=True
worker,
lambda: worker.predict({"sleep": 0.5}, tag),
swallow_exceptions=True,
tag=tag,
)

for _ in range(50):
worker.cancel()
worker.cancel(tag)

result2 = _process(
worker, lambda: worker.predict({"sleep": 0.1}), swallow_exceptions=True
worker,
lambda: worker.predict({"sleep": 0.1}, tag),
swallow_exceptions=True,
tag=tag,
)

assert not result1.exception
Expand All @@ -486,68 +509,113 @@ def test_cancel_is_safe(worker):


@uses_worker_configs(SLEEP_NO_SETUP_FIXTURES)
def test_cancel_idempotency(worker):
def test_cancel_idempotency(worker: Worker):
"""
Multiple calls to cancel within the same prediction, while not necessary or
recommended, should still only result in a single cancelled prediction, and
should not affect subsequent predictions.
"""

def cancel_a_bunch(_):
for _ in range(100):
worker.cancel()
tag = None
if worker.uses_concurrency:
tag = "p1"

result = _process(worker, worker.setup)
assert not result.done.error

fut = worker.predict({"sleep": 0.5})
fut = worker.predict({"sleep": 0.5}, tag)
# We call cancel a WHOLE BUNCH to make sure that we don't propagate any
# of those cancelations to subsequent predictions, regardless of the
# internal implementation of exceptions raised inside signal handlers.
for _ in range(5):
time.sleep(0.05)
for _ in range(100):
worker.cancel()
worker.cancel(tag)
result1 = fut.result()
assert result1.canceled

result2 = _process(worker, lambda: worker.predict({"sleep": 0.1}))
tag = None
if worker.uses_concurrency:
tag = "p2"
result2 = _process(worker, lambda: worker.predict({"sleep": 0.1}, tag))

assert not result2.done.canceled
assert result2.output == "done in 0.1 seconds"


@uses_worker_configs(SLEEP_FIXTURES)
def test_cancel_multiple_predictions(worker):
@uses_worker_configs(
[
WorkerConfig("sleep"),
WorkerConfig("sleep_async", min_python=(3, 11), is_async=True),
WorkerConfig(
"sleep_async", min_python=(3, 11), is_async=True, max_concurrency=5
),
]
)
def test_cancel_multiple_predictions(worker: Worker):
"""
Multiple predictions cancelled in a row shouldn't be a problem. This test
is mainly ensuring that the _allow_cancel latch in Worker is correctly
reset every time a prediction starts.
"""
dones: list[Done] = []
for _ in range(5):
fut = worker.predict({"sleep": 1})
for i in range(5):
tag = None
if worker._max_concurrency > 1:
tag = f"p{i}"
fut = worker.predict({"sleep": 0.2}, tag)
time.sleep(0.1)
worker.cancel()
worker.cancel(tag)
dones.append(fut.result())

assert dones == [Done(canceled=True)] * 5

assert not worker.predict({"sleep": 0}).result().canceled
assert not worker.predict({"sleep": 0}, "p6").result().canceled


@uses_worker_configs(
[
WorkerConfig(
"sleep_async", min_python=(3, 11), is_async=True, max_concurrency=5
),
]
)
def test_cancel_some_predictions_async_with_concurrency(worker: Worker):
"""
Multiple predictions cancelled in a row shouldn't be a problem. This test
is mainly ensuring that the _allow_cancel latch in Worker is correctly
reset every time a prediction starts.
"""
fut1 = worker.predict({"sleep": 0.2}, "p1")
fut2 = worker.predict({"sleep": 0.2}, "p2")
fut3 = worker.predict({"sleep": 0.2}, "p3")

time.sleep(0.1)

worker.cancel("p2")

assert not fut1.result().canceled
assert fut2.result().canceled
assert not fut3.result().canceled


@uses_worker_configs(SLEEP_FIXTURES)
def test_graceful_shutdown(worker):
def test_graceful_shutdown(worker: Worker):
"""
On shutdown, the worker should finish running the current prediction, and
then exit.
"""

tag = None
if worker.uses_concurrency:
tag = "p1"

saw_first_event = threading.Event()

# When we see the first event, we'll start the shutdown process.
worker.subscribe(lambda event: saw_first_event.set())
worker.subscribe(lambda event: saw_first_event.set(), tag=tag)

fut = worker.predict({"sleep": 1})
fut = worker.predict({"sleep": 1}, tag)

saw_first_event.wait(timeout=1)
worker.shutdown(timeout=2)
Expand Down Expand Up @@ -587,7 +655,7 @@ def start(self):
def is_alive(self):
return self.alive

def send_cancel(self):
def send_cancel_signal(self):
pass

def terminate(self):
Expand Down

0 comments on commit d4f5e70

Please sign in to comment.