rapx/analysis/senryx/
dominated_graph.rs

1use crate::{
2    analysis::{
3        senryx::contracts::{
4            contract,
5            property::{CisRangeItem, ContractualInvariantState, PropertyContract},
6        },
7        utils::fn_info::{display_hashmap, get_pointee, is_ptr, is_ref, is_slice, reverse_op},
8    },
9    rap_debug, rap_warn,
10};
11use rustc_hir::def_id::DefId;
12use rustc_middle::mir::BinOp;
13use rustc_middle::mir::Local;
14use rustc_middle::ty::TyKind;
15use rustc_middle::ty::{Ty, TyCtxt};
16use serde::de;
17use std::collections::{HashMap, HashSet, VecDeque};
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct States {
21    pub nonnull: bool,
22    pub allocator_consistency: bool,
23    pub init: bool,
24    pub align: bool,
25    pub valid_string: bool,
26    pub valid_cstr: bool,
27}
28
29impl States {
30    pub fn new() -> Self {
31        Self {
32            nonnull: true,
33            allocator_consistency: true,
34            init: true,
35            align: true,
36            valid_string: true,
37            valid_cstr: true,
38        }
39    }
40
41    pub fn new_unknown() -> Self {
42        Self {
43            nonnull: false,
44            allocator_consistency: false,
45            init: false,
46            align: false,
47            valid_string: false,
48            valid_cstr: false,
49        }
50    }
51
52    pub fn merge_states(&mut self, other: &States) {
53        self.nonnull &= other.nonnull;
54        self.allocator_consistency &= other.allocator_consistency;
55        self.init &= other.init;
56        self.align &= other.align;
57        self.valid_string &= other.valid_string;
58        self.valid_cstr &= other.valid_cstr;
59    }
60}
61
62#[derive(Debug, Clone)]
63pub struct InterResultNode<'tcx> {
64    pub point_to: Option<Box<InterResultNode<'tcx>>>,
65    pub fields: HashMap<usize, InterResultNode<'tcx>>,
66    pub ty: Option<Ty<'tcx>>,
67    pub states: States,
68    pub const_value: usize,
69}
70
71impl<'tcx> InterResultNode<'tcx> {
72    pub fn new_default(ty: Option<Ty<'tcx>>) -> Self {
73        Self {
74            point_to: None,
75            fields: HashMap::new(),
76            ty,
77            states: States::new(),
78            const_value: 0, // To be modified
79        }
80    }
81
82    pub fn construct_from_var_node(chain: DominatedGraph<'tcx>, var_id: usize) -> Self {
83        let var_node = chain.get_var_node(var_id).unwrap();
84        let point_node = if var_node.points_to.is_none() {
85            None
86        } else {
87            Some(Box::new(Self::construct_from_var_node(
88                chain.clone(),
89                var_node.points_to.unwrap(),
90            )))
91        };
92        let fields = var_node
93            .field
94            .iter()
95            .map(|(k, v)| (*k, Self::construct_from_var_node(chain.clone(), *v)))
96            .collect();
97        Self {
98            point_to: point_node,
99            fields,
100            ty: var_node.ty.clone(),
101            states: var_node.ots.clone(),
102            const_value: var_node.const_value,
103        }
104    }
105
106    pub fn merge(&mut self, other: InterResultNode<'tcx>) {
107        if self.ty != other.ty {
108            return;
109        }
110        // merge current node's states
111        self.states.merge_states(&other.states);
112
113        // merge node it points to
114        match (&mut self.point_to, other.point_to) {
115            (Some(self_ptr), Some(other_ptr)) => self_ptr.merge(*other_ptr),
116            (None, Some(other_ptr)) => {
117                self.point_to = Some(other_ptr.clone());
118            }
119            _ => {}
120        }
121        // merge the fields nodess
122        for (field_id, other_node) in &other.fields {
123            match self.fields.get_mut(field_id) {
124                Some(self_node) => self_node.merge(other_node.clone()),
125                None => {
126                    self.fields.insert(*field_id, other_node.clone());
127                }
128            }
129        }
130        // TODO: merge into a range
131        self.const_value = std::cmp::max(self.const_value, other.const_value);
132    }
133}
134
135#[derive(Debug, Clone)]
136pub struct VariableNode<'tcx> {
137    pub id: usize,
138    pub alias_set: HashSet<usize>,
139    points_to: Option<usize>,
140    pointed_by: HashSet<usize>,
141    pub field: HashMap<usize, usize>,
142    pub ty: Option<Ty<'tcx>>,
143    pub is_dropped: bool,
144    pub ots: States,
145    pub const_value: usize,
146    pub cis: ContractualInvariantState<'tcx>,
147}
148
149impl<'tcx> VariableNode<'tcx> {
150    pub fn new(
151        id: usize,
152        points_to: Option<usize>,
153        pointed_by: HashSet<usize>,
154        ty: Option<Ty<'tcx>>,
155        ots: States,
156    ) -> Self {
157        VariableNode {
158            id,
159            alias_set: HashSet::from([id]),
160            points_to,
161            pointed_by,
162            field: HashMap::new(),
163            ty,
164            is_dropped: false,
165            ots,
166            const_value: 0,
167            cis: ContractualInvariantState::new_default(),
168        }
169    }
170
171    pub fn new_default(id: usize, ty: Option<Ty<'tcx>>) -> Self {
172        VariableNode {
173            id,
174            alias_set: HashSet::from([id]),
175            points_to: None,
176            pointed_by: HashSet::new(),
177            field: HashMap::new(),
178            ty,
179            is_dropped: false,
180            ots: States::new(),
181            const_value: 0,
182            cis: ContractualInvariantState::new_default(),
183        }
184    }
185
186    pub fn new_with_states(id: usize, ty: Option<Ty<'tcx>>, ots: States) -> Self {
187        VariableNode {
188            id,
189            alias_set: HashSet::from([id]),
190            points_to: None,
191            pointed_by: HashSet::new(),
192            field: HashMap::new(),
193            ty,
194            is_dropped: false,
195            ots,
196            const_value: 0,
197            cis: ContractualInvariantState::new_default(),
198        }
199    }
200}
201
202#[derive(Clone)]
203pub struct DominatedGraph<'tcx> {
204    pub tcx: TyCtxt<'tcx>,
205    pub def_id: DefId,
206    pub local_len: usize,
207    pub variables: HashMap<usize, VariableNode<'tcx>>,
208}
209
210impl<'tcx> DominatedGraph<'tcx> {
211    // This constructor will init all the local arguments' node states.
212    // If input argument is ptr or ref, it will point to a corresponding obj node.
213    pub fn new(tcx: TyCtxt<'tcx>, def_id: DefId) -> Self {
214        let body = tcx.optimized_mir(def_id);
215        let locals = body.local_decls.clone();
216        let fn_sig = tcx.fn_sig(def_id).skip_binder();
217        let param_len = fn_sig.inputs().skip_binder().len();
218        let mut var_map: HashMap<usize, VariableNode<'_>> = HashMap::new();
219        let mut obj_cnt = 0;
220        for (idx, local) in locals.iter().enumerate() {
221            let local_ty = local.ty;
222            let mut node = VariableNode::new_default(idx, Some(local_ty));
223            if local_ty.to_string().contains("MaybeUninit") {
224                node.ots.init = false;
225            }
226            var_map.insert(idx, node);
227        }
228        Self {
229            tcx,
230            def_id,
231            local_len: locals.len(),
232            variables: var_map,
233        }
234    }
235
236    pub fn init_self_with_inter(&mut self, inter_result: InterResultNode<'tcx>) {
237        let self_node = self.get_var_node(1).unwrap().clone();
238        if self_node.ty.unwrap().is_ref() {
239            let obj_node = self.get_var_node(self.get_point_to_id(1)).unwrap();
240            self.dfs_insert_inter_results(inter_result, obj_node.id);
241        } else {
242            self.dfs_insert_inter_results(inter_result, self_node.id);
243        }
244    }
245
246    pub fn dfs_insert_inter_results(&mut self, inter_result: InterResultNode<'tcx>, local: usize) {
247        let new_id = self.generate_node_id();
248        let node = self.get_var_node_mut(local).unwrap();
249        // node.ty = inter_result.ty;
250        node.ots = inter_result.states;
251        node.const_value = inter_result.const_value;
252        if inter_result.point_to.is_some() {
253            let new_node = inter_result.point_to.unwrap();
254            node.points_to = Some(new_id);
255            self.insert_node(
256                new_id,
257                new_node.ty.clone(),
258                local,
259                None,
260                new_node.states.clone(),
261            );
262            self.dfs_insert_inter_results(*new_node, new_id);
263        }
264        for (field_idx, field_inter) in inter_result.fields {
265            let field_node_id = self.insert_field_node(local, field_idx, field_inter.ty.clone());
266            self.dfs_insert_inter_results(field_inter, field_node_id);
267        }
268    }
269
270    pub fn init_arg(&mut self) {
271        // init arg nodes' point to nodes.
272        let body = self.tcx.optimized_mir(self.def_id);
273        let locals = body.local_decls.clone();
274        let fn_sig = self.tcx.fn_sig(self.def_id).skip_binder();
275        let param_len = fn_sig.inputs().skip_binder().len();
276        for idx in 1..param_len + 1 {
277            let local_ty = locals[Local::from(idx)].ty;
278            self.generate_ptr_with_obj_node(local_ty, idx);
279        }
280        // init args' cis
281        let cis_results = crate::analysis::utils::fn_info::generate_contract_from_annotation(
282            self.tcx,
283            self.def_id,
284        );
285        for (base, fields, contract) in cis_results {
286            if fields.len() == 0 {
287                self.insert_cis_for_arg(base, contract);
288            } else {
289                let mut cur_base = base;
290                let mut field_node = base;
291                for field in fields {
292                    field_node = self.insert_field_node(cur_base, field.0, Some(field.1));
293                    // check if field's type is ptr or ref: yes -> create shadow var
294                    self.generate_ptr_with_obj_node(field.1, field_node);
295                    cur_base = field_node;
296                }
297                self.insert_cis_for_arg(field_node, contract);
298            }
299        }
300    }
301
302    fn insert_cis_for_arg(&mut self, local: usize, contract: PropertyContract<'tcx>) {
303        let node = self.get_var_node_mut(local).unwrap();
304        node.cis.add_contract(contract);
305    }
306
307    /// When generate obj node, this function will add InBound Sp automatically.
308    pub fn generate_ptr_with_obj_node(&mut self, local_ty: Ty<'tcx>, idx: usize) -> usize {
309        let new_id = self.generate_node_id();
310        if is_ptr(local_ty) {
311            // modify ptr node pointed
312            self.get_var_node_mut(idx).unwrap().points_to = Some(new_id);
313            // insert pointed object node
314            self.insert_node(
315                new_id,
316                Some(get_pointee(local_ty)),
317                idx,
318                None,
319                States::new_unknown(),
320            );
321            self.add_bound_for_obj(new_id, local_ty);
322        } else if is_ref(local_ty) {
323            // modify ptr node pointed
324            self.get_var_node_mut(idx).unwrap().points_to = Some(new_id);
325            // insert ref object node
326            self.insert_node(
327                new_id,
328                Some(get_pointee(local_ty)),
329                idx,
330                None,
331                States::new(),
332            );
333            self.add_bound_for_obj(new_id, local_ty);
334        }
335        new_id
336    }
337
338    fn add_bound_for_obj(&mut self, new_id: usize, local_ty: Ty<'tcx>) {
339        let new_node = self.get_var_node_mut(new_id).unwrap();
340        let new_node_ty = get_pointee(local_ty);
341        let contract = if is_slice(new_node_ty).is_some() {
342            let inner_ty = is_slice(new_node_ty).unwrap();
343            PropertyContract::new_obj_boundary(inner_ty, CisRangeItem::new_unknown())
344        } else {
345            PropertyContract::new_obj_boundary(new_node_ty, CisRangeItem::new_value(1))
346        };
347        new_node.cis.add_contract(contract);
348    }
349
350    // if current node is ptr or ref, then return the new node pointed by it.
351    pub fn check_ptr(&mut self, arg: usize) -> usize {
352        if self.get_var_node_mut(arg).unwrap().ty.is_none() {
353            display_hashmap(&self.variables, 1);
354        };
355        let node_ty = self.get_var_node_mut(arg).unwrap().ty.unwrap();
356        if is_ptr(node_ty) || is_ref(node_ty) {
357            return self.generate_ptr_with_obj_node(node_ty, arg);
358        }
359        arg
360    }
361
362    pub fn get_local_ty_by_place(&self, arg: usize) -> Option<Ty<'tcx>> {
363        let body = self.tcx.optimized_mir(self.def_id);
364        let locals = body.local_decls.clone();
365        if arg < locals.len() {
366            return Some(locals[Local::from(arg)].ty);
367        } else {
368            // If the arg is a field of some place, we search the whole map for it.
369            return self.get_var_node(arg).unwrap().ty;
370        }
371    }
372
373    pub fn get_obj_ty_through_chain(&self, arg: usize) -> Option<Ty<'tcx>> {
374        let var = self.get_var_node(arg).unwrap();
375        // If the var is ptr or ref, then find its pointed obj.
376        if let Some(pointed_idx) = var.points_to {
377            // let pointed_var = self.get_var_node(pointed_idx).unwrap();
378            // pointed_var.ty
379            self.get_obj_ty_through_chain(pointed_idx)
380        } else {
381            var.ty
382        }
383    }
384
385    pub fn get_point_to_id(&self, arg: usize) -> usize {
386        // display_hashmap(&self.variables,1);
387        // println!("{:?}",self.def_id);
388        let var = self.get_var_node(arg).unwrap();
389        if let Some(pointed_idx) = var.points_to {
390            pointed_idx
391        } else {
392            arg
393        }
394    }
395
396    pub fn is_local(&self, node_id: usize) -> bool {
397        self.local_len > node_id
398    }
399}
400
401// This implementation has the auxiliary functions of DominatedGraph,
402// including c/r/u/d nodes and printing chains' structure.
403impl<'tcx> DominatedGraph<'tcx> {
404    // Only for inserting field obj node or pointed obj node.
405    pub fn generate_node_id(&self) -> usize {
406        if self.variables.len() == 0 || *self.variables.keys().max().unwrap() < self.local_len {
407            return self.local_len;
408        }
409        *self.variables.keys().max().unwrap() + 1
410    }
411
412    pub fn get_field_node_id(
413        &mut self,
414        local: usize,
415        field_idx: usize,
416        ty: Option<Ty<'tcx>>,
417    ) -> usize {
418        let node = self.get_var_node(local).unwrap();
419        if let Some(alias_local) = node.field.get(&field_idx) {
420            *alias_local
421        } else {
422            self.insert_field_node(local, field_idx, ty)
423        }
424    }
425
426    // Insert the responding field node of one local, then return its genrated node_id.
427    pub fn insert_field_node(
428        &mut self,
429        local: usize,
430        field_idx: usize,
431        ty: Option<Ty<'tcx>>,
432    ) -> usize {
433        let new_id = self.generate_node_id();
434        self.variables
435            .insert(new_id, VariableNode::new_default(new_id, ty));
436        let mut_node = self.get_var_node_mut(local).unwrap();
437        mut_node.field.insert(field_idx, new_id);
438        return new_id;
439    }
440
441    pub fn find_var_id_with_fields_seq(&mut self, local: usize, fields: Vec<usize>) -> usize {
442        let mut cur = local;
443        for field in fields.clone() {
444            let mut cur_node = self.get_var_node(cur).unwrap();
445            if let TyKind::Ref(_, ty, _) = cur_node.ty.unwrap().kind() {
446                let point_to = self.get_point_to_id(cur);
447                cur_node = self.get_var_node(point_to).unwrap();
448            }
449            // If there exist a field node, then get it as cur node
450            if cur_node.field.get(&field).is_some() {
451                cur = *cur_node.field.get(&field).unwrap();
452                continue;
453            }
454            // Otherwise, insert a new field node.
455            match cur_node.ty.unwrap().kind() {
456                TyKind::Adt(adt_def, substs) => {
457                    if adt_def.is_struct() {
458                        for (idx, field_def) in adt_def.all_fields().enumerate() {
459                            if idx == field {
460                                cur = self.get_field_node_id(
461                                    cur,
462                                    field,
463                                    Some(field_def.ty(self.tcx, substs)),
464                                );
465                            }
466                        }
467                    }
468                }
469                // TODO: maybe unsafe here for setting ty as None!
470                _ => {
471                    rap_warn!("ty {:?}, field: {:?}", cur_node.ty.unwrap(), field);
472                    rap_warn!("set field type as None! --- src: Dominated Graph / find_var_id_with_fields_seq");
473                    cur = self.get_field_node_id(cur, field, None);
474                }
475            }
476        }
477        return cur;
478    }
479
480    pub fn point(&mut self, lv: usize, rv: usize) {
481        // rap_warn!("{lv} = & or * {rv}");
482        let rv_node = self.get_var_node_mut(rv).unwrap();
483        rv_node.pointed_by.insert(lv);
484        let lv_node = self.get_var_node_mut(lv).unwrap();
485        let ori_to = lv_node.points_to.clone();
486        lv_node.points_to = Some(rv);
487        // Delete lv from the origin pointed node's pointed_by.
488        if let Some(to) = ori_to {
489            let ori_to_node = self.get_var_node_mut(to).unwrap();
490            ori_to_node.pointed_by.remove(&lv);
491        }
492    }
493
494    pub fn get_var_nod_id(&self, local_id: usize) -> usize {
495        self.get_var_node(local_id).unwrap().id
496    }
497
498    pub fn get_map_idx_node(&self, local_id: usize) -> &VariableNode<'tcx> {
499        self.variables.get(&local_id).unwrap()
500    }
501
502    pub fn get_var_node(&self, local_id: usize) -> Option<&VariableNode<'tcx>> {
503        for (_idx, var_node) in &self.variables {
504            if var_node.alias_set.contains(&local_id) {
505                return Some(var_node);
506            }
507        }
508        rap_warn!("def id:{:?}, local_id: {local_id}", self.def_id);
509        display_hashmap(&self.variables, 1);
510        None
511    }
512
513    pub fn get_var_node_mut(&mut self, local_id: usize) -> Option<&mut VariableNode<'tcx>> {
514        let va = self.variables.clone();
515        for (_idx, var_node) in &mut self.variables {
516            if var_node.alias_set.contains(&local_id) {
517                return Some(var_node);
518            }
519        }
520        rap_warn!("def id:{:?}, local_id: {local_id}", self.def_id);
521        display_hashmap(&va, 1);
522        None
523    }
524
525    // Merge node when (lv = move rv);
526    // In this case, lv will be the same with rv.
527    // And the nodes pointing to lv originally will re-point to rv.
528    pub fn merge(&mut self, lv: usize, rv: usize) {
529        let lv_node = self.get_var_node_mut(lv).unwrap().clone();
530        if lv_node.alias_set.contains(&rv) {
531            return;
532        }
533        for lv_pointed_by in lv_node.pointed_by.clone() {
534            self.point(lv_pointed_by, rv);
535        }
536        let lv_node = self.get_var_node_mut(lv).unwrap();
537        lv_node.alias_set.remove(&lv);
538        let lv_ty = lv_node.ty;
539        let lv_states = lv_node.ots.clone();
540        let rv_node = self.get_var_node_mut(rv).unwrap();
541        rv_node.alias_set.insert(lv);
542        // rv_node.states.merge_states(&lv_states);
543        if rv_node.ty.is_none() {
544            rv_node.ty = lv_ty;
545        }
546    }
547
548    // Called when (lv = copy rv);
549    pub fn copy_node(&mut self, lv: usize, rv: usize) {
550        let rv_node = self.get_var_node_mut(rv).unwrap().clone();
551        let lv_node = self.get_var_node_mut(lv).unwrap();
552        let lv_ty = lv_node.ty.unwrap();
553        lv_node.ots = rv_node.ots;
554        lv_node.cis = rv_node.cis;
555        lv_node.is_dropped = rv_node.is_dropped;
556        let lv_id = lv_node.id;
557        // if is_ptr(rv_node.ty.unwrap()) && is_ptr(lv_ty) {
558        //     // println!("++++{lv}--{rv}");
559        //     self.merge(lv, rv);
560        // }
561        if rv_node.points_to.is_some() {
562            self.point(lv_id, rv_node.points_to.unwrap());
563        }
564    }
565
566    fn break_node_connection(&mut self, lv: usize, rv: usize) {
567        let rv_node = self.get_var_node_mut(rv).unwrap();
568        rv_node.pointed_by.remove(&lv);
569        let lv_node = self.get_var_node_mut(lv).unwrap();
570        lv_node.points_to = None;
571    }
572
573    fn insert_node(
574        &mut self,
575        dv: usize,
576        ty: Option<Ty<'tcx>>,
577        parent_id: usize,
578        child_id: Option<usize>,
579        state: States,
580    ) {
581        self.variables.insert(
582            dv,
583            VariableNode::new(dv, child_id, HashSet::from([parent_id]), ty, state),
584        );
585    }
586
587    fn delete_node(&mut self, idx: usize) {
588        let node = self.get_var_node(idx).unwrap().clone();
589        for pre_idx in &node.pointed_by.clone() {
590            let pre_node = self.get_var_node_mut(*pre_idx).unwrap();
591            pre_node.points_to = None;
592        }
593        if let Some(to) = &node.points_to.clone() {
594            let next_node = self.get_var_node_mut(*to).unwrap();
595            next_node.pointed_by.remove(&idx);
596        }
597        self.variables.remove(&idx);
598    }
599
600    pub fn set_drop(&mut self, idx: usize) -> bool {
601        if let Some(ori_node) = self.get_var_node_mut(idx) {
602            if ori_node.is_dropped == true {
603                // rap_warn!("Double free detected!"); // todo: update reports
604                return false;
605            }
606            ori_node.is_dropped = true;
607        }
608        true
609    }
610
611    pub fn update_value(&mut self, arg: usize, value: usize) {
612        let node = self.get_var_node_mut(arg).unwrap();
613        node.const_value = value;
614        node.ots.init = true;
615    }
616
617    pub fn insert_patial_op(&mut self, p1: usize, p2: usize, op: &BinOp) {
618        let p1_node = self.get_var_node_mut(p1).unwrap();
619        p1_node
620            .cis
621            .add_contract(PropertyContract::new_patial_order(p2, *op));
622        let p2_node = self.get_var_node_mut(p2).unwrap();
623        p2_node
624            .cis
625            .add_contract(PropertyContract::new_patial_order(p1, reverse_op(*op)));
626    }
627
628    pub fn print_graph(&self) {
629        let mut visited = HashSet::new();
630        let mut subgraphs = Vec::new();
631
632        for &node_id in self.variables.keys() {
633            if !visited.contains(&node_id) {
634                let mut queue = VecDeque::new();
635                let mut subgraph = Vec::new();
636
637                queue.push_back(node_id);
638                visited.insert(node_id);
639
640                while let Some(current_id) = queue.pop_front() {
641                    subgraph.push(current_id);
642
643                    if let Some(node) = self.get_var_node(current_id) {
644                        if let Some(next_id) = node.points_to {
645                            if !visited.contains(&next_id) {
646                                visited.insert(next_id);
647                                queue.push_back(next_id);
648                            }
649                        }
650
651                        for &pointer_id in &node.pointed_by {
652                            if !visited.contains(&pointer_id) {
653                                visited.insert(pointer_id);
654                                queue.push_back(pointer_id);
655                            }
656                        }
657                    }
658                }
659
660                subgraphs.push(subgraph);
661            }
662        }
663
664        for (i, mut subgraph) in subgraphs.into_iter().enumerate() {
665            subgraph.sort_unstable();
666            println!("Connected Subgraph {}: {:?}", i + 1, subgraph);
667
668            for node_id in subgraph {
669                if let Some(node) = self.get_var_node(node_id) {
670                    println!("  Node {} → {:?}", node_id, node.points_to);
671                }
672            }
673            println!();
674        }
675    }
676}