1use rustc_index::{Idx, IndexVec};
2use rustc_middle::mir::*;
3use rustc_middle::ty::Ty;
4use rustc_span::Span;
5use tracing::debug;
6
7pub(crate) struct MirPatch<'tcx> {
12 term_patch_map: IndexVec<BasicBlock, Option<TerminatorKind<'tcx>>>,
13 new_blocks: Vec<BasicBlockData<'tcx>>,
14 new_statements: Vec<(Location, StatementKind<'tcx>)>,
15 new_locals: Vec<LocalDecl<'tcx>>,
16 resume_block: Option<BasicBlock>,
17 unreachable_cleanup_block: Option<BasicBlock>,
19 unreachable_no_cleanup_block: Option<BasicBlock>,
21 terminate_block: Option<(BasicBlock, UnwindTerminateReason)>,
23 body_span: Span,
24 next_local: usize,
25}
26
27impl<'tcx> MirPatch<'tcx> {
28 pub(crate) fn new(body: &Body<'tcx>) -> Self {
30 let mut result = MirPatch {
31 term_patch_map: IndexVec::from_elem(None, &body.basic_blocks),
32 new_blocks: vec![],
33 new_statements: vec![],
34 new_locals: vec![],
35 next_local: body.local_decls.len(),
36 resume_block: None,
37 unreachable_cleanup_block: None,
38 unreachable_no_cleanup_block: None,
39 terminate_block: None,
40 body_span: body.span,
41 };
42
43 for (bb, block) in body.basic_blocks.iter_enumerated() {
44 if matches!(block.terminator().kind, TerminatorKind::UnwindResume)
46 && block.statements.is_empty()
47 {
48 result.resume_block = Some(bb);
49 continue;
50 }
51
52 if matches!(block.terminator().kind, TerminatorKind::Unreachable)
54 && block.statements.is_empty()
55 {
56 if block.is_cleanup {
57 result.unreachable_cleanup_block = Some(bb);
58 } else {
59 result.unreachable_no_cleanup_block = Some(bb);
60 }
61 continue;
62 }
63
64 if let TerminatorKind::UnwindTerminate(reason) = block.terminator().kind
66 && block.statements.is_empty()
67 {
68 result.terminate_block = Some((bb, reason));
69 continue;
70 }
71 }
72
73 result
74 }
75
76 pub(crate) fn resume_block(&mut self) -> BasicBlock {
77 if let Some(bb) = self.resume_block {
78 return bb;
79 }
80
81 let bb = self.new_block(BasicBlockData {
82 statements: vec![],
83 terminator: Some(Terminator {
84 source_info: SourceInfo::outermost(self.body_span),
85 kind: TerminatorKind::UnwindResume,
86 }),
87 is_cleanup: true,
88 });
89 self.resume_block = Some(bb);
90 bb
91 }
92
93 pub(crate) fn unreachable_cleanup_block(&mut self) -> BasicBlock {
94 if let Some(bb) = self.unreachable_cleanup_block {
95 return bb;
96 }
97
98 let bb = self.new_block(BasicBlockData {
99 statements: vec![],
100 terminator: Some(Terminator {
101 source_info: SourceInfo::outermost(self.body_span),
102 kind: TerminatorKind::Unreachable,
103 }),
104 is_cleanup: true,
105 });
106 self.unreachable_cleanup_block = Some(bb);
107 bb
108 }
109
110 pub(crate) fn unreachable_no_cleanup_block(&mut self) -> BasicBlock {
111 if let Some(bb) = self.unreachable_no_cleanup_block {
112 return bb;
113 }
114
115 let bb = self.new_block(BasicBlockData {
116 statements: vec![],
117 terminator: Some(Terminator {
118 source_info: SourceInfo::outermost(self.body_span),
119 kind: TerminatorKind::Unreachable,
120 }),
121 is_cleanup: false,
122 });
123 self.unreachable_no_cleanup_block = Some(bb);
124 bb
125 }
126
127 pub(crate) fn terminate_block(&mut self, reason: UnwindTerminateReason) -> BasicBlock {
128 if let Some((cached_bb, cached_reason)) = self.terminate_block
129 && reason == cached_reason
130 {
131 return cached_bb;
132 }
133
134 let bb = self.new_block(BasicBlockData {
135 statements: vec![],
136 terminator: Some(Terminator {
137 source_info: SourceInfo::outermost(self.body_span),
138 kind: TerminatorKind::UnwindTerminate(reason),
139 }),
140 is_cleanup: true,
141 });
142 self.terminate_block = Some((bb, reason));
143 bb
144 }
145
146 pub(crate) fn is_term_patched(&self, bb: BasicBlock) -> bool {
148 self.term_patch_map[bb].is_some()
149 }
150
151 pub(crate) fn block<'a>(
153 &'a self,
154 body: &'a Body<'tcx>,
155 bb: BasicBlock,
156 ) -> &'a BasicBlockData<'tcx> {
157 match bb.index().checked_sub(body.basic_blocks.len()) {
158 Some(new) => &self.new_blocks[new],
159 None => &body[bb],
160 }
161 }
162
163 pub(crate) fn terminator_loc(&self, body: &Body<'tcx>, bb: BasicBlock) -> Location {
164 let offset = self.block(body, bb).statements.len();
165 Location { block: bb, statement_index: offset }
166 }
167
168 pub(crate) fn new_local_with_info(
170 &mut self,
171 ty: Ty<'tcx>,
172 span: Span,
173 local_info: LocalInfo<'tcx>,
174 ) -> Local {
175 let index = self.next_local;
176 self.next_local += 1;
177 let mut new_decl = LocalDecl::new(ty, span);
178 **new_decl.local_info.as_mut().unwrap_crate_local() = local_info;
179 self.new_locals.push(new_decl);
180 Local::new(index)
181 }
182
183 pub(crate) fn new_temp(&mut self, ty: Ty<'tcx>, span: Span) -> Local {
185 let index = self.next_local;
186 self.next_local += 1;
187 self.new_locals.push(LocalDecl::new(ty, span));
188 Local::new(index)
189 }
190
191 pub(crate) fn local_ty(&self, local: Local) -> Ty<'tcx> {
193 let local = local.as_usize();
194 assert!(local < self.next_local);
195 let new_local_idx = self.new_locals.len() - (self.next_local - local);
196 self.new_locals[new_local_idx].ty
197 }
198
199 pub(crate) fn new_block(&mut self, data: BasicBlockData<'tcx>) -> BasicBlock {
201 let block = self.term_patch_map.next_index();
202 debug!("MirPatch: new_block: {:?}: {:?}", block, data);
203 self.new_blocks.push(data);
204 self.term_patch_map.push(None);
205 block
206 }
207
208 pub(crate) fn patch_terminator(&mut self, block: BasicBlock, new: TerminatorKind<'tcx>) {
210 assert!(self.term_patch_map[block].is_none());
211 debug!("MirPatch: patch_terminator({:?}, {:?})", block, new);
212 self.term_patch_map[block] = Some(new);
213 }
214
215 pub(crate) fn add_statement(&mut self, loc: Location, stmt: StatementKind<'tcx>) {
229 debug!("MirPatch: add_statement({:?}, {:?})", loc, stmt);
230 self.new_statements.push((loc, stmt));
231 }
232
233 pub(crate) fn add_assign(&mut self, loc: Location, place: Place<'tcx>, rv: Rvalue<'tcx>) {
235 self.add_statement(loc, StatementKind::Assign(Box::new((place, rv))));
236 }
237
238 pub(crate) fn apply(self, body: &mut Body<'tcx>) {
240 debug!(
241 "MirPatch: {:?} new temps, starting from index {}: {:?}",
242 self.new_locals.len(),
243 body.local_decls.len(),
244 self.new_locals
245 );
246 debug!(
247 "MirPatch: {} new blocks, starting from index {}",
248 self.new_blocks.len(),
249 body.basic_blocks.len()
250 );
251 let bbs = if self.term_patch_map.is_empty() && self.new_blocks.is_empty() {
252 body.basic_blocks.as_mut_preserves_cfg()
253 } else {
254 body.basic_blocks.as_mut()
255 };
256 bbs.extend(self.new_blocks);
257 body.local_decls.extend(self.new_locals);
258 for (src, patch) in self.term_patch_map.into_iter_enumerated() {
259 if let Some(patch) = patch {
260 debug!("MirPatch: patching block {:?}", src);
261 bbs[src].terminator_mut().kind = patch;
262 }
263 }
264
265 let mut new_statements = self.new_statements;
266
267 new_statements.sort_by_key(|s| s.0);
270
271 let mut delta = 0;
272 let mut last_bb = START_BLOCK;
273 for (mut loc, stmt) in new_statements {
274 if loc.block != last_bb {
275 delta = 0;
276 last_bb = loc.block;
277 }
278 debug!("MirPatch: adding statement {:?} at loc {:?}+{}", stmt, loc, delta);
279 loc.statement_index += delta;
280 let source_info = Self::source_info_for_index(&body[loc.block], loc);
281 body[loc.block]
282 .statements
283 .insert(loc.statement_index, Statement { source_info, kind: stmt });
284 delta += 1;
285 }
286 }
287
288 fn source_info_for_index(data: &BasicBlockData<'_>, loc: Location) -> SourceInfo {
289 match data.statements.get(loc.statement_index) {
290 Some(stmt) => stmt.source_info,
291 None => data.terminator().source_info,
292 }
293 }
294
295 pub(crate) fn source_info_for_location(&self, body: &Body<'tcx>, loc: Location) -> SourceInfo {
296 let data = self.block(body, loc.block);
297 Self::source_info_for_index(data, loc)
298 }
299}