Skip to content

Commit

Permalink
More test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Aron Carroll committed Dec 13, 2024
1 parent b7e1931 commit 409b659
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 48 deletions.
10 changes: 5 additions & 5 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class PredictionState:


class Worker:
@property
def uses_concurrency(self) -> bool:
return self._max_concurrency > 1

def __init__(
self, child: "_ChildWorker", events: Connection, max_concurrency: int = 1
) -> None:
Expand Down Expand Up @@ -536,17 +540,13 @@ async def _aloop(
self._events = AsyncConnection(self._events.connection)

async with asyncio.TaskGroup() as tg:
tasks = weakref.WeakValueDictionary[str, asyncio.Task[Any]]()
tasks = weakref.WeakValueDictionary[str | None, asyncio.Task[Any]]()
while True:
e = cast(Envelope, await self._events.recv())
if isinstance(e.event, Cancel):
print("THIS IS A CANCEL EVENT", e.tag, self._cancelable)
if not self._cancelable:
continue

assert (
e.tag
), "expected Cancel event to have a tag when concurrency_max > 1"
task = tasks.get(e.tag)
if not task:
print(
Expand Down
116 changes: 73 additions & 43 deletions python/tests/server/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,25 @@
SLEEP_FIXTURES = [
WorkerConfig("sleep"),
WorkerConfig("sleep_async", min_python=(3, 11), is_async=True),
WorkerConfig(
"sleep_async",
min_python=(3, 11),
setup=False,
is_async=True,
max_concurrency=10,
),
]

SLEEP_NO_SETUP_FIXTURES = [
WorkerConfig("sleep", setup=False),
WorkerConfig("sleep_async", min_python=(3, 11), setup=False, is_async=True),
WorkerConfig(
"sleep_async",
min_python=(3, 11),
setup=False,
is_async=True,
max_concurrency=10,
),
]


Expand Down Expand Up @@ -452,30 +466,40 @@ def test_predict_logging(worker, expected_stdout, expected_stderr):


@uses_worker_configs(SLEEP_NO_SETUP_FIXTURES)
def test_cancel_is_safe(worker):
def test_cancel_is_safe(worker: Worker):
"""
Calls to cancel at any time should not result in unexpected things
happening or the cancelation of unexpected predictions.
"""

tag = None
if worker.uses_concurrency:
tag = "p1"

for _ in range(50):
worker.cancel()
worker.cancel(tag)

result = _process(worker, worker.setup)
assert not result.done.error

for _ in range(50):
worker.cancel()
worker.cancel(tag)

result1 = _process(
worker, lambda: worker.predict({"sleep": 0.5}), swallow_exceptions=True
worker,
lambda: worker.predict({"sleep": 0.5}, tag),
swallow_exceptions=True,
tag=tag,
)

for _ in range(50):
worker.cancel()
worker.cancel(tag)

result2 = _process(
worker, lambda: worker.predict({"sleep": 0.1}), swallow_exceptions=True
worker,
lambda: worker.predict({"sleep": 0.1}, tag),
swallow_exceptions=True,
tag=tag,
)

assert not result1.exception
Expand All @@ -486,82 +510,83 @@ def test_cancel_is_safe(worker):


@uses_worker_configs(SLEEP_NO_SETUP_FIXTURES)
def test_cancel_idempotency(worker):
def test_cancel_idempotency(worker: Worker):
"""
Multiple calls to cancel within the same prediction, while not necessary or
recommended, should still only result in a single cancelled prediction, and
should not affect subsequent predictions.
"""

def cancel_a_bunch(_):
for _ in range(100):
worker.cancel()
tag = None
if worker.uses_concurrency:
tag = "p1"

result = _process(worker, worker.setup)
assert not result.done.error

fut = worker.predict({"sleep": 0.5})
fut = worker.predict({"sleep": 0.5}, tag)
# We call cancel a WHOLE BUNCH to make sure that we don't propagate any
# of those cancelations to subsequent predictions, regardless of the
# internal implementation of exceptions raised inside signal handlers.
for _ in range(5):
time.sleep(0.05)
for _ in range(100):
worker.cancel()
worker.cancel(tag)
result1 = fut.result()
assert result1.canceled

result2 = _process(worker, lambda: worker.predict({"sleep": 0.1}))
tag = None
if worker.uses_concurrency:
tag = "p2"
result2 = _process(worker, lambda: worker.predict({"sleep": 0.1}, tag))

assert not result2.done.canceled
assert result2.output == "done in 0.1 seconds"


@uses_worker_configs([WorkerConfig("sleep")])
@uses_worker_configs(
[
WorkerConfig("sleep"),
WorkerConfig("sleep_async", min_python=(3, 11), is_async=True),
WorkerConfig(
"sleep_async", min_python=(3, 11), is_async=True, max_concurrency=5
),
]
)
def test_cancel_multiple_predictions(worker: Worker):
"""
Multiple predictions cancelled in a row shouldn't be a problem. This test
is mainly ensuring that the _allow_cancel latch in Worker is correctly
reset every time a prediction starts.
"""
dones: list[Done] = []
for _ in range(5):
fut = worker.predict({"sleep": 0.2})
time.sleep(0.1)
worker.cancel()
dones.append(fut.result())
assert dones == [Done(canceled=True)] * 5

assert not worker.predict({"sleep": 0}).result().canceled


@uses_worker_configs([WorkerConfig("sleep_async", min_python=(3, 11), is_async=True)])
def test_cancel_multiple_predictions_async(worker: Worker):
"""
Multiple predictions cancelled in a row shouldn't be a problem. This test
is mainly ensuring that the _allow_cancel latch in Worker is correctly
reset every time a prediction starts.
"""
dones: list[Done] = []
for i in range(5):
tag = f"p{i}"
tag = None
if worker._max_concurrency > 1:
tag = f"p{i}"
fut = worker.predict({"sleep": 0.2}, tag)
time.sleep(0.1)
worker.cancel(tag)
dones.append(fut.result())

assert dones == [Done(canceled=True)] * 5

assert not worker.predict({"sleep": 0}).result().canceled
assert not worker.predict({"sleep": 0}, "p6").result().canceled


@uses_worker_configs([WorkerConfig("sleep_async", min_python=(3, 11), is_async=True)])
def test_cancel_some_predictions_async(worker: Worker):
@uses_worker_configs(
[
WorkerConfig(
"sleep_async", min_python=(3, 11), is_async=True, max_concurrency=5
),
]
)
def test_cancel_some_predictions_async_with_concurrency(worker: Worker):
"""
Multiple predictions cancelled in a row shouldn't be a problem. This test
is mainly ensuring that the _allow_cancel latch in Worker is correctly
reset every time a prediction starts.
"""
dones: list[Done] = []
fut1 = worker.predict({"sleep": 0.2}, "p1")
fut2 = worker.predict({"sleep": 0.2}, "p2")
fut3 = worker.predict({"sleep": 0.2}, "p3")
Expand All @@ -570,23 +595,28 @@ def test_cancel_some_predictions_async(worker: Worker):

worker.cancel("p2")

assert fut1.result().canceled = False
assert fut1.result().canceled = True
assert fut1.result().canceled = False
assert fut1.result().canceled == False
assert fut2.result().canceled == True
assert fut3.result().canceled == False


@uses_worker_configs(SLEEP_FIXTURES)
def test_graceful_shutdown(worker):
def test_graceful_shutdown(worker: Worker):
"""
On shutdown, the worker should finish running the current prediction, and
then exit.
"""

tag = None
if worker.uses_concurrency:
tag = "p1"

saw_first_event = threading.Event()

# When we see the first event, we'll start the shutdown process.
worker.subscribe(lambda event: saw_first_event.set())

fut = worker.predict({"sleep": 1})
fut = worker.predict({"sleep": 1}, tag)

saw_first_event.wait(timeout=1)
worker.shutdown(timeout=2)
Expand Down Expand Up @@ -626,7 +656,7 @@ def start(self):
def is_alive(self):
return self.alive

def send_cancel(self):
def send_cancel_signal(self):
pass

def terminate(self):
Expand Down

0 comments on commit 409b659

Please sign in to comment.