rustc_builtin_macros/deriving/
coerce_pointee.rs

1use ast::HasAttrs;
2use rustc_ast::mut_visit::MutVisitor;
3use rustc_ast::visit::BoundKind;
4use rustc_ast::{
5    self as ast, GenericArg, GenericBound, GenericParamKind, Generics, ItemKind, MetaItem,
6    TraitBoundModifiers, VariantData, WherePredicate,
7};
8use rustc_data_structures::flat_map_in_place::FlatMapInPlace;
9use rustc_errors::E0802;
10use rustc_expand::base::{Annotatable, ExtCtxt};
11use rustc_macros::Diagnostic;
12use rustc_span::{Ident, Span, Symbol, sym};
13use thin_vec::{ThinVec, thin_vec};
14
15use crate::errors;
16
17macro_rules! path {
18    ($span:expr, $($part:ident)::*) => { vec![$(Ident::new(sym::$part, $span),)*] }
19}
20
21pub(crate) fn expand_deriving_coerce_pointee(
22    cx: &ExtCtxt<'_>,
23    span: Span,
24    _mitem: &MetaItem,
25    item: &Annotatable,
26    push: &mut dyn FnMut(Annotatable),
27    _is_const: bool,
28) {
29    item.visit_with(&mut DetectNonGenericPointeeAttr { cx });
30
31    let (name_ident, generics) = if let Annotatable::Item(aitem) = item
32        && let ItemKind::Struct(ident, g, struct_data) = &aitem.kind
33    {
34        if !matches!(
35            struct_data,
36            VariantData::Struct { fields, recovered: _ } | VariantData::Tuple(fields, _)
37                if !fields.is_empty())
38        {
39            cx.dcx().emit_err(RequireOneField { span });
40            return;
41        }
42        (*ident, g)
43    } else {
44        cx.dcx().emit_err(RequireTransparent { span });
45        return;
46    };
47
48    // Convert generic parameters (from the struct) into generic args.
49    let self_params: Vec<_> = generics
50        .params
51        .iter()
52        .map(|p| match p.kind {
53            GenericParamKind::Lifetime => GenericArg::Lifetime(cx.lifetime(p.span(), p.ident)),
54            GenericParamKind::Type { .. } => GenericArg::Type(cx.ty_ident(p.span(), p.ident)),
55            GenericParamKind::Const { .. } => GenericArg::Const(cx.const_ident(p.span(), p.ident)),
56        })
57        .collect();
58    let type_params: Vec<_> = generics
59        .params
60        .iter()
61        .enumerate()
62        .filter_map(|(idx, p)| {
63            if let GenericParamKind::Type { .. } = p.kind {
64                Some((idx, p.span(), p.attrs().iter().any(|attr| attr.has_name(sym::pointee))))
65            } else {
66                None
67            }
68        })
69        .collect();
70
71    let pointee_param_idx = if type_params.is_empty() {
72        // `#[derive(CoercePointee)]` requires at least one generic type on the target `struct`
73        cx.dcx().emit_err(RequireOneGeneric { span });
74        return;
75    } else if type_params.len() == 1 {
76        // Regardless of the only type param being designed as `#[pointee]` or not, we can just use it as such
77        type_params[0].0
78    } else {
79        let mut pointees = type_params
80            .iter()
81            .filter_map(|&(idx, span, is_pointee)| is_pointee.then_some((idx, span)));
82        match (pointees.next(), pointees.next()) {
83            (Some((idx, _span)), None) => idx,
84            (None, _) => {
85                cx.dcx().emit_err(RequireOnePointee { span });
86                return;
87            }
88            (Some((_, one)), Some((_, another))) => {
89                cx.dcx().emit_err(TooManyPointees { one, another });
90                return;
91            }
92        }
93    };
94
95    // Create the type of `self`.
96    let path = cx.path_all(span, false, vec![name_ident], self_params.clone());
97    let self_type = cx.ty_path(path);
98
99    // Declare helper function that adds implementation blocks.
100    // FIXME(dingxiangfei2009): Investigate the set of attributes on target struct to be propagated to impls
101    let attrs = thin_vec![cx.attr_word(sym::automatically_derived, span),];
102    // # Validity assertion which will be checked later in `rustc_hir_analysis::coherence::builtins`.
103    {
104        let trait_path =
105            cx.path_all(span, true, path!(span, core::marker::CoercePointeeValidated), vec![]);
106        let trait_ref = cx.trait_ref(trait_path);
107        push(Annotatable::Item(
108            cx.item(
109                span,
110                attrs.clone(),
111                ast::ItemKind::Impl(ast::Impl {
112                    generics: Generics {
113                        params: generics
114                            .params
115                            .iter()
116                            .map(|p| match &p.kind {
117                                GenericParamKind::Lifetime => {
118                                    cx.lifetime_param(p.span(), p.ident, p.bounds.clone())
119                                }
120                                GenericParamKind::Type { default: _ } => {
121                                    cx.typaram(p.span(), p.ident, p.bounds.clone(), None)
122                                }
123                                GenericParamKind::Const { ty, span: _, default: _ } => cx
124                                    .const_param(
125                                        p.span(),
126                                        p.ident,
127                                        p.bounds.clone(),
128                                        ty.clone(),
129                                        None,
130                                    ),
131                            })
132                            .collect(),
133                        where_clause: generics.where_clause.clone(),
134                        span: generics.span,
135                    },
136                    of_trait: Some(Box::new(ast::TraitImplHeader {
137                        safety: ast::Safety::Default,
138                        polarity: ast::ImplPolarity::Positive,
139                        defaultness: ast::Defaultness::Final,
140                        constness: ast::Const::No,
141                        trait_ref,
142                    })),
143                    self_ty: self_type.clone(),
144                    items: ThinVec::new(),
145                }),
146            ),
147        ));
148    }
149    let mut add_impl_block = |generics, trait_symbol, trait_args| {
150        let mut parts = path!(span, core::ops);
151        parts.push(Ident::new(trait_symbol, span));
152        let trait_path = cx.path_all(span, true, parts, trait_args);
153        let trait_ref = cx.trait_ref(trait_path);
154        let item = cx.item(
155            span,
156            attrs.clone(),
157            ast::ItemKind::Impl(ast::Impl {
158                generics,
159                of_trait: Some(Box::new(ast::TraitImplHeader {
160                    safety: ast::Safety::Default,
161                    polarity: ast::ImplPolarity::Positive,
162                    defaultness: ast::Defaultness::Final,
163                    constness: ast::Const::No,
164                    trait_ref,
165                })),
166                self_ty: self_type.clone(),
167                items: ThinVec::new(),
168            }),
169        );
170        push(Annotatable::Item(item));
171    };
172
173    // Create unsized `self`, that is, one where the `#[pointee]` type arg is replaced with `__S`. For
174    // example, instead of `MyType<'a, T>`, it will be `MyType<'a, __S>`.
175    let s_ty = cx.ty_ident(span, Ident::new(sym::__S, span));
176    let mut alt_self_params = self_params;
177    alt_self_params[pointee_param_idx] = GenericArg::Type(s_ty.clone());
178    let alt_self_type = cx.ty_path(cx.path_all(span, false, vec![name_ident], alt_self_params));
179
180    // # Add `Unsize<__S>` bound to `#[pointee]` at the generic parameter location
181    //
182    // Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it.
183    let mut impl_generics = generics.clone();
184    let pointee_ty_ident = generics.params[pointee_param_idx].ident;
185    let mut self_bounds;
186    {
187        let pointee = &mut impl_generics.params[pointee_param_idx];
188        self_bounds = pointee.bounds.clone();
189        if !contains_maybe_sized_bound(&self_bounds)
190            && !contains_maybe_sized_bound_on_pointee(
191                &generics.where_clause.predicates,
192                pointee_ty_ident.name,
193            )
194        {
195            cx.dcx().emit_err(RequiresMaybeSized {
196                span: pointee_ty_ident.span,
197                name: pointee_ty_ident,
198            });
199            return;
200        }
201        let arg = GenericArg::Type(s_ty.clone());
202        let unsize = cx.path_all(span, true, path!(span, core::marker::Unsize), vec![arg]);
203        pointee.bounds.push(cx.trait_bound(unsize, false));
204        // Drop `#[pointee]` attribute since it should not be recognized outside `derive(CoercePointee)`
205        pointee.attrs.retain(|attr| !attr.has_name(sym::pointee));
206    }
207
208    // # Rewrite generic parameter bounds
209    // For each bound `U: ..` in `struct<U: ..>`, make a new bound with `__S` in place of `#[pointee]`
210    // Example:
211    // ```
212    // struct<
213    //     U: Trait<T>,
214    //     #[pointee] T: Trait<T> + ?Sized,
215    //     V: Trait<T>> ...
216    // ```
217    // ... generates this `impl` generic parameters
218    // ```
219    // impl<
220    //     U: Trait<T> + Trait<__S>,
221    //     T: Trait<T> + ?Sized + Unsize<__S>, // (**)
222    //     __S: Trait<__S> + ?Sized, // (*)
223    //     V: Trait<T> + Trait<__S>> ...
224    // ```
225    // The new bound marked with (*) has to be done separately.
226    // See next section
227    for (idx, (params, orig_params)) in
228        impl_generics.params.iter_mut().zip(&generics.params).enumerate()
229    {
230        // Default type parameters are rejected for `impl` block.
231        // We should drop them now.
232        match &mut params.kind {
233            ast::GenericParamKind::Const { default, .. } => *default = None,
234            ast::GenericParamKind::Type { default } => *default = None,
235            ast::GenericParamKind::Lifetime => {}
236        }
237        // We CANNOT rewrite `#[pointee]` type parameter bounds.
238        // This has been set in stone. (**)
239        // So we skip over it.
240        // Otherwise, we push extra bounds involving `__S`.
241        if idx != pointee_param_idx {
242            for bound in &orig_params.bounds {
243                let mut bound = bound.clone();
244                let mut substitution = TypeSubstitution {
245                    from_name: pointee_ty_ident.name,
246                    to_ty: &s_ty,
247                    rewritten: false,
248                };
249                substitution.visit_param_bound(&mut bound, BoundKind::Bound);
250                if substitution.rewritten {
251                    // We found use of `#[pointee]` somewhere,
252                    // so we make a new bound using `__S` in place of `#[pointee]`
253                    params.bounds.push(bound);
254                }
255            }
256        }
257    }
258
259    // # Insert `__S` type parameter
260    //
261    // We now insert `__S` with the missing bounds marked with (*) above.
262    // We should also write the bounds from `#[pointee]` to `__S` as required by `Unsize<__S>`.
263    {
264        let mut substitution =
265            TypeSubstitution { from_name: pointee_ty_ident.name, to_ty: &s_ty, rewritten: false };
266        for bound in &mut self_bounds {
267            substitution.visit_param_bound(bound, BoundKind::Bound);
268        }
269    }
270
271    // # Rewrite `where` clauses
272    //
273    // Move on to `where` clauses.
274    // Example:
275    // ```
276    // struct MyPointer<#[pointee] T, ..>
277    // where
278    //   U: Trait<V> + Trait<T>,
279    //   Companion<T>: Trait<T>,
280    //   T: Trait<T> + ?Sized,
281    // { .. }
282    // ```
283    // ... will have a impl prelude like so
284    // ```
285    // impl<..> ..
286    // where
287    //   U: Trait<V> + Trait<T>,
288    //   U: Trait<__S>,
289    //   Companion<T>: Trait<T>,
290    //   Companion<__S>: Trait<__S>,
291    //   T: Trait<T> + ?Sized,
292    //   __S: Trait<__S> + ?Sized,
293    // ```
294    //
295    // We should also write a few new `where` bounds from `#[pointee] T` to `__S`
296    // as well as any bound that indirectly involves the `#[pointee] T` type.
297    for predicate in &generics.where_clause.predicates {
298        if let ast::WherePredicateKind::BoundPredicate(bound) = &predicate.kind {
299            let mut substitution = TypeSubstitution {
300                from_name: pointee_ty_ident.name,
301                to_ty: &s_ty,
302                rewritten: false,
303            };
304            let mut kind = ast::WherePredicateKind::BoundPredicate(bound.clone());
305            substitution.visit_where_predicate_kind(&mut kind);
306            if substitution.rewritten {
307                let predicate = ast::WherePredicate {
308                    attrs: predicate.attrs.clone(),
309                    kind,
310                    span: predicate.span,
311                    id: ast::DUMMY_NODE_ID,
312                    is_placeholder: false,
313                };
314                impl_generics.where_clause.predicates.push(predicate);
315            }
316        }
317    }
318
319    let extra_param = cx.typaram(span, Ident::new(sym::__S, span), self_bounds, None);
320    impl_generics.params.insert(pointee_param_idx + 1, extra_param);
321
322    // Add the impl blocks for `DispatchFromDyn` and `CoerceUnsized`.
323    let gen_args = vec![GenericArg::Type(alt_self_type)];
324    add_impl_block(impl_generics.clone(), sym::DispatchFromDyn, gen_args.clone());
325    add_impl_block(impl_generics.clone(), sym::CoerceUnsized, gen_args);
326}
327
328fn contains_maybe_sized_bound_on_pointee(predicates: &[WherePredicate], pointee: Symbol) -> bool {
329    for bound in predicates {
330        if let ast::WherePredicateKind::BoundPredicate(bound) = &bound.kind
331            && bound.bounded_ty.kind.is_simple_path().is_some_and(|name| name == pointee)
332        {
333            for bound in &bound.bounds {
334                if is_maybe_sized_bound(bound) {
335                    return true;
336                }
337            }
338        }
339    }
340    false
341}
342
343fn is_maybe_sized_bound(bound: &GenericBound) -> bool {
344    if let GenericBound::Trait(trait_ref) = bound
345        && let TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. } =
346            trait_ref.modifiers
347        && is_sized_marker(&trait_ref.trait_ref.path)
348    {
349        true
350    } else {
351        false
352    }
353}
354
355fn contains_maybe_sized_bound(bounds: &[GenericBound]) -> bool {
356    bounds.iter().any(is_maybe_sized_bound)
357}
358
359fn path_segment_is_exact_match(path_segments: &[ast::PathSegment], syms: &[Symbol]) -> bool {
360    path_segments.iter().zip(syms).all(|(segment, &symbol)| segment.ident.name == symbol)
361}
362
363fn is_sized_marker(path: &ast::Path) -> bool {
364    const CORE_UNSIZE: [Symbol; 3] = [sym::core, sym::marker, sym::Sized];
365    const STD_UNSIZE: [Symbol; 3] = [sym::std, sym::marker, sym::Sized];
366    if path.segments.len() == 4 && path.is_global() {
367        path_segment_is_exact_match(&path.segments[1..], &CORE_UNSIZE)
368            || path_segment_is_exact_match(&path.segments[1..], &STD_UNSIZE)
369    } else if path.segments.len() == 3 {
370        path_segment_is_exact_match(&path.segments, &CORE_UNSIZE)
371            || path_segment_is_exact_match(&path.segments, &STD_UNSIZE)
372    } else {
373        *path == sym::Sized
374    }
375}
376
377struct TypeSubstitution<'a> {
378    from_name: Symbol,
379    to_ty: &'a ast::Ty,
380    rewritten: bool,
381}
382
383impl<'a> ast::mut_visit::MutVisitor for TypeSubstitution<'a> {
384    fn visit_ty(&mut self, ty: &mut ast::Ty) {
385        if let Some(name) = ty.kind.is_simple_path()
386            && name == self.from_name
387        {
388            *ty = self.to_ty.clone();
389            self.rewritten = true;
390        } else {
391            ast::mut_visit::walk_ty(self, ty);
392        }
393    }
394
395    fn visit_where_predicate_kind(&mut self, kind: &mut ast::WherePredicateKind) {
396        match kind {
397            rustc_ast::WherePredicateKind::BoundPredicate(bound) => {
398                bound
399                    .bound_generic_params
400                    .flat_map_in_place(|param| self.flat_map_generic_param(param));
401                self.visit_ty(&mut bound.bounded_ty);
402                for bound in &mut bound.bounds {
403                    self.visit_param_bound(bound, BoundKind::Bound)
404                }
405            }
406            rustc_ast::WherePredicateKind::RegionPredicate(_)
407            | rustc_ast::WherePredicateKind::EqPredicate(_) => {}
408        }
409    }
410}
411
412struct DetectNonGenericPointeeAttr<'a, 'b> {
413    cx: &'a ExtCtxt<'b>,
414}
415
416impl<'a, 'b> rustc_ast::visit::Visitor<'a> for DetectNonGenericPointeeAttr<'a, 'b> {
417    fn visit_attribute(&mut self, attr: &'a rustc_ast::Attribute) -> Self::Result {
418        if attr.has_name(sym::pointee) {
419            self.cx.dcx().emit_err(errors::NonGenericPointee { span: attr.span });
420        }
421    }
422
423    fn visit_generic_param(&mut self, param: &'a rustc_ast::GenericParam) -> Self::Result {
424        let mut error_on_pointee = AlwaysErrorOnGenericParam { cx: self.cx };
425
426        match &param.kind {
427            GenericParamKind::Type { default } => {
428                // The `default` may end up containing a block expression.
429                // The problem is block expressions  may define structs with generics.
430                // A user may attach a #[pointee] attribute to one of these generics
431                // We want to catch that. The simple solution is to just
432                // always raise a `NonGenericPointee` error when this happens.
433                //
434                // This solution does reject valid rust programs but,
435                // such a code would have to, in order:
436                // - Define a smart pointer struct.
437                // - Somewhere in this struct definition use a type with a const generic argument.
438                // - Calculate this const generic in a expression block.
439                // - Define a new smart pointer type in this block.
440                // - Have this smart pointer type have more than 1 generic type.
441                // In this case, the inner smart pointer derive would be complaining that it
442                // needs a pointer attribute. Meanwhile, the outer macro would be complaining
443                // that we attached a #[pointee] to a generic type argument while helpfully
444                // informing the user that #[pointee] can only be attached to generic pointer arguments
445                rustc_ast::visit::visit_opt!(error_on_pointee, visit_ty, default);
446            }
447
448            GenericParamKind::Const { .. } | GenericParamKind::Lifetime => {
449                rustc_ast::visit::walk_generic_param(&mut error_on_pointee, param);
450            }
451        }
452    }
453
454    fn visit_ty(&mut self, t: &'a rustc_ast::Ty) -> Self::Result {
455        let mut error_on_pointee = AlwaysErrorOnGenericParam { cx: self.cx };
456        error_on_pointee.visit_ty(t)
457    }
458}
459
460struct AlwaysErrorOnGenericParam<'a, 'b> {
461    cx: &'a ExtCtxt<'b>,
462}
463
464impl<'a, 'b> rustc_ast::visit::Visitor<'a> for AlwaysErrorOnGenericParam<'a, 'b> {
465    fn visit_attribute(&mut self, attr: &'a rustc_ast::Attribute) -> Self::Result {
466        if attr.has_name(sym::pointee) {
467            self.cx.dcx().emit_err(errors::NonGenericPointee { span: attr.span });
468        }
469    }
470}
471
472#[derive(Diagnostic)]
473#[diag(builtin_macros_coerce_pointee_requires_transparent, code = E0802)]
474struct RequireTransparent {
475    #[primary_span]
476    span: Span,
477}
478
479#[derive(Diagnostic)]
480#[diag(builtin_macros_coerce_pointee_requires_one_field, code = E0802)]
481struct RequireOneField {
482    #[primary_span]
483    span: Span,
484}
485
486#[derive(Diagnostic)]
487#[diag(builtin_macros_coerce_pointee_requires_one_generic, code = E0802)]
488struct RequireOneGeneric {
489    #[primary_span]
490    span: Span,
491}
492
493#[derive(Diagnostic)]
494#[diag(builtin_macros_coerce_pointee_requires_one_pointee, code = E0802)]
495struct RequireOnePointee {
496    #[primary_span]
497    span: Span,
498}
499
500#[derive(Diagnostic)]
501#[diag(builtin_macros_coerce_pointee_too_many_pointees, code = E0802)]
502struct TooManyPointees {
503    #[primary_span]
504    one: Span,
505    #[label]
506    another: Span,
507}
508
509#[derive(Diagnostic)]
510#[diag(builtin_macros_coerce_pointee_requires_maybe_sized, code = E0802)]
511struct RequiresMaybeSized {
512    #[primary_span]
513    span: Span,
514    name: Ident,
515}