Skip to content

Commit 9482614

Browse files
committed
fix(server): use a timeout for Server keep-alive
Server keep-alive is now **off** by default. In order to turn it on, the `keep_alive` method must be called on the `Server` object. Closes #368
1 parent 388ddf6 commit 9482614

File tree

2 files changed

+104
-60
lines changed

2 files changed

+104
-60
lines changed

src/server/mod.rs

Lines changed: 90 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,6 @@ use std::fmt;
111111
use std::io::{self, ErrorKind, BufWriter, Write};
112112
use std::net::{SocketAddr, ToSocketAddrs};
113113
use std::thread::{self, JoinHandle};
114-
115-
#[cfg(feature = "timeouts")]
116114
use std::time::Duration;
117115

118116
use num_cpus;
@@ -146,20 +144,16 @@ mod listener;
146144
#[derive(Debug)]
147145
pub struct Server<L = HttpListener> {
148146
listener: L,
149-
_timeouts: Timeouts,
147+
timeouts: Timeouts,
150148
}
151149

152-
#[cfg(feature = "timeouts")]
153150
#[derive(Clone, Copy, Default, Debug)]
154151
struct Timeouts {
155152
read: Option<Duration>,
156153
write: Option<Duration>,
154+
keep_alive: Option<Duration>,
157155
}
158156

159-
#[cfg(not(feature = "timeouts"))]
160-
#[derive(Clone, Copy, Default, Debug)]
161-
struct Timeouts;
162-
163157
macro_rules! try_option(
164158
($e:expr) => {{
165159
match $e {
@@ -175,18 +169,30 @@ impl<L: NetworkListener> Server<L> {
175169
pub fn new(listener: L) -> Server<L> {
176170
Server {
177171
listener: listener,
178-
_timeouts: Timeouts::default(),
172+
timeouts: Timeouts::default(),
179173
}
180174
}
181175

176+
/// Enables keep-alive for this server.
177+
///
178+
/// The timeout duration passed will be used to determine how long
179+
/// to keep the connection alive before dropping it.
180+
///
181+
/// **NOTE**: The timeout will only be used when the `timeouts` feature
182+
/// is enabled for hyper, and rustc is 1.4 or greater.
183+
#[inline]
184+
pub fn keep_alive(&mut self, timeout: Duration) {
185+
self.timeouts.keep_alive = Some(timeout);
186+
}
187+
182188
#[cfg(feature = "timeouts")]
183189
pub fn set_read_timeout(&mut self, dur: Option<Duration>) {
184-
self._timeouts.read = dur;
190+
self.timeouts.read = dur;
185191
}
186192

187193
#[cfg(feature = "timeouts")]
188194
pub fn set_write_timeout(&mut self, dur: Option<Duration>) {
189-
self._timeouts.write = dur;
195+
self.timeouts.write = dur;
190196
}
191197

192198

@@ -228,7 +234,7 @@ L: NetworkListener + Send + 'static {
228234

229235
debug!("threads = {:?}", threads);
230236
let pool = ListenerPool::new(server.listener);
231-
let worker = Worker::new(handler, server._timeouts);
237+
let worker = Worker::new(handler, server.timeouts);
232238
let work = move |mut stream| worker.handle_connection(&mut stream);
233239

234240
let guard = thread::spawn(move || pool.accept(work, threads));
@@ -241,15 +247,15 @@ L: NetworkListener + Send + 'static {
241247

242248
struct Worker<H: Handler + 'static> {
243249
handler: H,
244-
_timeouts: Timeouts,
250+
timeouts: Timeouts,
245251
}
246252

247253
impl<H: Handler + 'static> Worker<H> {
248254

249255
fn new(handler: H, timeouts: Timeouts) -> Worker<H> {
250256
Worker {
251257
handler: handler,
252-
_timeouts: timeouts,
258+
timeouts: timeouts,
253259
}
254260
}
255261

@@ -258,7 +264,7 @@ impl<H: Handler + 'static> Worker<H> {
258264

259265
self.handler.on_connection_start();
260266

261-
if let Err(e) = self.set_timeouts(stream) {
267+
if let Err(e) = self.set_timeouts(&(stream as &mut NetworkStream)) {
262268
error!("set_timeouts error: {:?}", e);
263269
return;
264270
}
@@ -273,73 +279,97 @@ impl<H: Handler + 'static> Worker<H> {
273279

274280
// FIXME: Use Type ascription
275281
let stream_clone: &mut NetworkStream = &mut stream.clone();
276-
let rdr = BufReader::new(stream_clone);
277-
let wrt = BufWriter::new(stream);
282+
let mut rdr = BufReader::new(stream_clone);
283+
let mut wrt = BufWriter::new(stream);
278284

279-
self.keep_alive_loop(rdr, wrt, addr);
285+
while self.keep_alive_loop(&mut rdr, &mut wrt, addr) {
286+
if let Err(e) = self.set_read_timeout(rdr.get_mut(), self.timeouts.keep_alive) {
287+
error!("set_read_timeout keep_alive {:?}", e);
288+
break;
289+
}
290+
}
280291

281292
self.handler.on_connection_end();
282293

283294
debug!("keep_alive loop ending for {}", addr);
284295
}
285296

297+
fn set_timeouts(&self, s: & &mut NetworkStream) -> io::Result<()> {
298+
try!(self.set_read_timeout(s, self.timeouts.read));
299+
self.set_write_timeout(s, self.timeouts.write)
300+
}
301+
302+
286303
#[cfg(not(feature = "timeouts"))]
287-
fn set_timeouts<S>(&self, _: &mut S) -> io::Result<()> where S: NetworkStream {
304+
fn set_write_timeout(&self, _s: & &mut NetworkStream, _timeout: Option<Duration>) -> io::Result<()> {
288305
Ok(())
289306
}
290307

291308
#[cfg(feature = "timeouts")]
292-
fn set_timeouts<S>(&self, s: &mut S) -> io::Result<()> where S: NetworkStream {
293-
try!(s.set_read_timeout(self._timeouts.read));
294-
s.set_write_timeout(self._timeouts.write)
309+
fn set_write_timeout(&self, s: & &mut NetworkStream, timeout: Option<Duration>) -> io::Result<()> {
310+
s.set_write_timeout(timeout)
295311
}
296312

297-
fn keep_alive_loop<W: Write>(&self, mut rdr: BufReader<&mut NetworkStream>,
298-
mut wrt: W, addr: SocketAddr) {
299-
let mut keep_alive = true;
300-
while keep_alive {
301-
let req = match Request::new(&mut rdr, addr) {
302-
Ok(req) => req,
303-
Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
304-
trace!("tcp closed, cancelling keep-alive loop");
305-
break;
306-
}
307-
Err(Error::Io(e)) => {
308-
debug!("ioerror in keepalive loop = {:?}", e);
309-
break;
310-
}
311-
Err(e) => {
312-
//TODO: send a 400 response
313-
error!("request error = {:?}", e);
314-
break;
315-
}
316-
};
313+
#[cfg(not(feature = "timeouts"))]
314+
fn set_read_timeout(&self, _s: & &mut NetworkStream, _timeout: Option<Duration>) -> io::Result<()> {
315+
Ok(())
316+
}
317317

318+
#[cfg(feature = "timeouts")]
319+
fn set_read_timeout(&self, s: & &mut NetworkStream, timeout: Option<Duration>) -> io::Result<()> {
320+
s.set_read_timeout(timeout)
321+
}
318322

319-
if !self.handle_expect(&req, &mut wrt) {
320-
break;
323+
fn keep_alive_loop<W: Write>(&self, mut rdr: &mut BufReader<&mut NetworkStream>,
324+
wrt: &mut W, addr: SocketAddr) -> bool {
325+
let req = match Request::new(rdr, addr) {
326+
Ok(req) => req,
327+
Err(Error::Io(ref e)) if e.kind() == ErrorKind::ConnectionAborted => {
328+
trace!("tcp closed, cancelling keep-alive loop");
329+
return false;
321330
}
322-
323-
keep_alive = http::should_keep_alive(req.version, &req.headers);
324-
let version = req.version;
325-
let mut res_headers = Headers::new();
326-
if !keep_alive {
327-
res_headers.set(Connection::close());
331+
Err(Error::Io(e)) => {
332+
debug!("ioerror in keepalive loop = {:?}", e);
333+
return false;
328334
}
329-
{
330-
let mut res = Response::new(&mut wrt, &mut res_headers);
331-
res.version = version;
332-
self.handler.handle(req, res);
335+
Err(e) => {
336+
//TODO: send a 400 response
337+
error!("request error = {:?}", e);
338+
return false;
333339
}
340+
};
334341

335-
// if the request was keep-alive, we need to check that the server agrees
336-
// if it wasn't, then the server cannot force it to be true anyways
337-
if keep_alive {
338-
keep_alive = http::should_keep_alive(version, &res_headers);
339-
}
340342

341-
debug!("keep_alive = {:?} for {}", keep_alive, addr);
343+
if !self.handle_expect(&req, wrt) {
344+
return false;
345+
}
346+
347+
if let Err(e) = req.set_read_timeout(self.timeouts.read) {
348+
error!("set_read_timeout {:?}", e);
349+
return false;
350+
}
351+
352+
let mut keep_alive = self.timeouts.keep_alive.is_some() &&
353+
http::should_keep_alive(req.version, &req.headers);
354+
let version = req.version;
355+
let mut res_headers = Headers::new();
356+
if !keep_alive {
357+
res_headers.set(Connection::close());
342358
}
359+
{
360+
let mut res = Response::new(wrt, &mut res_headers);
361+
res.version = version;
362+
self.handler.handle(req, res);
363+
}
364+
365+
// if the request was keep-alive, we need to check that the server agrees
366+
// if it wasn't, then the server cannot force it to be true anyways
367+
if keep_alive {
368+
keep_alive = http::should_keep_alive(version, &res_headers);
369+
}
370+
371+
debug!("keep_alive = {:?} for {}", keep_alive, addr);
372+
keep_alive
343373
}
344374

345375
fn handle_expect<W: Write>(&self, req: &Request, wrt: &mut W) -> bool {

src/server/request.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
//! target URI, headers, and message body.
55
use std::io::{self, Read};
66
use std::net::SocketAddr;
7+
use std::time::Duration;
78

89
use buffer::BufReader;
910
use net::NetworkStream;
@@ -64,6 +65,19 @@ impl<'a, 'b: 'a> Request<'a, 'b> {
6465
})
6566
}
6667

68+
/// Set the read timeout of the underlying NetworkStream.
69+
#[cfg(feature = "timeouts")]
70+
#[inline]
71+
pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
72+
self.body.get_ref().get_ref().set_read_timeout(timeout)
73+
}
74+
75+
/// Set the read timeout of the underlying NetworkStream.
76+
#[cfg(not(feature = "timeouts"))]
77+
#[inline]
78+
pub fn set_read_timeout(&self, _timeout: Option<Duration>) -> io::Result<()> {
79+
Ok(())
80+
}
6781
/// Get a reference to the underlying `NetworkStream`.
6882
#[inline]
6983
pub fn downcast_ref<T: NetworkStream>(&self) -> Option<&T> {

0 commit comments

Comments
 (0)