rapx/analysis/core/dataflow/
graph.rs

1use std::cell::Cell;
2use std::collections::HashSet;
3
4use rustc_hir::def_id::DefId;
5use rustc_index::IndexVec;
6use rustc_middle::mir::{
7    AggregateKind, BorrowKind, Const, Local, Operand, Place, PlaceElem, Rvalue, Statement,
8    StatementKind, Terminator, TerminatorKind,
9};
10use rustc_middle::ty::TyKind;
11use rustc_span::{Span, DUMMY_SP};
12
13use crate::utils::log::relative_pos_range;
14
15#[derive(Clone, Debug)]
16pub enum NodeOp {
17    //warning: the fields are related to the version of the backend rustc version
18    Nop,
19    Err,
20    Const(String, String),
21    //Rvalue
22    Use,
23    Repeat,
24    Ref,
25    ThreadLocalRef,
26    AddressOf,
27    Len,
28    Cast,
29    BinaryOp,
30    CheckedBinaryOp,
31    NullaryOp,
32    UnaryOp,
33    Discriminant,
34    Aggregate(AggKind),
35    ShallowInitBox,
36    CopyForDeref,
37    RawPtr,
38    //TerminatorKind
39    Call(DefId),
40    CallOperand, // the first in_edge is the func
41}
42
43#[derive(Clone, Debug)]
44pub enum EdgeOp {
45    Nop,
46    //Operand
47    Move,
48    Copy,
49    Const,
50    //Mutability
51    Immut,
52    Mut,
53    //Place
54    Deref,
55    Field(String),
56    Downcast(String),
57    Index,
58    ConstIndex,
59    SubSlice,
60    SubType,
61}
62
63#[derive(Clone)]
64pub struct GraphEdge {
65    pub src: Local,
66    pub dst: Local,
67    pub op: EdgeOp,
68    pub seq: usize,
69}
70
71#[derive(Clone)]
72pub struct GraphNode {
73    pub ops: Vec<NodeOp>,
74    pub span: Span, //the corresponding code span
75    pub seq: usize, //the sequence number, edges with the same seq number are added in the same batch within a statement or terminator
76    pub out_edges: Vec<EdgeIdx>,
77    pub in_edges: Vec<EdgeIdx>,
78}
79
80impl GraphNode {
81    pub fn new() -> Self {
82        Self {
83            ops: vec![NodeOp::Nop],
84            span: DUMMY_SP,
85            seq: 0,
86            out_edges: vec![],
87            in_edges: vec![],
88        }
89    }
90}
91
92pub type EdgeIdx = usize;
93pub type GraphNodes = IndexVec<Local, GraphNode>;
94pub type GraphEdges = IndexVec<EdgeIdx, GraphEdge>;
95pub struct Graph {
96    pub def_id: DefId,
97    pub span: Span,
98    pub argc: usize,
99    pub nodes: GraphNodes, //constsis of locals in mir and newly created markers
100    pub edges: GraphEdges,
101    pub n_locals: usize,
102    pub closures: HashSet<DefId>,
103}
104
105impl Graph {
106    pub fn new(def_id: DefId, span: Span, argc: usize, n_locals: usize) -> Self {
107        Self {
108            def_id,
109            span,
110            argc,
111            nodes: GraphNodes::from_elem_n(GraphNode::new(), n_locals),
112            edges: GraphEdges::new(),
113            n_locals,
114            closures: HashSet::new(),
115        }
116    }
117
118    // add an edge into an existing node
119    pub fn add_node_edge(&mut self, src: Local, dst: Local, op: EdgeOp) -> EdgeIdx {
120        let seq = self.nodes[dst].seq;
121        let edge_idx = self.edges.push(GraphEdge { src, dst, op, seq });
122        self.nodes[dst].in_edges.push(edge_idx);
123        self.nodes[src].out_edges.push(edge_idx);
124        edge_idx
125    }
126
127    // add an edge into an existing node with const value as src
128    pub fn add_const_edge(
129        &mut self,
130        src_desc: String,
131        src_ty: String,
132        dst: Local,
133        op: EdgeOp,
134    ) -> EdgeIdx {
135        let seq = self.nodes[dst].seq;
136        let mut const_node = GraphNode::new();
137        const_node.ops[0] = NodeOp::Const(src_desc, src_ty);
138        let src = self.nodes.push(const_node);
139        let edge_idx = self.edges.push(GraphEdge { src, dst, op, seq });
140        self.nodes[dst].in_edges.push(edge_idx);
141        edge_idx
142    }
143
144    pub fn add_operand(&mut self, operand: &Operand, dst: Local) {
145        match operand {
146            Operand::Copy(place) => {
147                let src = self.parse_place(place);
148                self.add_node_edge(src, dst, EdgeOp::Copy);
149            }
150            Operand::Move(place) => {
151                let src = self.parse_place(place);
152                self.add_node_edge(src, dst, EdgeOp::Move);
153            }
154            Operand::Constant(boxed_const_op) => {
155                let src_desc = boxed_const_op.const_.to_string();
156                let src_ty = match boxed_const_op.const_ {
157                    Const::Val(_, ty) => ty.to_string(),
158                    Const::Unevaluated(_, ty) => ty.to_string(),
159                    Const::Ty(ty, _) => ty.to_string(),
160                };
161                self.add_const_edge(src_desc, src_ty, dst, EdgeOp::Const);
162            }
163        }
164    }
165
166    pub fn parse_place(&mut self, place: &Place) -> Local {
167        fn parse_one_step(graph: &mut Graph, src: Local, place_elem: PlaceElem) -> Local {
168            let dst = graph.nodes.push(GraphNode::new());
169            match place_elem {
170                PlaceElem::Deref => {
171                    graph.add_node_edge(src, dst, EdgeOp::Deref);
172                }
173                PlaceElem::Field(field_idx, _) => {
174                    graph.add_node_edge(src, dst, EdgeOp::Field(format!("{:?}", field_idx)));
175                }
176                PlaceElem::Downcast(symbol, _) => {
177                    graph.add_node_edge(src, dst, EdgeOp::Downcast(symbol.unwrap().to_string()));
178                }
179                PlaceElem::Index(idx) => {
180                    graph.add_node_edge(src, dst, EdgeOp::Index);
181                    graph.add_node_edge(idx, dst, EdgeOp::Nop);
182                }
183                PlaceElem::ConstantIndex { .. } => {
184                    graph.add_node_edge(src, dst, EdgeOp::ConstIndex);
185                }
186                PlaceElem::Subslice { .. } => {
187                    graph.add_node_edge(src, dst, EdgeOp::SubSlice);
188                }
189                PlaceElem::Subtype(..) => {
190                    graph.add_node_edge(src, dst, EdgeOp::SubType);
191                }
192                _ => {
193                    println!("{:?}", place_elem);
194                    todo!()
195                }
196            }
197            dst
198        }
199        let mut ret = place.local;
200        for place_elem in place.projection {
201            // if there are projections, then add marker nodes
202            ret = parse_one_step(self, ret, place_elem);
203        }
204        ret
205    }
206
207    pub fn add_statm_to_graph(&mut self, statement: &Statement) {
208        if let StatementKind::Assign(boxed_statm) = &statement.kind {
209            let place = boxed_statm.0;
210            let dst = self.parse_place(&place);
211            self.nodes[dst].span = statement.source_info.span;
212            let rvalue = &boxed_statm.1;
213            let seq = self.nodes[dst].seq;
214            if seq == self.nodes[dst].ops.len() {
215                //warning: we do not check whether seq > len
216                self.nodes[dst].ops.push(NodeOp::Nop);
217            }
218            match rvalue {
219                Rvalue::Use(op) => {
220                    self.add_operand(op, dst);
221                    self.nodes[dst].ops[seq] = NodeOp::Use;
222                }
223                Rvalue::Repeat(op, _) => {
224                    self.add_operand(op, dst);
225                    self.nodes[dst].ops[seq] = NodeOp::Repeat;
226                }
227                Rvalue::Ref(_, borrow_kind, place) => {
228                    let op = match borrow_kind {
229                        BorrowKind::Shared => EdgeOp::Immut,
230                        BorrowKind::Mut { .. } => EdgeOp::Mut,
231                        BorrowKind::Fake(_) => EdgeOp::Nop, // todo
232                    };
233                    let src = self.parse_place(place);
234                    self.add_node_edge(src, dst, op);
235                    self.nodes[dst].ops[seq] = NodeOp::Ref;
236                }
237                Rvalue::Len(place) => {
238                    let src = self.parse_place(place);
239                    self.add_node_edge(src, dst, EdgeOp::Nop);
240                    self.nodes[dst].ops[seq] = NodeOp::Len;
241                }
242                Rvalue::Cast(_cast_kind, operand, _) => {
243                    self.add_operand(operand, dst);
244                    self.nodes[dst].ops[seq] = NodeOp::Cast;
245                }
246                Rvalue::BinaryOp(_, operands) => {
247                    self.add_operand(&operands.0, dst);
248                    self.add_operand(&operands.1, dst);
249                    self.nodes[dst].ops[seq] = NodeOp::CheckedBinaryOp;
250                }
251                Rvalue::Aggregate(boxed_kind, operands) => {
252                    for operand in operands.iter() {
253                        self.add_operand(operand, dst);
254                    }
255                    match **boxed_kind {
256                        AggregateKind::Array(_) => {
257                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Array)
258                        }
259                        AggregateKind::Tuple => {
260                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Tuple)
261                        }
262                        AggregateKind::Adt(def_id, ..) => {
263                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Adt(def_id))
264                        }
265                        AggregateKind::Closure(def_id, ..) => {
266                            self.closures.insert(def_id);
267                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Closure(def_id))
268                        }
269                        AggregateKind::Coroutine(def_id, ..) => {
270                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::Coroutine(def_id))
271                        }
272                        AggregateKind::RawPtr(_, _mutability) => {
273                            self.nodes[dst].ops[seq] = NodeOp::Aggregate(AggKind::RawPtr)
274                            // We temporarily have not taken mutability into account
275                        }
276                        _ => {
277                            println!("{:?}", boxed_kind);
278                            todo!()
279                        }
280                    }
281                }
282                Rvalue::UnaryOp(_, operand) => {
283                    self.add_operand(operand, dst);
284                    self.nodes[dst].ops[seq] = NodeOp::UnaryOp;
285                }
286                Rvalue::NullaryOp(_, ty) => {
287                    self.add_const_edge(ty.to_string(), ty.to_string(), dst, EdgeOp::Nop);
288                    self.nodes[dst].ops[seq] = NodeOp::NullaryOp;
289                }
290                Rvalue::ThreadLocalRef(_) => {
291                    //todo!()
292                }
293                Rvalue::Discriminant(place) => {
294                    let src = self.parse_place(place);
295                    self.add_node_edge(src, dst, EdgeOp::Nop);
296                    self.nodes[dst].ops[seq] = NodeOp::Discriminant;
297                }
298                Rvalue::ShallowInitBox(operand, _) => {
299                    self.add_operand(operand, dst);
300                    self.nodes[dst].ops[seq] = NodeOp::ShallowInitBox;
301                }
302                Rvalue::CopyForDeref(place) => {
303                    let src = self.parse_place(place);
304                    self.add_node_edge(src, dst, EdgeOp::Nop);
305                    self.nodes[dst].ops[seq] = NodeOp::CopyForDeref;
306                }
307                Rvalue::RawPtr(_, place) => {
308                    let src = self.parse_place(place);
309                    self.add_node_edge(src, dst, EdgeOp::Nop); // Mutability?
310                    self.nodes[dst].ops[seq] = NodeOp::RawPtr;
311                }
312                _ => todo!(),
313            };
314            self.nodes[dst].seq = seq + 1;
315        }
316    }
317
318    pub fn add_terminator_to_graph(&mut self, terminator: &Terminator) {
319        if let TerminatorKind::Call {
320            func,
321            args,
322            destination,
323            ..
324        } = &terminator.kind
325        {
326            let dst = destination.local;
327            let seq = self.nodes[dst].seq;
328            if seq == self.nodes[dst].ops.len() {
329                self.nodes[dst].ops.push(NodeOp::Nop);
330            }
331            match func {
332                Operand::Constant(boxed_cnst) => {
333                    if let Const::Val(_, ty) = boxed_cnst.const_ {
334                        if let TyKind::FnDef(def_id, _) = ty.kind() {
335                            for op in args.iter() {
336                                //rustc version related
337                                self.add_operand(&op.node, dst);
338                            }
339                            self.nodes[dst].ops[seq] = NodeOp::Call(*def_id);
340                        }
341                    }
342                }
343                Operand::Move(_) => {
344                    self.add_operand(func, dst); //the func is a place
345                    for op in args.iter() {
346                        //rustc version related
347                        self.add_operand(&op.node, dst);
348                    }
349                    self.nodes[dst].ops[seq] = NodeOp::CallOperand;
350                }
351                _ => {
352                    println!("{:?}", func);
353                    todo!();
354                }
355            }
356            self.nodes[dst].span = terminator.source_info.span;
357            self.nodes[dst].seq = seq + 1;
358        }
359    }
360
361    // Because a node(local) may have multiple ops, we need to decide whether to strictly collect equivalent locals or not
362    // For the former, all the ops should meet the equivalent condition.
363    // For the later, if only one op meets the condition, we still take it into consideration.
364    pub fn collect_equivalent_locals(&self, local: Local, strict: bool) -> HashSet<Local> {
365        let mut set = HashSet::new();
366        let mut root = local;
367        let reduce_func = if strict {
368            DFSStatus::and
369        } else {
370            DFSStatus::or
371        };
372        let mut find_root_operator = |graph: &Graph, idx: Local| -> DFSStatus {
373            let node = &graph.nodes[idx];
374            node.ops
375                .iter()
376                .map(|op| {
377                    match op {
378                        NodeOp::Nop | NodeOp::Use | NodeOp::Ref => {
379                            //Nop means an orphan node or a parameter
380                            root = idx;
381                            DFSStatus::Continue
382                        }
383                        _ => DFSStatus::Stop,
384                    }
385                })
386                .reduce(reduce_func)
387                .unwrap()
388        };
389        let mut find_equivalent_operator = |graph: &Graph, idx: Local| -> DFSStatus {
390            let node = &graph.nodes[idx];
391            if set.contains(&idx) {
392                return DFSStatus::Stop;
393            }
394            node.ops
395                .iter()
396                .map(|op| match op {
397                    NodeOp::Nop | NodeOp::Use | NodeOp::Ref => {
398                        set.insert(idx);
399                        DFSStatus::Continue
400                    }
401                    _ => DFSStatus::Stop,
402                })
403                .reduce(reduce_func)
404                .unwrap()
405        };
406        // Algorithm: dfs along upside to find the root node, and then dfs along downside to collect equivalent locals
407        let mut seen = HashSet::new();
408        self.dfs(
409            local,
410            Direction::Upside,
411            &mut find_root_operator,
412            &mut Self::equivalent_edge_validator,
413            true,
414            &mut seen,
415        );
416        seen.clear();
417        self.dfs(
418            root,
419            Direction::Downside,
420            &mut find_equivalent_operator,
421            &mut Self::equivalent_edge_validator,
422            true,
423            &mut seen,
424        );
425        set
426    }
427
428    pub fn collect_ancestor_locals(&self, local: Local, self_included: bool) -> HashSet<Local> {
429        let mut ret = HashSet::new();
430        let mut node_operator = |_: &Graph, idx: Local| -> DFSStatus {
431            ret.insert(idx);
432            DFSStatus::Continue
433        };
434        let mut seen = HashSet::new();
435        self.dfs(
436            local,
437            Direction::Upside,
438            &mut node_operator,
439            &mut Graph::always_true_edge_validator,
440            true,
441            &mut seen,
442        );
443        if !self_included {
444            ret.remove(&local);
445        }
446        ret
447    }
448
449    pub fn is_connected(&self, idx_1: Local, idx_2: Local) -> bool {
450        let target = idx_2;
451        let find = Cell::new(false);
452        let mut node_operator = |_: &Graph, idx: Local| -> DFSStatus {
453            find.set(idx == target);
454            if find.get() {
455                DFSStatus::Stop
456            } else {
457                // if not found, move on
458                DFSStatus::Continue
459            }
460        };
461        let mut seen = HashSet::new();
462        self.dfs(
463            idx_1,
464            Direction::Downside,
465            &mut node_operator,
466            &mut Self::always_true_edge_validator,
467            false,
468            &mut seen,
469        );
470        seen.clear();
471        if !find.get() {
472            self.dfs(
473                idx_1,
474                Direction::Upside,
475                &mut node_operator,
476                &mut Self::always_true_edge_validator,
477                false,
478                &mut seen,
479            );
480        }
481        find.get()
482    }
483
484    // Whether there exists dataflow between each parameter and the return value
485    pub fn param_return_deps(&self) -> IndexVec<Local, bool> {
486        let _0 = Local::from_usize(0);
487        let deps = (0..self.argc + 1) //the length is argc + 1, because _0 depends on _0 itself.
488            .map(|i| {
489                let _i = Local::from_usize(i);
490                self.is_connected(_i, _0)
491            })
492            .collect();
493        deps
494    }
495
496    // This function uses precedence traversal.
497    // The node operator and edge validator decide how far the traversal can reach.
498    // `traverse_all` decides if a branch finds the target successfully, whether the traversal will continue or not.
499    // For example, if you need to instantly stop the traversal once finding a certain node, then set `traverse_all` to false.
500    // If you want to traverse all the reachable nodes which are decided by the operator and validator, then set `traverse_all` to true.
501    pub fn dfs<F, G>(
502        &self,
503        now: Local,
504        direction: Direction,
505        node_operator: &mut F,
506        edge_validator: &mut G,
507        traverse_all: bool,
508        seen: &mut HashSet<Local>,
509    ) -> DFSStatus
510    where
511        F: FnMut(&Graph, Local) -> DFSStatus,
512        G: FnMut(&Graph, EdgeIdx) -> DFSStatus,
513    {
514        if seen.contains(&now) {
515            return DFSStatus::Stop;
516        }
517        seen.insert(now);
518        macro_rules! traverse {
519            ($edges: ident, $field: ident) => {
520                for edge_idx in self.nodes[now].$edges.iter() {
521                    let edge = &self.edges[*edge_idx];
522                    if matches!(edge_validator(self, *edge_idx), DFSStatus::Continue) {
523                        let result = self.dfs(
524                            edge.$field,
525                            direction,
526                            node_operator,
527                            edge_validator,
528                            traverse_all,
529                            seen,
530                        );
531                        if matches!(result, DFSStatus::Stop) && !traverse_all {
532                            return DFSStatus::Stop;
533                        }
534                    }
535                }
536            };
537        }
538
539        if matches!(node_operator(self, now), DFSStatus::Continue) {
540            match direction {
541                Direction::Upside => {
542                    traverse!(in_edges, src);
543                }
544                Direction::Downside => {
545                    traverse!(out_edges, dst);
546                }
547                Direction::Both => {
548                    traverse!(in_edges, src);
549                    traverse!(out_edges, dst);
550                }
551            };
552            DFSStatus::Continue
553        } else {
554            DFSStatus::Stop
555        }
556    }
557
558    pub fn get_upside_idx(&self, node_idx: Local, order: usize) -> Option<Local> {
559        if let Some(edge_idx) = self.nodes[node_idx].in_edges.get(order) {
560            Some(self.edges[*edge_idx].src)
561        } else {
562            None
563        }
564    }
565
566    pub fn get_downside_idx(&self, node_idx: Local, order: usize) -> Option<Local> {
567        if let Some(edge_idx) = self.nodes[node_idx].out_edges.get(order) {
568            Some(self.edges[*edge_idx].dst)
569        } else {
570            None
571        }
572    }
573
574    // if strict is set to false, we return the first node that wraps the target span and at least one end overlaps
575    pub fn query_node_by_span(&self, span: Span, strict: bool) -> Option<(Local, &GraphNode)> {
576        for (node_idx, node) in self.nodes.iter_enumerated() {
577            if strict {
578                if node.span == span {
579                    return Some((node_idx, node));
580                }
581            } else {
582                if !relative_pos_range(node.span, span).eq(0..0)
583                    && (node.span.lo() == span.lo() || node.span.hi() == span.hi())
584                {
585                    return Some((node_idx, node));
586                }
587            }
588        }
589        None
590    }
591
592    pub fn is_marker(&self, idx: Local) -> bool {
593        idx >= Local::from_usize(self.n_locals)
594    }
595}
596
597impl Graph {
598    pub fn equivalent_edge_validator(graph: &Graph, idx: EdgeIdx) -> DFSStatus {
599        match graph.edges[idx].op {
600            EdgeOp::Copy | EdgeOp::Move | EdgeOp::Mut | EdgeOp::Immut | EdgeOp::Deref => {
601                DFSStatus::Continue
602            }
603            EdgeOp::Nop
604            | EdgeOp::Const
605            | EdgeOp::Downcast(_)
606            | EdgeOp::Field(_)
607            | EdgeOp::Index
608            | EdgeOp::ConstIndex
609            | EdgeOp::SubSlice
610            | EdgeOp::SubType => DFSStatus::Stop,
611        }
612    }
613
614    pub fn always_true_edge_validator(_: &Graph, _: EdgeIdx) -> DFSStatus {
615        DFSStatus::Continue
616    }
617}
618
619#[derive(Clone, Copy)]
620pub enum Direction {
621    Upside,
622    Downside,
623    Both,
624}
625
626pub enum DFSStatus {
627    Continue, // true
628    Stop,     // false
629}
630
631impl DFSStatus {
632    pub fn and(s1: DFSStatus, s2: DFSStatus) -> DFSStatus {
633        if matches!(s1, DFSStatus::Stop) || matches!(s2, DFSStatus::Stop) {
634            DFSStatus::Stop
635        } else {
636            DFSStatus::Continue
637        }
638    }
639
640    pub fn or(s1: DFSStatus, s2: DFSStatus) -> DFSStatus {
641        if matches!(s1, DFSStatus::Continue) || matches!(s2, DFSStatus::Continue) {
642            DFSStatus::Continue
643        } else {
644            DFSStatus::Stop
645        }
646    }
647}
648
649#[derive(Clone, Copy, Debug)]
650pub enum AggKind {
651    Array,
652    Tuple,
653    Adt(DefId),
654    Closure(DefId),
655    Coroutine(DefId),
656    RawPtr,
657}