1use rustc_index::{Idx, IndexSlice, IndexVec};
38use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
39use rustc_middle::mir::*;
40use rustc_middle::ty::TyCtxt;
41use rustc_span::DUMMY_SP;
42use smallvec::SmallVec;
43use tracing::{debug, trace};
44
45pub(super) enum SimplifyCfg {
46 Initial,
47 PromoteConsts,
48 RemoveFalseEdges,
49 PostAnalysis,
51 PreOptimizations,
54 Final,
55 MakeShim,
56 AfterUnreachableEnumBranching,
57}
58
59impl SimplifyCfg {
60 fn name(&self) -> &'static str {
61 match self {
62 SimplifyCfg::Initial => "SimplifyCfg-initial",
63 SimplifyCfg::PromoteConsts => "SimplifyCfg-promote-consts",
64 SimplifyCfg::RemoveFalseEdges => "SimplifyCfg-remove-false-edges",
65 SimplifyCfg::PostAnalysis => "SimplifyCfg-post-analysis",
66 SimplifyCfg::PreOptimizations => "SimplifyCfg-pre-optimizations",
67 SimplifyCfg::Final => "SimplifyCfg-final",
68 SimplifyCfg::MakeShim => "SimplifyCfg-make_shim",
69 SimplifyCfg::AfterUnreachableEnumBranching => {
70 "SimplifyCfg-after-unreachable-enum-branching"
71 }
72 }
73 }
74}
75
76pub(super) fn simplify_cfg<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
77 CfgSimplifier::new(tcx, body).simplify();
78 remove_dead_blocks(body);
79
80 body.basic_blocks_mut().raw.shrink_to_fit();
82}
83
84impl<'tcx> crate::MirPass<'tcx> for SimplifyCfg {
85 fn name(&self) -> &'static str {
86 self.name()
87 }
88
89 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
90 debug!("SimplifyCfg({:?}) - simplifying {:?}", self.name(), body.source);
91 simplify_cfg(tcx, body);
92 }
93
94 fn is_required(&self) -> bool {
95 false
96 }
97}
98
99struct CfgSimplifier<'a, 'tcx> {
100 preserve_switch_reads: bool,
101 basic_blocks: &'a mut IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
102 pred_count: IndexVec<BasicBlock, u32>,
103}
104
105impl<'a, 'tcx> CfgSimplifier<'a, 'tcx> {
106 fn new(tcx: TyCtxt<'tcx>, body: &'a mut Body<'tcx>) -> Self {
107 let mut pred_count = IndexVec::from_elem(0u32, &body.basic_blocks);
108
109 pred_count[START_BLOCK] = 1;
112
113 for (_, data) in traversal::preorder(body) {
114 if let Some(ref term) = data.terminator {
115 for tgt in term.successors() {
116 pred_count[tgt] += 1;
117 }
118 }
119 }
120
121 let preserve_switch_reads = matches!(body.phase, MirPhase::Built | MirPhase::Analysis(_))
123 || tcx.sess.opts.unstable_opts.mir_preserve_ub;
124 let basic_blocks = body.basic_blocks_mut();
125
126 CfgSimplifier { preserve_switch_reads, basic_blocks, pred_count }
127 }
128
129 fn simplify(mut self) {
130 self.strip_nops();
131
132 let mut merged_blocks = Vec::new();
137 loop {
138 let mut changed = false;
139
140 for bb in self.basic_blocks.indices() {
141 if self.pred_count[bb] == 0 {
142 continue;
143 }
144
145 debug!("simplifying {:?}", bb);
146
147 let mut terminator =
148 self.basic_blocks[bb].terminator.take().expect("invalid terminator state");
149
150 terminator
151 .successors_mut(|successor| self.collapse_goto_chain(successor, &mut changed));
152
153 let mut inner_changed = true;
154 merged_blocks.clear();
155 while inner_changed {
156 inner_changed = false;
157 inner_changed |= self.simplify_branch(&mut terminator);
158 inner_changed |= self.merge_successor(&mut merged_blocks, &mut terminator);
159 changed |= inner_changed;
160 }
161
162 let statements_to_merge =
163 merged_blocks.iter().map(|&i| self.basic_blocks[i].statements.len()).sum();
164
165 if statements_to_merge > 0 {
166 let mut statements = std::mem::take(&mut self.basic_blocks[bb].statements);
167 statements.reserve(statements_to_merge);
168 for &from in &merged_blocks {
169 statements.append(&mut self.basic_blocks[from].statements);
170 }
171 self.basic_blocks[bb].statements = statements;
172 }
173
174 self.basic_blocks[bb].terminator = Some(terminator);
175 }
176
177 if !changed {
178 break;
179 }
180 }
181 }
182
183 fn take_terminator_if_simple_goto(&mut self, bb: BasicBlock) -> Option<Terminator<'tcx>> {
188 match self.basic_blocks[bb] {
189 BasicBlockData {
190 ref statements,
191 terminator:
192 ref mut terminator @ Some(Terminator { kind: TerminatorKind::Goto { .. }, .. }),
193 ..
194 } if statements.is_empty() => terminator.take(),
195 _ => None,
198 }
199 }
200
201 fn collapse_goto_chain(&mut self, start: &mut BasicBlock, changed: &mut bool) {
203 let mut terminators: SmallVec<[_; 1]> = Default::default();
206 let mut current = *start;
207 while let Some(terminator) = self.take_terminator_if_simple_goto(current) {
208 let Terminator { kind: TerminatorKind::Goto { target }, .. } = terminator else {
209 unreachable!();
210 };
211 terminators.push((current, terminator));
212 current = target;
213 }
214 let last = current;
215 *start = last;
216 while let Some((current, mut terminator)) = terminators.pop() {
217 let Terminator { kind: TerminatorKind::Goto { ref mut target }, .. } = terminator
218 else {
219 unreachable!();
220 };
221 *changed |= *target != last;
222 *target = last;
223 debug!("collapsing goto chain from {:?} to {:?}", current, target);
224
225 if self.pred_count[current] == 1 {
226 self.pred_count[current] = 0;
229 } else {
230 self.pred_count[*target] += 1;
231 self.pred_count[current] -= 1;
232 }
233 self.basic_blocks[current].terminator = Some(terminator);
234 }
235 }
236
237 fn merge_successor(
239 &mut self,
240 merged_blocks: &mut Vec<BasicBlock>,
241 terminator: &mut Terminator<'tcx>,
242 ) -> bool {
243 let target = match terminator.kind {
244 TerminatorKind::Goto { target } if self.pred_count[target] == 1 => target,
245 _ => return false,
246 };
247
248 debug!("merging block {:?} into {:?}", target, terminator);
249 *terminator = match self.basic_blocks[target].terminator.take() {
250 Some(terminator) => terminator,
251 None => {
252 return false;
255 }
256 };
257
258 merged_blocks.push(target);
259 self.pred_count[target] = 0;
260
261 true
262 }
263
264 fn simplify_branch(&mut self, terminator: &mut Terminator<'tcx>) -> bool {
266 if self.preserve_switch_reads {
270 return false;
271 }
272
273 let TerminatorKind::SwitchInt { .. } = terminator.kind else {
274 return false;
275 };
276
277 let first_succ = {
278 if let Some(first_succ) = terminator.successors().next() {
279 if terminator.successors().all(|s| s == first_succ) {
280 let count = terminator.successors().count();
281 self.pred_count[first_succ] -= (count - 1) as u32;
282 first_succ
283 } else {
284 return false;
285 }
286 } else {
287 return false;
288 }
289 };
290
291 debug!("simplifying branch {:?}", terminator);
292 terminator.kind = TerminatorKind::Goto { target: first_succ };
293 true
294 }
295
296 fn strip_nops(&mut self) {
297 for blk in self.basic_blocks.iter_mut() {
298 blk.statements.retain(|stmt| !matches!(stmt.kind, StatementKind::Nop))
299 }
300 }
301}
302
303pub(super) fn simplify_duplicate_switch_targets(terminator: &mut Terminator<'_>) {
304 if let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind {
305 let otherwise = targets.otherwise();
306 if targets.iter().any(|t| t.1 == otherwise) {
307 *targets = SwitchTargets::new(
308 targets.iter().filter(|t| t.1 != otherwise),
309 targets.otherwise(),
310 );
311 }
312 }
313}
314
315pub(super) fn remove_dead_blocks(body: &mut Body<'_>) {
316 let should_deduplicate_unreachable = |bbdata: &BasicBlockData<'_>| {
317 bbdata.terminator.is_some() && bbdata.is_empty_unreachable() && !bbdata.is_cleanup
323 };
324
325 let reachable = traversal::reachable_as_bitset(body);
326 let empty_unreachable_blocks = body
327 .basic_blocks
328 .iter_enumerated()
329 .filter(|(bb, bbdata)| should_deduplicate_unreachable(bbdata) && reachable.contains(*bb))
330 .count();
331
332 let num_blocks = body.basic_blocks.len();
333 if num_blocks == reachable.count() && empty_unreachable_blocks <= 1 {
334 return;
335 }
336
337 let basic_blocks = body.basic_blocks.as_mut();
338
339 let mut replacements: Vec<_> = (0..num_blocks).map(BasicBlock::new).collect();
340 let mut orig_index = 0;
341 let mut used_index = 0;
342 let mut kept_unreachable = None;
343 let mut deduplicated_unreachable = false;
344 basic_blocks.raw.retain(|bbdata| {
345 let orig_bb = BasicBlock::new(orig_index);
346 if !reachable.contains(orig_bb) {
347 orig_index += 1;
348 return false;
349 }
350
351 let used_bb = BasicBlock::new(used_index);
352 if should_deduplicate_unreachable(bbdata) {
353 let kept_unreachable = *kept_unreachable.get_or_insert(used_bb);
354 if kept_unreachable != used_bb {
355 replacements[orig_index] = kept_unreachable;
356 deduplicated_unreachable = true;
357 orig_index += 1;
358 return false;
359 }
360 }
361
362 replacements[orig_index] = used_bb;
363 used_index += 1;
364 orig_index += 1;
365 true
366 });
367
368 if deduplicated_unreachable {
372 basic_blocks[kept_unreachable.unwrap()].terminator_mut().source_info =
373 SourceInfo { span: DUMMY_SP, scope: OUTERMOST_SOURCE_SCOPE };
374 }
375
376 for block in basic_blocks {
377 block.terminator_mut().successors_mut(|target| *target = replacements[target.index()]);
378 }
379}
380
381pub(super) enum SimplifyLocals {
382 BeforeConstProp,
383 AfterGVN,
384 Final,
385}
386
387impl<'tcx> crate::MirPass<'tcx> for SimplifyLocals {
388 fn name(&self) -> &'static str {
389 match &self {
390 SimplifyLocals::BeforeConstProp => "SimplifyLocals-before-const-prop",
391 SimplifyLocals::AfterGVN => "SimplifyLocals-after-value-numbering",
392 SimplifyLocals::Final => "SimplifyLocals-final",
393 }
394 }
395
396 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
397 sess.mir_opt_level() > 0
398 }
399
400 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
401 trace!("running SimplifyLocals on {:?}", body.source);
402
403 let mut used_locals = UsedLocals::new(body);
405
406 remove_unused_definitions_helper(&mut used_locals, body);
412
413 let map = make_local_map(&mut body.local_decls, &used_locals);
416
417 if map.iter().any(Option::is_none) {
419 let mut updater = LocalUpdater { map, tcx };
421 updater.visit_body_preserves_cfg(body);
422
423 body.local_decls.shrink_to_fit();
424 }
425 }
426
427 fn is_required(&self) -> bool {
428 false
429 }
430}
431
432pub(super) fn remove_unused_definitions<'tcx>(body: &mut Body<'tcx>) {
433 let mut used_locals = UsedLocals::new(body);
435
436 remove_unused_definitions_helper(&mut used_locals, body);
442}
443
444fn make_local_map<V>(
446 local_decls: &mut IndexVec<Local, V>,
447 used_locals: &UsedLocals,
448) -> IndexVec<Local, Option<Local>> {
449 let mut map: IndexVec<Local, Option<Local>> = IndexVec::from_elem(None, local_decls);
450 let mut used = Local::ZERO;
451
452 for alive_index in local_decls.indices() {
453 if !used_locals.is_used(alive_index) {
455 continue;
456 }
457
458 map[alive_index] = Some(used);
459 if alive_index != used {
460 local_decls.swap(alive_index, used);
461 }
462 used.increment_by(1);
463 }
464 local_decls.truncate(used.index());
465 map
466}
467
468struct UsedLocals {
470 increment: bool,
471 arg_count: u32,
472 use_count: IndexVec<Local, u32>,
473}
474
475impl UsedLocals {
476 fn new(body: &Body<'_>) -> Self {
478 let mut this = Self {
479 increment: true,
480 arg_count: body.arg_count.try_into().unwrap(),
481 use_count: IndexVec::from_elem(0, &body.local_decls),
482 };
483 this.visit_body(body);
484 this
485 }
486
487 fn is_used(&self, local: Local) -> bool {
491 trace!("is_used({:?}): use_count: {:?}", local, self.use_count[local]);
492 local.as_u32() <= self.arg_count || self.use_count[local] != 0
493 }
494
495 fn statement_removed(&mut self, statement: &Statement<'_>) {
497 self.increment = false;
498
499 let location = Location::START;
501 self.visit_statement(statement, location);
502 }
503
504 fn visit_lhs(&mut self, place: &Place<'_>, location: Location) {
506 if place.is_indirect() {
507 self.visit_place(place, PlaceContext::MutatingUse(MutatingUseContext::Store), location);
509 } else {
510 self.super_projection(
514 place.as_ref(),
515 PlaceContext::MutatingUse(MutatingUseContext::Projection),
516 location,
517 );
518 }
519 }
520}
521
522impl<'tcx> Visitor<'tcx> for UsedLocals {
523 fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
524 match statement.kind {
525 StatementKind::Intrinsic(..)
526 | StatementKind::Retag(..)
527 | StatementKind::Coverage(..)
528 | StatementKind::FakeRead(..)
529 | StatementKind::PlaceMention(..)
530 | StatementKind::AscribeUserType(..) => {
531 self.super_statement(statement, location);
532 }
533
534 StatementKind::ConstEvalCounter | StatementKind::Nop => {}
535
536 StatementKind::StorageLive(_local) | StatementKind::StorageDead(_local) => {}
537
538 StatementKind::Assign(box (ref place, ref rvalue)) => {
539 if rvalue.is_safe_to_remove() {
540 self.visit_lhs(place, location);
541 self.visit_rvalue(rvalue, location);
542 } else {
543 self.super_statement(statement, location);
544 }
545 }
546
547 StatementKind::SetDiscriminant { ref place, variant_index: _ }
548 | StatementKind::Deinit(ref place)
549 | StatementKind::BackwardIncompatibleDropHint { ref place, reason: _ } => {
550 self.visit_lhs(place, location);
551 }
552 }
553 }
554
555 fn visit_local(&mut self, local: Local, _ctx: PlaceContext, _location: Location) {
556 if self.increment {
557 self.use_count[local] += 1;
558 } else {
559 assert_ne!(self.use_count[local], 0);
560 self.use_count[local] -= 1;
561 }
562 }
563}
564
565fn remove_unused_definitions_helper(used_locals: &mut UsedLocals, body: &mut Body<'_>) {
567 let mut modified = true;
573 while modified {
574 modified = false;
575
576 for data in body.basic_blocks.as_mut_preserves_cfg() {
577 data.statements.retain(|statement| {
579 let keep = match &statement.kind {
580 StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
581 used_locals.is_used(*local)
582 }
583 StatementKind::Assign(box (place, _)) => used_locals.is_used(place.local),
584
585 StatementKind::SetDiscriminant { place, .. }
586 | StatementKind::BackwardIncompatibleDropHint { place, reason: _ }
587 | StatementKind::Deinit(place) => used_locals.is_used(place.local),
588 StatementKind::Nop => false,
589 _ => true,
590 };
591
592 if !keep {
593 trace!("removing statement {:?}", statement);
594 modified = true;
595 used_locals.statement_removed(statement);
596 }
597
598 keep
599 });
600 }
601 }
602}
603
604struct LocalUpdater<'tcx> {
605 map: IndexVec<Local, Option<Local>>,
606 tcx: TyCtxt<'tcx>,
607}
608
609impl<'tcx> MutVisitor<'tcx> for LocalUpdater<'tcx> {
610 fn tcx(&self) -> TyCtxt<'tcx> {
611 self.tcx
612 }
613
614 fn visit_local(&mut self, l: &mut Local, _: PlaceContext, _: Location) {
615 *l = self.map[*l].unwrap();
616 }
617}