rapx/analysis/core/callgraph/
default.rs1use rustc_hir::{def::DefKind, def_id::DefId};
2use rustc_middle::{
3    mir::{self, Body},
4    ty::TyCtxt,
5};
6use std::collections::HashSet;
7use std::{collections::HashMap, hash::Hash};
8
9use super::visitor::CallGraphVisitor;
10use crate::{
11    analysis::core::callgraph::{CallGraph, CallGraphAnalysis},
12    rap_debug, rap_info, Analysis,
13};
14
15pub struct CallGraphAnalyzer<'tcx> {
16    pub tcx: TyCtxt<'tcx>,
17    pub graph: CallGraphInfo<'tcx>,
18}
19
20impl<'tcx> Analysis for CallGraphAnalyzer<'tcx> {
21    fn name(&self) -> &'static str {
22        "Default call graph analysis algorithm."
23    }
24
25    fn run(&mut self) {
26        self.start();
27    }
28
29    fn reset(&mut self) {
30        todo!();
31    }
32}
33
34impl<'tcx> CallGraphAnalysis for CallGraphAnalyzer<'tcx> {
35    fn get_callgraph(&mut self) -> CallGraph {
36        let fn_calls: HashMap<DefId, Vec<DefId>> = self
37            .graph
38            .fn_calls
39            .clone()
40            .into_iter()
41            .map(|(caller, callees)| {
42                let caller_id = self
43                    .graph
44                    .functions
45                    .get(&caller)
46                    .expect("Key must exist in functions map")
47                    .def_id;
48
49                let callees_id = callees
50                    .into_iter()
51                    .map(|(callee, _)| {
52                        self.graph
53                            .functions
54                            .get(&callee)
55                            .expect("Value must exist in functions map")
56                            .def_id
57                    })
58                    .collect::<Vec<_>>();
59                (caller_id, callees_id)
60            })
61            .collect();
62        CallGraph { fn_calls }
63    }
64}
65
66impl<'tcx> CallGraphAnalyzer<'tcx> {
67    pub fn new(tcx: TyCtxt<'tcx>) -> Self {
68        Self {
69            tcx: tcx,
70            graph: CallGraphInfo::new(),
71        }
72    }
73
74    pub fn start(&mut self) {
75        for local_def_id in self.tcx.iter_local_def_id() {
76            if self.tcx.hir_maybe_body_owned_by(local_def_id).is_some() {
77                let def_id = local_def_id.to_def_id();
78                if self.tcx.is_mir_available(def_id) {
79                    let def_kind = self.tcx.def_kind(def_id);
80
81                    let body: &Body<'_> = match def_kind {
82                        DefKind::Fn | DefKind::AssocFn => &self.tcx.optimized_mir(def_id),
83                        DefKind::Const
84                        | DefKind::Static { .. }
85                        | DefKind::AssocConst
86                        | DefKind::InlineConst
87                        | DefKind::AnonConst => {
88                            &self.tcx.mir_for_ctfe(def_id)
90                        }
91                        _ => {
93                            rap_debug!("Skipping def_id {:?} with kind {:?}", def_id, def_kind);
94                            continue;
95                        }
96                    };
97
98                    let mut call_graph_visitor =
99                        CallGraphVisitor::new(self.tcx, def_id.into(), body, &mut self.graph);
100                    call_graph_visitor.visit();
101                }
102            }
103        }
104    }
105
106    pub fn get_callee_def_path(&self, def_path: String) -> Option<HashSet<String>> {
107        self.graph.get_callees_path(&def_path)
108    }
109}
110
111#[derive(Debug, Clone, Eq, PartialEq, Hash)]
112pub struct Node {
113    def_id: DefId,
114    def_path: String,
115}
116
117impl Node {
118    pub fn new(def_id: DefId, def_path: &String) -> Self {
119        Self {
120            def_id: def_id,
121            def_path: def_path.clone(),
122        }
123    }
124
125    pub fn get_def_id(&self) -> DefId {
126        self.def_id
127    }
128
129    pub fn get_def_path(&self) -> String {
130        self.def_path.clone()
131    }
132}
133
134pub struct CallGraphInfo<'tcx> {
135    pub functions: HashMap<usize, Node>, pub fn_calls: HashMap<usize, Vec<(usize, Option<&'tcx mir::Terminator<'tcx>>)>>, pub node_registry: HashMap<String, usize>,                                       }
139
140impl<'tcx> CallGraphInfo<'tcx> {
141    pub fn new() -> Self {
142        Self {
143            functions: HashMap::new(),
144            fn_calls: HashMap::new(),
145            node_registry: HashMap::new(),
146        }
147    }
148
149    pub fn get_node_num(&self) -> usize {
150        self.functions.len()
151    }
152
153    pub fn get_callees_path(&self, caller_def_path: &String) -> Option<HashSet<String>> {
154        let mut callees_path: HashSet<String> = HashSet::new();
155        if let Some(caller_id) = self.node_registry.get(caller_def_path) {
156            if let Some(callees) = self.fn_calls.get(caller_id) {
157                for (id, _terminator) in callees {
158                    if let Some(callee_node) = self.functions.get(id) {
159                        callees_path.insert(callee_node.get_def_path());
160                    }
161                }
162            }
163            Some(callees_path)
164        } else {
165            None
166        }
167    }
168
169    pub fn add_node(&mut self, def_id: DefId, def_path: &String) -> usize {
171        if let Some(old_id) = self.node_registry.get(def_path) {
172            *old_id
173        } else {
174            let new_id = self.node_registry.len();
175            let node = Node::new(def_id, def_path);
176            self.node_registry.insert(def_path.clone(), new_id);
177            self.functions.insert(new_id, node);
178            new_id
179        }
180    }
181
182    pub fn add_funciton_call_edge(
183        &mut self,
184        caller_id: usize,
185        callee_id: usize,
186        terminator_stmt: Option<&'tcx mir::Terminator<'tcx>>,
187    ) {
188        let entry = self.fn_calls.entry(caller_id).or_insert_with(Vec::new);
189        entry.push((callee_id, terminator_stmt));
190    }
191
192    pub fn get_node_by_path(&self, def_path: &String) -> Option<usize> {
193        self.node_registry.get(def_path).copied()
194    }
195    pub fn get_callers_map(
196        &self,
197    ) -> HashMap<usize, Vec<(usize, Option<&'tcx mir::Terminator<'tcx>>)>> {
198        let mut callers_map: HashMap<usize, Vec<(usize, Option<&'tcx mir::Terminator<'tcx>>)>> =
199            HashMap::new();
200
201        for (&caller_id, calls_vec) in &self.fn_calls {
202            for (callee_id, terminator) in calls_vec {
203                callers_map
204                    .entry(*callee_id)
205                    .or_insert_with(Vec::new)
206                    .push((caller_id, *terminator));
207            }
208        }
209        callers_map
210    }
211
212    pub fn display(&self) {
213        rap_info!("CallGraph Analysis:");
214        for (caller_id, callees) in &self.fn_calls {
215            if let Some(caller_node) = self.functions.get(caller_id) {
216                for (callee_id, terminator) in callees {
217                    if let Some(callee_node) = self.functions.get(callee_id) {
218                        let caller_def_path = caller_node.get_def_path();
219                        let callee_def_path = callee_node.get_def_path();
220                        if let Some(terminator_stmt) = terminator {
221                            rap_info!(
222                                "{}:{} -> {}:{} @ {:?}",
223                                caller_id,
224                                caller_def_path,
225                                *callee_id,
226                                callee_def_path,
227                                terminator_stmt.kind
228                            );
229                        } else {
230                            rap_info!(
231                                " (Virtual) {}:{} -> {}:{}",
232                                caller_id,
233                                caller_def_path,
234                                *callee_id,
235                                callee_def_path,
236                            );
237                        }
238                    }
239                }
240            }
241        }
242    }
243
244    pub fn get_reverse_post_order(&self) -> Vec<DefId> {
245        let mut visited = HashSet::new();
246        let mut post_order_ids = Vec::new(); for &node_id in self.functions.keys() {
250            if !visited.contains(&node_id) {
251                self.dfs_post_order(node_id, &mut visited, &mut post_order_ids);
252            }
253        }
254
255        let mut analysis_order: Vec<DefId> = post_order_ids
257            .into_iter()
258            .map(|id| {
259                self.functions
260                    .get(&id)
261                    .expect("Node ID must exist in functions map")
262                    .def_id
263            })
264            .collect();
265
266        analysis_order.reverse();
268
269        analysis_order
270    }
271
272    fn dfs_post_order(
274        &self,
275        node_id: usize,
276        visited: &mut HashSet<usize>,
277        post_order_ids: &mut Vec<usize>,
278    ) {
279        visited.insert(node_id);
281
282        if let Some(callees) = self.fn_calls.get(&node_id) {
284            for (callee_id, _terminator) in callees {
285                if !visited.contains(callee_id) {
286                    self.dfs_post_order(*callee_id, visited, post_order_ids);
287                }
288            }
289        }
290
291        post_order_ids.push(node_id);
293    }
294}