rapx/analysis/core/dataflow/
graph.rs

1use std::cell::Cell;
2use std::collections::HashSet;
3
4use rustc_hir::def_id::DefId;
5use rustc_index::IndexVec;
6use rustc_middle::{
7    mir::{
8        AggregateKind, BorrowKind, Const, Local, Operand, Place, PlaceElem, Rvalue, Statement,
9        StatementKind, Terminator, TerminatorKind,
10    },
11    ty::TyKind,
12};
13use rustc_span::{Span, DUMMY_SP};
14
15use crate::{analysis::core::dataflow::*, utils::log::relative_pos_range};
16
17impl GraphNode {
18    pub fn new() -> Self {
19        Self {
20            ops: vec![NodeOp::Nop],
21            span: DUMMY_SP,
22            seq: 0,
23            out_edges: vec![],
24            in_edges: vec![],
25        }
26    }
27}
28
29#[derive(Clone)]
30pub struct Graph {
31    pub def_id: DefId,
32    pub span: Span,
33    pub argc: usize,
34    pub nodes: GraphNodes, //constsis of locals in mir and newly created markers
35    pub edges: GraphEdges,
36    pub n_locals: usize,
37    pub closures: HashSet<DefId>,
38}
39
40impl From<Graph> for DataFlowGraph {
41    fn from(graph: Graph) -> Self {
42        let param_ret_deps = graph.param_return_deps();
43        DataFlowGraph {
44            nodes: graph.nodes,
45            edges: graph.edges,
46            param_ret_deps: param_ret_deps,
47        }
48    }
49}
50
51impl Graph {
52    pub fn new(def_id: DefId, span: Span, argc: usize, n_locals: usize) -> Self {
53        Self {
54            def_id,
55            span,
56            argc,
57            nodes: GraphNodes::from_elem_n(GraphNode::new(), n_locals),
58            edges: GraphEdges::new(),
59            n_locals,
60            closures: HashSet::new(),
61        }
62    }
63
64    // add an edge into an existing node
65    pub fn add_node_edge(&mut self, src: Local, dst: Local, op: EdgeOp) -> EdgeIdx {
66        let seq = self.nodes[dst].seq;
67        let edge_idx = self.edges.push(GraphEdge { src, dst, op, seq });
68        self.nodes[dst].in_edges.push(edge_idx);
69        self.nodes[src].out_edges.push(edge_idx);
70        edge_idx
71    }
72
73    // add an edge into an existing node with const value as src
74    pub fn add_const_edge(
75        &mut self,
76        src_desc: String,
77        src_ty: String,
78        dst: Local,
79        op: EdgeOp,
80    ) -> EdgeIdx {
81        let seq = self.nodes[dst].seq;
82        let mut const_node = GraphNode::new();
83        const_node.ops[0] = NodeOp::Const(src_desc, src_ty);
84        let src = self.nodes.push(const_node);
85        let edge_idx = self.edges.push(GraphEdge { src, dst, op, seq });
86        self.nodes[dst].in_edges.push(edge_idx);
87        edge_idx
88    }
89
90    pub fn add_operand(&mut self, operand: &Operand, dst: Local) {
91        match operand {
92            Operand::Copy(place) => {
93                let src = self.parse_place(place);
94                self.add_node_edge(src, dst, EdgeOp::Copy);
95            }
96            Operand::Move(place) => {
97                let src = self.parse_place(place);
98                self.add_node_edge(src, dst, EdgeOp::Move);
99            }
100            Operand::Constant(boxed_const_op) => {
101                let src_desc = boxed_const_op.const_.to_string();
102                let src_ty = match boxed_const_op.const_ {
103                    Const::Val(_, ty) => ty.to_string(),
104                    Const::Unevaluated(_, ty) => ty.to_string(),
105                    Const::Ty(ty, _) => ty.to_string(),
106                };
107                self.add_const_edge(src_desc, src_ty, dst, EdgeOp::Const);
108            }
109        }
110    }
111
112    pub fn parse_place(&mut self, place: &Place) -> Local {
113        fn parse_one_step(graph: &mut Graph, src: Local, place_elem: PlaceElem) -> Local {
114            let dst = graph.nodes.push(GraphNode::new());
115            match place_elem {
116                PlaceElem::Deref => {
117                    graph.add_node_edge(src, dst, EdgeOp::Deref);
118                }
119                PlaceElem::Field(field_idx, _) => {
120                    graph.add_node_edge(src, dst, EdgeOp::Field(field_idx.as_usize()));
121                }
122                PlaceElem::Downcast(symbol, _) => {
123                    graph.add_node_edge(src, dst, EdgeOp::Downcast(symbol.unwrap().to_string()));
124                }
125                PlaceElem::Index(idx) => {
126                    graph.add_node_edge(src, dst, EdgeOp::Index);
127                    graph.add_node_edge(idx, dst, EdgeOp::Nop);
128                }
129                PlaceElem::ConstantIndex { .. } => {
130                    graph.add_node_edge(src, dst, EdgeOp::ConstIndex);
131                }
132                PlaceElem::Subslice { .. } => {
133                    graph.add_node_edge(src, dst, EdgeOp::SubSlice);
134                }
135                PlaceElem::Subtype(..) => {
136                    graph.add_node_edge(src, dst, EdgeOp::SubType);
137                }
138                _ => {
139                    println!("{:?}", place_elem);
140                    todo!()
141                }
142            }
143            dst
144        }
145        let mut ret = place.local;
146        for place_elem in place.projection {
147            // if there are projections, then add marker nodes
148            ret = parse_one_step(self, ret, place_elem);
149        }
150        ret
151    }
152
153    pub fn add_statm_to_graph(&mut self, statement: &Statement) {
154        if let StatementKind::Assign(boxed_statm) = &statement.kind {
155            let place = boxed_statm.0;
156            let dst = self.parse_place(&place);
157            self.nodes[dst].span = statement.source_info.span;
158            let rvalue = &boxed_statm.1;
159            let seq = self.nodes[dst].seq;
160            if seq == self.nodes[dst].ops.len() {
161                //warning: we do not check whether seq > len
162                self.nodes[dst].ops.push(NodeOp::Nop);
163            }
164            match rvalue {
165                Rvalue::Use(op) => {
166                    self.add_operand(op, dst);
167                    self.nodes[dst].ops[seq] = NodeOp::Use;
168                }
169                Rvalue::Repeat(op, _) => {
170                    self.add_operand(op, dst);
171                    self.nodes[dst].ops[seq] = NodeOp::Repeat;
172                }
173                Rvalue::Ref(_, borrow_kind, place) => {
174                    let op = match borrow_kind {
175                        BorrowKind::Shared => EdgeOp::Immut,
176                        BorrowKind::Mut { .. } => EdgeOp::Mut,
177                        BorrowKind::Fake(_) => EdgeOp::Nop, // todo
178                    };
179                    let src = self.parse_place(place);
180                    self.add_node_edge(src, dst, op);
181                    self.nodes[dst].ops[seq] = NodeOp::Ref;
182                }
183                Rvalue::Len(place) => {
184                    let src = self.parse_place(place);
185                    self.add_node_edge(src, dst, EdgeOp::Nop);
186                    self.nodes[dst].ops[seq] = NodeOp::Len;
187                }
188                Rvalue::Cast(_cast_kind, operand, _) => {
189                    self.add_operand(operand, dst);
190                    self.nodes[dst].ops[seq] = NodeOp::Cast;
191                }
192                Rvalue::BinaryOp(_, operands) => {
193                    self.add_operand(&operands.0, dst);
194                    self.add_operand(&operands.1, dst);
195                    self.nodes[dst].ops[seq] = NodeOp::CheckedBinaryOp;
196                }
197                Rvalue::Aggregate(boxed_kind, operands) => {
198                    for operand in operands.iter() {
199                        self.add_operand(operand, dst);
200                    }
201                    match **boxed_kind {
202                        AggregateKind::Array(_) => {
203                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Array)
204                        }
205                        AggregateKind::Tuple => {
206                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Tuple)
207                        }
208                        AggregateKind::Adt(def_id, ..) => {
209                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Adt(def_id))
210                        }
211                        AggregateKind::Closure(def_id, ..) => {
212                            self.closures.insert(def_id);
213                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Closure(def_id))
214                        }
215                        AggregateKind::Coroutine(def_id, ..) => {
216                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Coroutine(def_id))
217                        }
218                        AggregateKind::RawPtr(_, _mutability) => {
219                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::RawPtr)
220                            // We temporarily have not taken mutability into account
221                        }
222                        _ => {
223                            println!("{:?}", boxed_kind);
224                            todo!()
225                        }
226                    }
227                }
228                Rvalue::UnaryOp(_, operand) => {
229                    self.add_operand(operand, dst);
230                    self.nodes[dst].ops[seq] = NodeOp::UnaryOp;
231                }
232                Rvalue::NullaryOp(_, ty) => {
233                    self.add_const_edge(ty.to_string(), ty.to_string(), dst, EdgeOp::Nop);
234                    self.nodes[dst].ops[seq] = NodeOp::NullaryOp;
235                }
236                Rvalue::ThreadLocalRef(_) => {
237                    //todo!()
238                }
239                Rvalue::Discriminant(place) => {
240                    let src = self.parse_place(place);
241                    self.add_node_edge(src, dst, EdgeOp::Nop);
242                    self.nodes[dst].ops[seq] = NodeOp::Discriminant;
243                }
244                Rvalue::ShallowInitBox(operand, _) => {
245                    self.add_operand(operand, dst);
246                    self.nodes[dst].ops[seq] = NodeOp::ShallowInitBox;
247                }
248                Rvalue::CopyForDeref(place) => {
249                    let src = self.parse_place(place);
250                    self.add_node_edge(src, dst, EdgeOp::Nop);
251                    self.nodes[dst].ops[seq] = NodeOp::CopyForDeref;
252                }
253                Rvalue::RawPtr(_, place) => {
254                    let src = self.parse_place(place);
255                    self.add_node_edge(src, dst, EdgeOp::Nop); // Mutability?
256                    self.nodes[dst].ops[seq] = NodeOp::RawPtr;
257                }
258                _ => todo!(),
259            };
260            self.nodes[dst].seq = seq + 1;
261        }
262    }
263
264    pub fn add_terminator_to_graph(&mut self, terminator: &Terminator) {
265        if let TerminatorKind::Call {
266            func,
267            args,
268            destination,
269            ..
270        } = &terminator.kind
271        {
272            let dst = destination.local;
273            let seq = self.nodes[dst].seq;
274            if seq == self.nodes[dst].ops.len() {
275                self.nodes[dst].ops.push(NodeOp::Nop);
276            }
277            match func {
278                Operand::Constant(boxed_cnst) => {
279                    if let Const::Val(_, ty) = boxed_cnst.const_ {
280                        if let TyKind::FnDef(def_id, _) = ty.kind() {
281                            for op in args.iter() {
282                                //rustc version related
283                                self.add_operand(&op.node, dst);
284                            }
285                            self.nodes[dst].ops[seq] = NodeOp::Call(*def_id);
286                        }
287                    }
288                }
289                Operand::Move(_) => {
290                    self.add_operand(func, dst); //the func is a place
291                    for op in args.iter() {
292                        //rustc version related
293                        self.add_operand(&op.node, dst);
294                    }
295                    self.nodes[dst].ops[seq] = NodeOp::CallOperand;
296                }
297                _ => {
298                    println!("{:?}", func);
299                    todo!();
300                }
301            }
302            self.nodes[dst].span = terminator.source_info.span;
303            self.nodes[dst].seq = seq + 1;
304        }
305    }
306
307    // Because a node(local) may have multiple ops, we need to decide whether to strictly collect equivalent locals or not
308    // For the former, all the ops should meet the equivalent condition.
309    // For the later, if only one op meets the condition, we still take it into consideration.
310    pub fn collect_equivalent_locals(&self, local: Local, strict: bool) -> HashSet<Local> {
311        let mut set = HashSet::new();
312        let root = Cell::new(local);
313        let reduce_func = if strict {
314            DFSStatus::and
315        } else {
316            DFSStatus::or
317        };
318        let mut find_root_operator = |graph: &Graph, idx: Local| -> DFSStatus {
319            let node = &graph.nodes[idx];
320            node.ops
321                .iter()
322                .map(|op| {
323                    match op {
324                        NodeOp::Nop | NodeOp::Use | NodeOp::Ref => {
325                            //Nop means an orphan node or a parameter
326                            root.set(idx);
327                            DFSStatus::Continue
328                        }
329                        NodeOp::Call(_) => {
330                            //We are moving towards upside. Thus we can record the call node and stop dfs.
331                            //We stop because the return value does not equal to parameters
332                            root.set(idx);
333                            DFSStatus::Stop
334                        }
335                        _ => DFSStatus::Stop,
336                    }
337                })
338                .reduce(reduce_func)
339                .unwrap()
340        };
341        let mut find_equivalent_operator = |graph: &Graph, idx: Local| -> DFSStatus {
342            let node = &graph.nodes[idx];
343            if set.contains(&idx) {
344                return DFSStatus::Stop;
345            }
346            node.ops
347                .iter()
348                .map(|op| match op {
349                    NodeOp::Nop | NodeOp::Use | NodeOp::Ref => {
350                        set.insert(idx);
351                        DFSStatus::Continue
352                    }
353                    NodeOp::Call(_) => {
354                        if idx == root.get() {
355                            set.insert(idx);
356                            DFSStatus::Continue
357                        } else {
358                            // We are moving towards downside. Thus we stop dfs right now.
359                            DFSStatus::Stop
360                        }
361                    }
362                    _ => DFSStatus::Stop,
363                })
364                .reduce(reduce_func)
365                .unwrap()
366        };
367        // Algorithm: dfs along upside to find the root node, and then dfs along downside to collect equivalent locals
368        let mut seen = HashSet::new();
369        self.dfs(
370            local,
371            Direction::Upside,
372            &mut find_root_operator,
373            &mut Self::equivalent_edge_validator,
374            true,
375            &mut seen,
376        );
377        seen.clear();
378        self.dfs(
379            root.get(),
380            Direction::Downside,
381            &mut find_equivalent_operator,
382            &mut Self::equivalent_edge_validator,
383            true,
384            &mut seen,
385        );
386        set
387    }
388
389    pub fn collect_ancestor_locals(&self, local: Local, self_included: bool) -> HashSet<Local> {
390        let mut ret = HashSet::new();
391        let mut node_operator = |_: &Graph, idx: Local| -> DFSStatus {
392            ret.insert(idx);
393            DFSStatus::Continue
394        };
395        let mut seen = HashSet::new();
396        self.dfs(
397            local,
398            Direction::Upside,
399            &mut node_operator,
400            &mut Graph::always_true_edge_validator,
401            true,
402            &mut seen,
403        );
404        if !self_included {
405            ret.remove(&local);
406        }
407        ret
408    }
409
410    pub fn collect_descending_locals(&self, local: Local, self_included: bool) -> HashSet<Local> {
411        let mut ret = HashSet::new();
412        let mut node_operator = |_: &Graph, idx: Local| -> DFSStatus {
413            ret.insert(idx);
414            DFSStatus::Continue
415        };
416        let mut seen = HashSet::new();
417        self.dfs(
418            local,
419            Direction::Downside,
420            &mut node_operator,
421            &mut Graph::always_true_edge_validator,
422            true,
423            &mut seen,
424        );
425        if !self_included {
426            ret.remove(&local);
427        }
428        ret
429    }
430
431    pub fn get_field_sequence(&self, local: Local) -> Option<(Local, Vec<usize>)> {
432        let mut fields = vec![];
433        let var = Cell::new(local);
434        let mut node_operator = |graph: &Graph, idx: Local| -> DFSStatus {
435            if graph.is_marker(idx) {
436                DFSStatus::Continue
437            } else {
438                var.set(idx);
439                DFSStatus::Stop
440            }
441        };
442        let mut edge_validator = |graph: &Graph, idx: EdgeIdx| -> DFSStatus {
443            if let EdgeOp::Field(field) = graph.edges[idx].op {
444                fields.insert(0, field);
445                DFSStatus::Continue
446            } else {
447                DFSStatus::Stop
448            }
449        };
450        let mut seen = HashSet::new();
451        self.dfs(
452            local,
453            Direction::Upside,
454            &mut node_operator,
455            &mut edge_validator,
456            false,
457            &mut seen,
458        );
459        if fields.is_empty() {
460            None
461        } else {
462            Some((var.get(), fields))
463        }
464    }
465
466    pub fn is_connected(&self, idx_1: Local, idx_2: Local) -> bool {
467        let target = idx_2;
468        let find = Cell::new(false);
469        let mut node_operator = |_: &Graph, idx: Local| -> DFSStatus {
470            find.set(idx == target);
471            if find.get() {
472                DFSStatus::Stop
473            } else {
474                // if not found, move on
475                DFSStatus::Continue
476            }
477        };
478        let mut seen = HashSet::new();
479        self.dfs(
480            idx_1,
481            Direction::Downside,
482            &mut node_operator,
483            &mut Self::always_true_edge_validator,
484            false,
485            &mut seen,
486        );
487        seen.clear();
488        if !find.get() {
489            self.dfs(
490                idx_1,
491                Direction::Upside,
492                &mut node_operator,
493                &mut Self::always_true_edge_validator,
494                false,
495                &mut seen,
496            );
497        }
498        find.get()
499    }
500
501    // Whether there exists dataflow between each parameter and the return value
502    pub fn param_return_deps(&self) -> IndexVec<Local, bool> {
503        let _0 = Local::from_usize(0);
504        let deps = (0..self.argc + 1) //the length is argc + 1, because _0 depends on _0 itself.
505            .map(|i| {
506                let _i = Local::from_usize(i);
507                self.is_connected(_i, _0)
508            })
509            .collect();
510        deps
511    }
512
513    // This function uses precedence traversal.
514    // The node operator and edge validator decide how far the traversal can reach.
515    // `traverse_all` decides if a branch finds the target successfully, whether the traversal will continue or not.
516    // For example, if you need to instantly stop the traversal once finding a certain node, then set `traverse_all` to false.
517    // If you want to traverse all the reachable nodes which are decided by the operator and validator, then set `traverse_all` to true.
518    pub fn dfs<F, G>(
519        &self,
520        now: Local,
521        direction: Direction,
522        node_operator: &mut F,
523        edge_validator: &mut G,
524        traverse_all: bool,
525        seen: &mut HashSet<Local>,
526    ) -> (DFSStatus, bool)
527    where
528        F: FnMut(&Graph, Local) -> DFSStatus,
529        G: FnMut(&Graph, EdgeIdx) -> DFSStatus,
530    {
531        if seen.contains(&now) {
532            return (DFSStatus::Stop, false);
533        }
534        seen.insert(now);
535        macro_rules! traverse {
536            ($edges: ident, $field: ident) => {
537                for edge_idx in self.nodes[now].$edges.iter() {
538                    let edge = &self.edges[*edge_idx];
539                    if matches!(edge_validator(self, *edge_idx), DFSStatus::Continue) {
540                        let (dfs_status, result) = self.dfs(
541                            edge.$field,
542                            direction,
543                            node_operator,
544                            edge_validator,
545                            traverse_all,
546                            seen,
547                        );
548
549                        if matches!(dfs_status, DFSStatus::Stop) && result && !traverse_all {
550                            return (DFSStatus::Stop, true);
551                        }
552                    }
553                }
554            };
555        }
556
557        if matches!(node_operator(self, now), DFSStatus::Continue) {
558            match direction {
559                Direction::Upside => {
560                    traverse!(in_edges, src);
561                }
562                Direction::Downside => {
563                    traverse!(out_edges, dst);
564                }
565                Direction::Both => {
566                    traverse!(in_edges, src);
567                    traverse!(out_edges, dst);
568                }
569            };
570            (DFSStatus::Continue, false)
571        } else {
572            (DFSStatus::Stop, true)
573        }
574    }
575
576    pub fn get_upside_idx(&self, node_idx: Local, order: usize) -> Option<Local> {
577        if let Some(edge_idx) = self.nodes[node_idx].in_edges.get(order) {
578            Some(self.edges[*edge_idx].src)
579        } else {
580            None
581        }
582    }
583
584    pub fn get_downside_idx(&self, node_idx: Local, order: usize) -> Option<Local> {
585        if let Some(edge_idx) = self.nodes[node_idx].out_edges.get(order) {
586            Some(self.edges[*edge_idx].dst)
587        } else {
588            None
589        }
590    }
591
592    // if strict is set to false, we return the first node that wraps the target span and at least one end overlaps
593    pub fn query_node_by_span(&self, span: Span, strict: bool) -> Option<(Local, &GraphNode)> {
594        for (node_idx, node) in self.nodes.iter_enumerated() {
595            if strict {
596                if node.span == span {
597                    return Some((node_idx, node));
598                }
599            } else {
600                if !relative_pos_range(node.span, span).eq(0..0)
601                    && (node.span.lo() == span.lo() || node.span.hi() == span.hi())
602                {
603                    return Some((node_idx, node));
604                }
605            }
606        }
607        None
608    }
609
610    pub fn is_marker(&self, idx: Local) -> bool {
611        idx >= Local::from_usize(self.n_locals)
612    }
613}
614
615impl Graph {
616    pub fn equivalent_edge_validator(graph: &Graph, idx: EdgeIdx) -> DFSStatus {
617        match graph.edges[idx].op {
618            EdgeOp::Copy | EdgeOp::Move | EdgeOp::Mut | EdgeOp::Immut | EdgeOp::Deref => {
619                DFSStatus::Continue
620            }
621            EdgeOp::Nop
622            | EdgeOp::Const
623            | EdgeOp::Downcast(_)
624            | EdgeOp::Field(_)
625            | EdgeOp::Index
626            | EdgeOp::ConstIndex
627            | EdgeOp::SubSlice
628            | EdgeOp::SubType => DFSStatus::Stop,
629        }
630    }
631
632    pub fn always_true_edge_validator(_: &Graph, _: EdgeIdx) -> DFSStatus {
633        DFSStatus::Continue
634    }
635}
636
637#[derive(Clone, Copy)]
638pub enum Direction {
639    Upside,
640    Downside,
641    Both,
642}
643
644pub enum DFSStatus {
645    Continue, // true
646    Stop,     // false
647}
648
649impl DFSStatus {
650    pub fn and(s1: DFSStatus, s2: DFSStatus) -> DFSStatus {
651        if matches!(s1, DFSStatus::Stop) || matches!(s2, DFSStatus::Stop) {
652            DFSStatus::Stop
653        } else {
654            DFSStatus::Continue
655        }
656    }
657
658    pub fn or(s1: DFSStatus, s2: DFSStatus) -> DFSStatus {
659        if matches!(s1, DFSStatus::Continue) || matches!(s2, DFSStatus::Continue) {
660            DFSStatus::Continue
661        } else {
662            DFSStatus::Stop
663        }
664    }
665}