rapx/analysis/core/callgraph/
default.rs1use rustc_hir::{def::DefKind, def_id::DefId};
2use rustc_middle::{
3 mir::{self, Body},
4 ty::TyCtxt,
5};
6use std::collections::HashSet;
7use std::{collections::HashMap, hash::Hash};
8
9use super::visitor::CallGraphVisitor;
10use crate::{
11 Analysis,
12 analysis::core::callgraph::{CallGraph, CallGraphAnalysis},
13};
14
15pub struct CallGraphAnalyzer<'tcx> {
16 pub tcx: TyCtxt<'tcx>,
17 pub graph: CallGraphInfo<'tcx>,
18}
19
20impl<'tcx> Analysis for CallGraphAnalyzer<'tcx> {
21 fn name(&self) -> &'static str {
22 "Default call graph analysis algorithm."
23 }
24
25 fn run(&mut self) {
26 self.start();
27 }
28
29 fn reset(&mut self) {
30 todo!();
31 }
32}
33
34impl<'tcx> CallGraphAnalysis for CallGraphAnalyzer<'tcx> {
35 fn get_callgraph(&mut self) -> CallGraph {
36 let fn_calls: HashMap<DefId, Vec<DefId>> = self
37 .graph
38 .fn_calls
39 .clone()
40 .into_iter()
41 .map(|(caller, callees)| {
42 let caller_id = self
43 .graph
44 .functions
45 .get(&caller)
46 .expect("Key must exist in functions map")
47 .def_id;
48
49 let callees_id = callees
50 .into_iter()
51 .map(|(callee, _)| {
52 self.graph
53 .functions
54 .get(&callee)
55 .expect("Value must exist in functions map")
56 .def_id
57 })
58 .collect::<Vec<_>>();
59 (caller_id, callees_id)
60 })
61 .collect();
62 CallGraph { fn_calls }
63 }
64}
65
66impl<'tcx> CallGraphAnalyzer<'tcx> {
67 pub fn new(tcx: TyCtxt<'tcx>) -> Self {
68 Self {
69 tcx: tcx,
70 graph: CallGraphInfo::new(),
71 }
72 }
73
74 pub fn start(&mut self) {
75 for local_def_id in self.tcx.iter_local_def_id() {
76 if self.tcx.hir_maybe_body_owned_by(local_def_id).is_some() {
77 let def_id = local_def_id.to_def_id();
78 if self.tcx.is_mir_available(def_id) {
79 let def_kind = self.tcx.def_kind(def_id);
80
81 let body: &Body<'_> = match def_kind {
82 DefKind::Fn | DefKind::AssocFn => &self.tcx.optimized_mir(def_id),
83 DefKind::Const
84 | DefKind::Static { .. }
85 | DefKind::AssocConst
86 | DefKind::InlineConst
87 | DefKind::AnonConst => {
88 &self.tcx.mir_for_ctfe(def_id)
90 }
91 _ => {
93 rap_debug!("Skipping def_id {:?} with kind {:?}", def_id, def_kind);
94 continue;
95 }
96 };
97
98 let mut call_graph_visitor =
99 CallGraphVisitor::new(self.tcx, def_id.into(), body, &mut self.graph);
100 call_graph_visitor.visit();
101 }
102 }
103 }
104 }
105
106 pub fn get_callee_def_path(&self, def_path: String) -> Option<HashSet<String>> {
107 self.graph.get_callees_path(&def_path)
108 }
109}
110
111#[derive(Debug, Clone, Eq, PartialEq, Hash)]
112pub struct Node {
113 def_id: DefId,
114 def_path: String,
115}
116
117impl Node {
118 pub fn new(def_id: DefId, def_path: &String) -> Self {
119 Self {
120 def_id: def_id,
121 def_path: def_path.clone(),
122 }
123 }
124
125 pub fn get_def_id(&self) -> DefId {
126 self.def_id
127 }
128
129 pub fn get_def_path(&self) -> String {
130 self.def_path.clone()
131 }
132}
133
134pub struct CallGraphInfo<'tcx> {
135 pub functions: HashMap<usize, Node>, pub fn_calls: HashMap<usize, Vec<(usize, Option<&'tcx mir::Terminator<'tcx>>)>>, pub node_registry: HashMap<String, usize>, }
139
140impl<'tcx> CallGraphInfo<'tcx> {
141 pub fn new() -> Self {
142 Self {
143 functions: HashMap::new(),
144 fn_calls: HashMap::new(),
145 node_registry: HashMap::new(),
146 }
147 }
148
149 pub fn get_node_num(&self) -> usize {
150 self.functions.len()
151 }
152
153 pub fn get_callees_path(&self, caller_def_path: &String) -> Option<HashSet<String>> {
154 let mut callees_path: HashSet<String> = HashSet::new();
155 if let Some(caller_id) = self.node_registry.get(caller_def_path) {
156 if let Some(callees) = self.fn_calls.get(caller_id) {
157 for (id, _terminator) in callees {
158 if let Some(callee_node) = self.functions.get(id) {
159 callees_path.insert(callee_node.get_def_path());
160 }
161 }
162 }
163 Some(callees_path)
164 } else {
165 None
166 }
167 }
168
169 pub fn add_node(&mut self, def_id: DefId, def_path: &String) -> usize {
171 if let Some(old_id) = self.node_registry.get(def_path) {
172 *old_id
173 } else {
174 let new_id = self.node_registry.len();
175 let node = Node::new(def_id, def_path);
176 self.node_registry.insert(def_path.clone(), new_id);
177 self.functions.insert(new_id, node);
178 new_id
179 }
180 }
181
182 pub fn add_funciton_call_edge(
183 &mut self,
184 caller_id: usize,
185 callee_id: usize,
186 terminator_stmt: Option<&'tcx mir::Terminator<'tcx>>,
187 ) {
188 let entry = self.fn_calls.entry(caller_id).or_insert_with(Vec::new);
189 entry.push((callee_id, terminator_stmt));
190 }
191
192 pub fn get_node_by_path(&self, def_path: &String) -> Option<usize> {
193 self.node_registry.get(def_path).copied()
194 }
195 pub fn get_callers_map(
196 &self,
197 ) -> HashMap<usize, Vec<(usize, Option<&'tcx mir::Terminator<'tcx>>)>> {
198 let mut callers_map: HashMap<usize, Vec<(usize, Option<&'tcx mir::Terminator<'tcx>>)>> =
199 HashMap::new();
200
201 for (&caller_id, calls_vec) in &self.fn_calls {
202 for (callee_id, terminator) in calls_vec {
203 callers_map
204 .entry(*callee_id)
205 .or_insert_with(Vec::new)
206 .push((caller_id, *terminator));
207 }
208 }
209 callers_map
210 }
211
212 pub fn display(&self) {
213 rap_info!("CallGraph Analysis:");
214 for (caller_id, callees) in &self.fn_calls {
215 if let Some(caller_node) = self.functions.get(caller_id) {
216 for (callee_id, terminator) in callees {
217 if let Some(callee_node) = self.functions.get(callee_id) {
218 let caller_def_path = caller_node.get_def_path();
219 let callee_def_path = callee_node.get_def_path();
220 if let Some(terminator_stmt) = terminator {
221 rap_info!(
222 "{}:{} -> {}:{} @ {:?}",
223 caller_id,
224 caller_def_path,
225 *callee_id,
226 callee_def_path,
227 terminator_stmt.kind
228 );
229 } else {
230 rap_info!(
231 " (Virtual) {}:{} -> {}:{}",
232 caller_id,
233 caller_def_path,
234 *callee_id,
235 callee_def_path,
236 );
237 }
238 }
239 }
240 }
241 }
242 }
243
244 pub fn get_reverse_post_order(&self) -> Vec<DefId> {
245 let mut visited = HashSet::new();
246 let mut post_order_ids = Vec::new(); for &node_id in self.functions.keys() {
250 if !visited.contains(&node_id) {
251 self.dfs_post_order(node_id, &mut visited, &mut post_order_ids);
252 }
253 }
254
255 let mut analysis_order: Vec<DefId> = post_order_ids
257 .into_iter()
258 .map(|id| {
259 self.functions
260 .get(&id)
261 .expect("Node ID must exist in functions map")
262 .def_id
263 })
264 .collect();
265
266 analysis_order.reverse();
268
269 analysis_order
270 }
271
272 fn dfs_post_order(
274 &self,
275 node_id: usize,
276 visited: &mut HashSet<usize>,
277 post_order_ids: &mut Vec<usize>,
278 ) {
279 visited.insert(node_id);
281
282 if let Some(callees) = self.fn_calls.get(&node_id) {
284 for (callee_id, _terminator) in callees {
285 if !visited.contains(callee_id) {
286 self.dfs_post_order(*callee_id, visited, post_order_ids);
287 }
288 }
289 }
290
291 post_order_ids.push(node_id);
293 }
294}