rustc_mir_transform/
early_otherwise_branch.rs

1use std::fmt::Debug;
2
3use rustc_middle::mir::*;
4use rustc_middle::ty::{Ty, TyCtxt};
5use tracing::trace;
6
7use super::simplify::simplify_cfg;
8use crate::patch::MirPatch;
9
10/// This pass optimizes something like
11/// ```ignore (syntax-highlighting-only)
12/// let x: Option<()>;
13/// let y: Option<()>;
14/// match (x,y) {
15///     (Some(_), Some(_)) => {0},
16///     (None, None) => {2},
17///     _ => {1}
18/// }
19/// ```
20/// into something like
21/// ```ignore (syntax-highlighting-only)
22/// let x: Option<()>;
23/// let y: Option<()>;
24/// let discriminant_x = std::mem::discriminant(x);
25/// let discriminant_y = std::mem::discriminant(y);
26/// if discriminant_x == discriminant_y {
27///     match x {
28///         Some(_) => 0,
29///         None => 2,
30///     }
31/// } else {
32///     1
33/// }
34/// ```
35///
36/// Specifically, it looks for instances of control flow like this:
37/// ```text
38///
39///     =================
40///     |      BB1      |
41///     |---------------|                  ============================
42///     |     ...       |         /------> |            BBC           |
43///     |---------------|         |        |--------------------------|
44///     |  switchInt(Q) |         |        |   _cl = discriminant(P)  |
45///     |       c       | --------/        |--------------------------|
46///     |       d       | -------\         |       switchInt(_cl)     |
47///     |      ...      |        |         |            c             | ---> BBC.2
48///     |    otherwise  | --\    |    /--- |         otherwise        |
49///     =================   |    |    |    ============================
50///                         |    |    |
51///     =================   |    |    |
52///     |      BBU      | <-|    |    |    ============================
53///     |---------------|        \-------> |            BBD           |
54///     |---------------|             |    |--------------------------|
55///     |  unreachable  |             |    |   _dl = discriminant(P)  |
56///     =================             |    |--------------------------|
57///                                   |    |       switchInt(_dl)     |
58///     =================             |    |            d             | ---> BBD.2
59///     |      BB9      | <--------------- |         otherwise        |
60///     |---------------|                  ============================
61///     |      ...      |
62///     =================
63/// ```
64/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
65/// code:
66///  - `BB1` is `parent` and `BBC, BBD` are children
67///  - `P` is `child_place`
68///  - `child_ty` is the type of `_cl`.
69///  - `Q` is `parent_op`.
70///  - `parent_ty` is the type of `Q`.
71///  - `BB9` is `destination`
72/// All this is then transformed into:
73/// ```text
74///
75///     =======================
76///     |          BB1        |
77///     |---------------------|                  ============================
78///     |          ...        |         /------> |           BBEq           |
79///     | _s = discriminant(P)|         |        |--------------------------|
80///     | _t = Ne(Q, _s)      |         |        |--------------------------|
81///     |---------------------|         |        |       switchInt(Q)       |
82///     |     switchInt(_t)   |         |        |            c             | ---> BBC.2
83///     |        false        | --------/        |            d             | ---> BBD.2
84///     |       otherwise     |       /--------- |         otherwise        |
85///     =======================       |          ============================
86///                                   |
87///     =================             |
88///     |      BB9      | <-----------/
89///     |---------------|
90///     |      ...      |
91///     =================
92/// ```
93pub(super) struct EarlyOtherwiseBranch;
94
95impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
96    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
97        sess.mir_opt_level() >= 2
98    }
99
100    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
101        trace!("running EarlyOtherwiseBranch on {:?}", body.source);
102
103        let mut should_cleanup = false;
104
105        // Also consider newly generated bbs in the same pass
106        for parent in body.basic_blocks.indices() {
107            let bbs = &*body.basic_blocks;
108            let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue };
109
110            trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}");
111
112            should_cleanup = true;
113
114            let TerminatorKind::SwitchInt { discr: parent_op, targets: parent_targets } =
115                &bbs[parent].terminator().kind
116            else {
117                unreachable!()
118            };
119            // Always correct since we can only switch on `Copy` types
120            let parent_op = match parent_op {
121                Operand::Move(x) => Operand::Copy(*x),
122                Operand::Copy(x) => Operand::Copy(*x),
123                Operand::Constant(x) => Operand::Constant(x.clone()),
124            };
125            let parent_ty = parent_op.ty(body.local_decls(), tcx);
126            let statements_before = bbs[parent].statements.len();
127            let parent_end = Location { block: parent, statement_index: statements_before };
128
129            let mut patch = MirPatch::new(body);
130
131            let (second_discriminant_temp, second_operand) = if opt_data.need_hoist_discriminant {
132                // create temp to store second discriminant in, `_s` in example above
133                let second_discriminant_temp =
134                    patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
135
136                patch.add_statement(
137                    parent_end,
138                    StatementKind::StorageLive(second_discriminant_temp),
139                );
140
141                // create assignment of discriminant
142                patch.add_assign(
143                    parent_end,
144                    Place::from(second_discriminant_temp),
145                    Rvalue::Discriminant(opt_data.child_place),
146                );
147                (
148                    Some(second_discriminant_temp),
149                    Operand::Move(Place::from(second_discriminant_temp)),
150                )
151            } else {
152                (None, Operand::Copy(opt_data.child_place))
153            };
154
155            // create temp to store inequality comparison between the two discriminants, `_t` in
156            // example above
157            let nequal = BinOp::Ne;
158            let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty);
159            let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
160            patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
161
162            // create inequality comparison
163            let comp_rvalue =
164                Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
165            patch.add_statement(
166                parent_end,
167                StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
168            );
169
170            let eq_new_targets = parent_targets.iter().map(|(value, child)| {
171                let TerminatorKind::SwitchInt { targets, .. } = &bbs[child].terminator().kind
172                else {
173                    unreachable!()
174                };
175                (value, targets.target_for_value(value))
176            });
177            // The otherwise either is the same target branch or an unreachable.
178            let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise());
179
180            // Create `bbEq` in example above
181            let eq_switch = BasicBlockData::new(
182                Some(Terminator {
183                    source_info: bbs[parent].terminator().source_info,
184                    kind: TerminatorKind::SwitchInt {
185                        // switch on the first discriminant, so we can mark the second one as dead
186                        discr: parent_op,
187                        targets: eq_targets,
188                    },
189                }),
190                bbs[parent].is_cleanup,
191            );
192
193            let eq_bb = patch.new_block(eq_switch);
194
195            // Jump to it on the basis of the inequality comparison
196            let true_case = opt_data.destination;
197            let false_case = eq_bb;
198            patch.patch_terminator(
199                parent,
200                TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
201            );
202
203            if let Some(second_discriminant_temp) = second_discriminant_temp {
204                // generate StorageDead for the second_discriminant_temp not in use anymore
205                patch.add_statement(
206                    parent_end,
207                    StatementKind::StorageDead(second_discriminant_temp),
208                );
209            }
210
211            // Generate a StorageDead for comp_temp in each of the targets, since we moved it into
212            // the switch
213            for bb in [false_case, true_case].iter() {
214                patch.add_statement(
215                    Location { block: *bb, statement_index: 0 },
216                    StatementKind::StorageDead(comp_temp),
217                );
218            }
219
220            patch.apply(body);
221        }
222
223        // Since this optimization adds new basic blocks and invalidates others,
224        // clean up the cfg to make it nicer for other passes
225        if should_cleanup {
226            simplify_cfg(tcx, body);
227        }
228    }
229
230    fn is_required(&self) -> bool {
231        false
232    }
233}
234
235#[derive(Debug)]
236struct OptimizationData<'tcx> {
237    destination: BasicBlock,
238    child_place: Place<'tcx>,
239    child_ty: Ty<'tcx>,
240    child_source: SourceInfo,
241    need_hoist_discriminant: bool,
242}
243
244fn evaluate_candidate<'tcx>(
245    tcx: TyCtxt<'tcx>,
246    body: &Body<'tcx>,
247    parent: BasicBlock,
248) -> Option<OptimizationData<'tcx>> {
249    let bbs = &body.basic_blocks;
250    // NB: If this BB is a cleanup, we may need to figure out what else needs to be handled.
251    if bbs[parent].is_cleanup {
252        return None;
253    }
254    let TerminatorKind::SwitchInt { targets, discr: parent_discr } = &bbs[parent].terminator().kind
255    else {
256        return None;
257    };
258    let parent_ty = parent_discr.ty(body.local_decls(), tcx);
259    let (_, child) = targets.iter().next()?;
260
261    let Terminator {
262        kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
263        source_info,
264    } = bbs[child].terminator()
265    else {
266        return None;
267    };
268    let child_ty = child_discr.ty(body.local_decls(), tcx);
269    if child_ty != parent_ty {
270        return None;
271    }
272
273    // We only handle:
274    // ```
275    // bb4: {
276    //     _8 = discriminant((_3.1: Enum1));
277    //    switchInt(move _8) -> [2: bb7, otherwise: bb1];
278    // }
279    // ```
280    // and
281    // ```
282    // bb2: {
283    //     switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
284    // }
285    // ```
286    if bbs[child].statements.len() > 1 {
287        return None;
288    }
289
290    // When thie BB has exactly one statement, this statement should be discriminant.
291    let need_hoist_discriminant = bbs[child].statements.len() == 1;
292    let child_place = if need_hoist_discriminant {
293        if !bbs[targets.otherwise()].is_empty_unreachable() {
294            // Someone could write code like this:
295            // ```rust
296            // let Q = val;
297            // if discriminant(P) == otherwise {
298            //     let ptr = &mut Q as *mut _ as *mut u8;
299            //     // It may be difficult for us to effectively determine whether values are valid.
300            //     // Invalid values can come from all sorts of corners.
301            //     unsafe { *ptr = 10; }
302            // }
303            //
304            // match P {
305            //    A => match Q {
306            //        A => {
307            //            // code
308            //        }
309            //        _ => {
310            //            // don't use Q
311            //        }
312            //    }
313            //    _ => {
314            //        // don't use Q
315            //    }
316            // };
317            // ```
318            //
319            // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
320            // invalid value, which is UB.
321            // In order to fix this, **we would either need to show that the discriminant computation of
322            // `place` is computed in all branches**.
323            // FIXME(#95162) For the moment, we adopt a conservative approach and
324            // consider only the `otherwise` branch has no statements and an unreachable terminator.
325            return None;
326        }
327        // Handle:
328        // ```
329        // bb4: {
330        //     _8 = discriminant((_3.1: Enum1));
331        //    switchInt(move _8) -> [2: bb7, otherwise: bb1];
332        // }
333        // ```
334        let [
335            Statement {
336                kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))),
337                ..
338            },
339        ] = bbs[child].statements.as_slice()
340        else {
341            return None;
342        };
343        *child_place
344    } else {
345        // Handle:
346        // ```
347        // bb2: {
348        //     switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
349        // }
350        // ```
351        let Operand::Copy(child_place) = child_discr else {
352            return None;
353        };
354        *child_place
355    };
356    let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
357    {
358        child_targets.otherwise()
359    } else {
360        targets.otherwise()
361    };
362
363    // Verify that the optimization is legal for each branch
364    for (value, child) in targets.iter() {
365        if !verify_candidate_branch(
366            &bbs[child],
367            value,
368            child_place,
369            destination,
370            need_hoist_discriminant,
371        ) {
372            return None;
373        }
374    }
375    Some(OptimizationData {
376        destination,
377        child_place,
378        child_ty,
379        child_source: *source_info,
380        need_hoist_discriminant,
381    })
382}
383
384fn verify_candidate_branch<'tcx>(
385    branch: &BasicBlockData<'tcx>,
386    value: u128,
387    place: Place<'tcx>,
388    destination: BasicBlock,
389    need_hoist_discriminant: bool,
390) -> bool {
391    // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
392    let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
393        return false;
394    };
395    if need_hoist_discriminant {
396        // If we need hoist discriminant, the branch must have exactly one statement.
397        let [statement] = branch.statements.as_slice() else {
398            return false;
399        };
400        // The statement must assign the discriminant of `place`.
401        let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) =
402            statement.kind
403        else {
404            return false;
405        };
406        if from_place != place {
407            return false;
408        }
409        // The assignment must invalidate a local that terminate on a `SwitchInt`.
410        if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
411            return false;
412        }
413    } else {
414        // If we don't need hoist discriminant, the branch must not have any statements.
415        if !branch.statements.is_empty() {
416            return false;
417        }
418        // The place on `SwitchInt` must be the same.
419        if *switch_op != Operand::Copy(place) {
420            return false;
421        }
422    }
423    // It must fall through to `destination` if the switch misses.
424    if destination != targets.otherwise() {
425        return false;
426    }
427    // It must have exactly one branch for value `value` and have no more branches.
428    let mut iter = targets.iter();
429    let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
430        return false;
431    };
432    target_value == value
433}