Skip to content

Commit 48421e9

Browse files
derive(SmartPointer): rewrite bounds in where and generic bounds
1 parent a5ee5cb commit 48421e9

File tree

4 files changed

+308
-11
lines changed

4 files changed

+308
-11
lines changed

compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs

Lines changed: 164 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,30 @@
11
use std::mem::swap;
22

33
use ast::HasAttrs;
4+
use rustc_ast::mut_visit::MutVisitor;
5+
use rustc_ast::visit::BoundKind;
46
use rustc_ast::{
57
self as ast, GenericArg, GenericBound, GenericParamKind, ItemKind, MetaItem,
68
TraitBoundModifiers, VariantData,
79
};
810
use rustc_attr as attr;
11+
use rustc_data_structures::flat_map_in_place::FlatMapInPlace;
912
use rustc_expand::base::{Annotatable, ExtCtxt};
1013
use rustc_span::symbol::{sym, Ident};
11-
use rustc_span::Span;
14+
use rustc_span::{Span, Symbol};
1215
use smallvec::{smallvec, SmallVec};
1316
use thin_vec::{thin_vec, ThinVec};
1417

18+
type AstTy = ast::ptr::P<ast::Ty>;
19+
1520
macro_rules! path {
1621
($span:expr, $($part:ident)::*) => { vec![$(Ident::new(sym::$part, $span),)*] }
1722
}
1823

24+
macro_rules! symbols {
25+
($($part:ident)::*) => { [$(sym::$part),*] }
26+
}
27+
1928
pub fn expand_deriving_smart_ptr(
2029
cx: &ExtCtxt<'_>,
2130
span: Span,
@@ -143,31 +152,175 @@ pub fn expand_deriving_smart_ptr(
143152

144153
// Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it.
145154
let mut impl_generics = generics.clone();
155+
let pointee_ty_ident = generics.params[pointee_param_idx].ident;
156+
let mut self_bounds;
146157
{
147158
let p = &mut impl_generics.params[pointee_param_idx];
159+
self_bounds = p.bounds.clone();
148160
let arg = GenericArg::Type(s_ty.clone());
149161
let unsize = cx.path_all(span, true, path!(span, core::marker::Unsize), vec![arg]);
150162
p.bounds.push(cx.trait_bound(unsize, false));
151163
let mut attrs = thin_vec![];
152164
swap(&mut p.attrs, &mut attrs);
153165
p.attrs = attrs.into_iter().filter(|attr| !attr.has_name(sym::pointee)).collect();
154166
}
167+
// We should not set default values to constant generic parameters
168+
// and write out bounds that indirectly involves `#[pointee]`.
169+
for (idx, (params, orig_params)) in
170+
impl_generics.params.iter_mut().zip(&generics.params).enumerate()
171+
{
172+
if idx == pointee_param_idx {
173+
continue;
174+
}
175+
match &mut params.kind {
176+
ast::GenericParamKind::Const { default, .. } => *default = None,
177+
ast::GenericParamKind::Type { default } => *default = None,
178+
ast::GenericParamKind::Lifetime => {}
179+
}
180+
for bound in &orig_params.bounds {
181+
let mut bound = bound.clone();
182+
let mut substitution = TypeSubstitution {
183+
from_name: pointee_ty_ident.name,
184+
to_ty: &s_ty,
185+
rewritten: false,
186+
};
187+
substitution.visit_param_bound(&mut bound, BoundKind::Bound);
188+
if substitution.rewritten {
189+
params.bounds.push(bound);
190+
}
191+
}
192+
}
155193

156194
// Add the `__S: ?Sized` extra parameter to the impl block.
195+
// We should also write the bounds from `#[pointee]` to `__S` as required by `Unsize<__S>`.
157196
let sized = cx.path_global(span, path!(span, core::marker::Sized));
158-
let bound = GenericBound::Trait(
159-
cx.poly_trait_ref(span, sized),
160-
TraitBoundModifiers {
161-
polarity: ast::BoundPolarity::Maybe(span),
162-
constness: ast::BoundConstness::Never,
163-
asyncness: ast::BoundAsyncness::Normal,
164-
},
165-
);
166-
let extra_param = cx.typaram(span, Ident::new(sym::__S, span), vec![bound], None);
167-
impl_generics.params.push(extra_param);
197+
if self_bounds.iter().all(|bound| {
198+
if let GenericBound::Trait(
199+
trait_ref,
200+
TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. },
201+
) = bound
202+
{
203+
!is_sized_marker(&trait_ref.trait_ref.path)
204+
} else {
205+
false
206+
}
207+
}) {
208+
self_bounds.push(GenericBound::Trait(
209+
cx.poly_trait_ref(span, sized),
210+
TraitBoundModifiers {
211+
polarity: ast::BoundPolarity::Maybe(span),
212+
constness: ast::BoundConstness::Never,
213+
asyncness: ast::BoundAsyncness::Normal,
214+
},
215+
));
216+
}
217+
{
218+
let mut substitution =
219+
TypeSubstitution { from_name: pointee_ty_ident.name, to_ty: &s_ty, rewritten: false };
220+
for bound in &mut self_bounds {
221+
substitution.visit_param_bound(bound, BoundKind::Bound);
222+
}
223+
}
224+
225+
// We should also commute the where bounds from `#[pointee]` to `__S`
226+
// as well as any bound that indirectly involves the `#[pointee]` type.
227+
for bound in &generics.where_clause.predicates {
228+
if let ast::WherePredicate::BoundPredicate(bound) = bound {
229+
let bound_on_pointee = bound
230+
.bounded_ty
231+
.kind
232+
.is_simple_path()
233+
.map_or(false, |name| name == pointee_ty_ident.name);
234+
235+
let bounds: Vec<_> = bound
236+
.bounds
237+
.iter()
238+
.filter(|bound| {
239+
if let GenericBound::Trait(
240+
trait_ref,
241+
TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. },
242+
) = bound
243+
{
244+
!bound_on_pointee || !is_sized_marker(&trait_ref.trait_ref.path)
245+
} else {
246+
true
247+
}
248+
})
249+
.cloned()
250+
.collect();
251+
let mut substitution = TypeSubstitution {
252+
from_name: pointee_ty_ident.name,
253+
to_ty: &s_ty,
254+
rewritten: bounds.len() != bound.bounds.len(),
255+
};
256+
let mut predicate = ast::WherePredicate::BoundPredicate(ast::WhereBoundPredicate {
257+
span: bound.span,
258+
bound_generic_params: bound.bound_generic_params.clone(),
259+
bounded_ty: bound.bounded_ty.clone(),
260+
bounds,
261+
});
262+
substitution.visit_where_predicate(&mut predicate);
263+
if substitution.rewritten {
264+
impl_generics.where_clause.predicates.push(predicate);
265+
}
266+
}
267+
}
268+
269+
let extra_param = cx.typaram(span, Ident::new(sym::__S, span), self_bounds, None);
270+
impl_generics.params.insert(pointee_param_idx + 1, extra_param);
168271

169272
// Add the impl blocks for `DispatchFromDyn` and `CoerceUnsized`.
170273
let gen_args = vec![GenericArg::Type(alt_self_type.clone())];
171274
add_impl_block(impl_generics.clone(), sym::DispatchFromDyn, gen_args.clone());
172275
add_impl_block(impl_generics.clone(), sym::CoerceUnsized, gen_args.clone());
173276
}
277+
278+
fn is_sized_marker(path: &ast::Path) -> bool {
279+
const CORE_UNSIZE: [Symbol; 3] = symbols!(core::marker::Sized);
280+
const STD_UNSIZE: [Symbol; 3] = symbols!(std::marker::Sized);
281+
if path.segments.len() == 3 {
282+
path.segments.iter().zip(CORE_UNSIZE).all(|(segment, symbol)| segment.ident.name == symbol)
283+
|| path
284+
.segments
285+
.iter()
286+
.zip(STD_UNSIZE)
287+
.all(|(segment, symbol)| segment.ident.name == symbol)
288+
} else {
289+
*path == sym::Sized
290+
}
291+
}
292+
293+
struct TypeSubstitution<'a> {
294+
from_name: Symbol,
295+
to_ty: &'a AstTy,
296+
rewritten: bool,
297+
}
298+
299+
impl<'a> ast::mut_visit::MutVisitor for TypeSubstitution<'a> {
300+
fn visit_ty(&mut self, ty: &mut AstTy) {
301+
if let Some(name) = ty.kind.is_simple_path()
302+
&& name == self.from_name
303+
{
304+
*ty = self.to_ty.clone();
305+
self.rewritten = true;
306+
} else {
307+
ast::mut_visit::walk_ty(self, ty);
308+
}
309+
}
310+
311+
fn visit_where_predicate(&mut self, where_predicate: &mut ast::WherePredicate) {
312+
match where_predicate {
313+
rustc_ast::WherePredicate::BoundPredicate(bound) => {
314+
bound
315+
.bound_generic_params
316+
.flat_map_in_place(|param| self.flat_map_generic_param(param));
317+
self.visit_ty(&mut bound.bounded_ty);
318+
for bound in &mut bound.bounds {
319+
self.visit_param_bound(bound, BoundKind::Bound)
320+
}
321+
}
322+
rustc_ast::WherePredicate::RegionPredicate(_)
323+
| rustc_ast::WherePredicate::EqPredicate(_) => {}
324+
}
325+
}
326+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//@ check-pass
2+
//@ compile-flags: -Zunpretty=expanded
3+
#![feature(derive_smart_pointer)]
4+
use std::marker::SmartPointer;
5+
6+
pub trait MyTrait<T: ?Sized> {}
7+
8+
#[derive(SmartPointer)]
9+
#[repr(transparent)]
10+
struct MyPointer<'a, #[pointee] T: ?Sized> {
11+
ptr: &'a T,
12+
}
13+
14+
#[derive(core::marker::SmartPointer)]
15+
#[repr(transparent)]
16+
pub struct MyPointer2<'a, Y, Z: MyTrait<T>, #[pointee] T: ?Sized + MyTrait<T>, X: MyTrait<T>>
17+
where
18+
Y: MyTrait<T>,
19+
{
20+
data: &'a mut T,
21+
x: core::marker::PhantomData<X>,
22+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#![feature(prelude_import)]
2+
#![no_std]
3+
//@ check-pass
4+
//@ compile-flags: -Zunpretty=expanded
5+
#![feature(derive_smart_pointer)]
6+
#[prelude_import]
7+
use ::std::prelude::rust_2015::*;
8+
#[macro_use]
9+
extern crate std;
10+
use std::marker::SmartPointer;
11+
12+
pub trait MyTrait<T: ?Sized> {}
13+
14+
#[repr(transparent)]
15+
struct MyPointer<'a, #[pointee] T: ?Sized> {
16+
ptr: &'a T,
17+
}
18+
#[automatically_derived]
19+
impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized>
20+
::core::ops::DispatchFromDyn<MyPointer<'a, __S>> for MyPointer<'a, T> {
21+
}
22+
#[automatically_derived]
23+
impl<'a, T: ?Sized + ::core::marker::Unsize<__S>, __S: ?Sized>
24+
::core::ops::CoerceUnsized<MyPointer<'a, __S>> for MyPointer<'a, T> {
25+
}
26+
27+
#[repr(transparent)]
28+
pub struct MyPointer2<'a, Y, Z: MyTrait<T>, #[pointee] T: ?Sized + MyTrait<T>,
29+
X: MyTrait<T>> where Y: MyTrait<T> {
30+
data: &'a mut T,
31+
x: core::marker::PhantomData<X>,
32+
}
33+
#[automatically_derived]
34+
impl<'a, Y, Z: MyTrait<T> + MyTrait<__S>, T: ?Sized + MyTrait<T> +
35+
::core::marker::Unsize<__S>, __S: ?Sized + MyTrait<__S>, X: MyTrait<T> +
36+
MyTrait<__S>> ::core::ops::DispatchFromDyn<MyPointer2<'a, Y, Z, __S, X>>
37+
for MyPointer2<'a, Y, Z, T, X> where Y: MyTrait<T>, Y: MyTrait<__S> {
38+
}
39+
#[automatically_derived]
40+
impl<'a, Y, Z: MyTrait<T> + MyTrait<__S>, T: ?Sized + MyTrait<T> +
41+
::core::marker::Unsize<__S>, __S: ?Sized + MyTrait<__S>, X: MyTrait<T> +
42+
MyTrait<__S>> ::core::ops::CoerceUnsized<MyPointer2<'a, Y, Z, __S, X>> for
43+
MyPointer2<'a, Y, Z, T, X> where Y: MyTrait<T>, Y: MyTrait<__S> {
44+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//@ check-pass
2+
3+
#![feature(derive_smart_pointer)]
4+
5+
#[derive(core::marker::SmartPointer)]
6+
#[repr(transparent)]
7+
pub struct Ptr<'a, #[pointee] T: OnDrop + ?Sized, X> {
8+
data: &'a mut T,
9+
x: core::marker::PhantomData<X>,
10+
}
11+
12+
pub trait OnDrop {
13+
fn on_drop(&mut self);
14+
}
15+
16+
#[derive(core::marker::SmartPointer)]
17+
#[repr(transparent)]
18+
pub struct Ptr2<'a, #[pointee] T: ?Sized, X>
19+
where
20+
T: OnDrop,
21+
{
22+
data: &'a mut T,
23+
x: core::marker::PhantomData<X>,
24+
}
25+
26+
pub trait MyTrait<T: ?Sized> {}
27+
28+
#[derive(core::marker::SmartPointer)]
29+
#[repr(transparent)]
30+
pub struct Ptr3<'a, #[pointee] T: ?Sized, X>
31+
where
32+
T: MyTrait<T>,
33+
{
34+
data: &'a mut T,
35+
x: core::marker::PhantomData<X>,
36+
}
37+
38+
#[derive(core::marker::SmartPointer)]
39+
#[repr(transparent)]
40+
pub struct Ptr4<'a, #[pointee] T: MyTrait<T> + ?Sized, X> {
41+
data: &'a mut T,
42+
x: core::marker::PhantomData<X>,
43+
}
44+
45+
#[derive(core::marker::SmartPointer)]
46+
#[repr(transparent)]
47+
pub struct Ptr5<'a, #[pointee] T: ?Sized, X>
48+
where
49+
Ptr5Companion<T>: MyTrait<T>,
50+
Ptr5Companion2: MyTrait<T>,
51+
{
52+
data: &'a mut T,
53+
x: core::marker::PhantomData<X>,
54+
}
55+
56+
pub struct Ptr5Companion<T: ?Sized>(core::marker::PhantomData<T>);
57+
pub struct Ptr5Companion2;
58+
59+
#[derive(core::marker::SmartPointer)]
60+
#[repr(transparent)]
61+
pub struct Ptr6<'a, #[pointee] T: ?Sized, X: MyTrait<T>> {
62+
data: &'a mut T,
63+
x: core::marker::PhantomData<X>,
64+
}
65+
66+
// a reduced example from https://lore.kernel.org/all/[email protected]/
67+
#[repr(transparent)]
68+
#[derive(core::marker::SmartPointer)]
69+
pub struct ListArc<#[pointee] T, const ID: u64 = 0>
70+
where
71+
T: ListArcSafe<ID> + ?Sized,
72+
{
73+
arc: *const T,
74+
}
75+
76+
pub trait ListArcSafe<const ID: u64> {}
77+
78+
fn main() {}

0 commit comments

Comments
 (0)