|  | 
| 16 | 16 | 
 | 
| 17 | 17 | use std::net::TcpStream; | 
| 18 | 18 | use std::str; | 
|  | 19 | +use std::time::Duration; | 
| 19 | 20 | 
 | 
| 20 | 21 | use clarity::vm::types::QualifiedContractIdentifier; | 
| 21 | 22 | use libstackerdb::{ | 
| @@ -103,22 +104,34 @@ pub struct StackerDBSession { | 
| 103 | 104 |     pub stackerdb_contract_id: QualifiedContractIdentifier, | 
| 104 | 105 |     /// connection to the replica | 
| 105 | 106 |     sock: Option<TcpStream>, | 
|  | 107 | +    /// The timeout applied to HTTP read and write operations | 
|  | 108 | +    socket_timeout: Duration, | 
| 106 | 109 | } | 
| 107 | 110 | 
 | 
| 108 | 111 | impl StackerDBSession { | 
| 109 | 112 |     /// instantiate but don't connect | 
| 110 |  | -    pub fn new(host: &str, stackerdb_contract_id: QualifiedContractIdentifier) -> StackerDBSession { | 
|  | 113 | +    pub fn new( | 
|  | 114 | +        host: &str, | 
|  | 115 | +        stackerdb_contract_id: QualifiedContractIdentifier, | 
|  | 116 | +        socket_timeout: Duration, | 
|  | 117 | +    ) -> StackerDBSession { | 
| 111 | 118 |         StackerDBSession { | 
| 112 | 119 |             host: host.to_owned(), | 
| 113 | 120 |             stackerdb_contract_id, | 
| 114 | 121 |             sock: None, | 
|  | 122 | +            socket_timeout, | 
| 115 | 123 |         } | 
| 116 | 124 |     } | 
| 117 | 125 | 
 | 
| 118 | 126 |     /// connect or reconnect to the node | 
| 119 | 127 |     fn connect_or_reconnect(&mut self) -> Result<(), RPCError> { | 
| 120 | 128 |         debug!("connect to {}", &self.host); | 
| 121 |  | -        self.sock = Some(TcpStream::connect(&self.host)?); | 
|  | 129 | +        let sock = TcpStream::connect(&self.host)?; | 
|  | 130 | +        // Make sure we don't hang forever if for some reason our node does not | 
|  | 131 | +        // respond as expected such as failing to properly close the connection | 
|  | 132 | +        sock.set_read_timeout(Some(self.socket_timeout))?; | 
|  | 133 | +        sock.set_write_timeout(Some(self.socket_timeout))?; | 
|  | 134 | +        self.sock = Some(sock); | 
| 122 | 135 |         Ok(()) | 
| 123 | 136 |     } | 
| 124 | 137 | 
 | 
| @@ -251,11 +264,49 @@ impl SignerSession for StackerDBSession { | 
| 251 | 264 |     /// upload a chunk | 
| 252 | 265 |     fn put_chunk(&mut self, chunk: &StackerDBChunkData) -> Result<StackerDBChunkAckData, RPCError> { | 
| 253 | 266 |         let body = | 
| 254 |  | -            serde_json::to_vec(chunk).map_err(|e| RPCError::Deserialize(format!("{:?}", &e)))?; | 
|  | 267 | +            serde_json::to_vec(chunk).map_err(|e| RPCError::Deserialize(format!("{e:?}")))?; | 
| 255 | 268 |         let path = stackerdb_post_chunk_path(self.stackerdb_contract_id.clone()); | 
| 256 | 269 |         let resp_bytes = self.rpc_request("POST", &path, Some("application/json"), &body)?; | 
| 257 | 270 |         let ack: StackerDBChunkAckData = serde_json::from_slice(&resp_bytes) | 
| 258 |  | -            .map_err(|e| RPCError::Deserialize(format!("{:?}", &e)))?; | 
|  | 271 | +            .map_err(|e| RPCError::Deserialize(format!("{e:?}")))?; | 
| 259 | 272 |         Ok(ack) | 
| 260 | 273 |     } | 
| 261 | 274 | } | 
|  | 275 | + | 
|  | 276 | +#[cfg(test)] | 
|  | 277 | +mod tests { | 
|  | 278 | +    use std::io::Write; | 
|  | 279 | +    use std::net::TcpListener; | 
|  | 280 | +    use std::thread; | 
|  | 281 | + | 
|  | 282 | +    use super::*; | 
|  | 283 | + | 
|  | 284 | +    #[test] | 
|  | 285 | +    fn socket_timeout_works_as_expected() { | 
|  | 286 | +        let listener = TcpListener::bind("127.0.0.1:0").expect("bind failed"); | 
|  | 287 | +        let addr = listener.local_addr().unwrap(); | 
|  | 288 | + | 
|  | 289 | +        let short_timeout = Duration::from_millis(200); | 
|  | 290 | +        thread::spawn(move || { | 
|  | 291 | +            if let Ok((mut stream, _)) = listener.accept() { | 
|  | 292 | +                // Sleep long enough so the client should hit its timeout | 
|  | 293 | +                std::thread::sleep(short_timeout * 2); | 
|  | 294 | +                let _ = stream.write_all(b"HTTP/1.1 200 OK\r\n\r\n"); | 
|  | 295 | +            } | 
|  | 296 | +        }); | 
|  | 297 | + | 
|  | 298 | +        let contract_id = QualifiedContractIdentifier::transient(); | 
|  | 299 | +        let mut session = StackerDBSession::new(&addr.to_string(), contract_id, short_timeout); | 
|  | 300 | + | 
|  | 301 | +        session.connect_or_reconnect().expect("connect failed"); | 
|  | 302 | + | 
|  | 303 | +        // This should fail due to the timeout | 
|  | 304 | +        let result = session.rpc_request("GET", "/", None, &[]); | 
|  | 305 | +        match result { | 
|  | 306 | +            Err(RPCError::IO(e)) => { | 
|  | 307 | +                assert_eq!(e.kind(), std::io::ErrorKind::WouldBlock); | 
|  | 308 | +            } | 
|  | 309 | +            other => panic!("expected timeout error, got {other:?}"), | 
|  | 310 | +        } | 
|  | 311 | +    } | 
|  | 312 | +} | 
0 commit comments