rapx/analysis/core/callgraph/
visitor.rs

1use super::default::CallGraph;
2use rustc_hir::def_id::DefId;
3use rustc_middle::mir;
4use rustc_middle::ty::{FnDef, Instance, InstanceKind, TyCtxt, TypingEnv};
5use std::collections::HashSet;
6
7pub struct CallGraphVisitor<'b, 'tcx> {
8    tcx: TyCtxt<'tcx>,
9    def_id: DefId,
10    body: &'tcx mir::Body<'tcx>,
11    call_graph_info: &'b mut CallGraph<'tcx>,
12}
13
14impl<'b, 'tcx> CallGraphVisitor<'b, 'tcx> {
15    pub fn new(
16        tcx: TyCtxt<'tcx>,
17        def_id: DefId,
18        body: &'tcx mir::Body<'tcx>,
19        call_graph_info: &'b mut CallGraph<'tcx>,
20    ) -> Self {
21        Self {
22            tcx: tcx,
23            def_id: def_id,
24            body: body,
25            call_graph_info: call_graph_info,
26        }
27    }
28
29    fn add_fn_call(&mut self, callee_def_id: DefId, terminator: &'tcx mir::Terminator<'tcx>) {
30        self.call_graph_info.register_fn(callee_def_id);
31        self.call_graph_info.add_funciton_call(
32            self.def_id.clone(),
33            callee_def_id,
34            Some(terminator),
35        );
36    }
37
38    fn handle_fn_call(
39        &mut self,
40        callee_def_id: DefId,
41        is_virtual: bool,
42        terminator: &'tcx mir::Terminator<'tcx>,
43    ) {
44        if is_virtual {
45            // Handle dynamic dispatch for trait objects
46            self.handle_virtual_call(callee_def_id, terminator);
47        } else {
48            self.add_fn_call(callee_def_id, terminator);
49        }
50    }
51
52    fn handle_virtual_call(
53        &mut self,
54        stub_def_id: DefId, // Callee is the dynamic call stub, i.e. the fn definition in trait
55        terminator: &'tcx mir::Terminator<'tcx>,
56    ) {
57        // Step 1: Add an edge from caller to the virtual function (stub);
58        // If the DefId exists, we assume that stub has been analyzed.
59        let visited = !self.call_graph_info.register_fn(stub_def_id);
60        self.add_fn_call(stub_def_id, terminator);
61
62        // If this function has already been analyzed, return;
63        if visited {
64            return;
65        }
66
67        // Step 2: Find all impls of the virtual function;
68        let mut candidates: HashSet<DefId> = HashSet::new();
69        if let Some(trait_def_id) = self.tcx.trait_of_assoc(stub_def_id) {
70            rap_debug!(
71                "[Callgraph] Virtual fn {:?} belongs to trait {:?}",
72                stub_def_id,
73                trait_def_id
74            );
75            for impl_id in self.tcx.all_impls(trait_def_id) {
76                let impl_map = self.tcx.impl_item_implementor_ids(impl_id);
77                if let Some(candidate_def_id) = impl_map.get(&stub_def_id) {
78                    candidates.insert(*candidate_def_id);
79                }
80            }
81        }
82        rap_debug!(
83            "[Callgraph] Implementors of {:?}: {:?}",
84            stub_def_id,
85            candidates
86        );
87
88        // Step 3: For each implementor, add an edge from the stub to it.
89        for candidate_def_id in candidates {
90            self.add_fn_call(candidate_def_id, terminator);
91        }
92    }
93
94    pub fn visit(&mut self) {
95        self.call_graph_info.register_fn(self.def_id);
96        for (_, data) in self.body.basic_blocks.iter().enumerate() {
97            let terminator = data.terminator();
98            self.visit_terminator(&terminator);
99        }
100    }
101
102    fn visit_terminator(&mut self, terminator: &'tcx mir::Terminator<'tcx>) {
103        if let mir::TerminatorKind::Call { func, .. } = &terminator.kind {
104            if let mir::Operand::Constant(constant) = func {
105                if let FnDef(callee_def_id, callee_substs) = constant.const_.ty().kind() {
106                    let ty_env = TypingEnv::post_analysis(self.tcx, self.def_id);
107                    if let Ok(Some(instance)) =
108                        Instance::try_resolve(self.tcx, ty_env, *callee_def_id, callee_substs)
109                    {
110                        let mut is_virtual = false;
111                        // Try to analysis the specific type of callee.
112                        let instance_def_id = match instance.def {
113                            InstanceKind::Item(def_id) => Some(def_id),
114                            InstanceKind::Intrinsic(def_id) => Some(def_id),
115                            InstanceKind::VTableShim(def_id) => Some(def_id),
116                            InstanceKind::ReifyShim(def_id, _) => Some(def_id),
117                            InstanceKind::FnPtrShim(def_id, _) => Some(def_id),
118                            InstanceKind::Virtual(def_id, _) => {
119                                is_virtual = true;
120                                Some(def_id)
121                            }
122                            InstanceKind::ClosureOnceShim { call_once, .. } => Some(call_once),
123                            InstanceKind::ConstructCoroutineInClosureShim {
124                                coroutine_closure_def_id,
125                                ..
126                            } => Some(coroutine_closure_def_id),
127                            InstanceKind::ThreadLocalShim(def_id) => Some(def_id),
128                            InstanceKind::DropGlue(def_id, _) => Some(def_id),
129                            InstanceKind::FnPtrAddrShim(def_id, _) => Some(def_id),
130                            InstanceKind::AsyncDropGlueCtorShim(def_id, _) => Some(def_id),
131                            InstanceKind::CloneShim(def_id, _) => {
132                                if !self.tcx.is_closure_like(def_id) {
133                                    // Not a closure
134                                    Some(def_id)
135                                } else {
136                                    None
137                                }
138                            }
139                            _ => todo!(),
140                        };
141                        if let Some(instance_def_id) = instance_def_id {
142                            self.handle_fn_call(instance_def_id, is_virtual, terminator);
143                        }
144                    } else {
145                        // Although failing to get specific type, callee is still useful.
146                        self.handle_fn_call(*callee_def_id, false, terminator);
147                    }
148                }
149            }
150        }
151    }
152}