rustc_builtin_macros/deriving/cmp/partial_eq.rs
1use rustc_ast::{BinOpKind, BorrowKind, Expr, ExprKind, MetaItem, Mutability};
2use rustc_expand::base::{Annotatable, ExtCtxt};
3use rustc_span::{Span, sym};
4use thin_vec::thin_vec;
5
6use crate::deriving::generic::ty::*;
7use crate::deriving::generic::*;
8use crate::deriving::{path_local, path_std};
9
10/// Expands a `#[derive(PartialEq)]` attribute into an implementation for the
11/// target item.
12pub(crate) fn expand_deriving_partial_eq(
13 cx: &ExtCtxt<'_>,
14 span: Span,
15 mitem: &MetaItem,
16 item: &Annotatable,
17 push: &mut dyn FnMut(Annotatable),
18 is_const: bool,
19) {
20 let structural_trait_def = TraitDef {
21 span,
22 path: path_std!(marker::StructuralPartialEq),
23 skip_path_as_bound: true, // crucial!
24 needs_copy_as_bound_if_packed: false,
25 additional_bounds: Vec::new(),
26 // We really don't support unions, but that's already checked by the impl generated below;
27 // a second check here would lead to redundant error messages.
28 supports_unions: true,
29 methods: Vec::new(),
30 associated_types: Vec::new(),
31 is_const: false,
32 is_staged_api_crate: cx.ecfg.features.staged_api(),
33 };
34 structural_trait_def.expand(cx, mitem, item, push);
35
36 // No need to generate `ne`, the default suffices, and not generating it is
37 // faster.
38 let methods = vec![MethodDef {
39 name: sym::eq,
40 generics: Bounds::empty(),
41 explicit_self: true,
42 nonself_args: vec![(self_ref(), sym::other)],
43 ret_ty: Path(path_local!(bool)),
44 attributes: thin_vec![cx.attr_word(sym::inline, span)],
45 fieldless_variants_strategy: FieldlessVariantsStrategy::Unify,
46 combine_substructure: combine_substructure(Box::new(|a, b, c| {
47 BlockOrExpr::new_expr(get_substructure_equality_expr(a, b, c))
48 })),
49 }];
50
51 let trait_def = TraitDef {
52 span,
53 path: path_std!(cmp::PartialEq),
54 skip_path_as_bound: false,
55 needs_copy_as_bound_if_packed: true,
56 additional_bounds: Vec::new(),
57 supports_unions: false,
58 methods,
59 associated_types: Vec::new(),
60 is_const,
61 is_staged_api_crate: cx.ecfg.features.staged_api(),
62 };
63 trait_def.expand(cx, mitem, item, push)
64}
65
66/// Generates the equality expression for a struct or enum variant when deriving
67/// `PartialEq`.
68///
69/// This function generates an expression that checks if all fields of a struct
70/// or enum variant are equal.
71/// - Scalar fields are compared first for efficiency, followed by compound
72/// fields.
73/// - If there are no fields, returns `true` (fieldless types are always equal).
74///
75/// Whether a field is considered "scalar" is determined by comparing the symbol
76/// of its type to a set of known scalar type symbols (e.g., `i32`, `u8`, etc).
77/// This check is based on the type's symbol.
78///
79/// ### Example 1
80/// ```
81/// #[derive(PartialEq)]
82/// struct i32;
83///
84/// // Here, `field_2` is of type `i32`, but since it's a user-defined type (not
85/// // the primitive), it will not be treated as scalar. The function will still
86/// // check equality of `field_2` first because the symbol matches `i32`.
87/// #[derive(PartialEq)]
88/// struct Struct {
89/// field_1: &'static str,
90/// field_2: i32,
91/// }
92/// ```
93///
94/// ### Example 2
95/// ```
96/// mod ty {
97/// pub type i32 = i32;
98/// }
99///
100/// // Here, `field_2` is of type `ty::i32`, which is a type alias for `i32`.
101/// // However, the function will not reorder the fields because the symbol for
102/// // `ty::i32` does not match the symbol for the primitive `i32`
103/// // ("ty::i32" != "i32").
104/// #[derive(PartialEq)]
105/// struct Struct {
106/// field_1: &'static str,
107/// field_2: ty::i32,
108/// }
109/// ```
110///
111/// For enums, the discriminant is compared first, then the rest of the fields.
112///
113/// # Panics
114///
115/// If called on static or all-fieldless enums/structs, which should not occur
116/// during derive expansion.
117fn get_substructure_equality_expr(
118 cx: &ExtCtxt<'_>,
119 span: Span,
120 substructure: &Substructure<'_>,
121) -> Box<Expr> {
122 use SubstructureFields::*;
123
124 match substructure.fields {
125 EnumMatching(.., fields) | Struct(.., fields) => {
126 let combine = move |acc, field| {
127 let rhs = get_field_equality_expr(cx, field);
128 if let Some(lhs) = acc {
129 // Combine the previous comparison with the current field
130 // using logical AND.
131 return Some(cx.expr_binary(field.span, BinOpKind::And, lhs, rhs));
132 }
133 // Start the chain with the first field's comparison.
134 Some(rhs)
135 };
136
137 // First compare scalar fields, then compound fields, combining all
138 // with logical AND.
139 return fields
140 .iter()
141 .filter(|field| !field.maybe_scalar)
142 .fold(fields.iter().filter(|field| field.maybe_scalar).fold(None, combine), combine)
143 // If there are no fields, treat as always equal.
144 .unwrap_or_else(|| cx.expr_bool(span, true));
145 }
146 EnumDiscr(disc, match_expr) => {
147 let lhs = get_field_equality_expr(cx, disc);
148 let Some(match_expr) = match_expr else {
149 return lhs;
150 };
151 // Compare the discriminant first (cheaper), then the rest of the
152 // fields.
153 return cx.expr_binary(disc.span, BinOpKind::And, lhs, match_expr.clone());
154 }
155 StaticEnum(..) => cx.dcx().span_bug(
156 span,
157 "unexpected static enum encountered during `derive(PartialEq)` expansion",
158 ),
159 StaticStruct(..) => cx.dcx().span_bug(
160 span,
161 "unexpected static struct encountered during `derive(PartialEq)` expansion",
162 ),
163 AllFieldlessEnum(..) => cx.dcx().span_bug(
164 span,
165 "unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion",
166 ),
167 }
168}
169
170/// Generates an equality comparison expression for a single struct or enum
171/// field.
172///
173/// This function produces an AST expression that compares the `self` and
174/// `other` values for a field using `==`. It removes any leading references
175/// from both sides for readability. If the field is a block expression, it is
176/// wrapped in parentheses to ensure valid syntax.
177///
178/// # Panics
179///
180/// Panics if there are not exactly two arguments to compare (should be `self`
181/// and `other`).
182fn get_field_equality_expr(cx: &ExtCtxt<'_>, field: &FieldInfo) -> Box<Expr> {
183 let [rhs] = &field.other_selflike_exprs[..] else {
184 cx.dcx().span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
185 };
186
187 cx.expr_binary(
188 field.span,
189 BinOpKind::Eq,
190 wrap_block_expr(cx, peel_refs(&field.self_expr)),
191 wrap_block_expr(cx, peel_refs(rhs)),
192 )
193}
194
195/// Removes all leading immutable references from an expression.
196///
197/// This is used to strip away any number of leading `&` from an expression
198/// (e.g., `&&&T` becomes `T`). Only removes immutable references; mutable
199/// references are preserved.
200fn peel_refs(mut expr: &Box<Expr>) -> Box<Expr> {
201 while let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = &expr.kind {
202 expr = &inner;
203 }
204 expr.clone()
205}
206
207/// Wraps a block expression in parentheses to ensure valid AST in macro
208/// expansion output.
209///
210/// If the given expression is a block, it is wrapped in parentheses; otherwise,
211/// it is returned unchanged.
212fn wrap_block_expr(cx: &ExtCtxt<'_>, expr: Box<Expr>) -> Box<Expr> {
213 if matches!(&expr.kind, ExprKind::Block(..)) {
214 return cx.expr_paren(expr.span, expr);
215 }
216 expr
217}