Skip to content

Commit 187f3a3

Browse files
committed
feat(server): Allow keep alive to be turned off for a connection
Closes hyperium#1365
1 parent cecef9d commit 187f3a3

File tree

4 files changed

+136
-3
lines changed

4 files changed

+136
-3
lines changed

Diff for: src/proto/conn.rs

+13-1
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,14 @@ where I: AsyncRead + AsyncWrite,
453453
pub fn close_write(&mut self) {
454454
self.state.close_write();
455455
}
456+
457+
pub fn disable_keep_alive(&mut self) {
458+
if self.state.is_idle() {
459+
self.state.close_read();
460+
} else {
461+
self.state.disable_keep_alive();
462+
}
463+
}
456464
}
457465

458466
// ==== tokio_proto impl ====
@@ -700,6 +708,10 @@ impl<B, K: KeepAlive> State<B, K> {
700708
}
701709
}
702710

711+
fn disable_keep_alive(&mut self) {
712+
self.keep_alive.disable()
713+
}
714+
703715
fn busy(&mut self) {
704716
if let KA::Disabled = self.keep_alive.status() {
705717
return;
@@ -869,7 +881,7 @@ mod tests {
869881
other => panic!("unexpected frame: {:?}", other)
870882
}
871883

872-
// client
884+
// client
873885
let io = AsyncIo::new_buf(vec![], 1);
874886
let mut conn = Conn::<_, proto::Chunk, ClientTransaction>::new(io, Default::default());
875887
conn.state.busy();

Diff for: src/proto/dispatch.rs

+4
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ where
5454
}
5555
}
5656

57+
pub fn disable_keep_alive(&mut self) {
58+
self.conn.disable_keep_alive()
59+
}
60+
5761
fn poll_read(&mut self) -> Poll<(), ::Error> {
5862
loop {
5963
if self.conn.can_read_head() {

Diff for: src/server/mod.rs

+12
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,18 @@ where
536536
}
537537
}
538538

539+
impl<I, B, S> Connection<I, S>
540+
where S: Service<Request = Request, Response = Response<B>, Error = ::Error> + 'static,
541+
I: AsyncRead + AsyncWrite + 'static,
542+
B: Stream<Error=::Error> + 'static,
543+
B::Item: AsRef<[u8]>,
544+
{
545+
/// Disables keep-alive for this connection.
546+
pub fn disable_keep_alive(&mut self) {
547+
self.conn.disable_keep_alive()
548+
}
549+
}
550+
539551
mod unnameable {
540552
// This type is specifically not exported outside the crate,
541553
// so no one can actually name the type. With no methods, we make no

Diff for: tests/server.rs

+107-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ extern crate pretty_env_logger;
66
extern crate tokio_core;
77

88
use futures::{Future, Stream};
9-
use futures::future::{self, FutureResult};
9+
use futures::future::{self, FutureResult, Either};
1010
use futures::sync::oneshot;
1111

1212
use tokio_core::net::TcpListener;
@@ -551,6 +551,106 @@ fn pipeline_enabled() {
551551
assert_eq!(n, 0);
552552
}
553553

554+
#[test]
555+
fn disable_keep_alive_mid_request() {
556+
let mut core = Core::new().unwrap();
557+
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
558+
let addr = listener.local_addr().unwrap();
559+
560+
let (tx1, rx1) = oneshot::channel();
561+
let (tx2, rx2) = oneshot::channel();
562+
563+
let child = thread::spawn(move || {
564+
let mut req = connect(&addr);
565+
req.write_all(b"GET / HTTP/1.1\r\n").unwrap();
566+
tx1.send(()).unwrap();
567+
rx2.wait().unwrap();
568+
req.write_all(b"Host: localhost\r\n\r\n").unwrap();
569+
let mut buf = vec![];
570+
req.read_to_end(&mut buf).unwrap();
571+
});
572+
573+
let fut = listener.incoming()
574+
.into_future()
575+
.map_err(|_| unreachable!())
576+
.and_then(|(item, _incoming)| {
577+
let (socket, _) = item.unwrap();
578+
Http::<hyper::Chunk>::new().serve_connection(socket, HelloWorld)
579+
.select2(rx1)
580+
.then(|r| {
581+
match r {
582+
Ok(Either::A(_)) => panic!("expected rx first"),
583+
Ok(Either::B(((), mut conn))) => {
584+
conn.disable_keep_alive();
585+
tx2.send(()).unwrap();
586+
conn
587+
}
588+
Err(Either::A((e, _))) => panic!("unexpected error {}", e),
589+
Err(Either::B((e, _))) => panic!("unexpected error {}", e),
590+
}
591+
})
592+
});
593+
594+
core.run(fut).unwrap();
595+
child.join().unwrap();
596+
}
597+
598+
#[test]
599+
fn disable_keep_alive_post_request() {
600+
let mut core = Core::new().unwrap();
601+
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
602+
let addr = listener.local_addr().unwrap();
603+
604+
let (tx1, rx1) = oneshot::channel();
605+
606+
let child = thread::spawn(move || {
607+
let mut req = connect(&addr);
608+
req.write_all(b"\
609+
GET / HTTP/1.1\r\n\
610+
Host: localhost\r\n\
611+
\r\n\
612+
").unwrap();
613+
614+
let mut buf = [0; 1024 * 8];
615+
loop {
616+
let n = req.read(&mut buf).expect("reading 1");
617+
if n < buf.len() {
618+
if &buf[n - HELLO.len()..n] == HELLO.as_bytes() {
619+
break;
620+
}
621+
}
622+
}
623+
624+
tx1.send(()).unwrap();
625+
626+
let nread = req.read(&mut buf).unwrap();
627+
assert_eq!(nread, 0);
628+
});
629+
630+
let fut = listener.incoming()
631+
.into_future()
632+
.map_err(|_| unreachable!())
633+
.and_then(|(item, _incoming)| {
634+
let (socket, _) = item.unwrap();
635+
Http::<hyper::Chunk>::new().serve_connection(socket, HelloWorld)
636+
.select2(rx1)
637+
.then(|r| {
638+
match r {
639+
Ok(Either::A(_)) => panic!("expected rx first"),
640+
Ok(Either::B(((), mut conn))) => {
641+
conn.disable_keep_alive();
642+
conn
643+
}
644+
Err(Either::A((e, _))) => panic!("unexpected error {}", e),
645+
Err(Either::B((e, _))) => panic!("unexpected error {}", e),
646+
}
647+
})
648+
});
649+
650+
core.run(fut).unwrap();
651+
child.join().unwrap();
652+
}
653+
554654
#[test]
555655
fn no_proto_empty_parse_eof_does_not_return_error() {
556656
let mut core = Core::new().unwrap();
@@ -719,6 +819,8 @@ impl Service for TestService {
719819

720820
}
721821

822+
const HELLO: &'static str = "hello";
823+
722824
struct HelloWorld;
723825

724826
impl Service for HelloWorld {
@@ -728,7 +830,10 @@ impl Service for HelloWorld {
728830
type Future = FutureResult<Self::Response, Self::Error>;
729831

730832
fn call(&self, _req: Request) -> Self::Future {
731-
future::ok(Response::new())
833+
let mut response = Response::new();
834+
response.headers_mut().set(hyper::header::ContentLength(HELLO.len() as u64));
835+
response.set_body(HELLO);
836+
future::ok(response)
732837
}
733838
}
734839

0 commit comments

Comments
 (0)