Skip to content

Commit

Permalink
Fix 14
Browse files Browse the repository at this point in the history
  • Loading branch information
qqaatw committed Dec 29, 2023
1 parent ad6083b commit ca03a11
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 90 deletions.
11 changes: 8 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ repos:
args: [--markdown-linebreak-ext=md]
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/psf/black
rev: 23.7.0
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.8
hooks:
- id: black
# Run the linter.
#- id: ruff
# args: [ --fix ]
# Run the formatter.
- id: ruff-format
53 changes: 15 additions & 38 deletions JciHitachi/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,13 @@ def task_id(self) -> int:
Returns
-------
int
Serial number counted from 0.
Serial number counted from 0, with maximum 999.
"""

self._task_id += 1
if self._task_id >= 1000:
self._task_id = 1

return self._task_id

def _sync_peripherals_availablity(self) -> None:
Expand Down Expand Up @@ -914,7 +917,6 @@ def __init__(
self._things: dict[str, AWSThing] = {}
self._aws_tokens: Optional[aws_connection.AWSTokens] = None
self._aws_identity: Optional[aws_connection.AWSIdentity] = None
self._host_identity_id: Optional[str] = None
self._task_id: int = 0

@property
Expand Down Expand Up @@ -980,24 +982,6 @@ def login(self) -> None:
self._aws_tokens = conn.aws_tokens
conn_status, self._aws_identity = conn.get_data()

conn = aws_connection.ListSubUser(
self._aws_tokens, print_response=self.print_response
)
conn_status, conn_json = conn.get_data()

if conn_status == "OK":
for user in conn_json["results"]["FamilyMemberList"]:
if user["isHost"]:
self._host_identity_id = user["userId"]
break
assert (
self._host_identity_id is not None
), "Host is not found in the user list"
else:
raise RuntimeError(
f"An error occurred when listing account users: {conn_status}"
)

conn = aws_connection.GetAllDevice(
self._aws_tokens, print_response=self.print_response
)
Expand Down Expand Up @@ -1028,13 +1012,13 @@ def get_credential_callable():
self._mqtt = aws_connection.JciHitachiAWSMqttConnection(
get_credential_callable, print_response=self.print_response
)
self._mqtt.configure()
self._mqtt.configure(self._aws_identity.identity_id)

if not self._mqtt.connect(
self._host_identity_id, self._shadow_names, thing_names
self._aws_identity.host_identity_id, self._shadow_names, thing_names
):
raise RuntimeError(
f"An error occurred when connecting to MQTT endpoint."
"An error occurred when connecting to MQTT endpoint."
)

# status
Expand Down Expand Up @@ -1175,7 +1159,7 @@ def refresh_status(

if refresh_support_code:
self._mqtt.publish(
self._host_identity_id,
self._aws_identity.host_identity_id,
thing.thing_name,
"support",
self._mqtt_timeout,
Expand All @@ -1184,7 +1168,10 @@ def refresh_status(
self._mqtt.publish_shadow(thing.thing_name, "get", shadow_name="info")

self._mqtt.publish(
self._host_identity_id, thing.thing_name, "status", self._mqtt_timeout
self._aws_identity.host_identity_id,
thing.thing_name,
"status",
self._mqtt_timeout,
)

# execute
Expand Down Expand Up @@ -1324,9 +1311,7 @@ def set_status(
"enableQAMode": "qa",
}

if (
False
): # status_name in shadow_publish_mapping: # TODO: replace False cond after shadow function is completed.
if False: # status_name in shadow_publish_mapping: # TODO: replace False cond after shadow function is completed.
shadow_publish_schema = {}
if (
shadow_publish_mapping[status_name] == "filter"
Expand Down Expand Up @@ -1356,22 +1341,14 @@ def set_status(
return False

self._mqtt.publish(
self._host_identity_id,
self._aws_identity.host_identity_id,
thing.thing_name,
"control",
self._mqtt_timeout,
{
"Condition": {
"ThingName": thing.thing_name,
"Index": 0,
"Geofencing": {
"Arrive": None,
"Leave": None,
},
},
status_name: status_value,
"TaskID": self.task_id,
"Timestamp": time.time(),
"Timestamp": int(time.time()),
},
)

Expand Down
43 changes: 22 additions & 21 deletions JciHitachi/aws_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import logging
import threading
import time
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from random import random
from random import random, choices
from typing import Callable, Optional, Union

import awscrt
Expand Down Expand Up @@ -43,6 +42,7 @@ class AWSTokens:
@dataclass
class AWSIdentity:
identity_id: str
host_identity_id: str
user_name: str
user_attributes: dict

Expand Down Expand Up @@ -204,7 +204,7 @@ def login(self, use_refresh_token: bool = False) -> tuple(str, AWSTokens):
"""

# https://docs.aws.amazon.com/cognito-user-identity-pools/latest/APIReference/API_InitiateAuth.html
if use_refresh_token and self._aws_tokens != None:
if use_refresh_token and self._aws_tokens is not None:
login_json_data = {
"AuthFlow": "REFRESH_TOKEN_AUTH",
"AuthParameters": {
Expand Down Expand Up @@ -312,6 +312,7 @@ def get_data(self):
}
aws_identity = AWSIdentity(
identity_id=user_attributes["custom:cognito_identity_id"],
host_identity_id=user_attributes["custom:host_identity_id"],
user_name=response["Username"],
user_attributes=user_attributes,
)
Expand Down Expand Up @@ -653,7 +654,7 @@ def _on_message(self, topic, payload, dup, qos, retain, **kwargs):
return

def _on_connection_interrupted(self, connection, error, **kwargs):
_LOGGER.error("MQTT connection was interrupted with exception {error}.")
_LOGGER.error(f"MQTT connection was interrupted with exception {error}")
self._mqtt_events.mqtt_error = error.__class__.__name__
self._mqtt_events.mqtt_error_event.set()

Expand Down Expand Up @@ -681,11 +682,11 @@ def on_resubscribe_complete(resubscribe_future):
_LOGGER.info("Resubscribed successfully.")
return

async def _wrap_async(self, identifier: str, fn: Callable, timeout: float) -> str:
async def _wrap_async(self, identifier: str, fn: Callable) -> str:
await asyncio.sleep(
random() / 2
) # randomly wait 0~0.5 seconds to prevent messages flooding to the broker.
await asyncio.wait_for(to_thread(fn), timeout)
await to_thread(fn)
return identifier

def disconnect(self) -> None:
Expand All @@ -694,7 +695,7 @@ def disconnect(self) -> None:
if self._mqttc is not None:
self._mqttc.disconnect()

def configure(self) -> None:
def configure(self, identity_id) -> None:
"""Configure MQTT."""

cred_provider = awscrt.auth.AwsCredentialsProvider.new_delegate(
Expand All @@ -708,7 +709,7 @@ def configure(self) -> None:
cred_provider,
client_bootstrap=client_bootstrap,
endpoint=AWS_MQTT_ENDPOINT,
client_id=str(uuid.uuid4()),
client_id=f"{identity_id}_{''.join(choices('abcdef0123456789', k=16))}", # {identityid}_{64bit_hex}
on_connection_interrupted=self._on_connection_interrupted,
on_connection_resumed=self._on_connection_resumed,
)
Expand Down Expand Up @@ -750,7 +751,7 @@ def connect(

try:
subscribe_future, _ = self._mqttc.subscribe(
f"{host_identity_id}/#", QOS, callback=self._on_publish
f"{host_identity_id}/+/+/response", QOS, callback=self._on_publish
)
subscribe_future.result()

Expand Down Expand Up @@ -861,11 +862,11 @@ def fn():
publish_future, _ = self._mqttc.publish(
support_topic, json.dumps(default_payload), QOS
)
publish_future.result()
self._mqtt_events.device_support_event[thing_name].wait()
publish_future.result(timeout)
self._mqtt_events.device_support_event[thing_name].wait(timeout)

self._execution_pools.support_execution_pool.append(
self._wrap_async(thing_name, fn, timeout)
self._wrap_async(thing_name, fn)
)
elif publish_type == "status":
status_topic = f"{host_identity_id}/{thing_name}/status/request"
Expand All @@ -878,11 +879,11 @@ def fn():
publish_future, _ = self._mqttc.publish(
status_topic, json.dumps(default_payload), QOS
)
publish_future.result()
self._mqtt_events.device_status_event[thing_name].wait()
publish_future.result(timeout)
self._mqtt_events.device_status_event[thing_name].wait(timeout)

self._execution_pools.status_execution_pool.append(
self._wrap_async(thing_name, fn, timeout)
self._wrap_async(thing_name, fn)
)
elif publish_type == "control":
control_topic = f"{host_identity_id}/{thing_name}/control/request"
Expand All @@ -895,11 +896,11 @@ def fn():
publish_future, _ = self._mqttc.publish(
control_topic, json.dumps(payload), QOS
)
publish_future.result()
self._mqtt_events.device_control_event[thing_name].wait()
publish_future.result(timeout)
self._mqtt_events.device_control_event[thing_name].wait(timeout)

self._execution_pools.control_execution_pool.append(
self._wrap_async(thing_name, fn, timeout)
self._wrap_async(thing_name, fn)
)

else:
Expand Down Expand Up @@ -995,11 +996,11 @@ def fn():
),
qos=QOS,
)
publish_future.result()
self._mqtt_events.device_shadow_event[thing_name].wait()
publish_future.result(timeout)
self._mqtt_events.device_shadow_event[thing_name].wait(timeout)

self._execution_pools.shadow_execution_pool.append(
self._wrap_async(thing_name, fn, timeout)
self._wrap_async(thing_name, fn)
)

def execute(
Expand Down
1 change: 0 additions & 1 deletion JciHitachi/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os
import ssl

import httpx

Expand Down
11 changes: 4 additions & 7 deletions JciHitachi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2196,13 +2196,6 @@ class JciHitachiAWSStatusSupport:
Status retrieved from `JciHitachiAWSMqttConnection` _on_publish() callback.
"""

extended_mapping = {
"FirmwareId": None,
"Model": "model",
"Brand": "brand",
"FindMe": None,
}

device_type_mapping = JciHitachiAWSStatus.device_type_mapping

def __init__(self, raw_status: dict) -> None:
Expand All @@ -2217,6 +2210,10 @@ def __repr__(self) -> str:

def _preprocess(self, status):
status = status.copy()

if status.get("Error", 0) != 0:
return status

# device type
status["DeviceType"] = self.device_type_mapping[status["DeviceType"]]

Expand Down
8 changes: 2 additions & 6 deletions JciHitachi/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,9 @@ def extract_bytes(v, start, end): # pragma: no cover
Extracted value.
"""

assert (
start > end and end >= 0
), "Starting byte must be greater than ending byte, \
assert start > end and end >= 0, "Starting byte must be greater than ending byte, \
and ending byte must be greater than zero : \
{}, {}".format(
start, end
)
{}, {}".format(start, end)
return cast_bytes(v >> end * 8, start - end)


Expand Down
Loading

0 comments on commit ca03a11

Please sign in to comment.