diff --git a/ospd_openvas/messaging/mqtt.py b/ospd_openvas/messaging/mqtt.py index e7854638..7beadc49 100644 --- a/ospd_openvas/messaging/mqtt.py +++ b/ospd_openvas/messaging/mqtt.py @@ -86,7 +86,13 @@ def publish(self, message: Message) -> None: class MQTTSubscriber(Subscriber): def __init__(self, client: MQTTClient): - self._client = client + self.client = client + # Save the active subscriptions on subscribe() so we can resubscribe + # after reconnect + self.subscriptions: dict = {} + + self.client.on_connect = self.on_connect + self.client.user_data_set(self.subscriptions) def subscribe( self, message_class: Type[Message], callback: Callable[[Message], None] @@ -96,8 +102,21 @@ def subscribe( logger.debug("Subscribing to topic %s", message_class.topic) - self._client.subscribe(message_class.topic, qos=QOS_AT_LEAST_ONCE) - self._client.message_callback_add(message_class.topic, func) + self.client.subscribe(message_class.topic, qos=QOS_AT_LEAST_ONCE) + self.client.message_callback_add(message_class.topic, func) + + self.subscriptions[message_class.topic] = func + + @staticmethod + def on_connect(_client, _userdata, _flags, rc, _properties): + if rc == 0: + # If we previously had active subscription we subscribe to them + # again because they got lost after a broker disconnect. + # Userdata is set in __init__() and filled in subscribe() + if _userdata: + for topic, func in _userdata.items(): + _client.subscribe(topic, qos=QOS_AT_LEAST_ONCE) + _client.message_callback_add(topic, func) @staticmethod def _handle_message(