rapx/analysis/senryx/
dominated_chain.rs

1use crate::{
2    analysis::utils::fn_info::{display_hashmap, get_pointee, is_ptr, is_ref},
3    rap_warn,
4};
5use rustc_hir::def_id::DefId;
6use rustc_middle::mir::Local;
7use rustc_middle::ty::TyKind;
8use rustc_middle::ty::{Ty, TyCtxt};
9use std::collections::{HashMap, HashSet, VecDeque};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct States {
13    pub nonnull: bool,
14    pub allocator_consistency: bool,
15    pub init: bool,
16    pub align: bool,
17    pub valid_string: bool,
18    pub valid_cstr: bool,
19}
20
21impl States {
22    pub fn new() -> Self {
23        Self {
24            nonnull: true,
25            allocator_consistency: true,
26            init: true,
27            align: true,
28            valid_string: true,
29            valid_cstr: true,
30        }
31    }
32
33    pub fn new_unknown() -> Self {
34        Self {
35            nonnull: false,
36            allocator_consistency: false,
37            init: false,
38            align: false,
39            valid_string: false,
40            valid_cstr: false,
41        }
42    }
43
44    pub fn merge_states(&mut self, other: &States) {
45        self.nonnull &= other.nonnull;
46        self.allocator_consistency &= other.allocator_consistency;
47        self.init &= other.init;
48        self.align &= other.align;
49        self.valid_string &= other.valid_string;
50        self.valid_cstr &= other.valid_cstr;
51    }
52}
53
54#[derive(Debug, Clone)]
55pub struct InterResultNode<'tcx> {
56    pub point_to: Option<Box<InterResultNode<'tcx>>>,
57    pub fields: HashMap<usize, InterResultNode<'tcx>>,
58    pub ty: Option<Ty<'tcx>>,
59    pub states: States,
60    pub const_value: usize,
61}
62
63impl<'tcx> InterResultNode<'tcx> {
64    pub fn new_default(ty: Option<Ty<'tcx>>) -> Self {
65        Self {
66            point_to: None,
67            fields: HashMap::new(),
68            ty,
69            states: States::new(),
70            const_value: 0, // To be modified
71        }
72    }
73
74    pub fn construct_from_var_node(chain: DominatedGraph<'tcx>, var_id: usize) -> Self {
75        let var_node = chain.get_var_node(var_id).unwrap();
76        let point_node = if var_node.points_to.is_none() {
77            None
78        } else {
79            Some(Box::new(Self::construct_from_var_node(
80                chain.clone(),
81                var_node.points_to.unwrap(),
82            )))
83        };
84        let fields = var_node
85            .field
86            .iter()
87            .map(|(k, v)| (*k, Self::construct_from_var_node(chain.clone(), *v)))
88            .collect();
89        Self {
90            point_to: point_node,
91            fields,
92            ty: var_node.ty.clone(),
93            states: var_node.states.clone(),
94            const_value: var_node.const_value,
95        }
96    }
97
98    pub fn merge(&mut self, other: InterResultNode<'tcx>) {
99        if self.ty != other.ty {
100            return;
101        }
102        // merge current node's states
103        self.states.merge_states(&other.states);
104
105        // merge node it points to
106        match (&mut self.point_to, other.point_to) {
107            (Some(self_ptr), Some(other_ptr)) => self_ptr.merge(*other_ptr),
108            (None, Some(other_ptr)) => {
109                self.point_to = Some(other_ptr.clone());
110            }
111            _ => {}
112        }
113        // merge the fields nodess
114        for (field_id, other_node) in &other.fields {
115            match self.fields.get_mut(field_id) {
116                Some(self_node) => self_node.merge(other_node.clone()),
117                None => {
118                    self.fields.insert(*field_id, other_node.clone());
119                }
120            }
121        }
122        // TODO: merge into a range
123        self.const_value = std::cmp::max(self.const_value, other.const_value);
124    }
125}
126
127#[derive(Debug, Clone)]
128pub struct VariableNode<'tcx> {
129    pub id: usize,
130    pub alias_set: HashSet<usize>,
131    points_to: Option<usize>,
132    pointed_by: HashSet<usize>,
133    pub field: HashMap<usize, usize>,
134    pub ty: Option<Ty<'tcx>>,
135    pub is_dropped: bool,
136    pub states: States,
137    pub const_value: usize,
138}
139
140impl<'tcx> VariableNode<'tcx> {
141    pub fn new(
142        id: usize,
143        points_to: Option<usize>,
144        pointed_by: HashSet<usize>,
145        ty: Option<Ty<'tcx>>,
146        states: States,
147    ) -> Self {
148        VariableNode {
149            id,
150            alias_set: HashSet::from([id]),
151            points_to,
152            pointed_by,
153            field: HashMap::new(),
154            ty,
155            is_dropped: false,
156            states,
157            const_value: 0,
158        }
159    }
160
161    pub fn new_default(id: usize, ty: Option<Ty<'tcx>>) -> Self {
162        VariableNode {
163            id,
164            alias_set: HashSet::from([id]),
165            points_to: None,
166            pointed_by: HashSet::new(),
167            field: HashMap::new(),
168            ty,
169            is_dropped: false,
170            states: States::new(),
171            const_value: 0,
172        }
173    }
174
175    pub fn new_with_states(id: usize, ty: Option<Ty<'tcx>>, states: States) -> Self {
176        VariableNode {
177            id,
178            alias_set: HashSet::from([id]),
179            points_to: None,
180            pointed_by: HashSet::new(),
181            field: HashMap::new(),
182            ty,
183            is_dropped: false,
184            states,
185            const_value: 0,
186        }
187    }
188}
189
190#[derive(Clone)]
191pub struct DominatedGraph<'tcx> {
192    pub tcx: TyCtxt<'tcx>,
193    pub def_id: DefId,
194    pub local_len: usize,
195    pub variables: HashMap<usize, VariableNode<'tcx>>,
196}
197
198impl<'tcx> DominatedGraph<'tcx> {
199    // This constructor will init all the local arguments' node states.
200    // If input argument is ptr or ref, it will point to a corresponding obj node.
201    pub fn new(tcx: TyCtxt<'tcx>, def_id: DefId) -> Self {
202        let body = tcx.optimized_mir(def_id);
203        let locals = body.local_decls.clone();
204        let fn_sig = tcx.fn_sig(def_id).skip_binder();
205        let param_len = fn_sig.inputs().skip_binder().len();
206        let mut var_map: HashMap<usize, VariableNode<'_>> = HashMap::new();
207        let mut obj_cnt = 0;
208        for (idx, local) in locals.iter().enumerate() {
209            let local_ty = local.ty;
210            let mut node = VariableNode::new_default(idx, Some(local_ty));
211            if local_ty.to_string().contains("MaybeUninit") {
212                node.states.init = false;
213            }
214            var_map.insert(idx, node);
215        }
216        Self {
217            tcx,
218            def_id,
219            local_len: locals.len(),
220            variables: var_map,
221        }
222    }
223
224    pub fn init_self_with_inter(&mut self, inter_result: InterResultNode<'tcx>) {
225        let self_node = self.get_var_node(1).unwrap().clone();
226        if self_node.ty.unwrap().is_ref() {
227            let obj_node = self.get_var_node(self.get_point_to_id(1)).unwrap();
228            self.dfs_insert_inter_results(inter_result, obj_node.id);
229        } else {
230            self.dfs_insert_inter_results(inter_result, self_node.id);
231        }
232    }
233
234    pub fn dfs_insert_inter_results(&mut self, inter_result: InterResultNode<'tcx>, local: usize) {
235        let new_id = self.generate_node_id();
236        let node = self.get_var_node_mut(local).unwrap();
237        // node.ty = inter_result.ty;
238        node.states = inter_result.states;
239        node.const_value = inter_result.const_value;
240        if inter_result.point_to.is_some() {
241            let new_node = inter_result.point_to.unwrap();
242            node.points_to = Some(new_id);
243            self.insert_node(
244                new_id,
245                new_node.ty.clone(),
246                local,
247                None,
248                new_node.states.clone(),
249            );
250            self.dfs_insert_inter_results(*new_node, new_id);
251        }
252        for (field_idx, field_inter) in inter_result.fields {
253            let field_node_id = self.insert_field_node(local, field_idx, field_inter.ty.clone());
254            self.dfs_insert_inter_results(field_inter, field_node_id);
255        }
256    }
257
258    pub fn init_arg(&mut self) {
259        let body = self.tcx.optimized_mir(self.def_id);
260        let locals = body.local_decls.clone();
261        let fn_sig = self.tcx.fn_sig(self.def_id).skip_binder();
262        let param_len = fn_sig.inputs().skip_binder().len();
263        for idx in 1..param_len + 1 {
264            let local_ty = locals[Local::from(idx)].ty;
265            self.generate_ptr_with_obj_node(local_ty, idx);
266        }
267    }
268
269    pub fn generate_ptr_with_obj_node(&mut self, local_ty: Ty<'tcx>, idx: usize) -> usize {
270        let new_id = self.generate_node_id();
271        if is_ptr(local_ty) {
272            // modify ptr node pointed
273            self.get_var_node_mut(idx).unwrap().points_to = Some(new_id);
274            // insert pointed object node
275            self.insert_node(
276                new_id,
277                Some(get_pointee(local_ty)),
278                idx,
279                None,
280                States::new_unknown(),
281            );
282        } else if is_ref(local_ty) {
283            // modify ptr node pointed
284            self.get_var_node_mut(idx).unwrap().points_to = Some(new_id);
285            // insert ref object node
286            self.insert_node(
287                new_id,
288                Some(get_pointee(local_ty)),
289                idx,
290                None,
291                States::new(),
292            );
293        }
294        new_id
295    }
296
297    // if current node is ptr or ref, then return the new node pointed by it.
298    pub fn check_ptr(&mut self, arg: usize) -> usize {
299        if self.get_var_node_mut(arg).unwrap().ty.is_none() {
300            display_hashmap(&self.variables, 1);
301        };
302        let node_ty = self.get_var_node_mut(arg).unwrap().ty.unwrap();
303        if is_ptr(node_ty) || is_ref(node_ty) {
304            return self.generate_ptr_with_obj_node(node_ty, arg);
305        }
306        arg
307    }
308
309    pub fn get_local_ty_by_place(&self, arg: usize) -> Option<Ty<'tcx>> {
310        let body = self.tcx.optimized_mir(self.def_id);
311        let locals = body.local_decls.clone();
312        if arg < locals.len() {
313            return Some(locals[Local::from(arg)].ty);
314        } else {
315            // If the arg is a field of some place, we search the whole map for it.
316            return self.get_var_node(arg).unwrap().ty;
317        }
318    }
319
320    pub fn get_obj_ty_through_chain(&self, arg: usize) -> Option<Ty<'tcx>> {
321        let var = self.get_var_node(arg).unwrap();
322        // If the var is ptr or ref, then find its pointed obj.
323        if let Some(pointed_idx) = var.points_to {
324            // let pointed_var = self.get_var_node(pointed_idx).unwrap();
325            // pointed_var.ty
326            self.get_obj_ty_through_chain(pointed_idx)
327        } else {
328            var.ty
329        }
330    }
331
332    pub fn get_point_to_id(&self, arg: usize) -> usize {
333        // display_hashmap(&self.variables,1);
334        // println!("{:?}",self.def_id);
335        let var = self.get_var_node(arg).unwrap();
336        if let Some(pointed_idx) = var.points_to {
337            pointed_idx
338        } else {
339            arg
340        }
341    }
342
343    pub fn is_local(&self, node_id: usize) -> bool {
344        self.local_len > node_id
345    }
346}
347
348// This implementation has the auxiliary functions of DominatedGraph,
349// including c/r/u/d nodes and printing chains' structure.
350impl<'tcx> DominatedGraph<'tcx> {
351    // Only for inserting field obj node or pointed obj node.
352    pub fn generate_node_id(&self) -> usize {
353        if self.variables.len() == 0 || *self.variables.keys().max().unwrap() < self.local_len {
354            return self.local_len;
355        }
356        *self.variables.keys().max().unwrap() + 1
357    }
358
359    pub fn get_field_node_id(
360        &mut self,
361        local: usize,
362        field_idx: usize,
363        ty: Option<Ty<'tcx>>,
364    ) -> usize {
365        let node = self.get_var_node(local).unwrap();
366        if let Some(alias_local) = node.field.get(&field_idx) {
367            *alias_local
368        } else {
369            self.insert_field_node(local, field_idx, ty)
370        }
371    }
372
373    // Insert the responding field node of one local, then return its genrated node_id.
374    pub fn insert_field_node(
375        &mut self,
376        local: usize,
377        field_idx: usize,
378        ty: Option<Ty<'tcx>>,
379    ) -> usize {
380        let new_id = self.generate_node_id();
381        self.variables
382            .insert(new_id, VariableNode::new_default(new_id, ty));
383        let mut_node = self.get_var_node_mut(local).unwrap();
384        mut_node.field.insert(field_idx, new_id);
385        return new_id;
386    }
387
388    pub fn find_var_id_with_fields_seq(&mut self, local: usize, fields: Vec<usize>) -> usize {
389        let mut cur = self.get_point_to_id(local);
390        for field in fields {
391            cur = self.get_point_to_id(cur);
392            let cur_node = self.get_var_node(cur).unwrap();
393            match cur_node.ty.unwrap().kind() {
394                TyKind::Adt(adt_def, substs) => {
395                    if adt_def.is_struct() {
396                        for (idx, field_def) in adt_def.all_fields().enumerate() {
397                            if idx == field {
398                                cur = self.get_field_node_id(
399                                    cur,
400                                    field,
401                                    Some(field_def.ty(self.tcx, substs)),
402                                );
403                            }
404                        }
405                    }
406                }
407                // TODO: maybe unsafe here for setting ty as None!
408                _ => {
409                    cur = self.get_field_node_id(cur, field, None);
410                }
411            }
412        }
413        return cur;
414    }
415
416    pub fn point(&mut self, lv: usize, rv: usize) {
417        // rap_warn!("{lv} = & or * {rv}");
418        let rv_node = self.get_var_node_mut(rv).unwrap();
419        rv_node.pointed_by.insert(lv);
420        let lv_node = self.get_var_node_mut(lv).unwrap();
421        let ori_to = lv_node.points_to.clone();
422        lv_node.points_to = Some(rv);
423        // Delete lv from the origin pointed node's pointed_by.
424        if let Some(to) = ori_to {
425            let ori_to_node = self.get_var_node_mut(to).unwrap();
426            ori_to_node.pointed_by.remove(&lv);
427        }
428    }
429
430    pub fn get_var_nod_id(&self, local_id: usize) -> usize {
431        self.get_var_node(local_id).unwrap().id
432    }
433
434    pub fn get_map_idx_node(&self, local_id: usize) -> &VariableNode<'tcx> {
435        self.variables.get(&local_id).unwrap()
436    }
437
438    pub fn get_var_node(&self, local_id: usize) -> Option<&VariableNode<'tcx>> {
439        for (_idx, var_node) in &self.variables {
440            if var_node.alias_set.contains(&local_id) {
441                return Some(var_node);
442            }
443        }
444        rap_warn!("def id:{:?}, local_id: {local_id}", self.def_id);
445        display_hashmap(&self.variables, 1);
446        None
447    }
448
449    pub fn get_var_node_mut(&mut self, local_id: usize) -> Option<&mut VariableNode<'tcx>> {
450        let va = self.variables.clone();
451        for (_idx, var_node) in &mut self.variables {
452            if var_node.alias_set.contains(&local_id) {
453                return Some(var_node);
454            }
455        }
456        rap_warn!("def id:{:?}, local_id: {local_id}", self.def_id);
457        display_hashmap(&va, 1);
458        None
459    }
460
461    // Merge node when (lv = move rv);
462    // In this case, lv will be the same with rv.
463    // And the nodes pointing to lv originally will re-point to rv.
464    pub fn merge(&mut self, lv: usize, rv: usize) {
465        let lv_node = self.get_var_node_mut(lv).unwrap().clone();
466        if lv_node.alias_set.contains(&rv) {
467            return;
468        }
469        for lv_pointed_by in lv_node.pointed_by.clone() {
470            self.point(lv_pointed_by, rv);
471        }
472        let lv_node = self.get_var_node_mut(lv).unwrap();
473        lv_node.alias_set.remove(&lv);
474        let lv_ty = lv_node.ty;
475        let lv_states = lv_node.states.clone();
476        let rv_node = self.get_var_node_mut(rv).unwrap();
477        rv_node.alias_set.insert(lv);
478        // rv_node.states.merge_states(&lv_states);
479        if rv_node.ty.is_none() {
480            rv_node.ty = lv_ty;
481        }
482    }
483
484    // Called when (lv = copy rv);
485    pub fn copy_node(&mut self, lv: usize, rv: usize) {
486        let rv_node = self.get_var_node_mut(rv).unwrap().clone();
487        let lv_node = self.get_var_node_mut(lv).unwrap();
488        let lv_ty = lv_node.ty.unwrap();
489        lv_node.states = rv_node.states;
490        lv_node.is_dropped = rv_node.is_dropped;
491        if is_ptr(rv_node.ty.unwrap()) && is_ptr(lv_ty) {
492            // println!("++++{lv}--{rv}");
493            self.merge(lv, rv);
494        }
495    }
496
497    pub fn break_node_connection(&mut self, lv: usize, rv: usize) {
498        let rv_node = self.get_var_node_mut(rv).unwrap();
499        rv_node.pointed_by.remove(&lv);
500        let lv_node = self.get_var_node_mut(lv).unwrap();
501        lv_node.points_to = None;
502    }
503
504    pub fn insert_node(
505        &mut self,
506        dv: usize,
507        ty: Option<Ty<'tcx>>,
508        parent_id: usize,
509        child_id: Option<usize>,
510        state: States,
511    ) {
512        self.variables.insert(
513            dv,
514            VariableNode::new(dv, child_id, HashSet::from([parent_id]), ty, state),
515        );
516    }
517
518    pub fn delete_node(&mut self, idx: usize) {
519        let node = self.get_var_node(idx).unwrap().clone();
520        for pre_idx in &node.pointed_by.clone() {
521            let pre_node = self.get_var_node_mut(*pre_idx).unwrap();
522            pre_node.points_to = None;
523        }
524        if let Some(to) = &node.points_to.clone() {
525            let next_node = self.get_var_node_mut(*to).unwrap();
526            next_node.pointed_by.remove(&idx);
527        }
528        self.variables.remove(&idx);
529    }
530
531    pub fn set_drop(&mut self, idx: usize) -> bool {
532        if let Some(ori_node) = self.get_var_node_mut(idx) {
533            if ori_node.is_dropped == true {
534                // rap_warn!("Double free detected!"); // todo: update reports
535                return false;
536            }
537            ori_node.is_dropped = true;
538        }
539        true
540    }
541
542    pub fn update_value(&mut self, arg: usize, value: usize) {
543        let node = self.get_var_node_mut(arg).unwrap();
544        node.const_value = value;
545        node.states.init = true;
546    }
547
548    pub fn print_graph(&self) {
549        let mut visited = HashSet::new();
550        let mut subgraphs = Vec::new();
551
552        for &node_id in self.variables.keys() {
553            if !visited.contains(&node_id) {
554                let mut queue = VecDeque::new();
555                let mut subgraph = Vec::new();
556
557                queue.push_back(node_id);
558                visited.insert(node_id);
559
560                while let Some(current_id) = queue.pop_front() {
561                    subgraph.push(current_id);
562
563                    if let Some(node) = self.get_var_node(current_id) {
564                        if let Some(next_id) = node.points_to {
565                            if !visited.contains(&next_id) {
566                                visited.insert(next_id);
567                                queue.push_back(next_id);
568                            }
569                        }
570
571                        for &pointer_id in &node.pointed_by {
572                            if !visited.contains(&pointer_id) {
573                                visited.insert(pointer_id);
574                                queue.push_back(pointer_id);
575                            }
576                        }
577                    }
578                }
579
580                subgraphs.push(subgraph);
581            }
582        }
583
584        for (i, mut subgraph) in subgraphs.into_iter().enumerate() {
585            subgraph.sort_unstable();
586            println!("Connected Subgraph {}: {:?}", i + 1, subgraph);
587
588            for node_id in subgraph {
589                if let Some(node) = self.get_var_node(node_id) {
590                    println!("  Node {} → {:?}", node_id, node.points_to);
591                }
592            }
593            println!();
594        }
595    }
596}