rapx/analysis/core/callgraph/
default.rs

1use rustc_hir::{def::DefKind, def_id::DefId};
2use rustc_middle::{
3    mir::{self, Body},
4    ty::TyCtxt,
5};
6use std::collections::HashSet;
7use std::{collections::HashMap, hash::Hash};
8
9use super::visitor::CallGraphVisitor;
10use crate::{
11    Analysis,
12    analysis::core::callgraph::{CallGraph, CallGraphAnalysis},
13    rap_debug, rap_info,
14};
15
16pub struct CallGraphAnalyzer<'tcx> {
17    pub tcx: TyCtxt<'tcx>,
18    pub graph: CallGraphInfo<'tcx>,
19}
20
21impl<'tcx> Analysis for CallGraphAnalyzer<'tcx> {
22    fn name(&self) -> &'static str {
23        "Default call graph analysis algorithm."
24    }
25
26    fn run(&mut self) {
27        self.start();
28    }
29
30    fn reset(&mut self) {
31        todo!();
32    }
33}
34
35impl<'tcx> CallGraphAnalysis for CallGraphAnalyzer<'tcx> {
36    fn get_callgraph(&mut self) -> CallGraph {
37        let fn_calls: HashMap<DefId, Vec<DefId>> = self
38            .graph
39            .fn_calls
40            .clone()
41            .into_iter()
42            .map(|(caller, callees)| {
43                let caller_id = self
44                    .graph
45                    .functions
46                    .get(&caller)
47                    .expect("Key must exist in functions map")
48                    .def_id;
49
50                let callees_id = callees
51                    .into_iter()
52                    .map(|(callee, _)| {
53                        self.graph
54                            .functions
55                            .get(&callee)
56                            .expect("Value must exist in functions map")
57                            .def_id
58                    })
59                    .collect::<Vec<_>>();
60                (caller_id, callees_id)
61            })
62            .collect();
63        CallGraph { fn_calls }
64    }
65}
66
67impl<'tcx> CallGraphAnalyzer<'tcx> {
68    pub fn new(tcx: TyCtxt<'tcx>) -> Self {
69        Self {
70            tcx: tcx,
71            graph: CallGraphInfo::new(),
72        }
73    }
74
75    pub fn start(&mut self) {
76        for local_def_id in self.tcx.iter_local_def_id() {
77            if self.tcx.hir_maybe_body_owned_by(local_def_id).is_some() {
78                let def_id = local_def_id.to_def_id();
79                if self.tcx.is_mir_available(def_id) {
80                    let def_kind = self.tcx.def_kind(def_id);
81
82                    let body: &Body<'_> = match def_kind {
83                        DefKind::Fn | DefKind::AssocFn => &self.tcx.optimized_mir(def_id),
84                        DefKind::Const
85                        | DefKind::Static { .. }
86                        | DefKind::AssocConst
87                        | DefKind::InlineConst
88                        | DefKind::AnonConst => {
89                            // NOTE: safer fallback for constants
90                            &self.tcx.mir_for_ctfe(def_id)
91                        }
92                        // These don't have MIR or shouldn't be visited
93                        _ => {
94                            rap_debug!("Skipping def_id {:?} with kind {:?}", def_id, def_kind);
95                            continue;
96                        }
97                    };
98
99                    let mut call_graph_visitor =
100                        CallGraphVisitor::new(self.tcx, def_id.into(), body, &mut self.graph);
101                    call_graph_visitor.visit();
102                }
103            }
104        }
105    }
106
107    pub fn get_callee_def_path(&self, def_path: String) -> Option<HashSet<String>> {
108        self.graph.get_callees_path(&def_path)
109    }
110}
111
112#[derive(Debug, Clone, Eq, PartialEq, Hash)]
113pub struct Node {
114    def_id: DefId,
115    def_path: String,
116}
117
118impl Node {
119    pub fn new(def_id: DefId, def_path: &String) -> Self {
120        Self {
121            def_id: def_id,
122            def_path: def_path.clone(),
123        }
124    }
125
126    pub fn get_def_id(&self) -> DefId {
127        self.def_id
128    }
129
130    pub fn get_def_path(&self) -> String {
131        self.def_path.clone()
132    }
133}
134
135pub struct CallGraphInfo<'tcx> {
136    pub functions: HashMap<usize, Node>, // id -> node
137    pub fn_calls: HashMap<usize, Vec<(usize, Option<&'tcx mir::Terminator<'tcx>>)>>, // caller_id -> Vec<(callee_id, terminator)>
138    pub node_registry: HashMap<String, usize>,                                       // path -> id
139}
140
141impl<'tcx> CallGraphInfo<'tcx> {
142    pub fn new() -> Self {
143        Self {
144            functions: HashMap::new(),
145            fn_calls: HashMap::new(),
146            node_registry: HashMap::new(),
147        }
148    }
149
150    pub fn get_node_num(&self) -> usize {
151        self.functions.len()
152    }
153
154    pub fn get_callees_path(&self, caller_def_path: &String) -> Option<HashSet<String>> {
155        let mut callees_path: HashSet<String> = HashSet::new();
156        if let Some(caller_id) = self.node_registry.get(caller_def_path) {
157            if let Some(callees) = self.fn_calls.get(caller_id) {
158                for (id, _terminator) in callees {
159                    if let Some(callee_node) = self.functions.get(id) {
160                        callees_path.insert(callee_node.get_def_path());
161                    }
162                }
163            }
164            Some(callees_path)
165        } else {
166            None
167        }
168    }
169
170    /// Add a node and return its id. If node already exists, only return its id.
171    pub fn add_node(&mut self, def_id: DefId, def_path: &String) -> usize {
172        if let Some(old_id) = self.node_registry.get(def_path) {
173            *old_id
174        } else {
175            let new_id = self.node_registry.len();
176            let node = Node::new(def_id, def_path);
177            self.node_registry.insert(def_path.clone(), new_id);
178            self.functions.insert(new_id, node);
179            new_id
180        }
181    }
182
183    pub fn add_funciton_call_edge(
184        &mut self,
185        caller_id: usize,
186        callee_id: usize,
187        terminator_stmt: Option<&'tcx mir::Terminator<'tcx>>,
188    ) {
189        let entry = self.fn_calls.entry(caller_id).or_insert_with(Vec::new);
190        entry.push((callee_id, terminator_stmt));
191    }
192
193    pub fn get_node_by_path(&self, def_path: &String) -> Option<usize> {
194        self.node_registry.get(def_path).copied()
195    }
196    pub fn get_callers_map(
197        &self,
198    ) -> HashMap<usize, Vec<(usize, Option<&'tcx mir::Terminator<'tcx>>)>> {
199        let mut callers_map: HashMap<usize, Vec<(usize, Option<&'tcx mir::Terminator<'tcx>>)>> =
200            HashMap::new();
201
202        for (&caller_id, calls_vec) in &self.fn_calls {
203            for (callee_id, terminator) in calls_vec {
204                callers_map
205                    .entry(*callee_id)
206                    .or_insert_with(Vec::new)
207                    .push((caller_id, *terminator));
208            }
209        }
210        callers_map
211    }
212
213    pub fn display(&self) {
214        rap_info!("CallGraph Analysis:");
215        for (caller_id, callees) in &self.fn_calls {
216            if let Some(caller_node) = self.functions.get(caller_id) {
217                for (callee_id, terminator) in callees {
218                    if let Some(callee_node) = self.functions.get(callee_id) {
219                        let caller_def_path = caller_node.get_def_path();
220                        let callee_def_path = callee_node.get_def_path();
221                        if let Some(terminator_stmt) = terminator {
222                            rap_info!(
223                                "{}:{} -> {}:{} @ {:?}",
224                                caller_id,
225                                caller_def_path,
226                                *callee_id,
227                                callee_def_path,
228                                terminator_stmt.kind
229                            );
230                        } else {
231                            rap_info!(
232                                " (Virtual) {}:{} -> {}:{}",
233                                caller_id,
234                                caller_def_path,
235                                *callee_id,
236                                callee_def_path,
237                            );
238                        }
239                    }
240                }
241            }
242        }
243    }
244
245    pub fn get_reverse_post_order(&self) -> Vec<DefId> {
246        let mut visited = HashSet::new();
247        let mut post_order_ids = Vec::new(); // Will store the post-order traversal of `usize` IDs
248
249        // Iterate over all functions defined in the graph to handle disconnected components
250        for &node_id in self.functions.keys() {
251            if !visited.contains(&node_id) {
252                self.dfs_post_order(node_id, &mut visited, &mut post_order_ids);
253            }
254        }
255
256        // Map the ordered `usize` IDs back to `DefId`s for the analysis pipeline
257        let mut analysis_order: Vec<DefId> = post_order_ids
258            .into_iter()
259            .map(|id| {
260                self.functions
261                    .get(&id)
262                    .expect("Node ID must exist in functions map")
263                    .def_id
264            })
265            .collect();
266
267        // Reversing the post-order gives a topological sort (bottom-up)
268        analysis_order.reverse();
269
270        analysis_order
271    }
272
273    /// Helper function to perform a recursive depth-first search.
274    fn dfs_post_order(
275        &self,
276        node_id: usize,
277        visited: &mut HashSet<usize>,
278        post_order_ids: &mut Vec<usize>,
279    ) {
280        // Mark the current node as visited
281        visited.insert(node_id);
282
283        // Visit all callees (children) of the current node
284        if let Some(callees) = self.fn_calls.get(&node_id) {
285            for (callee_id, _terminator) in callees {
286                if !visited.contains(callee_id) {
287                    self.dfs_post_order(*callee_id, visited, post_order_ids);
288                }
289            }
290        }
291
292        // After visiting all children, add the current node to the post-order list
293        post_order_ids.push(node_id);
294    }
295}