1use crate::{
2 analysis::utils::fn_info::{display_hashmap, get_pointee, is_ptr, is_ref},
3 rap_warn,
4};
5use rustc_hir::def_id::DefId;
6use rustc_middle::mir::Local;
7use rustc_middle::ty::TyKind;
8use rustc_middle::ty::{Ty, TyCtxt};
9use std::collections::{HashMap, HashSet, VecDeque};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct States {
13 pub nonnull: bool,
14 pub allocator_consistency: bool,
15 pub init: bool,
16 pub align: bool,
17 pub valid_string: bool,
18 pub valid_cstr: bool,
19}
20
21impl States {
22 pub fn new() -> Self {
23 Self {
24 nonnull: true,
25 allocator_consistency: true,
26 init: true,
27 align: true,
28 valid_string: true,
29 valid_cstr: true,
30 }
31 }
32
33 pub fn new_unknown() -> Self {
34 Self {
35 nonnull: false,
36 allocator_consistency: false,
37 init: false,
38 align: false,
39 valid_string: false,
40 valid_cstr: false,
41 }
42 }
43
44 pub fn merge_states(&mut self, other: &States) {
45 self.nonnull &= other.nonnull;
46 self.allocator_consistency &= other.allocator_consistency;
47 self.init &= other.init;
48 self.align &= other.align;
49 self.valid_string &= other.valid_string;
50 self.valid_cstr &= other.valid_cstr;
51 }
52}
53
54#[derive(Debug, Clone)]
55pub struct InterResultNode<'tcx> {
56 pub point_to: Option<Box<InterResultNode<'tcx>>>,
57 pub fields: HashMap<usize, InterResultNode<'tcx>>,
58 pub ty: Option<Ty<'tcx>>,
59 pub states: States,
60 pub const_value: usize,
61}
62
63impl<'tcx> InterResultNode<'tcx> {
64 pub fn new_default(ty: Option<Ty<'tcx>>) -> Self {
65 Self {
66 point_to: None,
67 fields: HashMap::new(),
68 ty,
69 states: States::new(),
70 const_value: 0, }
72 }
73
74 pub fn construct_from_var_node(chain: DominatedGraph<'tcx>, var_id: usize) -> Self {
75 let var_node = chain.get_var_node(var_id).unwrap();
76 let point_node = if var_node.points_to.is_none() {
77 None
78 } else {
79 Some(Box::new(Self::construct_from_var_node(
80 chain.clone(),
81 var_node.points_to.unwrap(),
82 )))
83 };
84 let fields = var_node
85 .field
86 .iter()
87 .map(|(k, v)| (*k, Self::construct_from_var_node(chain.clone(), *v)))
88 .collect();
89 Self {
90 point_to: point_node,
91 fields,
92 ty: var_node.ty.clone(),
93 states: var_node.states.clone(),
94 const_value: var_node.const_value,
95 }
96 }
97
98 pub fn merge(&mut self, other: InterResultNode<'tcx>) {
99 if self.ty != other.ty {
100 return;
101 }
102 self.states.merge_states(&other.states);
104
105 match (&mut self.point_to, other.point_to) {
107 (Some(self_ptr), Some(other_ptr)) => self_ptr.merge(*other_ptr),
108 (None, Some(other_ptr)) => {
109 self.point_to = Some(other_ptr.clone());
110 }
111 _ => {}
112 }
113 for (field_id, other_node) in &other.fields {
115 match self.fields.get_mut(field_id) {
116 Some(self_node) => self_node.merge(other_node.clone()),
117 None => {
118 self.fields.insert(*field_id, other_node.clone());
119 }
120 }
121 }
122 self.const_value = std::cmp::max(self.const_value, other.const_value);
124 }
125}
126
127#[derive(Debug, Clone)]
128pub struct VariableNode<'tcx> {
129 pub id: usize,
130 pub alias_set: HashSet<usize>,
131 points_to: Option<usize>,
132 pointed_by: HashSet<usize>,
133 pub field: HashMap<usize, usize>,
134 pub ty: Option<Ty<'tcx>>,
135 pub is_dropped: bool,
136 pub states: States,
137 pub const_value: usize,
138}
139
140impl<'tcx> VariableNode<'tcx> {
141 pub fn new(
142 id: usize,
143 points_to: Option<usize>,
144 pointed_by: HashSet<usize>,
145 ty: Option<Ty<'tcx>>,
146 states: States,
147 ) -> Self {
148 VariableNode {
149 id,
150 alias_set: HashSet::from([id]),
151 points_to,
152 pointed_by,
153 field: HashMap::new(),
154 ty,
155 is_dropped: false,
156 states,
157 const_value: 0,
158 }
159 }
160
161 pub fn new_default(id: usize, ty: Option<Ty<'tcx>>) -> Self {
162 VariableNode {
163 id,
164 alias_set: HashSet::from([id]),
165 points_to: None,
166 pointed_by: HashSet::new(),
167 field: HashMap::new(),
168 ty,
169 is_dropped: false,
170 states: States::new(),
171 const_value: 0,
172 }
173 }
174
175 pub fn new_with_states(id: usize, ty: Option<Ty<'tcx>>, states: States) -> Self {
176 VariableNode {
177 id,
178 alias_set: HashSet::from([id]),
179 points_to: None,
180 pointed_by: HashSet::new(),
181 field: HashMap::new(),
182 ty,
183 is_dropped: false,
184 states,
185 const_value: 0,
186 }
187 }
188}
189
190#[derive(Clone)]
191pub struct DominatedGraph<'tcx> {
192 pub tcx: TyCtxt<'tcx>,
193 pub def_id: DefId,
194 pub local_len: usize,
195 pub variables: HashMap<usize, VariableNode<'tcx>>,
196}
197
198impl<'tcx> DominatedGraph<'tcx> {
199 pub fn new(tcx: TyCtxt<'tcx>, def_id: DefId) -> Self {
202 let body = tcx.optimized_mir(def_id);
203 let locals = body.local_decls.clone();
204 let fn_sig = tcx.fn_sig(def_id).skip_binder();
205 let param_len = fn_sig.inputs().skip_binder().len();
206 let mut var_map: HashMap<usize, VariableNode<'_>> = HashMap::new();
207 let mut obj_cnt = 0;
208 for (idx, local) in locals.iter().enumerate() {
209 let local_ty = local.ty;
210 let mut node = VariableNode::new_default(idx, Some(local_ty));
211 if local_ty.to_string().contains("MaybeUninit") {
212 node.states.init = false;
213 }
214 var_map.insert(idx, node);
215 }
216 Self {
217 tcx,
218 def_id,
219 local_len: locals.len(),
220 variables: var_map,
221 }
222 }
223
224 pub fn init_self_with_inter(&mut self, inter_result: InterResultNode<'tcx>) {
225 let self_node = self.get_var_node(1).unwrap().clone();
226 if self_node.ty.unwrap().is_ref() {
227 let obj_node = self.get_var_node(self.get_point_to_id(1)).unwrap();
228 self.dfs_insert_inter_results(inter_result, obj_node.id);
229 } else {
230 self.dfs_insert_inter_results(inter_result, self_node.id);
231 }
232 }
233
234 pub fn dfs_insert_inter_results(&mut self, inter_result: InterResultNode<'tcx>, local: usize) {
235 let new_id = self.generate_node_id();
236 let node = self.get_var_node_mut(local).unwrap();
237 node.states = inter_result.states;
239 node.const_value = inter_result.const_value;
240 if inter_result.point_to.is_some() {
241 let new_node = inter_result.point_to.unwrap();
242 node.points_to = Some(new_id);
243 self.insert_node(
244 new_id,
245 new_node.ty.clone(),
246 local,
247 None,
248 new_node.states.clone(),
249 );
250 self.dfs_insert_inter_results(*new_node, new_id);
251 }
252 for (field_idx, field_inter) in inter_result.fields {
253 let field_node_id = self.insert_field_node(local, field_idx, field_inter.ty.clone());
254 self.dfs_insert_inter_results(field_inter, field_node_id);
255 }
256 }
257
258 pub fn init_arg(&mut self) {
259 let body = self.tcx.optimized_mir(self.def_id);
260 let locals = body.local_decls.clone();
261 let fn_sig = self.tcx.fn_sig(self.def_id).skip_binder();
262 let param_len = fn_sig.inputs().skip_binder().len();
263 for idx in 1..param_len + 1 {
264 let local_ty = locals[Local::from(idx)].ty;
265 self.generate_ptr_with_obj_node(local_ty, idx);
266 }
267 }
268
269 pub fn generate_ptr_with_obj_node(&mut self, local_ty: Ty<'tcx>, idx: usize) -> usize {
270 let new_id = self.generate_node_id();
271 if is_ptr(local_ty) {
272 self.get_var_node_mut(idx).unwrap().points_to = Some(new_id);
274 self.insert_node(
276 new_id,
277 Some(get_pointee(local_ty)),
278 idx,
279 None,
280 States::new_unknown(),
281 );
282 } else if is_ref(local_ty) {
283 self.get_var_node_mut(idx).unwrap().points_to = Some(new_id);
285 self.insert_node(
287 new_id,
288 Some(get_pointee(local_ty)),
289 idx,
290 None,
291 States::new(),
292 );
293 }
294 new_id
295 }
296
297 pub fn check_ptr(&mut self, arg: usize) -> usize {
299 if self.get_var_node_mut(arg).unwrap().ty.is_none() {
300 display_hashmap(&self.variables, 1);
301 };
302 let node_ty = self.get_var_node_mut(arg).unwrap().ty.unwrap();
303 if is_ptr(node_ty) || is_ref(node_ty) {
304 return self.generate_ptr_with_obj_node(node_ty, arg);
305 }
306 arg
307 }
308
309 pub fn get_local_ty_by_place(&self, arg: usize) -> Option<Ty<'tcx>> {
310 let body = self.tcx.optimized_mir(self.def_id);
311 let locals = body.local_decls.clone();
312 if arg < locals.len() {
313 return Some(locals[Local::from(arg)].ty);
314 } else {
315 return self.get_var_node(arg).unwrap().ty;
317 }
318 }
319
320 pub fn get_obj_ty_through_chain(&self, arg: usize) -> Option<Ty<'tcx>> {
321 let var = self.get_var_node(arg).unwrap();
322 if let Some(pointed_idx) = var.points_to {
324 self.get_obj_ty_through_chain(pointed_idx)
327 } else {
328 var.ty
329 }
330 }
331
332 pub fn get_point_to_id(&self, arg: usize) -> usize {
333 let var = self.get_var_node(arg).unwrap();
336 if let Some(pointed_idx) = var.points_to {
337 pointed_idx
338 } else {
339 arg
340 }
341 }
342
343 pub fn is_local(&self, node_id: usize) -> bool {
344 self.local_len > node_id
345 }
346}
347
348impl<'tcx> DominatedGraph<'tcx> {
351 pub fn generate_node_id(&self) -> usize {
353 if self.variables.len() == 0 || *self.variables.keys().max().unwrap() < self.local_len {
354 return self.local_len;
355 }
356 *self.variables.keys().max().unwrap() + 1
357 }
358
359 pub fn get_field_node_id(
360 &mut self,
361 local: usize,
362 field_idx: usize,
363 ty: Option<Ty<'tcx>>,
364 ) -> usize {
365 let node = self.get_var_node(local).unwrap();
366 if let Some(alias_local) = node.field.get(&field_idx) {
367 *alias_local
368 } else {
369 self.insert_field_node(local, field_idx, ty)
370 }
371 }
372
373 pub fn insert_field_node(
375 &mut self,
376 local: usize,
377 field_idx: usize,
378 ty: Option<Ty<'tcx>>,
379 ) -> usize {
380 let new_id = self.generate_node_id();
381 self.variables
382 .insert(new_id, VariableNode::new_default(new_id, ty));
383 let mut_node = self.get_var_node_mut(local).unwrap();
384 mut_node.field.insert(field_idx, new_id);
385 return new_id;
386 }
387
388 pub fn find_var_id_with_fields_seq(&mut self, local: usize, fields: Vec<usize>) -> usize {
389 let mut cur = self.get_point_to_id(local);
390 for field in fields {
391 cur = self.get_point_to_id(cur);
392 let cur_node = self.get_var_node(cur).unwrap();
393 match cur_node.ty.unwrap().kind() {
394 TyKind::Adt(adt_def, substs) => {
395 if adt_def.is_struct() {
396 for (idx, field_def) in adt_def.all_fields().enumerate() {
397 if idx == field {
398 cur = self.get_field_node_id(
399 cur,
400 field,
401 Some(field_def.ty(self.tcx, substs)),
402 );
403 }
404 }
405 }
406 }
407 _ => {
409 cur = self.get_field_node_id(cur, field, None);
410 }
411 }
412 }
413 return cur;
414 }
415
416 pub fn point(&mut self, lv: usize, rv: usize) {
417 let rv_node = self.get_var_node_mut(rv).unwrap();
419 rv_node.pointed_by.insert(lv);
420 let lv_node = self.get_var_node_mut(lv).unwrap();
421 let ori_to = lv_node.points_to.clone();
422 lv_node.points_to = Some(rv);
423 if let Some(to) = ori_to {
425 let ori_to_node = self.get_var_node_mut(to).unwrap();
426 ori_to_node.pointed_by.remove(&lv);
427 }
428 }
429
430 pub fn get_var_nod_id(&self, local_id: usize) -> usize {
431 self.get_var_node(local_id).unwrap().id
432 }
433
434 pub fn get_map_idx_node(&self, local_id: usize) -> &VariableNode<'tcx> {
435 self.variables.get(&local_id).unwrap()
436 }
437
438 pub fn get_var_node(&self, local_id: usize) -> Option<&VariableNode<'tcx>> {
439 for (_idx, var_node) in &self.variables {
440 if var_node.alias_set.contains(&local_id) {
441 return Some(var_node);
442 }
443 }
444 rap_warn!("def id:{:?}, local_id: {local_id}", self.def_id);
445 display_hashmap(&self.variables, 1);
446 None
447 }
448
449 pub fn get_var_node_mut(&mut self, local_id: usize) -> Option<&mut VariableNode<'tcx>> {
450 let va = self.variables.clone();
451 for (_idx, var_node) in &mut self.variables {
452 if var_node.alias_set.contains(&local_id) {
453 return Some(var_node);
454 }
455 }
456 rap_warn!("def id:{:?}, local_id: {local_id}", self.def_id);
457 display_hashmap(&va, 1);
458 None
459 }
460
461 pub fn merge(&mut self, lv: usize, rv: usize) {
465 let lv_node = self.get_var_node_mut(lv).unwrap().clone();
466 if lv_node.alias_set.contains(&rv) {
467 return;
468 }
469 for lv_pointed_by in lv_node.pointed_by.clone() {
470 self.point(lv_pointed_by, rv);
471 }
472 let lv_node = self.get_var_node_mut(lv).unwrap();
473 lv_node.alias_set.remove(&lv);
474 let lv_ty = lv_node.ty;
475 let lv_states = lv_node.states.clone();
476 let rv_node = self.get_var_node_mut(rv).unwrap();
477 rv_node.alias_set.insert(lv);
478 if rv_node.ty.is_none() {
480 rv_node.ty = lv_ty;
481 }
482 }
483
484 pub fn copy_node(&mut self, lv: usize, rv: usize) {
486 let rv_node = self.get_var_node_mut(rv).unwrap().clone();
487 let lv_node = self.get_var_node_mut(lv).unwrap();
488 let lv_ty = lv_node.ty.unwrap();
489 lv_node.states = rv_node.states;
490 lv_node.is_dropped = rv_node.is_dropped;
491 if is_ptr(rv_node.ty.unwrap()) && is_ptr(lv_ty) {
492 self.merge(lv, rv);
494 }
495 }
496
497 pub fn break_node_connection(&mut self, lv: usize, rv: usize) {
498 let rv_node = self.get_var_node_mut(rv).unwrap();
499 rv_node.pointed_by.remove(&lv);
500 let lv_node = self.get_var_node_mut(lv).unwrap();
501 lv_node.points_to = None;
502 }
503
504 pub fn insert_node(
505 &mut self,
506 dv: usize,
507 ty: Option<Ty<'tcx>>,
508 parent_id: usize,
509 child_id: Option<usize>,
510 state: States,
511 ) {
512 self.variables.insert(
513 dv,
514 VariableNode::new(dv, child_id, HashSet::from([parent_id]), ty, state),
515 );
516 }
517
518 pub fn delete_node(&mut self, idx: usize) {
519 let node = self.get_var_node(idx).unwrap().clone();
520 for pre_idx in &node.pointed_by.clone() {
521 let pre_node = self.get_var_node_mut(*pre_idx).unwrap();
522 pre_node.points_to = None;
523 }
524 if let Some(to) = &node.points_to.clone() {
525 let next_node = self.get_var_node_mut(*to).unwrap();
526 next_node.pointed_by.remove(&idx);
527 }
528 self.variables.remove(&idx);
529 }
530
531 pub fn set_drop(&mut self, idx: usize) -> bool {
532 if let Some(ori_node) = self.get_var_node_mut(idx) {
533 if ori_node.is_dropped == true {
534 return false;
536 }
537 ori_node.is_dropped = true;
538 }
539 true
540 }
541
542 pub fn update_value(&mut self, arg: usize, value: usize) {
543 let node = self.get_var_node_mut(arg).unwrap();
544 node.const_value = value;
545 node.states.init = true;
546 }
547
548 pub fn print_graph(&self) {
549 let mut visited = HashSet::new();
550 let mut subgraphs = Vec::new();
551
552 for &node_id in self.variables.keys() {
553 if !visited.contains(&node_id) {
554 let mut queue = VecDeque::new();
555 let mut subgraph = Vec::new();
556
557 queue.push_back(node_id);
558 visited.insert(node_id);
559
560 while let Some(current_id) = queue.pop_front() {
561 subgraph.push(current_id);
562
563 if let Some(node) = self.get_var_node(current_id) {
564 if let Some(next_id) = node.points_to {
565 if !visited.contains(&next_id) {
566 visited.insert(next_id);
567 queue.push_back(next_id);
568 }
569 }
570
571 for &pointer_id in &node.pointed_by {
572 if !visited.contains(&pointer_id) {
573 visited.insert(pointer_id);
574 queue.push_back(pointer_id);
575 }
576 }
577 }
578 }
579
580 subgraphs.push(subgraph);
581 }
582 }
583
584 for (i, mut subgraph) in subgraphs.into_iter().enumerate() {
585 subgraph.sort_unstable();
586 println!("Connected Subgraph {}: {:?}", i + 1, subgraph);
587
588 for node_id in subgraph {
589 if let Some(node) = self.get_var_node(node_id) {
590 println!(" Node {} → {:?}", node_id, node.points_to);
591 }
592 }
593 println!();
594 }
595 }
596}