diff --git a/src/simple_http.rs b/src/simple_http.rs index 09c85cff..dde55cc2 100644 --- a/src/simple_http.rs +++ b/src/simple_http.rs @@ -4,10 +4,10 @@ #[cfg(feature = "proxy")] use socks::Socks5Stream; -use std::io::{BufRead, BufReader, Write}; -#[cfg(not(feature = "proxy"))] +use std::io::{BufRead, BufReader, Read, Write}; use std::net::TcpStream; use std::net::{SocketAddr, ToSocketAddrs}; +use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use std::{error, fmt, io, net, thread}; @@ -38,6 +38,7 @@ pub struct SimpleHttpTransport { proxy_addr: net::SocketAddr, #[cfg(feature = "proxy")] proxy_auth: Option<(String, String)>, + sock: Arc>>, } impl Default for SimpleHttpTransport { @@ -57,6 +58,7 @@ impl Default for SimpleHttpTransport { ), #[cfg(feature = "proxy")] proxy_auth: None, + sock: Arc::new(Mutex::new(None)), } } } @@ -73,29 +75,58 @@ impl SimpleHttpTransport { } fn request(&self, req: impl serde::Serialize) -> Result + where + R: for<'a> serde::de::Deserialize<'a>, + { + // `try_request` should not panic, so the mutex shouldn't be poisoned + // and unwrapping should be safe + let mut sock = self.sock.lock().expect("poisoned mutex"); + match self.try_request(req, &mut sock) { + Ok(response) => Ok(response), + Err(err) => { + *sock = None; + Err(err) + } + } + } + + fn try_request( + &self, + req: impl serde::Serialize, + sock: &mut Option, + ) -> Result where R: for<'a> serde::de::Deserialize<'a>, { // Open connection let request_deadline = Instant::now() + self.timeout; - #[cfg(feature = "proxy")] - let mut sock = if let Some((username, password)) = &self.proxy_auth { - Socks5Stream::connect_with_password( - self.proxy_addr, - self.addr, - username.as_str(), - password.as_str(), - )? - .into_inner() - } else { - Socks5Stream::connect(self.proxy_addr, self.addr)?.into_inner() - }; - - #[cfg(not(feature = "proxy"))] - let mut sock = TcpStream::connect_timeout(&self.addr, self.timeout)?; + if sock.is_none() { + *sock = Some({ + #[cfg(feature = "proxy")] + { + if let Some((username, password)) = &self.proxy_auth { + Socks5Stream::connect_with_password( + self.proxy_addr, + self.addr, + username.as_str(), + password.as_str(), + )? + .into_inner() + } else { + Socks5Stream::connect(self.proxy_addr, self.addr)?.into_inner() + } + } - sock.set_read_timeout(Some(self.timeout))?; - sock.set_write_timeout(Some(self.timeout))?; + #[cfg(not(feature = "proxy"))] + { + let stream = TcpStream::connect_timeout(&self.addr, self.timeout)?; + stream.set_read_timeout(Some(self.timeout))?; + stream.set_write_timeout(Some(self.timeout))?; + stream + } + }) + }; + let sock = sock.as_mut().unwrap(); // Serialize the body first so we can set the Content-Length header. let body = serde_json::to_vec(&req)?; @@ -105,7 +136,6 @@ impl SimpleHttpTransport { sock.write_all(self.path.as_bytes())?; sock.write_all(b" HTTP/1.1\r\n")?; // Write headers - sock.write_all(b"Connection: Close\r\n")?; sock.write_all(b"Content-Type: application/json\r\n")?; sock.write_all(b"Content-Length: ")?; sock.write_all(body.len().to_string().as_bytes())?; @@ -133,18 +163,39 @@ impl SimpleHttpTransport { Err(_) => return Err(Error::HttpParseError), }; - // Skip response header fields - while get_line(&mut reader, request_deadline)? != "\r\n" {} + // Parse response header fields + let mut content_length = None; + loop { + let line = get_line(&mut reader, request_deadline)?; + + if line == "\r\n" { + break; + } + + const CONTENT_LENGTH: &str = "content-length: "; + if line.to_lowercase().starts_with(CONTENT_LENGTH) { + content_length = Some( + line[CONTENT_LENGTH.len()..] + .trim() + .parse::() + .map_err(|_| Error::HttpParseError)?, + ); + } + } if response_code == 401 { // There is no body in a 401 response, so don't try to read it return Err(Error::HttpErrorCode(response_code)); } + let content_length = content_length.ok_or(Error::HttpParseError)?; + + let mut buffer = vec![0; content_length]; + // Even if it's != 200, we parse the response as we may get a JSONRPC error instead // of the less meaningful HTTP error code. - let resp_body = get_lines(&mut reader)?; - match serde_json::from_str(&resp_body) { + reader.read_exact(&mut buffer)?; + match serde_json::from_slice(&buffer) { Ok(s) => Ok(s), Err(e) => { if response_code != 200 { @@ -261,23 +312,6 @@ fn get_line(reader: &mut R, deadline: Instant) -> Result(reader: &mut R) -> Result { - let mut body: String = String::new(); - - for line in reader.lines() { - match line { - Ok(l) => body.push_str(&l), - // io error occurred, abort - Err(e) => return Err(Error::SocketError(e)), - } - } - // remove whitespace - body.retain(|c| !c.is_whitespace()); - - Ok(body) -} - /// Do some very basic manual URL parsing because the uri/url crates /// all have unicode-normalization as a dependency and that's broken. fn check_url(url: &str) -> Result<(SocketAddr, String), Error> {