Skip to content

Commit d0c24da

Browse files
authored
Add an example of SIMD-powered hex encoding (rust-lang#291)
This is lifted from an example elsewhere I found and shows off runtime dispatching along with a lot of intrinsics being used in a bunch.
1 parent 37b4f63 commit d0c24da

File tree

2 files changed

+336
-0
lines changed

2 files changed

+336
-0
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ coresimd = { version = "0.0.3", path = "coresimd/" }
2626

2727
[dev-dependencies]
2828
auxv = "0.3.3"
29+
quickcheck = "0.6"
30+
rand = "0.4"
2931

3032
[profile.release]
3133
debug = true

examples/hex.rs

Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
//! An example showing runtime dispatch to an architecture-optimized
2+
//! implementation.
3+
//!
4+
//! This program implements hex encoding a slice into a predetermined
5+
//! destination using various different instruction sets. This selects at
6+
//! runtime the most optimized implementation and uses that rather than being
7+
//! required to be compiled differently.
8+
//!
9+
//! You can test out this program via:
10+
//!
11+
//! echo test | cargo +nightly run --release --example hex
12+
//!
13+
//! and you should see `746573740a` get printed out.
14+
15+
#![feature(cfg_target_feature, target_feature)]
16+
#![cfg_attr(test, feature(test))]
17+
18+
#[macro_use]
19+
extern crate stdsimd;
20+
21+
#[cfg(test)]
22+
#[macro_use]
23+
extern crate quickcheck;
24+
25+
use std::str;
26+
use std::io::{self, Read};
27+
28+
use stdsimd::vendor::*;
29+
30+
fn main() {
31+
let mut input = Vec::new();
32+
io::stdin().read_to_end(&mut input).unwrap();
33+
let mut dst = vec![0; 2 * input.len()];
34+
let s = hex_encode(&input, &mut dst).unwrap();
35+
println!("{}", s);
36+
}
37+
38+
fn hex_encode<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, usize> {
39+
let len = src.len().checked_mul(2).unwrap();
40+
if dst.len() < len {
41+
return Err(len)
42+
}
43+
44+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
45+
{
46+
if cfg_feature_enabled!("avx2") {
47+
return unsafe { hex_encode_avx2(src, dst) }
48+
}
49+
if cfg_feature_enabled!("sse4.1") {
50+
return unsafe { hex_encode_sse41(src, dst) }
51+
}
52+
}
53+
54+
hex_encode_fallback(src, dst)
55+
}
56+
57+
#[target_feature(enable = "avx2")]
58+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
59+
unsafe fn hex_encode_avx2<'a>(mut src: &[u8], dst: &'a mut [u8])
60+
-> Result<&'a str, usize>
61+
{
62+
let ascii_zero = _mm256_set1_epi8(b'0' as i8);
63+
let nines = _mm256_set1_epi8(9);
64+
let ascii_a = _mm256_set1_epi8((b'a' - 9 - 1) as i8);
65+
let and4bits = _mm256_set1_epi8(0xf);
66+
67+
let mut i = 0isize;
68+
while src.len() >= 32 {
69+
let invec = _mm256_loadu_si256(src.as_ptr() as *const _);
70+
71+
let masked1 = _mm256_and_si256(invec, and4bits);
72+
let masked2 = _mm256_and_si256(_mm256_srli_epi64(invec, 4), and4bits);
73+
74+
// return 0xff corresponding to the elements > 9, or 0x00 otherwise
75+
let cmpmask1 = _mm256_cmpgt_epi8(masked1, nines);
76+
let cmpmask2 = _mm256_cmpgt_epi8(masked2, nines);
77+
78+
// add '0' or the offset depending on the masks
79+
let masked1 = _mm256_add_epi8(masked1, _mm256_blendv_epi8(ascii_zero, ascii_a, cmpmask1));
80+
let masked2 = _mm256_add_epi8(masked2, _mm256_blendv_epi8(ascii_zero, ascii_a, cmpmask2));
81+
82+
// interleave masked1 and masked2 bytes
83+
let res1 = _mm256_unpacklo_epi8(masked2, masked1);
84+
let res2 = _mm256_unpackhi_epi8(masked2, masked1);
85+
86+
// Store everything into the right destination now
87+
let base = dst.as_mut_ptr().offset(i * 2);
88+
let base1 = base.offset(0) as *mut _;
89+
let base2 = base.offset(16) as *mut _;
90+
let base3 = base.offset(32) as *mut _;
91+
let base4 = base.offset(48) as *mut _;
92+
_mm256_storeu2_m128i(base3, base1, res1);
93+
_mm256_storeu2_m128i(base4, base2, res2);
94+
src = &src[32..];
95+
i += 32;
96+
}
97+
98+
let i = i as usize;
99+
drop(hex_encode_sse41(src, &mut dst[i * 2..]));
100+
101+
return Ok(str::from_utf8_unchecked(&dst[..src.len() * 2 + i * 2]))
102+
}
103+
104+
// copied from https://github.com/Matherunner/bin2hex-sse/blob/master/base16_sse4.cpp
105+
#[target_feature(enable = "sse4.1")]
106+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
107+
unsafe fn hex_encode_sse41<'a>(mut src: &[u8], dst: &'a mut [u8])
108+
-> Result<&'a str, usize>
109+
{
110+
let ascii_zero = _mm_set1_epi8(b'0' as i8);
111+
let nines = _mm_set1_epi8(9);
112+
let ascii_a = _mm_set1_epi8((b'a' - 9 - 1) as i8);
113+
let and4bits = _mm_set1_epi8(0xf);
114+
115+
let mut i = 0isize;
116+
while src.len() >= 16 {
117+
let invec = _mm_loadu_si128(src.as_ptr() as *const _);
118+
119+
let masked1 = _mm_and_si128(invec, and4bits);
120+
let masked2 = _mm_and_si128(_mm_srli_epi64(invec, 4), and4bits);
121+
122+
// return 0xff corresponding to the elements > 9, or 0x00 otherwise
123+
let cmpmask1 = _mm_cmpgt_epi8(masked1, nines);
124+
let cmpmask2 = _mm_cmpgt_epi8(masked2, nines);
125+
126+
// add '0' or the offset depending on the masks
127+
let masked1 = _mm_add_epi8(masked1, _mm_blendv_epi8(ascii_zero, ascii_a, cmpmask1));
128+
let masked2 = _mm_add_epi8(masked2, _mm_blendv_epi8(ascii_zero, ascii_a, cmpmask2));
129+
130+
// interleave masked1 and masked2 bytes
131+
let res1 = _mm_unpacklo_epi8(masked2, masked1);
132+
let res2 = _mm_unpackhi_epi8(masked2, masked1);
133+
134+
_mm_storeu_si128(dst.as_mut_ptr().offset(i * 2) as *mut _, res1);
135+
_mm_storeu_si128(dst.as_mut_ptr().offset(i * 2 + 16) as *mut _, res2);
136+
src = &src[16..];
137+
i += 16;
138+
}
139+
140+
let i = i as usize;
141+
drop(hex_encode_fallback(src, &mut dst[i * 2..]));
142+
143+
return Ok(str::from_utf8_unchecked(&dst[..src.len() * 2 + i * 2]))
144+
}
145+
146+
fn hex_encode_fallback<'a>(src: &[u8], dst: &'a mut [u8]) -> Result<&'a str, usize> {
147+
for (byte, slots) in src.iter().zip(dst.chunks_mut(2)) {
148+
slots[0] = hex((*byte >> 4) & 0xf);
149+
slots[1] = hex((*byte >> 0) & 0xf);
150+
}
151+
152+
unsafe {
153+
return Ok(str::from_utf8_unchecked(&dst[..src.len() * 2]))
154+
}
155+
156+
fn hex(byte: u8) -> u8 {
157+
static TABLE: &[u8] = b"0123456789abcdef";
158+
TABLE[byte as usize]
159+
}
160+
}
161+
162+
// Run these with `cargo +nightly test --example hex`
163+
#[cfg(test)]
164+
mod tests {
165+
use std::iter;
166+
167+
use super::*;
168+
169+
fn test(input: &[u8], output: &str) {
170+
let tmp = || vec![0; input.len() * 2];
171+
172+
assert_eq!(hex_encode_fallback(input, &mut tmp()).unwrap(), output);
173+
assert_eq!(hex_encode(input, &mut tmp()).unwrap(), output);
174+
175+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
176+
unsafe {
177+
if cfg_feature_enabled!("avx2") {
178+
assert_eq!(hex_encode_avx2(input, &mut tmp()).unwrap(), output);
179+
}
180+
if cfg_feature_enabled!("sse4.1") {
181+
assert_eq!(hex_encode_sse41(input, &mut tmp()).unwrap(), output);
182+
}
183+
}
184+
}
185+
186+
#[test]
187+
fn empty() {
188+
test(b"", "");
189+
}
190+
191+
#[test]
192+
fn big() {
193+
test(&[0; 1024], &iter::repeat('0').take(2048).collect::<String>());
194+
}
195+
196+
#[test]
197+
fn odd() {
198+
test(&[0; 313], &iter::repeat('0').take(313 * 2).collect::<String>());
199+
}
200+
201+
#[test]
202+
fn avx_works() {
203+
let mut input = [0; 33];
204+
input[4] = 3;
205+
input[16] = 3;
206+
input[17] = 0x30;
207+
input[21] = 1;
208+
input[31] = 0x24;
209+
test(&input, "\
210+
0000000003000000\
211+
0000000000000000\
212+
0330000000010000\
213+
0000000000000024\
214+
00\
215+
");
216+
}
217+
218+
quickcheck! {
219+
fn encode_equals_fallback(input: Vec<u8>) -> bool {
220+
let mut space1 = vec![0; input.len() * 2];
221+
let mut space2 = vec![0; input.len() * 2];
222+
let a = hex_encode(&input, &mut space1).unwrap();
223+
let b = hex_encode_fallback(&input, &mut space2).unwrap();
224+
a == b
225+
}
226+
227+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
228+
fn avx_equals_fallback(input: Vec<u8>) -> bool {
229+
if !cfg_feature_enabled!("avx2") {
230+
return true
231+
}
232+
let mut space1 = vec![0; input.len() * 2];
233+
let mut space2 = vec![0; input.len() * 2];
234+
let a = unsafe { hex_encode_avx2(&input, &mut space1).unwrap() };
235+
let b = hex_encode_fallback(&input, &mut space2).unwrap();
236+
a == b
237+
}
238+
239+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
240+
fn sse41_equals_fallback(input: Vec<u8>) -> bool {
241+
if !cfg_feature_enabled!("avx2") {
242+
return true
243+
}
244+
let mut space1 = vec![0; input.len() * 2];
245+
let mut space2 = vec![0; input.len() * 2];
246+
let a = unsafe { hex_encode_sse41(&input, &mut space1).unwrap() };
247+
let b = hex_encode_fallback(&input, &mut space2).unwrap();
248+
a == b
249+
}
250+
}
251+
}
252+
253+
// Run these with `cargo +nightly bench --example hex`
254+
#[cfg(test)]
255+
mod benches {
256+
extern crate test;
257+
extern crate rand;
258+
259+
use self::rand::Rng;
260+
261+
use super::*;
262+
263+
const SMALL_LEN: usize = 117;
264+
const LARGE_LEN: usize = 1 * 1024 * 1024;
265+
266+
fn doit(b: &mut test::Bencher,
267+
len: usize,
268+
f: for<'a> unsafe fn(&[u8], &'a mut [u8]) -> Result<&'a str, usize>)
269+
{
270+
let input = rand::thread_rng()
271+
.gen_iter::<u8>()
272+
.take(len)
273+
.collect::<Vec<_>>();
274+
let mut dst = vec![0; input.len() * 2];
275+
b.bytes = len as u64;
276+
b.iter(|| unsafe {
277+
f(&input, &mut dst).unwrap();
278+
dst[0]
279+
});
280+
}
281+
282+
#[bench]
283+
fn small_default(b: &mut test::Bencher) {
284+
doit(b, SMALL_LEN, hex_encode);
285+
}
286+
287+
#[bench]
288+
fn small_fallback(b: &mut test::Bencher) {
289+
doit(b, SMALL_LEN, hex_encode_fallback);
290+
}
291+
292+
#[bench]
293+
fn large_default(b: &mut test::Bencher) {
294+
doit(b, LARGE_LEN, hex_encode);
295+
}
296+
297+
#[bench]
298+
fn large_fallback(b: &mut test::Bencher) {
299+
doit(b, LARGE_LEN, hex_encode_fallback);
300+
}
301+
302+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
303+
mod x86 {
304+
use super::*;
305+
306+
#[bench]
307+
fn small_avx2(b: &mut test::Bencher) {
308+
if cfg_feature_enabled!("avx2") {
309+
doit(b, SMALL_LEN, hex_encode_avx2);
310+
}
311+
}
312+
313+
#[bench]
314+
fn small_sse41(b: &mut test::Bencher) {
315+
if cfg_feature_enabled!("sse4.1") {
316+
doit(b, SMALL_LEN, hex_encode_sse41);
317+
}
318+
}
319+
320+
#[bench]
321+
fn large_avx2(b: &mut test::Bencher) {
322+
if cfg_feature_enabled!("avx2") {
323+
doit(b, LARGE_LEN, hex_encode_avx2);
324+
}
325+
}
326+
327+
#[bench]
328+
fn large_sse41(b: &mut test::Bencher) {
329+
if cfg_feature_enabled!("sse4.1") {
330+
doit(b, LARGE_LEN, hex_encode_sse41);
331+
}
332+
}
333+
}
334+
}

0 commit comments

Comments
 (0)