rapx/analysis/core/callgraph/
default.rs

1use rustc_hir::{def::DefKind, def_id::DefId};
2use rustc_middle::{
3    mir::{self, Body},
4    ty::TyCtxt,
5};
6use std::collections::HashSet;
7use std::{collections::HashMap, hash::Hash};
8
9use super::visitor::CallGraphVisitor;
10use crate::{
11    analysis::core::callgraph::{CallGraph, CallGraphAnalysis},
12    rap_debug, rap_info, Analysis,
13};
14
15pub struct CallGraphAnalyzer<'tcx> {
16    pub tcx: TyCtxt<'tcx>,
17    pub graph: CallGraphInfo<'tcx>,
18}
19
20impl<'tcx> Analysis for CallGraphAnalyzer<'tcx> {
21    fn name(&self) -> &'static str {
22        "Default call graph analysis algorithm."
23    }
24
25    fn run(&mut self) {
26        let mut analysis = CallGraphAnalyzer::new(self.tcx);
27        analysis.start();
28    }
29
30    fn reset(&mut self) {
31        todo!();
32    }
33}
34
35impl<'tcx> CallGraphAnalysis for CallGraphAnalyzer<'tcx> {
36    fn get_callgraph(&mut self) -> CallGraph {
37        let fn_calls: HashMap<DefId, Vec<DefId>> = self
38            .graph
39            .fn_calls
40            .clone()
41            .into_iter()
42            .map(|(caller, callees)| {
43                let caller_id = self
44                    .graph
45                    .functions
46                    .get(&caller)
47                    .expect("Key must exist in functions map")
48                    .def_id;
49
50                let callees_id = callees
51                    .into_iter()
52                    .map(|(callee, _)| {
53                        self.graph
54                            .functions
55                            .get(&callee)
56                            .expect("Value must exist in functions map")
57                            .def_id
58                    })
59                    .collect::<Vec<_>>();
60                (caller_id, callees_id)
61            })
62            .collect();
63        CallGraph { fn_calls }
64    }
65}
66
67impl<'tcx> CallGraphAnalyzer<'tcx> {
68    pub fn new(tcx: TyCtxt<'tcx>) -> Self {
69        Self {
70            tcx: tcx,
71            graph: CallGraphInfo::new(),
72        }
73    }
74
75    pub fn start(&mut self) {
76        for local_def_id in self.tcx.iter_local_def_id() {
77            if self.tcx.hir_maybe_body_owned_by(local_def_id).is_some() {
78                let def_id = local_def_id.to_def_id();
79                if self.tcx.is_mir_available(def_id) {
80                    let def_kind = self.tcx.def_kind(def_id);
81
82                    let body: &Body<'_> = match def_kind {
83                        DefKind::Fn | DefKind::AssocFn => &self.tcx.optimized_mir(def_id),
84                        DefKind::Const
85                        | DefKind::Static { .. }
86                        | DefKind::AssocConst
87                        | DefKind::InlineConst
88                        | DefKind::AnonConst => {
89                            // NOTE: safer fallback for constants
90                            &self.tcx.mir_for_ctfe(def_id)
91                        }
92                        // These don't have MIR or shouldn't be visited
93                        _ => {
94                            rap_debug!("Skipping def_id {:?} with kind {:?}", def_id, def_kind);
95                            continue;
96                        }
97                    };
98
99                    let mut call_graph_visitor =
100                        CallGraphVisitor::new(self.tcx, def_id.into(), body, &mut self.graph);
101                    call_graph_visitor.visit();
102                }
103            }
104        }
105    }
106
107    pub fn get_callee_def_path(&self, def_path: String) -> Option<HashSet<String>> {
108        self.graph.get_callees_path(&def_path)
109    }
110}
111
112#[derive(Debug, Clone, Eq, PartialEq, Hash)]
113pub struct Node {
114    def_id: DefId,
115    def_path: String,
116}
117
118impl Node {
119    pub fn new(def_id: DefId, def_path: &String) -> Self {
120        Self {
121            def_id: def_id,
122            def_path: def_path.clone(),
123        }
124    }
125
126    pub fn get_def_id(&self) -> DefId {
127        self.def_id
128    }
129
130    pub fn get_def_path(&self) -> String {
131        self.def_path.clone()
132    }
133}
134
135pub struct CallGraphInfo<'tcx> {
136    pub functions: HashMap<usize, Node>, // id -> node
137    pub fn_calls: HashMap<usize, Vec<(usize, &'tcx mir::Terminator<'tcx>)>>, // caller_id -> Vec<(callee_id, terminator)>
138    pub node_registry: HashMap<String, usize>,                               // path -> id
139}
140
141impl<'tcx> CallGraphInfo<'tcx> {
142    pub fn new() -> Self {
143        Self {
144            functions: HashMap::new(),
145            fn_calls: HashMap::new(),
146            node_registry: HashMap::new(),
147        }
148    }
149
150    pub fn get_node_num(&self) -> usize {
151        self.functions.len()
152    }
153
154    pub fn get_callees_path(&self, caller_def_path: &String) -> Option<HashSet<String>> {
155        let mut callees_path: HashSet<String> = HashSet::new();
156        if let Some(caller_id) = self.node_registry.get(caller_def_path) {
157            if let Some(callees) = self.fn_calls.get(caller_id) {
158                for (id, _terminator) in callees {
159                    if let Some(callee_node) = self.functions.get(id) {
160                        callees_path.insert(callee_node.get_def_path());
161                    }
162                }
163            }
164            Some(callees_path)
165        } else {
166            None
167        }
168    }
169
170    pub fn add_node(&mut self, def_id: DefId, def_path: &String) {
171        if self.node_registry.get(def_path).is_none() {
172            let id = self.node_registry.len();
173            let node = Node::new(def_id, def_path);
174            self.node_registry.insert(def_path.clone(), id);
175            self.functions.insert(id, node);
176        }
177    }
178
179    pub fn add_funciton_call_edge(
180        &mut self,
181        caller_id: usize,
182        callee_id: usize,
183        terminator_stmt: &'tcx mir::Terminator<'tcx>,
184    ) {
185        let entry = self.fn_calls.entry(caller_id).or_insert_with(Vec::new);
186        entry.push((callee_id, terminator_stmt));
187    }
188
189    pub fn get_node_by_path(&self, def_path: &String) -> Option<usize> {
190        self.node_registry.get(def_path).copied()
191    }
192    pub fn get_callers_map(&self) -> HashMap<usize, Vec<(usize, &'tcx mir::Terminator<'tcx>)>> {
193        let mut callers_map: HashMap<usize, Vec<(usize, &'tcx mir::Terminator<'tcx>)>> =
194            HashMap::new();
195
196        for (&caller_id, calls_vec) in &self.fn_calls {
197            for (callee_id, terminator) in calls_vec {
198                callers_map
199                    .entry(*callee_id)
200                    .or_insert_with(Vec::new)
201                    .push((caller_id, *terminator));
202            }
203        }
204        callers_map
205    }
206
207    pub fn display(&self) {
208        rap_info!("CallGraph Analysis:");
209        for (caller_id, callees) in &self.fn_calls {
210            if let Some(caller_node) = self.functions.get(caller_id) {
211                for (callee_id, terminator_stmt) in callees {
212                    if let Some(callee_node) = self.functions.get(callee_id) {
213                        let caller_def_path = caller_node.get_def_path();
214                        let callee_def_path = callee_node.get_def_path();
215                        rap_info!(
216                            "{}:{} -> {}:{} @ {:?}",
217                            caller_id,
218                            caller_def_path,
219                            *callee_id,
220                            callee_def_path,
221                            terminator_stmt.kind
222                        );
223                    }
224                }
225            }
226        }
227    }
228
229    pub fn get_reverse_post_order(&self) -> Vec<DefId> {
230        let mut visited = HashSet::new();
231        let mut post_order_ids = Vec::new(); // Will store the post-order traversal of `usize` IDs
232
233        // Iterate over all functions defined in the graph to handle disconnected components
234        for &node_id in self.functions.keys() {
235            if !visited.contains(&node_id) {
236                self.dfs_post_order(node_id, &mut visited, &mut post_order_ids);
237            }
238        }
239
240        // Map the ordered `usize` IDs back to `DefId`s for the analysis pipeline
241        let mut analysis_order: Vec<DefId> = post_order_ids
242            .into_iter()
243            .map(|id| {
244                self.functions
245                    .get(&id)
246                    .expect("Node ID must exist in functions map")
247                    .def_id
248            })
249            .collect();
250
251        // Reversing the post-order gives a topological sort (bottom-up)
252        analysis_order.reverse();
253
254        analysis_order
255    }
256
257    /// Helper function to perform a recursive depth-first search.
258    fn dfs_post_order(
259        &self,
260        node_id: usize,
261        visited: &mut HashSet<usize>,
262        post_order_ids: &mut Vec<usize>,
263    ) {
264        // Mark the current node as visited
265        visited.insert(node_id);
266
267        // Visit all callees (children) of the current node
268        if let Some(callees) = self.fn_calls.get(&node_id) {
269            for (callee_id, _terminator) in callees {
270                if !visited.contains(callee_id) {
271                    self.dfs_post_order(*callee_id, visited, post_order_ids);
272                }
273            }
274        }
275
276        // After visiting all children, add the current node to the post-order list
277        post_order_ids.push(node_id);
278    }
279}