rustc_mir_transform/
simplify.rs

1//! A number of passes which remove various redundancies in the CFG.
2//!
3//! The `SimplifyCfg` pass gets rid of unnecessary blocks in the CFG, whereas the `SimplifyLocals`
4//! gets rid of all the unnecessary local variable declarations.
5//!
6//! The `SimplifyLocals` pass is kinda expensive and therefore not very suitable to be run often.
7//! Most of the passes should not care or be impacted in meaningful ways due to extra locals
8//! either, so running the pass once, right before codegen, should suffice.
9//!
10//! On the other side of the spectrum, the `SimplifyCfg` pass is considerably cheap to run, thus
11//! one should run it after every pass which may modify CFG in significant ways. This pass must
12//! also be run before any analysis passes because it removes dead blocks, and some of these can be
13//! ill-typed.
14//!
15//! The cause of this typing issue is typeck allowing most blocks whose end is not reachable have
16//! an arbitrary return type, rather than having the usual () return type (as a note, typeck's
17//! notion of reachability is in fact slightly weaker than MIR CFG reachability - see #31617). A
18//! standard example of the situation is:
19//!
20//! ```rust
21//!   fn example() {
22//!       let _a: char = { return; };
23//!   }
24//! ```
25//!
26//! Here the block (`{ return; }`) has the return type `char`, rather than `()`, but the MIR we
27//! naively generate still contains the `_a = ()` write in the unreachable block "after" the
28//! return.
29//!
30//! **WARNING**: This is one of the few optimizations that runs on built and analysis MIR, and
31//! so its effects may affect the type-checking, borrow-checking, and other analysis of MIR.
32//! We must be extremely careful to only apply optimizations that preserve UB and all
33//! non-determinism, since changes here can affect which programs compile in an insta-stable way.
34//! The normal logic that a program with UB can be changed to do anything does not apply to
35//! pre-"runtime" MIR!
36
37use rustc_index::{Idx, IndexSlice, IndexVec};
38use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
39use rustc_middle::mir::*;
40use rustc_middle::ty::TyCtxt;
41use rustc_span::DUMMY_SP;
42use smallvec::SmallVec;
43use tracing::{debug, trace};
44
45pub(super) enum SimplifyCfg {
46    Initial,
47    PromoteConsts,
48    RemoveFalseEdges,
49    /// Runs at the beginning of "analysis to runtime" lowering, *before* drop elaboration.
50    PostAnalysis,
51    /// Runs at the end of "analysis to runtime" lowering, *after* drop elaboration.
52    /// This is before the main optimization passes on runtime MIR kick in.
53    PreOptimizations,
54    Final,
55    MakeShim,
56    AfterUnreachableEnumBranching,
57}
58
59impl SimplifyCfg {
60    fn name(&self) -> &'static str {
61        match self {
62            SimplifyCfg::Initial => "SimplifyCfg-initial",
63            SimplifyCfg::PromoteConsts => "SimplifyCfg-promote-consts",
64            SimplifyCfg::RemoveFalseEdges => "SimplifyCfg-remove-false-edges",
65            SimplifyCfg::PostAnalysis => "SimplifyCfg-post-analysis",
66            SimplifyCfg::PreOptimizations => "SimplifyCfg-pre-optimizations",
67            SimplifyCfg::Final => "SimplifyCfg-final",
68            SimplifyCfg::MakeShim => "SimplifyCfg-make_shim",
69            SimplifyCfg::AfterUnreachableEnumBranching => {
70                "SimplifyCfg-after-unreachable-enum-branching"
71            }
72        }
73    }
74}
75
76pub(super) fn simplify_cfg<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
77    CfgSimplifier::new(tcx, body).simplify();
78    remove_dead_blocks(body);
79
80    // FIXME: Should probably be moved into some kind of pass manager
81    body.basic_blocks_mut().raw.shrink_to_fit();
82}
83
84impl<'tcx> crate::MirPass<'tcx> for SimplifyCfg {
85    fn name(&self) -> &'static str {
86        self.name()
87    }
88
89    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
90        debug!("SimplifyCfg({:?}) - simplifying {:?}", self.name(), body.source);
91        simplify_cfg(tcx, body);
92    }
93
94    fn is_required(&self) -> bool {
95        false
96    }
97}
98
99struct CfgSimplifier<'a, 'tcx> {
100    preserve_switch_reads: bool,
101    basic_blocks: &'a mut IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
102    pred_count: IndexVec<BasicBlock, u32>,
103}
104
105impl<'a, 'tcx> CfgSimplifier<'a, 'tcx> {
106    fn new(tcx: TyCtxt<'tcx>, body: &'a mut Body<'tcx>) -> Self {
107        let mut pred_count = IndexVec::from_elem(0u32, &body.basic_blocks);
108
109        // we can't use mir.predecessors() here because that counts
110        // dead blocks, which we don't want to.
111        pred_count[START_BLOCK] = 1;
112
113        for (_, data) in traversal::preorder(body) {
114            if let Some(ref term) = data.terminator {
115                for tgt in term.successors() {
116                    pred_count[tgt] += 1;
117                }
118            }
119        }
120
121        // Preserve `SwitchInt` reads on built and analysis MIR, or if `-Zmir-preserve-ub`.
122        let preserve_switch_reads = matches!(body.phase, MirPhase::Built | MirPhase::Analysis(_))
123            || tcx.sess.opts.unstable_opts.mir_preserve_ub;
124        let basic_blocks = body.basic_blocks_mut();
125
126        CfgSimplifier { preserve_switch_reads, basic_blocks, pred_count }
127    }
128
129    fn simplify(mut self) {
130        self.strip_nops();
131
132        // Vec of the blocks that should be merged. We store the indices here, instead of the
133        // statements itself to avoid moving the (relatively) large statements twice.
134        // We do not push the statements directly into the target block (`bb`) as that is slower
135        // due to additional reallocations
136        let mut merged_blocks = Vec::new();
137        loop {
138            let mut changed = false;
139
140            for bb in self.basic_blocks.indices() {
141                if self.pred_count[bb] == 0 {
142                    continue;
143                }
144
145                debug!("simplifying {:?}", bb);
146
147                let mut terminator =
148                    self.basic_blocks[bb].terminator.take().expect("invalid terminator state");
149
150                terminator
151                    .successors_mut(|successor| self.collapse_goto_chain(successor, &mut changed));
152
153                let mut inner_changed = true;
154                merged_blocks.clear();
155                while inner_changed {
156                    inner_changed = false;
157                    inner_changed |= self.simplify_branch(&mut terminator);
158                    inner_changed |= self.merge_successor(&mut merged_blocks, &mut terminator);
159                    changed |= inner_changed;
160                }
161
162                let statements_to_merge =
163                    merged_blocks.iter().map(|&i| self.basic_blocks[i].statements.len()).sum();
164
165                if statements_to_merge > 0 {
166                    let mut statements = std::mem::take(&mut self.basic_blocks[bb].statements);
167                    statements.reserve(statements_to_merge);
168                    for &from in &merged_blocks {
169                        statements.append(&mut self.basic_blocks[from].statements);
170                    }
171                    self.basic_blocks[bb].statements = statements;
172                }
173
174                self.basic_blocks[bb].terminator = Some(terminator);
175            }
176
177            if !changed {
178                break;
179            }
180        }
181    }
182
183    /// This function will return `None` if
184    /// * the block has statements
185    /// * the block has a terminator other than `goto`
186    /// * the block has no terminator (meaning some other part of the current optimization stole it)
187    fn take_terminator_if_simple_goto(&mut self, bb: BasicBlock) -> Option<Terminator<'tcx>> {
188        match self.basic_blocks[bb] {
189            BasicBlockData {
190                ref statements,
191                terminator:
192                    ref mut terminator @ Some(Terminator { kind: TerminatorKind::Goto { .. }, .. }),
193                ..
194            } if statements.is_empty() => terminator.take(),
195            // if `terminator` is None, this means we are in a loop. In that
196            // case, let all the loop collapse to its entry.
197            _ => None,
198        }
199    }
200
201    /// Collapse a goto chain starting from `start`
202    fn collapse_goto_chain(&mut self, start: &mut BasicBlock, changed: &mut bool) {
203        // Using `SmallVec` here, because in some logs on libcore oli-obk saw many single-element
204        // goto chains. We should probably benchmark different sizes.
205        let mut terminators: SmallVec<[_; 1]> = Default::default();
206        let mut current = *start;
207        while let Some(terminator) = self.take_terminator_if_simple_goto(current) {
208            let Terminator { kind: TerminatorKind::Goto { target }, .. } = terminator else {
209                unreachable!();
210            };
211            terminators.push((current, terminator));
212            current = target;
213        }
214        let last = current;
215        *start = last;
216        while let Some((current, mut terminator)) = terminators.pop() {
217            let Terminator { kind: TerminatorKind::Goto { ref mut target }, .. } = terminator
218            else {
219                unreachable!();
220            };
221            *changed |= *target != last;
222            *target = last;
223            debug!("collapsing goto chain from {:?} to {:?}", current, target);
224
225            if self.pred_count[current] == 1 {
226                // This is the last reference to current, so the pred-count to
227                // to target is moved into the current block.
228                self.pred_count[current] = 0;
229            } else {
230                self.pred_count[*target] += 1;
231                self.pred_count[current] -= 1;
232            }
233            self.basic_blocks[current].terminator = Some(terminator);
234        }
235    }
236
237    // merge a block with 1 `goto` predecessor to its parent
238    fn merge_successor(
239        &mut self,
240        merged_blocks: &mut Vec<BasicBlock>,
241        terminator: &mut Terminator<'tcx>,
242    ) -> bool {
243        let target = match terminator.kind {
244            TerminatorKind::Goto { target } if self.pred_count[target] == 1 => target,
245            _ => return false,
246        };
247
248        debug!("merging block {:?} into {:?}", target, terminator);
249        *terminator = match self.basic_blocks[target].terminator.take() {
250            Some(terminator) => terminator,
251            None => {
252                // unreachable loop - this should not be possible, as we
253                // don't strand blocks, but handle it correctly.
254                return false;
255            }
256        };
257
258        merged_blocks.push(target);
259        self.pred_count[target] = 0;
260
261        true
262    }
263
264    // turn a branch with all successors identical to a goto
265    fn simplify_branch(&mut self, terminator: &mut Terminator<'tcx>) -> bool {
266        // Removing a `SwitchInt` terminator may remove reads that result in UB,
267        // so we must not apply this optimization before borrowck or when
268        // `-Zmir-preserve-ub` is set.
269        if self.preserve_switch_reads {
270            return false;
271        }
272
273        let TerminatorKind::SwitchInt { .. } = terminator.kind else {
274            return false;
275        };
276
277        let first_succ = {
278            if let Some(first_succ) = terminator.successors().next() {
279                if terminator.successors().all(|s| s == first_succ) {
280                    let count = terminator.successors().count();
281                    self.pred_count[first_succ] -= (count - 1) as u32;
282                    first_succ
283                } else {
284                    return false;
285                }
286            } else {
287                return false;
288            }
289        };
290
291        debug!("simplifying branch {:?}", terminator);
292        terminator.kind = TerminatorKind::Goto { target: first_succ };
293        true
294    }
295
296    fn strip_nops(&mut self) {
297        for blk in self.basic_blocks.iter_mut() {
298            blk.statements.retain(|stmt| !matches!(stmt.kind, StatementKind::Nop))
299        }
300    }
301}
302
303pub(super) fn simplify_duplicate_switch_targets(terminator: &mut Terminator<'_>) {
304    if let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind {
305        let otherwise = targets.otherwise();
306        if targets.iter().any(|t| t.1 == otherwise) {
307            *targets = SwitchTargets::new(
308                targets.iter().filter(|t| t.1 != otherwise),
309                targets.otherwise(),
310            );
311        }
312    }
313}
314
315pub(super) fn remove_dead_blocks(body: &mut Body<'_>) {
316    let should_deduplicate_unreachable = |bbdata: &BasicBlockData<'_>| {
317        // CfgSimplifier::simplify leaves behind some unreachable basic blocks without a
318        // terminator. Those blocks will be deleted by remove_dead_blocks, but we run just
319        // before then so we need to handle missing terminators.
320        // We also need to prevent confusing cleanup and non-cleanup blocks. In practice we
321        // don't emit empty unreachable cleanup blocks, so this simple check suffices.
322        bbdata.terminator.is_some() && bbdata.is_empty_unreachable() && !bbdata.is_cleanup
323    };
324
325    let reachable = traversal::reachable_as_bitset(body);
326    let empty_unreachable_blocks = body
327        .basic_blocks
328        .iter_enumerated()
329        .filter(|(bb, bbdata)| should_deduplicate_unreachable(bbdata) && reachable.contains(*bb))
330        .count();
331
332    let num_blocks = body.basic_blocks.len();
333    if num_blocks == reachable.count() && empty_unreachable_blocks <= 1 {
334        return;
335    }
336
337    let basic_blocks = body.basic_blocks.as_mut();
338
339    let mut replacements: Vec<_> = (0..num_blocks).map(BasicBlock::new).collect();
340    let mut orig_index = 0;
341    let mut used_index = 0;
342    let mut kept_unreachable = None;
343    let mut deduplicated_unreachable = false;
344    basic_blocks.raw.retain(|bbdata| {
345        let orig_bb = BasicBlock::new(orig_index);
346        if !reachable.contains(orig_bb) {
347            orig_index += 1;
348            return false;
349        }
350
351        let used_bb = BasicBlock::new(used_index);
352        if should_deduplicate_unreachable(bbdata) {
353            let kept_unreachable = *kept_unreachable.get_or_insert(used_bb);
354            if kept_unreachable != used_bb {
355                replacements[orig_index] = kept_unreachable;
356                deduplicated_unreachable = true;
357                orig_index += 1;
358                return false;
359            }
360        }
361
362        replacements[orig_index] = used_bb;
363        used_index += 1;
364        orig_index += 1;
365        true
366    });
367
368    // If we deduplicated unreachable blocks we erase their source_info as we
369    // can no longer attribute their code to a particular location in the
370    // source.
371    if deduplicated_unreachable {
372        basic_blocks[kept_unreachable.unwrap()].terminator_mut().source_info =
373            SourceInfo { span: DUMMY_SP, scope: OUTERMOST_SOURCE_SCOPE };
374    }
375
376    for block in basic_blocks {
377        block.terminator_mut().successors_mut(|target| *target = replacements[target.index()]);
378    }
379}
380
381pub(super) enum SimplifyLocals {
382    BeforeConstProp,
383    AfterGVN,
384    Final,
385}
386
387impl<'tcx> crate::MirPass<'tcx> for SimplifyLocals {
388    fn name(&self) -> &'static str {
389        match &self {
390            SimplifyLocals::BeforeConstProp => "SimplifyLocals-before-const-prop",
391            SimplifyLocals::AfterGVN => "SimplifyLocals-after-value-numbering",
392            SimplifyLocals::Final => "SimplifyLocals-final",
393        }
394    }
395
396    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
397        sess.mir_opt_level() > 0
398    }
399
400    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
401        trace!("running SimplifyLocals on {:?}", body.source);
402
403        // First, we're going to get a count of *actual* uses for every `Local`.
404        let mut used_locals = UsedLocals::new(body);
405
406        // Next, we're going to remove any `Local` with zero actual uses. When we remove those
407        // `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals`
408        // count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from
409        // `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a
410        // fixedpoint where there are no more unused locals.
411        remove_unused_definitions_helper(&mut used_locals, body);
412
413        // Finally, we'll actually do the work of shrinking `body.local_decls` and remapping the
414        // `Local`s.
415        let map = make_local_map(&mut body.local_decls, &used_locals);
416
417        // Only bother running the `LocalUpdater` if we actually found locals to remove.
418        if map.iter().any(Option::is_none) {
419            // Update references to all vars and tmps now
420            let mut updater = LocalUpdater { map, tcx };
421            updater.visit_body_preserves_cfg(body);
422
423            body.local_decls.shrink_to_fit();
424        }
425    }
426
427    fn is_required(&self) -> bool {
428        false
429    }
430}
431
432pub(super) fn remove_unused_definitions<'tcx>(body: &mut Body<'tcx>) {
433    // First, we're going to get a count of *actual* uses for every `Local`.
434    let mut used_locals = UsedLocals::new(body);
435
436    // Next, we're going to remove any `Local` with zero actual uses. When we remove those
437    // `Locals`, we're also going to subtract any uses of other `Locals` from the `used_locals`
438    // count. For example, if we removed `_2 = discriminant(_1)`, then we'll subtract one from
439    // `use_counts[_1]`. That in turn might make `_1` unused, so we loop until we hit a
440    // fixedpoint where there are no more unused locals.
441    remove_unused_definitions_helper(&mut used_locals, body);
442}
443
444/// Construct the mapping while swapping out unused stuff out from the `vec`.
445fn make_local_map<V>(
446    local_decls: &mut IndexVec<Local, V>,
447    used_locals: &UsedLocals,
448) -> IndexVec<Local, Option<Local>> {
449    let mut map: IndexVec<Local, Option<Local>> = IndexVec::from_elem(None, local_decls);
450    let mut used = Local::ZERO;
451
452    for alive_index in local_decls.indices() {
453        // `is_used` treats the `RETURN_PLACE` and arguments as used.
454        if !used_locals.is_used(alive_index) {
455            continue;
456        }
457
458        map[alive_index] = Some(used);
459        if alive_index != used {
460            local_decls.swap(alive_index, used);
461        }
462        used.increment_by(1);
463    }
464    local_decls.truncate(used.index());
465    map
466}
467
468/// Keeps track of used & unused locals.
469struct UsedLocals {
470    increment: bool,
471    arg_count: u32,
472    use_count: IndexVec<Local, u32>,
473}
474
475impl UsedLocals {
476    /// Determines which locals are used & unused in the given body.
477    fn new(body: &Body<'_>) -> Self {
478        let mut this = Self {
479            increment: true,
480            arg_count: body.arg_count.try_into().unwrap(),
481            use_count: IndexVec::from_elem(0, &body.local_decls),
482        };
483        this.visit_body(body);
484        this
485    }
486
487    /// Checks if local is used.
488    ///
489    /// Return place and arguments are always considered used.
490    fn is_used(&self, local: Local) -> bool {
491        trace!("is_used({:?}): use_count: {:?}", local, self.use_count[local]);
492        local.as_u32() <= self.arg_count || self.use_count[local] != 0
493    }
494
495    /// Updates the use counts to reflect the removal of given statement.
496    fn statement_removed(&mut self, statement: &Statement<'_>) {
497        self.increment = false;
498
499        // The location of the statement is irrelevant.
500        let location = Location::START;
501        self.visit_statement(statement, location);
502    }
503
504    /// Visits a left-hand side of an assignment.
505    fn visit_lhs(&mut self, place: &Place<'_>, location: Location) {
506        if place.is_indirect() {
507            // A use, not a definition.
508            self.visit_place(place, PlaceContext::MutatingUse(MutatingUseContext::Store), location);
509        } else {
510            // A definition. The base local itself is not visited, so this occurrence is not counted
511            // toward its use count. There might be other locals still, used in an indexing
512            // projection.
513            self.super_projection(
514                place.as_ref(),
515                PlaceContext::MutatingUse(MutatingUseContext::Projection),
516                location,
517            );
518        }
519    }
520}
521
522impl<'tcx> Visitor<'tcx> for UsedLocals {
523    fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
524        match statement.kind {
525            StatementKind::Intrinsic(..)
526            | StatementKind::Retag(..)
527            | StatementKind::Coverage(..)
528            | StatementKind::FakeRead(..)
529            | StatementKind::PlaceMention(..)
530            | StatementKind::AscribeUserType(..) => {
531                self.super_statement(statement, location);
532            }
533
534            StatementKind::ConstEvalCounter | StatementKind::Nop => {}
535
536            StatementKind::StorageLive(_local) | StatementKind::StorageDead(_local) => {}
537
538            StatementKind::Assign(box (ref place, ref rvalue)) => {
539                if rvalue.is_safe_to_remove() {
540                    self.visit_lhs(place, location);
541                    self.visit_rvalue(rvalue, location);
542                } else {
543                    self.super_statement(statement, location);
544                }
545            }
546
547            StatementKind::SetDiscriminant { ref place, variant_index: _ }
548            | StatementKind::Deinit(ref place)
549            | StatementKind::BackwardIncompatibleDropHint { ref place, reason: _ } => {
550                self.visit_lhs(place, location);
551            }
552        }
553    }
554
555    fn visit_local(&mut self, local: Local, _ctx: PlaceContext, _location: Location) {
556        if self.increment {
557            self.use_count[local] += 1;
558        } else {
559            assert_ne!(self.use_count[local], 0);
560            self.use_count[local] -= 1;
561        }
562    }
563}
564
565/// Removes unused definitions. Updates the used locals to reflect the changes made.
566fn remove_unused_definitions_helper(used_locals: &mut UsedLocals, body: &mut Body<'_>) {
567    // The use counts are updated as we remove the statements. A local might become unused
568    // during the retain operation, leading to a temporary inconsistency (storage statements or
569    // definitions referencing the local might remain). For correctness it is crucial that this
570    // computation reaches a fixed point.
571
572    let mut modified = true;
573    while modified {
574        modified = false;
575
576        for data in body.basic_blocks.as_mut_preserves_cfg() {
577            // Remove unnecessary StorageLive and StorageDead annotations.
578            data.statements.retain(|statement| {
579                let keep = match &statement.kind {
580                    StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
581                        used_locals.is_used(*local)
582                    }
583                    StatementKind::Assign(box (place, _)) => used_locals.is_used(place.local),
584
585                    StatementKind::SetDiscriminant { place, .. }
586                    | StatementKind::BackwardIncompatibleDropHint { place, reason: _ }
587                    | StatementKind::Deinit(place) => used_locals.is_used(place.local),
588                    StatementKind::Nop => false,
589                    _ => true,
590                };
591
592                if !keep {
593                    trace!("removing statement {:?}", statement);
594                    modified = true;
595                    used_locals.statement_removed(statement);
596                }
597
598                keep
599            });
600        }
601    }
602}
603
604struct LocalUpdater<'tcx> {
605    map: IndexVec<Local, Option<Local>>,
606    tcx: TyCtxt<'tcx>,
607}
608
609impl<'tcx> MutVisitor<'tcx> for LocalUpdater<'tcx> {
610    fn tcx(&self) -> TyCtxt<'tcx> {
611        self.tcx
612    }
613
614    fn visit_local(&mut self, l: &mut Local, _: PlaceContext, _: Location) {
615        *l = self.map[*l].unwrap();
616    }
617}