diff --git a/pyproject.toml b/pyproject.toml index 1544f76..e76e655 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ readme = "README.md" license = {text = "GPL-3.0"} requires-python = ">=3.8" dependencies = [ - "zigpy>=0.60.0", + "zigpy>=0.70.0", ] [tool.setuptools.packages.find] @@ -43,6 +43,7 @@ ignore_errors = true [tool.pytest.ini_options] asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" [tool.flake8] exclude = [".venv", ".git", ".tox", "docs", "venv", "bin", "lib", "deps", "build"] diff --git a/tests/async_mock.py b/tests/async_mock.py deleted file mode 100644 index 8257ddd..0000000 --- a/tests/async_mock.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Mock utilities that are async aware.""" -import sys - -if sys.version_info[:2] < (3, 8): - from asynctest.mock import * # noqa - - AsyncMock = CoroutineMock # noqa: F405 -else: - from unittest.mock import * # noqa diff --git a/tests/test_api.py b/tests/test_api.py index 259d11a..48768e1 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,19 +1,17 @@ """Tests for API.""" import asyncio +from unittest import mock import pytest -import serial import zigpy.config import zigpy.exceptions import zigpy.types as t -from zigpy_xbee import api as xbee_api, types as xbee_t, uart +from zigpy_xbee import api as xbee_api, types as xbee_t from zigpy_xbee.exceptions import ATCommandError, ATCommandException, InvalidCommand from zigpy_xbee.zigbee.application import ControllerApplication -import tests.async_mock as mock - DEVICE_CONFIG = zigpy.config.SCHEMA_DEVICE( { zigpy.config.CONF_DEVICE_PATH: "/dev/null", @@ -26,24 +24,49 @@ def api(): """Sample XBee API fixture.""" api = xbee_api.XBee(DEVICE_CONFIG) - api._uart = mock.MagicMock() + api._uart = mock.AsyncMock() return api -async def test_connect(monkeypatch): +async def test_connect(): """Test connect.""" api = xbee_api.XBee(DEVICE_CONFIG) - monkeypatch.setattr(uart, "connect", mock.AsyncMock()) - await api.connect() + api._command = mock.AsyncMock(spec=api._command) + + with mock.patch("zigpy_xbee.uart.connect"): + await api.connect() + + +async def test_connect_initial_timeout_success(): + """Test connect, initial command times out.""" + api = xbee_api.XBee(DEVICE_CONFIG) + api._at_command = mock.AsyncMock(side_effect=asyncio.TimeoutError) + api.init_api_mode = mock.AsyncMock(return_value=True) + + with mock.patch("zigpy_xbee.uart.connect"): + await api.connect() + + +async def test_connect_initial_timeout_failure(): + """Test connect, initial command times out.""" + api = xbee_api.XBee(DEVICE_CONFIG) + api._at_command = mock.AsyncMock(side_effect=asyncio.TimeoutError) + api.init_api_mode = mock.AsyncMock(return_value=False) + + with mock.patch("zigpy_xbee.uart.connect") as mock_connect: + with pytest.raises(zigpy.exceptions.APIException): + await api.connect() + assert mock_connect.return_value.disconnect.mock_calls == [mock.call()] -def test_close(api): + +async def test_disconnect(api): """Test connection close.""" uart = api._uart - api.close() + await api.disconnect() assert api._uart is None - assert uart.close.call_count == 1 + assert uart.disconnect.call_count == 1 def test_commands(): @@ -599,97 +622,10 @@ def test_handle_many_to_one_rri(api): api._handle_many_to_one_rri(ieee, nwk, 0) -@mock.patch.object(xbee_api.XBee, "_at_command", new_callable=mock.AsyncMock) -@mock.patch.object(uart, "connect", return_value=mock.MagicMock()) -async def test_probe_success(mock_connect, mock_at_cmd): - """Test device probing.""" - - res = await xbee_api.XBee.probe(DEVICE_CONFIG) - assert res is True - assert mock_connect.call_count == 1 - assert mock_connect.await_count == 1 - assert mock_connect.call_args[0][0] == DEVICE_CONFIG - assert mock_at_cmd.call_count == 1 - assert mock_connect.return_value.close.call_count == 1 - - -@mock.patch.object(xbee_api.XBee, "init_api_mode", return_value=True) -@mock.patch.object(xbee_api.XBee, "_at_command", side_effect=asyncio.TimeoutError) -@mock.patch.object(uart, "connect", return_value=mock.MagicMock()) -async def test_probe_success_api_mode(mock_connect, mock_at_cmd, mock_api_mode): - """Test device probing.""" - - res = await xbee_api.XBee.probe(DEVICE_CONFIG) - assert res is True - assert mock_connect.call_count == 1 - assert mock_connect.await_count == 1 - assert mock_connect.call_args[0][0] == DEVICE_CONFIG - assert mock_at_cmd.call_count == 1 - assert mock_api_mode.call_count == 1 - assert mock_connect.return_value.close.call_count == 1 - - -@mock.patch.object(xbee_api.XBee, "init_api_mode") -@mock.patch.object(xbee_api.XBee, "_at_command", side_effect=asyncio.TimeoutError) -@mock.patch.object(uart, "connect", return_value=mock.MagicMock()) -@pytest.mark.parametrize( - "exception", - (asyncio.TimeoutError, serial.SerialException, zigpy.exceptions.APIException), -) -async def test_probe_fail(mock_connect, mock_at_cmd, mock_api_mode, exception): - """Test device probing fails.""" - - mock_api_mode.side_effect = exception - mock_api_mode.reset_mock() - mock_at_cmd.reset_mock() - mock_connect.reset_mock() - res = await xbee_api.XBee.probe(DEVICE_CONFIG) - assert res is False - assert mock_connect.call_count == 1 - assert mock_connect.await_count == 1 - assert mock_connect.call_args[0][0] == DEVICE_CONFIG - assert mock_at_cmd.call_count == 1 - assert mock_api_mode.call_count == 1 - assert mock_connect.return_value.close.call_count == 1 - - -@mock.patch.object(xbee_api.XBee, "init_api_mode", return_value=False) -@mock.patch.object(xbee_api.XBee, "_at_command", side_effect=asyncio.TimeoutError) -@mock.patch.object(uart, "connect", return_value=mock.MagicMock()) -async def test_probe_fail_api_mode(mock_connect, mock_at_cmd, mock_api_mode): - """Test device probing fails.""" - - mock_api_mode.reset_mock() - mock_at_cmd.reset_mock() - mock_connect.reset_mock() - res = await xbee_api.XBee.probe(DEVICE_CONFIG) - assert res is False - assert mock_connect.call_count == 1 - assert mock_connect.await_count == 1 - assert mock_connect.call_args[0][0] == DEVICE_CONFIG - assert mock_at_cmd.call_count == 1 - assert mock_api_mode.call_count == 1 - assert mock_connect.return_value.close.call_count == 1 - - -@mock.patch.object(xbee_api.XBee, "connect", return_value=mock.MagicMock()) -async def test_xbee_new(conn_mck): - """Test new class method.""" - api = await xbee_api.XBee.new(mock.sentinel.application, DEVICE_CONFIG) - assert isinstance(api, xbee_api.XBee) - assert conn_mck.call_count == 1 - assert conn_mck.await_count == 1 - - -@mock.patch.object(xbee_api.XBee, "connect", return_value=mock.MagicMock()) -async def test_connection_lost(conn_mck): +async def test_connection_lost(api): """Test `connection_lost` propagataion.""" - api = await xbee_api.XBee.new(mock.sentinel.application, DEVICE_CONFIG) - await api.connect() - - app = api._app = mock.MagicMock() + api.set_application(mock.AsyncMock()) err = RuntimeError() api.connection_lost(err) - - app.connection_lost.assert_called_once_with(err) + api._app.connection_lost.assert_called_once_with(err) diff --git a/tests/test_application.py b/tests/test_application.py index 00467b0..e4c363f 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,6 +1,7 @@ """Tests for ControllerApplication.""" import asyncio +from unittest import mock import pytest import zigpy.config as config @@ -15,8 +16,6 @@ import zigpy_xbee.types as xbee_t from zigpy_xbee.zigbee import application -import tests.async_mock as mock - APP_CONFIG = { config.CONF_DEVICE: { config.CONF_DEVICE_PATH: "/dev/null", @@ -374,13 +373,12 @@ def init_api_mode_mock(): api_mode = api_config_succeeds return api_config_succeeds - with mock.patch("zigpy_xbee.api.XBee") as XBee_mock: - api_mock = mock.MagicMock() - api_mock._at_command = mock.AsyncMock(side_effect=_at_command_mock) - api_mock.init_api_mode = mock.AsyncMock(side_effect=init_api_mode_mock) - - XBee_mock.new = mock.AsyncMock(return_value=api_mock) + api_mock = mock.MagicMock() + api_mock._at_command = mock.AsyncMock(side_effect=_at_command_mock) + api_mock.init_api_mode = mock.AsyncMock(side_effect=init_api_mode_mock) + api_mock.connect = mock.AsyncMock() + with mock.patch("zigpy_xbee.api.XBee", return_value=api_mock): await app.connect() app.form_network = mock.AsyncMock() @@ -418,23 +416,17 @@ async def test_start_network(app): async def test_start_network_no_api_mode(app): """Test start network when not in API mode.""" - await _test_start_network(app, ai_status=0x00, api_mode=False) - assert app.state.node_info.nwk == 0x0000 - assert app.state.node_info.ieee == t.EUI64(range(1, 9)) - assert app._api.init_api_mode.call_count == 1 - assert app._api._at_command.call_count >= 16 + with pytest.raises(asyncio.TimeoutError): + await _test_start_network(app, ai_status=0x00, api_mode=False) async def test_start_network_api_mode_config_fails(app): """Test start network when not when API config fails.""" - with pytest.raises(zigpy.exceptions.ControllerException): + with pytest.raises(asyncio.TimeoutError): await _test_start_network( app, ai_status=0x00, api_mode=False, api_config_succeeds=False ) - assert app._api.init_api_mode.call_count == 1 - assert app._api._at_command.call_count == 1 - async def test_permit(app): """Test permit joins.""" @@ -559,11 +551,11 @@ async def test_force_remove(app): async def test_shutdown(app): """Test application shutdown.""" - mack_close = mock.MagicMock() - app._api.close = mack_close - await app.shutdown() + mock_disconnect = mock.AsyncMock() + app._api.disconnect = mock_disconnect + await app.disconnect() assert app._api is None - assert mack_close.call_count == 1 + assert mock_disconnect.call_count == 1 async def test_remote_at_cmd(app, device): diff --git a/tests/test_uart.py b/tests/test_uart.py index 93362dc..a235bd8 100644 --- a/tests/test_uart.py +++ b/tests/test_uart.py @@ -68,10 +68,12 @@ def test_command_mode_send(gw): gw._transport.write.assert_called_once_with(data) -def test_close(gw): +async def test_disconnect(gw): """Test closing connection.""" - gw.close() - assert gw._transport.close.call_count == 1 + transport = gw._transport + asyncio.get_running_loop().call_soon(gw.connection_lost, None) + await gw.disconnect() + assert transport.close.call_count == 1 def test_data_received_chunk_frame(gw): @@ -228,22 +230,12 @@ def test_unescape_underflow(gw): def test_connection_lost_exc(gw): """Test cannection lost callback is called.""" - gw._connected_future = asyncio.Future() - - gw.connection_lost(ValueError()) - - conn_lost = gw._api.connection_lost - assert conn_lost.call_count == 1 - assert isinstance(conn_lost.call_args[0][0], Exception) - assert gw._connected_future.done() - assert gw._connected_future.exception() + err = RuntimeError() + gw.connection_lost(err) + assert gw._api.connection_lost.mock_calls == [mock.call(err)] def test_connection_closed(gw): """Test connection closed.""" - gw._connected_future = asyncio.Future() gw.connection_lost(None) - - assert gw._api.connection_lost.call_count == 0 - assert gw._connected_future.done() - assert gw._connected_future.result() is True + assert gw._api.connection_lost.mock_calls == [mock.call(None)] diff --git a/zigpy_xbee/api.py b/zigpy_xbee/api.py index b4a73da..55faeae 100644 --- a/zigpy_xbee/api.py +++ b/zigpy_xbee/api.py @@ -6,12 +6,9 @@ import logging from typing import Any, Dict, Optional -import serial -from zigpy.config import CONF_DEVICE_PATH, SCHEMA_DEVICE from zigpy.exceptions import APIException, DeliveryError import zigpy.types as t -import zigpy_xbee from zigpy_xbee.exceptions import ( ATCommandError, ATCommandException, @@ -26,7 +23,6 @@ AT_COMMAND_TIMEOUT = 3 REMOTE_AT_COMMAND_TIMEOUT = 30 -PROBE_TIMEOUT = 45 # https://www.digi.com/resources/documentation/digidocs/PDFs/90000976.pdf @@ -305,32 +301,31 @@ def is_running(self): """Return true if coordinator is running.""" return self.coordinator_started_event.is_set() - @classmethod - async def new( - cls, - application: "zigpy_xbee.zigbee.application.ControllerApplication", - config: Dict[str, Any], - ) -> "XBee": - """Create new instance.""" - xbee_api = cls(config) - await xbee_api.connect() - xbee_api.set_application(application) - return xbee_api - async def connect(self) -> None: """Connect to the device.""" assert self._uart is None self._uart = await uart.connect(self._config, self) + try: + try: + # Ensure we have escaped commands + await self._at_command("AP", 2) + except asyncio.TimeoutError: + if not await self.init_api_mode(): + raise APIException("Failed to configure XBee for API mode") + except Exception: + await self.disconnect() + raise + def connection_lost(self, exc: Exception) -> None: """Lost serial connection.""" if self._app is not None: self._app.connection_lost(exc) - def close(self): + async def disconnect(self): """Close the connection.""" if self._uart: - self._uart.close() + await self._uart.disconnect() self._uart = None def _command(self, name, *args, mask_frame_id=False): @@ -568,36 +563,6 @@ async def init_api_mode(self): ) return False - @classmethod - async def probe(cls, device_config: Dict[str, Any]) -> bool: - """Probe port for the device presence.""" - api = cls(SCHEMA_DEVICE(device_config)) - try: - await asyncio.wait_for(api._probe(), timeout=PROBE_TIMEOUT) - return True - except (asyncio.TimeoutError, serial.SerialException, APIException) as exc: - LOGGER.debug( - "Unsuccessful radio probe of '%s' port", - device_config[CONF_DEVICE_PATH], - exc_info=exc, - ) - finally: - api.close() - - return False - - async def _probe(self) -> None: - """Open port and try sending a command.""" - await self.connect() - try: - # Ensure we have escaped commands - await self._at_command("AP", 2) - except asyncio.TimeoutError: - if not await self.init_api_mode(): - raise APIException("Failed to configure XBee for API mode") - finally: - self.close() - def __getattr__(self, item): """Handle supported command requests.""" if item in COMMAND_REQUESTS: diff --git a/zigpy_xbee/uart.py b/zigpy_xbee/uart.py index 2dbacd1..61a3dcf 100644 --- a/zigpy_xbee/uart.py +++ b/zigpy_xbee/uart.py @@ -10,7 +10,7 @@ LOGGER = logging.getLogger(__name__) -class Gateway(asyncio.Protocol): +class Gateway(zigpy.serial.SerialProtocol): """Class implementing the UART protocol.""" START = b"\x7E" @@ -21,10 +21,9 @@ class Gateway(asyncio.Protocol): RESERVED = START + ESCAPE + XON + XOFF THIS_ONE = True - def __init__(self, api, connected_future=None): + def __init__(self, api): """Initialize instance.""" - self._buffer = b"" - self._connected_future = connected_future + super().__init__() self._api = api self._in_command_mode = False @@ -54,24 +53,10 @@ def baudrate(self, baudrate): def connection_lost(self, exc) -> None: """Port was closed expectedly or unexpectedly.""" - if self._connected_future and not self._connected_future.done(): - if exc is None: - self._connected_future.set_result(True) - else: - self._connected_future.set_exception(exc) - if exc is None: - LOGGER.debug("Closed serial connection") - return - - LOGGER.error("Lost serial connection: %s", exc) - self._api.connection_lost(exc) + super().connection_lost(exc) - def connection_made(self, transport): - """Handle UART connection callback.""" - LOGGER.debug("Connection made") - self._transport = transport - if self._connected_future: - self._connected_future.set_result(True) + if self._api is not None: + self._api.connection_lost(exc) def command_mode_rsp(self, data): """Handle AT command mode response.""" @@ -87,14 +72,15 @@ def command_mode_send(self, data): def data_received(self, data): """Handle data received from the UART callback.""" - self._buffer += data + super().data_received(data) while self._buffer: frame = self._extract_frame() if frame is None: break self.frame_received(frame) if self._in_command_mode and self._buffer[-1:] == b"\r": - rsp, self._buffer = (self._buffer[:-1], b"") + rsp = self._buffer[:-1] + self._buffer.clear() self.command_mode_rsp(rsp) def frame_received(self, frame): @@ -102,10 +88,6 @@ def frame_received(self, frame): LOGGER.debug("Frame received: %s", frame) self._api.frame_received(frame) - def close(self): - """Close the connection.""" - self._transport.close() - def reset_command_mode(self): r"""Reset command mode and ignore '\r' character as command mode response.""" self._in_command_mode = False @@ -166,22 +148,16 @@ def _checksum(self, data): return 0xFF - (sum(data) % 0x100) -async def connect(device_config: Dict[str, Any], api, loop=None) -> Gateway: +async def connect(device_config: Dict[str, Any], api) -> Gateway: """Connect to the device.""" - if loop is None: - loop = asyncio.get_event_loop() - - connected_future = asyncio.Future() - protocol = Gateway(api, connected_future) - transport, protocol = await zigpy.serial.create_serial_connection( - loop, - lambda: protocol, + loop=asyncio.get_running_loop(), + protocol_factory=lambda: Gateway(api), url=device_config[zigpy.config.CONF_DEVICE_PATH], baudrate=device_config[zigpy.config.CONF_DEVICE_BAUDRATE], - xonxoff=False, + xonxoff=device_config[zigpy.config.CONF_DEVICE_BAUDRATE], ) - await connected_future + await protocol.wait_until_connected() return protocol diff --git a/zigpy_xbee/zigbee/application.py b/zigpy_xbee/zigbee/application.py index fd8b53b..aedb799 100644 --- a/zigpy_xbee/zigbee/application.py +++ b/zigpy_xbee/zigbee/application.py @@ -54,21 +54,16 @@ def __init__(self, config: dict[str, Any]): async def disconnect(self): """Shutdown application.""" if self._api: - self._api.close() + await self._api.disconnect() self._api = None async def connect(self): """Connect to the device.""" - self._api = await zigpy_xbee.api.XBee.new(self, self._config[CONF_DEVICE]) - try: - # Ensure we have escaped commands - await self._api._at_command("AP", 2) - except asyncio.TimeoutError: - LOGGER.debug("No response to API frame. Configure API mode") - if not await self._api.init_api_mode(): - raise zigpy.exceptions.ControllerException( - "Failed to configure XBee API mode." - ) + api = zigpy_xbee.api.XBee(self._config[CONF_DEVICE]) + await api.connect() + api.set_application(self) + + self._api = api async def start_network(self): """Configure the module to work with Zigpy."""