rapx/analysis/core/callgraph/
default.rs1use rustc_hir::{def::DefKind, def_id::DefId};
2use rustc_middle::{
3 mir::{self, Body},
4 ty::TyCtxt,
5};
6use std::collections::HashSet;
7use std::{collections::HashMap, hash::Hash};
8
9use super::visitor::CallGraphVisitor;
10use crate::{
11 analysis::core::callgraph::{CallGraph, CallGraphAnalysis},
12 rap_debug, rap_info, Analysis,
13};
14
15pub struct CallGraphAnalyzer<'tcx> {
16 pub tcx: TyCtxt<'tcx>,
17 pub graph: CallGraphInfo<'tcx>,
18}
19
20impl<'tcx> Analysis for CallGraphAnalyzer<'tcx> {
21 fn name(&self) -> &'static str {
22 "Default call graph analysis algorithm."
23 }
24
25 fn run(&mut self) {
26 let mut analysis = CallGraphAnalyzer::new(self.tcx);
27 analysis.start();
28 }
29
30 fn reset(&mut self) {
31 todo!();
32 }
33}
34
35impl<'tcx> CallGraphAnalysis for CallGraphAnalyzer<'tcx> {
36 fn get_callgraph(&mut self) -> CallGraph {
37 let fn_calls: HashMap<DefId, Vec<DefId>> = self
38 .graph
39 .fn_calls
40 .clone()
41 .into_iter()
42 .map(|(caller, callees)| {
43 let caller_id = self
44 .graph
45 .functions
46 .get(&caller)
47 .expect("Key must exist in functions map")
48 .def_id;
49
50 let callees_id = callees
51 .into_iter()
52 .map(|(callee, _)| {
53 self.graph
54 .functions
55 .get(&callee)
56 .expect("Value must exist in functions map")
57 .def_id
58 })
59 .collect::<Vec<_>>();
60 (caller_id, callees_id)
61 })
62 .collect();
63 CallGraph { fn_calls }
64 }
65}
66
67impl<'tcx> CallGraphAnalyzer<'tcx> {
68 pub fn new(tcx: TyCtxt<'tcx>) -> Self {
69 Self {
70 tcx: tcx,
71 graph: CallGraphInfo::new(),
72 }
73 }
74
75 pub fn start(&mut self) {
76 for local_def_id in self.tcx.iter_local_def_id() {
77 if self.tcx.hir_maybe_body_owned_by(local_def_id).is_some() {
78 let def_id = local_def_id.to_def_id();
79 if self.tcx.is_mir_available(def_id) {
80 let def_kind = self.tcx.def_kind(def_id);
81
82 let body: &Body<'_> = match def_kind {
83 DefKind::Fn | DefKind::AssocFn => &self.tcx.optimized_mir(def_id),
84 DefKind::Const
85 | DefKind::Static { .. }
86 | DefKind::AssocConst
87 | DefKind::InlineConst
88 | DefKind::AnonConst => {
89 &self.tcx.mir_for_ctfe(def_id)
91 }
92 _ => {
94 rap_debug!("Skipping def_id {:?} with kind {:?}", def_id, def_kind);
95 continue;
96 }
97 };
98
99 let mut call_graph_visitor =
100 CallGraphVisitor::new(self.tcx, def_id.into(), body, &mut self.graph);
101 call_graph_visitor.visit();
102 }
103 }
104 }
105 }
106
107 pub fn get_callee_def_path(&self, def_path: String) -> Option<HashSet<String>> {
108 self.graph.get_callees_path(&def_path)
109 }
110}
111
112#[derive(Debug, Clone, Eq, PartialEq, Hash)]
113pub struct Node {
114 def_id: DefId,
115 def_path: String,
116}
117
118impl Node {
119 pub fn new(def_id: DefId, def_path: &String) -> Self {
120 Self {
121 def_id: def_id,
122 def_path: def_path.clone(),
123 }
124 }
125
126 pub fn get_def_id(&self) -> DefId {
127 self.def_id
128 }
129
130 pub fn get_def_path(&self) -> String {
131 self.def_path.clone()
132 }
133}
134
135pub struct CallGraphInfo<'tcx> {
136 pub functions: HashMap<usize, Node>, pub fn_calls: HashMap<usize, Vec<(usize, &'tcx mir::Terminator<'tcx>)>>, pub node_registry: HashMap<String, usize>, }
140
141impl<'tcx> CallGraphInfo<'tcx> {
142 pub fn new() -> Self {
143 Self {
144 functions: HashMap::new(),
145 fn_calls: HashMap::new(),
146 node_registry: HashMap::new(),
147 }
148 }
149
150 pub fn get_node_num(&self) -> usize {
151 self.functions.len()
152 }
153
154 pub fn get_callees_path(&self, caller_def_path: &String) -> Option<HashSet<String>> {
155 let mut callees_path: HashSet<String> = HashSet::new();
156 if let Some(caller_id) = self.node_registry.get(caller_def_path) {
157 if let Some(callees) = self.fn_calls.get(caller_id) {
158 for (id, _terminator) in callees {
159 if let Some(callee_node) = self.functions.get(id) {
160 callees_path.insert(callee_node.get_def_path());
161 }
162 }
163 }
164 Some(callees_path)
165 } else {
166 None
167 }
168 }
169
170 pub fn add_node(&mut self, def_id: DefId, def_path: &String) {
171 if self.node_registry.get(def_path).is_none() {
172 let id = self.node_registry.len();
173 let node = Node::new(def_id, def_path);
174 self.node_registry.insert(def_path.clone(), id);
175 self.functions.insert(id, node);
176 }
177 }
178
179 pub fn add_funciton_call_edge(
180 &mut self,
181 caller_id: usize,
182 callee_id: usize,
183 terminator_stmt: &'tcx mir::Terminator<'tcx>,
184 ) {
185 let entry = self.fn_calls.entry(caller_id).or_insert_with(Vec::new);
186 entry.push((callee_id, terminator_stmt));
187 }
188
189 pub fn get_node_by_path(&self, def_path: &String) -> Option<usize> {
190 self.node_registry.get(def_path).copied()
191 }
192 pub fn get_callers_map(&self) -> HashMap<usize, Vec<(usize, &'tcx mir::Terminator<'tcx>)>> {
193 let mut callers_map: HashMap<usize, Vec<(usize, &'tcx mir::Terminator<'tcx>)>> =
194 HashMap::new();
195
196 for (&caller_id, calls_vec) in &self.fn_calls {
197 for (callee_id, terminator) in calls_vec {
198 callers_map
199 .entry(*callee_id)
200 .or_insert_with(Vec::new)
201 .push((caller_id, *terminator));
202 }
203 }
204 callers_map
205 }
206
207 pub fn display(&self) {
208 rap_info!("CallGraph Analysis:");
209 for (caller_id, callees) in &self.fn_calls {
210 if let Some(caller_node) = self.functions.get(caller_id) {
211 for (callee_id, terminator_stmt) in callees {
212 if let Some(callee_node) = self.functions.get(callee_id) {
213 let caller_def_path = caller_node.get_def_path();
214 let callee_def_path = callee_node.get_def_path();
215 rap_info!(
216 "{}:{} -> {}:{} @ {:?}",
217 caller_id,
218 caller_def_path,
219 *callee_id,
220 callee_def_path,
221 terminator_stmt.kind
222 );
223 }
224 }
225 }
226 }
227 }
228
229 pub fn get_reverse_post_order(&self) -> Vec<DefId> {
230 let mut visited = HashSet::new();
231 let mut post_order_ids = Vec::new(); for &node_id in self.functions.keys() {
235 if !visited.contains(&node_id) {
236 self.dfs_post_order(node_id, &mut visited, &mut post_order_ids);
237 }
238 }
239
240 let mut analysis_order: Vec<DefId> = post_order_ids
242 .into_iter()
243 .map(|id| {
244 self.functions
245 .get(&id)
246 .expect("Node ID must exist in functions map")
247 .def_id
248 })
249 .collect();
250
251 analysis_order.reverse();
253
254 analysis_order
255 }
256
257 fn dfs_post_order(
259 &self,
260 node_id: usize,
261 visited: &mut HashSet<usize>,
262 post_order_ids: &mut Vec<usize>,
263 ) {
264 visited.insert(node_id);
266
267 if let Some(callees) = self.fn_calls.get(&node_id) {
269 for (callee_id, _terminator) in callees {
270 if !visited.contains(callee_id) {
271 self.dfs_post_order(*callee_id, visited, post_order_ids);
272 }
273 }
274 }
275
276 post_order_ids.push(node_id);
278 }
279}