rapx/analysis/core/callgraph/
visitor.rs1use super::default::CallGraphInfo;
2use regex::Regex;
3use rustc_hir::def_id::DefId;
4use rustc_middle::mir;
5use rustc_middle::ty::{FnDef, Instance, InstanceKind, TyCtxt, TypingEnv};
6use std::collections::HashSet;
7
8pub struct CallGraphVisitor<'b, 'tcx> {
9 tcx: TyCtxt<'tcx>,
10 def_id: DefId,
11 body: &'tcx mir::Body<'tcx>,
12 call_graph_info: &'b mut CallGraphInfo<'tcx>,
13}
14
15impl<'b, 'tcx> CallGraphVisitor<'b, 'tcx> {
16 pub fn new(
17 tcx: TyCtxt<'tcx>,
18 def_id: DefId,
19 body: &'tcx mir::Body<'tcx>,
20 call_graph_info: &'b mut CallGraphInfo<'tcx>,
21 ) -> Self {
22 Self {
23 tcx: tcx,
24 def_id: def_id,
25 body: body,
26 call_graph_info: call_graph_info,
27 }
28 }
29
30 pub fn add_in_call_graph(
31 &mut self,
32 caller_def_path: &String,
33 callee_def_id: DefId,
34 callee_def_path: &String,
35 terminator: &'tcx mir::Terminator<'tcx>,
36 ) {
37 if let Some(caller_id) = self.call_graph_info.get_node_by_path(caller_def_path) {
38 if let Some(callee_id) = self.call_graph_info.get_node_by_path(callee_def_path) {
39 self.call_graph_info
40 .add_funciton_call_edge(caller_id, callee_id, Some(terminator));
41 } else {
42 self.call_graph_info
43 .add_node(callee_def_id, callee_def_path);
44 if let Some(callee_id) = self.call_graph_info.get_node_by_path(callee_def_path) {
45 self.call_graph_info.add_funciton_call_edge(
46 caller_id,
47 callee_id,
48 Some(terminator),
49 );
50 }
51 }
52 }
53 }
54
55 pub fn visit(&mut self) {
56 let caller_path_str = self.tcx.def_path_str(self.def_id);
57 self.call_graph_info.add_node(self.def_id, &caller_path_str);
58 for (_, data) in self.body.basic_blocks.iter().enumerate() {
59 let terminator = data.terminator();
60 self.visit_terminator(&terminator);
61 }
62 }
63
64 fn add_to_call_graph(
65 &mut self,
66 callee_def_id: DefId,
67 is_virtual: Option<bool>,
68 terminator: &'tcx mir::Terminator<'tcx>,
69 ) {
70 let caller_def_path = self.tcx.def_path_str(self.def_id);
71 let mut callee_def_path = self.tcx.def_path_str(callee_def_id);
72
73 if let Some(true) = is_virtual {
74 let re = Regex::new(r"(?<dyn>\w+)::(?<func>\w+)").unwrap();
76 if let Some(caps) = re.captures(&callee_def_path) {
77 callee_def_path = format!("(dyn trait) <* as {}>::{}", &caps["dyn"], &caps["func"]);
78 };
79 self.handle_virtual_call(
80 &caller_def_path,
81 callee_def_id,
82 &callee_def_path,
83 terminator,
84 );
85 } else {
86 if callee_def_id == self.def_id {
88 println!("Warning! Find a recursion function which may cause stackoverflow!")
90 }
91 self.add_in_call_graph(
92 &caller_def_path,
93 callee_def_id,
94 &callee_def_path,
95 terminator,
96 );
97 }
98 }
99
100 fn visit_terminator(&mut self, terminator: &'tcx mir::Terminator<'tcx>) {
101 if let mir::TerminatorKind::Call { func, .. } = &terminator.kind {
102 if let mir::Operand::Constant(constant) = func {
103 if let FnDef(callee_def_id, callee_substs) = constant.const_.ty().kind() {
104 let ty_env = TypingEnv::post_analysis(self.tcx, self.def_id);
105 if let Ok(Some(instance)) =
106 Instance::try_resolve(self.tcx, ty_env, *callee_def_id, callee_substs)
107 {
108 let mut is_virtual = false;
109 let instance_def_id = match instance.def {
111 InstanceKind::Item(def_id) => Some(def_id),
112 InstanceKind::Intrinsic(def_id) => Some(def_id),
113 InstanceKind::VTableShim(def_id) => Some(def_id),
114 InstanceKind::ReifyShim(def_id, _) => Some(def_id),
115 InstanceKind::FnPtrShim(def_id, _) => Some(def_id),
116 InstanceKind::Virtual(def_id, _) => {
117 is_virtual = true;
118 Some(def_id)
119 }
120 InstanceKind::ClosureOnceShim { call_once, .. } => Some(call_once),
121 InstanceKind::ConstructCoroutineInClosureShim {
122 coroutine_closure_def_id,
123 ..
124 } => Some(coroutine_closure_def_id),
125 InstanceKind::ThreadLocalShim(def_id) => Some(def_id),
126 InstanceKind::DropGlue(def_id, _) => Some(def_id),
127 InstanceKind::FnPtrAddrShim(def_id, _) => Some(def_id),
128 InstanceKind::AsyncDropGlueCtorShim(def_id, _) => Some(def_id),
129 InstanceKind::CloneShim(def_id, _) => {
130 if !self.tcx.is_closure_like(def_id) {
131 Some(def_id)
133 } else {
134 None
135 }
136 }
137 _ => todo!(),
138 };
139 if let Some(instance_def_id) = instance_def_id {
140 self.add_to_call_graph(instance_def_id, Some(is_virtual), terminator);
141 }
142 } else {
143 self.add_to_call_graph(*callee_def_id, None, terminator);
145 }
146 }
147 }
148 }
149 }
150
151 fn handle_virtual_call(
152 &mut self,
153 caller_def_path: &String,
154 stub_def_id: DefId, stub_def_path: &String,
156 terminator: &'tcx mir::Terminator<'tcx>,
157 ) {
158 let mut visited = false;
160 let stub_id = if let Some(id) = self.call_graph_info.get_node_by_path(stub_def_path) {
161 visited = true;
163 id
164 } else {
165 self.call_graph_info.add_node(stub_def_id, stub_def_path)
166 };
167 let caller_id = self
168 .call_graph_info
169 .get_node_by_path(caller_def_path)
170 .unwrap(); self.call_graph_info
172 .add_funciton_call_edge(caller_id, stub_id, Some(terminator));
173
174 if visited {
176 return;
177 }
178
179 let mut candidates: HashSet<DefId> = HashSet::new();
181 if let Some(trait_def_id) = self.tcx.trait_of_assoc(stub_def_id) {
182 rap_debug!(
183 "[Callgraph] Virtual fn {:?} belongs to trait {:?}",
184 stub_def_id,
185 trait_def_id
186 );
187 for impl_id in self.tcx.all_impls(trait_def_id) {
188 let impl_map = self.tcx.impl_item_implementor_ids(impl_id);
189 if let Some(candidate_def_id) = impl_map.get(&stub_def_id) {
190 candidates.insert(*candidate_def_id);
191 }
192 }
193 }
194 rap_debug!(
195 "[Callgraph] Implementors of {:?}: {:?}",
196 stub_def_id,
197 candidates
198 );
199
200 for candidate_def_id in candidates {
202 let candidate_def_path = self.tcx.def_path_str(candidate_def_id);
203 let callee_id =
204 if let Some(id) = self.call_graph_info.get_node_by_path(&candidate_def_path) {
205 id
206 } else {
207 self.call_graph_info
208 .add_node(candidate_def_id, &candidate_def_path)
209 };
210 self.call_graph_info
211 .add_funciton_call_edge(stub_id, callee_id, None);
212 }
213 }
214}