1use rustc_hir::lang_items::LangItem;
2use rustc_index::IndexVec;
3use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor};
4use rustc_middle::mir::*;
5use rustc_middle::ty::{self, Ty, TyCtxt};
6use tracing::{debug, trace};
7
8pub(crate) struct PointerCheck<'tcx> {
11 pub(crate) cond: Operand<'tcx>,
12 pub(crate) assert_kind: Box<AssertKind<Operand<'tcx>>>,
13}
14
15#[derive(Copy, Clone)]
19pub(crate) enum BorrowedFieldProjectionMode {
20 FollowProjections,
21 NoFollowProjections,
22}
23
24pub(crate) fn check_pointers<'tcx, F>(
44 tcx: TyCtxt<'tcx>,
45 body: &mut Body<'tcx>,
46 excluded_pointees: &[Ty<'tcx>],
47 on_finding: F,
48 field_projection_mode: BorrowedFieldProjectionMode,
49) where
50 F: Fn(
51 TyCtxt<'tcx>,
52 Place<'tcx>,
53 Ty<'tcx>,
54 PlaceContext,
55 &mut IndexVec<Local, LocalDecl<'tcx>>,
56 &mut Vec<Statement<'tcx>>,
57 SourceInfo,
58 ) -> PointerCheck<'tcx>,
59{
60 if tcx.lang_items().get(LangItem::PanicImpl).is_none() {
63 return;
64 }
65
66 let typing_env = body.typing_env(tcx);
67 let basic_blocks = body.basic_blocks.as_mut();
68 let local_decls = &mut body.local_decls;
69
70 for block in basic_blocks.indices().rev() {
75 for statement_index in (0..basic_blocks[block].statements.len()).rev() {
76 let location = Location { block, statement_index };
77 let statement = &basic_blocks[block].statements[statement_index];
78 let source_info = statement.source_info;
79
80 let mut finder = PointerFinder::new(
81 tcx,
82 local_decls,
83 typing_env,
84 excluded_pointees,
85 field_projection_mode,
86 );
87 finder.visit_statement(statement, location);
88
89 for (local, ty, context) in finder.into_found_pointers() {
90 debug!("Inserting check for {:?}", ty);
91 let new_block = split_block(basic_blocks, location);
92
93 let block_data = &mut basic_blocks[block];
97 let pointer_check = on_finding(
98 tcx,
99 local,
100 ty,
101 context,
102 local_decls,
103 &mut block_data.statements,
104 source_info,
105 );
106 block_data.terminator = Some(Terminator {
107 source_info,
108 kind: TerminatorKind::Assert {
109 cond: pointer_check.cond,
110 expected: true,
111 target: new_block,
112 msg: pointer_check.assert_kind,
113 unwind: UnwindAction::Unreachable,
118 },
119 });
120 }
121 }
122 }
123}
124
125struct PointerFinder<'a, 'tcx> {
126 tcx: TyCtxt<'tcx>,
127 local_decls: &'a mut LocalDecls<'tcx>,
128 typing_env: ty::TypingEnv<'tcx>,
129 pointers: Vec<(Place<'tcx>, Ty<'tcx>, PlaceContext)>,
130 excluded_pointees: &'a [Ty<'tcx>],
131 field_projection_mode: BorrowedFieldProjectionMode,
132}
133
134impl<'a, 'tcx> PointerFinder<'a, 'tcx> {
135 fn new(
136 tcx: TyCtxt<'tcx>,
137 local_decls: &'a mut LocalDecls<'tcx>,
138 typing_env: ty::TypingEnv<'tcx>,
139 excluded_pointees: &'a [Ty<'tcx>],
140 field_projection_mode: BorrowedFieldProjectionMode,
141 ) -> Self {
142 PointerFinder {
143 tcx,
144 local_decls,
145 typing_env,
146 excluded_pointees,
147 pointers: Vec::new(),
148 field_projection_mode,
149 }
150 }
151
152 fn into_found_pointers(self) -> Vec<(Place<'tcx>, Ty<'tcx>, PlaceContext)> {
153 self.pointers
154 }
155
156 fn should_visit_place(&self, context: PlaceContext) -> bool {
161 match context {
162 PlaceContext::MutatingUse(
163 MutatingUseContext::Store
164 | MutatingUseContext::Call
165 | MutatingUseContext::Yield
166 | MutatingUseContext::Drop
167 | MutatingUseContext::Borrow,
168 ) => true,
169 PlaceContext::NonMutatingUse(
170 NonMutatingUseContext::Copy
171 | NonMutatingUseContext::Move
172 | NonMutatingUseContext::SharedBorrow,
173 ) => true,
174 _ => false,
175 }
176 }
177}
178
179impl<'a, 'tcx> Visitor<'tcx> for PointerFinder<'a, 'tcx> {
180 fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
181 if !self.should_visit_place(context) || !place.is_indirect() {
182 return;
183 }
184
185 let pointer = Place::from(place.local);
187 let pointer_ty = pointer.ty(self.local_decls, self.tcx).ty;
188
189 let &ty::RawPtr(mut pointee_ty, _) = pointer_ty.kind() else {
191 trace!("Indirect, but not based on an raw ptr, not checking {:?}", place);
192 return;
193 };
194
195 if matches!(self.field_projection_mode, BorrowedFieldProjectionMode::FollowProjections)
198 && matches!(
199 context,
200 PlaceContext::NonMutatingUse(NonMutatingUseContext::SharedBorrow)
201 | PlaceContext::MutatingUse(MutatingUseContext::Borrow)
202 )
203 {
204 pointee_ty = place.ty(self.local_decls, self.tcx).ty;
206 }
207
208 if !pointee_ty.is_sized(self.tcx, self.typing_env) {
210 trace!("Raw pointer, but pointee is not known to be sized: {:?}", pointer_ty);
211 return;
212 }
213
214 let element_ty = match pointee_ty.kind() {
216 ty::Array(ty, _) => *ty,
217 _ => pointee_ty,
218 };
219 if self.excluded_pointees.contains(&element_ty) {
221 trace!("Skipping pointer for type: {:?}", pointee_ty);
222 return;
223 }
224
225 self.pointers.push((pointer, pointee_ty, context));
226
227 self.super_place(place, context, location);
228 }
229}
230
231fn split_block(
232 basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
233 location: Location,
234) -> BasicBlock {
235 let block_data = &mut basic_blocks[location.block];
236
237 let new_block = BasicBlockData {
239 statements: block_data.statements.split_off(location.statement_index),
240 terminator: block_data.terminator.take(),
241 is_cleanup: block_data.is_cleanup,
242 };
243
244 basic_blocks.push(new_block)
245}