rapx/analysis/core/callgraph/
visitor.rs

1use super::default::CallGraphInfo;
2use regex::Regex;
3use rustc_hir::def_id::DefId;
4use rustc_middle::mir;
5use rustc_middle::ty::{FnDef, Instance, InstanceKind, TyCtxt, TypingEnv};
6
7pub struct CallGraphVisitor<'b, 'tcx> {
8    tcx: TyCtxt<'tcx>,
9    def_id: DefId,
10    body: &'tcx mir::Body<'tcx>,
11    call_graph_info: &'b mut CallGraphInfo<'tcx>,
12}
13
14impl<'b, 'tcx> CallGraphVisitor<'b, 'tcx> {
15    pub fn new(
16        tcx: TyCtxt<'tcx>,
17        def_id: DefId,
18        body: &'tcx mir::Body<'tcx>,
19        call_graph_info: &'b mut CallGraphInfo<'tcx>,
20    ) -> Self {
21        Self {
22            tcx: tcx,
23            def_id: def_id,
24            body: body,
25            call_graph_info: call_graph_info,
26        }
27    }
28
29    pub fn add_in_call_graph(
30        &mut self,
31        caller_def_path: &String,
32        callee_def_id: DefId,
33        callee_def_path: &String,
34        terminator: &'tcx mir::Terminator<'tcx>,
35    ) {
36        if let Some(caller_id) = self.call_graph_info.get_node_by_path(caller_def_path) {
37            if let Some(callee_id) = self.call_graph_info.get_node_by_path(callee_def_path) {
38                self.call_graph_info
39                    .add_funciton_call_edge(caller_id, callee_id, terminator);
40            } else {
41                self.call_graph_info
42                    .add_node(callee_def_id, callee_def_path);
43                if let Some(callee_id) = self.call_graph_info.get_node_by_path(callee_def_path) {
44                    self.call_graph_info
45                        .add_funciton_call_edge(caller_id, callee_id, terminator);
46                }
47            }
48        }
49    }
50
51    pub fn visit(&mut self) {
52        let caller_path_str = self.tcx.def_path_str(self.def_id);
53        self.call_graph_info.add_node(self.def_id, &caller_path_str);
54        for (_, data) in self.body.basic_blocks.iter().enumerate() {
55            let terminator = data.terminator();
56            self.visit_terminator(&terminator);
57        }
58    }
59
60    fn add_to_call_graph(
61        &mut self,
62        callee_def_id: DefId,
63        is_virtual: Option<bool>,
64        terminator: &'tcx mir::Terminator<'tcx>,
65    ) {
66        let caller_def_path = self.tcx.def_path_str(self.def_id);
67        let mut callee_def_path = self.tcx.def_path_str(callee_def_id);
68        if let Some(judge) = is_virtual {
69            if judge {
70                let re = Regex::new(r"(?<dyn>\w+)::(?<func>\w+)").unwrap();
71                let Some(caps) = re.captures(&callee_def_path) else {
72                    return;
73                };
74                callee_def_path = format!("(dyn trait) <* as {}>::{}", &caps["dyn"], &caps["func"]);
75            }
76        }
77
78        // let callee_location = self.tcx.def_span(callee_def_id);
79        if callee_def_id == self.def_id {
80            // Recursion
81            println!("Warning! Find a recursion function which may cause stackoverflow!")
82        }
83        self.add_in_call_graph(
84            &caller_def_path,
85            callee_def_id,
86            &callee_def_path,
87            terminator,
88        );
89    }
90
91    fn visit_terminator(&mut self, terminator: &'tcx mir::Terminator<'tcx>) {
92        if let mir::TerminatorKind::Call { func, .. } = &terminator.kind {
93            if let mir::Operand::Constant(constant) = func {
94                if let FnDef(callee_def_id, callee_substs) = constant.const_.ty().kind() {
95                    let ty_env = TypingEnv::post_analysis(self.tcx, self.def_id);
96                    if let Ok(Some(instance)) =
97                        Instance::try_resolve(self.tcx, ty_env, *callee_def_id, callee_substs)
98                    {
99                        let mut is_virtual = false;
100                        // Try to analysis the specific type of callee.
101                        let instance_def_id = match instance.def {
102                            InstanceKind::Item(def_id) => Some(def_id),
103                            InstanceKind::Intrinsic(def_id) => Some(def_id),
104                            InstanceKind::VTableShim(def_id) => Some(def_id),
105                            InstanceKind::ReifyShim(def_id, _) => Some(def_id),
106                            InstanceKind::FnPtrShim(def_id, _) => Some(def_id),
107                            InstanceKind::Virtual(def_id, _) => {
108                                is_virtual = true;
109                                Some(def_id)
110                            }
111                            InstanceKind::ClosureOnceShim { call_once, .. } => Some(call_once),
112                            InstanceKind::ConstructCoroutineInClosureShim {
113                                coroutine_closure_def_id,
114                                ..
115                            } => Some(coroutine_closure_def_id),
116                            InstanceKind::ThreadLocalShim(def_id) => Some(def_id),
117                            InstanceKind::DropGlue(def_id, _) => Some(def_id),
118                            InstanceKind::FnPtrAddrShim(def_id, _) => Some(def_id),
119                            InstanceKind::AsyncDropGlueCtorShim(def_id, _) => Some(def_id),
120                            InstanceKind::CloneShim(def_id, _) => {
121                                if !self.tcx.is_closure_like(def_id) {
122                                    // Not a closure
123                                    Some(def_id)
124                                } else {
125                                    None
126                                }
127                            }
128                            _ => todo!(),
129                        };
130                        if let Some(instance_def_id) = instance_def_id {
131                            self.add_to_call_graph(instance_def_id, Some(is_virtual), terminator);
132                        }
133                    } else {
134                        // Although failing to get specific type, callee is still useful.
135                        self.add_to_call_graph(*callee_def_id, None, terminator);
136                    }
137                }
138            }
139        }
140    }
141}