rustc_mir_transform/
sroa.rs

1use rustc_abi::FieldIdx;
2use rustc_data_structures::flat_map_in_place::FlatMapInPlace;
3use rustc_hir::LangItem;
4use rustc_index::IndexVec;
5use rustc_index::bit_set::{DenseBitSet, GrowableBitSet};
6use rustc_middle::bug;
7use rustc_middle::mir::visit::*;
8use rustc_middle::mir::*;
9use rustc_middle::ty::{self, Ty, TyCtxt};
10use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
11use tracing::{debug, instrument};
12
13use crate::patch::MirPatch;
14
15pub(super) struct ScalarReplacementOfAggregates;
16
17impl<'tcx> crate::MirPass<'tcx> for ScalarReplacementOfAggregates {
18    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
19        sess.mir_opt_level() >= 2
20    }
21
22    #[instrument(level = "debug", skip(self, tcx, body))]
23    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
24        debug!(def_id = ?body.source.def_id());
25
26        // Avoid query cycles (coroutines require optimized MIR for layout).
27        if tcx.type_of(body.source.def_id()).instantiate_identity().is_coroutine() {
28            return;
29        }
30
31        let mut excluded = excluded_locals(body);
32        let typing_env = body.typing_env(tcx);
33        loop {
34            debug!(?excluded);
35            let escaping = escaping_locals(tcx, &excluded, body);
36            debug!(?escaping);
37            let replacements = compute_flattening(tcx, typing_env, body, escaping);
38            debug!(?replacements);
39            let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
40            if !all_dead_locals.is_empty() {
41                excluded.union(&all_dead_locals);
42                excluded = {
43                    let mut growable = GrowableBitSet::from(excluded);
44                    growable.ensure(body.local_decls.len());
45                    growable.into()
46                };
47            } else {
48                break;
49            }
50        }
51    }
52
53    fn is_required(&self) -> bool {
54        false
55    }
56}
57
58/// Identify all locals that are not eligible for SROA.
59///
60/// There are 3 cases:
61/// - the aggregated local is used or passed to other code (function parameters and arguments);
62/// - the locals is a union or an enum;
63/// - the local's address is taken, and thus the relative addresses of the fields are observable to
64///   client code.
65fn escaping_locals<'tcx>(
66    tcx: TyCtxt<'tcx>,
67    excluded: &DenseBitSet<Local>,
68    body: &Body<'tcx>,
69) -> DenseBitSet<Local> {
70    let is_excluded_ty = |ty: Ty<'tcx>| {
71        if ty.is_union() || ty.is_enum() {
72            return true;
73        }
74        if let ty::Adt(def, _args) = ty.kind()
75            && (def.repr().simd() || tcx.is_lang_item(def.did(), LangItem::DynMetadata))
76        {
77            // Exclude #[repr(simd)] types so that they are not de-optimized into an array
78            // (MCP#838 banned projections into SIMD types, but if the value is unused
79            // this pass sees "all the uses are of the fields" and expands it.)
80
81            // codegen wants to see the `DynMetadata<T>`,
82            // not the inner reference-to-opaque-type.
83            return true;
84        }
85        // Default for non-ADTs
86        false
87    };
88
89    let mut set = DenseBitSet::new_empty(body.local_decls.len());
90    set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
91    for (local, decl) in body.local_decls().iter_enumerated() {
92        if excluded.contains(local) || is_excluded_ty(decl.ty) {
93            set.insert(local);
94        }
95    }
96    let mut visitor = EscapeVisitor { set };
97    visitor.visit_body(body);
98    return visitor.set;
99
100    struct EscapeVisitor {
101        set: DenseBitSet<Local>,
102    }
103
104    impl<'tcx> Visitor<'tcx> for EscapeVisitor {
105        fn visit_local(&mut self, local: Local, _: PlaceContext, _: Location) {
106            self.set.insert(local);
107        }
108
109        fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
110            // Mirror the implementation in PreFlattenVisitor.
111            if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
112                return;
113            }
114            self.super_place(place, context, location);
115        }
116
117        fn visit_assign(
118            &mut self,
119            lvalue: &Place<'tcx>,
120            rvalue: &Rvalue<'tcx>,
121            location: Location,
122        ) {
123            if lvalue.as_local().is_some() {
124                match rvalue {
125                    // Aggregate assignments are expanded in run_pass.
126                    Rvalue::Aggregate(..) | Rvalue::Use(..) => {
127                        self.visit_rvalue(rvalue, location);
128                        return;
129                    }
130                    _ => {}
131                }
132            }
133            self.super_assign(lvalue, rvalue, location)
134        }
135
136        fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
137            match statement.kind {
138                // Storage statements are expanded in run_pass.
139                StatementKind::StorageLive(..)
140                | StatementKind::StorageDead(..)
141                | StatementKind::Deinit(..) => return,
142                _ => self.super_statement(statement, location),
143            }
144        }
145
146        // We ignore anything that happens in debuginfo, since we expand it using
147        // `VarDebugInfoFragment`.
148        fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {}
149    }
150}
151
152#[derive(Default, Debug)]
153struct ReplacementMap<'tcx> {
154    /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
155    /// and deinit statement and debuginfo.
156    fragments: IndexVec<Local, Option<IndexVec<FieldIdx, Option<(Ty<'tcx>, Local)>>>>,
157}
158
159impl<'tcx> ReplacementMap<'tcx> {
160    fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
161        let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else {
162            return None;
163        };
164        let fields = self.fragments[place.local].as_ref()?;
165        let (_, new_local) = fields[f]?;
166        Some(Place { local: new_local, projection: tcx.mk_place_elems(rest) })
167    }
168
169    fn place_fragments(
170        &self,
171        place: Place<'tcx>,
172    ) -> Option<impl Iterator<Item = (FieldIdx, Ty<'tcx>, Local)>> {
173        let local = place.as_local()?;
174        let fields = self.fragments[local].as_ref()?;
175        Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| {
176            let (ty, local) = opt_ty_local?;
177            Some((field, ty, local))
178        }))
179    }
180}
181
182/// Compute the replacement of flattened places into locals.
183///
184/// For each eligible place, we assign a new local to each accessed field.
185/// The replacement will be done later in `ReplacementVisitor`.
186fn compute_flattening<'tcx>(
187    tcx: TyCtxt<'tcx>,
188    typing_env: ty::TypingEnv<'tcx>,
189    body: &mut Body<'tcx>,
190    escaping: DenseBitSet<Local>,
191) -> ReplacementMap<'tcx> {
192    let mut fragments = IndexVec::from_elem(None, &body.local_decls);
193
194    for local in body.local_decls.indices() {
195        if escaping.contains(local) {
196            continue;
197        }
198        let decl = body.local_decls[local].clone();
199        let ty = decl.ty;
200        iter_fields(ty, tcx, typing_env, |variant, field, field_ty| {
201            if variant.is_some() {
202                // Downcasts are currently not supported.
203                return;
204            };
205            let new_local =
206                body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() });
207            fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local));
208        });
209    }
210    ReplacementMap { fragments }
211}
212
213/// Perform the replacement computed by `compute_flattening`.
214fn replace_flattened_locals<'tcx>(
215    tcx: TyCtxt<'tcx>,
216    body: &mut Body<'tcx>,
217    replacements: ReplacementMap<'tcx>,
218) -> DenseBitSet<Local> {
219    let mut all_dead_locals = DenseBitSet::new_empty(replacements.fragments.len());
220    for (local, replacements) in replacements.fragments.iter_enumerated() {
221        if replacements.is_some() {
222            all_dead_locals.insert(local);
223        }
224    }
225    debug!(?all_dead_locals);
226    if all_dead_locals.is_empty() {
227        return all_dead_locals;
228    }
229
230    let mut visitor = ReplacementVisitor {
231        tcx,
232        local_decls: &body.local_decls,
233        replacements: &replacements,
234        all_dead_locals,
235        patch: MirPatch::new(body),
236    };
237    for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
238        visitor.visit_basic_block_data(bb, data);
239    }
240    for scope in &mut body.source_scopes {
241        visitor.visit_source_scope_data(scope);
242    }
243    for (index, annotation) in body.user_type_annotations.iter_enumerated_mut() {
244        visitor.visit_user_type_annotation(index, annotation);
245    }
246    visitor.expand_var_debug_info(&mut body.var_debug_info);
247    let ReplacementVisitor { patch, all_dead_locals, .. } = visitor;
248    patch.apply(body);
249    all_dead_locals
250}
251
252struct ReplacementVisitor<'tcx, 'll> {
253    tcx: TyCtxt<'tcx>,
254    /// This is only used to compute the type for `VarDebugInfoFragment`.
255    local_decls: &'ll LocalDecls<'tcx>,
256    /// Work to do.
257    replacements: &'ll ReplacementMap<'tcx>,
258    /// This is used to check that we are not leaving references to replaced locals behind.
259    all_dead_locals: DenseBitSet<Local>,
260    patch: MirPatch<'tcx>,
261}
262
263impl<'tcx> ReplacementVisitor<'tcx, '_> {
264    #[instrument(level = "trace", skip(self))]
265    fn expand_var_debug_info(&mut self, var_debug_info: &mut Vec<VarDebugInfo<'tcx>>) {
266        var_debug_info.flat_map_in_place(|mut var_debug_info| {
267            let place = match var_debug_info.value {
268                VarDebugInfoContents::Const(_) => return vec![var_debug_info],
269                VarDebugInfoContents::Place(ref mut place) => place,
270            };
271
272            if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
273                *place = repl;
274                return vec![var_debug_info];
275            }
276
277            let Some(parts) = self.replacements.place_fragments(*place) else {
278                return vec![var_debug_info];
279            };
280
281            let ty = place.ty(self.local_decls, self.tcx).ty;
282
283            parts
284                .map(|(field, field_ty, replacement_local)| {
285                    let mut var_debug_info = var_debug_info.clone();
286                    let composite = var_debug_info.composite.get_or_insert_with(|| {
287                        Box::new(VarDebugInfoFragment { ty, projection: Vec::new() })
288                    });
289                    composite.projection.push(PlaceElem::Field(field, field_ty));
290
291                    var_debug_info.value = VarDebugInfoContents::Place(replacement_local.into());
292                    var_debug_info
293                })
294                .collect()
295        });
296    }
297}
298
299impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
300    fn tcx(&self) -> TyCtxt<'tcx> {
301        self.tcx
302    }
303
304    fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
305        if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
306            *place = repl
307        } else {
308            self.super_place(place, context, location)
309        }
310    }
311
312    #[instrument(level = "trace", skip(self))]
313    fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
314        match statement.kind {
315            // Duplicate storage and deinit statements, as they pretty much apply to all fields.
316            StatementKind::StorageLive(l) => {
317                if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
318                    for (_, _, fl) in final_locals {
319                        self.patch.add_statement(location, StatementKind::StorageLive(fl));
320                    }
321                    statement.make_nop();
322                }
323                return;
324            }
325            StatementKind::StorageDead(l) => {
326                if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
327                    for (_, _, fl) in final_locals {
328                        self.patch.add_statement(location, StatementKind::StorageDead(fl));
329                    }
330                    statement.make_nop();
331                }
332                return;
333            }
334            StatementKind::Deinit(box place) => {
335                if let Some(final_locals) = self.replacements.place_fragments(place) {
336                    for (_, _, fl) in final_locals {
337                        self.patch
338                            .add_statement(location, StatementKind::Deinit(Box::new(fl.into())));
339                    }
340                    statement.make_nop();
341                    return;
342                }
343            }
344
345            // We have `a = Struct { 0: x, 1: y, .. }`.
346            // We replace it by
347            // ```
348            // a_0 = x
349            // a_1 = y
350            // ...
351            // ```
352            StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => {
353                if let Some(local) = place.as_local()
354                    && let Some(final_locals) = &self.replacements.fragments[local]
355                {
356                    // This is ok as we delete the statement later.
357                    let operands = std::mem::take(operands);
358                    for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) {
359                        if let Some((_, new_local)) = opt_ty_local {
360                            // Replace mentions of SROA'd locals that appear in the operand.
361                            self.visit_operand(&mut operand, location);
362
363                            let rvalue = Rvalue::Use(operand);
364                            self.patch.add_statement(
365                                location,
366                                StatementKind::Assign(Box::new((new_local.into(), rvalue))),
367                            );
368                        }
369                    }
370                    statement.make_nop();
371                    return;
372                }
373            }
374
375            // We have `a = some constant`
376            // We add the projections.
377            // ```
378            // a_0 = a.0
379            // a_1 = a.1
380            // ...
381            // ```
382            // ConstProp will pick up the pieces and replace them by actual constants.
383            StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => {
384                if let Some(final_locals) = self.replacements.place_fragments(place) {
385                    // Put the deaggregated statements *after* the original one.
386                    let location = location.successor_within_block();
387                    for (field, ty, new_local) in final_locals {
388                        let rplace = self.tcx.mk_place_field(place, field, ty);
389                        let rvalue = Rvalue::Use(Operand::Move(rplace));
390                        self.patch.add_statement(
391                            location,
392                            StatementKind::Assign(Box::new((new_local.into(), rvalue))),
393                        );
394                    }
395                    // We still need `place.local` to exist, so don't make it nop.
396                    return;
397                }
398            }
399
400            // We have `a = move? place`
401            // We replace it by
402            // ```
403            // a_0 = move? place.0
404            // a_1 = move? place.1
405            // ...
406            // ```
407            StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => {
408                let (rplace, copy) = match *op {
409                    Operand::Copy(rplace) => (rplace, true),
410                    Operand::Move(rplace) => (rplace, false),
411                    Operand::Constant(_) => bug!(),
412                };
413                if let Some(final_locals) = self.replacements.place_fragments(lhs) {
414                    for (field, ty, new_local) in final_locals {
415                        let rplace = self.tcx.mk_place_field(rplace, field, ty);
416                        debug!(?rplace);
417                        let rplace = self
418                            .replacements
419                            .replace_place(self.tcx, rplace.as_ref())
420                            .unwrap_or(rplace);
421                        debug!(?rplace);
422                        let rvalue = if copy {
423                            Rvalue::Use(Operand::Copy(rplace))
424                        } else {
425                            Rvalue::Use(Operand::Move(rplace))
426                        };
427                        self.patch.add_statement(
428                            location,
429                            StatementKind::Assign(Box::new((new_local.into(), rvalue))),
430                        );
431                    }
432                    statement.make_nop();
433                    return;
434                }
435            }
436
437            _ => {}
438        }
439        self.super_statement(statement, location)
440    }
441
442    fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
443        assert!(!self.all_dead_locals.contains(*local));
444    }
445}