rustc_builtin_macros/deriving/
coerce_pointee.rs1use ast::HasAttrs;
2use rustc_ast::mut_visit::MutVisitor;
3use rustc_ast::visit::BoundKind;
4use rustc_ast::{
5 self as ast, GenericArg, GenericBound, GenericParamKind, Generics, ItemKind, MetaItem,
6 TraitBoundModifiers, VariantData, WherePredicate,
7};
8use rustc_data_structures::flat_map_in_place::FlatMapInPlace;
9use rustc_errors::E0802;
10use rustc_expand::base::{Annotatable, ExtCtxt};
11use rustc_macros::Diagnostic;
12use rustc_span::{Ident, Span, Symbol, sym};
13use thin_vec::{ThinVec, thin_vec};
14
15use crate::errors;
16
17macro_rules! path {
18 ($span:expr, $($part:ident)::*) => { vec![$(Ident::new(sym::$part, $span),)*] }
19}
20
21pub(crate) fn expand_deriving_coerce_pointee(
22 cx: &ExtCtxt<'_>,
23 span: Span,
24 _mitem: &MetaItem,
25 item: &Annotatable,
26 push: &mut dyn FnMut(Annotatable),
27 _is_const: bool,
28) {
29 item.visit_with(&mut DetectNonGenericPointeeAttr { cx });
30
31 let (name_ident, generics) = if let Annotatable::Item(aitem) = item
32 && let ItemKind::Struct(ident, g, struct_data) = &aitem.kind
33 {
34 if !matches!(
35 struct_data,
36 VariantData::Struct { fields, recovered: _ } | VariantData::Tuple(fields, _)
37 if !fields.is_empty())
38 {
39 cx.dcx().emit_err(RequireOneField { span });
40 return;
41 }
42 (*ident, g)
43 } else {
44 cx.dcx().emit_err(RequireTransparent { span });
45 return;
46 };
47
48 let self_params: Vec<_> = generics
50 .params
51 .iter()
52 .map(|p| match p.kind {
53 GenericParamKind::Lifetime => GenericArg::Lifetime(cx.lifetime(p.span(), p.ident)),
54 GenericParamKind::Type { .. } => GenericArg::Type(cx.ty_ident(p.span(), p.ident)),
55 GenericParamKind::Const { .. } => GenericArg::Const(cx.const_ident(p.span(), p.ident)),
56 })
57 .collect();
58 let type_params: Vec<_> = generics
59 .params
60 .iter()
61 .enumerate()
62 .filter_map(|(idx, p)| {
63 if let GenericParamKind::Type { .. } = p.kind {
64 Some((idx, p.span(), p.attrs().iter().any(|attr| attr.has_name(sym::pointee))))
65 } else {
66 None
67 }
68 })
69 .collect();
70
71 let pointee_param_idx = if type_params.is_empty() {
72 cx.dcx().emit_err(RequireOneGeneric { span });
74 return;
75 } else if type_params.len() == 1 {
76 type_params[0].0
78 } else {
79 let mut pointees = type_params
80 .iter()
81 .filter_map(|&(idx, span, is_pointee)| is_pointee.then_some((idx, span)));
82 match (pointees.next(), pointees.next()) {
83 (Some((idx, _span)), None) => idx,
84 (None, _) => {
85 cx.dcx().emit_err(RequireOnePointee { span });
86 return;
87 }
88 (Some((_, one)), Some((_, another))) => {
89 cx.dcx().emit_err(TooManyPointees { one, another });
90 return;
91 }
92 }
93 };
94
95 let path = cx.path_all(span, false, vec![name_ident], self_params.clone());
97 let self_type = cx.ty_path(path);
98
99 let attrs = thin_vec![cx.attr_word(sym::automatically_derived, span),];
102 {
104 let trait_path =
105 cx.path_all(span, true, path!(span, core::marker::CoercePointeeValidated), vec![]);
106 let trait_ref = cx.trait_ref(trait_path);
107 push(Annotatable::Item(
108 cx.item(
109 span,
110 attrs.clone(),
111 ast::ItemKind::Impl(ast::Impl {
112 generics: Generics {
113 params: generics
114 .params
115 .iter()
116 .map(|p| match &p.kind {
117 GenericParamKind::Lifetime => {
118 cx.lifetime_param(p.span(), p.ident, p.bounds.clone())
119 }
120 GenericParamKind::Type { default: _ } => {
121 cx.typaram(p.span(), p.ident, p.bounds.clone(), None)
122 }
123 GenericParamKind::Const { ty, span: _, default: _ } => cx
124 .const_param(
125 p.span(),
126 p.ident,
127 p.bounds.clone(),
128 ty.clone(),
129 None,
130 ),
131 })
132 .collect(),
133 where_clause: generics.where_clause.clone(),
134 span: generics.span,
135 },
136 of_trait: Some(Box::new(ast::TraitImplHeader {
137 safety: ast::Safety::Default,
138 polarity: ast::ImplPolarity::Positive,
139 defaultness: ast::Defaultness::Final,
140 constness: ast::Const::No,
141 trait_ref,
142 })),
143 self_ty: self_type.clone(),
144 items: ThinVec::new(),
145 }),
146 ),
147 ));
148 }
149 let mut add_impl_block = |generics, trait_symbol, trait_args| {
150 let mut parts = path!(span, core::ops);
151 parts.push(Ident::new(trait_symbol, span));
152 let trait_path = cx.path_all(span, true, parts, trait_args);
153 let trait_ref = cx.trait_ref(trait_path);
154 let item = cx.item(
155 span,
156 attrs.clone(),
157 ast::ItemKind::Impl(ast::Impl {
158 generics,
159 of_trait: Some(Box::new(ast::TraitImplHeader {
160 safety: ast::Safety::Default,
161 polarity: ast::ImplPolarity::Positive,
162 defaultness: ast::Defaultness::Final,
163 constness: ast::Const::No,
164 trait_ref,
165 })),
166 self_ty: self_type.clone(),
167 items: ThinVec::new(),
168 }),
169 );
170 push(Annotatable::Item(item));
171 };
172
173 let s_ty = cx.ty_ident(span, Ident::new(sym::__S, span));
176 let mut alt_self_params = self_params;
177 alt_self_params[pointee_param_idx] = GenericArg::Type(s_ty.clone());
178 let alt_self_type = cx.ty_path(cx.path_all(span, false, vec![name_ident], alt_self_params));
179
180 let mut impl_generics = generics.clone();
184 let pointee_ty_ident = generics.params[pointee_param_idx].ident;
185 let mut self_bounds;
186 {
187 let pointee = &mut impl_generics.params[pointee_param_idx];
188 self_bounds = pointee.bounds.clone();
189 if !contains_maybe_sized_bound(&self_bounds)
190 && !contains_maybe_sized_bound_on_pointee(
191 &generics.where_clause.predicates,
192 pointee_ty_ident.name,
193 )
194 {
195 cx.dcx().emit_err(RequiresMaybeSized {
196 span: pointee_ty_ident.span,
197 name: pointee_ty_ident,
198 });
199 return;
200 }
201 let arg = GenericArg::Type(s_ty.clone());
202 let unsize = cx.path_all(span, true, path!(span, core::marker::Unsize), vec![arg]);
203 pointee.bounds.push(cx.trait_bound(unsize, false));
204 pointee.attrs.retain(|attr| !attr.has_name(sym::pointee));
206 }
207
208 for (idx, (params, orig_params)) in
228 impl_generics.params.iter_mut().zip(&generics.params).enumerate()
229 {
230 match &mut params.kind {
233 ast::GenericParamKind::Const { default, .. } => *default = None,
234 ast::GenericParamKind::Type { default } => *default = None,
235 ast::GenericParamKind::Lifetime => {}
236 }
237 if idx != pointee_param_idx {
242 for bound in &orig_params.bounds {
243 let mut bound = bound.clone();
244 let mut substitution = TypeSubstitution {
245 from_name: pointee_ty_ident.name,
246 to_ty: &s_ty,
247 rewritten: false,
248 };
249 substitution.visit_param_bound(&mut bound, BoundKind::Bound);
250 if substitution.rewritten {
251 params.bounds.push(bound);
254 }
255 }
256 }
257 }
258
259 {
264 let mut substitution =
265 TypeSubstitution { from_name: pointee_ty_ident.name, to_ty: &s_ty, rewritten: false };
266 for bound in &mut self_bounds {
267 substitution.visit_param_bound(bound, BoundKind::Bound);
268 }
269 }
270
271 for predicate in &generics.where_clause.predicates {
298 if let ast::WherePredicateKind::BoundPredicate(bound) = &predicate.kind {
299 let mut substitution = TypeSubstitution {
300 from_name: pointee_ty_ident.name,
301 to_ty: &s_ty,
302 rewritten: false,
303 };
304 let mut kind = ast::WherePredicateKind::BoundPredicate(bound.clone());
305 substitution.visit_where_predicate_kind(&mut kind);
306 if substitution.rewritten {
307 let predicate = ast::WherePredicate {
308 attrs: predicate.attrs.clone(),
309 kind,
310 span: predicate.span,
311 id: ast::DUMMY_NODE_ID,
312 is_placeholder: false,
313 };
314 impl_generics.where_clause.predicates.push(predicate);
315 }
316 }
317 }
318
319 let extra_param = cx.typaram(span, Ident::new(sym::__S, span), self_bounds, None);
320 impl_generics.params.insert(pointee_param_idx + 1, extra_param);
321
322 let gen_args = vec![GenericArg::Type(alt_self_type)];
324 add_impl_block(impl_generics.clone(), sym::DispatchFromDyn, gen_args.clone());
325 add_impl_block(impl_generics.clone(), sym::CoerceUnsized, gen_args);
326}
327
328fn contains_maybe_sized_bound_on_pointee(predicates: &[WherePredicate], pointee: Symbol) -> bool {
329 for bound in predicates {
330 if let ast::WherePredicateKind::BoundPredicate(bound) = &bound.kind
331 && bound.bounded_ty.kind.is_simple_path().is_some_and(|name| name == pointee)
332 {
333 for bound in &bound.bounds {
334 if is_maybe_sized_bound(bound) {
335 return true;
336 }
337 }
338 }
339 }
340 false
341}
342
343fn is_maybe_sized_bound(bound: &GenericBound) -> bool {
344 if let GenericBound::Trait(trait_ref) = bound
345 && let TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. } =
346 trait_ref.modifiers
347 && is_sized_marker(&trait_ref.trait_ref.path)
348 {
349 true
350 } else {
351 false
352 }
353}
354
355fn contains_maybe_sized_bound(bounds: &[GenericBound]) -> bool {
356 bounds.iter().any(is_maybe_sized_bound)
357}
358
359fn path_segment_is_exact_match(path_segments: &[ast::PathSegment], syms: &[Symbol]) -> bool {
360 path_segments.iter().zip(syms).all(|(segment, &symbol)| segment.ident.name == symbol)
361}
362
363fn is_sized_marker(path: &ast::Path) -> bool {
364 const CORE_UNSIZE: [Symbol; 3] = [sym::core, sym::marker, sym::Sized];
365 const STD_UNSIZE: [Symbol; 3] = [sym::std, sym::marker, sym::Sized];
366 if path.segments.len() == 4 && path.is_global() {
367 path_segment_is_exact_match(&path.segments[1..], &CORE_UNSIZE)
368 || path_segment_is_exact_match(&path.segments[1..], &STD_UNSIZE)
369 } else if path.segments.len() == 3 {
370 path_segment_is_exact_match(&path.segments, &CORE_UNSIZE)
371 || path_segment_is_exact_match(&path.segments, &STD_UNSIZE)
372 } else {
373 *path == sym::Sized
374 }
375}
376
377struct TypeSubstitution<'a> {
378 from_name: Symbol,
379 to_ty: &'a ast::Ty,
380 rewritten: bool,
381}
382
383impl<'a> ast::mut_visit::MutVisitor for TypeSubstitution<'a> {
384 fn visit_ty(&mut self, ty: &mut ast::Ty) {
385 if let Some(name) = ty.kind.is_simple_path()
386 && name == self.from_name
387 {
388 *ty = self.to_ty.clone();
389 self.rewritten = true;
390 } else {
391 ast::mut_visit::walk_ty(self, ty);
392 }
393 }
394
395 fn visit_where_predicate_kind(&mut self, kind: &mut ast::WherePredicateKind) {
396 match kind {
397 rustc_ast::WherePredicateKind::BoundPredicate(bound) => {
398 bound
399 .bound_generic_params
400 .flat_map_in_place(|param| self.flat_map_generic_param(param));
401 self.visit_ty(&mut bound.bounded_ty);
402 for bound in &mut bound.bounds {
403 self.visit_param_bound(bound, BoundKind::Bound)
404 }
405 }
406 rustc_ast::WherePredicateKind::RegionPredicate(_)
407 | rustc_ast::WherePredicateKind::EqPredicate(_) => {}
408 }
409 }
410}
411
412struct DetectNonGenericPointeeAttr<'a, 'b> {
413 cx: &'a ExtCtxt<'b>,
414}
415
416impl<'a, 'b> rustc_ast::visit::Visitor<'a> for DetectNonGenericPointeeAttr<'a, 'b> {
417 fn visit_attribute(&mut self, attr: &'a rustc_ast::Attribute) -> Self::Result {
418 if attr.has_name(sym::pointee) {
419 self.cx.dcx().emit_err(errors::NonGenericPointee { span: attr.span });
420 }
421 }
422
423 fn visit_generic_param(&mut self, param: &'a rustc_ast::GenericParam) -> Self::Result {
424 let mut error_on_pointee = AlwaysErrorOnGenericParam { cx: self.cx };
425
426 match ¶m.kind {
427 GenericParamKind::Type { default } => {
428 rustc_ast::visit::visit_opt!(error_on_pointee, visit_ty, default);
446 }
447
448 GenericParamKind::Const { .. } | GenericParamKind::Lifetime => {
449 rustc_ast::visit::walk_generic_param(&mut error_on_pointee, param);
450 }
451 }
452 }
453
454 fn visit_ty(&mut self, t: &'a rustc_ast::Ty) -> Self::Result {
455 let mut error_on_pointee = AlwaysErrorOnGenericParam { cx: self.cx };
456 error_on_pointee.visit_ty(t)
457 }
458}
459
460struct AlwaysErrorOnGenericParam<'a, 'b> {
461 cx: &'a ExtCtxt<'b>,
462}
463
464impl<'a, 'b> rustc_ast::visit::Visitor<'a> for AlwaysErrorOnGenericParam<'a, 'b> {
465 fn visit_attribute(&mut self, attr: &'a rustc_ast::Attribute) -> Self::Result {
466 if attr.has_name(sym::pointee) {
467 self.cx.dcx().emit_err(errors::NonGenericPointee { span: attr.span });
468 }
469 }
470}
471
472#[derive(Diagnostic)]
473#[diag(builtin_macros_coerce_pointee_requires_transparent, code = E0802)]
474struct RequireTransparent {
475 #[primary_span]
476 span: Span,
477}
478
479#[derive(Diagnostic)]
480#[diag(builtin_macros_coerce_pointee_requires_one_field, code = E0802)]
481struct RequireOneField {
482 #[primary_span]
483 span: Span,
484}
485
486#[derive(Diagnostic)]
487#[diag(builtin_macros_coerce_pointee_requires_one_generic, code = E0802)]
488struct RequireOneGeneric {
489 #[primary_span]
490 span: Span,
491}
492
493#[derive(Diagnostic)]
494#[diag(builtin_macros_coerce_pointee_requires_one_pointee, code = E0802)]
495struct RequireOnePointee {
496 #[primary_span]
497 span: Span,
498}
499
500#[derive(Diagnostic)]
501#[diag(builtin_macros_coerce_pointee_too_many_pointees, code = E0802)]
502struct TooManyPointees {
503 #[primary_span]
504 one: Span,
505 #[label]
506 another: Span,
507}
508
509#[derive(Diagnostic)]
510#[diag(builtin_macros_coerce_pointee_requires_maybe_sized, code = E0802)]
511struct RequiresMaybeSized {
512 #[primary_span]
513 span: Span,
514 name: Ident,
515}