Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concurrency cancellation #2090

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import types
import uuid
import warnings
import weakref
from concurrent.futures import Future, ThreadPoolExecutor
from enum import Enum, auto, unique
from multiprocessing.connection import Connection
Expand Down Expand Up @@ -95,6 +96,10 @@ class PredictionState:


class Worker:
@property
def uses_concurrency(self) -> bool:
return self._max_concurrency > 1

def __init__(
self, child: "_ChildWorker", events: Connection, max_concurrency: int = 1
) -> None:
Expand Down Expand Up @@ -289,7 +294,7 @@ def _consume_events_inner(self) -> None:
with self._predictions_lock:
predict_state = self._predictions_in_flight.get(ev.tag)
if predict_state and not predict_state.cancel_sent:
self._child.send_cancel()
self._child.send_cancel_signal()
self._events.send(Envelope(event=Cancel(), tag=ev.tag))
predict_state.cancel_sent = True
else:
Expand Down Expand Up @@ -415,7 +420,7 @@ def run(self) -> None:
redirector,
)

def send_cancel(self) -> None:
def send_cancel_signal(self) -> None:
if self.is_alive() and self.pid:
os.kill(self.pid, signal.SIGUSR1)

Expand Down Expand Up @@ -516,7 +521,9 @@ def _loop(
while True:
e = cast(Envelope, self._events.recv())
if isinstance(e.event, Cancel):
continue # Ignored in sync predictors.
# for sync predictors, this is handled via SIGUSR1 signals from
# the parent via send_cancel_signal
continue
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
Expand All @@ -533,17 +540,27 @@ async def _aloop(
assert isinstance(self._events, LockedConnection)
self._events = AsyncConnection(self._events.connection)

task = None

async with asyncio.TaskGroup() as tg:
tasks = weakref.WeakValueDictionary[str | None, asyncio.Task[Any]]()
while True:
e = cast(Envelope, await self._events.recv())
if isinstance(e.event, Cancel) and task and self._cancelable:
if isinstance(e.event, Cancel):
# NOTE: We don't check the _cancelable flag here, instead we rely
# on the presence of the value in the weakmap to determine if
# a prediction is actively being processed.
task = tasks.get(e.tag)
if not task:
print(
"Got cancel event for unrecognized prediction",
file=sys.stderr,
)
continue

task.cancel()
elif isinstance(e.event, Shutdown):
break
elif isinstance(e.event, PredictionInput):
task = tg.create_task(
tasks[e.tag] = tg.create_task(
self._apredict(e.tag, e.event.payload, predict, redirector)
)
else:
Expand Down
114 changes: 91 additions & 23 deletions python/tests/server/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,24 @@
SLEEP_FIXTURES = [
WorkerConfig("sleep"),
WorkerConfig("sleep_async", min_python=(3, 11), is_async=True),
WorkerConfig(
"sleep_async",
min_python=(3, 11),
is_async=True,
max_concurrency=10,
),
]

SLEEP_NO_SETUP_FIXTURES = [
WorkerConfig("sleep", setup=False),
WorkerConfig("sleep_async", min_python=(3, 11), setup=False, is_async=True),
WorkerConfig(
"sleep_async",
min_python=(3, 11),
setup=False,
is_async=True,
max_concurrency=10,
),
]


Expand Down Expand Up @@ -452,30 +465,40 @@ def test_predict_logging(worker, expected_stdout, expected_stderr):


@uses_worker_configs(SLEEP_NO_SETUP_FIXTURES)
def test_cancel_is_safe(worker):
def test_cancel_is_safe(worker: Worker):
"""
Calls to cancel at any time should not result in unexpected things
happening or the cancelation of unexpected predictions.
"""

tag = None
if worker.uses_concurrency:
tag = "p1"

for _ in range(50):
worker.cancel()
worker.cancel(tag)

result = _process(worker, worker.setup)
assert not result.done.error

for _ in range(50):
worker.cancel()
worker.cancel(tag)

result1 = _process(
worker, lambda: worker.predict({"sleep": 0.5}), swallow_exceptions=True
worker,
lambda: worker.predict({"sleep": 0.5}, tag),
swallow_exceptions=True,
tag=tag,
)

for _ in range(50):
worker.cancel()
worker.cancel(tag)

result2 = _process(
worker, lambda: worker.predict({"sleep": 0.1}), swallow_exceptions=True
worker,
lambda: worker.predict({"sleep": 0.1}, tag),
swallow_exceptions=True,
tag=tag,
)

assert not result1.exception
Expand All @@ -486,68 +509,113 @@ def test_cancel_is_safe(worker):


@uses_worker_configs(SLEEP_NO_SETUP_FIXTURES)
def test_cancel_idempotency(worker):
def test_cancel_idempotency(worker: Worker):
"""
Multiple calls to cancel within the same prediction, while not necessary or
recommended, should still only result in a single cancelled prediction, and
should not affect subsequent predictions.
"""

def cancel_a_bunch(_):
for _ in range(100):
worker.cancel()
tag = None
if worker.uses_concurrency:
tag = "p1"

result = _process(worker, worker.setup)
assert not result.done.error

fut = worker.predict({"sleep": 0.5})
fut = worker.predict({"sleep": 0.5}, tag)
# We call cancel a WHOLE BUNCH to make sure that we don't propagate any
# of those cancelations to subsequent predictions, regardless of the
# internal implementation of exceptions raised inside signal handlers.
for _ in range(5):
time.sleep(0.05)
for _ in range(100):
worker.cancel()
worker.cancel(tag)
result1 = fut.result()
assert result1.canceled

result2 = _process(worker, lambda: worker.predict({"sleep": 0.1}))
tag = None
if worker.uses_concurrency:
tag = "p2"
result2 = _process(worker, lambda: worker.predict({"sleep": 0.1}, tag))

assert not result2.done.canceled
assert result2.output == "done in 0.1 seconds"


@uses_worker_configs(SLEEP_FIXTURES)
def test_cancel_multiple_predictions(worker):
@uses_worker_configs(
[
WorkerConfig("sleep"),
WorkerConfig("sleep_async", min_python=(3, 11), is_async=True),
WorkerConfig(
"sleep_async", min_python=(3, 11), is_async=True, max_concurrency=5
),
]
)
def test_cancel_multiple_predictions(worker: Worker):
"""
Multiple predictions cancelled in a row shouldn't be a problem. This test
is mainly ensuring that the _allow_cancel latch in Worker is correctly
reset every time a prediction starts.
"""
dones: list[Done] = []
for _ in range(5):
fut = worker.predict({"sleep": 1})
for i in range(5):
tag = None
if worker._max_concurrency > 1:
tag = f"p{i}"
fut = worker.predict({"sleep": 0.2}, tag)
time.sleep(0.1)
worker.cancel()
worker.cancel(tag)
dones.append(fut.result())

assert dones == [Done(canceled=True)] * 5

assert not worker.predict({"sleep": 0}).result().canceled
assert not worker.predict({"sleep": 0}, "p6").result().canceled


@uses_worker_configs(
[
WorkerConfig(
"sleep_async", min_python=(3, 11), is_async=True, max_concurrency=5
),
]
)
def test_cancel_some_predictions_async_with_concurrency(worker: Worker):
"""
Multiple predictions cancelled in a row shouldn't be a problem. This test
is mainly ensuring that the _allow_cancel latch in Worker is correctly
reset every time a prediction starts.
"""
fut1 = worker.predict({"sleep": 0.2}, "p1")
fut2 = worker.predict({"sleep": 0.2}, "p2")
fut3 = worker.predict({"sleep": 0.2}, "p3")

time.sleep(0.1)

worker.cancel("p2")

assert not fut1.result().canceled
assert fut2.result().canceled
assert not fut3.result().canceled


@uses_worker_configs(SLEEP_FIXTURES)
def test_graceful_shutdown(worker):
def test_graceful_shutdown(worker: Worker):
"""
On shutdown, the worker should finish running the current prediction, and
then exit.
"""

tag = None
if worker.uses_concurrency:
tag = "p1"

saw_first_event = threading.Event()

# When we see the first event, we'll start the shutdown process.
worker.subscribe(lambda event: saw_first_event.set())
worker.subscribe(lambda event: saw_first_event.set(), tag=tag)

fut = worker.predict({"sleep": 1})
fut = worker.predict({"sleep": 1}, tag)

saw_first_event.wait(timeout=1)
worker.shutdown(timeout=2)
Expand Down Expand Up @@ -587,7 +655,7 @@ def start(self):
def is_alive(self):
return self.alive

def send_cancel(self):
def send_cancel_signal(self):
pass

def terminate(self):
Expand Down
Loading