diff --git a/miio/device.py b/miio/device.py index 14912dc2c..682a127a8 100644 --- a/miio/device.py +++ b/miio/device.py @@ -5,6 +5,7 @@ import click from .click_common import DeviceGroupMeta, LiteralParamType, command, format_output +from .exceptions import DeviceInfoUnavailableException, PayloadDecodeException from .miioprotocol import MiIOProtocol _LOGGER = logging.getLogger(__name__) @@ -161,7 +162,7 @@ def raw_command(self, command, parameters): :param str command: Command to send :param dict parameters: Parameters to send""" - return self._protocol.send(command, parameters) + return self.send(command, parameters) @command( default_output=format_output( @@ -177,7 +178,12 @@ def info(self) -> DeviceInfo: """Get miIO protocol information from the device. This includes information about connected wlan network, and hardware and software versions.""" - return DeviceInfo(self._protocol.send("miIO.info")) + try: + return DeviceInfo(self.send("miIO.info")) + except PayloadDecodeException as ex: + raise DeviceInfoUnavailableException( + "Unable to request miIO.info from the device" + ) from ex def update(self, url: str, md5: str): """Start an OTA update.""" @@ -188,15 +194,15 @@ def update(self, url: str, md5: str): "file_md5": md5, "proc": "dnld install", } - return self._protocol.send("miIO.ota", payload)[0] == "ok" + return self.send("miIO.ota", payload)[0] == "ok" def update_progress(self) -> int: """Return current update progress [0-100].""" - return self._protocol.send("miIO.get_ota_progress")[0] + return self.send("miIO.get_ota_progress")[0] def update_state(self): """Return current update state.""" - return UpdateState(self._protocol.send("miIO.get_ota_state")[0]) + return UpdateState(self.send("miIO.get_ota_state")[0]) def configure_wifi(self, ssid, password, uid=0, extra_params=None): """Configure the wifi settings.""" @@ -204,7 +210,7 @@ def configure_wifi(self, ssid, password, uid=0, extra_params=None): extra_params = {} params = {"ssid": ssid, "passwd": password, "uid": uid, **extra_params} - return self._protocol.send("miIO.config_router", params)[0] + return self.send("miIO.config_router", params)[0] def get_properties(self, properties, *, max_properties=None): """Request properties in slices based on given max_properties. diff --git a/miio/exceptions.py b/miio/exceptions.py index 7305b0ba7..90c7ee04f 100644 --- a/miio/exceptions.py +++ b/miio/exceptions.py @@ -2,8 +2,28 @@ class DeviceException(Exception): """Exception wrapping any communication errors with the device.""" +class PayloadDecodeException(DeviceException): + """Exception for failures in payload decoding. + + This is raised when the json payload cannot be decoded, + indicating invalid response from a device. + """ + + +class DeviceInfoUnavailableException(DeviceException): + """Exception raised when requesting miio.info fails. + + This allows users to gracefully handle cases where the information unavailable. + This can happen, for instance, when the device has no cloud access. + """ + + class DeviceError(DeviceException): - """Exception communicating an error delivered by the target device.""" + """Exception communicating an error delivered by the target device. + + The device given error code and message can be accessed with + `code` and `message` variables. + """ def __init__(self, error): self.code = error.get("code") diff --git a/miio/protocol.py b/miio/protocol.py index 28e115fb4..2c4b3f99c 100644 --- a/miio/protocol.py +++ b/miio/protocol.py @@ -38,6 +38,8 @@ from cryptography.hazmat.primitives import padding from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from miio.exceptions import PayloadDecodeException + _LOGGER = logging.getLogger(__name__) @@ -193,7 +195,10 @@ def _decode(self, obj, context, path): # log the error when decrypted bytes couldn't be loaded # after trying all quirk adaptions if i == len(decrypted_quirks) - 1: - _LOGGER.error("unable to parse json '%s': %s", decoded, ex) + _LOGGER.debug("Unable to parse json '%s': %s", decoded, ex) + raise PayloadDecodeException( + "Unable to parse message payload" + ) from ex return None diff --git a/miio/tests/test_device.py b/miio/tests/test_device.py index 5e5f6ef6a..1aeede968 100644 --- a/miio/tests/test_device.py +++ b/miio/tests/test_device.py @@ -3,6 +3,7 @@ import pytest from miio import Device +from miio.exceptions import DeviceInfoUnavailableException, PayloadDecodeException @pytest.mark.parametrize("max_properties", [None, 1, 15]) @@ -16,3 +17,13 @@ def test_get_properties_splitting(mocker, max_properties): if max_properties is None: max_properties = len(properties) assert send.call_count == math.ceil(len(properties) / max_properties) + + +def test_unavailable_device_info_raises(mocker): + send = mocker.patch("miio.Device.send", side_effect=PayloadDecodeException) + d = Device("127.0.0.1", "68ffffffffffffffffffffffffffffff") + + with pytest.raises(DeviceInfoUnavailableException): + d.info() + + assert send.call_count == 1 diff --git a/miio/tests/test_protocol.py b/miio/tests/test_protocol.py index ddae2db01..095c2e9ab 100644 --- a/miio/tests/test_protocol.py +++ b/miio/tests/test_protocol.py @@ -2,7 +2,7 @@ import pytest -from miio.exceptions import DeviceError, RecoverableError +from miio.exceptions import DeviceError, PayloadDecodeException, RecoverableError from .. import Utils from ..miioprotocol import MiIOProtocol @@ -17,6 +17,28 @@ def proto() -> MiIOProtocol: return MiIOProtocol() +@pytest.fixture +def token() -> bytes: + return bytes.fromhex(32 * "0") + + +def build_msg(data, token): + encrypted_data = Utils.encrypt(data, token) + + # header + magic = binascii.unhexlify(b"2131") + length = (32 + len(encrypted_data)).to_bytes(2, byteorder="big") + unknown = binascii.unhexlify(b"00000000") + did = binascii.unhexlify(b"01234567") + epoch = binascii.unhexlify(b"00000000") + + checksum = Utils.md5( + magic + length + unknown + did + epoch + token + encrypted_data + ) + + return magic + length + unknown + did + epoch + checksum + encrypted_data + + def test_incrementing_id(proto): old_id = proto.raw_id proto._create_request("dummycmd", "dummy") @@ -62,18 +84,16 @@ def test_device_error_handling(proto: MiIOProtocol): proto._handle_error({"code": 1234}) -def test_non_bytes_payload(): +def test_non_bytes_payload(token): payload = "hello world" - valid_token = 32 * b"0" with pytest.raises(TypeError): - Utils.encrypt(payload, valid_token) + Utils.encrypt(payload, token) with pytest.raises(TypeError): - Utils.decrypt(payload, valid_token) + Utils.decrypt(payload, token) -def test_encrypt(): +def test_encrypt(token): payload = b"hello world" - token = bytes.fromhex(32 * "0") encrypted = Utils.encrypt(payload, token) decrypted = Utils.decrypt(encrypted, token) @@ -95,46 +115,46 @@ def test_invalid_token(): Utils.decrypt(payload, wrong_length) -def test_decode_json_payload(): - token = bytes.fromhex(32 * "0") +def test_decode_json_payload(token): ctx = {"token": token} - def build_msg(data): - encrypted_data = Utils.encrypt(data, token) - - # header - magic = binascii.unhexlify(b"2131") - length = (32 + len(encrypted_data)).to_bytes(2, byteorder="big") - unknown = binascii.unhexlify(b"00000000") - did = binascii.unhexlify(b"01234567") - epoch = binascii.unhexlify(b"00000000") - - checksum = Utils.md5( - magic + length + unknown + did + epoch + token + encrypted_data - ) - - return magic + length + unknown + did + epoch + checksum + encrypted_data - # can parse message with valid json - serialized_msg = build_msg(b'{"id": 123456}') + serialized_msg = build_msg(b'{"id": 123456}', token) parsed_msg = Message.parse(serialized_msg, **ctx) assert parsed_msg.data.value assert isinstance(parsed_msg.data.value, dict) assert parsed_msg.data.value["id"] == 123456 + +def test_decode_json_quirk_powerstrip(token): + ctx = {"token": token} + # can parse message with invalid json for edge case powerstrip # when not connected to cloud - serialized_msg = build_msg(b'{"id": 123456,,"otu_stat":0}') + serialized_msg = build_msg(b'{"id": 123456,,"otu_stat":0}', token) parsed_msg = Message.parse(serialized_msg, **ctx) assert parsed_msg.data.value assert isinstance(parsed_msg.data.value, dict) assert parsed_msg.data.value["id"] == 123456 assert parsed_msg.data.value["otu_stat"] == 0 + +def test_decode_json_quirk_cloud(token): + ctx = {"token": token} + # can parse message with invalid json for edge case xiaomi cloud # reply to _sync.batch_gen_room_up_url - serialized_msg = build_msg(b'{"id": 123456}\x00k') + serialized_msg = build_msg(b'{"id": 123456}\x00k', token) parsed_msg = Message.parse(serialized_msg, **ctx) assert parsed_msg.data.value assert isinstance(parsed_msg.data.value, dict) assert parsed_msg.data.value["id"] == 123456 + + +def test_decode_json_raises_for_invalid_json(token): + ctx = {"token": token} + + # make sure PayloadDecodeDexception is raised for invalid json + serialized_msg = build_msg(b'{"id": 123456,,"otu_stat":0', token) + with pytest.raises(PayloadDecodeException): + Message.parse(serialized_msg, **ctx)