Skip to content

Commit 2787031

Browse files
Replace InputBuffer with a faster alternative
We're also deprecating the usage of `input_buffer` crate, see: snapview/input_buffer#6 (comment)
1 parent 8c3172c commit 2787031

File tree

8 files changed

+135
-41
lines changed

8 files changed

+135
-41
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ byteorder = "1.3.2"
2727
bytes = "1.0"
2828
http = "0.2"
2929
httparse = "1.3.4"
30-
input_buffer = "0.4.0"
3130
log = "0.4.8"
3231
rand = "0.8.0"
3332
sha-1 = "0.9"
@@ -53,5 +52,6 @@ optional = true
5352
version = "0.5.0"
5453

5554
[dev-dependencies]
55+
input_buffer = "0.5.0"
5656
env_logger = "0.8.1"
5757
net2 = "0.2.33"

src/buffer.rs

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
//! A buffer for reading data from the network.
2+
//!
3+
//! The `ReadBuffer` is a buffer of bytes similar to a first-in, first-out queue.
4+
//! It is filled by reading from a stream supporting `Read` and is then
5+
//! accessible as a cursor for reading bytes.
6+
7+
use std::io::{Cursor, Read, Result as IoResult};
8+
9+
use bytes::Buf;
10+
11+
/// A FIFO buffer for reading packets from the network.
12+
#[derive(Debug)]
13+
pub struct ReadBuffer<const CHUNK_SIZE: usize> {
14+
storage: Cursor<Vec<u8>>,
15+
chunk: [u8; CHUNK_SIZE],
16+
}
17+
18+
impl<const CHUNK_SIZE: usize> ReadBuffer<CHUNK_SIZE> {
19+
/// Create a new empty input buffer.
20+
pub fn new() -> Self {
21+
Self::with_capacity(CHUNK_SIZE)
22+
}
23+
24+
/// Create a new empty input buffer with a given `capacity`.
25+
pub fn with_capacity(capacity: usize) -> Self {
26+
Self::from_partially_read(Vec::with_capacity(capacity))
27+
}
28+
29+
/// Create a input buffer filled with previously read data.
30+
pub fn from_partially_read(part: Vec<u8>) -> Self {
31+
Self { storage: Cursor::new(part), chunk: [0; CHUNK_SIZE] }
32+
}
33+
34+
/// Get a cursor to the data storage.
35+
pub fn as_cursor(&self) -> &Cursor<Vec<u8>> {
36+
&self.storage
37+
}
38+
39+
/// Get a cursor to the mutable data storage.
40+
pub fn as_cursor_mut(&mut self) -> &mut Cursor<Vec<u8>> {
41+
&mut self.storage
42+
}
43+
44+
/// Consume the `ReadBuffer` and get the internal storage.
45+
pub fn into_vec(mut self) -> Vec<u8> {
46+
// Current implementation of `tungstenite-rs` expects that the `into_vec()` drains
47+
// the data from the container that has already been read by the cursor.
48+
let pos = self.storage.position() as usize;
49+
self.storage.get_mut().drain(0..pos).count();
50+
self.storage.set_position(0);
51+
52+
// Now we can safely return the internal container.
53+
self.storage.into_inner()
54+
}
55+
56+
/// Read next portion of data from the given input stream.
57+
pub fn read_from<S: Read>(&mut self, stream: &mut S) -> IoResult<usize> {
58+
let size = stream.read(&mut self.chunk)?;
59+
self.storage.get_mut().extend_from_slice(&self.chunk[..size]);
60+
Ok(size)
61+
}
62+
}
63+
64+
impl<const CHUNK_SIZE: usize> Buf for ReadBuffer<CHUNK_SIZE> {
65+
fn remaining(&self) -> usize {
66+
Buf::remaining(self.as_cursor())
67+
}
68+
69+
fn chunk(&self) -> &[u8] {
70+
Buf::chunk(self.as_cursor())
71+
}
72+
73+
fn advance(&mut self, cnt: usize) {
74+
Buf::advance(self.as_cursor_mut(), cnt)
75+
}
76+
}
77+
78+
#[cfg(test)]
79+
mod tests {
80+
use super::*;
81+
82+
#[test]
83+
fn simple_reading() {
84+
let mut input = Cursor::new(b"Hello World!".to_vec());
85+
let mut buffer = ReadBuffer::<4096>::new();
86+
let size = buffer.read_from(&mut input).unwrap();
87+
assert_eq!(size, 12);
88+
assert_eq!(buffer.chunk(), b"Hello World!");
89+
}
90+
91+
#[test]
92+
fn reading_in_chunks() {
93+
let mut inp = Cursor::new(b"Hello World!".to_vec());
94+
let mut buf = ReadBuffer::<4>::new();
95+
96+
let size = buf.read_from(&mut inp).unwrap();
97+
assert_eq!(size, 4);
98+
assert_eq!(buf.chunk(), b"Hell");
99+
100+
buf.advance(2);
101+
assert_eq!(buf.chunk(), b"ll");
102+
103+
let size = buf.read_from(&mut inp).unwrap();
104+
assert_eq!(size, 4);
105+
assert_eq!(buf.chunk(), b"llo Wo");
106+
107+
let size = buf.read_from(&mut inp).unwrap();
108+
assert_eq!(size, 4);
109+
assert_eq!(buf.chunk(), b"llo World!");
110+
}
111+
}

src/client.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ mod encryption {
7272
Mode::Tls => {
7373
let config = {
7474
let mut config = ClientConfig::new();
75-
config.root_store = rustls_native_certs::load_native_certs().map_err(|(_, err)| err)?;
75+
config.root_store =
76+
rustls_native_certs::load_native_certs().map_err(|(_, err)| err)?;
7677

7778
Arc::new(config)
7879
};

src/error.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,6 @@ pub enum CapacityError {
127127
#[error("Too many headers")]
128128
TooManyHeaders,
129129
/// Received header is too long.
130-
#[error("Header too long")]
131-
HeaderTooLong,
132130
/// Message is bigger than the maximum allowed size.
133131
#[error("Message too long: {size} > {max_size}")]
134132
MessageTooLong {
@@ -137,9 +135,6 @@ pub enum CapacityError {
137135
/// The maximum allowed message size.
138136
max_size: usize,
139137
},
140-
/// TCP buffer is full.
141-
#[error("Incoming TCP buffer is full")]
142-
TcpBufferFull,
143138
}
144139

145140
/// Indicates the specific type/cause of a protocol error.

src/handshake/machine.rs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ use log::*;
33
use std::io::{Cursor, Read, Write};
44

55
use crate::{
6-
error::{CapacityError, Error, ProtocolError, Result},
6+
error::{Error, ProtocolError, Result},
77
util::NonBlockingResult,
8+
ReadBuffer,
89
};
9-
use input_buffer::{InputBuffer, MIN_READ};
1010

1111
/// A generic handshake state machine.
1212
#[derive(Debug)]
@@ -18,10 +18,7 @@ pub struct HandshakeMachine<Stream> {
1818
impl<Stream> HandshakeMachine<Stream> {
1919
/// Start reading data from the peer.
2020
pub fn start_read(stream: Stream) -> Self {
21-
HandshakeMachine {
22-
stream,
23-
state: HandshakeState::Reading(InputBuffer::with_capacity(MIN_READ)),
24-
}
21+
HandshakeMachine { stream, state: HandshakeState::Reading(ReadBuffer::new()) }
2522
}
2623
/// Start writing data to the peer.
2724
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
@@ -43,12 +40,7 @@ impl<Stream: Read + Write> HandshakeMachine<Stream> {
4340
trace!("Doing handshake round.");
4441
match self.state {
4542
HandshakeState::Reading(mut buf) => {
46-
let read = buf
47-
.prepare_reserve(MIN_READ)
48-
.with_limit(usize::max_value()) // TODO limit size
49-
.map_err(|_| Error::Capacity(CapacityError::HeaderTooLong))?
50-
.read_from(&mut self.stream)
51-
.no_block()?;
43+
let read = buf.read_from(&mut self.stream).no_block()?;
5244
match read {
5345
Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)),
5446
Some(_) => Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
@@ -124,7 +116,7 @@ pub trait TryParse: Sized {
124116
#[derive(Debug)]
125117
enum HandshakeState {
126118
/// Reading data from the peer.
127-
Reading(InputBuffer),
119+
Reading(ReadBuffer),
128120
/// Sending data to the peer.
129121
Writing(Cursor<Vec<u8>>),
130122
}

src/handshake/mod.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,6 @@ mod tests {
131131
#[test]
132132
fn key_conversion() {
133133
// example from RFC 6455
134-
assert_eq!(
135-
derive_accept_key(b"dGhlIHNhbXBsZSBub25jZQ=="),
136-
"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
137-
);
134+
assert_eq!(derive_accept_key(b"dGhlIHNhbXBsZSBub25jZQ=="), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
138135
}
139136
}

src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
pub use http;
1616

17+
mod buffer;
1718
pub mod client;
1819
pub mod error;
1920
pub mod handshake;
@@ -22,6 +23,9 @@ pub mod server;
2223
pub mod stream;
2324
pub mod util;
2425

26+
const READ_BUFFER_CHUNK_SIZE: usize = 4096;
27+
type ReadBuffer = buffer::ReadBuffer<READ_BUFFER_CHUNK_SIZE>;
28+
2529
pub use crate::{
2630
client::{client, connect},
2731
error::{Error, Result},

src/protocol/frame/mod.rs

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@ pub mod coding;
66
mod frame;
77
mod mask;
88

9-
pub use self::frame::{CloseFrame, Frame, FrameHeader};
9+
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};
1010

11-
use crate::error::{CapacityError, Error, Result};
12-
use input_buffer::{InputBuffer, MIN_READ};
1311
use log::*;
14-
use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};
12+
13+
pub use self::frame::{CloseFrame, Frame, FrameHeader};
14+
use crate::{
15+
error::{CapacityError, Error, Result},
16+
ReadBuffer,
17+
};
1518

1619
/// A reader and writer for WebSocket frames.
1720
#[derive(Debug)]
@@ -82,7 +85,7 @@ where
8285
#[derive(Debug)]
8386
pub(super) struct FrameCodec {
8487
/// Buffer to read data from the stream.
85-
in_buffer: InputBuffer,
88+
in_buffer: ReadBuffer,
8689
/// Buffer to send packets to the network.
8790
out_buffer: Vec<u8>,
8891
/// Header and remaining size of the incoming packet being processed.
@@ -92,17 +95,13 @@ pub(super) struct FrameCodec {
9295
impl FrameCodec {
9396
/// Create a new frame codec.
9497
pub(super) fn new() -> Self {
95-
Self {
96-
in_buffer: InputBuffer::with_capacity(MIN_READ),
97-
out_buffer: Vec::new(),
98-
header: None,
99-
}
98+
Self { in_buffer: ReadBuffer::new(), out_buffer: Vec::new(), header: None }
10099
}
101100

102101
/// Create a new frame codec from partially read data.
103102
pub(super) fn from_partially_read(part: Vec<u8>) -> Self {
104103
Self {
105-
in_buffer: InputBuffer::from_partially_read(part),
104+
in_buffer: ReadBuffer::from_partially_read(part),
106105
out_buffer: Vec::new(),
107106
header: None,
108107
}
@@ -152,12 +151,7 @@ impl FrameCodec {
152151
}
153152

154153
// Not enough data in buffer.
155-
let size = self
156-
.in_buffer
157-
.prepare_reserve(MIN_READ)
158-
.with_limit(usize::max_value())
159-
.map_err(|_| Error::Capacity(CapacityError::TcpBufferFull))?
160-
.read_from(stream)?;
154+
let size = self.in_buffer.read_from(stream)?;
161155
if size == 0 {
162156
trace!("no frame received");
163157
return Ok(None);

0 commit comments

Comments
 (0)