1pub mod draw_dot;
5pub mod fn_collector;
6pub mod hir_visitor;
7pub mod std_upg;
8pub mod upg_graph;
9pub mod upg_unit;
10
11use crate::{
12 analysis::utils::{draw_dot::render_dot_graphs, fn_info::*},
13 utils::source::{get_fn_name_byid, get_module_name},
14};
15use fn_collector::FnCollector;
16use hir_visitor::ContainsUnsafe;
17use rustc_hir::{Safety, def_id::DefId};
18use rustc_middle::{mir::Local, ty::TyCtxt};
19use std::collections::{HashMap, HashSet};
20use upg_graph::{UPGEdge, UPGraph};
21use upg_unit::UPGUnit;
22
23#[derive(PartialEq)]
24pub enum TargetCrate {
25 Std,
26 Other,
27}
28
29pub struct UPGAnalysis<'tcx> {
30 pub tcx: TyCtxt<'tcx>,
31 pub upgs: Vec<UPGUnit>,
32}
33
34impl<'tcx> UPGAnalysis<'tcx> {
35 pub fn new(tcx: TyCtxt<'tcx>) -> Self {
36 Self {
37 tcx,
38 upgs: Vec::new(),
39 }
40 }
41
42 pub fn start(&mut self, ins: TargetCrate) {
43 match ins {
44 TargetCrate::Std => {
45 self.audit_std_unsafe();
46 return;
47 }
48 _ => {
49 let fns = FnCollector::collect(self.tcx);
55 for vec in fns.values() {
56 for (body_id, _span) in vec {
57 let (fn_unsafe, block_unsafe) =
60 ContainsUnsafe::contains_unsafe(self.tcx, *body_id);
61 let def_id = self.tcx.hir_body_owner_def_id(*body_id).to_def_id();
63 if fn_unsafe | block_unsafe {
64 self.insert_upg(def_id);
65 }
66 }
67 }
68 self.generate_graph_dots();
69 }
70 }
71 }
72
73 pub fn insert_upg(&mut self, def_id: DefId) {
74 let callees = get_unsafe_callees(self.tcx, def_id);
75 let raw_ptrs = get_rawptr_deref(self.tcx, def_id);
76 let global_locals = collect_global_local_pairs(self.tcx, def_id);
77 let static_muts: HashSet<DefId> = global_locals.keys().copied().collect();
78
79 let global_locals_set: HashSet<Local> = global_locals.values().flatten().copied().collect();
81 let raw_ptrs_filtered: HashSet<Local> =
82 raw_ptrs.difference(&global_locals_set).copied().collect();
83
84 let constructors = get_cons(self.tcx, def_id);
85 let caller_typed = append_fn_with_types(self.tcx, def_id);
86 let mut callees_typed = HashSet::new();
87 for callee in &callees {
88 callees_typed.insert(append_fn_with_types(self.tcx, *callee));
89 }
90 let mut cons_typed = HashSet::new();
91 for con in &constructors {
92 cons_typed.insert(append_fn_with_types(self.tcx, *con));
93 }
94
95 let caller_name = get_fn_name_byid(&def_id);
97 if let Some(_) = caller_name.find("__raw_ptr_deref_dummy") {
98 return;
99 }
100
101 if check_safety(self.tcx, def_id) == Safety::Safe
103 && callees.is_empty()
104 && raw_ptrs.is_empty()
105 && static_muts.is_empty()
106 {
107 return;
108 }
109 let mut_methods_set = get_all_mutable_methods(self.tcx, def_id);
110 let mut_methods = mut_methods_set.keys().copied().collect();
111 let upg = UPGUnit::new(
112 caller_typed,
113 callees_typed,
114 raw_ptrs_filtered,
115 static_muts,
116 cons_typed,
117 mut_methods,
118 );
119 self.upgs.push(upg);
120 }
121
122 pub fn generate_graph_dots(&self) {
124 let mut modules_data: HashMap<String, UPGraph> = HashMap::new();
125
126 let mut collect_unit = |unit: &UPGUnit| {
127 let caller_id = unit.caller.def_id;
128 let module_name = get_module_name(self.tcx, caller_id);
129 rap_info!("module name: {:?}", module_name);
130
131 let module_data = modules_data.entry(module_name).or_insert_with(UPGraph::new);
132
133 module_data.add_node(self.tcx, unit.caller, None);
134
135 if let Some(adt) = get_adt_via_method(self.tcx, caller_id) {
136 if adt.literal_cons_enabled {
137 let adt_node_type = FnInfo::new(adt.def_id, Safety::Safe, FnKind::Constructor);
138 let label = format!("Literal Constructor: {}", self.tcx.item_name(adt.def_id));
139 module_data.add_node(self.tcx, adt_node_type, Some(label));
140 if unit.caller.fn_kind == FnKind::Method {
141 module_data.add_edge(adt.def_id, caller_id, UPGEdge::ConsToMethod);
142 }
143 } else {
144 let adt_node_type = FnInfo::new(adt.def_id, Safety::Safe, FnKind::Method);
145 let label = format!(
146 "MutMethod Introduced by PubFields: {}",
147 self.tcx.item_name(adt.def_id)
148 );
149 module_data.add_node(self.tcx, adt_node_type, Some(label));
150 if unit.caller.fn_kind == FnKind::Method {
151 module_data.add_edge(adt.def_id, caller_id, UPGEdge::MutToCaller);
152 }
153 }
154 }
155
156 for cons in &unit.caller_cons {
158 module_data.add_node(self.tcx, *cons, None);
159 module_data.add_edge(cons.def_id, unit.caller.def_id, UPGEdge::ConsToMethod);
160 }
161
162 for mut_method_id in &unit.mut_methods {
164 let node_type = get_type(self.tcx, *mut_method_id);
165 let fn_safety = check_safety(self.tcx, *mut_method_id);
166 let node = FnInfo::new(*mut_method_id, fn_safety, node_type);
167
168 module_data.add_node(self.tcx, node, None);
169 module_data.add_edge(*mut_method_id, unit.caller.def_id, UPGEdge::MutToCaller);
170 }
171
172 for callee in &unit.callees {
174 module_data.add_node(self.tcx, *callee, None);
175 module_data.add_edge(unit.caller.def_id, callee.def_id, UPGEdge::CallerToCallee);
176 }
177
178 rap_debug!("raw ptrs: {:?}", unit.raw_ptrs);
179 if !unit.raw_ptrs.is_empty() {
180 let all_raw_ptrs = unit
181 .raw_ptrs
182 .iter()
183 .map(|p| format!("{:?}", p))
184 .collect::<Vec<_>>()
185 .join(", ");
186
187 match get_ptr_deref_dummy_def_id(self.tcx) {
188 Some(dummy_fn_def_id) => {
189 let rawptr_deref_fn =
190 FnInfo::new(dummy_fn_def_id, Safety::Unsafe, FnKind::Intrinsic);
191 module_data.add_node(
192 self.tcx,
193 rawptr_deref_fn,
194 Some(format!("Raw ptr deref: {}", all_raw_ptrs)),
195 );
196 module_data.add_edge(
197 unit.caller.def_id,
198 dummy_fn_def_id,
199 UPGEdge::CallerToCallee,
200 );
201 }
202 None => {
203 rap_info!("fail to find the dummy ptr deref id.");
204 }
205 }
206 }
207
208 rap_debug!("static muts: {:?}", unit.static_muts);
209 for def_id in &unit.static_muts {
210 let node = FnInfo::new(*def_id, Safety::Unsafe, FnKind::Intrinsic);
211 module_data.add_node(self.tcx, node, None);
212 module_data.add_edge(unit.caller.def_id, *def_id, UPGEdge::CallerToCallee);
213 }
214 };
215
216 for upg in &self.upgs {
218 collect_unit(upg);
219 }
220
221 let mut final_dots = Vec::new();
223 for (mod_name, data) in modules_data {
224 let dot = data.upg_unit_string(&mod_name);
225 final_dots.push((mod_name, dot));
226 }
227 rap_info!("{:?}", final_dots); render_dot_graphs(final_dots);
229 }
230}