rapx/analysis/core/dataflow/
graph.rs

1use std::cell::Cell;
2use std::collections::HashSet;
3
4use rustc_hir::def_id::DefId;
5use rustc_index::IndexVec;
6use rustc_middle::{
7    mir::{
8        AggregateKind, BorrowKind, Const, Local, Operand, Place, PlaceElem, Rvalue, Statement,
9        StatementKind, Terminator, TerminatorKind,
10    },
11    ty::TyKind,
12};
13use rustc_span::{DUMMY_SP, Span};
14
15use crate::{analysis::core::dataflow::*, utils::log::relative_pos_range};
16
17impl GraphNode {
18    pub fn new() -> Self {
19        Self {
20            ops: vec![NodeOp::Nop],
21            span: DUMMY_SP,
22            seq: 0,
23            out_edges: vec![],
24            in_edges: vec![],
25        }
26    }
27}
28
29#[derive(Clone)]
30pub struct Graph {
31    pub def_id: DefId,
32    pub span: Span,
33    pub argc: usize,
34    pub nodes: GraphNodes, //constsis of locals in mir and newly created markers
35    pub edges: GraphEdges,
36    pub n_locals: usize,
37    pub closures: HashSet<DefId>,
38}
39
40impl From<Graph> for DataFlowGraph {
41    fn from(graph: Graph) -> Self {
42        let param_ret_deps = graph.param_return_deps();
43        DataFlowGraph {
44            nodes: graph.nodes,
45            edges: graph.edges,
46            param_ret_deps: param_ret_deps,
47        }
48    }
49}
50
51impl Graph {
52    pub fn new(def_id: DefId, span: Span, argc: usize, n_locals: usize) -> Self {
53        Self {
54            def_id,
55            span,
56            argc,
57            nodes: GraphNodes::from_elem_n(GraphNode::new(), n_locals),
58            edges: GraphEdges::new(),
59            n_locals,
60            closures: HashSet::new(),
61        }
62    }
63
64    // add an edge into an existing node
65    pub fn add_node_edge(&mut self, src: Local, dst: Local, op: EdgeOp) -> EdgeIdx {
66        let seq = self.nodes[dst].seq;
67        let edge_idx = self.edges.push(GraphEdge { src, dst, op, seq });
68        self.nodes[dst].in_edges.push(edge_idx);
69        self.nodes[src].out_edges.push(edge_idx);
70        edge_idx
71    }
72
73    // add an edge into an existing node with const value as src
74    pub fn add_const_edge(
75        &mut self,
76        src_desc: String,
77        src_ty: String,
78        dst: Local,
79        op: EdgeOp,
80    ) -> EdgeIdx {
81        let seq = self.nodes[dst].seq;
82        let mut const_node = GraphNode::new();
83        const_node.ops[0] = NodeOp::Const(src_desc, src_ty);
84        let src = self.nodes.push(const_node);
85        let edge_idx = self.edges.push(GraphEdge { src, dst, op, seq });
86        self.nodes[dst].in_edges.push(edge_idx);
87        edge_idx
88    }
89
90    pub fn add_operand(&mut self, operand: &Operand, dst: Local) {
91        match operand {
92            Operand::Copy(place) => {
93                let src = self.parse_place(place);
94                self.add_node_edge(src, dst, EdgeOp::Copy);
95            }
96            Operand::Move(place) => {
97                let src = self.parse_place(place);
98                self.add_node_edge(src, dst, EdgeOp::Move);
99            }
100            Operand::Constant(boxed_const_op) => {
101                let src_desc = boxed_const_op.const_.to_string();
102                let src_ty = match boxed_const_op.const_ {
103                    Const::Val(_, ty) => ty.to_string(),
104                    Const::Unevaluated(_, ty) => ty.to_string(),
105                    Const::Ty(ty, _) => ty.to_string(),
106                };
107                self.add_const_edge(src_desc, src_ty, dst, EdgeOp::Const);
108            }
109        }
110    }
111
112    pub fn parse_place(&mut self, place: &Place) -> Local {
113        fn parse_one_step(graph: &mut Graph, src: Local, place_elem: PlaceElem) -> Local {
114            let dst = graph.nodes.push(GraphNode::new());
115            match place_elem {
116                PlaceElem::Deref => {
117                    graph.add_node_edge(src, dst, EdgeOp::Deref);
118                }
119                PlaceElem::Field(field_idx, _) => {
120                    graph.add_node_edge(src, dst, EdgeOp::Field(field_idx.as_usize()));
121                }
122                PlaceElem::Downcast(symbol, _) => {
123                    graph.add_node_edge(src, dst, EdgeOp::Downcast(symbol.unwrap().to_string()));
124                }
125                PlaceElem::Index(idx) => {
126                    graph.add_node_edge(src, dst, EdgeOp::Index);
127                    graph.add_node_edge(idx, dst, EdgeOp::Nop);
128                }
129                PlaceElem::ConstantIndex { .. } => {
130                    graph.add_node_edge(src, dst, EdgeOp::ConstIndex);
131                }
132                PlaceElem::Subslice { .. } => {
133                    graph.add_node_edge(src, dst, EdgeOp::SubSlice);
134                }
135                _ => {
136                    println!("{:?}", place_elem);
137                    todo!()
138                }
139            }
140            dst
141        }
142        let mut ret = place.local;
143        for place_elem in place.projection {
144            // if there are projections, then add marker nodes
145            ret = parse_one_step(self, ret, place_elem);
146        }
147        ret
148    }
149
150    pub fn add_statm_to_graph(&mut self, statement: &Statement) {
151        if let StatementKind::Assign(boxed_statm) = &statement.kind {
152            let place = boxed_statm.0;
153            let dst = self.parse_place(&place);
154            self.nodes[dst].span = statement.source_info.span;
155            let rvalue = &boxed_statm.1;
156            let seq = self.nodes[dst].seq;
157            if seq == self.nodes[dst].ops.len() {
158                //warning: we do not check whether seq > len
159                self.nodes[dst].ops.push(NodeOp::Nop);
160            }
161            match rvalue {
162                Rvalue::Use(op) => {
163                    self.add_operand(op, dst);
164                    self.nodes[dst].ops[seq] = NodeOp::Use;
165                }
166                Rvalue::Repeat(op, _) => {
167                    self.add_operand(op, dst);
168                    self.nodes[dst].ops[seq] = NodeOp::Repeat;
169                }
170                Rvalue::Ref(_, borrow_kind, place) => {
171                    let op = match borrow_kind {
172                        BorrowKind::Shared => EdgeOp::Immut,
173                        BorrowKind::Mut { .. } => EdgeOp::Mut,
174                        BorrowKind::Fake(_) => EdgeOp::Nop, // todo
175                    };
176                    let src = self.parse_place(place);
177                    self.add_node_edge(src, dst, op);
178                    self.nodes[dst].ops[seq] = NodeOp::Ref;
179                }
180                Rvalue::Cast(_cast_kind, operand, _) => {
181                    self.add_operand(operand, dst);
182                    self.nodes[dst].ops[seq] = NodeOp::Cast;
183                }
184                Rvalue::BinaryOp(_, operands) => {
185                    self.add_operand(&operands.0, dst);
186                    self.add_operand(&operands.1, dst);
187                    self.nodes[dst].ops[seq] = NodeOp::CheckedBinaryOp;
188                }
189                Rvalue::Aggregate(boxed_kind, operands) => {
190                    for operand in operands.iter() {
191                        self.add_operand(operand, dst);
192                    }
193                    match **boxed_kind {
194                        AggregateKind::Array(_) => {
195                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Array)
196                        }
197                        AggregateKind::Tuple => {
198                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Tuple)
199                        }
200                        AggregateKind::Adt(def_id, ..) => {
201                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Adt(def_id))
202                        }
203                        AggregateKind::Closure(def_id, ..) => {
204                            self.closures.insert(def_id);
205                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Closure(def_id))
206                        }
207                        AggregateKind::Coroutine(def_id, ..) => {
208                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Coroutine(def_id))
209                        }
210                        AggregateKind::RawPtr(_, _mutability) => {
211                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::RawPtr)
212                            // We temporarily have not taken mutability into account
213                        }
214                        _ => {
215                            println!("{:?}", boxed_kind);
216                            todo!()
217                        }
218                    }
219                }
220                Rvalue::UnaryOp(_, operand) => {
221                    self.add_operand(operand, dst);
222                    self.nodes[dst].ops[seq] = NodeOp::UnaryOp;
223                }
224                Rvalue::NullaryOp(_) => {
225                    self.nodes[dst].ops[seq] = NodeOp::NullaryOp;
226                }
227                Rvalue::ThreadLocalRef(_) => {
228                    //todo!()
229                }
230                Rvalue::Discriminant(place) => {
231                    let src = self.parse_place(place);
232                    self.add_node_edge(src, dst, EdgeOp::Nop);
233                    self.nodes[dst].ops[seq] = NodeOp::Discriminant;
234                }
235                Rvalue::ShallowInitBox(operand, _) => {
236                    self.add_operand(operand, dst);
237                    self.nodes[dst].ops[seq] = NodeOp::ShallowInitBox;
238                }
239                Rvalue::CopyForDeref(place) => {
240                    let src = self.parse_place(place);
241                    self.add_node_edge(src, dst, EdgeOp::Nop);
242                    self.nodes[dst].ops[seq] = NodeOp::CopyForDeref;
243                }
244                Rvalue::RawPtr(_, place) => {
245                    let src = self.parse_place(place);
246                    self.add_node_edge(src, dst, EdgeOp::Nop); // Mutability?
247                    self.nodes[dst].ops[seq] = NodeOp::RawPtr;
248                }
249                _ => todo!(),
250            };
251            self.nodes[dst].seq = seq + 1;
252        }
253    }
254
255    pub fn add_terminator_to_graph(&mut self, terminator: &Terminator) {
256        if let TerminatorKind::Call {
257            func,
258            args,
259            destination,
260            ..
261        } = &terminator.kind
262        {
263            let dst = destination.local;
264            let seq = self.nodes[dst].seq;
265            if seq == self.nodes[dst].ops.len() {
266                self.nodes[dst].ops.push(NodeOp::Nop);
267            }
268            match func {
269                Operand::Constant(boxed_cnst) => {
270                    if let Const::Val(_, ty) = boxed_cnst.const_ {
271                        if let TyKind::FnDef(def_id, _) = ty.kind() {
272                            for op in args.iter() {
273                                //rustc version related
274                                self.add_operand(&op.node, dst);
275                            }
276                            self.nodes[dst].ops[seq] = NodeOp::Call(*def_id);
277                        }
278                    }
279                }
280                Operand::Move(_) => {
281                    self.add_operand(func, dst); //the func is a place
282                    for op in args.iter() {
283                        //rustc version related
284                        self.add_operand(&op.node, dst);
285                    }
286                    self.nodes[dst].ops[seq] = NodeOp::CallOperand;
287                }
288                _ => {
289                    println!("{:?}", func);
290                    todo!();
291                }
292            }
293            self.nodes[dst].span = terminator.source_info.span;
294            self.nodes[dst].seq = seq + 1;
295        }
296    }
297
298    // Because a node(local) may have multiple ops, we need to decide whether to strictly collect equivalent locals or not
299    // For the former, all the ops should meet the equivalent condition.
300    // For the later, if only one op meets the condition, we still take it into consideration.
301    pub fn collect_equivalent_locals(&self, local: Local, strict: bool) -> HashSet<Local> {
302        let mut set = HashSet::new();
303        let root = Cell::new(local);
304        let reduce_func = if strict {
305            DFSStatus::and
306        } else {
307            DFSStatus::or
308        };
309        let mut find_root_operator = |graph: &Graph, idx: Local| -> DFSStatus {
310            let node = &graph.nodes[idx];
311            node.ops
312                .iter()
313                .map(|op| {
314                    match op {
315                        NodeOp::Nop | NodeOp::Use | NodeOp::Ref => {
316                            //Nop means an orphan node or a parameter
317                            root.set(idx);
318                            DFSStatus::Continue
319                        }
320                        NodeOp::Call(_) => {
321                            //We are moving towards upside. Thus we can record the call node and stop dfs.
322                            //We stop because the return value does not equal to parameters
323                            root.set(idx);
324                            DFSStatus::Stop
325                        }
326                        _ => DFSStatus::Stop,
327                    }
328                })
329                .reduce(reduce_func)
330                .unwrap()
331        };
332        let mut find_equivalent_operator = |graph: &Graph, idx: Local| -> DFSStatus {
333            let node = &graph.nodes[idx];
334            if set.contains(&idx) {
335                return DFSStatus::Stop;
336            }
337            node.ops
338                .iter()
339                .map(|op| match op {
340                    NodeOp::Nop | NodeOp::Use | NodeOp::Ref => {
341                        set.insert(idx);
342                        DFSStatus::Continue
343                    }
344                    NodeOp::Call(_) => {
345                        if idx == root.get() {
346                            set.insert(idx);
347                            DFSStatus::Continue
348                        } else {
349                            // We are moving towards downside. Thus we stop dfs right now.
350                            DFSStatus::Stop
351                        }
352                    }
353                    _ => DFSStatus::Stop,
354                })
355                .reduce(reduce_func)
356                .unwrap()
357        };
358        // Algorithm: dfs along upside to find the root node, and then dfs along downside to collect equivalent locals
359        let mut seen = HashSet::new();
360        self.dfs(
361            local,
362            Direction::Upside,
363            &mut find_root_operator,
364            &mut Self::equivalent_edge_validator,
365            true,
366            &mut seen,
367        );
368        seen.clear();
369        self.dfs(
370            root.get(),
371            Direction::Downside,
372            &mut find_equivalent_operator,
373            &mut Self::equivalent_edge_validator,
374            true,
375            &mut seen,
376        );
377        set
378    }
379
380    pub fn collect_ancestor_locals(&self, local: Local, self_included: bool) -> HashSet<Local> {
381        let mut ret = HashSet::new();
382        let mut node_operator = |_: &Graph, idx: Local| -> DFSStatus {
383            ret.insert(idx);
384            DFSStatus::Continue
385        };
386        let mut seen = HashSet::new();
387        self.dfs(
388            local,
389            Direction::Upside,
390            &mut node_operator,
391            &mut Graph::always_true_edge_validator,
392            true,
393            &mut seen,
394        );
395        if !self_included {
396            ret.remove(&local);
397        }
398        ret
399    }
400
401    pub fn collect_descending_locals(&self, local: Local, self_included: bool) -> HashSet<Local> {
402        let mut ret = HashSet::new();
403        let mut node_operator = |_: &Graph, idx: Local| -> DFSStatus {
404            ret.insert(idx);
405            DFSStatus::Continue
406        };
407        let mut seen = HashSet::new();
408        self.dfs(
409            local,
410            Direction::Downside,
411            &mut node_operator,
412            &mut Graph::always_true_edge_validator,
413            true,
414            &mut seen,
415        );
416        if !self_included {
417            ret.remove(&local);
418        }
419        ret
420    }
421
422    pub fn get_field_sequence(&self, local: Local) -> Option<(Local, Vec<usize>)> {
423        let mut fields = vec![];
424        let var = Cell::new(local);
425        let mut node_operator = |graph: &Graph, idx: Local| -> DFSStatus {
426            if graph.is_marker(idx) {
427                DFSStatus::Continue
428            } else {
429                var.set(idx);
430                DFSStatus::Stop
431            }
432        };
433        let mut edge_validator = |graph: &Graph, idx: EdgeIdx| -> DFSStatus {
434            if let EdgeOp::Field(field) = graph.edges[idx].op {
435                fields.insert(0, field);
436                DFSStatus::Continue
437            } else {
438                DFSStatus::Stop
439            }
440        };
441        let mut seen = HashSet::new();
442        self.dfs(
443            local,
444            Direction::Upside,
445            &mut node_operator,
446            &mut edge_validator,
447            false,
448            &mut seen,
449        );
450        if fields.is_empty() {
451            None
452        } else {
453            Some((var.get(), fields))
454        }
455    }
456
457    pub fn is_connected(&self, idx_1: Local, idx_2: Local) -> bool {
458        let target = idx_2;
459        let find = Cell::new(false);
460        let mut node_operator = |_: &Graph, idx: Local| -> DFSStatus {
461            find.set(idx == target);
462            if find.get() {
463                DFSStatus::Stop
464            } else {
465                // if not found, move on
466                DFSStatus::Continue
467            }
468        };
469        let mut seen = HashSet::new();
470        self.dfs(
471            idx_1,
472            Direction::Downside,
473            &mut node_operator,
474            &mut Self::always_true_edge_validator,
475            false,
476            &mut seen,
477        );
478        seen.clear();
479        if !find.get() {
480            self.dfs(
481                idx_1,
482                Direction::Upside,
483                &mut node_operator,
484                &mut Self::always_true_edge_validator,
485                false,
486                &mut seen,
487            );
488        }
489        find.get()
490    }
491
492    // Whether there exists dataflow between each parameter and the return value
493    pub fn param_return_deps(&self) -> IndexVec<Local, bool> {
494        let _0 = Local::from_usize(0);
495        let deps = (0..self.argc + 1) //the length is argc + 1, because _0 depends on _0 itself.
496            .map(|i| {
497                let _i = Local::from_usize(i);
498                self.is_connected(_i, _0)
499            })
500            .collect();
501        deps
502    }
503
504    // This function uses precedence traversal.
505    // The node operator and edge validator decide how far the traversal can reach.
506    // `traverse_all` decides if a branch finds the target successfully, whether the traversal will continue or not.
507    // For example, if you need to instantly stop the traversal once finding a certain node, then set `traverse_all` to false.
508    // If you want to traverse all the reachable nodes which are decided by the operator and validator, then set `traverse_all` to true.
509    pub fn dfs<F, G>(
510        &self,
511        now: Local,
512        direction: Direction,
513        node_operator: &mut F,
514        edge_validator: &mut G,
515        traverse_all: bool,
516        seen: &mut HashSet<Local>,
517    ) -> (DFSStatus, bool)
518    where
519        F: FnMut(&Graph, Local) -> DFSStatus,
520        G: FnMut(&Graph, EdgeIdx) -> DFSStatus,
521    {
522        if seen.contains(&now) {
523            return (DFSStatus::Stop, false);
524        }
525        seen.insert(now);
526        macro_rules! traverse {
527            ($edges: ident, $field: ident) => {
528                for edge_idx in self.nodes[now].$edges.iter() {
529                    let edge = &self.edges[*edge_idx];
530                    if matches!(edge_validator(self, *edge_idx), DFSStatus::Continue) {
531                        let (dfs_status, result) = self.dfs(
532                            edge.$field,
533                            direction,
534                            node_operator,
535                            edge_validator,
536                            traverse_all,
537                            seen,
538                        );
539
540                        if matches!(dfs_status, DFSStatus::Stop) && result && !traverse_all {
541                            return (DFSStatus::Stop, true);
542                        }
543                    }
544                }
545            };
546        }
547
548        if matches!(node_operator(self, now), DFSStatus::Continue) {
549            match direction {
550                Direction::Upside => {
551                    traverse!(in_edges, src);
552                }
553                Direction::Downside => {
554                    traverse!(out_edges, dst);
555                }
556                Direction::Both => {
557                    traverse!(in_edges, src);
558                    traverse!(out_edges, dst);
559                }
560            };
561            (DFSStatus::Continue, false)
562        } else {
563            (DFSStatus::Stop, true)
564        }
565    }
566
567    pub fn get_upside_idx(&self, node_idx: Local, order: usize) -> Option<Local> {
568        if let Some(edge_idx) = self.nodes[node_idx].in_edges.get(order) {
569            Some(self.edges[*edge_idx].src)
570        } else {
571            None
572        }
573    }
574
575    pub fn get_downside_idx(&self, node_idx: Local, order: usize) -> Option<Local> {
576        if let Some(edge_idx) = self.nodes[node_idx].out_edges.get(order) {
577            Some(self.edges[*edge_idx].dst)
578        } else {
579            None
580        }
581    }
582
583    // if strict is set to false, we return the first node that wraps the target span and at least one end overlaps
584    pub fn query_node_by_span(&self, span: Span, strict: bool) -> Option<(Local, &GraphNode)> {
585        for (node_idx, node) in self.nodes.iter_enumerated() {
586            if strict {
587                if node.span == span {
588                    return Some((node_idx, node));
589                }
590            } else {
591                if !relative_pos_range(node.span, span).eq(0..0)
592                    && (node.span.lo() == span.lo() || node.span.hi() == span.hi())
593                {
594                    return Some((node_idx, node));
595                }
596            }
597        }
598        None
599    }
600
601    pub fn is_marker(&self, idx: Local) -> bool {
602        idx >= Local::from_usize(self.n_locals)
603    }
604}
605
606impl Graph {
607    pub fn equivalent_edge_validator(graph: &Graph, idx: EdgeIdx) -> DFSStatus {
608        match graph.edges[idx].op {
609            EdgeOp::Copy | EdgeOp::Move | EdgeOp::Mut | EdgeOp::Immut | EdgeOp::Deref => {
610                DFSStatus::Continue
611            }
612            EdgeOp::Nop
613            | EdgeOp::Const
614            | EdgeOp::Downcast(_)
615            | EdgeOp::Field(_)
616            | EdgeOp::Index
617            | EdgeOp::ConstIndex
618            | EdgeOp::SubSlice => DFSStatus::Stop,
619        }
620    }
621
622    pub fn always_true_edge_validator(_: &Graph, _: EdgeIdx) -> DFSStatus {
623        DFSStatus::Continue
624    }
625}
626
627#[derive(Clone, Copy)]
628pub enum Direction {
629    Upside,
630    Downside,
631    Both,
632}
633
634pub enum DFSStatus {
635    Continue, // true
636    Stop,     // false
637}
638
639impl DFSStatus {
640    pub fn and(s1: DFSStatus, s2: DFSStatus) -> DFSStatus {
641        if matches!(s1, DFSStatus::Stop) || matches!(s2, DFSStatus::Stop) {
642            DFSStatus::Stop
643        } else {
644            DFSStatus::Continue
645        }
646    }
647
648    pub fn or(s1: DFSStatus, s2: DFSStatus) -> DFSStatus {
649        if matches!(s1, DFSStatus::Continue) || matches!(s2, DFSStatus::Continue) {
650            DFSStatus::Continue
651        } else {
652            DFSStatus::Stop
653        }
654    }
655}