rapx/analysis/upg/
mod.rs

1/*
2 * This module generates the unsafety propagation graph for each Rust module in the target crate.
3 */
4pub mod draw_dot;
5pub mod fn_collector;
6pub mod hir_visitor;
7pub mod std_upg;
8pub mod upg_graph;
9pub mod upg_unit;
10
11use crate::{
12    analysis::utils::{draw_dot::render_dot_graphs, fn_info::*},
13    utils::source::{get_fn_name_byid, get_module_name},
14};
15use fn_collector::FnCollector;
16use hir_visitor::ContainsUnsafe;
17use rustc_hir::{Safety, def_id::DefId};
18use rustc_middle::{mir::Local, ty::TyCtxt};
19use std::collections::{HashMap, HashSet};
20use upg_graph::{UPGEdge, UPGraph};
21use upg_unit::UPGUnit;
22
23#[derive(PartialEq)]
24pub enum TargetCrate {
25    Std,
26    Other,
27}
28
29pub struct UPGAnalysis<'tcx> {
30    pub tcx: TyCtxt<'tcx>,
31    pub upgs: Vec<UPGUnit>,
32}
33
34impl<'tcx> UPGAnalysis<'tcx> {
35    pub fn new(tcx: TyCtxt<'tcx>) -> Self {
36        Self {
37            tcx,
38            upgs: Vec::new(),
39        }
40    }
41
42    pub fn start(&mut self, ins: TargetCrate) {
43        match ins {
44            TargetCrate::Std => {
45                self.audit_std_unsafe();
46                return;
47            }
48            _ => {
49                /* Type of collected data: FxHashMap<Option<HirId>, Vec<(BodyId, Span)>>;
50                 * For a function, the Vec contains only one entry;
51                 * For implementations of structs and traits, the Vec contains all associated
52                 * function entries.
53                 */
54                let fns = FnCollector::collect(self.tcx);
55                for vec in fns.values() {
56                    for (body_id, _span) in vec {
57                        // each function or associated function in
58                        // structs and traits
59                        let (fn_unsafe, block_unsafe) =
60                            ContainsUnsafe::contains_unsafe(self.tcx, *body_id);
61                        // map the function body_id back to its def_id;
62                        let def_id = self.tcx.hir_body_owner_def_id(*body_id).to_def_id();
63                        if fn_unsafe | block_unsafe {
64                            self.insert_upg(def_id);
65                        }
66                    }
67                }
68                self.generate_graph_dots();
69            }
70        }
71    }
72
73    pub fn insert_upg(&mut self, def_id: DefId) {
74        let callees = get_unsafe_callees(self.tcx, def_id);
75        let raw_ptrs = get_rawptr_deref(self.tcx, def_id);
76        let global_locals = collect_global_local_pairs(self.tcx, def_id);
77        let static_muts: HashSet<DefId> = global_locals.keys().copied().collect();
78
79        /*Static mutable access is in nature via raw ptr; We have to prune the duplication.*/
80        let global_locals_set: HashSet<Local> = global_locals.values().flatten().copied().collect();
81        let raw_ptrs_filtered: HashSet<Local> =
82            raw_ptrs.difference(&global_locals_set).copied().collect();
83
84        let constructors = get_cons(self.tcx, def_id);
85        let caller_typed = append_fn_with_types(self.tcx, def_id);
86        let mut callees_typed = HashSet::new();
87        for callee in &callees {
88            callees_typed.insert(append_fn_with_types(self.tcx, *callee));
89        }
90        let mut cons_typed = HashSet::new();
91        for con in &constructors {
92            cons_typed.insert(append_fn_with_types(self.tcx, *con));
93        }
94
95        // Skip processing if the caller is the dummy raw pointer dereference function
96        let caller_name = get_fn_name_byid(&def_id);
97        if let Some(_) = caller_name.find("__raw_ptr_deref_dummy") {
98            return;
99        }
100
101        // If the function is entirely safe (no unsafe code, no unsafe callees, no raw pointer dereferences, and no static mutable accesses), skip further analysis
102        if check_safety(self.tcx, def_id) == Safety::Safe
103            && callees.is_empty()
104            && raw_ptrs.is_empty()
105            && static_muts.is_empty()
106        {
107            return;
108        }
109        let mut_methods_set = get_all_mutable_methods(self.tcx, def_id);
110        let mut_methods = mut_methods_set.keys().copied().collect();
111        let upg = UPGUnit::new(
112            caller_typed,
113            callees_typed,
114            raw_ptrs_filtered,
115            static_muts,
116            cons_typed,
117            mut_methods,
118        );
119        self.upgs.push(upg);
120    }
121
122    /// Main function to aggregate data and render DOT graphs per module.
123    pub fn generate_graph_dots(&self) {
124        let mut modules_data: HashMap<String, UPGraph> = HashMap::new();
125
126        let mut collect_unit = |unit: &UPGUnit| {
127            let caller_id = unit.caller.def_id;
128            let module_name = get_module_name(self.tcx, caller_id);
129            rap_info!("module name: {:?}", module_name);
130
131            let module_data = modules_data.entry(module_name).or_insert_with(UPGraph::new);
132
133            module_data.add_node(self.tcx, unit.caller, None);
134
135            if let Some(adt) = get_adt_via_method(self.tcx, caller_id) {
136                if adt.literal_cons_enabled {
137                    let adt_node_type = FnInfo::new(adt.def_id, Safety::Safe, FnKind::Constructor);
138                    let label = format!("Literal Constructor: {}", self.tcx.item_name(adt.def_id));
139                    module_data.add_node(self.tcx, adt_node_type, Some(label));
140                    if unit.caller.fn_kind == FnKind::Method {
141                        module_data.add_edge(adt.def_id, caller_id, UPGEdge::ConsToMethod);
142                    }
143                } else {
144                    let adt_node_type = FnInfo::new(adt.def_id, Safety::Safe, FnKind::Method);
145                    let label = format!(
146                        "MutMethod Introduced by PubFields: {}",
147                        self.tcx.item_name(adt.def_id)
148                    );
149                    module_data.add_node(self.tcx, adt_node_type, Some(label));
150                    if unit.caller.fn_kind == FnKind::Method {
151                        module_data.add_edge(adt.def_id, caller_id, UPGEdge::MutToCaller);
152                    }
153                }
154            }
155
156            // Edge from associated item (constructor) to the method.
157            for cons in &unit.caller_cons {
158                module_data.add_node(self.tcx, *cons, None);
159                module_data.add_edge(cons.def_id, unit.caller.def_id, UPGEdge::ConsToMethod);
160            }
161
162            // Edge from mutable access to the caller.
163            for mut_method_id in &unit.mut_methods {
164                let node_type = get_type(self.tcx, *mut_method_id);
165                let fn_safety = check_safety(self.tcx, *mut_method_id);
166                let node = FnInfo::new(*mut_method_id, fn_safety, node_type);
167
168                module_data.add_node(self.tcx, node, None);
169                module_data.add_edge(*mut_method_id, unit.caller.def_id, UPGEdge::MutToCaller);
170            }
171
172            // Edge representing a call from caller to callee.
173            for callee in &unit.callees {
174                module_data.add_node(self.tcx, *callee, None);
175                module_data.add_edge(unit.caller.def_id, callee.def_id, UPGEdge::CallerToCallee);
176            }
177
178            rap_debug!("raw ptrs: {:?}", unit.raw_ptrs);
179            if !unit.raw_ptrs.is_empty() {
180                let all_raw_ptrs = unit
181                    .raw_ptrs
182                    .iter()
183                    .map(|p| format!("{:?}", p))
184                    .collect::<Vec<_>>()
185                    .join(", ");
186
187                match get_ptr_deref_dummy_def_id(self.tcx) {
188                    Some(dummy_fn_def_id) => {
189                        let rawptr_deref_fn =
190                            FnInfo::new(dummy_fn_def_id, Safety::Unsafe, FnKind::Intrinsic);
191                        module_data.add_node(
192                            self.tcx,
193                            rawptr_deref_fn,
194                            Some(format!("Raw ptr deref: {}", all_raw_ptrs)),
195                        );
196                        module_data.add_edge(
197                            unit.caller.def_id,
198                            dummy_fn_def_id,
199                            UPGEdge::CallerToCallee,
200                        );
201                    }
202                    None => {
203                        rap_info!("fail to find the dummy ptr deref id.");
204                    }
205                }
206            }
207
208            rap_debug!("static muts: {:?}", unit.static_muts);
209            for def_id in &unit.static_muts {
210                let node = FnInfo::new(*def_id, Safety::Unsafe, FnKind::Intrinsic);
211                module_data.add_node(self.tcx, node, None);
212                module_data.add_edge(unit.caller.def_id, *def_id, UPGEdge::CallerToCallee);
213            }
214        };
215
216        // Aggregate all Units
217        for upg in &self.upgs {
218            collect_unit(upg);
219        }
220
221        // Generate string of dot
222        let mut final_dots = Vec::new();
223        for (mod_name, data) in modules_data {
224            let dot = data.upg_unit_string(&mod_name);
225            final_dots.push((mod_name, dot));
226        }
227        rap_info!("{:?}", final_dots); // Output required for tests; do not change.
228        render_dot_graphs(final_dots);
229    }
230}