rapx/analysis/utils/
fn_info.rs

1use crate::analysis::core::dataflow::DataFlowAnalysis;
2use crate::analysis::core::dataflow::default::DataFlowAnalyzer;
3use crate::analysis::senryx::contracts::property;
4#[allow(unused)]
5use crate::analysis::senryx::contracts::property::PropertyContract;
6use crate::analysis::senryx::matcher::parse_unsafe_api;
7use crate::analysis::unsafety_isolation::UnsafetyIsolationCheck;
8use crate::analysis::unsafety_isolation::draw_dot::render_dot_graphs;
9use crate::analysis::unsafety_isolation::generate_dot::NodeType;
10use crate::rap_debug;
11use crate::rap_warn;
12use rustc_data_structures::fx::FxHashMap;
13use rustc_hir::Attribute;
14use rustc_hir::ImplItemKind;
15use rustc_hir::def::DefKind;
16use rustc_hir::def_id::DefId;
17use rustc_middle::mir::BinOp;
18use rustc_middle::mir::Local;
19use rustc_middle::mir::{BasicBlock, Terminator};
20use rustc_middle::ty::{AssocKind, Mutability, Ty, TyCtxt, TyKind};
21use rustc_middle::{
22    mir::{Operand, TerminatorKind},
23    ty,
24};
25use rustc_span::def_id::LocalDefId;
26use rustc_span::kw;
27use rustc_span::sym;
28use std::collections::HashMap;
29use std::collections::HashSet;
30use std::fmt::Debug;
31use std::hash::Hash;
32use syn::Expr;
33
34pub fn generate_node_ty(tcx: TyCtxt, def_id: DefId) -> NodeType {
35    (def_id, check_safety(tcx, def_id), get_type(tcx, def_id))
36}
37
38pub fn check_visibility(tcx: TyCtxt, func_defid: DefId) -> bool {
39    if !tcx.visibility(func_defid).is_public() {
40        return false;
41    }
42    // if func_defid.is_local() {
43    //     if let Some(local_defid) = func_defid.as_local() {
44    //         let module_moddefid = tcx.parent_module_from_def_id(local_defid);
45    //         let module_defid = module_moddefid.to_def_id();
46    //         if !tcx.visibility(module_defid).is_public() {
47    //             // println!("module def id {:?}",UigUnit::get_cleaned_def_path_name(tcx, module_defid));
48    //             return Self::is_re_exported(tcx, func_defid, module_moddefid.to_local_def_id());
49    //         }
50    //     }
51    // }
52    true
53}
54
55pub fn is_re_exported(tcx: TyCtxt, target_defid: DefId, module_defid: LocalDefId) -> bool {
56    for child in tcx.module_children_local(module_defid) {
57        if child.vis.is_public() {
58            if let Some(def_id) = child.res.opt_def_id() {
59                if def_id == target_defid {
60                    return true;
61                }
62            }
63        }
64    }
65    false
66}
67
68pub fn print_hashset<T: std::fmt::Debug>(set: &HashSet<T>) {
69    for item in set {
70        println!("{:?}", item);
71    }
72    println!("---------------");
73}
74
75pub fn get_cleaned_def_path_name_ori(tcx: TyCtxt, def_id: DefId) -> String {
76    let def_id_str = format!("{:?}", def_id);
77    let mut parts: Vec<&str> = def_id_str.split("::").collect();
78
79    let mut remove_first = false;
80    if let Some(first_part) = parts.get_mut(0) {
81        if first_part.contains("core") {
82            *first_part = "core";
83        } else if first_part.contains("std") {
84            *first_part = "std";
85        } else if first_part.contains("alloc") {
86            *first_part = "alloc";
87        } else {
88            remove_first = true;
89        }
90    }
91    if remove_first && !parts.is_empty() {
92        parts.remove(0);
93    }
94
95    let new_parts: Vec<String> = parts
96        .into_iter()
97        .filter_map(|s| {
98            if s.contains("{") {
99                if remove_first {
100                    get_struct_name(tcx, def_id)
101                } else {
102                    None
103                }
104            } else {
105                Some(s.to_string())
106            }
107        })
108        .collect();
109
110    let mut cleaned_path = new_parts.join("::");
111    cleaned_path = cleaned_path.trim_end_matches(')').to_string();
112    cleaned_path
113}
114
115pub fn get_sp_json() -> serde_json::Value {
116    let json_data: serde_json::Value =
117        serde_json::from_str(include_str!("../unsafety_isolation/data/std_sps.json"))
118            .expect("Unable to parse JSON");
119    json_data
120}
121
122pub fn get_std_api_signature_json() -> serde_json::Value {
123    let json_data: serde_json::Value =
124        serde_json::from_str(include_str!("../unsafety_isolation/data/std_sig.json"))
125            .expect("Unable to parse JSON");
126    json_data
127}
128
129pub fn get_sp(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<String> {
130    let cleaned_path_name = get_cleaned_def_path_name(tcx, def_id);
131    let json_data: serde_json::Value = get_sp_json();
132
133    if let Some(function_info) = json_data.get(&cleaned_path_name) {
134        if let Some(sp_list) = function_info.get("0") {
135            let mut result = HashSet::new();
136            if let Some(sp_array) = sp_list.as_array() {
137                for sp in sp_array {
138                    if let Some(sp_name) = sp.as_str() {
139                        result.insert(sp_name.to_string());
140                    }
141                }
142            }
143            return result;
144        }
145    }
146    HashSet::new()
147}
148
149pub fn get_struct_name(tcx: TyCtxt<'_>, def_id: DefId) -> Option<String> {
150    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
151        if let Some(impl_id) = assoc_item.impl_container(tcx) {
152            let ty = tcx.type_of(impl_id).skip_binder();
153            let type_name = ty.to_string();
154            let struct_name = type_name
155                .split('<')
156                .next()
157                .unwrap_or("")
158                .split("::")
159                .last()
160                .unwrap_or("")
161                .to_string();
162
163            return Some(struct_name);
164        }
165    }
166    None
167}
168
169pub fn check_safety(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
170    let poly_fn_sig = tcx.fn_sig(def_id);
171    let fn_sig = poly_fn_sig.skip_binder();
172    fn_sig.safety() == rustc_hir::Safety::Unsafe
173}
174
175//retval: 0-constructor, 1-method, 2-function
176pub fn get_type(tcx: TyCtxt<'_>, def_id: DefId) -> usize {
177    let mut node_type = 2;
178    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
179        match assoc_item.kind {
180            AssocKind::Fn { has_self, .. } => {
181                if has_self {
182                    node_type = 1;
183                } else {
184                    let fn_sig = tcx.fn_sig(def_id).skip_binder();
185                    let output = fn_sig.output().skip_binder();
186                    // return type is 'Self'
187                    if output.is_param(0) {
188                        node_type = 0;
189                    }
190                    // return type is struct's name
191                    if let Some(impl_id) = assoc_item.impl_container(tcx) {
192                        let ty = tcx.type_of(impl_id).skip_binder();
193                        if output == ty {
194                            node_type = 0;
195                        }
196                    }
197                    match output.kind() {
198                        TyKind::Ref(_, ref_ty, _) => {
199                            if ref_ty.is_param(0) {
200                                node_type = 0;
201                            }
202                            if let Some(impl_id) = assoc_item.impl_container(tcx) {
203                                let ty = tcx.type_of(impl_id).skip_binder();
204                                if *ref_ty == ty {
205                                    node_type = 0;
206                                }
207                            }
208                        }
209                        TyKind::Adt(adt_def, substs) => {
210                            if adt_def.is_enum()
211                                && (tcx.is_diagnostic_item(sym::Option, adt_def.did())
212                                    || tcx.is_diagnostic_item(sym::Result, adt_def.did())
213                                    || tcx.is_diagnostic_item(kw::Box, adt_def.did()))
214                            {
215                                let inner_ty = substs.type_at(0);
216                                if inner_ty.is_param(0) {
217                                    node_type = 0;
218                                }
219                                if let Some(impl_id) = assoc_item.impl_container(tcx) {
220                                    let ty_impl = tcx.type_of(impl_id).skip_binder();
221                                    if inner_ty == ty_impl {
222                                        node_type = 0;
223                                    }
224                                }
225                            }
226                        }
227                        _ => {}
228                    }
229                }
230            }
231            _ => todo!(),
232        }
233    }
234    node_type
235}
236
237pub fn get_adt_ty(tcx: TyCtxt, def_id: DefId) -> Option<Ty> {
238    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
239        if let Some(impl_id) = assoc_item.impl_container(tcx) {
240            return Some(tcx.type_of(impl_id).skip_binder());
241        }
242    }
243    None
244}
245
246pub fn get_cons(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<NodeType> {
247    let mut cons = Vec::new();
248    if tcx.def_kind(def_id) == DefKind::Fn || get_type(tcx, def_id) == 0 {
249        return cons;
250    }
251    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
252        if let Some(impl_id) = assoc_item.impl_container(tcx) {
253            // get struct ty
254            let ty = tcx.type_of(impl_id).skip_binder();
255            if let Some(adt_def) = ty.ty_adt_def() {
256                let adt_def_id = adt_def.did();
257                let impls = tcx.inherent_impls(adt_def_id);
258                for impl_def_id in impls {
259                    for item in tcx.associated_item_def_ids(impl_def_id) {
260                        if (tcx.def_kind(item) == DefKind::Fn
261                            || tcx.def_kind(item) == DefKind::AssocFn)
262                            && get_type(tcx, *item) == 0
263                        {
264                            cons.push(generate_node_ty(tcx, *item));
265                        }
266                    }
267                }
268            }
269        }
270    }
271    cons
272}
273
274pub fn get_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
275    let mut callees = HashSet::new();
276    if tcx.is_mir_available(def_id) {
277        let body = tcx.optimized_mir(def_id);
278        for bb in body.basic_blocks.iter() {
279            if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
280                if let Operand::Constant(func_constant) = func {
281                    if let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() {
282                        if check_safety(tcx, *callee_def_id) {
283                            callees.insert(*callee_def_id);
284                        }
285                    }
286                }
287            }
288        }
289    }
290    callees
291}
292
293pub fn get_all_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
294    let mut callees = HashSet::new();
295    if tcx.is_mir_available(def_id) {
296        let body = tcx.optimized_mir(def_id);
297        for bb in body.basic_blocks.iter() {
298            if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
299                if let Operand::Constant(func_constant) = func {
300                    if let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() {
301                        callees.insert(*callee_def_id);
302                    }
303                }
304            }
305        }
306    }
307    callees
308}
309
310// return all the impls def id of corresponding struct
311pub fn get_impl_items_of_struct(
312    tcx: TyCtxt<'_>,
313    struct_def_id: DefId,
314) -> Vec<&rustc_hir::ImplItem<'_>> {
315    let mut impls = Vec::new();
316    for impl_item_id in tcx.hir_crate_items(()).impl_items() {
317        let impl_item = tcx.hir_impl_item(impl_item_id);
318        impls.push(impl_item);
319    }
320    impls
321}
322
323// return all the impls def id of corresponding struct
324pub fn get_impls_for_struct(tcx: TyCtxt<'_>, struct_def_id: DefId) -> Vec<DefId> {
325    let mut impls = Vec::new();
326    for impl_item_id in tcx.hir_crate_items(()).impl_items() {
327        let impl_item = tcx.hir_impl_item(impl_item_id);
328        match impl_item.kind {
329            ImplItemKind::Type(ty) => {
330                if let rustc_hir::TyKind::Path(ref qpath) = ty.kind {
331                    if let rustc_hir::QPath::Resolved(_, path) = qpath {
332                        if let rustc_hir::def::Res::Def(_, ref def_id) = path.res {
333                            if *def_id == struct_def_id {
334                                impls.push(impl_item.owner_id.to_def_id());
335                            }
336                        }
337                    }
338                }
339            }
340            _ => (),
341        }
342    }
343    impls
344}
345
346pub fn get_adt_def_id_by_adt_method(tcx: TyCtxt<'_>, def_id: DefId) -> Option<DefId> {
347    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
348        if let Some(impl_id) = assoc_item.impl_container(tcx) {
349            // get struct ty
350            let ty = tcx.type_of(impl_id).skip_binder();
351            if let Some(adt_def) = ty.ty_adt_def() {
352                return Some(adt_def.did());
353            }
354        }
355    }
356    None
357}
358
359// get the pointee or wrapped type
360pub fn get_pointee(matched_ty: Ty<'_>) -> Ty<'_> {
361    // progress_info!("get_pointee: > {:?} as type: {:?}", matched_ty, matched_ty.kind());
362    let pointee = if let ty::RawPtr(ty_mut, _) = matched_ty.kind() {
363        get_pointee(*ty_mut)
364    } else if let ty::Ref(_, referred_ty, _) = matched_ty.kind() {
365        get_pointee(*referred_ty)
366    } else {
367        matched_ty
368    };
369    pointee
370}
371
372pub fn is_ptr(matched_ty: Ty<'_>) -> bool {
373    if let ty::RawPtr(_, _) = matched_ty.kind() {
374        return true;
375    }
376    false
377}
378
379pub fn is_ref(matched_ty: Ty<'_>) -> bool {
380    if let ty::Ref(_, _, _) = matched_ty.kind() {
381        return true;
382    }
383    false
384}
385
386pub fn is_slice(matched_ty: Ty) -> Option<Ty> {
387    if let ty::Slice(inner) = matched_ty.kind() {
388        return Some(*inner);
389    }
390    None
391}
392
393pub fn has_mut_self_param(tcx: TyCtxt, def_id: DefId) -> bool {
394    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
395        match assoc_item.kind {
396            AssocKind::Fn { has_self, .. } => {
397                if has_self && tcx.is_mir_available(def_id) {
398                    let body = tcx.optimized_mir(def_id);
399                    let fst_arg = body.local_decls[Local::from_usize(1)].clone();
400                    let ty = fst_arg.ty;
401                    let is_mut_ref =
402                        matches!(ty.kind(), ty::Ref(_, _, mutbl) if *mutbl == Mutability::Mut);
403                    return fst_arg.mutability.is_mut() || is_mut_ref;
404                }
405            }
406            _ => (),
407        }
408    }
409    false
410}
411
412// Input the adt def id
413// Return set of (mutable method def_id, fields can be modified)
414pub fn get_all_mutable_methods(tcx: TyCtxt, src_def_id: DefId) -> HashMap<DefId, HashSet<usize>> {
415    let mut results = HashMap::new();
416    let all_std_fn_def = get_all_std_fns_by_rustc_public(tcx);
417    let target_adt_def = get_adt_def_id_by_adt_method(tcx, src_def_id);
418    let mut uig_entrance = UnsafetyIsolationCheck::new(tcx);
419    let mut is_std = false;
420    for &def_id in &all_std_fn_def {
421        let adt_def = get_adt_def_id_by_adt_method(tcx, def_id);
422        if adt_def.is_some() && adt_def == target_adt_def && src_def_id != def_id {
423            if has_mut_self_param(tcx, def_id) {
424                results.insert(def_id, HashSet::new());
425            }
426            is_std = true;
427        }
428    }
429    if is_std {
430        return results;
431    }
432
433    let public_fields = target_adt_def.map_or_else(HashSet::new, |def| get_public_fields(tcx, def));
434    let impl_vec = target_adt_def.map_or_else(Vec::new, |def| get_impl_items_of_struct(tcx, def));
435    for item in impl_vec {
436        if let rustc_hir::ImplItemKind::Fn(fnsig, body) = item.kind {
437            let item_def_id = item.owner_id.to_def_id();
438            if has_mut_self_param(tcx, item_def_id) {
439                // TODO: using dataflow to detect field modificaiton, combined with publi c fields
440                let modified_fields = public_fields.clone();
441                results.insert(item_def_id, modified_fields);
442            }
443        }
444        // }
445    }
446    results
447}
448
449// Check each field's visibility, return the public fields vec
450pub fn get_public_fields(tcx: TyCtxt, def_id: DefId) -> HashSet<usize> {
451    let adt_def = tcx.adt_def(def_id);
452    adt_def
453        .all_fields()
454        .enumerate()
455        .filter_map(|(index, field_def)| tcx.visibility(field_def.did).is_public().then_some(index))
456        .collect()
457}
458
459// general function for displaying hashmap
460pub fn display_hashmap<K, V>(map: &HashMap<K, V>, level: usize)
461where
462    K: Ord + Debug + Hash,
463    V: Debug,
464{
465    let indent = "  ".repeat(level);
466    let mut sorted_keys: Vec<_> = map.keys().collect();
467    sorted_keys.sort();
468
469    for key in sorted_keys {
470        if let Some(value) = map.get(key) {
471            println!("{}{:?}: {:?}", indent, key, value);
472        }
473    }
474}
475
476// pub fn get_all_std_unsafe_chains(tcx: TyCtxt, def_id: DefId) -> Vec<String> {
477//     let mut results = Vec::new();
478//     let body = tcx.optimized_mir(def_id);
479//     let bb_len = body.basic_blocks.len();
480//     for i in 0..bb_len {
481//         let callees = match_std_unsafe_chains_callee(
482//             tcx,
483//             body.basic_blocks[BasicBlock::from_usize(i)]
484//                 .clone()
485//                 .terminator(),
486//         );
487//         results.extend(callees);
488//     }
489//     results
490// }
491
492pub fn match_std_unsafe_chains_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
493    let mut results = Vec::new();
494    if let TerminatorKind::Call { func, .. } = &terminator.kind {
495        if let Operand::Constant(func_constant) = func {
496            if let ty::FnDef(callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
497                let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
498            }
499        }
500    }
501    results
502}
503
504pub fn get_all_std_unsafe_callees(tcx: TyCtxt, def_id: DefId) -> Vec<String> {
505    let mut results = Vec::new();
506    let body = tcx.optimized_mir(def_id);
507    let bb_len = body.basic_blocks.len();
508    for i in 0..bb_len {
509        let callees = match_std_unsafe_callee(
510            tcx,
511            body.basic_blocks[BasicBlock::from_usize(i)]
512                .clone()
513                .terminator(),
514        );
515        results.extend(callees);
516    }
517    results
518}
519
520pub fn get_all_std_unsafe_callees_block_id(tcx: TyCtxt, def_id: DefId) -> Vec<usize> {
521    let mut results = Vec::new();
522    let body = tcx.optimized_mir(def_id);
523    let bb_len = body.basic_blocks.len();
524    for i in 0..bb_len {
525        if match_std_unsafe_callee(
526            tcx,
527            body.basic_blocks[BasicBlock::from_usize(i)]
528                .clone()
529                .terminator(),
530        )
531        .is_empty()
532        {
533            results.push(i);
534        }
535    }
536    results
537}
538
539pub fn match_std_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
540    let mut results = Vec::new();
541    if let TerminatorKind::Call { func, .. } = &terminator.kind {
542        if let Operand::Constant(func_constant) = func {
543            if let ty::FnDef(callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
544                let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
545                if parse_unsafe_api(&func_name).is_some() {
546                    results.push(func_name);
547                }
548            }
549        }
550    }
551    results
552}
553
554// Bug definition: (1) strict -> weak & dst is mutable;
555//                 (2) _ -> strict
556pub fn is_strict_ty_convert<'tcx>(tcx: TyCtxt<'tcx>, src_ty: Ty<'tcx>, dst_ty: Ty<'tcx>) -> bool {
557    (is_strict_ty(tcx, src_ty) && dst_ty.is_mutable_ptr()) || is_strict_ty(tcx, dst_ty)
558}
559
560// strict ty: bool, str, adt fields containing bool or str;
561pub fn is_strict_ty<'tcx>(tcx: TyCtxt<'tcx>, ori_ty: Ty<'tcx>) -> bool {
562    let ty = get_pointee(ori_ty);
563    let mut flag = false;
564    if let TyKind::Adt(adt_def, substs) = ty.kind() {
565        if adt_def.is_struct() {
566            for field_def in adt_def.all_fields() {
567                flag |= is_strict_ty(tcx, field_def.ty(tcx, substs))
568            }
569        }
570    }
571    ty.is_bool() || ty.is_str() || flag
572}
573
574pub fn reverse_op(op: BinOp) -> BinOp {
575    match op {
576        BinOp::Lt => BinOp::Ge,
577        BinOp::Ge => BinOp::Lt,
578        BinOp::Le => BinOp::Gt,
579        BinOp::Gt => BinOp::Le,
580        BinOp::Eq => BinOp::Eq,
581        BinOp::Ne => BinOp::Ne,
582        _ => op,
583    }
584}
585
586/// Same with `generate_contract_from_annotation` but does not contain field types.
587pub fn generate_contract_from_annotation_without_field_types(
588    tcx: TyCtxt,
589    def_id: DefId,
590) -> Vec<(usize, Vec<usize>, PropertyContract)> {
591    let contracts_with_ty = generate_contract_from_annotation(tcx, def_id);
592
593    contracts_with_ty
594        .into_iter()
595        .map(|(local_id, fields_with_ty, contract)| {
596            let fields: Vec<usize> = fields_with_ty
597                .into_iter()
598                .map(|(field_idx, _)| field_idx)
599                .collect();
600            (local_id, fields, contract)
601        })
602        .collect()
603}
604
605/// Filter the function which contains "rapx::proof"
606pub fn is_verify_target_func(tcx: TyCtxt, def_id: DefId) -> bool {
607    for attr in tcx.get_all_attrs(def_id).into_iter() {
608        let attr_str = rustc_hir_pretty::attribute_to_string(&tcx, attr);
609        // Find proof placeholder
610        if attr_str.contains("#[rapx::proof(proof)]") {
611            return true;
612        }
613    }
614    false
615}
616
617/// Get the annotation in tag-std style.
618/// Then generate the contractual invariant states (CIS) for the args.
619/// This function will recognize the args name and record states to MIR variable (represent by usize).
620/// Return value means Vec<(local_id, fields of this local, contracts)>
621pub fn generate_contract_from_annotation(
622    tcx: TyCtxt,
623    def_id: DefId,
624) -> Vec<(usize, Vec<(usize, Ty)>, PropertyContract)> {
625    const REGISTER_TOOL: &str = "rapx";
626    let tool_attrs = tcx.get_all_attrs(def_id).into_iter().filter(|attr| {
627        if let Attribute::Unparsed(tool_attr) = attr {
628            if tool_attr.path.segments[0].as_str() == REGISTER_TOOL {
629                return true;
630            }
631        }
632        false
633    });
634    let mut results = Vec::new();
635    for attr in tool_attrs {
636        let attr_str = rustc_hir_pretty::attribute_to_string(&tcx, attr);
637        // Find proof placeholder, skip it
638        if attr_str.contains("#[rapx::proof(proof)]") {
639            continue;
640        }
641        rap_debug!("{:?}", attr_str);
642        let safety_attr = safety_parser::safety::parse_attr_and_get_properties(attr_str.as_str());
643        for par in safety_attr.iter() {
644            for property in par.tags.iter() {
645                let tag_name = property.tag.name();
646                let exprs = property.args.clone().into_vec();
647                let contract = PropertyContract::new(tcx, def_id, tag_name, &exprs);
648                let (local, fields) = parse_cis_local(tcx, def_id, exprs);
649                results.push((local, fields, contract));
650            }
651        }
652    }
653    // if results.len() > 0 {
654    //     rap_warn!("results:\n{:?}", results);
655    // }
656    results
657}
658
659/// Parse attr.expr into local id and local fields.
660///
661/// Example:
662/// ```
663/// #[rapx::inner(property = ValidPtr(ptr, u32, 1), kind = "precond")]
664/// #[rapx::inner(property = ValidNum(region.size>=0), kind = "precond")]
665/// pub fn xor_secret_region(ptr: *mut u32, region:SecretRegion) -> u32 {...}
666/// ```
667///
668/// The first attribute will be parsed as (1, []).
669///     -> "1" means the first arg "ptr", "[]" means no fields.
670/// The second attribute will be parsed as (2, [1]).
671///     -> "2" means the second arg "region", "[1]" means "size" is region's second field.
672///
673/// If this function doesn't have args, then it will return default pattern: (0, Vec::new())
674pub fn parse_cis_local(tcx: TyCtxt, def_id: DefId, expr: Vec<Expr>) -> (usize, Vec<(usize, Ty)>) {
675    // match expr with cis local
676    for e in expr {
677        if let Some((base, fields, _ty)) = parse_expr_into_local_and_ty(tcx, def_id, &e) {
678            return (base, fields);
679        }
680    }
681    (0, Vec::new())
682}
683
684/// parse single expr into (local, fields, ty)
685pub fn parse_expr_into_local_and_ty<'tcx>(
686    tcx: TyCtxt<'tcx>,
687    def_id: DefId,
688    expr: &Expr,
689) -> Option<(usize, Vec<(usize, Ty<'tcx>)>, Ty<'tcx>)> {
690    if let Some((base_ident, fields)) = access_ident_recursive(&expr) {
691        let (param_names, param_tys) = parse_signature(tcx, def_id);
692        if param_names[0] == "0".to_string() {
693            return None;
694        }
695        if let Some(param_index) = param_names.iter().position(|name| name == &base_ident) {
696            let mut current_ty = param_tys[param_index];
697            let mut field_indices = Vec::new();
698            for field_name in fields {
699                // peel the ref and ptr
700                let peeled_ty = current_ty.peel_refs();
701                if let rustc_middle::ty::TyKind::Adt(adt_def, arg_list) = *peeled_ty.kind() {
702                    let variant = adt_def.non_enum_variant();
703                    // 1. if field_name is number, then parse it as usize
704                    if let Ok(field_idx) = field_name.parse::<usize>() {
705                        if field_idx < variant.fields.len() {
706                            current_ty = variant.fields[rustc_abi::FieldIdx::from_usize(field_idx)]
707                                .ty(tcx, arg_list);
708                            field_indices.push((field_idx, current_ty));
709                            continue;
710                        }
711                    }
712                    // 2. if field_name is String, then compare it with current ty's field names
713                    if let Some((idx, _)) = variant
714                        .fields
715                        .iter()
716                        .enumerate()
717                        .find(|(_, f)| f.ident(tcx).name.to_string() == field_name.clone())
718                    {
719                        current_ty =
720                            variant.fields[rustc_abi::FieldIdx::from_usize(idx)].ty(tcx, arg_list);
721                        field_indices.push((idx, current_ty));
722                    }
723                    // 3. if field_name can not match any fields, then break
724                    else {
725                        break; // TODO:
726                    }
727                }
728                // if current ty is not Adt, then break the loop
729                else {
730                    break; // TODO:
731                }
732            }
733            // It's different from default one, we return the result as param_index+1 because param_index count from 0.
734            // But 0 in MIR is the ret index, the args' indexes begin from 1.
735            return Some((param_index + 1, field_indices, current_ty));
736        }
737    }
738    None
739}
740
741/// Return the Vecs of args' names and types
742/// This function will handle outside def_id by different way.
743pub fn parse_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
744    // 0. If the def id is local
745    if def_id.as_local().is_some() {
746        return parse_local_signature(tcx, def_id);
747    } else {
748        rap_debug!("{:?} is not local def id.", def_id);
749        return parse_outside_signature(tcx, def_id);
750    };
751}
752
753/// Return the Vecs of args' names and types of outside functions.
754fn parse_outside_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
755    let sig = tcx.fn_sig(def_id).skip_binder();
756    let param_tys: Vec<Ty<'tcx>> = sig.inputs().skip_binder().iter().copied().collect();
757
758    // 1. check pre-defined std unsafe api signature
759    if let Some(args_name) = get_known_std_names(tcx, def_id) {
760        // rap_warn!(
761        //     "function {:?} has arg: {:?}, arg types: {:?}",
762        //     def_id,
763        //     args_name,
764        //     param_tys
765        // );
766        return (args_name, param_tys);
767    }
768
769    // 2. TODO: If can not find known std apis, then use numbers like `0`,`1`,... to represent args.
770    let args_name = (0..param_tys.len()).map(|i| format!("{}", i)).collect();
771    rap_debug!(
772        "function {:?} has arg: {:?}, arg types: {:?}",
773        def_id,
774        args_name,
775        param_tys
776    );
777    return (args_name, param_tys);
778}
779
780/// We use a json to record known std apis' arg names.
781/// This function will search the json and return the names.
782/// Notes: If std gets updated, the json may still record old ones.
783fn get_known_std_names<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Vec<String>> {
784    let std_func_name = get_cleaned_def_path_name(tcx, def_id);
785    let json_data: serde_json::Value = get_std_api_signature_json();
786
787    if let Some(arg_info) = json_data.get(&std_func_name) {
788        if let Some(args_name) = arg_info.as_array() {
789            // set default value to arg name
790            if args_name.len() == 0 {
791                return Some(vec!["0".to_string()]);
792            }
793            // iterate and collect
794            let mut result = Vec::new();
795            for arg in args_name {
796                if let Some(sp_name) = arg.as_str() {
797                    result.push(sp_name.to_string());
798                }
799            }
800            return Some(result);
801        }
802    }
803    None
804}
805
806/// Return the Vecs of args' names and types of local functions.
807pub fn parse_local_signature(tcx: TyCtxt, def_id: DefId) -> (Vec<String>, Vec<Ty>) {
808    // 1. parse local def_id and get arg list
809    let local_def_id = def_id.as_local().unwrap();
810    let hir_body = tcx.hir_body_owned_by(local_def_id);
811    if hir_body.params.len() == 0 {
812        return (vec!["0".to_string()], Vec::new());
813    }
814    // 2. contruct the vec of param and param ty
815    let params = hir_body.params;
816    let typeck_results = tcx.typeck_body(hir_body.id());
817    let mut param_names = Vec::new();
818    let mut param_tys = Vec::new();
819    for param in params {
820        match param.pat.kind {
821            rustc_hir::PatKind::Binding(_, _, ident, _) => {
822                param_names.push(ident.name.to_string());
823                let ty = typeck_results.pat_ty(param.pat);
824                param_tys.push(ty);
825            }
826            _ => {
827                param_names.push(String::new());
828                param_tys.push(typeck_results.pat_ty(param.pat));
829            }
830        }
831    }
832    (param_names, param_tys)
833}
834
835/// return the (ident, its fields) of the expr.
836///
837/// illustrated cases :
838///    ptr	-> ("ptr", [])
839///    region.size	-> ("region", ["size"])
840///    tuple.0.value -> ("tuple", ["0", "value"])
841pub fn access_ident_recursive(expr: &Expr) -> Option<(String, Vec<String>)> {
842    match expr {
843        Expr::Path(syn::ExprPath { path, .. }) => {
844            if path.segments.len() == 1 {
845                rap_debug!("expr2 {:?}", expr);
846                let ident = path.segments[0].ident.to_string();
847                Some((ident, Vec::new()))
848            } else {
849                None
850            }
851        }
852        // get the base and fields recursively
853        Expr::Field(syn::ExprField { base, member, .. }) => {
854            let (base_ident, mut fields) =
855                if let Some((base_ident, fields)) = access_ident_recursive(base) {
856                    (base_ident, fields)
857                } else {
858                    return None;
859                };
860            let field_name = match member {
861                syn::Member::Named(ident) => ident.to_string(),
862                syn::Member::Unnamed(index) => index.index.to_string(),
863            };
864            fields.push(field_name);
865            Some((base_ident, fields))
866        }
867        _ => None,
868    }
869}
870
871/// parse expr into number.
872pub fn parse_expr_into_number(expr: &Expr) -> Option<usize> {
873    if let Expr::Lit(expr_lit) = expr {
874        if let syn::Lit::Int(lit_int) = &expr_lit.lit {
875            return lit_int.base10_parse::<usize>().ok();
876        }
877    }
878    None
879}
880
881/// Match a type identifier string to a concrete Rust type
882///
883/// This function attempts to match a given type identifier (e.g., "u32", "T", "MyStruct")
884/// to a type in the provided parameter type list. It handles:
885/// 1. Built-in primitive types (u32, usize, etc.)
886/// 2. Generic type parameters (T, U, etc.)
887/// 3. User-defined types found in the parameter list
888///
889/// Arguments:
890/// - `tcx`: Type context for querying compiler information
891/// - `type_ident`: String representing the type identifier to match
892/// - `param_ty`: List of parameter types from the function signature
893///
894/// Returns:
895/// - `Some(Ty)` if a matching type is found
896/// - `None` if no match is found
897pub fn match_ty_with_ident(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
898    // 1. First check for built-in primitive types
899    if let Some(primitive_ty) = match_primitive_type(tcx, type_ident.clone()) {
900        return Some(primitive_ty);
901    }
902    // 2. Check if the identifier matches any generic type parameter
903    return find_generic_param(tcx, def_id, type_ident.clone());
904    // 3. Check if the identifier matches any user-defined type in the parameters
905    // find_user_defined_type(tcx, def_id, type_ident)
906}
907
908/// Match built-in primitive types from String
909fn match_primitive_type(tcx: TyCtxt, type_ident: String) -> Option<Ty> {
910    match type_ident.as_str() {
911        "i8" => Some(tcx.types.i8),
912        "i16" => Some(tcx.types.i16),
913        "i32" => Some(tcx.types.i32),
914        "i64" => Some(tcx.types.i64),
915        "i128" => Some(tcx.types.i128),
916        "isize" => Some(tcx.types.isize),
917        "u8" => Some(tcx.types.u8),
918        "u16" => Some(tcx.types.u16),
919        "u32" => Some(tcx.types.u32),
920        "u64" => Some(tcx.types.u64),
921        "u128" => Some(tcx.types.u128),
922        "usize" => Some(tcx.types.usize),
923        "f16" => Some(tcx.types.f16),
924        "f32" => Some(tcx.types.f32),
925        "f64" => Some(tcx.types.f64),
926        "f128" => Some(tcx.types.f128),
927        "bool" => Some(tcx.types.bool),
928        "char" => Some(tcx.types.char),
929        "str" => Some(tcx.types.str_),
930        _ => None,
931    }
932}
933
934/// Find generic type parameters in the parameter list
935fn find_generic_param(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
936    rap_debug!(
937        "Searching for generic param: {} in {:?}",
938        type_ident,
939        def_id
940    );
941    let (_, param_tys) = parse_signature(tcx, def_id);
942    rap_debug!("Function parameter types: {:?} of {:?}", param_tys, def_id);
943    // 递归查找泛型参数
944    for &ty in &param_tys {
945        if let Some(found) = find_generic_in_ty(tcx, ty, &type_ident) {
946            return Some(found);
947        }
948    }
949
950    None
951}
952
953/// Iterate the args' types recursively and find the matched generic one.
954fn find_generic_in_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>, type_ident: &str) -> Option<Ty<'tcx>> {
955    match ty.kind() {
956        TyKind::Param(param_ty) => {
957            if param_ty.name.as_str() == type_ident {
958                return Some(ty);
959            }
960        }
961        TyKind::RawPtr(ty, _)
962        | TyKind::Ref(_, ty, _)
963        | TyKind::Slice(ty)
964        | TyKind::Array(ty, _) => {
965            if let Some(found) = find_generic_in_ty(tcx, *ty, type_ident) {
966                return Some(found);
967            }
968        }
969        TyKind::Tuple(tys) => {
970            for tuple_ty in tys.iter() {
971                if let Some(found) = find_generic_in_ty(tcx, tuple_ty, type_ident) {
972                    return Some(found);
973                }
974            }
975        }
976        TyKind::Adt(adt_def, substs) => {
977            let name = tcx.item_name(adt_def.did()).to_string();
978            if name == type_ident {
979                return Some(ty);
980            }
981            for field in adt_def.all_fields() {
982                let field_ty = field.ty(tcx, substs);
983                if let Some(found) = find_generic_in_ty(tcx, field_ty, type_ident) {
984                    return Some(found);
985                }
986            }
987        }
988        _ => {}
989    }
990    None
991}
992
993// /// Find user-defined types in the parameter list
994// fn find_user_defined_type(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
995//     let param_ty = parse_signature(tcx, def_id).1;
996//     param_ty.iter().find_map(|&ty| {
997//         // Peel off references and pointers to get to the underlying type
998//         let peeled_ty = ty.peel_refs();
999//         match peeled_ty.kind() {
1000//             TyKind::Adt(adt_def, _raw_list) => {
1001//                 // Compare the type name to our identifier
1002//                 let name = tcx.item_name(adt_def.did()).to_string();
1003//                 if name == type_ident {
1004//                     return Some(peeled_ty);
1005//                 }
1006//             }
1007//             _ => {}
1008//         }
1009//         None
1010//     })
1011// }
1012
1013pub fn reflect_generic<'tcx>(
1014    generic_mapping: &FxHashMap<String, Ty<'tcx>>,
1015    ty: Ty<'tcx>,
1016) -> Ty<'tcx> {
1017    match ty.kind() {
1018        TyKind::Param(param_ty) => {
1019            let generic_name = param_ty.name.to_string();
1020            if let Some(actual_ty) = generic_mapping.get(&generic_name) {
1021                return *actual_ty;
1022            }
1023        }
1024        _ => {}
1025    }
1026    ty
1027}
1028
1029// src_var = 0: for constructor
1030// src_var = 1: for methods
1031pub fn has_tainted_fields(tcx: TyCtxt, def_id: DefId, src_var: u32) -> bool {
1032    let mut dataflow_analyzer = DataFlowAnalyzer::new(tcx, false);
1033    dataflow_analyzer.build_graph(def_id);
1034
1035    let body = tcx.optimized_mir(def_id);
1036    let params = &body.args_iter().collect::<Vec<_>>();
1037    rap_info!("params {:?}", params);
1038    let self_local = Local::from(src_var);
1039
1040    let flowing_params: Vec<Local> = params
1041        .iter()
1042        .filter(|&&param_local| {
1043            dataflow_analyzer.has_flow_between(def_id, self_local, param_local)
1044                && self_local != param_local
1045        })
1046        .copied()
1047        .collect();
1048
1049    if !flowing_params.is_empty() {
1050        rap_info!(
1051            "Taint flow found from self to other parameters: {:?}",
1052            flowing_params
1053        );
1054        true
1055    } else {
1056        false
1057    }
1058}
1059
1060// 修改返回值类型为调用链的向量
1061pub fn get_all_std_unsafe_chains(tcx: TyCtxt, def_id: DefId) -> Vec<Vec<String>> {
1062    let mut results = Vec::new();
1063    let mut visited = HashSet::new(); // 避免循环调用
1064    let mut current_chain = Vec::new();
1065
1066    // 开始DFS遍历
1067    dfs_find_unsafe_chains(tcx, def_id, &mut current_chain, &mut results, &mut visited);
1068    results
1069}
1070
1071// DFS递归查找unsafe调用链
1072fn dfs_find_unsafe_chains(
1073    tcx: TyCtxt,
1074    def_id: DefId,
1075    current_chain: &mut Vec<String>,
1076    results: &mut Vec<Vec<String>>,
1077    visited: &mut HashSet<DefId>,
1078) {
1079    // 避免循环调用
1080    if visited.contains(&def_id) {
1081        return;
1082    }
1083    visited.insert(def_id);
1084
1085    let current_func_name = get_cleaned_def_path_name(tcx, def_id);
1086    current_chain.push(current_func_name.clone());
1087
1088    // 获取当前函数的所有unsafe callee
1089    let unsafe_callees = find_unsafe_callees_in_function(tcx, def_id);
1090
1091    if unsafe_callees.is_empty() {
1092        // 如果没有更多的unsafe callee,保存当前链
1093        results.push(current_chain.clone());
1094    } else {
1095        // 对每个unsafe callee继续DFS
1096        for (callee_def_id, callee_name) in unsafe_callees {
1097            dfs_find_unsafe_chains(tcx, callee_def_id, current_chain, results, visited);
1098        }
1099    }
1100
1101    // 回溯
1102    current_chain.pop();
1103    visited.remove(&def_id);
1104}
1105
1106// 在函数中查找所有unsafe callee
1107fn find_unsafe_callees_in_function(tcx: TyCtxt, def_id: DefId) -> Vec<(DefId, String)> {
1108    let mut callees = Vec::new();
1109
1110    if let Some(body) = try_get_mir(tcx, def_id) {
1111        for bb in body.basic_blocks.iter() {
1112            if let Some(terminator) = &bb.terminator {
1113                if let Some((callee_def_id, callee_name)) = extract_unsafe_callee(tcx, terminator) {
1114                    callees.push((callee_def_id, callee_name));
1115                }
1116            }
1117        }
1118    }
1119
1120    callees
1121}
1122
1123// 从terminator中提取unsafe callee
1124fn extract_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Option<(DefId, String)> {
1125    if let TerminatorKind::Call { func, .. } = &terminator.kind {
1126        if let Operand::Constant(func_constant) = func {
1127            if let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() {
1128                if check_safety(tcx, *callee_def_id) {
1129                    let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
1130                    return Some((*callee_def_id, func_name));
1131                }
1132            }
1133        }
1134    }
1135    None
1136}
1137
1138// 安全地获取MIR,处理可能无法获取MIR的情况
1139fn try_get_mir(tcx: TyCtxt<'_>, def_id: DefId) -> Option<&rustc_middle::mir::Body<'_>> {
1140    if tcx.is_mir_available(def_id) {
1141        Some(tcx.optimized_mir(def_id))
1142    } else {
1143        None
1144    }
1145}
1146
1147// 清理def path名称的辅助函数
1148pub fn get_cleaned_def_path_name(tcx: TyCtxt<'_>, def_id: DefId) -> String {
1149    // 这里实现你的路径清理逻辑
1150    tcx.def_path_str(def_id)
1151}
1152
1153// 打印调用链的函数
1154pub fn print_unsafe_chains(chains: &[Vec<String>]) {
1155    if chains.is_empty() {
1156        return;
1157    }
1158
1159    println!("==============================");
1160    println!("Found {} unsafe call chain(s):", chains.len());
1161    for (i, chain) in chains.iter().enumerate() {
1162        println!("Chain {}:", i + 1);
1163        for (j, func_name) in chain.iter().enumerate() {
1164            let indent = "  ".repeat(j);
1165            println!("{}{}-> {}", indent, if j > 0 { " " } else { "" }, func_name);
1166        }
1167        println!();
1168    }
1169}
1170
1171pub fn get_all_std_fns_by_rustc_public(tcx: TyCtxt) -> Vec<DefId> {
1172    let mut all_std_fn_def = Vec::new();
1173    let mut results = Vec::new();
1174    let mut core_fn_def: Vec<_> = rustc_public::find_crates("core")
1175        .iter()
1176        .flat_map(|krate| krate.fn_defs())
1177        .collect();
1178    let mut std_fn_def: Vec<_> = rustc_public::find_crates("std")
1179        .iter()
1180        .flat_map(|krate| krate.fn_defs())
1181        .collect();
1182    let mut alloc_fn_def: Vec<_> = rustc_public::find_crates("alloc")
1183        .iter()
1184        .flat_map(|krate| krate.fn_defs())
1185        .collect();
1186    all_std_fn_def.append(&mut core_fn_def);
1187    all_std_fn_def.append(&mut std_fn_def);
1188    all_std_fn_def.append(&mut alloc_fn_def);
1189
1190    for fn_def in &all_std_fn_def {
1191        let def_id = crate::def_id::to_internal(fn_def, tcx);
1192        results.push(def_id);
1193    }
1194    results
1195}
1196
1197// pub fn generate_uig_for_one_struct(tcx: TyCtxt, def_id: DefId, adt_def_id: DefId) {
1198//     let adt_def = get_adt_def_id_by_adt_method(tcx, def_id);
1199//     let mut uig_entrance = UnsafetyIsolationCheck::new(tcx);
1200//     let impls = tcx.inherent_impls(adt_def.unwrap());
1201//     let impl_results = get_impl_items_of_struct(tcx, adt_def.unwrap());
1202//     println!("impls {:?}", impl_results);
1203//     for impl_def_id in impls {
1204//         // println!("impls {:?}", tcx.associated_item_def_ids(impl_def_id));
1205//         for item in tcx.associated_item_def_ids(impl_def_id) {
1206//             if tcx.def_kind(item) == DefKind::Fn || tcx.def_kind(item) == DefKind::AssocFn {
1207//                 println!("item {:?}", item);
1208//                 uig_entrance.insert_uig(*item, get_callees(tcx, *item), get_cons(tcx, *item));
1209//             }
1210//         }
1211//     }
1212
1213//     let mut dot_strs = Vec::new();
1214//     for uig in &uig_entrance.uigs {
1215//         let dot_str = uig.generate_dot_str();
1216//         dot_strs.push(dot_str);
1217//     }
1218//     for uig in &uig_entrance.single {
1219//         let dot_str = uig.generate_dot_str();
1220//         dot_strs.push(dot_str);
1221//     }
1222//     render_dot_graphs(dot_strs);
1223// }