rapx/analysis/utils/
fn_info.rs

1use crate::analysis::senryx::matcher::parse_unsafe_api;
2use crate::analysis::unsafety_isolation::generate_dot::NodeType;
3use rustc_hir::def::DefKind;
4use rustc_hir::def_id::DefId;
5use rustc_hir::ImplItemKind;
6use rustc_middle::mir::Local;
7use rustc_middle::mir::{BasicBlock, Terminator};
8use rustc_middle::ty::{AssocKind, Mutability, Ty, TyCtxt, TyKind};
9use rustc_middle::{
10    mir::{Operand, TerminatorKind},
11    ty,
12};
13use rustc_span::def_id::LocalDefId;
14use rustc_span::kw;
15use rustc_span::sym;
16use std::collections::HashMap;
17use std::collections::HashSet;
18use std::fmt::Debug;
19use std::hash::Hash;
20
21pub fn generate_node_ty(tcx: TyCtxt<'_>, def_id: DefId) -> NodeType {
22    (def_id, check_safety(tcx, def_id), get_type(tcx, def_id))
23}
24
25pub fn check_visibility(tcx: TyCtxt<'_>, func_defid: DefId) -> bool {
26    if !tcx.visibility(func_defid).is_public() {
27        return false;
28    }
29    // if func_defid.is_local() {
30    //     if let Some(local_defid) = func_defid.as_local() {
31    //         let module_moddefid = tcx.parent_module_from_def_id(local_defid);
32    //         let module_defid = module_moddefid.to_def_id();
33    //         if !tcx.visibility(module_defid).is_public() {
34    //             // println!("module def id {:?}",UigUnit::get_cleaned_def_path_name(tcx, module_defid));
35    //             return Self::is_re_exported(tcx, func_defid, module_moddefid.to_local_def_id());
36    //         }
37    //     }
38    // }
39    true
40}
41
42pub fn is_re_exported(tcx: TyCtxt<'_>, target_defid: DefId, module_defid: LocalDefId) -> bool {
43    for child in tcx.module_children_local(module_defid) {
44        if child.vis.is_public() {
45            if let Some(def_id) = child.res.opt_def_id() {
46                if def_id == target_defid {
47                    return true;
48                }
49            }
50        }
51    }
52    false
53}
54
55pub fn print_hashset<T: std::fmt::Debug>(set: &HashSet<T>) {
56    for item in set {
57        println!("{:?}", item);
58    }
59    println!("---------------");
60}
61
62pub fn get_cleaned_def_path_name(tcx: TyCtxt<'_>, def_id: DefId) -> String {
63    let def_id_str = format!("{:?}", def_id);
64    let mut parts: Vec<&str> = def_id_str
65        .split("::")
66        // .filter(|part| !part.contains("{")) // 去除包含 "{" 的部分
67        .collect();
68
69    let mut remove_first = false;
70    if let Some(first_part) = parts.get_mut(0) {
71        if first_part.contains("core") {
72            *first_part = "core";
73        } else if first_part.contains("std") {
74            *first_part = "std";
75        } else if first_part.contains("alloc") {
76            *first_part = "alloc";
77        } else {
78            remove_first = true;
79        }
80    }
81    if remove_first && !parts.is_empty() {
82        parts.remove(0);
83    }
84
85    let new_parts: Vec<String> = parts
86        .into_iter()
87        .filter_map(|s| {
88            if s.contains("{") {
89                if remove_first {
90                    get_struct_name(tcx, def_id)
91                } else {
92                    None
93                }
94            } else {
95                Some(s.to_string())
96            }
97        })
98        .collect();
99
100    let mut cleaned_path = new_parts.join("::");
101    cleaned_path = cleaned_path.trim_end_matches(')').to_string();
102    cleaned_path
103}
104
105pub fn get_sp_json() -> serde_json::Value {
106    let json_data: serde_json::Value =
107        serde_json::from_str(include_str!("../unsafety_isolation/data/std_sps.json"))
108            .expect("Unable to parse JSON");
109    json_data
110}
111
112pub fn get_sp(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<String> {
113    let cleaned_path_name = get_cleaned_def_path_name(tcx, def_id);
114    let json_data: serde_json::Value = get_sp_json();
115
116    if let Some(function_info) = json_data.get(&cleaned_path_name) {
117        if let Some(sp_list) = function_info.get("0") {
118            let mut result = HashSet::new();
119            if let Some(sp_array) = sp_list.as_array() {
120                for sp in sp_array {
121                    if let Some(sp_name) = sp.as_str() {
122                        result.insert(sp_name.to_string());
123                    }
124                }
125            }
126            return result;
127        }
128    }
129    HashSet::new()
130}
131
132pub fn get_struct_name(tcx: TyCtxt<'_>, def_id: DefId) -> Option<String> {
133    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
134        if let Some(impl_id) = assoc_item.impl_container(tcx) {
135            let ty = tcx.type_of(impl_id).skip_binder();
136            let type_name = ty.to_string();
137            let struct_name = type_name
138                .split('<')
139                .next()
140                .unwrap_or("")
141                .split("::")
142                .last()
143                .unwrap_or("")
144                .to_string();
145
146            return Some(struct_name);
147        }
148    }
149    None
150}
151
152pub fn check_safety(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
153    let poly_fn_sig = tcx.fn_sig(def_id);
154    let fn_sig = poly_fn_sig.skip_binder();
155    fn_sig.safety() == rustc_hir::Safety::Unsafe
156}
157
158//retval: 0-constructor, 1-method, 2-function
159pub fn get_type(tcx: TyCtxt<'_>, def_id: DefId) -> usize {
160    let mut node_type = 2;
161    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
162        match assoc_item.kind {
163            AssocKind::Fn { has_self, .. } => {
164                if has_self {
165                    node_type = 1;
166                } else {
167                    let fn_sig = tcx.fn_sig(def_id).skip_binder();
168                    let output = fn_sig.output().skip_binder();
169                    // return type is 'Self'
170                    if output.is_param(0) {
171                        node_type = 0;
172                    }
173                    // return type is struct's name
174                    if let Some(impl_id) = assoc_item.impl_container(tcx) {
175                        let ty = tcx.type_of(impl_id).skip_binder();
176                        if output == ty {
177                            node_type = 0;
178                        }
179                    }
180                    match output.kind() {
181                        TyKind::Ref(_, ref_ty, _) => {
182                            if ref_ty.is_param(0) {
183                                node_type = 0;
184                            }
185                            if let Some(impl_id) = assoc_item.impl_container(tcx) {
186                                let ty = tcx.type_of(impl_id).skip_binder();
187                                if *ref_ty == ty {
188                                    node_type = 0;
189                                }
190                            }
191                        }
192                        TyKind::Adt(adt_def, substs) => {
193                            if adt_def.is_enum()
194                                && (tcx.is_diagnostic_item(sym::Option, adt_def.did())
195                                    || tcx.is_diagnostic_item(sym::Result, adt_def.did())
196                                    || tcx.is_diagnostic_item(kw::Box, adt_def.did()))
197                            {
198                                let inner_ty = substs.type_at(0);
199                                if inner_ty.is_param(0) {
200                                    node_type = 0;
201                                }
202                                if let Some(impl_id) = assoc_item.impl_container(tcx) {
203                                    let ty_impl = tcx.type_of(impl_id).skip_binder();
204                                    if inner_ty == ty_impl {
205                                        node_type = 0;
206                                    }
207                                }
208                            }
209                        }
210                        _ => {}
211                    }
212                }
213            }
214            _ => todo!(),
215        }
216    }
217    node_type
218}
219
220pub fn get_adt_ty(tcx: TyCtxt<'_>, def_id: DefId) -> Option<Ty> {
221    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
222        if let Some(impl_id) = assoc_item.impl_container(tcx) {
223            return Some(tcx.type_of(impl_id).skip_binder());
224        }
225    }
226    None
227}
228
229pub fn get_cons(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<NodeType> {
230    let mut cons = Vec::new();
231    if tcx.def_kind(def_id) == DefKind::Fn || get_type(tcx, def_id) == 0 {
232        return cons;
233    }
234    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
235        if let Some(impl_id) = assoc_item.impl_container(tcx) {
236            // get struct ty
237            let ty = tcx.type_of(impl_id).skip_binder();
238            if let Some(adt_def) = ty.ty_adt_def() {
239                let adt_def_id = adt_def.did();
240                let impls = tcx.inherent_impls(adt_def_id);
241                for impl_def_id in impls {
242                    for item in tcx.associated_item_def_ids(impl_def_id) {
243                        if (tcx.def_kind(item) == DefKind::Fn
244                            || tcx.def_kind(item) == DefKind::AssocFn)
245                            && get_type(tcx, *item) == 0
246                        {
247                            cons.push(generate_node_ty(tcx, *item));
248                        }
249                    }
250                }
251            }
252        }
253    }
254    cons
255}
256
257pub fn get_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
258    let mut callees = HashSet::new();
259    if tcx.is_mir_available(def_id) {
260        let body = tcx.optimized_mir(def_id);
261        for bb in body.basic_blocks.iter() {
262            if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
263                if let Operand::Constant(func_constant) = func {
264                    if let ty::FnDef(ref callee_def_id, _) = func_constant.const_.ty().kind() {
265                        if check_safety(tcx, *callee_def_id)
266                        // && check_visibility(tcx, *callee_def_id)
267                        {
268                            let sp_set = get_sp(tcx, *callee_def_id);
269                            if sp_set.len() != 0 {
270                                callees.insert(*callee_def_id);
271                            }
272                        }
273                    }
274                }
275            }
276        }
277    }
278    callees
279}
280
281// return all the impls def id of corresponding struct
282pub fn get_impls_for_struct(tcx: TyCtxt<'_>, struct_def_id: DefId) -> Vec<DefId> {
283    let mut impls = Vec::new();
284    for impl_item_id in tcx.hir_crate_items(()).impl_items() {
285        let impl_item = tcx.hir_impl_item(impl_item_id);
286        match impl_item.kind {
287            ImplItemKind::Type(ty) => {
288                if let rustc_hir::TyKind::Path(ref qpath) = ty.kind {
289                    if let rustc_hir::QPath::Resolved(_, path) = qpath {
290                        if let rustc_hir::def::Res::Def(_, ref def_id) = path.res {
291                            if *def_id == struct_def_id {
292                                impls.push(impl_item.owner_id.to_def_id());
293                            }
294                        }
295                    }
296                }
297            }
298            _ => (),
299        }
300    }
301    impls
302}
303
304pub fn get_adt_def_id_by_adt_method(tcx: TyCtxt<'_>, def_id: DefId) -> Option<DefId> {
305    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
306        if let Some(impl_id) = assoc_item.impl_container(tcx) {
307            // get struct ty
308            let ty = tcx.type_of(impl_id).skip_binder();
309            if let Some(adt_def) = ty.ty_adt_def() {
310                return Some(adt_def.did());
311            }
312        }
313    }
314    None
315}
316
317// get the pointee or wrapped type
318pub fn get_pointee(matched_ty: Ty<'_>) -> Ty<'_> {
319    // progress_info!("get_pointee: > {:?} as type: {:?}", matched_ty, matched_ty.kind());
320    let pointee = if let ty::RawPtr(ty_mut, _) = matched_ty.kind() {
321        get_pointee(*ty_mut)
322    } else if let ty::Ref(_, referred_ty, _) = matched_ty.kind() {
323        get_pointee(*referred_ty)
324    } else {
325        matched_ty
326    };
327    pointee
328}
329
330pub fn is_ptr(matched_ty: Ty<'_>) -> bool {
331    if let ty::RawPtr(_, _) = matched_ty.kind() {
332        return true;
333    }
334    false
335}
336
337pub fn is_ref(matched_ty: Ty<'_>) -> bool {
338    if let ty::Ref(_, _, _) = matched_ty.kind() {
339        return true;
340    }
341    false
342}
343
344pub fn has_mut_self_param(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
345    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
346        match assoc_item.kind {
347            AssocKind::Fn { has_self, .. } => {
348                if has_self {
349                    let body = tcx.optimized_mir(def_id);
350                    let fst_arg = body.local_decls[Local::from_usize(1)].clone();
351                    let ty = fst_arg.ty;
352                    let is_mut_ref =
353                        matches!(ty.kind(), ty::Ref(_, _, mutbl) if *mutbl == Mutability::Mut);
354                    return fst_arg.mutability.is_mut() || is_mut_ref;
355                }
356            }
357            _ => (),
358        }
359    }
360    false
361}
362
363// Input the adt def id
364// Return set of (mutable method def_id, fields can be modified)
365pub fn get_all_mutable_methods(tcx: TyCtxt<'_>, def_id: DefId) -> HashMap<DefId, HashSet<usize>> {
366    let mut results = HashMap::new();
367    let adt_def = get_adt_def_id_by_adt_method(tcx, def_id);
368    let public_fields = adt_def.map_or_else(HashSet::new, |def| get_public_fields(tcx, def));
369    let impl_vec = adt_def.map_or_else(Vec::new, |def| get_impls_for_struct(tcx, def));
370    for impl_id in impl_vec {
371        let associated_items = tcx.associated_items(impl_id);
372        for item in associated_items.in_definition_order() {
373            if let AssocKind::Fn {
374                name: _,
375                has_self: _,
376            } = item.kind
377            {
378                let item_def_id = item.def_id;
379                if has_mut_self_param(tcx, item_def_id) {
380                    // TODO: using dataflow to detect field modificaiton, combined with public fields
381                    let modified_fields = public_fields.clone();
382                    results.insert(item_def_id, modified_fields);
383                }
384            }
385        }
386    }
387    results
388}
389
390// Check each field's visibility, return the public fields vec
391pub fn get_public_fields(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<usize> {
392    let adt_def = tcx.adt_def(def_id);
393    adt_def
394        .all_fields()
395        .enumerate()
396        .filter_map(|(index, field_def)| tcx.visibility(field_def.did).is_public().then_some(index))
397        .collect()
398}
399
400// general function for displaying hashmap
401pub fn display_hashmap<K, V>(map: &HashMap<K, V>, level: usize)
402where
403    K: Ord + Debug + Hash,
404    V: Debug,
405{
406    let indent = "  ".repeat(level);
407    let mut sorted_keys: Vec<_> = map.keys().collect();
408    sorted_keys.sort();
409
410    for key in sorted_keys {
411        if let Some(value) = map.get(key) {
412            println!("{}{:?}: {:?}", indent, key, value);
413        }
414    }
415}
416
417pub fn get_all_std_unsafe_callees(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<String> {
418    let mut results = Vec::new();
419    let body = tcx.optimized_mir(def_id);
420    let bb_len = body.basic_blocks.len();
421    for i in 0..bb_len {
422        let callees = match_std_unsafe_callee(
423            tcx,
424            body.basic_blocks[BasicBlock::from_usize(i)]
425                .clone()
426                .terminator(),
427        );
428        results.extend(callees);
429    }
430    results
431}
432
433pub fn get_all_std_unsafe_callees_block_id(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<usize> {
434    let mut results = Vec::new();
435    let body = tcx.optimized_mir(def_id);
436    let bb_len = body.basic_blocks.len();
437    for i in 0..bb_len {
438        if match_std_unsafe_callee(
439            tcx,
440            body.basic_blocks[BasicBlock::from_usize(i)]
441                .clone()
442                .terminator(),
443        )
444        .is_empty()
445        {
446            results.push(i);
447        }
448    }
449    results
450}
451
452pub fn match_std_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
453    let mut results = Vec::new();
454    if let TerminatorKind::Call { func, .. } = &terminator.kind {
455        if let Operand::Constant(func_constant) = func {
456            if let ty::FnDef(ref callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
457                let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
458                if parse_unsafe_api(&func_name).is_some() {
459                    results.push(func_name);
460                }
461            }
462        }
463    }
464    results
465}
466
467// Bug definition: (1) strict -> weak & dst is mutable;
468//                 (2) _ -> strict
469pub fn is_strict_ty_convert<'tcx>(tcx: TyCtxt<'tcx>, src_ty: Ty<'tcx>, dst_ty: Ty<'tcx>) -> bool {
470    (is_strict_ty(tcx, src_ty) && dst_ty.is_mutable_ptr()) || is_strict_ty(tcx, dst_ty)
471}
472
473// strict ty: bool, str, adt fields containing bool or str;
474pub fn is_strict_ty<'tcx>(tcx: TyCtxt<'tcx>, ori_ty: Ty<'tcx>) -> bool {
475    let ty = get_pointee(ori_ty);
476    let mut flag = false;
477    if let TyKind::Adt(adt_def, substs) = ty.kind() {
478        if adt_def.is_struct() {
479            for field_def in adt_def.all_fields() {
480                flag |= is_strict_ty(tcx, field_def.ty(tcx, substs))
481            }
482        }
483    }
484    ty.is_bool() || ty.is_str() || flag
485}