Skip to content

Commit 75943f7

Browse files
committed
Simplify bitmasks
1 parent 6b1e7f6 commit 75943f7

File tree

6 files changed

+183
-189
lines changed

6 files changed

+183
-189
lines changed

crates/core_simd/src/masks.rs

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
)]
1313
mod mask_impl;
1414

15-
mod to_bitmask;
16-
pub use to_bitmask::{ToBitMask, ToBitMaskArray};
17-
1815
use crate::simd::{
1916
cmp::SimdPartialEq, intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount,
2017
};
@@ -262,6 +259,45 @@ where
262259
pub fn all(self) -> bool {
263260
self.0.all()
264261
}
262+
263+
/// Create a bitmask from a mask.
264+
///
265+
/// Each bit is set if the corresponding element in the mask is `true`.
266+
/// If the mask contains more than 64 elements, the bitmask is truncated to the first 64.
267+
#[inline]
268+
#[must_use = "method returns a new integer and does not mutate the original value"]
269+
pub fn to_bitmask(self) -> u64 {
270+
self.0.to_bitmask_integer()
271+
}
272+
273+
/// Create a mask from a bitmask.
274+
///
275+
/// For each bit, if it is set, the corresponding element in the mask is set to `true`.
276+
/// If the mask contains more than 64 elements, the remainder are set to `false`.
277+
#[inline]
278+
#[must_use = "method returns a new mask and does not mutate the original value"]
279+
pub fn from_bitmask(bitmask: u64) -> Self {
280+
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
281+
}
282+
283+
/// Create a bitmask vector from a mask.
284+
///
285+
/// Each bit is set if the corresponding element in the mask is `true`.
286+
/// The remaining bits are unset.
287+
#[inline]
288+
#[must_use = "method returns a new integer and does not mutate the original value"]
289+
pub fn to_bitmask_vector(self) -> Simd<T, N> {
290+
self.0.to_bitmask_vector()
291+
}
292+
293+
/// Create a mask from a bitmask vector.
294+
///
295+
/// For each bit, if it is set, the corresponding element in the mask is set to `true`.
296+
#[inline]
297+
#[must_use = "method returns a new mask and does not mutate the original value"]
298+
pub fn from_bitmask_vector(bitmask: Simd<T, N>) -> Self {
299+
Self(mask_impl::Mask::from_bitmask_vector(bitmask))
300+
}
265301
}
266302

267303
// vector/array conversion

crates/core_simd/src/masks/bitmask.rs

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#![allow(unused_imports)]
22
use super::MaskElement;
33
use crate::simd::intrinsics;
4-
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
4+
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
55
use core::marker::PhantomData;
66

77
/// A mask where each lane is represented by a single bit.
@@ -120,39 +120,64 @@ where
120120
}
121121

122122
#[inline]
123-
#[must_use = "method returns a new array and does not mutate the original value"]
124-
pub fn to_bitmask_array<const M: usize>(self) -> [u8; M] {
125-
assert!(core::mem::size_of::<Self>() == M);
123+
#[must_use = "method returns a new vector and does not mutate the original value"]
124+
pub fn to_bitmask_vector(self) -> Simd<T, N> {
125+
let mut bitmask = Self::splat(false).to_int();
126+
127+
assert!(
128+
core::mem::size_of::<Simd<T, N>>()
129+
>= core::mem::size_of::<<LaneCount<N> as SupportedLaneCount>::BitMask>()
130+
);
126131

127-
// Safety: converting an integer to an array of bytes of the same size is safe
128-
unsafe { core::mem::transmute_copy(&self.0) }
132+
// Safety: the bitmask vector is big enough to hold the bitmask
133+
unsafe {
134+
core::ptr::copy_nonoverlapping(
135+
self.0.as_ref().as_ptr(),
136+
bitmask.as_mut_array().as_mut_ptr() as _,
137+
self.0.as_ref().len(),
138+
);
139+
}
140+
141+
bitmask
129142
}
130143

131144
#[inline]
132145
#[must_use = "method returns a new mask and does not mutate the original value"]
133-
pub fn from_bitmask_array<const M: usize>(bitmask: [u8; M]) -> Self {
134-
assert!(core::mem::size_of::<Self>() == M);
146+
pub fn from_bitmask_vector(bitmask: Simd<T, N>) -> Self {
147+
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
148+
149+
assert!(
150+
core::mem::size_of::<Simd<T, N>>()
151+
>= core::mem::size_of::<<LaneCount<N> as SupportedLaneCount>::BitMask>()
152+
);
135153

136-
// Safety: converting an array of bytes to an integer of the same size is safe
137-
Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData)
154+
// Safety: the bitmask vector is big enough to hold the bitmask
155+
unsafe {
156+
core::ptr::copy_nonoverlapping(
157+
bitmask.as_array().as_ptr() as _,
158+
bytes.as_mut().as_mut_ptr(),
159+
bytes.as_ref().len(),
160+
);
161+
}
162+
163+
Self(bytes, PhantomData)
138164
}
139165

140166
#[inline]
141-
pub fn to_bitmask_integer<U>(self) -> U
142-
where
143-
super::Mask<T, N>: ToBitMask<BitMask = U>,
144-
{
145-
// Safety: these are the same types
146-
unsafe { core::mem::transmute_copy(&self.0) }
167+
pub fn to_bitmask_integer(self) -> u64 {
168+
let mut bitmask = [0u8; 8];
169+
bitmask[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref());
170+
u64::from_ne_bytes(bitmask)
147171
}
148172

149173
#[inline]
150-
pub fn from_bitmask_integer<U>(bitmask: U) -> Self
151-
where
152-
super::Mask<T, N>: ToBitMask<BitMask = U>,
153-
{
154-
// Safety: these are the same types
155-
unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) }
174+
pub fn from_bitmask_integer(bitmask: u64) -> Self {
175+
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
176+
let len = bytes.as_mut().len();
177+
bytes
178+
.as_mut()
179+
.copy_from_slice(&bitmask.to_ne_bytes()[..len]);
180+
Self(bytes, PhantomData)
156181
}
157182

158183
#[inline]

crates/core_simd/src/masks/full_masks.rs

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
//! Masks that take up full SIMD vector registers.
22
3-
use super::{to_bitmask::ToBitMaskArray, MaskElement};
43
use crate::simd::intrinsics;
5-
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
4+
use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount};
65

76
#[repr(transparent)]
87
pub struct Mask<T, const N: usize>(Simd<T, N>)
@@ -143,95 +142,105 @@ where
143142
}
144143

145144
#[inline]
146-
#[must_use = "method returns a new array and does not mutate the original value"]
147-
pub fn to_bitmask_array<const M: usize>(self) -> [u8; M]
148-
where
149-
super::Mask<T, N>: ToBitMaskArray,
150-
{
145+
#[must_use = "method returns a new vector and does not mutate the original value"]
146+
pub fn to_bitmask_vector(self) -> Simd<T, N> {
147+
let mut bitmask = Self::splat(false).to_int();
148+
151149
// Safety: Bytes is the right size array
152150
unsafe {
153151
// Compute the bitmask
154-
let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
152+
let mut bytes: <LaneCount<N> as SupportedLaneCount>::BitMask =
155153
intrinsics::simd_bitmask(self.0);
156154

157-
// Transmute to the return type
158-
let mut bitmask: [u8; M] = core::mem::transmute_copy(&bitmask);
159-
160155
// LLVM assumes bit order should match endianness
161156
if cfg!(target_endian = "big") {
162-
for x in bitmask.as_mut() {
163-
*x = x.reverse_bits();
157+
for x in bytes.as_mut() {
158+
*x = x.reverse_bits()
164159
}
165-
};
160+
}
166161

167-
bitmask
162+
assert!(
163+
core::mem::size_of::<Simd<T, N>>()
164+
>= core::mem::size_of::<<LaneCount<N> as SupportedLaneCount>::BitMask>()
165+
);
166+
core::ptr::copy_nonoverlapping(
167+
bytes.as_ref().as_ptr(),
168+
bitmask.as_mut_array().as_mut_ptr() as _,
169+
bytes.as_ref().len(),
170+
);
168171
}
172+
173+
bitmask
169174
}
170175

171176
#[inline]
172177
#[must_use = "method returns a new mask and does not mutate the original value"]
173-
pub fn from_bitmask_array<const M: usize>(mut bitmask: [u8; M]) -> Self
174-
where
175-
super::Mask<T, N>: ToBitMaskArray,
176-
{
178+
pub fn from_bitmask_vector(bitmask: Simd<T, N>) -> Self {
179+
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
180+
177181
// Safety: Bytes is the right size array
178182
unsafe {
183+
assert!(
184+
core::mem::size_of::<Simd<T, N>>()
185+
>= core::mem::size_of::<<LaneCount<N> as SupportedLaneCount>::BitMask>()
186+
);
187+
core::ptr::copy_nonoverlapping(
188+
bitmask.as_array().as_ptr() as _,
189+
bytes.as_mut().as_mut_ptr(),
190+
bytes.as_mut().len(),
191+
);
192+
179193
// LLVM assumes bit order should match endianness
180194
if cfg!(target_endian = "big") {
181-
for x in bitmask.as_mut() {
195+
for x in bytes.as_mut() {
182196
*x = x.reverse_bits();
183197
}
184198
}
185199

186-
// Transmute to the bitmask
187-
let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
188-
core::mem::transmute_copy(&bitmask);
189-
190200
// Compute the regular mask
191201
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
192-
bitmask,
202+
bytes,
193203
Self::splat(true).to_int(),
194204
Self::splat(false).to_int(),
195205
))
196206
}
197207
}
198208

199209
#[inline]
200-
pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U
201-
where
202-
super::Mask<T, N>: ToBitMask<BitMask = U>,
203-
{
204-
// Safety: U is required to be the appropriate bitmask type
205-
let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) };
210+
pub(crate) fn to_bitmask_integer(self) -> u64 {
211+
let resized = self.to_int().extend::<64>(T::FALSE);
212+
213+
// SAFETY: `resized` is an integer vector with length 64
214+
let bitmask: u64 = unsafe { intrinsics::simd_bitmask(resized) };
206215

207216
// LLVM assumes bit order should match endianness
208217
if cfg!(target_endian = "big") {
209-
bitmask.reverse_bits(N)
218+
bitmask.reverse_bits()
210219
} else {
211220
bitmask
212221
}
213222
}
214223

215224
#[inline]
216-
pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self
217-
where
218-
super::Mask<T, N>: ToBitMask<BitMask = U>,
219-
{
225+
pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
220226
// LLVM assumes bit order should match endianness
221227
let bitmask = if cfg!(target_endian = "big") {
222-
bitmask.reverse_bits(N)
228+
bitmask.reverse_bits()
223229
} else {
224230
bitmask
225231
};
226232

227-
// Safety: U is required to be the appropriate bitmask type
228-
unsafe {
229-
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
233+
// SAFETY: `mask` is the correct bitmask type for a u64 bitmask
234+
let mask: Simd<T, 64> = unsafe {
235+
intrinsics::simd_select_bitmask(
230236
bitmask,
231-
Self::splat(true).to_int(),
232-
Self::splat(false).to_int(),
233-
))
234-
}
237+
Simd::<T, 64>::splat(T::TRUE),
238+
Simd::<T, 64>::splat(T::FALSE),
239+
)
240+
};
241+
242+
// SAFETY: `mask` only contains `T::TRUE` or `T::FALSE`
243+
unsafe { Self::from_int_unchecked(mask.extend::<N>(T::FALSE)) }
235244
}
236245

237246
#[inline]

0 commit comments

Comments
 (0)