rapx/analysis/utils/
fn_info.rs

1use crate::analysis::core::dataflow::default::DataFlowAnalyzer;
2use crate::analysis::core::dataflow::DataFlowAnalysis;
3use crate::analysis::senryx::contracts::property;
4#[allow(unused)]
5use crate::analysis::senryx::contracts::property::PropertyContract;
6use crate::analysis::senryx::matcher::parse_unsafe_api;
7use crate::analysis::unsafety_isolation::draw_dot::render_dot_graphs;
8use crate::analysis::unsafety_isolation::generate_dot::NodeType;
9use crate::analysis::unsafety_isolation::UnsafetyIsolationCheck;
10use crate::rap_debug;
11use crate::rap_warn;
12use rustc_data_structures::fx::FxHashMap;
13use rustc_hir::def::DefKind;
14use rustc_hir::def_id::DefId;
15use rustc_hir::Attribute;
16use rustc_hir::ImplItemKind;
17use rustc_middle::mir::BinOp;
18use rustc_middle::mir::Local;
19use rustc_middle::mir::{BasicBlock, Terminator};
20use rustc_middle::ty::{AssocKind, Mutability, Ty, TyCtxt, TyKind};
21use rustc_middle::{
22    mir::{Operand, TerminatorKind},
23    ty,
24};
25use rustc_span::def_id::LocalDefId;
26use rustc_span::kw;
27use rustc_span::sym;
28use std::collections::HashMap;
29use std::collections::HashSet;
30use std::fmt::Debug;
31use std::hash::Hash;
32use syn::Expr;
33
34pub fn generate_node_ty(tcx: TyCtxt, def_id: DefId) -> NodeType {
35    (def_id, check_safety(tcx, def_id), get_type(tcx, def_id))
36}
37
38pub fn check_visibility(tcx: TyCtxt, func_defid: DefId) -> bool {
39    if !tcx.visibility(func_defid).is_public() {
40        return false;
41    }
42    // if func_defid.is_local() {
43    //     if let Some(local_defid) = func_defid.as_local() {
44    //         let module_moddefid = tcx.parent_module_from_def_id(local_defid);
45    //         let module_defid = module_moddefid.to_def_id();
46    //         if !tcx.visibility(module_defid).is_public() {
47    //             // println!("module def id {:?}",UigUnit::get_cleaned_def_path_name(tcx, module_defid));
48    //             return Self::is_re_exported(tcx, func_defid, module_moddefid.to_local_def_id());
49    //         }
50    //     }
51    // }
52    true
53}
54
55pub fn is_re_exported(tcx: TyCtxt, target_defid: DefId, module_defid: LocalDefId) -> bool {
56    for child in tcx.module_children_local(module_defid) {
57        if child.vis.is_public() {
58            if let Some(def_id) = child.res.opt_def_id() {
59                if def_id == target_defid {
60                    return true;
61                }
62            }
63        }
64    }
65    false
66}
67
68pub fn print_hashset<T: std::fmt::Debug>(set: &HashSet<T>) {
69    for item in set {
70        println!("{:?}", item);
71    }
72    println!("---------------");
73}
74
75pub fn get_cleaned_def_path_name_ori(tcx: TyCtxt, def_id: DefId) -> String {
76    let def_id_str = format!("{:?}", def_id);
77    let mut parts: Vec<&str> = def_id_str.split("::").collect();
78
79    let mut remove_first = false;
80    if let Some(first_part) = parts.get_mut(0) {
81        if first_part.contains("core") {
82            *first_part = "core";
83        } else if first_part.contains("std") {
84            *first_part = "std";
85        } else if first_part.contains("alloc") {
86            *first_part = "alloc";
87        } else {
88            remove_first = true;
89        }
90    }
91    if remove_first && !parts.is_empty() {
92        parts.remove(0);
93    }
94
95    let new_parts: Vec<String> = parts
96        .into_iter()
97        .filter_map(|s| {
98            if s.contains("{") {
99                if remove_first {
100                    get_struct_name(tcx, def_id)
101                } else {
102                    None
103                }
104            } else {
105                Some(s.to_string())
106            }
107        })
108        .collect();
109
110    let mut cleaned_path = new_parts.join("::");
111    cleaned_path = cleaned_path.trim_end_matches(')').to_string();
112    cleaned_path
113}
114
115pub fn get_sp_json() -> serde_json::Value {
116    let json_data: serde_json::Value =
117        serde_json::from_str(include_str!("../unsafety_isolation/data/std_sps.json"))
118            .expect("Unable to parse JSON");
119    json_data
120}
121
122pub fn get_std_api_signature_json() -> serde_json::Value {
123    let json_data: serde_json::Value =
124        serde_json::from_str(include_str!("../unsafety_isolation/data/std_sig.json"))
125            .expect("Unable to parse JSON");
126    json_data
127}
128
129pub fn get_sp(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<String> {
130    let cleaned_path_name = get_cleaned_def_path_name(tcx, def_id);
131    let json_data: serde_json::Value = get_sp_json();
132
133    if let Some(function_info) = json_data.get(&cleaned_path_name) {
134        if let Some(sp_list) = function_info.get("0") {
135            let mut result = HashSet::new();
136            if let Some(sp_array) = sp_list.as_array() {
137                for sp in sp_array {
138                    if let Some(sp_name) = sp.as_str() {
139                        result.insert(sp_name.to_string());
140                    }
141                }
142            }
143            return result;
144        }
145    }
146    HashSet::new()
147}
148
149pub fn get_struct_name(tcx: TyCtxt<'_>, def_id: DefId) -> Option<String> {
150    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
151        if let Some(impl_id) = assoc_item.impl_container(tcx) {
152            let ty = tcx.type_of(impl_id).skip_binder();
153            let type_name = ty.to_string();
154            let struct_name = type_name
155                .split('<')
156                .next()
157                .unwrap_or("")
158                .split("::")
159                .last()
160                .unwrap_or("")
161                .to_string();
162
163            return Some(struct_name);
164        }
165    }
166    None
167}
168
169pub fn check_safety(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
170    let poly_fn_sig = tcx.fn_sig(def_id);
171    let fn_sig = poly_fn_sig.skip_binder();
172    fn_sig.safety() == rustc_hir::Safety::Unsafe
173}
174
175//retval: 0-constructor, 1-method, 2-function
176pub fn get_type(tcx: TyCtxt<'_>, def_id: DefId) -> usize {
177    let mut node_type = 2;
178    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
179        match assoc_item.kind {
180            AssocKind::Fn { has_self, .. } => {
181                if has_self {
182                    node_type = 1;
183                } else {
184                    let fn_sig = tcx.fn_sig(def_id).skip_binder();
185                    let output = fn_sig.output().skip_binder();
186                    // return type is 'Self'
187                    if output.is_param(0) {
188                        node_type = 0;
189                    }
190                    // return type is struct's name
191                    if let Some(impl_id) = assoc_item.impl_container(tcx) {
192                        let ty = tcx.type_of(impl_id).skip_binder();
193                        if output == ty {
194                            node_type = 0;
195                        }
196                    }
197                    match output.kind() {
198                        TyKind::Ref(_, ref_ty, _) => {
199                            if ref_ty.is_param(0) {
200                                node_type = 0;
201                            }
202                            if let Some(impl_id) = assoc_item.impl_container(tcx) {
203                                let ty = tcx.type_of(impl_id).skip_binder();
204                                if *ref_ty == ty {
205                                    node_type = 0;
206                                }
207                            }
208                        }
209                        TyKind::Adt(adt_def, substs) => {
210                            if adt_def.is_enum()
211                                && (tcx.is_diagnostic_item(sym::Option, adt_def.did())
212                                    || tcx.is_diagnostic_item(sym::Result, adt_def.did())
213                                    || tcx.is_diagnostic_item(kw::Box, adt_def.did()))
214                            {
215                                let inner_ty = substs.type_at(0);
216                                if inner_ty.is_param(0) {
217                                    node_type = 0;
218                                }
219                                if let Some(impl_id) = assoc_item.impl_container(tcx) {
220                                    let ty_impl = tcx.type_of(impl_id).skip_binder();
221                                    if inner_ty == ty_impl {
222                                        node_type = 0;
223                                    }
224                                }
225                            }
226                        }
227                        _ => {}
228                    }
229                }
230            }
231            _ => todo!(),
232        }
233    }
234    node_type
235}
236
237pub fn get_adt_ty(tcx: TyCtxt, def_id: DefId) -> Option<Ty> {
238    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
239        if let Some(impl_id) = assoc_item.impl_container(tcx) {
240            return Some(tcx.type_of(impl_id).skip_binder());
241        }
242    }
243    None
244}
245
246pub fn get_cons(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<NodeType> {
247    let mut cons = Vec::new();
248    if tcx.def_kind(def_id) == DefKind::Fn || get_type(tcx, def_id) == 0 {
249        return cons;
250    }
251    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
252        if let Some(impl_id) = assoc_item.impl_container(tcx) {
253            // get struct ty
254            let ty = tcx.type_of(impl_id).skip_binder();
255            if let Some(adt_def) = ty.ty_adt_def() {
256                let adt_def_id = adt_def.did();
257                let impls = tcx.inherent_impls(adt_def_id);
258                for impl_def_id in impls {
259                    for item in tcx.associated_item_def_ids(impl_def_id) {
260                        if (tcx.def_kind(item) == DefKind::Fn
261                            || tcx.def_kind(item) == DefKind::AssocFn)
262                            && get_type(tcx, *item) == 0
263                        {
264                            cons.push(generate_node_ty(tcx, *item));
265                        }
266                    }
267                }
268            }
269        }
270    }
271    cons
272}
273
274pub fn get_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
275    let mut callees = HashSet::new();
276    if tcx.is_mir_available(def_id) {
277        let body = tcx.optimized_mir(def_id);
278        for bb in body.basic_blocks.iter() {
279            if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
280                if let Operand::Constant(func_constant) = func {
281                    if let ty::FnDef(ref callee_def_id, _) = func_constant.const_.ty().kind() {
282                        if check_safety(tcx, *callee_def_id) {
283                            callees.insert(*callee_def_id);
284                        }
285                    }
286                }
287            }
288        }
289    }
290    callees
291}
292
293// return all the impls def id of corresponding struct
294pub fn get_impl_items_of_struct(
295    tcx: TyCtxt<'_>,
296    struct_def_id: DefId,
297) -> Vec<&rustc_hir::ImplItem<'_>> {
298    let mut impls = Vec::new();
299    for impl_item_id in tcx.hir_crate_items(()).impl_items() {
300        let impl_item = tcx.hir_impl_item(impl_item_id);
301        impls.push(impl_item);
302    }
303    impls
304}
305
306// return all the impls def id of corresponding struct
307pub fn get_impls_for_struct(tcx: TyCtxt<'_>, struct_def_id: DefId) -> Vec<DefId> {
308    let mut impls = Vec::new();
309    for impl_item_id in tcx.hir_crate_items(()).impl_items() {
310        let impl_item = tcx.hir_impl_item(impl_item_id);
311        match impl_item.kind {
312            ImplItemKind::Type(ty) => {
313                if let rustc_hir::TyKind::Path(ref qpath) = ty.kind {
314                    if let rustc_hir::QPath::Resolved(_, path) = qpath {
315                        if let rustc_hir::def::Res::Def(_, ref def_id) = path.res {
316                            if *def_id == struct_def_id {
317                                impls.push(impl_item.owner_id.to_def_id());
318                            }
319                        }
320                    }
321                }
322            }
323            _ => (),
324        }
325    }
326    impls
327}
328
329pub fn get_adt_def_id_by_adt_method(tcx: TyCtxt<'_>, def_id: DefId) -> Option<DefId> {
330    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
331        if let Some(impl_id) = assoc_item.impl_container(tcx) {
332            // get struct ty
333            let ty = tcx.type_of(impl_id).skip_binder();
334            if let Some(adt_def) = ty.ty_adt_def() {
335                return Some(adt_def.did());
336            }
337        }
338    }
339    None
340}
341
342// get the pointee or wrapped type
343pub fn get_pointee(matched_ty: Ty<'_>) -> Ty<'_> {
344    // progress_info!("get_pointee: > {:?} as type: {:?}", matched_ty, matched_ty.kind());
345    let pointee = if let ty::RawPtr(ty_mut, _) = matched_ty.kind() {
346        get_pointee(*ty_mut)
347    } else if let ty::Ref(_, referred_ty, _) = matched_ty.kind() {
348        get_pointee(*referred_ty)
349    } else {
350        matched_ty
351    };
352    pointee
353}
354
355pub fn is_ptr(matched_ty: Ty<'_>) -> bool {
356    if let ty::RawPtr(_, _) = matched_ty.kind() {
357        return true;
358    }
359    false
360}
361
362pub fn is_ref(matched_ty: Ty<'_>) -> bool {
363    if let ty::Ref(_, _, _) = matched_ty.kind() {
364        return true;
365    }
366    false
367}
368
369pub fn is_slice(matched_ty: Ty) -> Option<Ty> {
370    if let ty::Slice(inner) = matched_ty.kind() {
371        return Some(*inner);
372    }
373    None
374}
375
376pub fn has_mut_self_param(tcx: TyCtxt, def_id: DefId) -> bool {
377    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
378        match assoc_item.kind {
379            AssocKind::Fn { has_self, .. } => {
380                if has_self && tcx.is_mir_available(def_id) {
381                    let body = tcx.optimized_mir(def_id);
382                    let fst_arg = body.local_decls[Local::from_usize(1)].clone();
383                    let ty = fst_arg.ty;
384                    let is_mut_ref =
385                        matches!(ty.kind(), ty::Ref(_, _, mutbl) if *mutbl == Mutability::Mut);
386                    return fst_arg.mutability.is_mut() || is_mut_ref;
387                }
388            }
389            _ => (),
390        }
391    }
392    false
393}
394
395// Input the adt def id
396// Return set of (mutable method def_id, fields can be modified)
397pub fn get_all_mutable_methods(tcx: TyCtxt, src_def_id: DefId) -> HashMap<DefId, HashSet<usize>> {
398    let mut results = HashMap::new();
399    let all_std_fn_def = get_all_std_fns_by_rustc_public(tcx);
400    let target_adt_def = get_adt_def_id_by_adt_method(tcx, src_def_id);
401    let mut uig_entrance = UnsafetyIsolationCheck::new(tcx);
402    let mut is_std = false;
403    for &def_id in &all_std_fn_def {
404        let adt_def = get_adt_def_id_by_adt_method(tcx, def_id);
405        if adt_def.is_some() && adt_def == target_adt_def && src_def_id != def_id {
406            if has_mut_self_param(tcx, def_id) {
407                results.insert(def_id, HashSet::new());
408            }
409            is_std = true;
410        }
411    }
412    if is_std {
413        return results;
414    }
415
416    let public_fields = target_adt_def.map_or_else(HashSet::new, |def| get_public_fields(tcx, def));
417    let impl_vec = target_adt_def.map_or_else(Vec::new, |def| get_impl_items_of_struct(tcx, def));
418    for item in impl_vec {
419        if let rustc_hir::ImplItemKind::Fn(fnsig, body) = item.kind {
420            let item_def_id = item.owner_id.to_def_id();
421            if has_mut_self_param(tcx, item_def_id) {
422                // TODO: using dataflow to detect field modificaiton, combined with publi c fields
423                let modified_fields = public_fields.clone();
424                results.insert(item_def_id, modified_fields);
425            }
426        }
427        // }
428    }
429    results
430}
431
432// Check each field's visibility, return the public fields vec
433pub fn get_public_fields(tcx: TyCtxt, def_id: DefId) -> HashSet<usize> {
434    let adt_def = tcx.adt_def(def_id);
435    adt_def
436        .all_fields()
437        .enumerate()
438        .filter_map(|(index, field_def)| tcx.visibility(field_def.did).is_public().then_some(index))
439        .collect()
440}
441
442// general function for displaying hashmap
443pub fn display_hashmap<K, V>(map: &HashMap<K, V>, level: usize)
444where
445    K: Ord + Debug + Hash,
446    V: Debug,
447{
448    let indent = "  ".repeat(level);
449    let mut sorted_keys: Vec<_> = map.keys().collect();
450    sorted_keys.sort();
451
452    for key in sorted_keys {
453        if let Some(value) = map.get(key) {
454            println!("{}{:?}: {:?}", indent, key, value);
455        }
456    }
457}
458
459// pub fn get_all_std_unsafe_chains(tcx: TyCtxt, def_id: DefId) -> Vec<String> {
460//     let mut results = Vec::new();
461//     let body = tcx.optimized_mir(def_id);
462//     let bb_len = body.basic_blocks.len();
463//     for i in 0..bb_len {
464//         let callees = match_std_unsafe_chains_callee(
465//             tcx,
466//             body.basic_blocks[BasicBlock::from_usize(i)]
467//                 .clone()
468//                 .terminator(),
469//         );
470//         results.extend(callees);
471//     }
472//     results
473// }
474
475pub fn match_std_unsafe_chains_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
476    let mut results = Vec::new();
477    if let TerminatorKind::Call { func, .. } = &terminator.kind {
478        if let Operand::Constant(func_constant) = func {
479            if let ty::FnDef(ref callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
480                let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
481            }
482        }
483    }
484    results
485}
486
487pub fn get_all_std_unsafe_callees(tcx: TyCtxt, def_id: DefId) -> Vec<String> {
488    let mut results = Vec::new();
489    let body = tcx.optimized_mir(def_id);
490    let bb_len = body.basic_blocks.len();
491    for i in 0..bb_len {
492        let callees = match_std_unsafe_callee(
493            tcx,
494            body.basic_blocks[BasicBlock::from_usize(i)]
495                .clone()
496                .terminator(),
497        );
498        results.extend(callees);
499    }
500    results
501}
502
503pub fn get_all_std_unsafe_callees_block_id(tcx: TyCtxt, def_id: DefId) -> Vec<usize> {
504    let mut results = Vec::new();
505    let body = tcx.optimized_mir(def_id);
506    let bb_len = body.basic_blocks.len();
507    for i in 0..bb_len {
508        if match_std_unsafe_callee(
509            tcx,
510            body.basic_blocks[BasicBlock::from_usize(i)]
511                .clone()
512                .terminator(),
513        )
514        .is_empty()
515        {
516            results.push(i);
517        }
518    }
519    results
520}
521
522pub fn match_std_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
523    let mut results = Vec::new();
524    if let TerminatorKind::Call { func, .. } = &terminator.kind {
525        if let Operand::Constant(func_constant) = func {
526            if let ty::FnDef(ref callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
527                let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
528                if parse_unsafe_api(&func_name).is_some() {
529                    results.push(func_name);
530                }
531            }
532        }
533    }
534    results
535}
536
537// Bug definition: (1) strict -> weak & dst is mutable;
538//                 (2) _ -> strict
539pub fn is_strict_ty_convert<'tcx>(tcx: TyCtxt<'tcx>, src_ty: Ty<'tcx>, dst_ty: Ty<'tcx>) -> bool {
540    (is_strict_ty(tcx, src_ty) && dst_ty.is_mutable_ptr()) || is_strict_ty(tcx, dst_ty)
541}
542
543// strict ty: bool, str, adt fields containing bool or str;
544pub fn is_strict_ty<'tcx>(tcx: TyCtxt<'tcx>, ori_ty: Ty<'tcx>) -> bool {
545    let ty = get_pointee(ori_ty);
546    let mut flag = false;
547    if let TyKind::Adt(adt_def, substs) = ty.kind() {
548        if adt_def.is_struct() {
549            for field_def in adt_def.all_fields() {
550                flag |= is_strict_ty(tcx, field_def.ty(tcx, substs))
551            }
552        }
553    }
554    ty.is_bool() || ty.is_str() || flag
555}
556
557pub fn reverse_op(op: BinOp) -> BinOp {
558    match op {
559        BinOp::Lt => BinOp::Ge,
560        BinOp::Ge => BinOp::Lt,
561        BinOp::Le => BinOp::Gt,
562        BinOp::Gt => BinOp::Le,
563        BinOp::Eq => BinOp::Eq,
564        BinOp::Ne => BinOp::Ne,
565        _ => op,
566    }
567}
568
569/// Same with `generate_contract_from_annotation` but does not contain field types.
570pub fn generate_contract_from_annotation_without_field_types(
571    tcx: TyCtxt,
572    def_id: DefId,
573) -> Vec<(usize, Vec<usize>, PropertyContract)> {
574    let contracts_with_ty = generate_contract_from_annotation(tcx, def_id);
575
576    contracts_with_ty
577        .into_iter()
578        .map(|(local_id, fields_with_ty, contract)| {
579            let fields: Vec<usize> = fields_with_ty
580                .into_iter()
581                .map(|(field_idx, _)| field_idx)
582                .collect();
583            (local_id, fields, contract)
584        })
585        .collect()
586}
587
588/// Filter the function which contains "rapx::proof"
589pub fn is_verify_target_func(tcx: TyCtxt, def_id: DefId) -> bool {
590    for attr in tcx.get_all_attrs(def_id).into_iter() {
591        let attr_str = rustc_hir_pretty::attribute_to_string(&tcx, attr);
592        // Find proof placeholder
593        if attr_str.contains("#[rapx::proof(proof)]") {
594            return true;
595        }
596    }
597    false
598}
599
600/// Get the annotation in tag-std style.
601/// Then generate the contractual invariant states (CIS) for the args.
602/// This function will recognize the args name and record states to MIR variable (represent by usize).
603/// Return value means Vec<(local_id, fields of this local, contracts)>
604pub fn generate_contract_from_annotation(
605    tcx: TyCtxt,
606    def_id: DefId,
607) -> Vec<(usize, Vec<(usize, Ty)>, PropertyContract)> {
608    const REGISTER_TOOL: &str = "rapx";
609    let tool_attrs = tcx.get_all_attrs(def_id).into_iter().filter(|attr| {
610        if let Attribute::Unparsed(tool_attr) = attr {
611            if tool_attr.path.segments[0].as_str() == REGISTER_TOOL {
612                return true;
613            }
614        }
615        false
616    });
617    let mut results = Vec::new();
618    for attr in tool_attrs {
619        let attr_str = rustc_hir_pretty::attribute_to_string(&tcx, attr);
620        // Find proof placeholder, skip it
621        if attr_str.contains("#[rapx::proof(proof)]") {
622            continue;
623        }
624        rap_debug!("{:?}", attr_str);
625        let safety_attr = safety_parser::safety::parse_attr_and_get_properties(attr_str.as_str());
626        for par in safety_attr.iter() {
627            for property in par.tags.iter() {
628                let tag_name = property.tag.name();
629                let exprs = property.args.clone().into_vec();
630                let contract = PropertyContract::new(tcx, def_id, tag_name, &exprs);
631                let (local, fields) = parse_cis_local(tcx, def_id, exprs);
632                results.push((local, fields, contract));
633            }
634        }
635    }
636    // if results.len() > 0 {
637    //     rap_warn!("results:\n{:?}", results);
638    // }
639    results
640}
641
642/// Parse attr.expr into local id and local fields.
643///
644/// Example:
645/// ```
646/// #[rapx::inner(property = ValidPtr(ptr, u32, 1), kind = "precond")]
647/// #[rapx::inner(property = ValidNum(region.size>=0), kind = "precond")]
648/// pub fn xor_secret_region(ptr: *mut u32, region:SecretRegion) -> u32 {...}
649/// ```
650///
651/// The first attribute will be parsed as (1, []).
652///     -> "1" means the first arg "ptr", "[]" means no fields.
653/// The second attribute will be parsed as (2, [1]).
654///     -> "2" means the second arg "region", "[1]" means "size" is region's second field.
655///
656/// If this function doesn't have args, then it will return default pattern: (0, Vec::new())
657pub fn parse_cis_local(tcx: TyCtxt, def_id: DefId, expr: Vec<Expr>) -> (usize, Vec<(usize, Ty)>) {
658    // match expr with cis local
659    for e in expr {
660        if let Some((base, fields, _ty)) = parse_expr_into_local_and_ty(tcx, def_id, &e) {
661            return (base, fields);
662        }
663    }
664    (0, Vec::new())
665}
666
667/// parse single expr into (local, fields, ty)
668pub fn parse_expr_into_local_and_ty<'tcx>(
669    tcx: TyCtxt<'tcx>,
670    def_id: DefId,
671    expr: &Expr,
672) -> Option<(usize, Vec<(usize, Ty<'tcx>)>, Ty<'tcx>)> {
673    if let Some((base_ident, fields)) = access_ident_recursive(&expr) {
674        let (param_names, param_tys) = parse_signature(tcx, def_id);
675        if param_names[0] == "0".to_string() {
676            return None;
677        }
678        if let Some(param_index) = param_names.iter().position(|name| name == &base_ident) {
679            let mut current_ty = param_tys[param_index];
680            let mut field_indices = Vec::new();
681            for field_name in fields {
682                // peel the ref and ptr
683                let peeled_ty = current_ty.peel_refs();
684                if let rustc_middle::ty::TyKind::Adt(adt_def, arg_list) = *peeled_ty.kind() {
685                    let variant = adt_def.non_enum_variant();
686                    // 1. if field_name is number, then parse it as usize
687                    if let Ok(field_idx) = field_name.parse::<usize>() {
688                        if field_idx < variant.fields.len() {
689                            current_ty = variant.fields[rustc_abi::FieldIdx::from_usize(field_idx)]
690                                .ty(tcx, arg_list);
691                            field_indices.push((field_idx, current_ty));
692                            continue;
693                        }
694                    }
695                    // 2. if field_name is String, then compare it with current ty's field names
696                    if let Some((idx, _)) = variant
697                        .fields
698                        .iter()
699                        .enumerate()
700                        .find(|(_, f)| f.ident(tcx).name.to_string() == field_name.clone())
701                    {
702                        current_ty =
703                            variant.fields[rustc_abi::FieldIdx::from_usize(idx)].ty(tcx, arg_list);
704                        field_indices.push((idx, current_ty));
705                    }
706                    // 3. if field_name can not match any fields, then break
707                    else {
708                        break; // TODO:
709                    }
710                }
711                // if current ty is not Adt, then break the loop
712                else {
713                    break; // TODO:
714                }
715            }
716            // It's different from default one, we return the result as param_index+1 because param_index count from 0.
717            // But 0 in MIR is the ret index, the args' indexes begin from 1.
718            return Some((param_index + 1, field_indices, current_ty));
719        }
720    }
721    None
722}
723
724/// Return the Vecs of args' names and types
725/// This function will handle outside def_id by different way.
726pub fn parse_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
727    // 0. If the def id is local
728    if def_id.as_local().is_some() {
729        return parse_local_signature(tcx, def_id);
730    } else {
731        rap_debug!("{:?} is not local def id.", def_id);
732        return parse_outside_signature(tcx, def_id);
733    };
734}
735
736/// Return the Vecs of args' names and types of outside functions.
737fn parse_outside_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
738    let sig = tcx.fn_sig(def_id).skip_binder();
739    let param_tys: Vec<Ty<'tcx>> = sig.inputs().skip_binder().iter().copied().collect();
740
741    // 1. check pre-defined std unsafe api signature
742    if let Some(args_name) = get_known_std_names(tcx, def_id) {
743        // rap_warn!(
744        //     "function {:?} has arg: {:?}, arg types: {:?}",
745        //     def_id,
746        //     args_name,
747        //     param_tys
748        // );
749        return (args_name, param_tys);
750    }
751
752    // 2. TODO: If can not find known std apis, then use numbers like `0`,`1`,... to represent args.
753    let args_name = (0..param_tys.len()).map(|i| format!("{}", i)).collect();
754    rap_debug!(
755        "function {:?} has arg: {:?}, arg types: {:?}",
756        def_id,
757        args_name,
758        param_tys
759    );
760    return (args_name, param_tys);
761}
762
763/// We use a json to record known std apis' arg names.
764/// This function will search the json and return the names.
765/// Notes: If std gets updated, the json may still record old ones.
766fn get_known_std_names<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Vec<String>> {
767    let std_func_name = get_cleaned_def_path_name(tcx, def_id);
768    let json_data: serde_json::Value = get_std_api_signature_json();
769
770    if let Some(arg_info) = json_data.get(&std_func_name) {
771        if let Some(args_name) = arg_info.as_array() {
772            // set default value to arg name
773            if args_name.len() == 0 {
774                return Some(vec!["0".to_string()]);
775            }
776            // iterate and collect
777            let mut result = Vec::new();
778            for arg in args_name {
779                if let Some(sp_name) = arg.as_str() {
780                    result.push(sp_name.to_string());
781                }
782            }
783            return Some(result);
784        }
785    }
786    None
787}
788
789/// Return the Vecs of args' names and types of local functions.
790pub fn parse_local_signature(tcx: TyCtxt, def_id: DefId) -> (Vec<String>, Vec<Ty>) {
791    // 1. parse local def_id and get arg list
792    let local_def_id = def_id.as_local().unwrap();
793    let hir_body = tcx.hir_body_owned_by(local_def_id);
794    if hir_body.params.len() == 0 {
795        return (vec!["0".to_string()], Vec::new());
796    }
797    // 2. contruct the vec of param and param ty
798    let params = hir_body.params;
799    let typeck_results = tcx.typeck_body(hir_body.id());
800    let mut param_names = Vec::new();
801    let mut param_tys = Vec::new();
802    for param in params {
803        match param.pat.kind {
804            rustc_hir::PatKind::Binding(_, _, ident, _) => {
805                param_names.push(ident.name.to_string());
806                let ty = typeck_results.pat_ty(param.pat);
807                param_tys.push(ty);
808            }
809            _ => {
810                param_names.push(String::new());
811                param_tys.push(typeck_results.pat_ty(param.pat));
812            }
813        }
814    }
815    (param_names, param_tys)
816}
817
818/// return the (ident, its fields) of the expr.
819///
820/// illustrated cases :
821///    ptr	-> ("ptr", [])
822///    region.size	-> ("region", ["size"])
823///    tuple.0.value -> ("tuple", ["0", "value"])
824pub fn access_ident_recursive(expr: &Expr) -> Option<(String, Vec<String>)> {
825    match expr {
826        Expr::Path(syn::ExprPath { path, .. }) => {
827            if path.segments.len() == 1 {
828                rap_debug!("expr2 {:?}", expr);
829                let ident = path.segments[0].ident.to_string();
830                Some((ident, Vec::new()))
831            } else {
832                None
833            }
834        }
835        // get the base and fields recursively
836        Expr::Field(syn::ExprField { base, member, .. }) => {
837            let (base_ident, mut fields) =
838                if let Some((base_ident, fields)) = access_ident_recursive(base) {
839                    (base_ident, fields)
840                } else {
841                    return None;
842                };
843            let field_name = match member {
844                syn::Member::Named(ident) => ident.to_string(),
845                syn::Member::Unnamed(index) => index.index.to_string(),
846            };
847            fields.push(field_name);
848            Some((base_ident, fields))
849        }
850        _ => None,
851    }
852}
853
854/// parse expr into number.
855pub fn parse_expr_into_number(expr: &Expr) -> Option<usize> {
856    if let Expr::Lit(expr_lit) = expr {
857        if let syn::Lit::Int(lit_int) = &expr_lit.lit {
858            return lit_int.base10_parse::<usize>().ok();
859        }
860    }
861    None
862}
863
864/// Match a type identifier string to a concrete Rust type
865///
866/// This function attempts to match a given type identifier (e.g., "u32", "T", "MyStruct")
867/// to a type in the provided parameter type list. It handles:
868/// 1. Built-in primitive types (u32, usize, etc.)
869/// 2. Generic type parameters (T, U, etc.)
870/// 3. User-defined types found in the parameter list
871///
872/// Arguments:
873/// - `tcx`: Type context for querying compiler information
874/// - `type_ident`: String representing the type identifier to match
875/// - `param_ty`: List of parameter types from the function signature
876///
877/// Returns:
878/// - `Some(Ty)` if a matching type is found
879/// - `None` if no match is found
880pub fn match_ty_with_ident(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
881    // 1. First check for built-in primitive types
882    if let Some(primitive_ty) = match_primitive_type(tcx, type_ident.clone()) {
883        return Some(primitive_ty);
884    }
885    // 2. Check if the identifier matches any generic type parameter
886    return find_generic_param(tcx, def_id, type_ident.clone());
887    // 3. Check if the identifier matches any user-defined type in the parameters
888    // find_user_defined_type(tcx, def_id, type_ident)
889}
890
891/// Match built-in primitive types from String
892fn match_primitive_type(tcx: TyCtxt, type_ident: String) -> Option<Ty> {
893    match type_ident.as_str() {
894        "i8" => Some(tcx.types.i8),
895        "i16" => Some(tcx.types.i16),
896        "i32" => Some(tcx.types.i32),
897        "i64" => Some(tcx.types.i64),
898        "i128" => Some(tcx.types.i128),
899        "isize" => Some(tcx.types.isize),
900        "u8" => Some(tcx.types.u8),
901        "u16" => Some(tcx.types.u16),
902        "u32" => Some(tcx.types.u32),
903        "u64" => Some(tcx.types.u64),
904        "u128" => Some(tcx.types.u128),
905        "usize" => Some(tcx.types.usize),
906        "f16" => Some(tcx.types.f16),
907        "f32" => Some(tcx.types.f32),
908        "f64" => Some(tcx.types.f64),
909        "f128" => Some(tcx.types.f128),
910        "bool" => Some(tcx.types.bool),
911        "char" => Some(tcx.types.char),
912        "str" => Some(tcx.types.str_),
913        _ => None,
914    }
915}
916
917/// Find generic type parameters in the parameter list
918fn find_generic_param(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
919    rap_debug!(
920        "Searching for generic param: {} in {:?}",
921        type_ident,
922        def_id
923    );
924    let (_, param_tys) = parse_signature(tcx, def_id);
925    rap_debug!("Function parameter types: {:?} of {:?}", param_tys, def_id);
926    // 递归查找泛型参数
927    for &ty in &param_tys {
928        if let Some(found) = find_generic_in_ty(tcx, ty, &type_ident) {
929            return Some(found);
930        }
931    }
932
933    None
934}
935
936/// Iterate the args' types recursively and find the matched generic one.
937fn find_generic_in_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>, type_ident: &str) -> Option<Ty<'tcx>> {
938    match ty.kind() {
939        TyKind::Param(param_ty) => {
940            if param_ty.name.as_str() == type_ident {
941                return Some(ty);
942            }
943        }
944        TyKind::RawPtr(ty, _)
945        | TyKind::Ref(_, ty, _)
946        | TyKind::Slice(ty)
947        | TyKind::Array(ty, _) => {
948            if let Some(found) = find_generic_in_ty(tcx, *ty, type_ident) {
949                return Some(found);
950            }
951        }
952        TyKind::Tuple(tys) => {
953            for tuple_ty in tys.iter() {
954                if let Some(found) = find_generic_in_ty(tcx, tuple_ty, type_ident) {
955                    return Some(found);
956                }
957            }
958        }
959        TyKind::Adt(adt_def, substs) => {
960            let name = tcx.item_name(adt_def.did()).to_string();
961            if name == type_ident {
962                return Some(ty);
963            }
964            for field in adt_def.all_fields() {
965                let field_ty = field.ty(tcx, substs);
966                if let Some(found) = find_generic_in_ty(tcx, field_ty, type_ident) {
967                    return Some(found);
968                }
969            }
970        }
971        _ => {}
972    }
973    None
974}
975
976// /// Find user-defined types in the parameter list
977// fn find_user_defined_type(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
978//     let param_ty = parse_signature(tcx, def_id).1;
979//     param_ty.iter().find_map(|&ty| {
980//         // Peel off references and pointers to get to the underlying type
981//         let peeled_ty = ty.peel_refs();
982//         match peeled_ty.kind() {
983//             TyKind::Adt(adt_def, _raw_list) => {
984//                 // Compare the type name to our identifier
985//                 let name = tcx.item_name(adt_def.did()).to_string();
986//                 if name == type_ident {
987//                     return Some(peeled_ty);
988//                 }
989//             }
990//             _ => {}
991//         }
992//         None
993//     })
994// }
995
996pub fn reflect_generic<'tcx>(
997    generic_mapping: &FxHashMap<String, Ty<'tcx>>,
998    ty: Ty<'tcx>,
999) -> Ty<'tcx> {
1000    match ty.kind() {
1001        TyKind::Param(param_ty) => {
1002            let generic_name = param_ty.name.to_string();
1003            if let Some(actual_ty) = generic_mapping.get(&generic_name) {
1004                return *actual_ty;
1005            }
1006        }
1007        _ => {}
1008    }
1009    ty
1010}
1011
1012// src_var = 0: for constructor
1013// src_var = 1: for methods
1014pub fn has_tainted_fields(tcx: TyCtxt, def_id: DefId, src_var: u32) -> bool {
1015    let mut dataflow_analyzer = DataFlowAnalyzer::new(tcx, false);
1016    dataflow_analyzer.build_graph(def_id);
1017
1018    let body = tcx.optimized_mir(def_id);
1019    let params = &body.args_iter().collect::<Vec<_>>();
1020    rap_info!("params {:?}", params);
1021    let self_local = Local::from(src_var);
1022
1023    let flowing_params: Vec<Local> = params
1024        .iter()
1025        .filter(|&&param_local| {
1026            dataflow_analyzer.has_flow_between(def_id, self_local, param_local)
1027                && self_local != param_local
1028        })
1029        .copied()
1030        .collect();
1031
1032    if !flowing_params.is_empty() {
1033        rap_info!(
1034            "Taint flow found from self to other parameters: {:?}",
1035            flowing_params
1036        );
1037        true
1038    } else {
1039        false
1040    }
1041}
1042
1043// 修改返回值类型为调用链的向量
1044pub fn get_all_std_unsafe_chains(tcx: TyCtxt, def_id: DefId) -> Vec<Vec<String>> {
1045    let mut results = Vec::new();
1046    let mut visited = HashSet::new(); // 避免循环调用
1047    let mut current_chain = Vec::new();
1048
1049    // 开始DFS遍历
1050    dfs_find_unsafe_chains(tcx, def_id, &mut current_chain, &mut results, &mut visited);
1051    results
1052}
1053
1054// DFS递归查找unsafe调用链
1055fn dfs_find_unsafe_chains(
1056    tcx: TyCtxt,
1057    def_id: DefId,
1058    current_chain: &mut Vec<String>,
1059    results: &mut Vec<Vec<String>>,
1060    visited: &mut HashSet<DefId>,
1061) {
1062    // 避免循环调用
1063    if visited.contains(&def_id) {
1064        return;
1065    }
1066    visited.insert(def_id);
1067
1068    let current_func_name = get_cleaned_def_path_name(tcx, def_id);
1069    current_chain.push(current_func_name.clone());
1070
1071    // 获取当前函数的所有unsafe callee
1072    let unsafe_callees = find_unsafe_callees_in_function(tcx, def_id);
1073
1074    if unsafe_callees.is_empty() {
1075        // 如果没有更多的unsafe callee,保存当前链
1076        results.push(current_chain.clone());
1077    } else {
1078        // 对每个unsafe callee继续DFS
1079        for (callee_def_id, callee_name) in unsafe_callees {
1080            dfs_find_unsafe_chains(tcx, callee_def_id, current_chain, results, visited);
1081        }
1082    }
1083
1084    // 回溯
1085    current_chain.pop();
1086    visited.remove(&def_id);
1087}
1088
1089// 在函数中查找所有unsafe callee
1090fn find_unsafe_callees_in_function(tcx: TyCtxt, def_id: DefId) -> Vec<(DefId, String)> {
1091    let mut callees = Vec::new();
1092
1093    if let Some(body) = try_get_mir(tcx, def_id) {
1094        for bb in body.basic_blocks.iter() {
1095            if let Some(terminator) = &bb.terminator {
1096                if let Some((callee_def_id, callee_name)) = extract_unsafe_callee(tcx, terminator) {
1097                    callees.push((callee_def_id, callee_name));
1098                }
1099            }
1100        }
1101    }
1102
1103    callees
1104}
1105
1106// 从terminator中提取unsafe callee
1107fn extract_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Option<(DefId, String)> {
1108    if let TerminatorKind::Call { func, .. } = &terminator.kind {
1109        if let Operand::Constant(func_constant) = func {
1110            if let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() {
1111                if check_safety(tcx, *callee_def_id) {
1112                    let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
1113                    return Some((*callee_def_id, func_name));
1114                }
1115            }
1116        }
1117    }
1118    None
1119}
1120
1121// 安全地获取MIR,处理可能无法获取MIR的情况
1122fn try_get_mir(tcx: TyCtxt<'_>, def_id: DefId) -> Option<&rustc_middle::mir::Body<'_>> {
1123    if tcx.is_mir_available(def_id) {
1124        Some(tcx.optimized_mir(def_id))
1125    } else {
1126        None
1127    }
1128}
1129
1130// 清理def path名称的辅助函数
1131pub fn get_cleaned_def_path_name(tcx: TyCtxt<'_>, def_id: DefId) -> String {
1132    // 这里实现你的路径清理逻辑
1133    tcx.def_path_str(def_id)
1134}
1135
1136// 打印调用链的函数
1137pub fn print_unsafe_chains(chains: &[Vec<String>]) {
1138    if chains.is_empty() {
1139        return;
1140    }
1141
1142    println!("==============================");
1143    println!("Found {} unsafe call chain(s):", chains.len());
1144    for (i, chain) in chains.iter().enumerate() {
1145        println!("Chain {}:", i + 1);
1146        for (j, func_name) in chain.iter().enumerate() {
1147            let indent = "  ".repeat(j);
1148            println!("{}{}-> {}", indent, if j > 0 { " " } else { "" }, func_name);
1149        }
1150        println!();
1151    }
1152}
1153
1154pub fn get_all_std_fns_by_rustc_public(tcx: TyCtxt) -> Vec<DefId> {
1155    let mut all_std_fn_def = Vec::new();
1156    let mut results = Vec::new();
1157    let mut core_fn_def: Vec<_> = rustc_public::find_crates("core")
1158        .iter()
1159        .flat_map(|krate| krate.fn_defs())
1160        .collect();
1161    let mut std_fn_def: Vec<_> = rustc_public::find_crates("std")
1162        .iter()
1163        .flat_map(|krate| krate.fn_defs())
1164        .collect();
1165    let mut alloc_fn_def: Vec<_> = rustc_public::find_crates("alloc")
1166        .iter()
1167        .flat_map(|krate| krate.fn_defs())
1168        .collect();
1169    all_std_fn_def.append(&mut core_fn_def);
1170    all_std_fn_def.append(&mut std_fn_def);
1171    all_std_fn_def.append(&mut alloc_fn_def);
1172
1173    for fn_def in &all_std_fn_def {
1174        let def_id = crate::def_id::to_internal(fn_def, tcx);
1175        results.push(def_id);
1176    }
1177    results
1178}
1179
1180// pub fn generate_uig_for_one_struct(tcx: TyCtxt, def_id: DefId, adt_def_id: DefId) {
1181//     let adt_def = get_adt_def_id_by_adt_method(tcx, def_id);
1182//     let mut uig_entrance = UnsafetyIsolationCheck::new(tcx);
1183//     let impls = tcx.inherent_impls(adt_def.unwrap());
1184//     let impl_results = get_impl_items_of_struct(tcx, adt_def.unwrap());
1185//     println!("impls {:?}", impl_results);
1186//     for impl_def_id in impls {
1187//         // println!("impls {:?}", tcx.associated_item_def_ids(impl_def_id));
1188//         for item in tcx.associated_item_def_ids(impl_def_id) {
1189//             if tcx.def_kind(item) == DefKind::Fn || tcx.def_kind(item) == DefKind::AssocFn {
1190//                 println!("item {:?}", item);
1191//                 uig_entrance.insert_uig(*item, get_callees(tcx, *item), get_cons(tcx, *item));
1192//             }
1193//         }
1194//     }
1195
1196//     let mut dot_strs = Vec::new();
1197//     for uig in &uig_entrance.uigs {
1198//         let dot_str = uig.generate_dot_str();
1199//         dot_strs.push(dot_str);
1200//     }
1201//     for uig in &uig_entrance.single {
1202//         let dot_str = uig.generate_dot_str();
1203//         dot_strs.push(dot_str);
1204//     }
1205//     render_dot_graphs(dot_strs);
1206// }