Skip to content

Commit

Permalink
watchman_client: don't spawn off reads on their own task
Browse files Browse the repository at this point in the history
Summary:
There are a couple issues with watchman_client as it exists today:

First, if the reader task fails (e.g. because Watchman crashed), then we just
hang forever. What happens then is that the reader stops sending messages, but
the client doesn't notice because it's just waiting on a channel that is shared
by the reader and clients that want to send requests.

Second, even if the client is dropped and its channel is closed, the reader
is left around because nothing stops it.

Finally, there is a bit of a bug here in the sense that we assume the first read
we do will always have the full PDU header. In practice since we talk to a local
socket that is true, but I updated this slightly so that we continue trying to decode
the header after getting more data, unless we've definitely seen too much data (
as per the comment that was there already)

This diffs fixes those issues by removing the dedicated task we had for the reader, and
instead having the client task do it. This requires rewriting the logic we use
for reading a bit using a Decoder instead of doing it imperatively in an `async
fn`.

Reviewed By: markbt

Differential Revision: D28742178

fbshipit-source-id: e3a2e3cc7269e45bf9d7865a9015f80f81fabfd3
  • Loading branch information
krallin authored and facebook-github-bot committed Jun 2, 2021
1 parent 6206b68 commit 9c897ce
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 99 deletions.
12 changes: 4 additions & 8 deletions rust/watchman_client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,13 @@ structopt = "0.3"

[dependencies]
maplit = "1.0"
futures = { version = "0.3.13", features = ["async-await", "compat"] }
bytes = { version = "1.0", features = ["serde"] }
serde = { version = "1.0.102", features = ["derive"] }
serde_bser = { version = "0.2", path = "../serde_bser" }
thiserror = ">=1.0.6"
tokio = { version = "1.0", features = [
"io-util",
"macros",
"net",
"process",
"rt",
"sync",
] }
tokio = { version = "1.5", features = ["full", "test-util"] }
tokio-util = { version = "0.6", features = ["full"] }

[target."cfg(windows)".dependencies]
mio-named-pipes = "0.1"
Expand Down
230 changes: 139 additions & 91 deletions rust/watchman_client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,23 @@ pub mod expr;
pub mod fields;
mod named_pipe;
pub mod pdu;
use serde_bser::de::{Bunser, PduInfo, SliceRead};
use bytes::{Bytes, BytesMut};
use futures::{future::FutureExt, stream::StreamExt};
use serde_bser::de::{Bunser, SliceRead};
use serde_bser::value::Value;
use std::collections::{HashMap, VecDeque};
use std::marker::PhantomData;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::process::Command;
use tokio::sync::mpsc::{Receiver, Sender, UnboundedReceiver, UnboundedSender};
use tokio::sync::Mutex;
use tokio_util::codec::{Decoder, FramedRead};

/// The next id number to use when generating a subscription name
static SUB_ID: AtomicUsize = AtomicUsize::new(1);
Expand Down Expand Up @@ -220,18 +223,9 @@ impl Connector {

let (request_tx, request_rx) = tokio::sync::mpsc::channel(128);

let mut reader_task = ReaderTask {
reader,
request_tx: request_tx.clone(),
};
tokio::spawn(async move {
if let Err(err) = reader_task.run().await {
eprintln!("watchman reader task failed: {}", err);
}
});

let mut task = ClientTask {
writer,
reader: FramedRead::new(reader, BserSplitter),
request_rx,
request_queue: VecDeque::new(),
waiting_response: false,
Expand Down Expand Up @@ -351,11 +345,11 @@ struct SendRequest {
/// The serialized request to send to the server
buf: Vec<u8>,
/// to pass the response back to the requstor
tx: tokio::sync::oneshot::Sender<Result<Vec<u8>, String>>,
tx: tokio::sync::oneshot::Sender<Result<Bytes, String>>,
}

impl SendRequest {
fn respond(self, result: Result<Vec<u8>, String>) -> Result<(), Error> {
fn respond(self, result: Result<Bytes, String>) -> Result<(), Error> {
self.tx
.send(result)
.map_err(|_| Error::generic("requestor has dropped its receiver"))
Expand All @@ -364,90 +358,70 @@ impl SendRequest {

enum TaskItem {
QueueRequest(SendRequest),
ProcessReceivedPdu(Vec<u8>),
RegisterSubscription(String, UnboundedSender<Vec<u8>>),
RegisterSubscription(String, UnboundedSender<Bytes>),
}

/// A live connection to a watchman server.
/// Use [Connector](struct.Connector.html) to establish a connection.
pub struct Client {
inner: Arc<Mutex<ClientInner>>,
}
/// Splits BSER mesages out of a stream. Does not attempt to actually decode them.
struct BserSplitter;

/// The reader task lives to read a PDU and send it to the ClientTask
struct ReaderTask {
reader: tokio::io::ReadHalf<Box<dyn ReadWriteStream>>,
request_tx: Sender<TaskItem>,
}
impl Decoder for BserSplitter {
type Item = Bytes;
type Error = Error;

impl ReaderTask {
async fn run(&mut self) -> Result<(), Error> {
loop {
let pdu = self.read_pdu_vec().await?;
self.request_tx
.send(TaskItem::ProcessReceivedPdu(pdu))
.await
.map_err(Error::generic)?;
}
}
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let mut bunser = Bunser::new(SliceRead::new(buf.as_ref()));

/// Sniffs out the BSER PDU header to determine the length of data that
/// needs to be read in order to decode the full PDU
async fn read_bser_pdu_length(&mut self) -> Result<PduHeader, Error> {
// We know that the smallest full PDU returned by the server
// won't ever be smaller than this size
const BUF_SIZE: usize = 16;
let mut buf = [0u8; BUF_SIZE];

let pos = self.reader.read(&mut buf).await?;
if pos == 0 {
return Err(Error::Eof);
}
let pdu = match bunser.read_pdu() {
Ok(pdu) => pdu,
Err(source) => {
// We know that the smallest full PDU returned by the server won't ever be smaller
// than this size. So, if we have less than BUF_SIZE bytes, ask for more data.
const BUF_SIZE: usize = 16;

let buf = &buf[..pos];

let mut bunser = Bunser::new(SliceRead::new(buf));
let pdu = bunser.read_pdu().map_err(|source| Error::Deserialize {
source: Box::new(source),
data: buf.to_vec(),
})?;
let buf = buf.to_vec();
Ok(PduHeader { buf, pdu })
}
let missing = BUF_SIZE.saturating_sub(buf.len());

/// Read the bytes that comprise a BSER encoded PDU
async fn read_pdu_vec(&mut self) -> Result<Vec<u8>, Error> {
let header = self.read_bser_pdu_length().await?;
let total_size = (header.pdu.start + header.pdu.len) as usize;
let mut buf = header.buf;
if missing > 0 {
buf.reserve(missing);
return Ok(None);
}

let mut end = buf.len();
// We should have succeded in reading some data here, but we didn't. Return an
// error.
return Err(Error::Deserialize {
source: Box::new(source),
data: buf.to_vec(),
});
}
};

buf.resize(total_size, 0);
let total_size = (pdu.start + pdu.len) as usize;

while end != total_size {
let n = self
.reader
.read(&mut buf.as_mut_slice()[end..total_size])
.await?;
if n == 0 {
return Err(Error::Eof);
}
end += n;
let missing = total_size.saturating_sub(buf.len());
if missing > 0 {
buf.reserve(missing);
return Ok(None);
}

Ok(buf)
let ret = buf.split_to(total_size);
Ok(Some(ret.freeze()))
}
}

/// A live connection to a watchman server.
/// Use [Connector](struct.Connector.html) to establish a connection.
pub struct Client {
inner: Arc<Mutex<ClientInner>>,
}

/// The client task coordinates sending requests with processing
/// unilateral results
struct ClientTask {
writer: tokio::io::WriteHalf<Box<dyn ReadWriteStream>>,
reader: FramedRead<tokio::io::ReadHalf<Box<dyn ReadWriteStream>>, BserSplitter>,
request_rx: Receiver<TaskItem>,
request_queue: VecDeque<SendRequest>,
waiting_response: bool,
subscriptions: HashMap<String, UnboundedSender<Vec<u8>>>,
subscriptions: HashMap<String, UnboundedSender<Bytes>>,
}

impl Drop for ClientTask {
Expand All @@ -471,19 +445,29 @@ impl ClientTask {

async fn run_loop(&mut self) -> Result<(), Error> {
loop {
match self.request_rx.recv().await {
Some(TaskItem::QueueRequest(request)) => self.queue_request(request).await?,
Some(TaskItem::ProcessReceivedPdu(pdu)) => self.process_pdu(pdu).await?,
Some(TaskItem::RegisterSubscription(name, tx)) => {
self.register_subscription(name, tx)
futures::select_biased! {
pdu = self.reader.next().fuse() => {
match pdu {
Some(pdu) => self.process_pdu(pdu?).await?,
None => return Err(Error::Eof),
}
}
task = self.request_rx.recv().fuse() => {
match task {
Some(TaskItem::QueueRequest(request)) => self.queue_request(request).await?,
Some(TaskItem::RegisterSubscription(name, tx)) => {
self.register_subscription(name, tx)
}
None => break,
}
}
None => break,
};
}
}

Ok(())
}

fn register_subscription(&mut self, name: String, tx: UnboundedSender<Vec<u8>>) {
fn register_subscription(&mut self, name: String, tx: UnboundedSender<Bytes>) {
self.subscriptions.insert(name, tx);
}

Expand Down Expand Up @@ -525,7 +509,7 @@ impl ClientTask {
}

/// Dispatch a PDU that we just read to the appropriate client code.
async fn process_pdu(&mut self, pdu: Vec<u8>) -> Result<(), Error> {
async fn process_pdu(&mut self, pdu: Bytes) -> Result<(), Error> {
use serde::Deserialize;
#[derive(Deserialize, Debug)]
pub struct Unilateral {
Expand Down Expand Up @@ -560,11 +544,6 @@ impl ClientTask {
}
}

struct PduHeader {
buf: Vec<u8>,
pdu: PduInfo,
}

fn bunser<T>(buf: &[u8]) -> Result<T, Error>
where
T: serde::de::DeserializeOwned,
Expand Down Expand Up @@ -702,7 +681,7 @@ where
name: String,
inner: Arc<Mutex<ClientInner>>,
root: ResolvedRoot,
responses: UnboundedReceiver<Vec<u8>>,
responses: UnboundedReceiver<Bytes>,
_phantom: PhantomData<F>,
}

Expand Down Expand Up @@ -1021,9 +1000,78 @@ impl Client {
mod tests {
use super::*;

use futures::stream::{self, TryStreamExt};
use serde::{Deserialize, Serialize};
use std::io;
use tokio_util::io::StreamReader;

#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct TestStruct {
value: i32,
}

#[test]
fn connection_builder_paths() {
let builder = Connector::new().unix_domain_socket("/some/path");
assert_eq!(builder.unix_domain, Some(PathBuf::from("/some/path")));
}

#[tokio::test]
async fn test_decoder() {
async fn read_bser(buf: &[u8], chunk_size: usize) -> Vec<TestStruct> {
let chunks = buf
.chunks(chunk_size)
.map(|c| Result::<_, io::Error>::Ok(Bytes::copy_from_slice(c)));

let reader = StreamReader::new(stream::iter(chunks));

let decoded = FramedRead::new(reader, BserSplitter)
.map_err(Error::from)
.and_then(|bytes| async move {
// We unwrap this since a) this is a test and b) serde_bser's errors aren't
// easily propagated into en error type like anyhow::Error without losing the
// message.
Ok(serde_bser::from_slice::<TestStruct>(&bytes).unwrap())
})
.try_collect()
.await
.unwrap();

decoded
}

let msgs = vec![
TestStruct { value: 1 },
TestStruct { value: 2 },
TestStruct { value: 3 },
];

let mut buf = vec![];

for msg in msgs.iter() {
serde_bser::ser::serialize(&mut buf, msg).expect("Failed to write to a Vec");
}

// Read it with various sizes
assert_eq!(msgs, read_bser(&buf, 1).await);
assert_eq!(msgs, read_bser(&buf, 2).await);
assert_eq!(msgs, read_bser(&buf, 10).await);
assert_eq!(msgs, read_bser(&buf, buf.len()).await);
}

#[test]
fn test_decoder_err() {
let mut bytes = BytesMut::new();

// We don't error if there isn't much data yet
bytes.extend_from_slice(&[0; 10]);
let r1 = BserSplitter.decode(&mut bytes);
assert!(r1.is_ok());
assert!(r1.unwrap().is_none());

// We do if there is enough
bytes.extend_from_slice(&[0; 10]);
let r1 = BserSplitter.decode(&mut bytes);
assert!(r1.is_err());
}
}

0 comments on commit 9c897ce

Please sign in to comment.