rapx/analysis/core/callgraph/
default.rs

1use rustc_hir::{def::DefKind, def_id::DefId};
2use rustc_middle::{
3    mir::{self, Body},
4    ty::TyCtxt,
5};
6use std::collections::HashMap;
7use std::collections::HashSet;
8
9use super::visitor::CallGraphVisitor;
10use crate::{
11    Analysis,
12    analysis::core::callgraph::{CallGraphAnalysis, FnCallMap},
13};
14
15pub struct CallGraphAnalyzer<'tcx> {
16    pub tcx: TyCtxt<'tcx>,
17    pub graph: CallGraph<'tcx>,
18}
19
20impl<'tcx> Analysis for CallGraphAnalyzer<'tcx> {
21    fn name(&self) -> &'static str {
22        "Default call graph analysis algorithm."
23    }
24
25    fn run(&mut self) {
26        self.start();
27    }
28
29    fn reset(&mut self) {
30        todo!();
31    }
32}
33
34impl<'tcx> CallGraphAnalysis for CallGraphAnalyzer<'tcx> {
35    fn get_fn_calls(&self) -> FnCallMap {
36        let fn_calls: HashMap<DefId, Vec<DefId>> = self
37            .graph
38            .fn_calls
39            .clone()
40            .into_iter()
41            .map(|(caller, callees)| {
42                let callee_ids = callees.into_iter().map(|(did, _)| did).collect::<Vec<_>>();
43                (caller, callee_ids)
44            })
45            .collect();
46        fn_calls
47    }
48}
49
50impl<'tcx> CallGraphAnalyzer<'tcx> {
51    pub fn new(tcx: TyCtxt<'tcx>) -> Self {
52        Self {
53            tcx: tcx,
54            graph: CallGraph::new(tcx),
55        }
56    }
57
58    pub fn start(&mut self) {
59        for local_def_id in self.tcx.mir_keys(()) {
60            let def_id = local_def_id.to_def_id();
61            if self.tcx.is_mir_available(def_id) {
62                let def_kind = self.tcx.def_kind(def_id);
63
64                let body: &Body<'_> = match def_kind {
65                    DefKind::Fn | DefKind::AssocFn | DefKind::Closure => {
66                        &self.tcx.optimized_mir(def_id)
67                    }
68                    DefKind::Const
69                    | DefKind::Static { .. }
70                    | DefKind::AssocConst
71                    | DefKind::InlineConst
72                    | DefKind::AnonConst => {
73                        // NOTE: safer fallback for constants
74                        &self.tcx.mir_for_ctfe(def_id)
75                    }
76                    // These don't have MIR or shouldn't be visited
77                    _ => {
78                        rap_debug!("Skipping def_id {:?} with kind {:?}", def_id, def_kind);
79                        continue;
80                    }
81                };
82
83                let mut call_graph_visitor =
84                    CallGraphVisitor::new(self.tcx, def_id.into(), body, &mut self.graph);
85                call_graph_visitor.visit();
86            }
87        }
88    }
89}
90
91pub type CallMap<'tcx> = HashMap<DefId, Vec<(DefId, Option<&'tcx mir::Terminator<'tcx>>)>>;
92
93pub struct CallGraph<'tcx> {
94    pub tcx: TyCtxt<'tcx>,
95    pub functions: HashSet<DefId>, // Function-like, including closures
96    pub fn_calls: CallMap<'tcx>,   // caller -> Vec<(callee, terminator)>
97}
98
99/// Internal apis for constructing a call graph
100impl<'tcx> CallGraph<'tcx> {
101    pub fn new(tcx: TyCtxt<'tcx>) -> Self {
102        Self {
103            tcx,
104            functions: HashSet::new(),
105            fn_calls: HashMap::new(),
106        }
107    }
108
109    /// Register a function to the call graph. Return true on insert, false if that DefId already exists.
110    pub fn register_fn(&mut self, def_id: DefId) -> bool {
111        if let Some(_) = self.functions.iter().find(|func_id| **func_id == def_id) {
112            false
113        } else {
114            self.functions.insert(def_id);
115            true
116        }
117    }
118
119    /// Add a function call to the call graph.
120    pub fn add_funciton_call(
121        &mut self,
122        caller_id: DefId,
123        callee_id: DefId,
124        terminator_stmt: Option<&'tcx mir::Terminator<'tcx>>,
125    ) {
126        let entry = self.fn_calls.entry(caller_id).or_insert_with(Vec::new);
127        entry.push((callee_id, terminator_stmt));
128    }
129}
130
131/// Public apis to get information from the call graph
132impl<'tcx> CallGraph<'tcx> {
133    pub fn get_reverse_post_order(&self) -> Vec<DefId> {
134        let mut result = self.get_post_order();
135        result.reverse();
136        result
137    }
138
139    pub fn get_post_order(&self) -> Vec<DefId> {
140        let mut visited = HashSet::new();
141        let mut post_order_ids = Vec::new(); // Will store the post-order traversal of `usize` IDs
142
143        // Iterate over all functions defined in the graph to handle disconnected components
144        for &func_def_id in self.functions.iter() {
145            if !visited.contains(&func_def_id) {
146                self.dfs_post_order(func_def_id, &mut visited, &mut post_order_ids);
147            }
148        }
149
150        post_order_ids
151    }
152
153    /// Helper function to perform a recursive depth-first search.
154    fn dfs_post_order(
155        &self,
156        func_def_id: DefId,
157        visited: &mut HashSet<DefId>,
158        post_order_ids: &mut Vec<DefId>,
159    ) {
160        // Mark the current node as visited
161        visited.insert(func_def_id);
162
163        // Visit all callees (children) of the current node
164        if let Some(callees) = self.fn_calls.get(&func_def_id) {
165            for (callee_id, _terminator) in callees {
166                if !visited.contains(callee_id) {
167                    self.dfs_post_order(*callee_id, visited, post_order_ids);
168                }
169            }
170        }
171
172        // After visiting all children, add the current node to the post-order list
173        post_order_ids.push(func_def_id);
174    }
175
176    /// Get a reversed (callee -> Vec<Caller>) call map.
177    pub fn get_callers_map(&self) -> CallMap<'tcx> {
178        let mut callers_map: CallMap<'tcx> = HashMap::new();
179
180        for (&caller_id, calls_vec) in &self.fn_calls {
181            for (callee_id, terminator) in calls_vec {
182                callers_map
183                    .entry(*callee_id)
184                    .or_insert_with(Vec::new)
185                    .push((caller_id, *terminator));
186            }
187        }
188        callers_map
189    }
190
191    /// Get all direct callees' DefId of the caller function
192    pub fn get_callees(&self, caller_def_id: DefId) -> Vec<DefId> {
193        if let Some(callees) = self.fn_calls.get(&caller_def_id) {
194            callees
195                .clone()
196                .into_iter()
197                .map(|(did, _)| did)
198                .collect::<Vec<_>>()
199        } else {
200            vec![]
201        }
202    }
203
204    /// Get all recursively reachable callee's DefId
205    pub fn get_callees_recursive(&self, caller_def_id: DefId) -> Vec<DefId> {
206        let mut visited = HashSet::new();
207        let mut result = Vec::new();
208        self.dfs_post_order(caller_def_id, &mut visited, &mut result);
209        result
210    }
211
212    /// Get all direct callers' DefId of the callee function
213    pub fn get_callers(&self, callee_def_id: DefId) -> Vec<DefId> {
214        let callers_map = self.get_callers_map();
215        if let Some(callers) = callers_map.get(&callee_def_id) {
216            callers
217                .clone()
218                .into_iter()
219                .map(|(did, _)| did)
220                .collect::<Vec<_>>()
221        } else {
222            vec![]
223        }
224    }
225}