rapx/analysis/core/callgraph/
visitor.rs

1use super::default::CallGraphInfo;
2use regex::Regex;
3use rustc_hir::def_id::DefId;
4use rustc_middle::mir;
5use rustc_middle::ty::{FnDef, Instance, InstanceKind, TyCtxt, TypingEnv};
6use std::collections::HashSet;
7
8pub struct CallGraphVisitor<'b, 'tcx> {
9    tcx: TyCtxt<'tcx>,
10    def_id: DefId,
11    body: &'tcx mir::Body<'tcx>,
12    call_graph_info: &'b mut CallGraphInfo<'tcx>,
13}
14
15impl<'b, 'tcx> CallGraphVisitor<'b, 'tcx> {
16    pub fn new(
17        tcx: TyCtxt<'tcx>,
18        def_id: DefId,
19        body: &'tcx mir::Body<'tcx>,
20        call_graph_info: &'b mut CallGraphInfo<'tcx>,
21    ) -> Self {
22        Self {
23            tcx: tcx,
24            def_id: def_id,
25            body: body,
26            call_graph_info: call_graph_info,
27        }
28    }
29
30    pub fn add_in_call_graph(
31        &mut self,
32        caller_def_path: &String,
33        callee_def_id: DefId,
34        callee_def_path: &String,
35        terminator: &'tcx mir::Terminator<'tcx>,
36    ) {
37        if let Some(caller_id) = self.call_graph_info.get_node_by_path(caller_def_path) {
38            if let Some(callee_id) = self.call_graph_info.get_node_by_path(callee_def_path) {
39                self.call_graph_info
40                    .add_funciton_call_edge(caller_id, callee_id, Some(terminator));
41            } else {
42                self.call_graph_info
43                    .add_node(callee_def_id, callee_def_path);
44                if let Some(callee_id) = self.call_graph_info.get_node_by_path(callee_def_path) {
45                    self.call_graph_info.add_funciton_call_edge(
46                        caller_id,
47                        callee_id,
48                        Some(terminator),
49                    );
50                }
51            }
52        }
53    }
54
55    pub fn visit(&mut self) {
56        let caller_path_str = self.tcx.def_path_str(self.def_id);
57        self.call_graph_info.add_node(self.def_id, &caller_path_str);
58        for (_, data) in self.body.basic_blocks.iter().enumerate() {
59            let terminator = data.terminator();
60            self.visit_terminator(&terminator);
61        }
62    }
63
64    fn add_to_call_graph(
65        &mut self,
66        callee_def_id: DefId,
67        is_virtual: Option<bool>,
68        terminator: &'tcx mir::Terminator<'tcx>,
69    ) {
70        let caller_def_path = self.tcx.def_path_str(self.def_id);
71        let mut callee_def_path = self.tcx.def_path_str(callee_def_id);
72
73        if let Some(true) = is_virtual {
74            // Handle dynamic dispatch for trait objects
75            let re = Regex::new(r"(?<dyn>\w+)::(?<func>\w+)").unwrap();
76            if let Some(caps) = re.captures(&callee_def_path) {
77                callee_def_path = format!("(dyn trait) <* as {}>::{}", &caps["dyn"], &caps["func"]);
78            };
79            self.handle_virtual_call(
80                &caller_def_path,
81                callee_def_id,
82                &callee_def_path,
83                terminator,
84            );
85        } else {
86            // let callee_location = self.tcx.def_span(callee_def_id);
87            if callee_def_id == self.def_id {
88                // Recursion
89                println!("Warning! Find a recursion function which may cause stackoverflow!")
90            }
91            self.add_in_call_graph(
92                &caller_def_path,
93                callee_def_id,
94                &callee_def_path,
95                terminator,
96            );
97        }
98    }
99
100    fn visit_terminator(&mut self, terminator: &'tcx mir::Terminator<'tcx>) {
101        if let mir::TerminatorKind::Call { func, .. } = &terminator.kind {
102            if let mir::Operand::Constant(constant) = func {
103                if let FnDef(callee_def_id, callee_substs) = constant.const_.ty().kind() {
104                    let ty_env = TypingEnv::post_analysis(self.tcx, self.def_id);
105                    if let Ok(Some(instance)) =
106                        Instance::try_resolve(self.tcx, ty_env, *callee_def_id, callee_substs)
107                    {
108                        let mut is_virtual = false;
109                        // Try to analysis the specific type of callee.
110                        let instance_def_id = match instance.def {
111                            InstanceKind::Item(def_id) => Some(def_id),
112                            InstanceKind::Intrinsic(def_id) => Some(def_id),
113                            InstanceKind::VTableShim(def_id) => Some(def_id),
114                            InstanceKind::ReifyShim(def_id, _) => Some(def_id),
115                            InstanceKind::FnPtrShim(def_id, _) => Some(def_id),
116                            InstanceKind::Virtual(def_id, _) => {
117                                is_virtual = true;
118                                Some(def_id)
119                            }
120                            InstanceKind::ClosureOnceShim { call_once, .. } => Some(call_once),
121                            InstanceKind::ConstructCoroutineInClosureShim {
122                                coroutine_closure_def_id,
123                                ..
124                            } => Some(coroutine_closure_def_id),
125                            InstanceKind::ThreadLocalShim(def_id) => Some(def_id),
126                            InstanceKind::DropGlue(def_id, _) => Some(def_id),
127                            InstanceKind::FnPtrAddrShim(def_id, _) => Some(def_id),
128                            InstanceKind::AsyncDropGlueCtorShim(def_id, _) => Some(def_id),
129                            InstanceKind::CloneShim(def_id, _) => {
130                                if !self.tcx.is_closure_like(def_id) {
131                                    // Not a closure
132                                    Some(def_id)
133                                } else {
134                                    None
135                                }
136                            }
137                            _ => todo!(),
138                        };
139                        if let Some(instance_def_id) = instance_def_id {
140                            self.add_to_call_graph(instance_def_id, Some(is_virtual), terminator);
141                        }
142                    } else {
143                        // Although failing to get specific type, callee is still useful.
144                        self.add_to_call_graph(*callee_def_id, None, terminator);
145                    }
146                }
147            }
148        }
149    }
150
151    fn handle_virtual_call(
152        &mut self,
153        caller_def_path: &String,
154        stub_def_id: DefId, // Callee is the dynamic call stub, i.e. the fn definition in trait
155        stub_def_path: &String,
156        terminator: &'tcx mir::Terminator<'tcx>,
157    ) {
158        // Step 1: Add an edge from caller to the virtual function (stub);
159        let mut visited = false;
160        let stub_id = if let Some(id) = self.call_graph_info.get_node_by_path(stub_def_path) {
161            // Node exists, suggesting we have already analyzed this virtual function
162            visited = true;
163            id
164        } else {
165            self.call_graph_info.add_node(stub_def_id, stub_def_path)
166        };
167        let caller_id = self
168            .call_graph_info
169            .get_node_by_path(caller_def_path)
170            .unwrap(); // This must be Some since the caller must have been added to graph
171        self.call_graph_info
172            .add_funciton_call_edge(caller_id, stub_id, Some(terminator));
173
174        // If this function has already been analyzed, return;
175        if visited {
176            return;
177        }
178
179        // Step 2: Find all impls of the virtual function;
180        let mut candidates: HashSet<DefId> = HashSet::new();
181        if let Some(trait_def_id) = self.tcx.trait_of_assoc(stub_def_id) {
182            rap_debug!(
183                "[Callgraph] Virtual fn {:?} belongs to trait {:?}",
184                stub_def_id,
185                trait_def_id
186            );
187            for impl_id in self.tcx.all_impls(trait_def_id) {
188                let impl_map = self.tcx.impl_item_implementor_ids(impl_id);
189                if let Some(candidate_def_id) = impl_map.get(&stub_def_id) {
190                    candidates.insert(*candidate_def_id);
191                }
192            }
193        }
194        rap_debug!(
195            "[Callgraph] Implementors of {:?}: {:?}",
196            stub_def_id,
197            candidates
198        );
199
200        // Step 3: For each implementor, add an edge from the stub to it.
201        for candidate_def_id in candidates {
202            let candidate_def_path = self.tcx.def_path_str(candidate_def_id);
203            let callee_id =
204                if let Some(id) = self.call_graph_info.get_node_by_path(&candidate_def_path) {
205                    id
206                } else {
207                    self.call_graph_info
208                        .add_node(candidate_def_id, &candidate_def_path)
209                };
210            self.call_graph_info
211                .add_funciton_call_edge(stub_id, callee_id, None);
212        }
213    }
214}