1use rustc_abi::FieldIdx;
2use rustc_data_structures::flat_map_in_place::FlatMapInPlace;
3use rustc_hir::LangItem;
4use rustc_index::IndexVec;
5use rustc_index::bit_set::{DenseBitSet, GrowableBitSet};
6use rustc_middle::bug;
7use rustc_middle::mir::visit::*;
8use rustc_middle::mir::*;
9use rustc_middle::ty::{self, Ty, TyCtxt};
10use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
11use tracing::{debug, instrument};
12
13use crate::patch::MirPatch;
14
15pub(super) struct ScalarReplacementOfAggregates;
16
17impl<'tcx> crate::MirPass<'tcx> for ScalarReplacementOfAggregates {
18 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
19 sess.mir_opt_level() >= 2
20 }
21
22 #[instrument(level = "debug", skip(self, tcx, body))]
23 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
24 debug!(def_id = ?body.source.def_id());
25
26 if tcx.type_of(body.source.def_id()).instantiate_identity().is_coroutine() {
28 return;
29 }
30
31 let mut excluded = excluded_locals(body);
32 let typing_env = body.typing_env(tcx);
33 loop {
34 debug!(?excluded);
35 let escaping = escaping_locals(tcx, &excluded, body);
36 debug!(?escaping);
37 let replacements = compute_flattening(tcx, typing_env, body, escaping);
38 debug!(?replacements);
39 let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
40 if !all_dead_locals.is_empty() {
41 excluded.union(&all_dead_locals);
42 excluded = {
43 let mut growable = GrowableBitSet::from(excluded);
44 growable.ensure(body.local_decls.len());
45 growable.into()
46 };
47 } else {
48 break;
49 }
50 }
51 }
52
53 fn is_required(&self) -> bool {
54 false
55 }
56}
57
58fn escaping_locals<'tcx>(
66 tcx: TyCtxt<'tcx>,
67 excluded: &DenseBitSet<Local>,
68 body: &Body<'tcx>,
69) -> DenseBitSet<Local> {
70 let is_excluded_ty = |ty: Ty<'tcx>| {
71 if ty.is_union() || ty.is_enum() {
72 return true;
73 }
74 if let ty::Adt(def, _args) = ty.kind()
75 && (def.repr().simd() || tcx.is_lang_item(def.did(), LangItem::DynMetadata))
76 {
77 return true;
84 }
85 false
87 };
88
89 let mut set = DenseBitSet::new_empty(body.local_decls.len());
90 set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
91 for (local, decl) in body.local_decls().iter_enumerated() {
92 if excluded.contains(local) || is_excluded_ty(decl.ty) {
93 set.insert(local);
94 }
95 }
96 let mut visitor = EscapeVisitor { set };
97 visitor.visit_body(body);
98 return visitor.set;
99
100 struct EscapeVisitor {
101 set: DenseBitSet<Local>,
102 }
103
104 impl<'tcx> Visitor<'tcx> for EscapeVisitor {
105 fn visit_local(&mut self, local: Local, _: PlaceContext, _: Location) {
106 self.set.insert(local);
107 }
108
109 fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
110 if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
112 return;
113 }
114 self.super_place(place, context, location);
115 }
116
117 fn visit_assign(
118 &mut self,
119 lvalue: &Place<'tcx>,
120 rvalue: &Rvalue<'tcx>,
121 location: Location,
122 ) {
123 if lvalue.as_local().is_some() {
124 match rvalue {
125 Rvalue::Aggregate(..) | Rvalue::Use(..) => {
127 self.visit_rvalue(rvalue, location);
128 return;
129 }
130 _ => {}
131 }
132 }
133 self.super_assign(lvalue, rvalue, location)
134 }
135
136 fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
137 match statement.kind {
138 StatementKind::StorageLive(..)
140 | StatementKind::StorageDead(..)
141 | StatementKind::Deinit(..) => return,
142 _ => self.super_statement(statement, location),
143 }
144 }
145
146 fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {}
149 }
150}
151
152#[derive(Default, Debug)]
153struct ReplacementMap<'tcx> {
154 fragments: IndexVec<Local, Option<IndexVec<FieldIdx, Option<(Ty<'tcx>, Local)>>>>,
157}
158
159impl<'tcx> ReplacementMap<'tcx> {
160 fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
161 let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else {
162 return None;
163 };
164 let fields = self.fragments[place.local].as_ref()?;
165 let (_, new_local) = fields[f]?;
166 Some(Place { local: new_local, projection: tcx.mk_place_elems(rest) })
167 }
168
169 fn place_fragments(
170 &self,
171 place: Place<'tcx>,
172 ) -> Option<impl Iterator<Item = (FieldIdx, Ty<'tcx>, Local)>> {
173 let local = place.as_local()?;
174 let fields = self.fragments[local].as_ref()?;
175 Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| {
176 let (ty, local) = opt_ty_local?;
177 Some((field, ty, local))
178 }))
179 }
180}
181
182fn compute_flattening<'tcx>(
187 tcx: TyCtxt<'tcx>,
188 typing_env: ty::TypingEnv<'tcx>,
189 body: &mut Body<'tcx>,
190 escaping: DenseBitSet<Local>,
191) -> ReplacementMap<'tcx> {
192 let mut fragments = IndexVec::from_elem(None, &body.local_decls);
193
194 for local in body.local_decls.indices() {
195 if escaping.contains(local) {
196 continue;
197 }
198 let decl = body.local_decls[local].clone();
199 let ty = decl.ty;
200 iter_fields(ty, tcx, typing_env, |variant, field, field_ty| {
201 if variant.is_some() {
202 return;
204 };
205 let new_local =
206 body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() });
207 fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local));
208 });
209 }
210 ReplacementMap { fragments }
211}
212
213fn replace_flattened_locals<'tcx>(
215 tcx: TyCtxt<'tcx>,
216 body: &mut Body<'tcx>,
217 replacements: ReplacementMap<'tcx>,
218) -> DenseBitSet<Local> {
219 let mut all_dead_locals = DenseBitSet::new_empty(replacements.fragments.len());
220 for (local, replacements) in replacements.fragments.iter_enumerated() {
221 if replacements.is_some() {
222 all_dead_locals.insert(local);
223 }
224 }
225 debug!(?all_dead_locals);
226 if all_dead_locals.is_empty() {
227 return all_dead_locals;
228 }
229
230 let mut visitor = ReplacementVisitor {
231 tcx,
232 local_decls: &body.local_decls,
233 replacements: &replacements,
234 all_dead_locals,
235 patch: MirPatch::new(body),
236 };
237 for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
238 visitor.visit_basic_block_data(bb, data);
239 }
240 for scope in &mut body.source_scopes {
241 visitor.visit_source_scope_data(scope);
242 }
243 for (index, annotation) in body.user_type_annotations.iter_enumerated_mut() {
244 visitor.visit_user_type_annotation(index, annotation);
245 }
246 visitor.expand_var_debug_info(&mut body.var_debug_info);
247 let ReplacementVisitor { patch, all_dead_locals, .. } = visitor;
248 patch.apply(body);
249 all_dead_locals
250}
251
252struct ReplacementVisitor<'tcx, 'll> {
253 tcx: TyCtxt<'tcx>,
254 local_decls: &'ll LocalDecls<'tcx>,
256 replacements: &'ll ReplacementMap<'tcx>,
258 all_dead_locals: DenseBitSet<Local>,
260 patch: MirPatch<'tcx>,
261}
262
263impl<'tcx> ReplacementVisitor<'tcx, '_> {
264 #[instrument(level = "trace", skip(self))]
265 fn expand_var_debug_info(&mut self, var_debug_info: &mut Vec<VarDebugInfo<'tcx>>) {
266 var_debug_info.flat_map_in_place(|mut var_debug_info| {
267 let place = match var_debug_info.value {
268 VarDebugInfoContents::Const(_) => return vec![var_debug_info],
269 VarDebugInfoContents::Place(ref mut place) => place,
270 };
271
272 if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
273 *place = repl;
274 return vec![var_debug_info];
275 }
276
277 let Some(parts) = self.replacements.place_fragments(*place) else {
278 return vec![var_debug_info];
279 };
280
281 let ty = place.ty(self.local_decls, self.tcx).ty;
282
283 parts
284 .map(|(field, field_ty, replacement_local)| {
285 let mut var_debug_info = var_debug_info.clone();
286 let composite = var_debug_info.composite.get_or_insert_with(|| {
287 Box::new(VarDebugInfoFragment { ty, projection: Vec::new() })
288 });
289 composite.projection.push(PlaceElem::Field(field, field_ty));
290
291 var_debug_info.value = VarDebugInfoContents::Place(replacement_local.into());
292 var_debug_info
293 })
294 .collect()
295 });
296 }
297}
298
299impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
300 fn tcx(&self) -> TyCtxt<'tcx> {
301 self.tcx
302 }
303
304 fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
305 if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
306 *place = repl
307 } else {
308 self.super_place(place, context, location)
309 }
310 }
311
312 #[instrument(level = "trace", skip(self))]
313 fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
314 match statement.kind {
315 StatementKind::StorageLive(l) => {
317 if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
318 for (_, _, fl) in final_locals {
319 self.patch.add_statement(location, StatementKind::StorageLive(fl));
320 }
321 statement.make_nop();
322 }
323 return;
324 }
325 StatementKind::StorageDead(l) => {
326 if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
327 for (_, _, fl) in final_locals {
328 self.patch.add_statement(location, StatementKind::StorageDead(fl));
329 }
330 statement.make_nop();
331 }
332 return;
333 }
334 StatementKind::Deinit(box place) => {
335 if let Some(final_locals) = self.replacements.place_fragments(place) {
336 for (_, _, fl) in final_locals {
337 self.patch
338 .add_statement(location, StatementKind::Deinit(Box::new(fl.into())));
339 }
340 statement.make_nop();
341 return;
342 }
343 }
344
345 StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => {
353 if let Some(local) = place.as_local()
354 && let Some(final_locals) = &self.replacements.fragments[local]
355 {
356 let operands = std::mem::take(operands);
358 for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) {
359 if let Some((_, new_local)) = opt_ty_local {
360 self.visit_operand(&mut operand, location);
362
363 let rvalue = Rvalue::Use(operand);
364 self.patch.add_statement(
365 location,
366 StatementKind::Assign(Box::new((new_local.into(), rvalue))),
367 );
368 }
369 }
370 statement.make_nop();
371 return;
372 }
373 }
374
375 StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => {
384 if let Some(final_locals) = self.replacements.place_fragments(place) {
385 let location = location.successor_within_block();
387 for (field, ty, new_local) in final_locals {
388 let rplace = self.tcx.mk_place_field(place, field, ty);
389 let rvalue = Rvalue::Use(Operand::Move(rplace));
390 self.patch.add_statement(
391 location,
392 StatementKind::Assign(Box::new((new_local.into(), rvalue))),
393 );
394 }
395 return;
397 }
398 }
399
400 StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => {
408 let (rplace, copy) = match *op {
409 Operand::Copy(rplace) => (rplace, true),
410 Operand::Move(rplace) => (rplace, false),
411 Operand::Constant(_) => bug!(),
412 };
413 if let Some(final_locals) = self.replacements.place_fragments(lhs) {
414 for (field, ty, new_local) in final_locals {
415 let rplace = self.tcx.mk_place_field(rplace, field, ty);
416 debug!(?rplace);
417 let rplace = self
418 .replacements
419 .replace_place(self.tcx, rplace.as_ref())
420 .unwrap_or(rplace);
421 debug!(?rplace);
422 let rvalue = if copy {
423 Rvalue::Use(Operand::Copy(rplace))
424 } else {
425 Rvalue::Use(Operand::Move(rplace))
426 };
427 self.patch.add_statement(
428 location,
429 StatementKind::Assign(Box::new((new_local.into(), rvalue))),
430 );
431 }
432 statement.make_nop();
433 return;
434 }
435 }
436
437 _ => {}
438 }
439 self.super_statement(statement, location)
440 }
441
442 fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
443 assert!(!self.all_dead_locals.contains(*local));
444 }
445}