rapx/analysis/core/callgraph/
default.rs1use rustc_hir::{def::DefKind, def_id::DefId};
2use rustc_middle::{
3 mir::{self, Body},
4 ty::TyCtxt,
5};
6use std::collections::HashMap;
7use std::collections::HashSet;
8
9use super::visitor::CallGraphVisitor;
10use crate::{
11 Analysis,
12 analysis::core::callgraph::{CallGraphAnalysis, FnCallMap},
13};
14
15pub struct CallGraphAnalyzer<'tcx> {
16 pub tcx: TyCtxt<'tcx>,
17 pub graph: CallGraph<'tcx>,
18}
19
20impl<'tcx> Analysis for CallGraphAnalyzer<'tcx> {
21 fn name(&self) -> &'static str {
22 "Default call graph analysis algorithm."
23 }
24
25 fn run(&mut self) {
26 self.start();
27 }
28
29 fn reset(&mut self) {
30 todo!();
31 }
32}
33
34impl<'tcx> CallGraphAnalysis for CallGraphAnalyzer<'tcx> {
35 fn get_fn_calls(&self) -> FnCallMap {
36 let fn_calls: HashMap<DefId, Vec<DefId>> = self
37 .graph
38 .fn_calls
39 .clone()
40 .into_iter()
41 .map(|(caller, callees)| {
42 let callee_ids = callees.into_iter().map(|(did, _)| did).collect::<Vec<_>>();
43 (caller, callee_ids)
44 })
45 .collect();
46 fn_calls
47 }
48}
49
50impl<'tcx> CallGraphAnalyzer<'tcx> {
51 pub fn new(tcx: TyCtxt<'tcx>) -> Self {
52 Self {
53 tcx: tcx,
54 graph: CallGraph::new(tcx),
55 }
56 }
57
58 pub fn start(&mut self) {
59 for local_def_id in self.tcx.mir_keys(()) {
60 let def_id = local_def_id.to_def_id();
61 if self.tcx.is_mir_available(def_id) {
62 let def_kind = self.tcx.def_kind(def_id);
63
64 let body: &Body<'_> = match def_kind {
65 DefKind::Fn | DefKind::AssocFn | DefKind::Closure => {
66 &self.tcx.optimized_mir(def_id)
67 }
68 DefKind::Const
69 | DefKind::Static { .. }
70 | DefKind::AssocConst
71 | DefKind::InlineConst
72 | DefKind::AnonConst => {
73 &self.tcx.mir_for_ctfe(def_id)
75 }
76 _ => {
78 rap_debug!("Skipping def_id {:?} with kind {:?}", def_id, def_kind);
79 continue;
80 }
81 };
82
83 let mut call_graph_visitor =
84 CallGraphVisitor::new(self.tcx, def_id.into(), body, &mut self.graph);
85 call_graph_visitor.visit();
86 }
87 }
88 }
89}
90
91pub type CallMap<'tcx> = HashMap<DefId, Vec<(DefId, Option<&'tcx mir::Terminator<'tcx>>)>>;
92
93pub struct CallGraph<'tcx> {
94 pub tcx: TyCtxt<'tcx>,
95 pub functions: HashSet<DefId>, pub fn_calls: CallMap<'tcx>, }
98
99impl<'tcx> CallGraph<'tcx> {
101 pub fn new(tcx: TyCtxt<'tcx>) -> Self {
102 Self {
103 tcx,
104 functions: HashSet::new(),
105 fn_calls: HashMap::new(),
106 }
107 }
108
109 pub fn register_fn(&mut self, def_id: DefId) -> bool {
111 if let Some(_) = self.functions.iter().find(|func_id| **func_id == def_id) {
112 false
113 } else {
114 self.functions.insert(def_id);
115 true
116 }
117 }
118
119 pub fn add_funciton_call(
121 &mut self,
122 caller_id: DefId,
123 callee_id: DefId,
124 terminator_stmt: Option<&'tcx mir::Terminator<'tcx>>,
125 ) {
126 let entry = self.fn_calls.entry(caller_id).or_insert_with(Vec::new);
127 entry.push((callee_id, terminator_stmt));
128 }
129}
130
131impl<'tcx> CallGraph<'tcx> {
133 pub fn get_reverse_post_order(&self) -> Vec<DefId> {
134 let mut result = self.get_post_order();
135 result.reverse();
136 result
137 }
138
139 pub fn get_post_order(&self) -> Vec<DefId> {
140 let mut visited = HashSet::new();
141 let mut post_order_ids = Vec::new(); for &func_def_id in self.functions.iter() {
145 if !visited.contains(&func_def_id) {
146 self.dfs_post_order(func_def_id, &mut visited, &mut post_order_ids);
147 }
148 }
149
150 post_order_ids
151 }
152
153 fn dfs_post_order(
155 &self,
156 func_def_id: DefId,
157 visited: &mut HashSet<DefId>,
158 post_order_ids: &mut Vec<DefId>,
159 ) {
160 visited.insert(func_def_id);
162
163 if let Some(callees) = self.fn_calls.get(&func_def_id) {
165 for (callee_id, _terminator) in callees {
166 if !visited.contains(callee_id) {
167 self.dfs_post_order(*callee_id, visited, post_order_ids);
168 }
169 }
170 }
171
172 post_order_ids.push(func_def_id);
174 }
175
176 pub fn get_callers_map(&self) -> CallMap<'tcx> {
178 let mut callers_map: CallMap<'tcx> = HashMap::new();
179
180 for (&caller_id, calls_vec) in &self.fn_calls {
181 for (callee_id, terminator) in calls_vec {
182 callers_map
183 .entry(*callee_id)
184 .or_insert_with(Vec::new)
185 .push((caller_id, *terminator));
186 }
187 }
188 callers_map
189 }
190
191 pub fn get_callees(&self, caller_def_id: DefId) -> Vec<DefId> {
193 if let Some(callees) = self.fn_calls.get(&caller_def_id) {
194 callees
195 .clone()
196 .into_iter()
197 .map(|(did, _)| did)
198 .collect::<Vec<_>>()
199 } else {
200 vec![]
201 }
202 }
203
204 pub fn get_callees_recursive(&self, caller_def_id: DefId) -> Vec<DefId> {
206 let mut visited = HashSet::new();
207 let mut result = Vec::new();
208 self.dfs_post_order(caller_def_id, &mut visited, &mut result);
209 result
210 }
211
212 pub fn get_callers(&self, callee_def_id: DefId) -> Vec<DefId> {
214 let callers_map = self.get_callers_map();
215 if let Some(callers) = callers_map.get(&callee_def_id) {
216 callers
217 .clone()
218 .into_iter()
219 .map(|(did, _)| did)
220 .collect::<Vec<_>>()
221 } else {
222 vec![]
223 }
224 }
225}