rustc_mir_transform/early_otherwise_branch.rs
1use std::fmt::Debug;
2
3use rustc_middle::mir::*;
4use rustc_middle::ty::{Ty, TyCtxt};
5use tracing::trace;
6
7use super::simplify::simplify_cfg;
8use crate::patch::MirPatch;
9
10/// This pass optimizes something like
11/// ```ignore (syntax-highlighting-only)
12/// let x: Option<()>;
13/// let y: Option<()>;
14/// match (x,y) {
15/// (Some(_), Some(_)) => {0},
16/// (None, None) => {2},
17/// _ => {1}
18/// }
19/// ```
20/// into something like
21/// ```ignore (syntax-highlighting-only)
22/// let x: Option<()>;
23/// let y: Option<()>;
24/// let discriminant_x = std::mem::discriminant(x);
25/// let discriminant_y = std::mem::discriminant(y);
26/// if discriminant_x == discriminant_y {
27/// match x {
28/// Some(_) => 0,
29/// None => 2,
30/// }
31/// } else {
32/// 1
33/// }
34/// ```
35///
36/// Specifically, it looks for instances of control flow like this:
37/// ```text
38///
39/// =================
40/// | BB1 |
41/// |---------------| ============================
42/// | ... | /------> | BBC |
43/// |---------------| | |--------------------------|
44/// | switchInt(Q) | | | _cl = discriminant(P) |
45/// | c | --------/ |--------------------------|
46/// | d | -------\ | switchInt(_cl) |
47/// | ... | | | c | ---> BBC.2
48/// | otherwise | --\ | /--- | otherwise |
49/// ================= | | | ============================
50/// | | |
51/// ================= | | |
52/// | BBU | <-| | | ============================
53/// |---------------| \-------> | BBD |
54/// |---------------| | |--------------------------|
55/// | unreachable | | | _dl = discriminant(P) |
56/// ================= | |--------------------------|
57/// | | switchInt(_dl) |
58/// ================= | | d | ---> BBD.2
59/// | BB9 | <--------------- | otherwise |
60/// |---------------| ============================
61/// | ... |
62/// =================
63/// ```
64/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU`. In the
65/// code:
66/// - `BB1` is `parent` and `BBC, BBD` are children
67/// - `P` is `child_place`
68/// - `child_ty` is the type of `_cl`.
69/// - `Q` is `parent_op`.
70/// - `parent_ty` is the type of `Q`.
71/// - `BB9` is `destination`
72/// All this is then transformed into:
73/// ```text
74///
75/// =======================
76/// | BB1 |
77/// |---------------------| ============================
78/// | ... | /------> | BBEq |
79/// | _s = discriminant(P)| | |--------------------------|
80/// | _t = Ne(Q, _s) | | |--------------------------|
81/// |---------------------| | | switchInt(Q) |
82/// | switchInt(_t) | | | c | ---> BBC.2
83/// | false | --------/ | d | ---> BBD.2
84/// | otherwise | /--------- | otherwise |
85/// ======================= | ============================
86/// |
87/// ================= |
88/// | BB9 | <-----------/
89/// |---------------|
90/// | ... |
91/// =================
92/// ```
93pub(super) struct EarlyOtherwiseBranch;
94
95impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
96 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
97 sess.mir_opt_level() >= 2
98 }
99
100 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
101 trace!("running EarlyOtherwiseBranch on {:?}", body.source);
102
103 let mut should_cleanup = false;
104
105 // Also consider newly generated bbs in the same pass
106 for parent in body.basic_blocks.indices() {
107 let bbs = &*body.basic_blocks;
108 let Some(opt_data) = evaluate_candidate(tcx, body, parent) else { continue };
109
110 trace!("SUCCESS: found optimization possibility to apply: {opt_data:?}");
111
112 should_cleanup = true;
113
114 let TerminatorKind::SwitchInt { discr: parent_op, targets: parent_targets } =
115 &bbs[parent].terminator().kind
116 else {
117 unreachable!()
118 };
119 // Always correct since we can only switch on `Copy` types
120 let parent_op = match parent_op {
121 Operand::Move(x) => Operand::Copy(*x),
122 Operand::Copy(x) => Operand::Copy(*x),
123 Operand::Constant(x) => Operand::Constant(x.clone()),
124 };
125 let parent_ty = parent_op.ty(body.local_decls(), tcx);
126 let statements_before = bbs[parent].statements.len();
127 let parent_end = Location { block: parent, statement_index: statements_before };
128
129 let mut patch = MirPatch::new(body);
130
131 let (second_discriminant_temp, second_operand) = if opt_data.need_hoist_discriminant {
132 // create temp to store second discriminant in, `_s` in example above
133 let second_discriminant_temp =
134 patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
135
136 patch.add_statement(
137 parent_end,
138 StatementKind::StorageLive(second_discriminant_temp),
139 );
140
141 // create assignment of discriminant
142 patch.add_assign(
143 parent_end,
144 Place::from(second_discriminant_temp),
145 Rvalue::Discriminant(opt_data.child_place),
146 );
147 (
148 Some(second_discriminant_temp),
149 Operand::Move(Place::from(second_discriminant_temp)),
150 )
151 } else {
152 (None, Operand::Copy(opt_data.child_place))
153 };
154
155 // create temp to store inequality comparison between the two discriminants, `_t` in
156 // example above
157 let nequal = BinOp::Ne;
158 let comp_res_type = nequal.ty(tcx, parent_ty, opt_data.child_ty);
159 let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
160 patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
161
162 // create inequality comparison
163 let comp_rvalue =
164 Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
165 patch.add_statement(
166 parent_end,
167 StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
168 );
169
170 let eq_new_targets = parent_targets.iter().map(|(value, child)| {
171 let TerminatorKind::SwitchInt { targets, .. } = &bbs[child].terminator().kind
172 else {
173 unreachable!()
174 };
175 (value, targets.target_for_value(value))
176 });
177 // The otherwise either is the same target branch or an unreachable.
178 let eq_targets = SwitchTargets::new(eq_new_targets, parent_targets.otherwise());
179
180 // Create `bbEq` in example above
181 let eq_switch = BasicBlockData::new(
182 Some(Terminator {
183 source_info: bbs[parent].terminator().source_info,
184 kind: TerminatorKind::SwitchInt {
185 // switch on the first discriminant, so we can mark the second one as dead
186 discr: parent_op,
187 targets: eq_targets,
188 },
189 }),
190 bbs[parent].is_cleanup,
191 );
192
193 let eq_bb = patch.new_block(eq_switch);
194
195 // Jump to it on the basis of the inequality comparison
196 let true_case = opt_data.destination;
197 let false_case = eq_bb;
198 patch.patch_terminator(
199 parent,
200 TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
201 );
202
203 if let Some(second_discriminant_temp) = second_discriminant_temp {
204 // generate StorageDead for the second_discriminant_temp not in use anymore
205 patch.add_statement(
206 parent_end,
207 StatementKind::StorageDead(second_discriminant_temp),
208 );
209 }
210
211 // Generate a StorageDead for comp_temp in each of the targets, since we moved it into
212 // the switch
213 for bb in [false_case, true_case].iter() {
214 patch.add_statement(
215 Location { block: *bb, statement_index: 0 },
216 StatementKind::StorageDead(comp_temp),
217 );
218 }
219
220 patch.apply(body);
221 }
222
223 // Since this optimization adds new basic blocks and invalidates others,
224 // clean up the cfg to make it nicer for other passes
225 if should_cleanup {
226 simplify_cfg(tcx, body);
227 }
228 }
229
230 fn is_required(&self) -> bool {
231 false
232 }
233}
234
235#[derive(Debug)]
236struct OptimizationData<'tcx> {
237 destination: BasicBlock,
238 child_place: Place<'tcx>,
239 child_ty: Ty<'tcx>,
240 child_source: SourceInfo,
241 need_hoist_discriminant: bool,
242}
243
244fn evaluate_candidate<'tcx>(
245 tcx: TyCtxt<'tcx>,
246 body: &Body<'tcx>,
247 parent: BasicBlock,
248) -> Option<OptimizationData<'tcx>> {
249 let bbs = &body.basic_blocks;
250 // NB: If this BB is a cleanup, we may need to figure out what else needs to be handled.
251 if bbs[parent].is_cleanup {
252 return None;
253 }
254 let TerminatorKind::SwitchInt { targets, discr: parent_discr } = &bbs[parent].terminator().kind
255 else {
256 return None;
257 };
258 let parent_ty = parent_discr.ty(body.local_decls(), tcx);
259 let (_, child) = targets.iter().next()?;
260
261 let Terminator {
262 kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
263 source_info,
264 } = bbs[child].terminator()
265 else {
266 return None;
267 };
268 let child_ty = child_discr.ty(body.local_decls(), tcx);
269 if child_ty != parent_ty {
270 return None;
271 }
272
273 // We only handle:
274 // ```
275 // bb4: {
276 // _8 = discriminant((_3.1: Enum1));
277 // switchInt(move _8) -> [2: bb7, otherwise: bb1];
278 // }
279 // ```
280 // and
281 // ```
282 // bb2: {
283 // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
284 // }
285 // ```
286 if bbs[child].statements.len() > 1 {
287 return None;
288 }
289
290 // When thie BB has exactly one statement, this statement should be discriminant.
291 let need_hoist_discriminant = bbs[child].statements.len() == 1;
292 let child_place = if need_hoist_discriminant {
293 if !bbs[targets.otherwise()].is_empty_unreachable() {
294 // Someone could write code like this:
295 // ```rust
296 // let Q = val;
297 // if discriminant(P) == otherwise {
298 // let ptr = &mut Q as *mut _ as *mut u8;
299 // // It may be difficult for us to effectively determine whether values are valid.
300 // // Invalid values can come from all sorts of corners.
301 // unsafe { *ptr = 10; }
302 // }
303 //
304 // match P {
305 // A => match Q {
306 // A => {
307 // // code
308 // }
309 // _ => {
310 // // don't use Q
311 // }
312 // }
313 // _ => {
314 // // don't use Q
315 // }
316 // };
317 // ```
318 //
319 // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
320 // invalid value, which is UB.
321 // In order to fix this, **we would either need to show that the discriminant computation of
322 // `place` is computed in all branches**.
323 // FIXME(#95162) For the moment, we adopt a conservative approach and
324 // consider only the `otherwise` branch has no statements and an unreachable terminator.
325 return None;
326 }
327 // Handle:
328 // ```
329 // bb4: {
330 // _8 = discriminant((_3.1: Enum1));
331 // switchInt(move _8) -> [2: bb7, otherwise: bb1];
332 // }
333 // ```
334 let [
335 Statement {
336 kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))),
337 ..
338 },
339 ] = bbs[child].statements.as_slice()
340 else {
341 return None;
342 };
343 *child_place
344 } else {
345 // Handle:
346 // ```
347 // bb2: {
348 // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
349 // }
350 // ```
351 let Operand::Copy(child_place) = child_discr else {
352 return None;
353 };
354 *child_place
355 };
356 let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
357 {
358 child_targets.otherwise()
359 } else {
360 targets.otherwise()
361 };
362
363 // Verify that the optimization is legal for each branch
364 for (value, child) in targets.iter() {
365 if !verify_candidate_branch(
366 &bbs[child],
367 value,
368 child_place,
369 destination,
370 need_hoist_discriminant,
371 ) {
372 return None;
373 }
374 }
375 Some(OptimizationData {
376 destination,
377 child_place,
378 child_ty,
379 child_source: *source_info,
380 need_hoist_discriminant,
381 })
382}
383
384fn verify_candidate_branch<'tcx>(
385 branch: &BasicBlockData<'tcx>,
386 value: u128,
387 place: Place<'tcx>,
388 destination: BasicBlock,
389 need_hoist_discriminant: bool,
390) -> bool {
391 // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
392 let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
393 return false;
394 };
395 if need_hoist_discriminant {
396 // If we need hoist discriminant, the branch must have exactly one statement.
397 let [statement] = branch.statements.as_slice() else {
398 return false;
399 };
400 // The statement must assign the discriminant of `place`.
401 let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) =
402 statement.kind
403 else {
404 return false;
405 };
406 if from_place != place {
407 return false;
408 }
409 // The assignment must invalidate a local that terminate on a `SwitchInt`.
410 if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
411 return false;
412 }
413 } else {
414 // If we don't need hoist discriminant, the branch must not have any statements.
415 if !branch.statements.is_empty() {
416 return false;
417 }
418 // The place on `SwitchInt` must be the same.
419 if *switch_op != Operand::Copy(place) {
420 return false;
421 }
422 }
423 // It must fall through to `destination` if the switch misses.
424 if destination != targets.otherwise() {
425 return false;
426 }
427 // It must have exactly one branch for value `value` and have no more branches.
428 let mut iter = targets.iter();
429 let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
430 return false;
431 };
432 target_value == value
433}