From c434acd83cb8ceed6f64d5688dbedfd9379795ed Mon Sep 17 00:00:00 2001 From: Scott McMurray Date: Mon, 3 Apr 2023 12:49:17 -0700 Subject: [PATCH] Avoid some extra bounds checks in `read_{u8,u16}` --- compiler/rustc_serialize/src/lib.rs | 1 + compiler/rustc_serialize/src/opaque.rs | 81 ++++++++++++++--------- compiler/rustc_serialize/src/serialize.rs | 5 ++ compiler/rustc_serialize/tests/opaque.rs | 8 +-- 4 files changed, 61 insertions(+), 34 deletions(-) diff --git a/compiler/rustc_serialize/src/lib.rs b/compiler/rustc_serialize/src/lib.rs index 1f8d2336c4e58..eb288200cf82d 100644 --- a/compiler/rustc_serialize/src/lib.rs +++ b/compiler/rustc_serialize/src/lib.rs @@ -11,6 +11,7 @@ Core encoding and decoding interfaces. )] #![feature(never_type)] #![feature(associated_type_bounds)] +#![feature(iter_advance_by)] #![feature(min_specialization)] #![feature(core_intrinsics)] #![feature(maybe_uninit_slice)] diff --git a/compiler/rustc_serialize/src/opaque.rs b/compiler/rustc_serialize/src/opaque.rs index 0e0ebc79eb2e3..ecd8e26379ae1 100644 --- a/compiler/rustc_serialize/src/opaque.rs +++ b/compiler/rustc_serialize/src/opaque.rs @@ -535,34 +535,55 @@ impl Encoder for FileEncoder { // ----------------------------------------------------------------------------- pub struct MemDecoder<'a> { + // Previously this type stored `position: usize`, but because it's staying + // safe code, that meant that reading `n` bytes meant a bounds check both + // for `position + n` *and* `position`, since there's nothing saying that + // the additions didn't wrap. Storing an iterator like this instead means + // there's no offsetting needed to get to the data, and the iterator instead + // of a slice means only increasing the start pointer on reads, rather than + // also needing to decrease the count in a slice. + // This field is first because it's touched more than `data`. + reader: std::slice::Iter<'a, u8>, pub data: &'a [u8], - position: usize, } impl<'a> MemDecoder<'a> { #[inline] pub fn new(data: &'a [u8], position: usize) -> MemDecoder<'a> { - MemDecoder { data, position } + let reader = data[position..].iter(); + MemDecoder { data, reader } } #[inline] pub fn position(&self) -> usize { - self.position + self.data.len() - self.reader.len() } #[inline] pub fn set_position(&mut self, pos: usize) { - self.position = pos + self.reader = self.data[pos..].iter(); } #[inline] pub fn advance(&mut self, bytes: usize) { - self.position += bytes; + self.reader.advance_by(bytes).unwrap(); + } + + #[cold] + fn panic_insufficient_data(&self) -> ! { + let pos = self.position(); + let len = self.data.len(); + panic!("Insufficient remaining data at position {pos} (length {len})"); } } macro_rules! read_leb128 { - ($dec:expr, $fun:ident) => {{ leb128::$fun($dec.data, &mut $dec.position) }}; + ($dec:expr, $fun:ident) => {{ + let mut position = 0_usize; + let val = leb128::$fun($dec.reader.as_slice(), &mut position); + let _ = $dec.reader.advance_by(position); + val + }}; } impl<'a> Decoder for MemDecoder<'a> { @@ -583,17 +604,14 @@ impl<'a> Decoder for MemDecoder<'a> { #[inline] fn read_u16(&mut self) -> u16 { - let bytes = [self.data[self.position], self.data[self.position + 1]]; - let value = u16::from_le_bytes(bytes); - self.position += 2; - value + let bytes = self.read_raw_bytes_array::<2>(); + u16::from_le_bytes(*bytes) } #[inline] fn read_u8(&mut self) -> u8 { - let value = self.data[self.position]; - self.position += 1; - value + let bytes = self.read_raw_bytes_array::<1>(); + u8::from_le_bytes(*bytes) } #[inline] @@ -618,17 +636,14 @@ impl<'a> Decoder for MemDecoder<'a> { #[inline] fn read_i16(&mut self) -> i16 { - let bytes = [self.data[self.position], self.data[self.position + 1]]; - let value = i16::from_le_bytes(bytes); - self.position += 2; - value + let bytes = self.read_raw_bytes_array::<2>(); + i16::from_le_bytes(*bytes) } #[inline] fn read_i8(&mut self) -> i8 { - let value = self.data[self.position]; - self.position += 1; - value as i8 + let bytes = self.read_raw_bytes_array::<1>(); + i8::from_le_bytes(*bytes) } #[inline] @@ -663,20 +678,26 @@ impl<'a> Decoder for MemDecoder<'a> { #[inline] fn read_str(&mut self) -> &'a str { let len = self.read_usize(); - let sentinel = self.data[self.position + len]; - assert!(sentinel == STR_SENTINEL); - let s = unsafe { - std::str::from_utf8_unchecked(&self.data[self.position..self.position + len]) - }; - self.position += len + 1; - s + + // This cannot reuse `read_raw_bytes` as that runs into lifetime issues + // where the slice gets tied to `'b` instead of just to `'a`. + if self.reader.len() <= len { + self.panic_insufficient_data(); + } + let slice = self.reader.as_slice(); + assert!(slice[len] == STR_SENTINEL); + self.reader.advance_by(len + 1).unwrap(); + unsafe { std::str::from_utf8_unchecked(&slice[..len]) } } #[inline] fn read_raw_bytes(&mut self, bytes: usize) -> &'a [u8] { - let start = self.position; - self.position += bytes; - &self.data[start..self.position] + if self.reader.len() < bytes { + self.panic_insufficient_data(); + } + let slice = self.reader.as_slice(); + self.reader.advance_by(bytes).unwrap(); + &slice[..bytes] } } diff --git a/compiler/rustc_serialize/src/serialize.rs b/compiler/rustc_serialize/src/serialize.rs index 567fe06109b78..75a46b8ef5f7f 100644 --- a/compiler/rustc_serialize/src/serialize.rs +++ b/compiler/rustc_serialize/src/serialize.rs @@ -78,6 +78,11 @@ pub trait Decoder { fn read_char(&mut self) -> char; fn read_str(&mut self) -> &str; fn read_raw_bytes(&mut self, len: usize) -> &[u8]; + + #[inline] + fn read_raw_bytes_array(&mut self) -> &[u8; N] { + self.read_raw_bytes(N).try_into().unwrap() + } } /// Trait for types that can be serialized diff --git a/compiler/rustc_serialize/tests/opaque.rs b/compiler/rustc_serialize/tests/opaque.rs index 3a695d0714ee1..032853ac640cf 100644 --- a/compiler/rustc_serialize/tests/opaque.rs +++ b/compiler/rustc_serialize/tests/opaque.rs @@ -55,7 +55,7 @@ fn test_unit() { #[test] fn test_u8() { let mut vec = vec![]; - for i in u8::MIN..u8::MAX { + for i in u8::MIN..=u8::MAX { vec.push(i); } check_round_trip(vec); @@ -63,7 +63,7 @@ fn test_u8() { #[test] fn test_u16() { - for i in u16::MIN..u16::MAX { + for i in u16::MIN..=u16::MAX { check_round_trip(vec![1, 2, 3, i, i, i]); } } @@ -86,7 +86,7 @@ fn test_usize() { #[test] fn test_i8() { let mut vec = vec![]; - for i in i8::MIN..i8::MAX { + for i in i8::MIN..=i8::MAX { vec.push(i); } check_round_trip(vec); @@ -94,7 +94,7 @@ fn test_i8() { #[test] fn test_i16() { - for i in i16::MIN..i16::MAX { + for i in i16::MIN..=i16::MAX { check_round_trip(vec![-1, 2, -3, i, i, i, 2]); } }