diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index db5ad7e..9b50f92 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,12 @@ repos: args: [--markdown-linebreak-ext=md] - id: check-yaml - id: check-added-large-files -- repo: https://github.com/psf/black - rev: 23.7.0 +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.1.8 hooks: - - id: black \ No newline at end of file + # Run the linter. + #- id: ruff + # args: [ --fix ] + # Run the formatter. + - id: ruff-format \ No newline at end of file diff --git a/JciHitachi/api.py b/JciHitachi/api.py index 347d39f..5eab3b1 100644 --- a/JciHitachi/api.py +++ b/JciHitachi/api.py @@ -347,10 +347,13 @@ def task_id(self) -> int: Returns ------- int - Serial number counted from 0. + Serial number counted from 0, with maximum 999. """ self._task_id += 1 + if self._task_id >= 1000: + self._task_id = 1 + return self._task_id def _sync_peripherals_availablity(self) -> None: @@ -914,7 +917,6 @@ def __init__( self._things: dict[str, AWSThing] = {} self._aws_tokens: Optional[aws_connection.AWSTokens] = None self._aws_identity: Optional[aws_connection.AWSIdentity] = None - self._host_identity_id: Optional[str] = None self._task_id: int = 0 @property @@ -980,24 +982,6 @@ def login(self) -> None: self._aws_tokens = conn.aws_tokens conn_status, self._aws_identity = conn.get_data() - conn = aws_connection.ListSubUser( - self._aws_tokens, print_response=self.print_response - ) - conn_status, conn_json = conn.get_data() - - if conn_status == "OK": - for user in conn_json["results"]["FamilyMemberList"]: - if user["isHost"]: - self._host_identity_id = user["userId"] - break - assert ( - self._host_identity_id is not None - ), "Host is not found in the user list" - else: - raise RuntimeError( - f"An error occurred when listing account users: {conn_status}" - ) - conn = aws_connection.GetAllDevice( self._aws_tokens, print_response=self.print_response ) @@ -1028,13 +1012,13 @@ def get_credential_callable(): self._mqtt = aws_connection.JciHitachiAWSMqttConnection( get_credential_callable, print_response=self.print_response ) - self._mqtt.configure() + self._mqtt.configure(self._aws_identity.identity_id) if not self._mqtt.connect( - self._host_identity_id, self._shadow_names, thing_names + self._aws_identity.host_identity_id, self._shadow_names, thing_names ): raise RuntimeError( - f"An error occurred when connecting to MQTT endpoint." + "An error occurred when connecting to MQTT endpoint." ) # status @@ -1175,7 +1159,7 @@ def refresh_status( if refresh_support_code: self._mqtt.publish( - self._host_identity_id, + self._aws_identity.host_identity_id, thing.thing_name, "support", self._mqtt_timeout, @@ -1184,7 +1168,10 @@ def refresh_status( self._mqtt.publish_shadow(thing.thing_name, "get", shadow_name="info") self._mqtt.publish( - self._host_identity_id, thing.thing_name, "status", self._mqtt_timeout + self._aws_identity.host_identity_id, + thing.thing_name, + "status", + self._mqtt_timeout, ) # execute @@ -1324,9 +1311,7 @@ def set_status( "enableQAMode": "qa", } - if ( - False - ): # status_name in shadow_publish_mapping: # TODO: replace False cond after shadow function is completed. + if False: # status_name in shadow_publish_mapping: # TODO: replace False cond after shadow function is completed. shadow_publish_schema = {} if ( shadow_publish_mapping[status_name] == "filter" @@ -1356,22 +1341,14 @@ def set_status( return False self._mqtt.publish( - self._host_identity_id, + self._aws_identity.host_identity_id, thing.thing_name, "control", self._mqtt_timeout, { - "Condition": { - "ThingName": thing.thing_name, - "Index": 0, - "Geofencing": { - "Arrive": None, - "Leave": None, - }, - }, status_name: status_value, "TaskID": self.task_id, - "Timestamp": time.time(), + "Timestamp": int(time.time()), }, ) diff --git a/JciHitachi/aws_connection.py b/JciHitachi/aws_connection.py index 4ef3615..6a16f6c 100644 --- a/JciHitachi/aws_connection.py +++ b/JciHitachi/aws_connection.py @@ -5,10 +5,9 @@ import logging import threading import time -import uuid from abc import ABC, abstractmethod from dataclasses import dataclass, field -from random import random +from random import random, choices from typing import Callable, Optional, Union import awscrt @@ -43,6 +42,7 @@ class AWSTokens: @dataclass class AWSIdentity: identity_id: str + host_identity_id: str user_name: str user_attributes: dict @@ -204,7 +204,7 @@ def login(self, use_refresh_token: bool = False) -> tuple(str, AWSTokens): """ # https://docs.aws.amazon.com/cognito-user-identity-pools/latest/APIReference/API_InitiateAuth.html - if use_refresh_token and self._aws_tokens != None: + if use_refresh_token and self._aws_tokens is not None: login_json_data = { "AuthFlow": "REFRESH_TOKEN_AUTH", "AuthParameters": { @@ -312,6 +312,7 @@ def get_data(self): } aws_identity = AWSIdentity( identity_id=user_attributes["custom:cognito_identity_id"], + host_identity_id=user_attributes["custom:host_identity_id"], user_name=response["Username"], user_attributes=user_attributes, ) @@ -653,7 +654,7 @@ def _on_message(self, topic, payload, dup, qos, retain, **kwargs): return def _on_connection_interrupted(self, connection, error, **kwargs): - _LOGGER.error("MQTT connection was interrupted with exception {error}.") + _LOGGER.error(f"MQTT connection was interrupted with exception {error}") self._mqtt_events.mqtt_error = error.__class__.__name__ self._mqtt_events.mqtt_error_event.set() @@ -681,11 +682,11 @@ def on_resubscribe_complete(resubscribe_future): _LOGGER.info("Resubscribed successfully.") return - async def _wrap_async(self, identifier: str, fn: Callable, timeout: float) -> str: + async def _wrap_async(self, identifier: str, fn: Callable) -> str: await asyncio.sleep( random() / 2 ) # randomly wait 0~0.5 seconds to prevent messages flooding to the broker. - await asyncio.wait_for(to_thread(fn), timeout) + await to_thread(fn) return identifier def disconnect(self) -> None: @@ -694,7 +695,7 @@ def disconnect(self) -> None: if self._mqttc is not None: self._mqttc.disconnect() - def configure(self) -> None: + def configure(self, identity_id) -> None: """Configure MQTT.""" cred_provider = awscrt.auth.AwsCredentialsProvider.new_delegate( @@ -708,7 +709,7 @@ def configure(self) -> None: cred_provider, client_bootstrap=client_bootstrap, endpoint=AWS_MQTT_ENDPOINT, - client_id=str(uuid.uuid4()), + client_id=f"{identity_id}_{''.join(choices('abcdef0123456789', k=16))}", # {identityid}_{64bit_hex} on_connection_interrupted=self._on_connection_interrupted, on_connection_resumed=self._on_connection_resumed, ) @@ -750,7 +751,7 @@ def connect( try: subscribe_future, _ = self._mqttc.subscribe( - f"{host_identity_id}/#", QOS, callback=self._on_publish + f"{host_identity_id}/+/+/response", QOS, callback=self._on_publish ) subscribe_future.result() @@ -861,11 +862,11 @@ def fn(): publish_future, _ = self._mqttc.publish( support_topic, json.dumps(default_payload), QOS ) - publish_future.result() - self._mqtt_events.device_support_event[thing_name].wait() + publish_future.result(timeout) + self._mqtt_events.device_support_event[thing_name].wait(timeout) self._execution_pools.support_execution_pool.append( - self._wrap_async(thing_name, fn, timeout) + self._wrap_async(thing_name, fn) ) elif publish_type == "status": status_topic = f"{host_identity_id}/{thing_name}/status/request" @@ -878,11 +879,11 @@ def fn(): publish_future, _ = self._mqttc.publish( status_topic, json.dumps(default_payload), QOS ) - publish_future.result() - self._mqtt_events.device_status_event[thing_name].wait() + publish_future.result(timeout) + self._mqtt_events.device_status_event[thing_name].wait(timeout) self._execution_pools.status_execution_pool.append( - self._wrap_async(thing_name, fn, timeout) + self._wrap_async(thing_name, fn) ) elif publish_type == "control": control_topic = f"{host_identity_id}/{thing_name}/control/request" @@ -895,11 +896,11 @@ def fn(): publish_future, _ = self._mqttc.publish( control_topic, json.dumps(payload), QOS ) - publish_future.result() - self._mqtt_events.device_control_event[thing_name].wait() + publish_future.result(timeout) + self._mqtt_events.device_control_event[thing_name].wait(timeout) self._execution_pools.control_execution_pool.append( - self._wrap_async(thing_name, fn, timeout) + self._wrap_async(thing_name, fn) ) else: @@ -995,11 +996,11 @@ def fn(): ), qos=QOS, ) - publish_future.result() - self._mqtt_events.device_shadow_event[thing_name].wait() + publish_future.result(timeout) + self._mqtt_events.device_shadow_event[thing_name].wait(timeout) self._execution_pools.shadow_execution_pool.append( - self._wrap_async(thing_name, fn, timeout) + self._wrap_async(thing_name, fn) ) def execute( diff --git a/JciHitachi/connection.py b/JciHitachi/connection.py index 34d4d46..5f290ed 100644 --- a/JciHitachi/connection.py +++ b/JciHitachi/connection.py @@ -1,6 +1,5 @@ import json import os -import ssl import httpx diff --git a/JciHitachi/model.py b/JciHitachi/model.py index 6c5fc26..4dfcd2e 100644 --- a/JciHitachi/model.py +++ b/JciHitachi/model.py @@ -2196,13 +2196,6 @@ class JciHitachiAWSStatusSupport: Status retrieved from `JciHitachiAWSMqttConnection` _on_publish() callback. """ - extended_mapping = { - "FirmwareId": None, - "Model": "model", - "Brand": "brand", - "FindMe": None, - } - device_type_mapping = JciHitachiAWSStatus.device_type_mapping def __init__(self, raw_status: dict) -> None: @@ -2217,6 +2210,10 @@ def __repr__(self) -> str: def _preprocess(self, status): status = status.copy() + + if status.get("Error", 0) != 0: + return status + # device type status["DeviceType"] = self.device_type_mapping[status["DeviceType"]] diff --git a/JciHitachi/utility.py b/JciHitachi/utility.py index 8adb034..21b8ee9 100644 --- a/JciHitachi/utility.py +++ b/JciHitachi/utility.py @@ -111,13 +111,9 @@ def extract_bytes(v, start, end): # pragma: no cover Extracted value. """ - assert ( - start > end and end >= 0 - ), "Starting byte must be greater than ending byte, \ + assert start > end and end >= 0, "Starting byte must be greater than ending byte, \ and ending byte must be greater than zero : \ - {}, {}".format( - start, end - ) + {}, {}".format(start, end) return cast_bytes(v >> end * 8, start - end) diff --git a/tests/test_api.py b/tests/test_api.py index f5f4d41..1edc78f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -141,9 +141,11 @@ def test_login(self, fixture_aws_mock_api): api._aws_tokens = None mock_get_data_get_user.return_value = ( "OK", - AWSIdentity("id", "username", {"attr", "attr"}), + AWSIdentity("id", "host_id", "username", {"attr", "attr"}), ) mock_login_get_user.return_value = ("OK", aws_tokens) + + # mock_get_data_list_subuser currently unused. mock_get_data_list_subuser.return_value = ( "OK", { @@ -160,7 +162,6 @@ def test_login(self, fixture_aws_mock_api): api.refresh_status = MagicMock() assert api._aws_tokens is None assert api._aws_identity is None - assert api._host_identity_id is None assert len(api._things) == 3 assert api.device_names is None @@ -168,7 +169,8 @@ def test_login(self, fixture_aws_mock_api): api.login() assert api._aws_tokens == aws_tokens assert api._aws_identity is not None - assert api._host_identity_id == "uid2" + assert api._aws_identity.identity_id == "id" + assert api._aws_identity.host_identity_id == "host_id" assert len(api._things) == 2 assert len(api.device_names) == 2 @@ -186,7 +188,7 @@ def test_change_password(self, fixture_aws_mock_api): mock_get_data_1.return_value = ("Not OK", "") with pytest.raises( RuntimeError, - match=f"An error occurred when changing AWS Cognito password: Not OK", + match="An error occurred when changing AWS Cognito password: Not OK", ): api.change_password("new_password") @@ -194,7 +196,7 @@ def test_change_password(self, fixture_aws_mock_api): mock_get_data_2.return_value = ("Not OK", "") with pytest.raises( RuntimeError, - match=f"An error occurred when changing Hitachi password: Not OK", + match="An error occurred when changing Hitachi password: Not OK", ): api.change_password("new_password") @@ -264,6 +266,8 @@ def test_get_status(self, fixture_aws_mock_api): def test_refresh_status(self, fixture_aws_mock_api): api = fixture_aws_mock_api + api._aws_identity = AWSIdentity("id", "host_id", "username", {"attr", "attr"}) + thing_name = api.things[MOCK_DEVICE_AC].thing_name with patch.object(api, "_mqtt") as mock_mqtt: mock_mqtt.publish.return_value = None @@ -421,7 +425,7 @@ def test_refresh_monthly_data(self, fixture_aws_mock_api): }, ) - assert api.things[MOCK_DEVICE_AC].monthly_data == None + assert api.things[MOCK_DEVICE_AC].monthly_data is None api.refresh_monthly_data(2, MOCK_DEVICE_AC) assert api.things[MOCK_DEVICE_AC].monthly_data == [ {"Timestamp": current_time}, @@ -432,7 +436,7 @@ def test_refresh_monthly_data(self, fixture_aws_mock_api): with pytest.raises( RuntimeError, - match=f"An error occurred when getting monthly data: Not OK", + match="An error occurred when getting monthly data: Not OK", ): api.refresh_monthly_data(2, MOCK_DEVICE_AC) diff --git a/tests/test_aws_connection.py b/tests/test_aws_connection.py index 62a1cdc..1b7b103 100644 --- a/tests/test_aws_connection.py +++ b/tests/test_aws_connection.py @@ -23,7 +23,6 @@ GetHistoryEventByUser, GetUser, JciHitachiAWSCognitoConnection, - JciHitachiAWSIoTConnection, JciHitachiAWSMqttConnection, ListSubUser, ) @@ -139,7 +138,7 @@ def test_configure(self, fixture_aws_mock_mqtt_connection): assert mqtt._mqttc is None assert mqtt._shadow_mqttc is None - mqtt.configure() + mqtt.configure(identity_id="identity_id") assert isinstance(mqtt._mqttc, awscrt.mqtt.Connection) assert isinstance(mqtt._shadow_mqttc, awsiot.iotshadow.IotShadowClient) @@ -191,7 +190,7 @@ def test_publish(self, fixture_aws_mock_mqtt_connection): assert thing_name in mqtt._mqtt_events.device_support_event assert len(mqtt._execution_pools.support_execution_pool) == 1 - # test clearning event + # test clearing event mqtt._mqtt_events.device_support_event[thing_name] = threading.Event() mqtt._mqtt_events.device_support_event[thing_name].set() with patch.object(mqtt, "_mqttc") as mock_mqttc: @@ -207,9 +206,11 @@ def test_publish(self, fixture_aws_mock_mqtt_connection): publish_future = concurrent.futures.Future() publish_future.set_exception(ValueError()) mock_mqttc.publish.return_value = (publish_future, None) - with pytest.raises(ValueError, match=f"Invalid publish_type: others"): + with pytest.raises(ValueError, match="Invalid publish_type: others"): mqtt.publish("", thing_name, "others") + # TODO: test timeout + @pytest.mark.parametrize("raise_exception", [False, True]) def test_publish_shadow(self, fixture_aws_mock_mqtt_connection, raise_exception): mqtt = fixture_aws_mock_mqtt_connection @@ -254,7 +255,7 @@ def publish_shadow_func(request, qos): # Test invalid command name. with pytest.raises( - ValueError, match=f"command_name must be one of `get` or `update`." + ValueError, match="command_name must be one of `get` or `update`." ): mqtt.publish_shadow(thing_name, "delete") @@ -298,14 +299,18 @@ class TestJciHitachiAWSCognitoConnection: { "Name": "custom:cognito_identity_id", "Value": "identity_id", - } + }, + { + "Name": "custom:host_identity_id", + "Value": "host_identity_id", + }, ], }, AWSIdentity, ), ( GetCredentials, - {"aws_identity": AWSIdentity("", "", {})}, + {"aws_identity": AWSIdentity("", "", "", {})}, "AWSCognitoIdentityService.GetCredentialsForIdentity", { "Credentials": {