rapx/analysis/core/call_graph/
call_graph_visitor.rs

1use super::call_graph_helper::CallGraphInfo;
2use regex::Regex;
3use rustc_hir::def_id::DefId;
4use rustc_middle::mir;
5use rustc_middle::ty::{FnDef, Instance, InstanceKind, TyCtxt, TypingEnv};
6
7pub struct CallGraphVisitor<'b, 'tcx> {
8    tcx: TyCtxt<'tcx>,
9    def_id: DefId,
10    body: &'tcx mir::Body<'tcx>,
11    call_graph_info: &'b mut CallGraphInfo,
12}
13
14impl<'b, 'tcx> CallGraphVisitor<'b, 'tcx> {
15    pub fn new(
16        tcx: TyCtxt<'tcx>,
17        def_id: DefId,
18        body: &'tcx mir::Body<'tcx>,
19        call_graph_info: &'b mut CallGraphInfo,
20    ) -> Self {
21        Self {
22            tcx: tcx,
23            def_id: def_id,
24            body: body,
25            call_graph_info: call_graph_info,
26        }
27    }
28
29    pub fn add_in_call_graph(
30        &mut self,
31        caller_def_path: &String,
32        callee_def_id: DefId,
33        callee_def_path: &String,
34    ) {
35        if let Some(caller_id) = self.call_graph_info.get_noed_by_path(caller_def_path) {
36            if let Some(callee_id) = self.call_graph_info.get_noed_by_path(callee_def_path) {
37                self.call_graph_info
38                    .add_funciton_call_edge(caller_id, callee_id);
39            } else {
40                self.call_graph_info
41                    .add_node(callee_def_id, callee_def_path);
42                if let Some(callee_id) = self.call_graph_info.get_noed_by_path(callee_def_path) {
43                    self.call_graph_info
44                        .add_funciton_call_edge(caller_id, callee_id);
45                }
46            }
47        }
48    }
49
50    pub fn visit(&mut self) {
51        let caller_path_str = self.tcx.def_path_str(self.def_id);
52        self.call_graph_info.add_node(self.def_id, &caller_path_str);
53        for (_, data) in self.body.basic_blocks.iter().enumerate() {
54            let terminator = data.terminator();
55            self.visit_terminator(&terminator);
56        }
57    }
58
59    fn add_to_call_graph(&mut self, callee_def_id: DefId, is_virtual: Option<bool>) {
60        let caller_def_path = self.tcx.def_path_str(self.def_id);
61        let mut callee_def_path = self.tcx.def_path_str(callee_def_id);
62        if let Some(judge) = is_virtual {
63            if judge {
64                let re = Regex::new(r"(?<dyn>\w+)::(?<func>\w+)").unwrap();
65                let Some(caps) = re.captures(&callee_def_path) else {
66                    return;
67                };
68                callee_def_path = format!("(dyn trait) <* as {}>::{}", &caps["dyn"], &caps["func"]);
69            }
70        }
71
72        // let callee_location = self.tcx.def_span(callee_def_id);
73        if callee_def_id == self.def_id {
74            // Recursion
75            println!("Warning! Find a recursion function which may cause stackoverflow!")
76        }
77        self.add_in_call_graph(&caller_def_path, callee_def_id, &callee_def_path);
78    }
79
80    fn visit_terminator(&mut self, terminator: &mir::Terminator<'tcx>) {
81        if let mir::TerminatorKind::Call { func, .. } = &terminator.kind {
82            if let mir::Operand::Constant(constant) = func {
83                if let FnDef(callee_def_id, callee_substs) = constant.const_.ty().kind() {
84                    let ty_env = TypingEnv::post_analysis(self.tcx, self.def_id);
85                    if let Ok(Some(instance)) =
86                        Instance::try_resolve(self.tcx, ty_env, *callee_def_id, callee_substs)
87                    {
88                        let mut is_virtual = false;
89                        // Try to analysis the specific type of callee.
90                        let instance_def_id = match instance.def {
91                            InstanceKind::Item(def_id) => Some(def_id),
92                            InstanceKind::Intrinsic(def_id) => Some(def_id),
93                            InstanceKind::VTableShim(def_id) => Some(def_id),
94                            InstanceKind::ReifyShim(def_id, _) => Some(def_id),
95                            InstanceKind::FnPtrShim(def_id, _) => Some(def_id),
96                            InstanceKind::Virtual(def_id, _) => {
97                                is_virtual = true;
98                                Some(def_id)
99                            }
100                            InstanceKind::ClosureOnceShim { call_once, .. } => Some(call_once),
101                            InstanceKind::ConstructCoroutineInClosureShim {
102                                coroutine_closure_def_id,
103                                ..
104                            } => Some(coroutine_closure_def_id),
105                            InstanceKind::ThreadLocalShim(def_id) => Some(def_id),
106                            InstanceKind::DropGlue(def_id, _) => Some(def_id),
107                            InstanceKind::FnPtrAddrShim(def_id, _) => Some(def_id),
108                            InstanceKind::AsyncDropGlueCtorShim(def_id, _) => Some(def_id),
109                            InstanceKind::CloneShim(def_id, _) => {
110                                if !self.tcx.is_closure_like(def_id) {
111                                    // Not a closure
112                                    Some(def_id)
113                                } else {
114                                    None
115                                }
116                            }
117                            _ => todo!(),
118                        };
119                        if let Some(instance_def_id) = instance_def_id {
120                            self.add_to_call_graph(instance_def_id, Some(is_virtual));
121                        }
122                    } else {
123                        // Although failing to get specific type, callee is still useful.
124                        self.add_to_call_graph(*callee_def_id, None);
125                    }
126                }
127            }
128        }
129    }
130}