diff --git a/crates/core_simd/src/intrinsics.rs b/crates/core_simd/src/intrinsics.rs index b27893bc729..5c688c11991 100644 --- a/crates/core_simd/src/intrinsics.rs +++ b/crates/core_simd/src/intrinsics.rs @@ -107,6 +107,15 @@ extern "platform-intrinsic" { /// like gather, but more spicy, as it writes instead of reads pub(crate) fn simd_scatter(val: T, ptr: U, mask: V); + /// like a loop of reads offset from the same pointer + /// val: vector of values to select if a lane is masked + /// ptr: vector of pointers to read from + /// mask: a "wide" mask of integers, selects as if simd_select(mask, read(ptr), val) + /// note, the LLVM intrinsic accepts a mask vector of `` + pub(crate) fn simd_masked_load(mask: V, ptr: U, val: T) -> T; + /// like masked_load, but more spicy, as it writes instead of reads + pub(crate) fn simd_masked_store(mask: V, ptr: U, val: T); + // {s,u}add.sat pub(crate) fn simd_saturating_add(x: T, y: T) -> T; diff --git a/crates/core_simd/src/lib.rs b/crates/core_simd/src/lib.rs index 64ba9705ef5..e974e7aa25a 100644 --- a/crates/core_simd/src/lib.rs +++ b/crates/core_simd/src/lib.rs @@ -4,6 +4,7 @@ const_maybe_uninit_as_mut_ptr, const_mut_refs, convert_float_to_int, + core_intrinsics, decl_macro, inline_const, intra_doc_pointers, diff --git a/crates/core_simd/src/vector.rs b/crates/core_simd/src/vector.rs index 18a0bb0a77e..46c6bbc88b2 100644 --- a/crates/core_simd/src/vector.rs +++ b/crates/core_simd/src/vector.rs @@ -1,6 +1,7 @@ use crate::simd::{ cmp::SimdPartialOrd, intrinsics, + ToBitMask, ToBitMaskArray, ptr::{SimdConstPtr, SimdMutPtr}, LaneCount, Mask, MaskElement, SupportedLaneCount, Swizzle, }; @@ -311,6 +312,104 @@ where unsafe { self.store(slice.as_mut_ptr().cast()) } } + #[must_use] + #[inline] + pub fn load_or_default(slice: &[T]) -> Self + where + Mask<::Mask, N>: ToBitMask + ToBitMaskArray, + T: Default, + ::Mask: Default + + core::convert::From + + core::ops::Add<::Mask, Output = ::Mask>, + Simd<::Mask, N>: SimdPartialOrd, + Mask<::Mask, N>: core::ops::BitAnd::Mask, N>> + + core::convert::From>, + { + Self::load_or(slice, Default::default()) + } + + #[must_use] + #[inline] + pub fn load_or(slice: &[T], or: Self) -> Self + where + Mask<::Mask, N>: ToBitMask + ToBitMaskArray, + ::Mask: Default + + core::convert::From + + core::ops::Add<::Mask, Output = ::Mask>, + Simd<::Mask, N>: SimdPartialOrd, + Mask<::Mask, N>: core::ops::BitAnd::Mask, N>> + + core::convert::From>, + { + Self::load_select(slice, Mask::splat(true), or) + } + + #[must_use] + #[inline] + pub fn load_select_or_default(slice: &[T], enable: Mask<::Mask, N>) -> Self + where + Mask<::Mask, N>: ToBitMask + ToBitMaskArray, + T: Default, + ::Mask: Default + + core::convert::From + + core::ops::Add<::Mask, Output = ::Mask>, + Simd<::Mask, N>: SimdPartialOrd, + Mask<::Mask, N>: core::ops::BitAnd::Mask, N>> + + core::convert::From>, + { + Self::load_select(slice, enable, Default::default()) + } + + #[must_use] + #[inline] + pub fn load_select(slice: &[T], mut enable: Mask<::Mask, N>, or: Self) -> Self + where + Mask<::Mask, N>: ToBitMask + ToBitMaskArray, + ::Mask: Default + + core::convert::From + + core::ops::Add<::Mask, Output = ::Mask>, + Simd<::Mask, N>: SimdPartialOrd, + Mask<::Mask, N>: core::ops::BitAnd::Mask, N>> + + core::convert::From>, + { + if USE_BRANCH { + if core::intrinsics::likely(enable.all() && slice.len() > N) { + return Self::from_slice(slice); + } + } + enable &= if USE_BITMASK { + let mask = bzhi_u64(u64::MAX, core::cmp::min(N, slice.len()) as u32); + let mask_bytes: [u8; 8] = unsafe { core::mem::transmute(mask) }; + let mut in_bounds_arr = Mask::splat(true).to_bitmask_array(); + let len = in_bounds_arr.as_ref().len(); + in_bounds_arr.as_mut().copy_from_slice(&mask_bytes[..len]); + Mask::from_bitmask_array(in_bounds_arr) + } else { + mask_up_to(enable, slice.len()) + }; + unsafe { Self::load_select_ptr(slice.as_ptr(), enable, or) } + } + + #[must_use] + #[inline] + pub unsafe fn load_select_unchecked( + slice: &[T], + enable: Mask<::Mask, N>, + or: Self, + ) -> Self { + let ptr = slice.as_ptr(); + unsafe { Self::load_select_ptr(ptr, enable, or) } + } + + #[must_use] + #[inline] + pub unsafe fn load_select_ptr( + ptr: *const T, + enable: Mask<::Mask, N>, + or: Self, + ) -> Self { + unsafe { intrinsics::simd_masked_load(enable.to_int(), ptr, or) } + } + /// Reads from potentially discontiguous indices in `slice` to construct a SIMD vector. /// If an index is out-of-bounds, the element is instead selected from the `or` vector. /// @@ -489,6 +588,51 @@ where unsafe { intrinsics::simd_gather(or, source, enable.to_int()) } } + #[inline] + pub fn masked_store(self, slice: &mut [T], mut enable: Mask<::Mask, N>) + where + Mask<::Mask, N>: ToBitMask + ToBitMaskArray, + Mask<::Mask, N>: ToBitMask + ToBitMaskArray, + ::Mask: Default + + core::convert::From + + core::ops::Add<::Mask, Output = ::Mask>, + Simd<::Mask, N>: SimdPartialOrd, + Mask<::Mask, N>: core::ops::BitAnd::Mask, N>> + + core::convert::From>, + { + if USE_BRANCH { + if core::intrinsics::likely(enable.all() && slice.len() > N) { + return self.copy_to_slice(slice); + } + } + enable &= if USE_BITMASK { + let mask = bzhi_u64(u64::MAX, core::cmp::min(N, slice.len()) as u32); + let mask_bytes: [u8; 8] = unsafe { core::mem::transmute(mask) }; + let mut in_bounds_arr = Mask::splat(true).to_bitmask_array(); + let len = in_bounds_arr.as_ref().len(); + in_bounds_arr.as_mut().copy_from_slice(&mask_bytes[..len]); + Mask::from_bitmask_array(in_bounds_arr) + } else { + mask_up_to(enable, slice.len()) + }; + unsafe { self.masked_store_ptr(slice.as_mut_ptr(), enable) } + } + + #[inline] + pub unsafe fn masked_store_unchecked( + self, + slice: &mut [T], + enable: Mask<::Mask, N>, + ) { + let ptr = slice.as_mut_ptr(); + unsafe { self.masked_store_ptr(ptr, enable) } + } + + #[inline] + pub unsafe fn masked_store_ptr(self, ptr: *mut T, enable: Mask<::Mask, N>) { + unsafe { intrinsics::simd_masked_store(enable.to_int(), ptr, self) } + } + /// Writes the values in a SIMD vector to potentially discontiguous indices in `slice`. /// If an index is out-of-bounds, the write is suppressed without panicking. /// If two elements in the scattered vector would write to the same index @@ -974,3 +1118,44 @@ where { type Mask = isize; } + +const USE_BRANCH: bool = true; +const USE_BITMASK: bool = false; + +#[inline] +fn index() -> Simd +where + T: MaskElement + Default + core::convert::From + core::ops::Add, + LaneCount: SupportedLaneCount, +{ + let mut index = [T::default(); N]; + for i in 1..N { + index[i] = index[i - 1] + T::from(1); + } + Simd::from_array(index) +} + +#[inline] +fn mask_up_to(enable: Mask, len: usize) -> Mask +where + LaneCount: SupportedLaneCount, + M: MaskElement + Default + core::convert::From + core::ops::Add, + Simd: SimdPartialOrd, + // as SimdPartialEq>::Mask: Mask, + Mask: core::ops::BitAnd> + core::convert::From>, +{ + let index = index::(); + enable & Mask::::from(index.simd_lt(Simd::splat(i8::try_from(len).unwrap_or(i8::MAX)))) +} + +// This function matches the semantics of the `bzhi` instruction on x86 BMI2 +// TODO: optimize it further if possible +// https://stackoverflow.com/questions/75179720/how-to-get-rust-compiler-to-emit-bzhi-instruction-without-resorting-to-platform +#[inline(always)] +fn bzhi_u64(a: u64, ix: u32) -> u64 { + if ix > 63 { + a + } else { + a & (1u64 << ix) - 1 + } +} diff --git a/crates/core_simd/tests/masked_load_store.rs b/crates/core_simd/tests/masked_load_store.rs new file mode 100644 index 00000000000..e830330249c --- /dev/null +++ b/crates/core_simd/tests/masked_load_store.rs @@ -0,0 +1,35 @@ +#![feature(portable_simd)] +use core_simd::simd::prelude::*; + +#[cfg(target_arch = "wasm32")] +use wasm_bindgen_test::*; + +#[cfg(target_arch = "wasm32")] +wasm_bindgen_test_configure!(run_in_browser); + +#[test] +#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] +fn masked_load_store() { + let mut arr = [u8::MAX; 7]; + + u8x4::splat(0).masked_store(&mut arr[5..], Mask::from_array([false, true, false, true])); + // write to index 8 is OOB and dropped + assert_eq!(arr, [255u8, 255, 255, 255, 255, 255, 0]); + + u8x4::from_array([0, 1, 2, 3]).masked_store(&mut arr[1..], Mask::splat(true)); + assert_eq!(arr, [255u8, 0, 1, 2, 3, 255, 0]); + + // read from index 8 is OOB and dropped + assert_eq!( + u8x4::load_or(&arr[4..], u8x4::splat(42)), + u8x4::from_array([3, 255, 0, 42]) + ); + assert_eq!( + u8x4::load_select( + &arr[4..], + Mask::from_array([true, false, true, true]), + u8x4::splat(42) + ), + u8x4::from_array([3, 42, 0, 42]) + ); +} diff --git a/crates/core_simd/tests/pointers.rs b/crates/core_simd/tests/pointers.rs index a90ff928ced..b9f32d16e01 100644 --- a/crates/core_simd/tests/pointers.rs +++ b/crates/core_simd/tests/pointers.rs @@ -1,4 +1,4 @@ -#![feature(portable_simd, strict_provenance)] +#![feature(portable_simd, strict_provenance, exposed_provenance)] use core_simd::simd::{ ptr::{SimdConstPtr, SimdMutPtr},