rapx/analysis/core/callgraph/
default.rs1use rustc_hir::{def::DefKind, def_id::DefId};
2use rustc_middle::{
3 mir::{self, Body},
4 ty::TyCtxt,
5};
6use std::collections::HashSet;
7use std::{collections::HashMap, hash::Hash};
8
9use super::visitor::CallGraphVisitor;
10use crate::{
11 Analysis,
12 analysis::core::callgraph::{CallGraph, CallGraphAnalysis},
13 rap_debug, rap_info,
14};
15
16pub struct CallGraphAnalyzer<'tcx> {
17 pub tcx: TyCtxt<'tcx>,
18 pub graph: CallGraphInfo<'tcx>,
19}
20
21impl<'tcx> Analysis for CallGraphAnalyzer<'tcx> {
22 fn name(&self) -> &'static str {
23 "Default call graph analysis algorithm."
24 }
25
26 fn run(&mut self) {
27 self.start();
28 }
29
30 fn reset(&mut self) {
31 todo!();
32 }
33}
34
35impl<'tcx> CallGraphAnalysis for CallGraphAnalyzer<'tcx> {
36 fn get_callgraph(&mut self) -> CallGraph {
37 let fn_calls: HashMap<DefId, Vec<DefId>> = self
38 .graph
39 .fn_calls
40 .clone()
41 .into_iter()
42 .map(|(caller, callees)| {
43 let caller_id = self
44 .graph
45 .functions
46 .get(&caller)
47 .expect("Key must exist in functions map")
48 .def_id;
49
50 let callees_id = callees
51 .into_iter()
52 .map(|(callee, _)| {
53 self.graph
54 .functions
55 .get(&callee)
56 .expect("Value must exist in functions map")
57 .def_id
58 })
59 .collect::<Vec<_>>();
60 (caller_id, callees_id)
61 })
62 .collect();
63 CallGraph { fn_calls }
64 }
65}
66
67impl<'tcx> CallGraphAnalyzer<'tcx> {
68 pub fn new(tcx: TyCtxt<'tcx>) -> Self {
69 Self {
70 tcx: tcx,
71 graph: CallGraphInfo::new(),
72 }
73 }
74
75 pub fn start(&mut self) {
76 for local_def_id in self.tcx.iter_local_def_id() {
77 if self.tcx.hir_maybe_body_owned_by(local_def_id).is_some() {
78 let def_id = local_def_id.to_def_id();
79 if self.tcx.is_mir_available(def_id) {
80 let def_kind = self.tcx.def_kind(def_id);
81
82 let body: &Body<'_> = match def_kind {
83 DefKind::Fn | DefKind::AssocFn => &self.tcx.optimized_mir(def_id),
84 DefKind::Const
85 | DefKind::Static { .. }
86 | DefKind::AssocConst
87 | DefKind::InlineConst
88 | DefKind::AnonConst => {
89 &self.tcx.mir_for_ctfe(def_id)
91 }
92 _ => {
94 rap_debug!("Skipping def_id {:?} with kind {:?}", def_id, def_kind);
95 continue;
96 }
97 };
98
99 let mut call_graph_visitor =
100 CallGraphVisitor::new(self.tcx, def_id.into(), body, &mut self.graph);
101 call_graph_visitor.visit();
102 }
103 }
104 }
105 }
106
107 pub fn get_callee_def_path(&self, def_path: String) -> Option<HashSet<String>> {
108 self.graph.get_callees_path(&def_path)
109 }
110}
111
112#[derive(Debug, Clone, Eq, PartialEq, Hash)]
113pub struct Node {
114 def_id: DefId,
115 def_path: String,
116}
117
118impl Node {
119 pub fn new(def_id: DefId, def_path: &String) -> Self {
120 Self {
121 def_id: def_id,
122 def_path: def_path.clone(),
123 }
124 }
125
126 pub fn get_def_id(&self) -> DefId {
127 self.def_id
128 }
129
130 pub fn get_def_path(&self) -> String {
131 self.def_path.clone()
132 }
133}
134
135pub struct CallGraphInfo<'tcx> {
136 pub functions: HashMap<usize, Node>, pub fn_calls: HashMap<usize, Vec<(usize, Option<&'tcx mir::Terminator<'tcx>>)>>, pub node_registry: HashMap<String, usize>, }
140
141impl<'tcx> CallGraphInfo<'tcx> {
142 pub fn new() -> Self {
143 Self {
144 functions: HashMap::new(),
145 fn_calls: HashMap::new(),
146 node_registry: HashMap::new(),
147 }
148 }
149
150 pub fn get_node_num(&self) -> usize {
151 self.functions.len()
152 }
153
154 pub fn get_callees_path(&self, caller_def_path: &String) -> Option<HashSet<String>> {
155 let mut callees_path: HashSet<String> = HashSet::new();
156 if let Some(caller_id) = self.node_registry.get(caller_def_path) {
157 if let Some(callees) = self.fn_calls.get(caller_id) {
158 for (id, _terminator) in callees {
159 if let Some(callee_node) = self.functions.get(id) {
160 callees_path.insert(callee_node.get_def_path());
161 }
162 }
163 }
164 Some(callees_path)
165 } else {
166 None
167 }
168 }
169
170 pub fn add_node(&mut self, def_id: DefId, def_path: &String) -> usize {
172 if let Some(old_id) = self.node_registry.get(def_path) {
173 *old_id
174 } else {
175 let new_id = self.node_registry.len();
176 let node = Node::new(def_id, def_path);
177 self.node_registry.insert(def_path.clone(), new_id);
178 self.functions.insert(new_id, node);
179 new_id
180 }
181 }
182
183 pub fn add_funciton_call_edge(
184 &mut self,
185 caller_id: usize,
186 callee_id: usize,
187 terminator_stmt: Option<&'tcx mir::Terminator<'tcx>>,
188 ) {
189 let entry = self.fn_calls.entry(caller_id).or_insert_with(Vec::new);
190 entry.push((callee_id, terminator_stmt));
191 }
192
193 pub fn get_node_by_path(&self, def_path: &String) -> Option<usize> {
194 self.node_registry.get(def_path).copied()
195 }
196 pub fn get_callers_map(
197 &self,
198 ) -> HashMap<usize, Vec<(usize, Option<&'tcx mir::Terminator<'tcx>>)>> {
199 let mut callers_map: HashMap<usize, Vec<(usize, Option<&'tcx mir::Terminator<'tcx>>)>> =
200 HashMap::new();
201
202 for (&caller_id, calls_vec) in &self.fn_calls {
203 for (callee_id, terminator) in calls_vec {
204 callers_map
205 .entry(*callee_id)
206 .or_insert_with(Vec::new)
207 .push((caller_id, *terminator));
208 }
209 }
210 callers_map
211 }
212
213 pub fn display(&self) {
214 rap_info!("CallGraph Analysis:");
215 for (caller_id, callees) in &self.fn_calls {
216 if let Some(caller_node) = self.functions.get(caller_id) {
217 for (callee_id, terminator) in callees {
218 if let Some(callee_node) = self.functions.get(callee_id) {
219 let caller_def_path = caller_node.get_def_path();
220 let callee_def_path = callee_node.get_def_path();
221 if let Some(terminator_stmt) = terminator {
222 rap_info!(
223 "{}:{} -> {}:{} @ {:?}",
224 caller_id,
225 caller_def_path,
226 *callee_id,
227 callee_def_path,
228 terminator_stmt.kind
229 );
230 } else {
231 rap_info!(
232 " (Virtual) {}:{} -> {}:{}",
233 caller_id,
234 caller_def_path,
235 *callee_id,
236 callee_def_path,
237 );
238 }
239 }
240 }
241 }
242 }
243 }
244
245 pub fn get_reverse_post_order(&self) -> Vec<DefId> {
246 let mut visited = HashSet::new();
247 let mut post_order_ids = Vec::new(); for &node_id in self.functions.keys() {
251 if !visited.contains(&node_id) {
252 self.dfs_post_order(node_id, &mut visited, &mut post_order_ids);
253 }
254 }
255
256 let mut analysis_order: Vec<DefId> = post_order_ids
258 .into_iter()
259 .map(|id| {
260 self.functions
261 .get(&id)
262 .expect("Node ID must exist in functions map")
263 .def_id
264 })
265 .collect();
266
267 analysis_order.reverse();
269
270 analysis_order
271 }
272
273 fn dfs_post_order(
275 &self,
276 node_id: usize,
277 visited: &mut HashSet<usize>,
278 post_order_ids: &mut Vec<usize>,
279 ) {
280 visited.insert(node_id);
282
283 if let Some(callees) = self.fn_calls.get(&node_id) {
285 for (callee_id, _terminator) in callees {
286 if !visited.contains(callee_id) {
287 self.dfs_post_order(*callee_id, visited, post_order_ids);
288 }
289 }
290 }
291
292 post_order_ids.push(node_id);
294 }
295}