rapx/analysis/core/dataflow/
graph.rs

1use std::cell::Cell;
2use std::collections::HashSet;
3
4use rustc_hir::def_id::DefId;
5use rustc_index::IndexVec;
6use rustc_middle::{
7    mir::{
8        AggregateKind, BorrowKind, Const, Local, Operand, Place, PlaceElem, Rvalue, Statement,
9        StatementKind, Terminator, TerminatorKind,
10    },
11    ty::TyKind,
12};
13use rustc_span::{Span, DUMMY_SP};
14
15use crate::{analysis::core::dataflow::*, utils::log::relative_pos_range};
16
17impl GraphNode {
18    pub fn new() -> Self {
19        Self {
20            ops: vec![NodeOp::Nop],
21            span: DUMMY_SP,
22            seq: 0,
23            out_edges: vec![],
24            in_edges: vec![],
25        }
26    }
27}
28
29#[derive(Clone)]
30pub struct Graph {
31    pub def_id: DefId,
32    pub span: Span,
33    pub argc: usize,
34    pub nodes: GraphNodes, //constsis of locals in mir and newly created markers
35    pub edges: GraphEdges,
36    pub n_locals: usize,
37    pub closures: HashSet<DefId>,
38}
39
40impl From<Graph> for DataFlowGraph {
41    fn from(graph: Graph) -> Self {
42        let param_ret_deps = graph.param_return_deps();
43        DataFlowGraph {
44            nodes: graph.nodes,
45            edges: graph.edges,
46            param_ret_deps: param_ret_deps,
47        }
48    }
49}
50
51impl Graph {
52    pub fn new(def_id: DefId, span: Span, argc: usize, n_locals: usize) -> Self {
53        Self {
54            def_id,
55            span,
56            argc,
57            nodes: GraphNodes::from_elem_n(GraphNode::new(), n_locals),
58            edges: GraphEdges::new(),
59            n_locals,
60            closures: HashSet::new(),
61        }
62    }
63
64    // add an edge into an existing node
65    pub fn add_node_edge(&mut self, src: Local, dst: Local, op: EdgeOp) -> EdgeIdx {
66        let seq = self.nodes[dst].seq;
67        let edge_idx = self.edges.push(GraphEdge { src, dst, op, seq });
68        self.nodes[dst].in_edges.push(edge_idx);
69        self.nodes[src].out_edges.push(edge_idx);
70        edge_idx
71    }
72
73    // add an edge into an existing node with const value as src
74    pub fn add_const_edge(
75        &mut self,
76        src_desc: String,
77        src_ty: String,
78        dst: Local,
79        op: EdgeOp,
80    ) -> EdgeIdx {
81        let seq = self.nodes[dst].seq;
82        let mut const_node = GraphNode::new();
83        const_node.ops[0] = NodeOp::Const(src_desc, src_ty);
84        let src = self.nodes.push(const_node);
85        let edge_idx = self.edges.push(GraphEdge { src, dst, op, seq });
86        self.nodes[dst].in_edges.push(edge_idx);
87        edge_idx
88    }
89
90    pub fn add_operand(&mut self, operand: &Operand, dst: Local) {
91        match operand {
92            Operand::Copy(place) => {
93                let src = self.parse_place(place);
94                self.add_node_edge(src, dst, EdgeOp::Copy);
95            }
96            Operand::Move(place) => {
97                let src = self.parse_place(place);
98                self.add_node_edge(src, dst, EdgeOp::Move);
99            }
100            Operand::Constant(boxed_const_op) => {
101                let src_desc = boxed_const_op.const_.to_string();
102                let src_ty = match boxed_const_op.const_ {
103                    Const::Val(_, ty) => ty.to_string(),
104                    Const::Unevaluated(_, ty) => ty.to_string(),
105                    Const::Ty(ty, _) => ty.to_string(),
106                };
107                self.add_const_edge(src_desc, src_ty, dst, EdgeOp::Const);
108            }
109        }
110    }
111
112    pub fn parse_place(&mut self, place: &Place) -> Local {
113        fn parse_one_step(graph: &mut Graph, src: Local, place_elem: PlaceElem) -> Local {
114            let dst = graph.nodes.push(GraphNode::new());
115            match place_elem {
116                PlaceElem::Deref => {
117                    graph.add_node_edge(src, dst, EdgeOp::Deref);
118                }
119                PlaceElem::Field(field_idx, _) => {
120                    graph.add_node_edge(src, dst, EdgeOp::Field(format!("{:?}", field_idx)));
121                }
122                PlaceElem::Downcast(symbol, _) => {
123                    graph.add_node_edge(src, dst, EdgeOp::Downcast(symbol.unwrap().to_string()));
124                }
125                PlaceElem::Index(idx) => {
126                    graph.add_node_edge(src, dst, EdgeOp::Index);
127                    graph.add_node_edge(idx, dst, EdgeOp::Nop);
128                }
129                PlaceElem::ConstantIndex { .. } => {
130                    graph.add_node_edge(src, dst, EdgeOp::ConstIndex);
131                }
132                PlaceElem::Subslice { .. } => {
133                    graph.add_node_edge(src, dst, EdgeOp::SubSlice);
134                }
135                PlaceElem::Subtype(..) => {
136                    graph.add_node_edge(src, dst, EdgeOp::SubType);
137                }
138                _ => {
139                    println!("{:?}", place_elem);
140                    todo!()
141                }
142            }
143            dst
144        }
145        let mut ret = place.local;
146        for place_elem in place.projection {
147            // if there are projections, then add marker nodes
148            ret = parse_one_step(self, ret, place_elem);
149        }
150        ret
151    }
152
153    pub fn add_statm_to_graph(&mut self, statement: &Statement) {
154        if let StatementKind::Assign(boxed_statm) = &statement.kind {
155            let place = boxed_statm.0;
156            let dst = self.parse_place(&place);
157            self.nodes[dst].span = statement.source_info.span;
158            let rvalue = &boxed_statm.1;
159            let seq = self.nodes[dst].seq;
160            if seq == self.nodes[dst].ops.len() {
161                //warning: we do not check whether seq > len
162                self.nodes[dst].ops.push(NodeOp::Nop);
163            }
164            match rvalue {
165                Rvalue::Use(op) => {
166                    self.add_operand(op, dst);
167                    self.nodes[dst].ops[seq] = NodeOp::Use;
168                }
169                Rvalue::Repeat(op, _) => {
170                    self.add_operand(op, dst);
171                    self.nodes[dst].ops[seq] = NodeOp::Repeat;
172                }
173                Rvalue::Ref(_, borrow_kind, place) => {
174                    let op = match borrow_kind {
175                        BorrowKind::Shared => EdgeOp::Immut,
176                        BorrowKind::Mut { .. } => EdgeOp::Mut,
177                        BorrowKind::Fake(_) => EdgeOp::Nop, // todo
178                    };
179                    let src = self.parse_place(place);
180                    self.add_node_edge(src, dst, op);
181                    self.nodes[dst].ops[seq] = NodeOp::Ref;
182                }
183                Rvalue::Len(place) => {
184                    let src = self.parse_place(place);
185                    self.add_node_edge(src, dst, EdgeOp::Nop);
186                    self.nodes[dst].ops[seq] = NodeOp::Len;
187                }
188                Rvalue::Cast(_cast_kind, operand, _) => {
189                    self.add_operand(operand, dst);
190                    self.nodes[dst].ops[seq] = NodeOp::Cast;
191                }
192                Rvalue::BinaryOp(_, operands) => {
193                    self.add_operand(&operands.0, dst);
194                    self.add_operand(&operands.1, dst);
195                    self.nodes[dst].ops[seq] = NodeOp::CheckedBinaryOp;
196                }
197                Rvalue::Aggregate(boxed_kind, operands) => {
198                    for operand in operands.iter() {
199                        self.add_operand(operand, dst);
200                    }
201                    match **boxed_kind {
202                        AggregateKind::Array(_) => {
203                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Array)
204                        }
205                        AggregateKind::Tuple => {
206                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Tuple)
207                        }
208                        AggregateKind::Adt(def_id, ..) => {
209                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Adt(def_id))
210                        }
211                        AggregateKind::Closure(def_id, ..) => {
212                            self.closures.insert(def_id);
213                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Closure(def_id))
214                        }
215                        AggregateKind::Coroutine(def_id, ..) => {
216                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Coroutine(def_id))
217                        }
218                        AggregateKind::RawPtr(_, _mutability) => {
219                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::RawPtr)
220                            // We temporarily have not taken mutability into account
221                        }
222                        _ => {
223                            println!("{:?}", boxed_kind);
224                            todo!()
225                        }
226                    }
227                }
228                Rvalue::UnaryOp(_, operand) => {
229                    self.add_operand(operand, dst);
230                    self.nodes[dst].ops[seq] = NodeOp::UnaryOp;
231                }
232                Rvalue::NullaryOp(_, ty) => {
233                    self.add_const_edge(ty.to_string(), ty.to_string(), dst, EdgeOp::Nop);
234                    self.nodes[dst].ops[seq] = NodeOp::NullaryOp;
235                }
236                Rvalue::ThreadLocalRef(_) => {
237                    //todo!()
238                }
239                Rvalue::Discriminant(place) => {
240                    let src = self.parse_place(place);
241                    self.add_node_edge(src, dst, EdgeOp::Nop);
242                    self.nodes[dst].ops[seq] = NodeOp::Discriminant;
243                }
244                Rvalue::ShallowInitBox(operand, _) => {
245                    self.add_operand(operand, dst);
246                    self.nodes[dst].ops[seq] = NodeOp::ShallowInitBox;
247                }
248                Rvalue::CopyForDeref(place) => {
249                    let src = self.parse_place(place);
250                    self.add_node_edge(src, dst, EdgeOp::Nop);
251                    self.nodes[dst].ops[seq] = NodeOp::CopyForDeref;
252                }
253                Rvalue::RawPtr(_, place) => {
254                    let src = self.parse_place(place);
255                    self.add_node_edge(src, dst, EdgeOp::Nop); // Mutability?
256                    self.nodes[dst].ops[seq] = NodeOp::RawPtr;
257                }
258                _ => todo!(),
259            };
260            self.nodes[dst].seq = seq + 1;
261        }
262    }
263
264    pub fn add_terminator_to_graph(&mut self, terminator: &Terminator) {
265        if let TerminatorKind::Call {
266            func,
267            args,
268            destination,
269            ..
270        } = &terminator.kind
271        {
272            let dst = destination.local;
273            let seq = self.nodes[dst].seq;
274            if seq == self.nodes[dst].ops.len() {
275                self.nodes[dst].ops.push(NodeOp::Nop);
276            }
277            match func {
278                Operand::Constant(boxed_cnst) => {
279                    if let Const::Val(_, ty) = boxed_cnst.const_ {
280                        if let TyKind::FnDef(def_id, _) = ty.kind() {
281                            for op in args.iter() {
282                                //rustc version related
283                                self.add_operand(&op.node, dst);
284                            }
285                            self.nodes[dst].ops[seq] = NodeOp::Call(*def_id);
286                        }
287                    }
288                }
289                Operand::Move(_) => {
290                    self.add_operand(func, dst); //the func is a place
291                    for op in args.iter() {
292                        //rustc version related
293                        self.add_operand(&op.node, dst);
294                    }
295                    self.nodes[dst].ops[seq] = NodeOp::CallOperand;
296                }
297                _ => {
298                    println!("{:?}", func);
299                    todo!();
300                }
301            }
302            self.nodes[dst].span = terminator.source_info.span;
303            self.nodes[dst].seq = seq + 1;
304        }
305    }
306
307    // Because a node(local) may have multiple ops, we need to decide whether to strictly collect equivalent locals or not
308    // For the former, all the ops should meet the equivalent condition.
309    // For the later, if only one op meets the condition, we still take it into consideration.
310    pub fn collect_equivalent_locals(&self, local: Local, strict: bool) -> HashSet<Local> {
311        let mut set = HashSet::new();
312        let root = Cell::new(local);
313        let reduce_func = if strict {
314            DFSStatus::and
315        } else {
316            DFSStatus::or
317        };
318        let mut find_root_operator = |graph: &Graph, idx: Local| -> DFSStatus {
319            let node = &graph.nodes[idx];
320            node.ops
321                .iter()
322                .map(|op| {
323                    match op {
324                        NodeOp::Nop | NodeOp::Use | NodeOp::Ref => {
325                            //Nop means an orphan node or a parameter
326                            root.set(idx);
327                            DFSStatus::Continue
328                        }
329                        NodeOp::Call(_) => {
330                            //We are moving towards upside. Thus we can record the call node and stop dfs.
331                            //We stop because the return value does not equal to parameters
332                            root.set(idx);
333                            DFSStatus::Stop
334                        }
335                        _ => DFSStatus::Stop,
336                    }
337                })
338                .reduce(reduce_func)
339                .unwrap()
340        };
341        let mut find_equivalent_operator = |graph: &Graph, idx: Local| -> DFSStatus {
342            let node = &graph.nodes[idx];
343            if set.contains(&idx) {
344                return DFSStatus::Stop;
345            }
346            node.ops
347                .iter()
348                .map(|op| match op {
349                    NodeOp::Nop | NodeOp::Use | NodeOp::Ref => {
350                        set.insert(idx);
351                        DFSStatus::Continue
352                    }
353                    NodeOp::Call(_) => {
354                        if idx == root.get() {
355                            set.insert(idx);
356                            DFSStatus::Continue
357                        } else {
358                            // We are moving towards downside. Thus we stop dfs right now.
359                            DFSStatus::Stop
360                        }
361                    }
362                    _ => DFSStatus::Stop,
363                })
364                .reduce(reduce_func)
365                .unwrap()
366        };
367        // Algorithm: dfs along upside to find the root node, and then dfs along downside to collect equivalent locals
368        let mut seen = HashSet::new();
369        self.dfs(
370            local,
371            Direction::Upside,
372            &mut find_root_operator,
373            &mut Self::equivalent_edge_validator,
374            true,
375            &mut seen,
376        );
377        seen.clear();
378        self.dfs(
379            root.get(),
380            Direction::Downside,
381            &mut find_equivalent_operator,
382            &mut Self::equivalent_edge_validator,
383            true,
384            &mut seen,
385        );
386        set
387    }
388
389    pub fn collect_ancestor_locals(&self, local: Local, self_included: bool) -> HashSet<Local> {
390        let mut ret = HashSet::new();
391        let mut node_operator = |_: &Graph, idx: Local| -> DFSStatus {
392            ret.insert(idx);
393            DFSStatus::Continue
394        };
395        let mut seen = HashSet::new();
396        self.dfs(
397            local,
398            Direction::Upside,
399            &mut node_operator,
400            &mut Graph::always_true_edge_validator,
401            true,
402            &mut seen,
403        );
404        if !self_included {
405            ret.remove(&local);
406        }
407        ret
408    }
409
410    pub fn is_connected(&self, idx_1: Local, idx_2: Local) -> bool {
411        let target = idx_2;
412        let find = Cell::new(false);
413        let mut node_operator = |_: &Graph, idx: Local| -> DFSStatus {
414            find.set(idx == target);
415            if find.get() {
416                DFSStatus::Stop
417            } else {
418                // if not found, move on
419                DFSStatus::Continue
420            }
421        };
422        let mut seen = HashSet::new();
423        self.dfs(
424            idx_1,
425            Direction::Downside,
426            &mut node_operator,
427            &mut Self::always_true_edge_validator,
428            false,
429            &mut seen,
430        );
431        seen.clear();
432        if !find.get() {
433            self.dfs(
434                idx_1,
435                Direction::Upside,
436                &mut node_operator,
437                &mut Self::always_true_edge_validator,
438                false,
439                &mut seen,
440            );
441        }
442        find.get()
443    }
444
445    // Whether there exists dataflow between each parameter and the return value
446    pub fn param_return_deps(&self) -> IndexVec<Local, bool> {
447        let _0 = Local::from_usize(0);
448        let deps = (0..self.argc + 1) //the length is argc + 1, because _0 depends on _0 itself.
449            .map(|i| {
450                let _i = Local::from_usize(i);
451                self.is_connected(_i, _0)
452            })
453            .collect();
454        deps
455    }
456
457    // This function uses precedence traversal.
458    // The node operator and edge validator decide how far the traversal can reach.
459    // `traverse_all` decides if a branch finds the target successfully, whether the traversal will continue or not.
460    // For example, if you need to instantly stop the traversal once finding a certain node, then set `traverse_all` to false.
461    // If you want to traverse all the reachable nodes which are decided by the operator and validator, then set `traverse_all` to true.
462    pub fn dfs<F, G>(
463        &self,
464        now: Local,
465        direction: Direction,
466        node_operator: &mut F,
467        edge_validator: &mut G,
468        traverse_all: bool,
469        seen: &mut HashSet<Local>,
470    ) -> (DFSStatus, bool)
471    where
472        F: FnMut(&Graph, Local) -> DFSStatus,
473        G: FnMut(&Graph, EdgeIdx) -> DFSStatus,
474    {
475        if seen.contains(&now) {
476            return (DFSStatus::Stop, false);
477        }
478        seen.insert(now);
479        macro_rules! traverse {
480            ($edges: ident, $field: ident) => {
481                for edge_idx in self.nodes[now].$edges.iter() {
482                    let edge = &self.edges[*edge_idx];
483                    if matches!(edge_validator(self, *edge_idx), DFSStatus::Continue) {
484                        let (dfs_status, result) = self.dfs(
485                            edge.$field,
486                            direction,
487                            node_operator,
488                            edge_validator,
489                            traverse_all,
490                            seen,
491                        );
492
493                        if matches!(dfs_status, DFSStatus::Stop) && result && !traverse_all {
494                            return (DFSStatus::Stop, true);
495                        }
496                    }
497                }
498            };
499        }
500
501        if matches!(node_operator(self, now), DFSStatus::Continue) {
502            match direction {
503                Direction::Upside => {
504                    traverse!(in_edges, src);
505                }
506                Direction::Downside => {
507                    traverse!(out_edges, dst);
508                }
509                Direction::Both => {
510                    traverse!(in_edges, src);
511                    traverse!(out_edges, dst);
512                }
513            };
514            (DFSStatus::Continue, false)
515        } else {
516            (DFSStatus::Stop, true)
517        }
518    }
519
520    pub fn get_upside_idx(&self, node_idx: Local, order: usize) -> Option<Local> {
521        if let Some(edge_idx) = self.nodes[node_idx].in_edges.get(order) {
522            Some(self.edges[*edge_idx].src)
523        } else {
524            None
525        }
526    }
527
528    pub fn get_downside_idx(&self, node_idx: Local, order: usize) -> Option<Local> {
529        if let Some(edge_idx) = self.nodes[node_idx].out_edges.get(order) {
530            Some(self.edges[*edge_idx].dst)
531        } else {
532            None
533        }
534    }
535
536    // if strict is set to false, we return the first node that wraps the target span and at least one end overlaps
537    pub fn query_node_by_span(&self, span: Span, strict: bool) -> Option<(Local, &GraphNode)> {
538        for (node_idx, node) in self.nodes.iter_enumerated() {
539            if strict {
540                if node.span == span {
541                    return Some((node_idx, node));
542                }
543            } else {
544                if !relative_pos_range(node.span, span).eq(0..0)
545                    && (node.span.lo() == span.lo() || node.span.hi() == span.hi())
546                {
547                    return Some((node_idx, node));
548                }
549            }
550        }
551        None
552    }
553
554    pub fn is_marker(&self, idx: Local) -> bool {
555        idx >= Local::from_usize(self.n_locals)
556    }
557}
558
559impl Graph {
560    pub fn equivalent_edge_validator(graph: &Graph, idx: EdgeIdx) -> DFSStatus {
561        match graph.edges[idx].op {
562            EdgeOp::Copy | EdgeOp::Move | EdgeOp::Mut | EdgeOp::Immut | EdgeOp::Deref => {
563                DFSStatus::Continue
564            }
565            EdgeOp::Nop
566            | EdgeOp::Const
567            | EdgeOp::Downcast(_)
568            | EdgeOp::Field(_)
569            | EdgeOp::Index
570            | EdgeOp::ConstIndex
571            | EdgeOp::SubSlice
572            | EdgeOp::SubType => DFSStatus::Stop,
573        }
574    }
575
576    pub fn always_true_edge_validator(_: &Graph, _: EdgeIdx) -> DFSStatus {
577        DFSStatus::Continue
578    }
579}
580
581#[derive(Clone, Copy)]
582pub enum Direction {
583    Upside,
584    Downside,
585    Both,
586}
587
588pub enum DFSStatus {
589    Continue, // true
590    Stop,     // false
591}
592
593impl DFSStatus {
594    pub fn and(s1: DFSStatus, s2: DFSStatus) -> DFSStatus {
595        if matches!(s1, DFSStatus::Stop) || matches!(s2, DFSStatus::Stop) {
596            DFSStatus::Stop
597        } else {
598            DFSStatus::Continue
599        }
600    }
601
602    pub fn or(s1: DFSStatus, s2: DFSStatus) -> DFSStatus {
603        if matches!(s1, DFSStatus::Continue) || matches!(s2, DFSStatus::Continue) {
604            DFSStatus::Continue
605        } else {
606            DFSStatus::Stop
607        }
608    }
609}