Skip to content

Commit

Permalink
Shared Subscription tests (#540)
Browse files Browse the repository at this point in the history
  • Loading branch information
alfred2g authored Jan 9, 2024
1 parent 2981db9 commit 5742290
Showing 1 changed file with 162 additions and 0 deletions.
162 changes: 162 additions & 0 deletions test/test_mqtt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from concurrent.futures import Future
from awscrt import mqtt5, io, http, exceptions
from test import NativeResourceTest
from threading import Lock
import os
import unittest
import uuid
Expand Down Expand Up @@ -1008,6 +1009,167 @@ def test_operation_sub_unsub(self):
client.stop()
callbacks.future_stopped.result(TIMEOUT)

sub1_callbacks = False
sub2_callbacks = False
total_callbacks = 0
all_packets_received = Future()
mutex = Lock()
received_subscriptions = [0] * 10

def subscriber1_callback(self, publish_received_data: mqtt5.PublishReceivedData):
self.mutex.acquire()
var = publish_received_data.publish_packet.payload
self.received_subscriptions[int(var)] = 1
self.sub1_callbacks = True
self.total_callbacks = self.total_callbacks + 1
if self.total_callbacks == 10:
self.all_packets_received.set_result(None)
self.mutex.release()

def subscriber2_callback(self, publish_received_data: mqtt5.PublishReceivedData):
self.mutex.acquire()
var = publish_received_data.publish_packet.payload
self.received_subscriptions[int(var)] = 1
self.sub2_callbacks = True
self.total_callbacks = self.total_callbacks + 1
if self.total_callbacks == 10:
self.all_packets_received.set_result(None)
self.mutex.release()

def test_operation_shared_subscription(self):
input_host_name = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_HOST")
input_cert = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_CERT")
input_key = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_KEY")

client_id_subscriber1 = create_client_id()
client_id_subscriber2 = create_client_id()
client_id_publisher = create_client_id()

testTopic = "test/MQTT5_Binding_Python_" + client_id_publisher
sharedTopicfilter = "$share/crttest/test/MQTT5_Binding_Python_" + client_id_publisher

tls_ctx_options = io.TlsContextOptions.create_client_with_mtls_from_path(
input_cert,
input_key
)

# subscriber 1
connect_subscriber1_options = mqtt5.ConnectPacket(client_id=client_id_subscriber1)
subscriber1_generic_callback = Mqtt5TestCallbacks()
subscriber1_options = mqtt5.ClientOptions(
host_name=input_host_name,
port=8883,
tls_ctx=io.ClientTlsContext(tls_ctx_options),
connect_options=connect_subscriber1_options,
on_publish_callback_fn=self.subscriber1_callback,
on_lifecycle_event_stopped_fn=subscriber1_generic_callback.on_lifecycle_stopped,
on_lifecycle_event_attempting_connect_fn=subscriber1_generic_callback.on_lifecycle_attempting_connect,
on_lifecycle_event_connection_success_fn=subscriber1_generic_callback.on_lifecycle_connection_success,
on_lifecycle_event_connection_failure_fn=subscriber1_generic_callback.on_lifecycle_connection_failure
)
subscriber1_client = mqtt5.Client(client_options=subscriber1_options)

# subscriber 2
connect_subscriber2_options = mqtt5.ConnectPacket(client_id=client_id_subscriber2)
subscriber2_generic_callback = Mqtt5TestCallbacks()
subscriber2_options = mqtt5.ClientOptions(
host_name=input_host_name,
port=8883,
tls_ctx=io.ClientTlsContext(tls_ctx_options),
connect_options=connect_subscriber2_options,
on_publish_callback_fn=self.subscriber2_callback,
on_lifecycle_event_stopped_fn=subscriber2_generic_callback.on_lifecycle_stopped,
on_lifecycle_event_attempting_connect_fn=subscriber2_generic_callback.on_lifecycle_attempting_connect,
on_lifecycle_event_connection_success_fn=subscriber2_generic_callback.on_lifecycle_connection_success,
on_lifecycle_event_connection_failure_fn=subscriber2_generic_callback.on_lifecycle_connection_failure
)
subscriber2_client = mqtt5.Client(client_options=subscriber2_options)

# publisher
connect_publisher_options = mqtt5.ConnectPacket(client_id=client_id_publisher)
publisher_generic_callback = Mqtt5TestCallbacks()

publisher_options = mqtt5.ClientOptions(
host_name=input_host_name,
port=8883,
tls_ctx=io.ClientTlsContext(tls_ctx_options),
connect_options=connect_publisher_options,
on_lifecycle_event_stopped_fn=publisher_generic_callback.on_lifecycle_stopped,
on_lifecycle_event_attempting_connect_fn=publisher_generic_callback.on_lifecycle_attempting_connect,
on_lifecycle_event_connection_success_fn=publisher_generic_callback.on_lifecycle_connection_success,
on_lifecycle_event_connection_failure_fn=publisher_generic_callback.on_lifecycle_connection_failure
)
publisher_client = mqtt5.Client(client_options=publisher_options)

print("Connecting all 3 clients\n")
subscriber1_client.start()
subscriber1_generic_callback.future_connection_success.result(TIMEOUT)

subscriber2_client.start()
subscriber2_generic_callback.future_connection_success.result(TIMEOUT)

publisher_client.start()
publisher_generic_callback.future_connection_success.result(TIMEOUT)
print("All clients connected\n")

# Subscriber 1
subscriptions = []
subscriptions.append(mqtt5.Subscription(topic_filter=sharedTopicfilter, qos=mqtt5.QoS.AT_LEAST_ONCE))
subscribe_packet = mqtt5.SubscribePacket(
subscriptions=subscriptions)
subscribe_future = subscriber1_client.subscribe(subscribe_packet=subscribe_packet)
suback_packet1 = subscribe_future.result(TIMEOUT)
self.assertIsInstance(suback_packet1, mqtt5.SubackPacket)

# Subscriber 2
subscriptions2 = []
subscriptions2.append(mqtt5.Subscription(topic_filter=sharedTopicfilter, qos=mqtt5.QoS.AT_LEAST_ONCE))
subscribe_packet2 = mqtt5.SubscribePacket(
subscriptions=subscriptions2)
subscribe_future2 = subscriber2_client.subscribe(subscribe_packet=subscribe_packet2)
suback_packet2 = subscribe_future2.result(TIMEOUT)
self.assertIsInstance(suback_packet2, mqtt5.SubackPacket)

publishes = 10
for x in range(0, publishes):
packet = mqtt5.PublishPacket(
payload=f"{x}",
qos=mqtt5.QoS.AT_LEAST_ONCE,
topic=testTopic
)
publish_future = publisher_client.publish(packet)
publish_future.result(TIMEOUT)

self.all_packets_received.result(TIMEOUT)

topic_filters = []
topic_filters.append(testTopic)
unsubscribe_packet = mqtt5.UnsubscribePacket(topic_filters=testTopic)

unsubscribe_future = subscriber1_client.unsubscribe(unsubscribe_packet)
unsuback_packet = unsubscribe_future.result(TIMEOUT)
self.assertIsInstance(unsuback_packet, mqtt5.UnsubackPacket)

unsubscribe_future = subscriber2_client.unsubscribe(unsubscribe_packet)
unsuback_packet = unsubscribe_future.result(TIMEOUT)
self.assertIsInstance(unsuback_packet, mqtt5.UnsubackPacket)

self.assertEqual(self.sub1_callbacks, True)
self.assertEqual(self.sub2_callbacks, True)
self.assertEqual(self.total_callbacks, 10)

for e in self.received_subscriptions:
self.assertEqual(e, 1)

subscriber1_client.stop()
subscriber1_generic_callback.future_stopped.result(TIMEOUT)

subscriber2_client.stop()
subscriber2_generic_callback.future_stopped.result(TIMEOUT)

publisher_client.stop()
publisher_generic_callback.future_stopped.result(TIMEOUT)

def test_operation_will(self):
input_host_name = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_HOST")
input_cert = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_CERT")
Expand Down

0 comments on commit 5742290

Please sign in to comment.