1use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange};
2use rustc_hir::LangItem;
3use rustc_index::IndexVec;
4use rustc_middle::bug;
5use rustc_middle::mir::visit::Visitor;
6use rustc_middle::mir::*;
7use rustc_middle::ty::layout::PrimitiveExt;
8use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv};
9use rustc_session::Session;
10use tracing::debug;
11
12pub(super) struct CheckEnums;
16
17impl<'tcx> crate::MirPass<'tcx> for CheckEnums {
18 fn is_enabled(&self, sess: &Session) -> bool {
19 sess.ub_checks()
20 }
21
22 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
23 if tcx.lang_items().get(LangItem::PanicImpl).is_none() {
26 return;
27 }
28
29 let typing_env = body.typing_env(tcx);
30 let basic_blocks = body.basic_blocks.as_mut();
31 let local_decls = &mut body.local_decls;
32
33 for block in basic_blocks.indices().rev() {
38 for statement_index in (0..basic_blocks[block].statements.len()).rev() {
39 let location = Location { block, statement_index };
40 let statement = &basic_blocks[block].statements[statement_index];
41 let source_info = statement.source_info;
42
43 let mut finder = EnumFinder::new(tcx, local_decls, typing_env);
44 finder.visit_statement(statement, location);
45
46 for check in finder.into_found_enums() {
47 debug!("Inserting enum check");
48 let new_block = split_block(basic_blocks, location);
49
50 match check {
51 EnumCheckType::Direct { op_size, .. }
52 | EnumCheckType::WithNiche { op_size, .. }
53 if op_size.bytes() == 0 =>
54 {
55 tcx.dcx().span_delayed_bug(
58 source_info.span,
59 "cannot build enum discriminant from zero-sized type",
60 );
61 basic_blocks[block].terminator = Some(Terminator {
62 source_info,
63 kind: TerminatorKind::Goto { target: new_block },
64 });
65 }
66 EnumCheckType::Direct { source_op, discr, op_size, valid_discrs } => {
67 insert_direct_enum_check(
68 tcx,
69 local_decls,
70 basic_blocks,
71 block,
72 source_op,
73 discr,
74 op_size,
75 valid_discrs,
76 source_info,
77 new_block,
78 )
79 }
80 EnumCheckType::Uninhabited => insert_uninhabited_enum_check(
81 tcx,
82 local_decls,
83 &mut basic_blocks[block],
84 source_info,
85 new_block,
86 ),
87 EnumCheckType::WithNiche {
88 source_op,
89 discr,
90 op_size,
91 offset,
92 valid_range,
93 } => insert_niche_check(
94 tcx,
95 local_decls,
96 &mut basic_blocks[block],
97 source_op,
98 valid_range,
99 discr,
100 op_size,
101 offset,
102 source_info,
103 new_block,
104 ),
105 }
106 }
107 }
108 }
109 }
110
111 fn is_required(&self) -> bool {
112 true
113 }
114}
115
116enum EnumCheckType<'tcx> {
118 Uninhabited,
120 Direct {
123 source_op: Operand<'tcx>,
124 discr: TyAndSize<'tcx>,
125 op_size: Size,
126 valid_discrs: Vec<u128>,
127 },
128 WithNiche {
130 source_op: Operand<'tcx>,
131 discr: TyAndSize<'tcx>,
132 op_size: Size,
133 offset: Size,
134 valid_range: WrappingRange,
135 },
136}
137
138#[derive(Debug, Copy, Clone)]
139struct TyAndSize<'tcx> {
140 pub ty: Ty<'tcx>,
141 pub size: Size,
142}
143
144struct EnumFinder<'a, 'tcx> {
147 tcx: TyCtxt<'tcx>,
148 local_decls: &'a mut LocalDecls<'tcx>,
149 typing_env: TypingEnv<'tcx>,
150 enums: Vec<EnumCheckType<'tcx>>,
151}
152
153impl<'a, 'tcx> EnumFinder<'a, 'tcx> {
154 fn new(
155 tcx: TyCtxt<'tcx>,
156 local_decls: &'a mut LocalDecls<'tcx>,
157 typing_env: TypingEnv<'tcx>,
158 ) -> Self {
159 EnumFinder { tcx, local_decls, typing_env, enums: Vec::new() }
160 }
161
162 fn into_found_enums(self) -> Vec<EnumCheckType<'tcx>> {
164 self.enums
165 }
166}
167
168impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> {
169 fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
170 if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue {
171 let ty::Adt(adt_def, _) = ty.kind() else {
172 return;
173 };
174 if !adt_def.is_enum() {
175 return;
176 }
177
178 let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else {
179 return;
180 };
181 let Ok(op_layout) = self
182 .tcx
183 .layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx)))
184 else {
185 return;
186 };
187
188 match enum_layout.variants {
189 Variants::Empty if op_layout.is_uninhabited() => return,
190 Variants::Empty => {
193 self.enums.push(EnumCheckType::Uninhabited);
196 }
197 Variants::Single { .. } => {}
199 Variants::Multiple {
201 tag_encoding: TagEncoding::Direct,
202 tag: Scalar::Initialized { value, .. },
203 ..
204 } => {
205 let valid_discrs =
206 adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect();
207
208 let discr =
209 TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
210 self.enums.push(EnumCheckType::Direct {
211 source_op: op.to_copy(),
212 discr,
213 op_size: op_layout.size,
214 valid_discrs,
215 });
216 }
217 Variants::Multiple {
219 tag_encoding: TagEncoding::Niche { .. },
220 tag: Scalar::Initialized { value, valid_range, .. },
221 tag_field,
222 ..
223 } => {
224 let discr =
225 TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) };
226 self.enums.push(EnumCheckType::WithNiche {
227 source_op: op.to_copy(),
228 discr,
229 op_size: op_layout.size,
230 offset: enum_layout.fields.offset(tag_field.as_usize()),
231 valid_range,
232 });
233 }
234 _ => return,
235 }
236
237 self.super_rvalue(rvalue, location);
238 }
239 }
240}
241
242fn split_block(
243 basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
244 location: Location,
245) -> BasicBlock {
246 let block_data = &mut basic_blocks[location.block];
247
248 let new_block = BasicBlockData::new_stmts(
250 block_data.statements.split_off(location.statement_index),
251 block_data.terminator.take(),
252 block_data.is_cleanup,
253 );
254
255 basic_blocks.push(new_block)
256}
257
258fn insert_discr_cast_to_u128<'tcx>(
260 tcx: TyCtxt<'tcx>,
261 local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
262 block_data: &mut BasicBlockData<'tcx>,
263 source_op: Operand<'tcx>,
264 discr: TyAndSize<'tcx>,
265 op_size: Size,
266 offset: Option<Size>,
267 source_info: SourceInfo,
268) -> Place<'tcx> {
269 let get_ty_for_size = |tcx: TyCtxt<'tcx>, size: Size| -> Ty<'tcx> {
270 match size.bytes() {
271 1 => tcx.types.u8,
272 2 => tcx.types.u16,
273 4 => tcx.types.u32,
274 8 => tcx.types.u64,
275 16 => tcx.types.u128,
276 invalid => bug!("Found discriminant with invalid size, has {} bytes", invalid),
277 }
278 };
279
280 let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() {
281 let mu = Ty::new_maybe_uninit(tcx, tcx.types.u8);
284 let array_len = op_size.bytes();
285 let mu_array_ty = Ty::new_array(tcx, mu, array_len);
286 let mu_array =
287 local_decls.push(LocalDecl::with_source_info(mu_array_ty, source_info)).into();
288 let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_array_ty);
289 block_data
290 .statements
291 .push(Statement::new(source_info, StatementKind::Assign(Box::new((mu_array, rvalue)))));
292
293 let offset = offset.unwrap_or(Size::ZERO);
296 let smaller_mu_array = mu_array.project_deeper(
297 &[ProjectionElem::Subslice {
298 from: offset.bytes(),
299 to: offset.bytes() + discr.size.bytes(),
300 from_end: false,
301 }],
302 tcx,
303 );
304
305 (CastKind::Transmute, Operand::Copy(smaller_mu_array))
306 } else {
307 let operand_int_ty = get_ty_for_size(tcx, op_size);
308
309 let op_as_int =
310 local_decls.push(LocalDecl::with_source_info(operand_int_ty, source_info)).into();
311 let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, operand_int_ty);
312 block_data.statements.push(Statement::new(
313 source_info,
314 StatementKind::Assign(Box::new((op_as_int, rvalue))),
315 ));
316
317 (CastKind::IntToInt, Operand::Copy(op_as_int))
318 };
319
320 let rvalue = Rvalue::Cast(cast_kind, discr_ty_bits, discr.ty);
322 let discr_in_discr_ty =
323 local_decls.push(LocalDecl::with_source_info(discr.ty, source_info)).into();
324 block_data.statements.push(Statement::new(
325 source_info,
326 StatementKind::Assign(Box::new((discr_in_discr_ty, rvalue))),
327 ));
328
329 let const_u128 = Ty::new_uint(tcx, ty::UintTy::U128);
331 let rvalue = Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_in_discr_ty), const_u128);
332 let discr = local_decls.push(LocalDecl::with_source_info(const_u128, source_info)).into();
333 block_data
334 .statements
335 .push(Statement::new(source_info, StatementKind::Assign(Box::new((discr, rvalue)))));
336
337 discr
338}
339
340fn insert_direct_enum_check<'tcx>(
341 tcx: TyCtxt<'tcx>,
342 local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
343 basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
344 current_block: BasicBlock,
345 source_op: Operand<'tcx>,
346 discr: TyAndSize<'tcx>,
347 op_size: Size,
348 discriminants: Vec<u128>,
349 source_info: SourceInfo,
350 new_block: BasicBlock,
351) {
352 let invalid_discr_block_data = BasicBlockData::new(None, false);
354 let invalid_discr_block = basic_blocks.push(invalid_discr_block_data);
355 let block_data = &mut basic_blocks[current_block];
356 let discr_place = insert_discr_cast_to_u128(
357 tcx,
358 local_decls,
359 block_data,
360 source_op,
361 discr,
362 op_size,
363 None,
364 source_info,
365 );
366
367 let mask = discr.size.unsigned_int_max();
369 let discr_masked =
370 local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
371 let rvalue = Rvalue::BinaryOp(
372 BinOp::BitAnd,
373 Box::new((
374 Operand::Copy(discr_place),
375 Operand::Constant(Box::new(ConstOperand {
376 span: source_info.span,
377 user_ty: None,
378 const_: Const::Val(ConstValue::from_u128(mask), tcx.types.u128),
379 })),
380 )),
381 );
382 block_data
383 .statements
384 .push(Statement::new(source_info, StatementKind::Assign(Box::new((discr_masked, rvalue)))));
385
386 block_data.terminator = Some(Terminator {
388 source_info,
389 kind: TerminatorKind::SwitchInt {
390 discr: Operand::Copy(discr_masked),
391 targets: SwitchTargets::new(
392 discriminants
393 .into_iter()
394 .map(|discr_val| (discr.size.truncate(discr_val), new_block)),
395 invalid_discr_block,
396 ),
397 },
398 });
399
400 basic_blocks[invalid_discr_block].terminator = Some(Terminator {
402 source_info,
403 kind: TerminatorKind::Assert {
404 cond: Operand::Constant(Box::new(ConstOperand {
405 span: source_info.span,
406 user_ty: None,
407 const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
408 })),
409 expected: true,
410 target: new_block,
411 msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr_masked))),
412 unwind: UnwindAction::Unreachable,
416 },
417 });
418}
419
420fn insert_uninhabited_enum_check<'tcx>(
421 tcx: TyCtxt<'tcx>,
422 local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
423 block_data: &mut BasicBlockData<'tcx>,
424 source_info: SourceInfo,
425 new_block: BasicBlock,
426) {
427 let is_ok: Place<'_> =
428 local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
429 block_data.statements.push(Statement::new(
430 source_info,
431 StatementKind::Assign(Box::new((
432 is_ok,
433 Rvalue::Use(Operand::Constant(Box::new(ConstOperand {
434 span: source_info.span,
435 user_ty: None,
436 const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool),
437 }))),
438 ))),
439 ));
440
441 block_data.terminator = Some(Terminator {
442 source_info,
443 kind: TerminatorKind::Assert {
444 cond: Operand::Copy(is_ok),
445 expected: true,
446 target: new_block,
447 msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Constant(Box::new(
448 ConstOperand {
449 span: source_info.span,
450 user_ty: None,
451 const_: Const::Val(ConstValue::from_u128(0), tcx.types.u128),
452 },
453 )))),
454 unwind: UnwindAction::Unreachable,
458 },
459 });
460}
461
462fn insert_niche_check<'tcx>(
463 tcx: TyCtxt<'tcx>,
464 local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
465 block_data: &mut BasicBlockData<'tcx>,
466 source_op: Operand<'tcx>,
467 valid_range: WrappingRange,
468 discr: TyAndSize<'tcx>,
469 op_size: Size,
470 offset: Size,
471 source_info: SourceInfo,
472 new_block: BasicBlock,
473) {
474 let discr = insert_discr_cast_to_u128(
475 tcx,
476 local_decls,
477 block_data,
478 source_op,
479 discr,
480 op_size,
481 Some(offset),
482 source_info,
483 );
484
485 let start_const = Operand::Constant(Box::new(ConstOperand {
487 span: source_info.span,
488 user_ty: None,
489 const_: Const::Val(ConstValue::from_u128(valid_range.start), tcx.types.u128),
490 }));
491 let end_start_diff_const = Operand::Constant(Box::new(ConstOperand {
492 span: source_info.span,
493 user_ty: None,
494 const_: Const::Val(
495 ConstValue::from_u128(u128::wrapping_sub(valid_range.end, valid_range.start)),
496 tcx.types.u128,
497 ),
498 }));
499
500 let discr_diff: Place<'_> =
501 local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into();
502 block_data.statements.push(Statement::new(
503 source_info,
504 StatementKind::Assign(Box::new((
505 discr_diff,
506 Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(discr), start_const))),
507 ))),
508 ));
509
510 let is_ok: Place<'_> =
511 local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
512 block_data.statements.push(Statement::new(
513 source_info,
514 StatementKind::Assign(Box::new((
515 is_ok,
516 Rvalue::BinaryOp(
517 BinOp::Le,
519 Box::new((Operand::Copy(discr_diff), end_start_diff_const)),
520 ),
521 ))),
522 ));
523
524 block_data.terminator = Some(Terminator {
525 source_info,
526 kind: TerminatorKind::Assert {
527 cond: Operand::Copy(is_ok),
528 expected: true,
529 target: new_block,
530 msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))),
531 unwind: UnwindAction::Unreachable,
535 },
536 });
537}