1use itertools::Itertools as _;
38use rustc_index::{Idx, IndexSlice, IndexVec};
39use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
40use rustc_middle::mir::*;
41use rustc_middle::ty::TyCtxt;
42use rustc_span::DUMMY_SP;
43use smallvec::SmallVec;
44use tracing::{debug, trace};
45
46pub(super) enum SimplifyCfg {
47 Initial,
48 PromoteConsts,
49 RemoveFalseEdges,
50 PostAnalysis,
52 PreOptimizations,
55 Final,
56 MakeShim,
57 AfterUnreachableEnumBranching,
58}
59
60impl SimplifyCfg {
61 fn name(&self) -> &'static str {
62 match self {
63 SimplifyCfg::Initial => "SimplifyCfg-initial",
64 SimplifyCfg::PromoteConsts => "SimplifyCfg-promote-consts",
65 SimplifyCfg::RemoveFalseEdges => "SimplifyCfg-remove-false-edges",
66 SimplifyCfg::PostAnalysis => "SimplifyCfg-post-analysis",
67 SimplifyCfg::PreOptimizations => "SimplifyCfg-pre-optimizations",
68 SimplifyCfg::Final => "SimplifyCfg-final",
69 SimplifyCfg::MakeShim => "SimplifyCfg-make_shim",
70 SimplifyCfg::AfterUnreachableEnumBranching => {
71 "SimplifyCfg-after-unreachable-enum-branching"
72 }
73 }
74 }
75}
76
77pub(super) fn simplify_cfg<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
78 if CfgSimplifier::new(tcx, body).simplify() {
79 body.basic_blocks.invalidate_cfg_cache();
82 }
83 remove_dead_blocks(body);
84
85 body.basic_blocks.as_mut_preserves_cfg().shrink_to_fit();
87}
88
89impl<'tcx> crate::MirPass<'tcx> for SimplifyCfg {
90 fn name(&self) -> &'static str {
91 self.name()
92 }
93
94 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
95 debug!("SimplifyCfg({:?}) - simplifying {:?}", self.name(), body.source);
96 simplify_cfg(tcx, body);
97 }
98
99 fn is_required(&self) -> bool {
100 false
101 }
102}
103
104struct CfgSimplifier<'a, 'tcx> {
105 preserve_switch_reads: bool,
106 basic_blocks: &'a mut IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
107 pred_count: IndexVec<BasicBlock, u32>,
108}
109
110impl<'a, 'tcx> CfgSimplifier<'a, 'tcx> {
111 fn new(tcx: TyCtxt<'tcx>, body: &'a mut Body<'tcx>) -> Self {
112 let mut pred_count = IndexVec::from_elem(0u32, &body.basic_blocks);
113
114 pred_count[START_BLOCK] = 1;
117
118 for (_, data) in traversal::preorder(body) {
119 if let Some(ref term) = data.terminator {
120 for tgt in term.successors() {
121 pred_count[tgt] += 1;
122 }
123 }
124 }
125
126 let preserve_switch_reads = matches!(body.phase, MirPhase::Built | MirPhase::Analysis(_))
128 || tcx.sess.opts.unstable_opts.mir_preserve_ub;
129 let basic_blocks = body.basic_blocks.as_mut_preserves_cfg();
131
132 CfgSimplifier { preserve_switch_reads, basic_blocks, pred_count }
133 }
134
135 #[must_use]
138 fn simplify(mut self) -> bool {
139 self.strip_nops();
140
141 let mut merged_blocks = Vec::new();
146 let mut outer_changed = false;
147 loop {
148 let mut changed = false;
149
150 for bb in self.basic_blocks.indices() {
151 if self.pred_count[bb] == 0 {
152 continue;
153 }
154
155 debug!("simplifying {:?}", bb);
156
157 let mut terminator =
158 self.basic_blocks[bb].terminator.take().expect("invalid terminator state");
159
160 terminator
161 .successors_mut(|successor| self.collapse_goto_chain(successor, &mut changed));
162
163 let mut inner_changed = true;
164 merged_blocks.clear();
165 while inner_changed {
166 inner_changed = false;
167 inner_changed |= self.simplify_branch(&mut terminator);
168 inner_changed |= self.merge_successor(&mut merged_blocks, &mut terminator);
169 changed |= inner_changed;
170 }
171
172 let statements_to_merge =
173 merged_blocks.iter().map(|&i| self.basic_blocks[i].statements.len()).sum();
174
175 if statements_to_merge > 0 {
176 let mut statements = std::mem::take(&mut self.basic_blocks[bb].statements);
177 statements.reserve(statements_to_merge);
178 for &from in &merged_blocks {
179 statements.append(&mut self.basic_blocks[from].statements);
180 }
181 self.basic_blocks[bb].statements = statements;
182 }
183
184 self.basic_blocks[bb].terminator = Some(terminator);
185 }
186
187 if !changed {
188 break;
189 }
190
191 outer_changed = true;
192 }
193
194 outer_changed
195 }
196
197 fn take_terminator_if_simple_goto(&mut self, bb: BasicBlock) -> Option<Terminator<'tcx>> {
202 match self.basic_blocks[bb] {
203 BasicBlockData {
204 ref statements,
205 terminator:
206 ref mut terminator @ Some(Terminator { kind: TerminatorKind::Goto { .. }, .. }),
207 ..
208 } if statements.is_empty() => terminator.take(),
209 _ => None,
212 }
213 }
214
215 fn collapse_goto_chain(&mut self, start: &mut BasicBlock, changed: &mut bool) {
217 let mut terminators: SmallVec<[_; 1]> = Default::default();
220 let mut current = *start;
221 while let Some(terminator) = self.take_terminator_if_simple_goto(current) {
222 let Terminator { kind: TerminatorKind::Goto { target }, .. } = terminator else {
223 unreachable!();
224 };
225 terminators.push((current, terminator));
226 current = target;
227 }
228 let last = current;
229 *changed |= *start != last;
230 *start = last;
231 while let Some((current, mut terminator)) = terminators.pop() {
232 let Terminator { kind: TerminatorKind::Goto { ref mut target }, .. } = terminator
233 else {
234 unreachable!();
235 };
236 *changed |= *target != last;
237 *target = last;
238 debug!("collapsing goto chain from {:?} to {:?}", current, target);
239
240 if self.pred_count[current] == 1 {
241 self.pred_count[current] = 0;
244 } else {
245 self.pred_count[*target] += 1;
246 self.pred_count[current] -= 1;
247 }
248 self.basic_blocks[current].terminator = Some(terminator);
249 }
250 }
251
252 fn merge_successor(
254 &mut self,
255 merged_blocks: &mut Vec<BasicBlock>,
256 terminator: &mut Terminator<'tcx>,
257 ) -> bool {
258 let target = match terminator.kind {
259 TerminatorKind::Goto { target } if self.pred_count[target] == 1 => target,
260 _ => return false,
261 };
262
263 debug!("merging block {:?} into {:?}", target, terminator);
264 *terminator = match self.basic_blocks[target].terminator.take() {
265 Some(terminator) => terminator,
266 None => {
267 return false;
270 }
271 };
272
273 merged_blocks.push(target);
274 self.pred_count[target] = 0;
275
276 true
277 }
278
279 fn simplify_branch(&mut self, terminator: &mut Terminator<'tcx>) -> bool {
281 if self.preserve_switch_reads {
285 return false;
286 }
287
288 let TerminatorKind::SwitchInt { .. } = terminator.kind else {
289 return false;
290 };
291
292 let Ok(first_succ) = terminator.successors().all_equal_value() else {
293 return false;
294 };
295
296 let count = terminator.successors().count();
297 self.pred_count[first_succ] -= (count - 1) as u32;
298
299 debug!("simplifying branch {:?}", terminator);
300 terminator.kind = TerminatorKind::Goto { target: first_succ };
301 true
302 }
303
304 fn strip_nops(&mut self) {
305 for blk in self.basic_blocks.iter_mut() {
306 blk.statements.retain(|stmt| !matches!(stmt.kind, StatementKind::Nop))
307 }
308 }
309}
310
311pub(super) fn simplify_duplicate_switch_targets(terminator: &mut Terminator<'_>) {
312 if let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind {
313 let otherwise = targets.otherwise();
314 if targets.iter().any(|t| t.1 == otherwise) {
315 *targets = SwitchTargets::new(
316 targets.iter().filter(|t| t.1 != otherwise),
317 targets.otherwise(),
318 );
319 }
320 }
321}
322
323pub(super) fn remove_dead_blocks(body: &mut Body<'_>) {
324 let should_deduplicate_unreachable = |bbdata: &BasicBlockData<'_>| {
325 bbdata.terminator.is_some() && bbdata.is_empty_unreachable() && !bbdata.is_cleanup
331 };
332
333 let reachable = traversal::reachable_as_bitset(body);
334 let empty_unreachable_blocks = body
335 .basic_blocks
336 .iter_enumerated()
337 .filter(|(bb, bbdata)| should_deduplicate_unreachable(bbdata) && reachable.contains(*bb))
338 .count();
339
340 let num_blocks = body.basic_blocks.len();
341 if num_blocks == reachable.count() && empty_unreachable_blocks <= 1 {
342 return;
343 }
344
345 let basic_blocks = body.basic_blocks.as_mut();
346
347 let mut replacements: Vec<_> = (0..num_blocks).map(BasicBlock::new).collect();
348 let mut orig_index = 0;
349 let mut used_index = 0;
350 let mut kept_unreachable = None;
351 let mut deduplicated_unreachable = false;
352 basic_blocks.raw.retain(|bbdata| {
353 let orig_bb = BasicBlock::new(orig_index);
354 if !reachable.contains(orig_bb) {
355 orig_index += 1;
356 return false;
357 }
358
359 let used_bb = BasicBlock::new(used_index);
360 if should_deduplicate_unreachable(bbdata) {
361 let kept_unreachable = *kept_unreachable.get_or_insert(used_bb);
362 if kept_unreachable != used_bb {
363 replacements[orig_index] = kept_unreachable;
364 deduplicated_unreachable = true;
365 orig_index += 1;
366 return false;
367 }
368 }
369
370 replacements[orig_index] = used_bb;
371 used_index += 1;
372 orig_index += 1;
373 true
374 });
375
376 if deduplicated_unreachable {
380 basic_blocks[kept_unreachable.unwrap()].terminator_mut().source_info =
381 SourceInfo { span: DUMMY_SP, scope: OUTERMOST_SOURCE_SCOPE };
382 }
383
384 for block in basic_blocks {
385 block.terminator_mut().successors_mut(|target| *target = replacements[target.index()]);
386 }
387}
388
389pub(super) enum SimplifyLocals {
390 BeforeConstProp,
391 AfterGVN,
392 Final,
393}
394
395impl<'tcx> crate::MirPass<'tcx> for SimplifyLocals {
396 fn name(&self) -> &'static str {
397 match &self {
398 SimplifyLocals::BeforeConstProp => "SimplifyLocals-before-const-prop",
399 SimplifyLocals::AfterGVN => "SimplifyLocals-after-value-numbering",
400 SimplifyLocals::Final => "SimplifyLocals-final",
401 }
402 }
403
404 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
405 sess.mir_opt_level() > 0
406 }
407
408 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
409 trace!("running SimplifyLocals on {:?}", body.source);
410
411 let mut used_locals = UsedLocals::new(body);
413
414 remove_unused_definitions_helper(&mut used_locals, body);
420
421 let map = make_local_map(&mut body.local_decls, &used_locals);
424
425 if map.iter().any(Option::is_none) {
427 let mut updater = LocalUpdater { map, tcx };
429 updater.visit_body_preserves_cfg(body);
430
431 body.local_decls.shrink_to_fit();
432 }
433 }
434
435 fn is_required(&self) -> bool {
436 false
437 }
438}
439
440pub(super) fn remove_unused_definitions<'tcx>(body: &mut Body<'tcx>) {
441 let mut used_locals = UsedLocals::new(body);
443
444 remove_unused_definitions_helper(&mut used_locals, body);
450}
451
452fn make_local_map<V>(
454 local_decls: &mut IndexVec<Local, V>,
455 used_locals: &UsedLocals,
456) -> IndexVec<Local, Option<Local>> {
457 let mut map: IndexVec<Local, Option<Local>> = IndexVec::from_elem(None, local_decls);
458 let mut used = Local::ZERO;
459
460 for alive_index in local_decls.indices() {
461 if !used_locals.is_used(alive_index) {
463 continue;
464 }
465
466 map[alive_index] = Some(used);
467 if alive_index != used {
468 local_decls.swap(alive_index, used);
469 }
470 used.increment_by(1);
471 }
472 local_decls.truncate(used.index());
473 map
474}
475
476struct UsedLocals {
478 increment: bool,
479 arg_count: u32,
480 use_count: IndexVec<Local, u32>,
481}
482
483impl UsedLocals {
484 fn new(body: &Body<'_>) -> Self {
486 let mut this = Self {
487 increment: true,
488 arg_count: body.arg_count.try_into().unwrap(),
489 use_count: IndexVec::from_elem(0, &body.local_decls),
490 };
491 this.visit_body(body);
492 this
493 }
494
495 fn is_used(&self, local: Local) -> bool {
499 trace!("is_used({:?}): use_count: {:?}", local, self.use_count[local]);
500 local.as_u32() <= self.arg_count || self.use_count[local] != 0
501 }
502
503 fn statement_removed(&mut self, statement: &Statement<'_>) {
505 self.increment = false;
506
507 let location = Location::START;
509 self.visit_statement(statement, location);
510 }
511
512 fn visit_lhs(&mut self, place: &Place<'_>, location: Location) {
514 if place.is_indirect() {
515 self.visit_place(place, PlaceContext::MutatingUse(MutatingUseContext::Store), location);
517 } else {
518 self.super_projection(
522 place.as_ref(),
523 PlaceContext::MutatingUse(MutatingUseContext::Projection),
524 location,
525 );
526 }
527 }
528}
529
530impl<'tcx> Visitor<'tcx> for UsedLocals {
531 fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
532 match statement.kind {
533 StatementKind::Intrinsic(..)
534 | StatementKind::Retag(..)
535 | StatementKind::Coverage(..)
536 | StatementKind::FakeRead(..)
537 | StatementKind::PlaceMention(..)
538 | StatementKind::AscribeUserType(..) => {
539 self.super_statement(statement, location);
540 }
541
542 StatementKind::ConstEvalCounter | StatementKind::Nop => {}
543
544 StatementKind::StorageLive(_local) | StatementKind::StorageDead(_local) => {}
545
546 StatementKind::Assign(box (ref place, ref rvalue)) => {
547 if rvalue.is_safe_to_remove() {
548 self.visit_lhs(place, location);
549 self.visit_rvalue(rvalue, location);
550 } else {
551 self.super_statement(statement, location);
552 }
553 }
554
555 StatementKind::SetDiscriminant { ref place, variant_index: _ }
556 | StatementKind::Deinit(ref place)
557 | StatementKind::BackwardIncompatibleDropHint { ref place, reason: _ } => {
558 self.visit_lhs(place, location);
559 }
560 }
561 }
562
563 fn visit_local(&mut self, local: Local, _ctx: PlaceContext, _location: Location) {
564 if self.increment {
565 self.use_count[local] += 1;
566 } else {
567 assert_ne!(self.use_count[local], 0);
568 self.use_count[local] -= 1;
569 }
570 }
571}
572
573fn remove_unused_definitions_helper(used_locals: &mut UsedLocals, body: &mut Body<'_>) {
575 let mut modified = true;
581 while modified {
582 modified = false;
583
584 for data in body.basic_blocks.as_mut_preserves_cfg() {
585 data.statements.retain(|statement| {
587 let keep = match &statement.kind {
588 StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
589 used_locals.is_used(*local)
590 }
591 StatementKind::Assign(box (place, _)) => used_locals.is_used(place.local),
592
593 StatementKind::SetDiscriminant { place, .. }
594 | StatementKind::BackwardIncompatibleDropHint { place, reason: _ }
595 | StatementKind::Deinit(place) => used_locals.is_used(place.local),
596 StatementKind::Nop => false,
597 _ => true,
598 };
599
600 if !keep {
601 trace!("removing statement {:?}", statement);
602 modified = true;
603 used_locals.statement_removed(statement);
604 }
605
606 keep
607 });
608 }
609 }
610}
611
612struct LocalUpdater<'tcx> {
613 map: IndexVec<Local, Option<Local>>,
614 tcx: TyCtxt<'tcx>,
615}
616
617impl<'tcx> MutVisitor<'tcx> for LocalUpdater<'tcx> {
618 fn tcx(&self) -> TyCtxt<'tcx> {
619 self.tcx
620 }
621
622 fn visit_local(&mut self, l: &mut Local, _: PlaceContext, _: Location) {
623 *l = self.map[*l].unwrap();
624 }
625}