Skip to content

Commit 502de01

Browse files
committed
rustc: SIMD types use pointers in Rust's ABI
This commit changes the ABI of SIMD types in the "Rust" ABI to unconditionally be passed via pointers instead of being passed as immediates. This should fix a longstanding issue, #44367, where SIMD-using programs ended up showing very odd behavior at runtime because the ABI between functions was mismatched. As a bit of a recap, this is sort of an LLVM bug and sort of an LLVM feature (today's behavior). LLVM will generate code for a function solely looking at the function it's generating, including calls to other functions. Let's then say you've got something that looks like: ```llvm define void @foo() { ; no target features enabled call void @bar(<i64 x 4> zeroinitializer) ret void } define void @bar(<i64 x 4>) #0 { ; enables the AVX feature ... } ``` LLVM will codegen the call to `bar` *without* using AVX registers becauase `foo` doesn't have access to these registers. Instead it's generated with emulation that uses two 128-bit registers. The `bar` function, on the other hand, will expect its argument in an AVX register (as it has AVX enabled). This means we've got a codegen problem! Comments on #44367 have some more contexutal information but the crux of the issue is that if we want SIMD to work in general we'll need to ensure that whenever a function calls another they ABI of the arguments being passed is in agreement. One possible solution to this would be to insert "shim functions" where whenever a `target_feature` mismatch is detected the compiler inserts a shim function where you pass arguments via memory to the shim and then the shim loads the values and calls the target function (where the shim and the target have the same target features enabled). This unfortunately is quite nontrivial to implement in rustc today (especially when accounting for function pointers and such). This commit takes a different solution, *always* passing SIMD arguments through memory instead of passing as immediates. This strategy solves the problem at the LLVM layer because the ABI between two functions never uses SIMD registers. This also shouldn't be a hit to performance because SIMD performance is thought to often rely on inlining anyway, where a `call` instruction, even if using SIMD registers, would be disastrous to performance regardless. LLVM should then be more than capable of fixing all our memory usage to use registers instead after enough inlining has been performed. Note that there's a few caveats to this commit though: * The "platform intrinsic" ABI is omitted from "always pass via memory". This ABI is used to define intrinsics like `simd_shuffle4` where LLVM and rustc need to have the arguments as an immediate. * Additionally this commit does *not* fix the `extern` ("C") ABI. This means that the bug in #44367 can still happen when using non-Rust-ABI functions. My hope is that before stabilization we can ban and/or warn about SIMD types in these functions (as AFAIK there's not much motivation to belong there anyway), but I'll leave that for a later commit and if this is merged I'll file a follow-up issue. All in all this... Closes #44367
1 parent a0dcecf commit 502de01

File tree

3 files changed

+207
-3
lines changed

3 files changed

+207
-3
lines changed

src/librustc_trans/abi.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,31 @@ impl<'a, 'tcx> FnType<'tcx> {
871871

872872
match arg.layout.abi {
873873
layout::Abi::Aggregate { .. } => {}
874+
875+
// This is a fun case! The gist of what this is doing is
876+
// that we want callers and callees to always agree on the
877+
// ABI of how they pass SIMD arguments. If we were to *not*
878+
// make these arguments indirect then they'd be immediates
879+
// in LLVM, which means that they'd used whatever the
880+
// appropriate ABI is for the callee and the caller. That
881+
// means, for example, if the caller doesn't have AVX
882+
// enabled but the callee does, then passing an AVX argument
883+
// across this boundary would cause corrupt data to show up.
884+
//
885+
// This problem is fixed by unconditionally passing SIMD
886+
// arguments through memory between callers and callees
887+
// which should get them all to agree on ABI regardless of
888+
// target feature sets. Some more information about this
889+
// issue can be found in #44367.
890+
//
891+
// Note that the platform intrinsic ABI is exempt here as
892+
// that's how we connect up to LLVM and it's unstable
893+
// anyway, we control all calls to it in libstd.
894+
layout::Abi::Vector { .. } if abi != Abi::PlatformIntrinsic => {
895+
arg.make_indirect();
896+
return
897+
}
898+
874899
_ => return
875900
}
876901

src/test/codegen/x86_mmx.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ pub struct i8x8(u64);
2222

2323
#[no_mangle]
2424
pub fn a(a: &mut i8x8, b: i8x8) -> i8x8 {
25-
// CHECK-LABEL: define x86_mmx @a(x86_mmx*{{.*}}, x86_mmx{{.*}})
26-
// CHECK: store x86_mmx %b, x86_mmx* %a
27-
// CHECK: ret x86_mmx %b
25+
// CHECK-LABEL: define void @a(x86_mmx*{{.*}}, x86_mmx*{{.*}}, x86_mmx*{{.*}})
2826
*a = b;
2927
return b
3028
}
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// Copyright 2018 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// http://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
#![feature(repr_simd, target_feature, cfg_target_feature)]
12+
13+
use std::process::{Command, ExitStatus};
14+
use std::env;
15+
16+
fn main() {
17+
if let Some(level) = env::args().nth(1) {
18+
return test::main(&level)
19+
}
20+
21+
let me = env::current_exe().unwrap();
22+
for level in ["sse", "avx", "avx512"].iter() {
23+
let status = Command::new(&me).arg(level).status().unwrap();
24+
if status.success() {
25+
println!("success with {}", level);
26+
continue
27+
}
28+
29+
// We don't actually know if our computer has the requisite target features
30+
// for the test below. Testing for that will get added to libstd later so
31+
// for now just asume sigill means this is a machine that can't run this test.
32+
if is_sigill(status) {
33+
println!("sigill with {}, assuming spurious", level);
34+
continue
35+
}
36+
panic!("invalid status at {}: {}", level, status);
37+
}
38+
}
39+
40+
#[cfg(unix)]
41+
fn is_sigill(status: ExitStatus) -> bool {
42+
use std::os::unix::prelude::*;
43+
status.signal() == Some(4)
44+
}
45+
46+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
47+
#[allow(bad_style)]
48+
mod test {
49+
// An SSE type
50+
#[repr(simd)]
51+
#[derive(PartialEq, Debug, Clone, Copy)]
52+
struct __m128i(u64, u64);
53+
54+
// An AVX type
55+
#[repr(simd)]
56+
#[derive(PartialEq, Debug, Clone, Copy)]
57+
struct __m256i(u64, u64, u64, u64);
58+
59+
// An AVX-512 type
60+
#[repr(simd)]
61+
#[derive(PartialEq, Debug, Clone, Copy)]
62+
struct __m512i(u64, u64, u64, u64, u64, u64, u64, u64);
63+
64+
pub fn main(level: &str) {
65+
unsafe {
66+
main_normal(level);
67+
main_sse(level);
68+
if level == "sse" {
69+
return
70+
}
71+
main_avx(level);
72+
if level == "avx" {
73+
return
74+
}
75+
main_avx512(level);
76+
}
77+
}
78+
79+
macro_rules! mains {
80+
($(
81+
$(#[$attr:meta])*
82+
unsafe fn $main:ident(level: &str) {
83+
...
84+
}
85+
)*) => ($(
86+
$(#[$attr])*
87+
unsafe fn $main(level: &str) {
88+
let m128 = __m128i(1, 2);
89+
let m256 = __m256i(3, 4, 5, 6);
90+
let m512 = __m512i(7, 8, 9, 10, 11, 12, 13, 14);
91+
assert_eq!(id_sse_128(m128), m128);
92+
assert_eq!(id_sse_256(m256), m256);
93+
assert_eq!(id_sse_512(m512), m512);
94+
95+
if level == "sse" {
96+
return
97+
}
98+
assert_eq!(id_avx_128(m128), m128);
99+
assert_eq!(id_avx_256(m256), m256);
100+
assert_eq!(id_avx_512(m512), m512);
101+
102+
if level == "avx" {
103+
return
104+
}
105+
assert_eq!(id_avx512_128(m128), m128);
106+
assert_eq!(id_avx512_256(m256), m256);
107+
assert_eq!(id_avx512_512(m512), m512);
108+
}
109+
)*)
110+
}
111+
112+
mains! {
113+
unsafe fn main_normal(level: &str) { ... }
114+
#[target_feature(enable = "sse2")]
115+
unsafe fn main_sse(level: &str) { ... }
116+
#[target_feature(enable = "avx")]
117+
unsafe fn main_avx(level: &str) { ... }
118+
#[target_feature(enable = "avx512bw")]
119+
unsafe fn main_avx512(level: &str) { ... }
120+
}
121+
122+
123+
#[target_feature(enable = "sse2")]
124+
unsafe fn id_sse_128(a: __m128i) -> __m128i {
125+
assert_eq!(a, __m128i(1, 2));
126+
a.clone()
127+
}
128+
129+
#[target_feature(enable = "sse2")]
130+
unsafe fn id_sse_256(a: __m256i) -> __m256i {
131+
assert_eq!(a, __m256i(3, 4, 5, 6));
132+
a.clone()
133+
}
134+
135+
#[target_feature(enable = "sse2")]
136+
unsafe fn id_sse_512(a: __m512i) -> __m512i {
137+
assert_eq!(a, __m512i(7, 8, 9, 10, 11, 12, 13, 14));
138+
a.clone()
139+
}
140+
141+
#[target_feature(enable = "avx")]
142+
unsafe fn id_avx_128(a: __m128i) -> __m128i {
143+
assert_eq!(a, __m128i(1, 2));
144+
a.clone()
145+
}
146+
147+
#[target_feature(enable = "avx")]
148+
unsafe fn id_avx_256(a: __m256i) -> __m256i {
149+
assert_eq!(a, __m256i(3, 4, 5, 6));
150+
a.clone()
151+
}
152+
153+
#[target_feature(enable = "avx")]
154+
unsafe fn id_avx_512(a: __m512i) -> __m512i {
155+
assert_eq!(a, __m512i(7, 8, 9, 10, 11, 12, 13, 14));
156+
a.clone()
157+
}
158+
159+
#[target_feature(enable = "avx512bw")]
160+
unsafe fn id_avx512_128(a: __m128i) -> __m128i {
161+
assert_eq!(a, __m128i(1, 2));
162+
a.clone()
163+
}
164+
165+
#[target_feature(enable = "avx512bw")]
166+
unsafe fn id_avx512_256(a: __m256i) -> __m256i {
167+
assert_eq!(a, __m256i(3, 4, 5, 6));
168+
a.clone()
169+
}
170+
171+
#[target_feature(enable = "avx512bw")]
172+
unsafe fn id_avx512_512(a: __m512i) -> __m512i {
173+
assert_eq!(a, __m512i(7, 8, 9, 10, 11, 12, 13, 14));
174+
a.clone()
175+
}
176+
}
177+
178+
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
179+
mod test {
180+
pub fn main(level: &str) {}
181+
}

0 commit comments

Comments
 (0)