From 79ca9640bb4a6e14a7fb6d2a86700e09e2b7def4 Mon Sep 17 00:00:00 2001 From: Thibaut Lorrain Date: Mon, 6 Aug 2018 17:09:41 +0200 Subject: [PATCH] add unsubscribe --- src/client.rs | 18 ++++++++++++++++-- src/connection.rs | 5 +++++ src/state.rs | 24 ++++++++++++++++++++++++ tests/testsuite.rs | 45 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 90 insertions(+), 2 deletions(-) diff --git a/src/client.rs b/src/client.rs index f301789..5b7ec67 100644 --- a/src/client.rs +++ b/src/client.rs @@ -11,6 +11,7 @@ use MqttOptions; pub enum Command { Status(#[debug_stub = ""] ::std::sync::mpsc::Sender<::state::MqttConnectionStatus>), Subscribe(Subscription), + Unsubscribe(TopicPath), Publish(Publish), Connect, Disconnect, @@ -156,8 +157,21 @@ impl<'a> SubscriptionBuilder<'a> { it: Subscription { qos, ..it }, } } - pub fn send(self) -> Result<()> { - self.client.send_command(Command::Subscribe(self.it)) + pub fn send(self) -> Result> { + let token = SubscriptionToken { client: &self.client,topic_path: self.it.topic_path.clone()}; + self.client.send_command(Command::Subscribe(self.it))?; + Ok(token) + } +} + +pub struct SubscriptionToken<'a> { + client: &'a MqttClient, + topic_path: TopicPath +} + +impl<'a> SubscriptionToken<'a> { + pub fn unsubscribe(self) -> Result<()> { + self.client.send_command(Command::Unsubscribe(self.topic_path)) } } diff --git a/src/connection.rs b/src/connection.rs index e0210c1..d13939a 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -389,6 +389,7 @@ impl ConnectionState { self.turn_command()?; } mqtt3::Packet::Suback(suback) => self.mqtt_state.handle_incoming_suback(suback)?, + mqtt3::Packet::Unsuback(packet_identifier) => self.mqtt_state.handle_incoming_unsuback(packet_identifier)?, mqtt3::Packet::Publish(publish) => { let (_, server) = self.mqtt_state.handle_incoming_publish(publish)?; if let Some(server) = server { @@ -437,6 +438,10 @@ impl ConnectionState { let packet = self.mqtt_state.handle_outgoing_subscribe(vec![sub])?; self.send_packet(mqtt3::Packet::Subscribe(packet))? } + Command::Unsubscribe(topic_path) => { + let packet = self.mqtt_state.handle_outgoing_unsubscribe(vec![topic_path])?; + self.send_packet(mqtt3::Packet::Unsubscribe(packet))? + } Command::Status(tx) => { let _ = tx.send(self.state().status()); } diff --git a/src/state.rs b/src/state.rs index 7cd5d63..317c9c3 100644 --- a/src/state.rs +++ b/src/state.rs @@ -302,6 +302,26 @@ impl MqttState { } } + pub fn handle_outgoing_unsubscribe( + &mut self, + topics: Vec<::mqtt3::TopicPath>, + ) -> Result { + let pkid = self.next_pkid(); + let topics: Vec = topics.iter() + .map(|s| s.path.clone()) + .collect(); + self.subscriptions.retain(|it| !topics.contains(&it.topic_path.path)); + + if self.connection_status == MqttConnectionStatus::Connected { + Ok(mqtt3::Unsubscribe { pid: pkid, topics }) + } else { + error!( + "State = {:?}. Shouldn't unsubscribe in this state", + self.connection_status + ); + Err(ErrorKind::InvalidState.into()) + } + } pub fn handle_incoming_suback(&mut self, ack: mqtt3::Suback) -> Result<()> { if ack.return_codes @@ -313,6 +333,10 @@ impl MqttState { Ok(()) } + pub fn handle_incoming_unsuback(&mut self, ack: mqtt3::PacketIdentifier) -> Result<()> { + Ok(()) + } + pub fn handle_socket_disconnect(&mut self) { self.await_pingresp = false; self.set_status_after_error(); diff --git a/tests/testsuite.rs b/tests/testsuite.rs index b796ec0..47c0042 100644 --- a/tests/testsuite.rs +++ b/tests/testsuite.rs @@ -117,6 +117,51 @@ fn basic_publishes_and_subscribes() { assert_eq!(3, final_count.load(Ordering::SeqCst)); } +#[test] +fn publishes_and_subscribes_and_unsubscribes() { + // loggerv::init_with_level(log::LogLevel::Debug); + let client_options = MqttOptions::new("pubsubunsub", MOSQUITTO_ADDR); + let count = Arc::new(AtomicUsize::new(0)); + let final_count = count.clone(); + let count = count.clone(); + + let request = MqttClient::start(client_options).expect("Coudn't start"); + let token = request + .subscribe( + "test/pubsubunsub", + Box::new(move |_| { + count.fetch_add(1, Ordering::SeqCst); + }), + ) + .unwrap() + .send() + .unwrap(); + + let payload = format!("hello rust"); + request + .publish("test/pubsubunsub") + .unwrap() + .payload(payload.clone().into_bytes()) + .send() + .unwrap(); + + thread::sleep(Duration::from_secs(1)); + token.unsubscribe().unwrap(); + thread::sleep(Duration::from_secs(1)); + + request + .publish("test/pubsubunsub") + .unwrap() + .payload(payload.clone().into_bytes()) + .send() + .unwrap(); + + thread::sleep(Duration::from_secs(1)); + + assert_eq!(1, final_count.load(Ordering::SeqCst)); +} + + #[test] fn alive() { // loggerv::init_with_level(log::LogLevel::Debug);