diff --git a/bellows/__init__.py b/bellows/__init__.py index ed1828e6..12cae532 100644 --- a/bellows/__init__.py +++ b/bellows/__init__.py @@ -1,5 +1,5 @@ MAJOR_VERSION = 0 -MINOR_VERSION = 8 -PATCH_VERSION = '2' +MINOR_VERSION = 9 +PATCH_VERSION = '0' __short_version__ = '{}.{}'.format(MAJOR_VERSION, MINOR_VERSION) __version__ = '{}.{}'.format(__short_version__, PATCH_VERSION) diff --git a/bellows/thread.py b/bellows/thread.py new file mode 100644 index 00000000..6f39d48d --- /dev/null +++ b/bellows/thread.py @@ -0,0 +1,98 @@ +import asyncio +import logging + +import sys + +import functools +from concurrent.futures import ThreadPoolExecutor + +LOGGER = logging.getLogger(__name__) + + +class EventLoopThread: + ''' Run a parallel event loop in a separate thread ''' + def __init__(self): + self.loop = None + self.thread_complete = None + + def run_coroutine_threadsafe(self, coroutine): + current_loop = asyncio.get_event_loop() + future = asyncio.run_coroutine_threadsafe(coroutine, self.loop) + return asyncio.wrap_future(future, loop=current_loop) + + def _thread_main(self, init_task): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + try: + self.loop.run_until_complete(init_task) + self.loop.run_forever() + finally: + self.loop.close() + self.loop = None + + async def start(self): + current_loop = asyncio.get_event_loop() + if self.loop is not None and not self.loop.is_closed(): + return + + executor_opts = {'max_workers': 1} + if sys.version_info[:2] >= (3, 6): + executor_opts['thread_name_prefix'] = __name__ + executor = ThreadPoolExecutor(**executor_opts) + + thread_started_future = current_loop.create_future() + + async def init_task(): + current_loop.call_soon_threadsafe(thread_started_future.set_result, None) + + # Use current loop so current loop has a reference to the long-running thread as one of its tasks + thread_complete = current_loop.run_in_executor(executor, self._thread_main, init_task()) + self.thread_complete = thread_complete + current_loop.call_soon(executor.shutdown, False) + await thread_started_future + return thread_complete + + def force_stop(self): + if self.loop is not None: + self.loop.call_soon_threadsafe(self.loop.stop) + + +class ThreadsafeProxy: + ''' Proxy class which enforces threadsafe non-blocking calls + This class can be used to wrap an object to ensure any calls + using that object's methods are done on a particular event loop + ''' + def __init__(self, obj, obj_loop): + self._obj = obj + self._obj_loop = obj_loop + + def __getattr__(self, name): + func = getattr(self._obj, name) + if not callable(func): + raise TypeError("Can only use ThreadsafeProxy with callable attributes: {}.{}".format( + self._obj.__class__.__name__, name)) + + def func_wrapper(*args, **kwargs): + loop = self._obj_loop + curr_loop = asyncio.get_event_loop() + call = functools.partial(func, *args, **kwargs) + if loop == curr_loop: + return call() + if loop.is_closed(): + # Disconnected + LOGGER.warning("Attempted to use a closed event loop") + return + if asyncio.iscoroutinefunction(func): + future = asyncio.run_coroutine_threadsafe(call(), loop) + return asyncio.wrap_future(future, loop=curr_loop) + else: + def check_result_wrapper(): + result = call() + if result is not None: + raise TypeError("ThreadsafeProxy can only wrap functions with no return value \ + \nUse an async method to return values: {}.{}".format( + self._obj.__class__.__name__, name)) + + loop.call_soon_threadsafe(check_result_wrapper) + return func_wrapper diff --git a/bellows/uart.py b/bellows/uart.py index 51a01e78..5e980fd7 100644 --- a/bellows/uart.py +++ b/bellows/uart.py @@ -5,6 +5,8 @@ import serial import serial_asyncio +from bellows.thread import EventLoopThread, ThreadsafeProxy + import bellows.types as t LOGGER = logging.getLogger(__name__) @@ -27,7 +29,7 @@ class Gateway(asyncio.Protocol): class Terminator: pass - def __init__(self, application, connected_future=None): + def __init__(self, application, connected_future=None, connection_done_future=None): self._send_seq = 0 self._rec_seq = 0 self._buffer = b'' @@ -36,6 +38,7 @@ def __init__(self, application, connected_future=None): self._connected_future = connected_future self._sendq = asyncio.Queue() self._pending = (-1, None) + self._connection_done_future = connection_done_future def connection_made(self, transport): """Callback when the uart is connected""" @@ -173,6 +176,8 @@ def _reset_cleanup(self, future): def connection_lost(self, exc): """Port was closed unexpectedly.""" + if self._connection_done_future: + self._connection_done_future.set_result(exc) if exc is None: LOGGER.debug("Closed serial connection") return @@ -180,13 +185,13 @@ def connection_lost(self, exc): LOGGER.error("Lost serial connection: %s", exc) self._application.connection_lost(exc) - def reset(self): + async def reset(self): """Send a reset frame and init internal state.""" LOGGER.debug("Resetting ASH") if self._reset_future is not None: LOGGER.error(("received new reset request while an existing " "one is in progress")) - return self._reset_future + return await self._reset_future self._send_seq = 0 self._rec_seq = 0 @@ -197,10 +202,10 @@ def reset(self): self._pending[1].set_result(True) self._pending = (-1, None) - self._reset_future = asyncio.Future() + self._reset_future = asyncio.get_event_loop().create_future() self._reset_future.add_done_callback(self._reset_cleanup) self.write(self._rst_frame()) - return asyncio.wait_for(self._reset_future, timeout=RESET_TIMEOUT) + return await asyncio.wait_for(self._reset_future, timeout=RESET_TIMEOUT) async def _send_task(self): """Send queue handler""" @@ -212,7 +217,7 @@ async def _send_task(self): success = False rxmit = 0 while not success: - self._pending = (seq, asyncio.Future()) + self._pending = (seq, asyncio.get_event_loop().create_future()) self.write(self._data_frame(data, seq, rxmit)) rxmit = 1 success = await self._pending[1] @@ -305,12 +310,12 @@ def _unstuff(self, s): return out -async def connect(port, baudrate, application, loop=None): - if loop is None: - loop = asyncio.get_event_loop() +async def _connect(port, baudrate, application): + loop = asyncio.get_event_loop() - connection_future = asyncio.Future() - protocol = Gateway(application, connection_future) + connection_future = loop.create_future() + connection_done_future = loop.create_future() + protocol = Gateway(application, connection_future, connection_done_future) transport, protocol = await serial_asyncio.create_serial_connection( loop, @@ -324,4 +329,17 @@ async def connect(port, baudrate, application, loop=None): await connection_future + thread_safe_protocol = ThreadsafeProxy(protocol, loop) + return thread_safe_protocol, connection_done_future + + +async def connect(port, baudrate, application, use_thread=True): + if use_thread: + application = ThreadsafeProxy(application, asyncio.get_event_loop()) + thread = EventLoopThread() + await thread.start() + protocol, connection_done = await thread.run_coroutine_threadsafe(_connect(port, baudrate, application)) + connection_done.add_done_callback(lambda _: thread.force_stop()) + else: + protocol, _ = await _connect(port, baudrate, application) return protocol diff --git a/tests/test_application.py b/tests/test_application.py index ba36f000..44ffd037 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -343,7 +343,7 @@ async def mocksend(method, nwk, aps_frame, seq, data): return [returnvals.pop(0)] def mock_get_device(*args, **kwargs): - dev = Device(app, mock.sentinel.ieee, mock.sentinel.nwk) + dev = Device(app, mock.sentinel.ieee, 0xaa55) dev.node_desc = mock.MagicMock() dev.node_desc.is_end_device = is_an_end_dev return dev diff --git a/tests/test_thread.py b/tests/test_thread.py new file mode 100644 index 00000000..a2d122ae --- /dev/null +++ b/tests/test_thread.py @@ -0,0 +1,180 @@ +import asyncio +from unittest import mock + +import pytest +import threading +import sys + +from bellows.thread import ThreadsafeProxy, EventLoopThread + +from async_generator import yield_, async_generator + + +@pytest.mark.asyncio +async def test_thread_start(monkeypatch): + current_loop = asyncio.get_event_loop() + loopmock = mock.MagicMock() + + monkeypatch.setattr( + asyncio, + 'new_event_loop', + lambda: loopmock + ) + monkeypatch.setattr( + asyncio, + 'set_event_loop', + lambda loop: None + ) + + def mockrun(task): + future = asyncio.run_coroutine_threadsafe(task, loop=current_loop) + return future.result(1) + + loopmock.run_until_complete.side_effect = mockrun + thread = EventLoopThread() + thread_complete = await thread.start() + await thread_complete + + assert loopmock.run_until_complete.call_count == 1 + assert loopmock.run_forever.call_count == 1 + assert loopmock.close.call_count == 1 + + +class ExceptionCollector: + def __init__(self): + self.exceptions = [] + + def __call__(self, thread_loop, context): + exc = context.get('exception') or Exception(context['message']) + self.exceptions.append(exc) + + +@pytest.fixture +@async_generator # Remove when Python 3.5 is no longer supported +async def thread(): + thread = EventLoopThread() + await thread.start() + thread.loop.call_soon_threadsafe(thread.loop.set_exception_handler, ExceptionCollector()) + await yield_(thread) + thread.force_stop() + if thread.thread_complete is not None: + await asyncio.wait_for(thread.thread_complete, 1) + [t.join(1) for t in threading.enumerate() if 'bellows' in t.name] + threads = [t for t in threading.enumerate() if 'bellows' in t.name] + assert len(threads) == 0 + + +async def yield_other_thread(thread): + await thread.run_coroutine_threadsafe(asyncio.sleep(0)) + + exception_collector = thread.loop.get_exception_handler() + if exception_collector.exceptions: + raise exception_collector.exceptions[0] + + +@pytest.mark.asyncio +async def test_thread_loop(thread): + async def test_coroutine(): + return mock.sentinel.result + + future = asyncio.run_coroutine_threadsafe(test_coroutine(), loop=thread.loop) + result = await asyncio.wrap_future(future, loop=asyncio.get_event_loop()) + assert result is mock.sentinel.result + + +@pytest.mark.asyncio +async def test_thread_double_start(thread): + previous_loop = thread.loop + await thread.start() + if sys.version_info[:2] >= (3, 6): + threads = [t for t in threading.enumerate() if 'bellows' in t.name] + assert len(threads) == 1 + assert thread.loop is previous_loop + + +@pytest.mark.asyncio +async def test_thread_already_stopped(thread): + thread.force_stop() + thread.force_stop() + + +@pytest.mark.asyncio +async def test_thread_run_coroutine_threadsafe(thread): + inner_loop = None + + async def test_coroutine(): + nonlocal inner_loop + inner_loop = asyncio.get_event_loop() + return mock.sentinel.result + + result = await thread.run_coroutine_threadsafe(test_coroutine()) + assert result is mock.sentinel.result + assert inner_loop is thread.loop + + +@pytest.mark.asyncio +async def test_proxy_callback(thread): + obj = mock.MagicMock() + proxy = ThreadsafeProxy(obj, thread.loop) + obj.test.return_value = None + proxy.test() + await yield_other_thread(thread) + assert obj.test.call_count == 1 + + +@pytest.mark.asyncio +async def test_proxy_async(thread): + obj = mock.MagicMock() + proxy = ThreadsafeProxy(obj, thread.loop) + call_count = 0 + + async def magic(): + nonlocal thread, call_count + assert asyncio.get_event_loop() == thread.loop + call_count += 1 + return mock.sentinel.result + + obj.test = magic + result = await proxy.test() + + assert call_count == 1 + assert result == mock.sentinel.result + + +@pytest.mark.asyncio +async def test_proxy_bad_function(thread): + obj = mock.MagicMock() + proxy = ThreadsafeProxy(obj, thread.loop) + obj.test.return_value = mock.sentinel.value + + with pytest.raises(TypeError): + proxy.test() + await yield_other_thread(thread) + + +@pytest.mark.asyncio +async def test_proxy_not_function(): + loop = asyncio.get_event_loop() + obj = mock.MagicMock() + proxy = ThreadsafeProxy(obj, loop) + obj.test = mock.sentinel.value + with pytest.raises(TypeError): + proxy.test + + +@pytest.mark.asyncio +async def test_proxy_no_thread(): + loop = asyncio.get_event_loop() + obj = mock.MagicMock() + proxy = ThreadsafeProxy(obj, loop) + proxy.test() + assert obj.test.call_count == 1 + + +def test_proxy_loop_closed(): + loop = asyncio.new_event_loop() + obj = mock.MagicMock() + proxy = ThreadsafeProxy(obj, loop) + loop.close() + proxy.test() + assert obj.test.call_count == 0 diff --git a/tests/test_uart.py b/tests/test_uart.py index 2319a914..6b1b9c60 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -3,17 +3,20 @@ import serial_asyncio import pytest +import threading from bellows import uart -def test_connect(monkeypatch): +@pytest.mark.asyncio +async def test_connect(monkeypatch): portmock = mock.MagicMock() appmock = mock.MagicMock() + transport = mock.MagicMock() async def mockconnect(loop, protocol_factory, **kwargs): protocol = protocol_factory() - loop.call_soon(protocol.connection_made, None) + loop.call_soon(protocol.connection_made, transport) return None, protocol monkeypatch.setattr( @@ -21,8 +24,43 @@ async def mockconnect(loop, protocol_factory, **kwargs): 'create_serial_connection', mockconnect, ) - loop = asyncio.get_event_loop() - loop.run_until_complete(uart.connect(portmock, 115200, appmock)) + await uart.connect(portmock, 115200, appmock, use_thread=False) + + threads = [t for t in threading.enumerate() if 'bellows' in t.name] + assert len(threads) == 0 + + +@pytest.mark.asyncio +async def test_connect_threaded(monkeypatch): + + portmock = mock.MagicMock() + appmock = mock.MagicMock() + transport = mock.MagicMock() + + async def mockconnect(loop, protocol_factory, **kwargs): + protocol = protocol_factory() + loop.call_soon(protocol.connection_made, transport) + return None, protocol + + monkeypatch.setattr( + serial_asyncio, + 'create_serial_connection', + mockconnect, + ) + + def on_transport_close(): + gw.connection_lost(None) + + transport.close.side_effect = on_transport_close + gw = await uart.connect(portmock, 115200, appmock) + + # Need to close to release thread + gw.close() + + # Ensure all threads are cleaned up + [t.join(1) for t in threading.enumerate() if 'bellows' in t.name] + threads = [t for t in threading.enumerate() if 'bellows' in t.name] + assert len(threads) == 0 @pytest.fixture @@ -167,14 +205,14 @@ def test_close(gw): @pytest.mark.asyncio async def test_reset(gw): + gw._loop = asyncio.get_event_loop() gw._sendq.put_nowait(mock.sentinel.queue_item) fut = asyncio.Future() gw._pending = (mock.sentinel.seq, fut) gw._transport.write.side_effect = lambda *args: gw._reset_future.set_result( mock.sentinel.reset_result) - ret = gw.reset() + reset_result = await gw.reset() - assert asyncio.iscoroutine(ret) assert gw._transport.write.call_count == 1 assert gw._send_seq == 0 assert gw._rec_seq == 0 @@ -183,21 +221,25 @@ async def test_reset(gw): assert fut.done() assert gw._pending == (-1, None) - reset_result = await ret assert reset_result is mock.sentinel.reset_result @pytest.mark.asyncio async def test_reset_timeout(gw, monkeypatch): + gw._loop = asyncio.get_event_loop() monkeypatch.setattr(uart, 'RESET_TIMEOUT', 0.1) with pytest.raises(asyncio.TimeoutError): await gw.reset() -def test_reset_old(gw): - gw._reset_future = mock.sentinel.future - ret = gw.reset() - assert ret == mock.sentinel.future +@pytest.mark.asyncio +async def test_reset_old(gw): + gw._loop = asyncio.get_event_loop() + future = asyncio.get_event_loop().create_future() + future.set_result(mock.sentinel.result) + gw._reset_future = future + ret = await gw.reset() + assert ret == mock.sentinel.result gw._transport.write.assert_not_called() diff --git a/tox.ini b/tox.ini index 08a4bd5c..f1c6f0dc 100644 --- a/tox.ini +++ b/tox.ini @@ -17,6 +17,7 @@ deps = pytest-cov pytest-asyncio zigpy-homeassistant + async_generator [testenv:lint] basepython = python3