1use crate::analysis::senryx::matcher::parse_unsafe_api;
2use crate::analysis::unsafety_isolation::generate_dot::NodeType;
3use rustc_hir::def::DefKind;
4use rustc_hir::def_id::DefId;
5use rustc_hir::ImplItemKind;
6use rustc_middle::mir::Local;
7use rustc_middle::mir::{BasicBlock, Terminator};
8use rustc_middle::ty::{AssocKind, Mutability, Ty, TyCtxt, TyKind};
9use rustc_middle::{
10 mir::{Operand, TerminatorKind},
11 ty,
12};
13use rustc_span::def_id::LocalDefId;
14use rustc_span::kw;
15use rustc_span::sym;
16use std::collections::HashMap;
17use std::collections::HashSet;
18use std::fmt::Debug;
19use std::hash::Hash;
20
21pub fn generate_node_ty(tcx: TyCtxt<'_>, def_id: DefId) -> NodeType {
22 (def_id, check_safety(tcx, def_id), get_type(tcx, def_id))
23}
24
25pub fn check_visibility(tcx: TyCtxt<'_>, func_defid: DefId) -> bool {
26 if !tcx.visibility(func_defid).is_public() {
27 return false;
28 }
29 true
40}
41
42pub fn is_re_exported(tcx: TyCtxt<'_>, target_defid: DefId, module_defid: LocalDefId) -> bool {
43 for child in tcx.module_children_local(module_defid) {
44 if child.vis.is_public() {
45 if let Some(def_id) = child.res.opt_def_id() {
46 if def_id == target_defid {
47 return true;
48 }
49 }
50 }
51 }
52 false
53}
54
55pub fn print_hashset<T: std::fmt::Debug>(set: &HashSet<T>) {
56 for item in set {
57 println!("{:?}", item);
58 }
59 println!("---------------");
60}
61
62pub fn get_cleaned_def_path_name(tcx: TyCtxt<'_>, def_id: DefId) -> String {
63 let def_id_str = format!("{:?}", def_id);
64 let mut parts: Vec<&str> = def_id_str
65 .split("::")
66 .collect();
68
69 let mut remove_first = false;
70 if let Some(first_part) = parts.get_mut(0) {
71 if first_part.contains("core") {
72 *first_part = "core";
73 } else if first_part.contains("std") {
74 *first_part = "std";
75 } else if first_part.contains("alloc") {
76 *first_part = "alloc";
77 } else {
78 remove_first = true;
79 }
80 }
81 if remove_first && !parts.is_empty() {
82 parts.remove(0);
83 }
84
85 let new_parts: Vec<String> = parts
86 .into_iter()
87 .filter_map(|s| {
88 if s.contains("{") {
89 if remove_first {
90 get_struct_name(tcx, def_id)
91 } else {
92 None
93 }
94 } else {
95 Some(s.to_string())
96 }
97 })
98 .collect();
99
100 let mut cleaned_path = new_parts.join("::");
101 cleaned_path = cleaned_path.trim_end_matches(')').to_string();
102 cleaned_path
103}
104
105pub fn get_sp_json() -> serde_json::Value {
106 let json_data: serde_json::Value =
107 serde_json::from_str(include_str!("../unsafety_isolation/data/std_sps.json"))
108 .expect("Unable to parse JSON");
109 json_data
110}
111
112pub fn get_sp(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<String> {
113 let cleaned_path_name = get_cleaned_def_path_name(tcx, def_id);
114 let json_data: serde_json::Value = get_sp_json();
115
116 if let Some(function_info) = json_data.get(&cleaned_path_name) {
117 if let Some(sp_list) = function_info.get("0") {
118 let mut result = HashSet::new();
119 if let Some(sp_array) = sp_list.as_array() {
120 for sp in sp_array {
121 if let Some(sp_name) = sp.as_str() {
122 result.insert(sp_name.to_string());
123 }
124 }
125 }
126 return result;
127 }
128 }
129 HashSet::new()
130}
131
132pub fn get_struct_name(tcx: TyCtxt<'_>, def_id: DefId) -> Option<String> {
133 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
134 if let Some(impl_id) = assoc_item.impl_container(tcx) {
135 let ty = tcx.type_of(impl_id).skip_binder();
136 let type_name = ty.to_string();
137 let struct_name = type_name
138 .split('<')
139 .next()
140 .unwrap_or("")
141 .split("::")
142 .last()
143 .unwrap_or("")
144 .to_string();
145
146 return Some(struct_name);
147 }
148 }
149 None
150}
151
152pub fn check_safety(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
153 let poly_fn_sig = tcx.fn_sig(def_id);
154 let fn_sig = poly_fn_sig.skip_binder();
155 fn_sig.safety() == rustc_hir::Safety::Unsafe
156}
157
158pub fn get_type(tcx: TyCtxt<'_>, def_id: DefId) -> usize {
160 let mut node_type = 2;
161 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
162 match assoc_item.kind {
163 AssocKind::Fn { has_self, .. } => {
164 if has_self {
165 node_type = 1;
166 } else {
167 let fn_sig = tcx.fn_sig(def_id).skip_binder();
168 let output = fn_sig.output().skip_binder();
169 if output.is_param(0) {
171 node_type = 0;
172 }
173 if let Some(impl_id) = assoc_item.impl_container(tcx) {
175 let ty = tcx.type_of(impl_id).skip_binder();
176 if output == ty {
177 node_type = 0;
178 }
179 }
180 match output.kind() {
181 TyKind::Ref(_, ref_ty, _) => {
182 if ref_ty.is_param(0) {
183 node_type = 0;
184 }
185 if let Some(impl_id) = assoc_item.impl_container(tcx) {
186 let ty = tcx.type_of(impl_id).skip_binder();
187 if *ref_ty == ty {
188 node_type = 0;
189 }
190 }
191 }
192 TyKind::Adt(adt_def, substs) => {
193 if adt_def.is_enum()
194 && (tcx.is_diagnostic_item(sym::Option, adt_def.did())
195 || tcx.is_diagnostic_item(sym::Result, adt_def.did())
196 || tcx.is_diagnostic_item(kw::Box, adt_def.did()))
197 {
198 let inner_ty = substs.type_at(0);
199 if inner_ty.is_param(0) {
200 node_type = 0;
201 }
202 if let Some(impl_id) = assoc_item.impl_container(tcx) {
203 let ty_impl = tcx.type_of(impl_id).skip_binder();
204 if inner_ty == ty_impl {
205 node_type = 0;
206 }
207 }
208 }
209 }
210 _ => {}
211 }
212 }
213 }
214 _ => todo!(),
215 }
216 }
217 node_type
218}
219
220pub fn get_adt_ty(tcx: TyCtxt<'_>, def_id: DefId) -> Option<Ty> {
221 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
222 if let Some(impl_id) = assoc_item.impl_container(tcx) {
223 return Some(tcx.type_of(impl_id).skip_binder());
224 }
225 }
226 None
227}
228
229pub fn get_cons(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<NodeType> {
230 let mut cons = Vec::new();
231 if tcx.def_kind(def_id) == DefKind::Fn || get_type(tcx, def_id) == 0 {
232 return cons;
233 }
234 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
235 if let Some(impl_id) = assoc_item.impl_container(tcx) {
236 let ty = tcx.type_of(impl_id).skip_binder();
238 if let Some(adt_def) = ty.ty_adt_def() {
239 let adt_def_id = adt_def.did();
240 let impls = tcx.inherent_impls(adt_def_id);
241 for impl_def_id in impls {
242 for item in tcx.associated_item_def_ids(impl_def_id) {
243 if (tcx.def_kind(item) == DefKind::Fn
244 || tcx.def_kind(item) == DefKind::AssocFn)
245 && get_type(tcx, *item) == 0
246 {
247 cons.push(generate_node_ty(tcx, *item));
248 }
249 }
250 }
251 }
252 }
253 }
254 cons
255}
256
257pub fn get_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
258 let mut callees = HashSet::new();
259 if tcx.is_mir_available(def_id) {
260 let body = tcx.optimized_mir(def_id);
261 for bb in body.basic_blocks.iter() {
262 if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
263 if let Operand::Constant(func_constant) = func {
264 if let ty::FnDef(ref callee_def_id, _) = func_constant.const_.ty().kind() {
265 if check_safety(tcx, *callee_def_id)
266 {
268 let sp_set = get_sp(tcx, *callee_def_id);
269 if sp_set.len() != 0 {
270 callees.insert(*callee_def_id);
271 }
272 }
273 }
274 }
275 }
276 }
277 }
278 callees
279}
280
281pub fn get_impls_for_struct(tcx: TyCtxt<'_>, struct_def_id: DefId) -> Vec<DefId> {
283 let mut impls = Vec::new();
284 for impl_item_id in tcx.hir_crate_items(()).impl_items() {
285 let impl_item = tcx.hir_impl_item(impl_item_id);
286 match impl_item.kind {
287 ImplItemKind::Type(ty) => {
288 if let rustc_hir::TyKind::Path(ref qpath) = ty.kind {
289 if let rustc_hir::QPath::Resolved(_, path) = qpath {
290 if let rustc_hir::def::Res::Def(_, ref def_id) = path.res {
291 if *def_id == struct_def_id {
292 impls.push(impl_item.owner_id.to_def_id());
293 }
294 }
295 }
296 }
297 }
298 _ => (),
299 }
300 }
301 impls
302}
303
304pub fn get_adt_def_id_by_adt_method(tcx: TyCtxt<'_>, def_id: DefId) -> Option<DefId> {
305 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
306 if let Some(impl_id) = assoc_item.impl_container(tcx) {
307 let ty = tcx.type_of(impl_id).skip_binder();
309 if let Some(adt_def) = ty.ty_adt_def() {
310 return Some(adt_def.did());
311 }
312 }
313 }
314 None
315}
316
317pub fn get_pointee(matched_ty: Ty<'_>) -> Ty<'_> {
319 let pointee = if let ty::RawPtr(ty_mut, _) = matched_ty.kind() {
321 get_pointee(*ty_mut)
322 } else if let ty::Ref(_, referred_ty, _) = matched_ty.kind() {
323 get_pointee(*referred_ty)
324 } else {
325 matched_ty
326 };
327 pointee
328}
329
330pub fn is_ptr(matched_ty: Ty<'_>) -> bool {
331 if let ty::RawPtr(_, _) = matched_ty.kind() {
332 return true;
333 }
334 false
335}
336
337pub fn is_ref(matched_ty: Ty<'_>) -> bool {
338 if let ty::Ref(_, _, _) = matched_ty.kind() {
339 return true;
340 }
341 false
342}
343
344pub fn has_mut_self_param(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
345 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
346 match assoc_item.kind {
347 AssocKind::Fn { has_self, .. } => {
348 if has_self {
349 let body = tcx.optimized_mir(def_id);
350 let fst_arg = body.local_decls[Local::from_usize(1)].clone();
351 let ty = fst_arg.ty;
352 let is_mut_ref =
353 matches!(ty.kind(), ty::Ref(_, _, mutbl) if *mutbl == Mutability::Mut);
354 return fst_arg.mutability.is_mut() || is_mut_ref;
355 }
356 }
357 _ => (),
358 }
359 }
360 false
361}
362
363pub fn get_all_mutable_methods(tcx: TyCtxt<'_>, def_id: DefId) -> HashMap<DefId, HashSet<usize>> {
366 let mut results = HashMap::new();
367 let adt_def = get_adt_def_id_by_adt_method(tcx, def_id);
368 let public_fields = adt_def.map_or_else(HashSet::new, |def| get_public_fields(tcx, def));
369 let impl_vec = adt_def.map_or_else(Vec::new, |def| get_impls_for_struct(tcx, def));
370 for impl_id in impl_vec {
371 let associated_items = tcx.associated_items(impl_id);
372 for item in associated_items.in_definition_order() {
373 if let AssocKind::Fn {
374 name: _,
375 has_self: _,
376 } = item.kind
377 {
378 let item_def_id = item.def_id;
379 if has_mut_self_param(tcx, item_def_id) {
380 let modified_fields = public_fields.clone();
382 results.insert(item_def_id, modified_fields);
383 }
384 }
385 }
386 }
387 results
388}
389
390pub fn get_public_fields(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<usize> {
392 let adt_def = tcx.adt_def(def_id);
393 adt_def
394 .all_fields()
395 .enumerate()
396 .filter_map(|(index, field_def)| tcx.visibility(field_def.did).is_public().then_some(index))
397 .collect()
398}
399
400pub fn display_hashmap<K, V>(map: &HashMap<K, V>, level: usize)
402where
403 K: Ord + Debug + Hash,
404 V: Debug,
405{
406 let indent = " ".repeat(level);
407 let mut sorted_keys: Vec<_> = map.keys().collect();
408 sorted_keys.sort();
409
410 for key in sorted_keys {
411 if let Some(value) = map.get(key) {
412 println!("{}{:?}: {:?}", indent, key, value);
413 }
414 }
415}
416
417pub fn get_all_std_unsafe_callees(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<String> {
418 let mut results = Vec::new();
419 let body = tcx.optimized_mir(def_id);
420 let bb_len = body.basic_blocks.len();
421 for i in 0..bb_len {
422 let callees = match_std_unsafe_callee(
423 tcx,
424 body.basic_blocks[BasicBlock::from_usize(i)]
425 .clone()
426 .terminator(),
427 );
428 results.extend(callees);
429 }
430 results
431}
432
433pub fn get_all_std_unsafe_callees_block_id(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<usize> {
434 let mut results = Vec::new();
435 let body = tcx.optimized_mir(def_id);
436 let bb_len = body.basic_blocks.len();
437 for i in 0..bb_len {
438 if match_std_unsafe_callee(
439 tcx,
440 body.basic_blocks[BasicBlock::from_usize(i)]
441 .clone()
442 .terminator(),
443 )
444 .is_empty()
445 {
446 results.push(i);
447 }
448 }
449 results
450}
451
452pub fn match_std_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
453 let mut results = Vec::new();
454 if let TerminatorKind::Call { func, .. } = &terminator.kind {
455 if let Operand::Constant(func_constant) = func {
456 if let ty::FnDef(ref callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
457 let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
458 if parse_unsafe_api(&func_name).is_some() {
459 results.push(func_name);
460 }
461 }
462 }
463 }
464 results
465}
466
467pub fn is_strict_ty_convert<'tcx>(tcx: TyCtxt<'tcx>, src_ty: Ty<'tcx>, dst_ty: Ty<'tcx>) -> bool {
470 (is_strict_ty(tcx, src_ty) && dst_ty.is_mutable_ptr()) || is_strict_ty(tcx, dst_ty)
471}
472
473pub fn is_strict_ty<'tcx>(tcx: TyCtxt<'tcx>, ori_ty: Ty<'tcx>) -> bool {
475 let ty = get_pointee(ori_ty);
476 let mut flag = false;
477 if let TyKind::Adt(adt_def, substs) = ty.kind() {
478 if adt_def.is_struct() {
479 for field_def in adt_def.all_fields() {
480 flag |= is_strict_ty(tcx, field_def.ty(tcx, substs))
481 }
482 }
483 }
484 ty.is_bool() || ty.is_str() || flag
485}