rapx/analysis/core/callgraph/
visitor.rs1use super::default::CallGraphInfo;
2use regex::Regex;
3use rustc_hir::def_id::DefId;
4use rustc_middle::mir;
5use rustc_middle::ty::{FnDef, Instance, InstanceKind, TyCtxt, TypingEnv};
6
7pub struct CallGraphVisitor<'b, 'tcx> {
8 tcx: TyCtxt<'tcx>,
9 def_id: DefId,
10 body: &'tcx mir::Body<'tcx>,
11 call_graph_info: &'b mut CallGraphInfo<'tcx>,
12}
13
14impl<'b, 'tcx> CallGraphVisitor<'b, 'tcx> {
15 pub fn new(
16 tcx: TyCtxt<'tcx>,
17 def_id: DefId,
18 body: &'tcx mir::Body<'tcx>,
19 call_graph_info: &'b mut CallGraphInfo<'tcx>,
20 ) -> Self {
21 Self {
22 tcx: tcx,
23 def_id: def_id,
24 body: body,
25 call_graph_info: call_graph_info,
26 }
27 }
28
29 pub fn add_in_call_graph(
30 &mut self,
31 caller_def_path: &String,
32 callee_def_id: DefId,
33 callee_def_path: &String,
34 terminator: &'tcx mir::Terminator<'tcx>,
35 ) {
36 if let Some(caller_id) = self.call_graph_info.get_node_by_path(caller_def_path) {
37 if let Some(callee_id) = self.call_graph_info.get_node_by_path(callee_def_path) {
38 self.call_graph_info
39 .add_funciton_call_edge(caller_id, callee_id, terminator);
40 } else {
41 self.call_graph_info
42 .add_node(callee_def_id, callee_def_path);
43 if let Some(callee_id) = self.call_graph_info.get_node_by_path(callee_def_path) {
44 self.call_graph_info
45 .add_funciton_call_edge(caller_id, callee_id, terminator);
46 }
47 }
48 }
49 }
50
51 pub fn visit(&mut self) {
52 let caller_path_str = self.tcx.def_path_str(self.def_id);
53 self.call_graph_info.add_node(self.def_id, &caller_path_str);
54 for (_, data) in self.body.basic_blocks.iter().enumerate() {
55 let terminator = data.terminator();
56 self.visit_terminator(&terminator);
57 }
58 }
59
60 fn add_to_call_graph(
61 &mut self,
62 callee_def_id: DefId,
63 is_virtual: Option<bool>,
64 terminator: &'tcx mir::Terminator<'tcx>,
65 ) {
66 let caller_def_path = self.tcx.def_path_str(self.def_id);
67 let mut callee_def_path = self.tcx.def_path_str(callee_def_id);
68 if let Some(judge) = is_virtual {
69 if judge {
70 let re = Regex::new(r"(?<dyn>\w+)::(?<func>\w+)").unwrap();
71 let Some(caps) = re.captures(&callee_def_path) else {
72 return;
73 };
74 callee_def_path = format!("(dyn trait) <* as {}>::{}", &caps["dyn"], &caps["func"]);
75 }
76 }
77
78 if callee_def_id == self.def_id {
80 println!("Warning! Find a recursion function which may cause stackoverflow!")
82 }
83 self.add_in_call_graph(
84 &caller_def_path,
85 callee_def_id,
86 &callee_def_path,
87 terminator,
88 );
89 }
90
91 fn visit_terminator(&mut self, terminator: &'tcx mir::Terminator<'tcx>) {
92 if let mir::TerminatorKind::Call { func, .. } = &terminator.kind {
93 if let mir::Operand::Constant(constant) = func {
94 if let FnDef(callee_def_id, callee_substs) = constant.const_.ty().kind() {
95 let ty_env = TypingEnv::post_analysis(self.tcx, self.def_id);
96 if let Ok(Some(instance)) =
97 Instance::try_resolve(self.tcx, ty_env, *callee_def_id, callee_substs)
98 {
99 let mut is_virtual = false;
100 let instance_def_id = match instance.def {
102 InstanceKind::Item(def_id) => Some(def_id),
103 InstanceKind::Intrinsic(def_id) => Some(def_id),
104 InstanceKind::VTableShim(def_id) => Some(def_id),
105 InstanceKind::ReifyShim(def_id, _) => Some(def_id),
106 InstanceKind::FnPtrShim(def_id, _) => Some(def_id),
107 InstanceKind::Virtual(def_id, _) => {
108 is_virtual = true;
109 Some(def_id)
110 }
111 InstanceKind::ClosureOnceShim { call_once, .. } => Some(call_once),
112 InstanceKind::ConstructCoroutineInClosureShim {
113 coroutine_closure_def_id,
114 ..
115 } => Some(coroutine_closure_def_id),
116 InstanceKind::ThreadLocalShim(def_id) => Some(def_id),
117 InstanceKind::DropGlue(def_id, _) => Some(def_id),
118 InstanceKind::FnPtrAddrShim(def_id, _) => Some(def_id),
119 InstanceKind::AsyncDropGlueCtorShim(def_id, _) => Some(def_id),
120 InstanceKind::CloneShim(def_id, _) => {
121 if !self.tcx.is_closure_like(def_id) {
122 Some(def_id)
124 } else {
125 None
126 }
127 }
128 _ => todo!(),
129 };
130 if let Some(instance_def_id) = instance_def_id {
131 self.add_to_call_graph(instance_def_id, Some(is_virtual), terminator);
132 }
133 } else {
134 self.add_to_call_graph(*callee_def_id, None, terminator);
136 }
137 }
138 }
139 }
140 }
141}