Skip to content

Commit

Permalink
Merge pull request #150 from oremanj/asyncgens
Browse files Browse the repository at this point in the history
Finalize async generators in the correct context
  • Loading branch information
oremanj authored Apr 24, 2024
2 parents 6f2870a + cf7a25f commit 898ab8d
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 5 deletions.
4 changes: 4 additions & 0 deletions newsfragments/92.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
trio-asyncio now properly finalizes asyncio-flavored async generators
upon closure of the event loop. Previously, Trio's async generator finalizers
would try to finalize all async generators in Trio mode, regardless of their
flavor, which could lead to spurious errors.
5 changes: 1 addition & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
@pytest.fixture
async def loop():
async with trio_asyncio.open_loop() as loop:
try:
yield loop
finally:
await loop.stop().wait()
yield loop


# auto-trio-ize all async functions
Expand Down
104 changes: 104 additions & 0 deletions tests/test_trio_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import types
import asyncio
import trio
import trio.testing
import trio_asyncio
import contextlib
import gc


async def use_asyncio():
Expand Down Expand Up @@ -203,3 +205,105 @@ async def main():
asyncio.run(main())

assert scope.value.code == 42


@pytest.mark.trio
@pytest.mark.parametrize("alive_on_exit", (False, True))
@pytest.mark.parametrize("slow_finalizer", (False, True))
@pytest.mark.parametrize("loop_timeout", (0, 1, 20))
async def test_asyncgens(alive_on_exit, slow_finalizer, loop_timeout, autojump_clock):
import sniffio

record = set()
holder = []

async def agen(label, extra):
assert sniffio.current_async_library() == label
if label == "asyncio":
loop = asyncio.get_running_loop()
try:
yield 1
finally:
library = sniffio.current_async_library()
if label == "asyncio":
assert loop is asyncio.get_running_loop()
try:
await sys.modules[library].sleep(5 if slow_finalizer else 0)
except (trio.Cancelled, asyncio.CancelledError):
pass
record.add((label + extra, library))

async def iterate_one(label, extra=""):
ag = agen(label, extra)
await ag.asend(None)
if alive_on_exit:
holder.append(ag)
else:
del ag

sys.unraisablehook, prev_hook = sys.__unraisablehook__, sys.unraisablehook
try:
before_hooks = sys.get_asyncgen_hooks()

start_time = trio.current_time()
with trio.move_on_after(loop_timeout) as scope:
if loop_timeout == 0:
scope.cancel()
async with trio_asyncio.open_loop() as loop, trio_asyncio.open_loop() as loop2:
assert sys.get_asyncgen_hooks() != before_hooks
async with trio.open_nursery() as nursery:
# Make sure the iterate_one aio tasks don't get
# cancelled before they start:
nursery.cancel_scope.shield = True
try:
nursery.start_soon(iterate_one, "trio")
nursery.start_soon(
loop.run_aio_coroutine, iterate_one("asyncio")
)
nursery.start_soon(
loop2.run_aio_coroutine, iterate_one("asyncio", "2")
)
await loop.synchronize()
await loop2.synchronize()
finally:
nursery.cancel_scope.shield = False
if not alive_on_exit and sys.implementation.name == "pypy":
for _ in range(5):
gc.collect()

# Make sure we cleaned up properly once all trio-aio loops were closed
assert sys.get_asyncgen_hooks() == before_hooks

# asyncio agens should be finalized as soon as asyncio loop ends,
# regardless of liveness
assert ("asyncio", "asyncio") in record
assert ("asyncio2", "asyncio") in record

# asyncio agen finalizers should be able to take a cancel
if (slow_finalizer or loop_timeout == 0) and alive_on_exit:
# Each loop finalizes in series, and takes 5 seconds
# if slow_finalizer is true.
assert trio.current_time() == start_time + min(loop_timeout, 10)
assert scope.cancelled_caught == (loop_timeout < 10)
else:
# `not alive_on_exit` implies that the asyncio agen aclose() tasks
# are started before loop shutdown, which means they'll be
# cancelled during loop shutdown; this matches regular asyncio.
#
# `not slow_finalizer and loop_timeout > 0` implies that the agens
# have time to complete before we cancel them.
assert trio.current_time() == start_time
assert not scope.cancelled_caught

# trio asyncgen should eventually be finalized in trio mode
del holder[:]
for _ in range(5):
gc.collect()
await trio.testing.wait_all_tasks_blocked()
assert record == {
("trio", "trio"),
("asyncio", "asyncio"),
("asyncio2", "asyncio"),
}
finally:
sys.unraisablehook = prev_hook
74 changes: 74 additions & 0 deletions trio_asyncio/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,71 @@ def shutdown(self, wait=None):
self._running = False


class AsyncGeneratorDispatcher:
"""Helper object providing async generator hooks that route
finalization to either the correct trio-asyncio event loop or the
outer Trio run, depending on where the generator was first iterated.
"""

def __init__(self, prev_hooks):
self.prev_hooks = prev_hooks
self.refcnt = 1

@classmethod
def install(cls):
current_hooks = sys.get_asyncgen_hooks()

# These hooks should either be our own AsyncGeneratorDispatcher
# (for another trio-asyncio loop) or Trio's hooks. Both of those
# provide both hooks.
assert current_hooks.firstiter is not None
assert current_hooks.finalizer is not None

matches = (
getattr(current_hooks.firstiter, "__func__", None) is cls.firstiter
) + (getattr(current_hooks.finalizer, "__func__", None) is cls.finalizer)
if matches == 0:
# Create a new dispatcher that forwards non-trio-asyncio asyncgens
# to the current_hooks
dispatcher = cls(prev_hooks=current_hooks)
sys.set_asyncgen_hooks(
firstiter=dispatcher.firstiter, finalizer=dispatcher.finalizer
)
else:
# Take a new reference to the dispatcher that the current_hooks
# refer to
assert matches == 2
dispatcher = current_hooks.firstiter.__self__
assert dispatcher is current_hooks.finalizer.__self__
assert isinstance(dispatcher, cls)
dispatcher.refcnt += 1
return dispatcher

def uninstall(self):
self.refcnt -= 1
if self.refcnt <= 0:
sys.set_asyncgen_hooks(*self.prev_hooks)
assert self.refcnt == 0

def firstiter(self, agen):
if sniffio_library.name == "asyncio":
loop = asyncio.get_running_loop()
agen.ag_frame.f_locals["@trio_asyncio_loop"] = loop
return loop._asyncgen_firstiter_hook(agen)
else:
return self.prev_hooks.firstiter(agen)

def finalizer(self, agen):
try:
loop = agen.ag_frame.f_locals.get("@trio_asyncio_loop")
except AttributeError: # pragma: no cover
loop = None
if loop is not None:
return loop._asyncgen_finalizer_hook(agen)
else:
return self.prev_hooks.finalizer(agen)


class BaseTrioEventLoop(asyncio.SelectorEventLoop):
"""An asyncio event loop that runs on top of Trio.
Expand Down Expand Up @@ -135,6 +200,10 @@ class BaseTrioEventLoop(asyncio.SelectorEventLoop):
# (threading) Thread this loop is running in
_thread = None

# An instance of AsyncGeneratorDispatcher for handling asyncio async
# generators; it may be shared by multiple running trio-asyncio loops
_asyncgen_dispatcher = None

def __init__(self, queue_len=None):
if queue_len is None:
queue_len = math.inf
Expand Down Expand Up @@ -629,6 +698,7 @@ async def _main_loop_init(self, nursery):
self._nursery = nursery
self._task = trio.lowlevel.current_task()
self._token = trio.lowlevel.current_trio_token()
self._asyncgen_dispatcher = AsyncGeneratorDispatcher.install()

async def _main_loop(self, task_status=trio.TASK_STATUS_IGNORED):
"""Run the loop by processing its event queue.
Expand Down Expand Up @@ -738,6 +808,10 @@ async def _main_loop_exit(self):
except TrioAsyncioExit:
pass

# Restore previous async generator hooks
self._asyncgen_dispatcher.uninstall()
self._asyncgen_dispatcher = None

# Kill off unprocessed work
self._cancel_fds()
self._cancel_timers()
Expand Down
56 changes: 55 additions & 1 deletion trio_asyncio/_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import trio
import asyncio
import warnings
import threading
from contextvars import ContextVar
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -560,6 +561,49 @@ async def wait_for_sync():
tasks_nursery.cancel_scope.cancel()

finally:
# If we have any async generators left, finalize them before
# closing the event loop. Make sure that the finalizers have a
# chance to actually start before they're exposed to any
# external cancellation, since asyncio doesn't guarantee that
# cancelled tasks have a chance to start first.

asyncgens_done = trio.Event()
should_warn = False
if len(loop._asyncgens) == 0:
asyncgens_done.set()
elif not loop.is_running():
asyncgens_done.set()
should_warn = True
else:
shield_asyncgen_finalizers = trio.CancelScope(shield=True)

async def sentinel():
try:
yield
finally:
try:
# Open-coded asyncio version of loop.synchronize();
# since we closed the tasks_nursery, we can't do
# any more asyncio-to-trio-mode conversions
w = asyncio.Event()
loop.call_soon(w.set)
await w.wait()
finally:
shield_asyncgen_finalizers.shield = False

async def shutdown_asyncgens_from_aio():
agen = sentinel()
await agen.asend(None)
try:
await loop.shutdown_asyncgens()
finally:
asyncgens_done.set()

@loop_nursery.start_soon
async def shutdown_asyncgens_from_trio():
with shield_asyncgen_finalizers:
await loop.run_aio_coroutine(shutdown_asyncgens_from_aio())

if forwarded_cancellation is not None:
# Now that we're outside the shielded tasks_nursery, we can
# add this cancellation to the set of errors propagating out
Expand All @@ -570,7 +614,17 @@ async def forward_cancellation():
raise forwarded_cancellation

try:
await loop._main_loop_exit()
try:
if should_warn:
warnings.warn(
"trio-asyncio loop was stopped before its async "
"generators were finalized; weird stuff might happen",
RuntimeWarning,
)
finally:
with trio.CancelScope(shield=True):
await asyncgens_done.wait()
await loop._main_loop_exit()
finally:
loop.close()
current_loop.reset(old_loop)
Expand Down

0 comments on commit 898ab8d

Please sign in to comment.