diff --git a/tests/test_api.py b/tests/test_api.py index 74c05de..8facdc9 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -10,6 +10,7 @@ from zigpy_xbee import api as xbee_api, types as xbee_t, uart import zigpy_xbee.config +from zigpy_xbee.exceptions import ATCommandError, ATCommandException, InvalidCommand from zigpy_xbee.zigbee.application import ControllerApplication import tests.async_mock as mock @@ -327,7 +328,16 @@ def test_handle_at_response_error(api): status, response = 1, 0x23 fut = _handle_at_response(api, tsn, status, [response]) assert fut.done() is True - assert fut.exception() is not None + assert isinstance(fut.exception(), ATCommandError) + + +def test_handle_at_response_invalid_command(api): + """Test invalid AT command response.""" + tsn = 123 + status, response = 2, 0x23 + fut = _handle_at_response(api, tsn, status, [response]) + assert fut.done() is True + assert isinstance(fut.exception(), InvalidCommand) def test_handle_at_response_undef_error(api): @@ -336,7 +346,7 @@ def test_handle_at_response_undef_error(api): status, response = 0xEE, 0x23 fut = _handle_at_response(api, tsn, status, [response]) assert fut.done() is True - assert fut.exception() is not None + assert isinstance(fut.exception(), ATCommandException) def test_handle_remote_at_rsp(api): diff --git a/tests/test_application.py b/tests/test_application.py index a508caf..4ecb643 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -11,6 +11,7 @@ from zigpy_xbee.api import XBee import zigpy_xbee.config as config +from zigpy_xbee.exceptions import InvalidCommand import zigpy_xbee.types as xbee_t from zigpy_xbee.zigbee import application @@ -305,7 +306,7 @@ async def test_write_network_info(app, node_info, network_info, legacy_module): def _mock_queued_at(name, *args): if legacy_module and name == "CE": - raise RuntimeError("Legacy module") + raise InvalidCommand("Legacy module") return "OK" app._api._queued_at = mock.AsyncMock( @@ -346,7 +347,7 @@ def _at_command_mock(cmd, *args): if not api_mode: raise asyncio.TimeoutError if cmd == "CE" and legacy_module: - raise RuntimeError + raise InvalidCommand ai_tries -= 1 if cmd == "AI" else 0 return { diff --git a/zigpy_xbee/api.py b/zigpy_xbee/api.py index 32c5c81..1076e9c 100644 --- a/zigpy_xbee/api.py +++ b/zigpy_xbee/api.py @@ -2,7 +2,6 @@ import asyncio import binascii -import enum import functools import logging from typing import Any, Dict, Optional @@ -13,6 +12,13 @@ import zigpy_xbee from zigpy_xbee.config import CONF_DEVICE_BAUDRATE, CONF_DEVICE_PATH, SCHEMA_DEVICE +from zigpy_xbee.exceptions import ( + ATCommandError, + ATCommandException, + InvalidCommand, + InvalidParameter, + TransmissionFailure, +) from . import types as xbee_t, uart @@ -261,14 +267,12 @@ } -class ATCommandResult(enum.IntEnum): - """AT Command Result.""" - - OK = 0 - ERROR = 1 - INVALID_COMMAND = 2 - INVALID_PARAMETER = 3 - TX_FAILURE = 4 +AT_COMMAND_RESULT = { + 1: ATCommandError, + 2: InvalidCommand, + 3: InvalidParameter, + 4: TransmissionFailure, +} class XBee: @@ -444,13 +448,13 @@ def frame_received(self, data): def _handle_at_response(self, frame_id, cmd, status, value): """Local AT command response.""" (fut,) = self._awaiting.pop(frame_id) - try: - status = ATCommandResult(status) - except ValueError: - status = ATCommandResult.ERROR if status: - fut.set_exception(RuntimeError(f"AT Command response: {status.name}")) + try: + exception = AT_COMMAND_RESULT[status] + except KeyError: + exception = ATCommandException + fut.set_exception(exception(f"AT Command response: {status}")) return response_type = AT_COMMANDS[cmd.decode("ascii")] diff --git a/zigpy_xbee/exceptions.py b/zigpy_xbee/exceptions.py new file mode 100644 index 0000000..32c38d9 --- /dev/null +++ b/zigpy_xbee/exceptions.py @@ -0,0 +1,21 @@ +"""Additional exceptions for XBee.""" + + +class ATCommandException(Exception): + """Base exception class for AT Command exceptions.""" + + +class ATCommandError(ATCommandException): + """Exception for AT Command Status 1 (ERROR).""" + + +class InvalidCommand(ATCommandException): + """Exception for AT Command Status 2 (Invalid command).""" + + +class InvalidParameter(ATCommandException): + """Exception for AT Command Status 3 (Invalid parameter).""" + + +class TransmissionFailure(ATCommandException): + """Exception for Remote AT Command Status 4 (Transmission failure).""" diff --git a/zigpy_xbee/zigbee/application.py b/zigpy_xbee/zigbee/application.py index 5733fc9..18969ac 100644 --- a/zigpy_xbee/zigbee/application.py +++ b/zigpy_xbee/zigbee/application.py @@ -23,6 +23,7 @@ import zigpy_xbee import zigpy_xbee.api from zigpy_xbee.config import CONF_DEVICE, CONFIG_SCHEMA, SCHEMA_DEVICE +from zigpy_xbee.exceptions import InvalidCommand from zigpy_xbee.types import EUI64, UNKNOWN_IEEE, UNKNOWN_NWK, TXOptions, TXStatus # how long coordinator would hold message for an end device in 10ms units @@ -131,7 +132,7 @@ async def load_network_info(self, *, load_devices=False): node_info.logical_type = zdo_t.LogicalType.Coordinator else: node_info.logical_type = zdo_t.LogicalType.EndDevice - except RuntimeError: + except InvalidCommand: LOGGER.warning("CE command failed, assuming node is coordinator") node_info.logical_type = zdo_t.LogicalType.Coordinator @@ -171,7 +172,7 @@ async def write_network_info(self, *, network_info, node_info): try: await self._api._queued_at("CE", 1) - except RuntimeError: + except InvalidCommand: pass await self._api._at_command("WR")