rustc_mir_transform/
check_enums.rs

1use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange};
2use rustc_hir::LangItem;
3use rustc_index::IndexVec;
4use rustc_middle::bug;
5use rustc_middle::mir::visit::Visitor;
6use rustc_middle::mir::*;
7use rustc_middle::ty::layout::PrimitiveExt;
8use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv};
9use rustc_session::Session;
10use tracing::debug;
11
12/// This pass inserts checks for a valid enum discriminant where they are most
13/// likely to find UB, because checking everywhere like Miri would generate too
14/// much MIR.
15pub(super) struct CheckEnums;
16
17impl<'tcx> crate::MirPass<'tcx> for CheckEnums {
18    fn is_enabled(&self, sess: &Session) -> bool {
19        sess.ub_checks()
20    }
21
22    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
23        // This pass emits new panics. If for whatever reason we do not have a panic
24        // implementation, running this pass may cause otherwise-valid code to not compile.
25        if tcx.lang_items().get(LangItem::PanicImpl).is_none() {
26            return;
27        }
28
29        let typing_env = body.typing_env(tcx);
30        let basic_blocks = body.basic_blocks.as_mut();
31        let local_decls = &mut body.local_decls;
32
33        // This operation inserts new blocks. Each insertion changes the Location for all
34        // statements/blocks after. Iterating or visiting the MIR in order would require updating
35        // our current location after every insertion. By iterating backwards, we dodge this issue:
36        // The only Locations that an insertion changes have already been handled.
37        for block in basic_blocks.indices().rev() {
38            for statement_index in (0..basic_blocks[block].statements.len()).rev() {
39                let location = Location { block, statement_index };
40                let statement = &basic_blocks[block].statements[statement_index];
41                let source_info = statement.source_info;
42
43                let mut finder = EnumFinder::new(tcx, local_decls, typing_env);
44                finder.visit_statement(statement, location);
45
46                for check in finder.into_found_enums() {
47                    debug!("Inserting enum check");
48                    let new_block = split_block(basic_blocks, location);
49
50                    match check {
51                        EnumCheckType::Direct { op_size, .. }
52                        | EnumCheckType::WithNiche { op_size, .. }
53                            if op_size.bytes() == 0 =>
54                        {
55                            // It is never valid to use a ZST as a discriminant for an inhabited enum, but that will
56                            // have been caught by the type checker. Do nothing but ensure that a bug has been signaled.
57                            tcx.dcx().span_delayed_bug(
58                                source_info.span,
59                                "cannot build enum discriminant from zero-sized type",
60                            );
61                            basic_blocks[block].terminator = Some(Terminator {
62                                source_info,
63                                kind: TerminatorKind::Goto { target: new_block },
64                            });
65                        }
66                        EnumCheckType::Direct { source_op, discr, op_size, valid_discrs } => {
67                            insert_direct_enum_check(
68                                tcx,
69                                local_decls,
70                                basic_blocks,
71                                block,
72                                source_op,
73                                discr,
74                                op_size,
75                                valid_discrs,
76                                source_info,
77                                new_block,
78                            )
79                        }
80                        EnumCheckType::Uninhabited => insert_uninhabited_enum_check(
81                            tcx,
82                            local_decls,
83                            &mut basic_blocks[block],
84                            source_info,
85                            new_block,
86                        ),
87                        EnumCheckType::WithNiche {
88                            source_op,
89                            discr,
90                            op_size,
91                            offset,
92                            valid_range,
93                        } => insert_niche_check(
94                            tcx,
95                            local_decls,
96                            &mut basic_blocks[block],
97                            source_op,
98                            valid_range,
99                            discr,
100                            op_size,
101                            offset,
102                            source_info,
103                            new_block,
104                        ),
105                    }
106                }
107            }
108        }
109    }
110
111    fn is_required(&self) -> bool {
112        true
113    }
114}
115
116/// Represent the different kind of enum checks we can insert.
117enum EnumCheckType<'tcx> {
118    /// We know we try to create an uninhabited enum from an inhabited variant.
119    Uninhabited,
120    /// We know the enum does no niche optimizations and can thus easily compute
121    /// the valid discriminants.
122    Direct {
123        source_op: Operand<'tcx>,
124        discr: TyAndSize<'tcx>,
125        op_size: Size,
126        valid_discrs: Vec<u128>,
127    },
128    /// We try to construct an enum that has a niche.
129    WithNiche {
130        source_op: Operand<'tcx>,
131        discr: TyAndSize<'tcx>,
132        op_size: Size,
133        offset: Size,
134        valid_range: WrappingRange,
135    },
136}
137
138#[derive(Debug, Copy, Clone)]
139struct TyAndSize<'tcx> {
140    pub ty: Ty<'tcx>,
141    pub size: Size,
142}
143
144/// A [Visitor] that finds the construction of enums and evaluates which checks
145/// we should apply.
146struct EnumFinder<'a, 'tcx> {
147    tcx: TyCtxt<'tcx>,
148    local_decls: &'a mut LocalDecls<'tcx>,
149    typing_env: TypingEnv<'tcx>,
150    enums: Vec<EnumCheckType<'tcx>>,
151}
152
153impl<'a, 'tcx> EnumFinder<'a, 'tcx> {
154    fn new(
155        tcx: TyCtxt<'tcx>,
156        local_decls: &'a mut LocalDecls<'tcx>,
157        typing_env: TypingEnv<'tcx>,
158    ) -> Self {
159        EnumFinder { tcx, local_decls, typing_env, enums: Vec::new() }
160    }
161
162    /// Returns the found enum creations and which checks should be inserted.
163    fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> {
164        self.enums
165    }
166}
167
168impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> {
169    fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
170        if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
171            let ty::Adt(adt_def, _) = ty.kind() else {
172                return;
173            };
174            if !adt_def.is_enum() {
175                return;
176            }
177
178            let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else {
179                return;
180            };
181            let Ok(op_layout) = self
182                .tcx
183                .layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx)))
184            else {
185                return;
186            };
187
188            match enum_layout.variants {
189                Variants::Empty if op_layout.is_uninhabited() => return,
190                // An empty enum that tries to be constructed from an inhabited value, this
191                // is never correct.
192                Variants::Empty => {
193                    // The enum layout is uninhabited but we construct it from sth inhabited.
194                    // This is always UB.
195                    self.enums.push(EnumCheckType::Uninhabited);
196                }
197                // Construction of Single value enums is always fine.
198                Variants::Single { .. } => {}
199                // Construction of an enum with multiple variants but no niche optimizations.
200                Variants::Multiple {
201                    tag_encoding: TagEncoding::Direct,
202                    tag: Scalar::Initialized { value, .. },
203                    ..
204                } => {
205                    let valid_discrs =
206                        adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
207
208                    let discr =
209                        TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
210                    self.enums.push(EnumCheckType::Direct {
211                        source_op: op.to_copy(),
212                        discr,
213                        op_size: op_layout.size,
214                        valid_discrs,
215                    });
216                }
217                // Construction of an enum with multiple variants and niche optimizations.
218                Variants::Multiple {
219                    tag_encoding: TagEncoding::Niche { .. },
220                    tag: Scalar::Initialized { value, valid_range, .. },
221                    tag_field,
222                    ..
223                } => {
224                    let discr =
225                        TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
226                    self.enums.push(EnumCheckType::WithNiche {
227                        source_op: op.to_copy(),
228                        discr,
229                        op_size: op_layout.size,
230                        offset: enum_layout.fields.offset(tag_field.as_usize()),
231                        valid_range,
232                    });
233                }
234                _ => return,
235            }
236
237            self.super_rvalue(rvalue, location);
238        }
239    }
240}
241
242fn split_block(
243    basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
244    location: Location,
245) -> BasicBlock {
246    let block_data = &mut basic_blocks[location.block];
247
248    // Drain every statement after this one and move the current terminator to a new basic block.
249    let new_block = BasicBlockData::new_stmts(
250        block_data.statements.split_off(location.statement_index),
251        block_data.terminator.take(),
252        block_data.is_cleanup,
253    );
254
255    basic_blocks.push(new_block)
256}
257
258/// Inserts the cast of an operand (any type) to a u128 value that holds the discriminant value.
259fn insert_discr_cast_to_u128<'tcx>(
260    tcx: TyCtxt<'tcx>,
261    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
262    block_data: &mut BasicBlockData<'tcx>,
263    source_op: Operand<'tcx>,
264    discr: TyAndSize<'tcx>,
265    op_size: Size,
266    offset: Option<Size>,
267    source_info: SourceInfo,
268) -> Place<'tcx> {
269    let get_ty_for_size = |tcx: TyCtxt<'tcx>, size: Size| -> Ty<'tcx> {
270        match size.bytes() {
271            1 => tcx.types.u8,
272            2 => tcx.types.u16,
273            4 => tcx.types.u32,
274            8 => tcx.types.u64,
275            16 => tcx.types.u128,
276            invalid => bug!("Found discriminant with invalid size, has {} bytes", invalid),
277        }
278    };
279
280    let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() {
281        // The discriminant is less wide than the operand, cast the operand into
282        // [MaybeUninit; N] and then index into it.
283        let mu = Ty::new_maybe_uninit(tcx, tcx.types.u8);
284        let array_len = op_size.bytes();
285        let mu_array_ty = Ty::new_array(tcx, mu, array_len);
286        let mu_array =
287            local_decls.push(LocalDecl::with_source_info(mu_array_ty, source_info)).into();
288        let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_array_ty);
289        block_data
290            .statements
291            .push(Statement::new(source_info, StatementKind::Assign(Box::new((mu_array, rvalue)))));
292
293        // Index into the array of MaybeUninit to get something that is actually
294        // as wide as the discriminant.
295        let offset = offset.unwrap_or(Size::ZERO);
296        let smaller_mu_array = mu_array.project_deeper(
297            &[ProjectionElem::Subslice {
298                from: offset.bytes(),
299                to: offset.bytes() + discr.size.bytes(),
300                from_end: false,
301            }],
302            tcx,
303        );
304
305        (CastKind::Transmute, Operand::Copy(smaller_mu_array))
306    } else {
307        let operand_int_ty = get_ty_for_size(tcx, op_size);
308
309        let op_as_int =
310            local_decls.push(LocalDecl::with_source_info(operand_int_ty, source_info)).into();
311        let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, operand_int_ty);
312        block_data.statements.push(Statement::new(
313            source_info,
314            StatementKind::Assign(Box::new((op_as_int, rvalue))),
315        ));
316
317        (CastKind::IntToInt, Operand::Copy(op_as_int))
318    };
319
320    // Cast the resulting value to the actual discriminant integer type.
321    let rvalue = Rvalue::Cast(cast_kind, discr_ty_bits, discr.ty);
322    let discr_in_discr_ty =
323        local_decls.push(LocalDecl::with_source_info(discr.ty, source_info)).into();
324    block_data.statements.push(Statement::new(
325        source_info,
326        StatementKind::Assign(Box::new((discr_in_discr_ty, rvalue))),
327    ));
328
329    // Cast the discriminant to a u128 (base for comparisons of enum discriminants).
330    let const_u128 = Ty::new_uint(tcx, ty::UintTy::U128);
331    let rvalue = Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_in_discr_ty), const_u128);
332    let discr = local_decls.push(LocalDecl::with_source_info(const_u128, source_info)).into();
333    block_data
334        .statements
335        .push(Statement::new(source_info, StatementKind::Assign(Box::new((discr, rvalue)))));
336
337    discr
338}
339
340fn insert_direct_enum_check<'tcx>(
341    tcx: TyCtxt<'tcx>,
342    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
343    basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
344    current_block: BasicBlock,
345    source_op: Operand<'tcx>,
346    discr: TyAndSize<'tcx>,
347    op_size: Size,
348    discriminants: Vec<u128>,
349    source_info: SourceInfo,
350    new_block: BasicBlock,
351) {
352    // Insert a new target block that is branched to in case of an invalid discriminant.
353    let invalid_discr_block_data = BasicBlockData::new(None, false);
354    let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
355    let block_data = &mut basic_blocks[current_block];
356    let discr_place = insert_discr_cast_to_u128(
357        tcx,
358        local_decls,
359        block_data,
360        source_op,
361        discr,
362        op_size,
363        None,
364        source_info,
365    );
366
367    // Mask out the bits of the discriminant type.
368    let mask = discr.size.unsigned_int_max();
369    let discr_masked =
370        local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
371    let rvalue = Rvalue::BinaryOp(
372        BinOp::BitAnd,
373        Box::new((
374            Operand::Copy(discr_place),
375            Operand::Constant(Box::new(ConstOperand {
376                span: source_info.span,
377                user_ty: None,
378                const_: Const::Val(ConstValue::from_u128(mask), tcx.types.u128),
379            })),
380        )),
381    );
382    block_data
383        .statements
384        .push(Statement::new(source_info, StatementKind::Assign(Box::new((discr_masked, rvalue)))));
385
386    // Branch based on the discriminant value.
387    block_data.terminator = Some(Terminator {
388        source_info,
389        kind: TerminatorKind::SwitchInt {
390            discr: Operand::Copy(discr_masked),
391            targets: SwitchTargets::new(
392                discriminants
393                    .into_iter()
394                    .map(|discr_val| (discr.size.truncate(discr_val), new_block)),
395                invalid_discr_block,
396            ),
397        },
398    });
399
400    // Abort in case of an invalid enum discriminant.
401    basic_blocks[invalid_discr_block].terminator = Some(Terminator {
402        source_info,
403        kind: TerminatorKind::Assert {
404            cond: Operand::Constant(Box::new(ConstOperand {
405                span: source_info.span,
406                user_ty: None,
407                const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
408            })),
409            expected: true,
410            target: new_block,
411            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr_masked))),
412            // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
413            // We never want to insert an unwind into unsafe code, because unwinding could
414            // make a failing UB check turn into much worse UB when we start unwinding.
415            unwind: UnwindAction::Unreachable,
416        },
417    });
418}
419
420fn insert_uninhabited_enum_check<'tcx>(
421    tcx: TyCtxt<'tcx>,
422    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
423    block_data: &mut BasicBlockData<'tcx>,
424    source_info: SourceInfo,
425    new_block: BasicBlock,
426) {
427    let is_ok: Place<'_> =
428        local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
429    block_data.statements.push(Statement::new(
430        source_info,
431        StatementKind::Assign(Box::new((
432            is_ok,
433            Rvalue::Use(Operand::Constant(Box::new(ConstOperand {
434                span: source_info.span,
435                user_ty: None,
436                const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
437            }))),
438        ))),
439    ));
440
441    block_data.terminator = Some(Terminator {
442        source_info,
443        kind: TerminatorKind::Assert {
444            cond: Operand::Copy(is_ok),
445            expected: true,
446            target: new_block,
447            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Constant(Box::new(
448                ConstOperand {
449                    span: source_info.span,
450                    user_ty: None,
451                    const_: Const::Val(ConstValue::from_u128(0), tcx.types.u128),
452                },
453            )))),
454            // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
455            // We never want to insert an unwind into unsafe code, because unwinding could
456            // make a failing UB check turn into much worse UB when we start unwinding.
457            unwind: UnwindAction::Unreachable,
458        },
459    });
460}
461
462fn insert_niche_check<'tcx>(
463    tcx: TyCtxt<'tcx>,
464    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
465    block_data: &mut BasicBlockData<'tcx>,
466    source_op: Operand<'tcx>,
467    valid_range: WrappingRange,
468    discr: TyAndSize<'tcx>,
469    op_size: Size,
470    offset: Size,
471    source_info: SourceInfo,
472    new_block: BasicBlock,
473) {
474    let discr = insert_discr_cast_to_u128(
475        tcx,
476        local_decls,
477        block_data,
478        source_op,
479        discr,
480        op_size,
481        Some(offset),
482        source_info,
483    );
484
485    // Compare the discriminant against the valid_range.
486    let start_const = Operand::Constant(Box::new(ConstOperand {
487        span: source_info.span,
488        user_ty: None,
489        const_: Const::Val(ConstValue::from_u128(valid_range.start), tcx.types.u128),
490    }));
491    let end_start_diff_const = Operand::Constant(Box::new(ConstOperand {
492        span: source_info.span,
493        user_ty: None,
494        const_: Const::Val(
495            ConstValue::from_u128(u128::wrapping_sub(valid_range.end, valid_range.start)),
496            tcx.types.u128,
497        ),
498    }));
499
500    let discr_diff: Place<'_> =
501        local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
502    block_data.statements.push(Statement::new(
503        source_info,
504        StatementKind::Assign(Box::new((
505            discr_diff,
506            Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(discr), start_const))),
507        ))),
508    ));
509
510    let is_ok: Place<'_> =
511        local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
512    block_data.statements.push(Statement::new(
513        source_info,
514        StatementKind::Assign(Box::new((
515            is_ok,
516            Rvalue::BinaryOp(
517                // This is a `WrappingRange`, so make sure to get the wrapping right.
518                BinOp::Le,
519                Box::new((Operand::Copy(discr_diff), end_start_diff_const)),
520            ),
521        ))),
522    ));
523
524    block_data.terminator = Some(Terminator {
525        source_info,
526        kind: TerminatorKind::Assert {
527            cond: Operand::Copy(is_ok),
528            expected: true,
529            target: new_block,
530            msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))),
531            // This calls panic_invalid_enum_construction, which is #[rustc_nounwind].
532            // We never want to insert an unwind into unsafe code, because unwinding could
533            // make a failing UB check turn into much worse UB when we start unwinding.
534            unwind: UnwindAction::Unreachable,
535        },
536    });
537}