rapx/analysis/utils/
fn_info.rs

1#[allow(unused)]
2use crate::analysis::senryx::contracts::property::PropertyContract;
3use crate::analysis::senryx::matcher::parse_unsafe_api;
4use crate::analysis::unsafety_isolation::generate_dot::NodeType;
5use crate::rap_debug;
6use crate::rap_warn;
7use rustc_data_structures::fx::FxHashMap;
8use rustc_hir::def::DefKind;
9use rustc_hir::def_id::DefId;
10use rustc_hir::Attribute;
11use rustc_hir::ImplItemKind;
12use rustc_middle::mir::BinOp;
13use rustc_middle::mir::Local;
14use rustc_middle::mir::{BasicBlock, Terminator};
15use rustc_middle::ty::{AssocKind, Mutability, Ty, TyCtxt, TyKind};
16use rustc_middle::{
17    mir::{Operand, TerminatorKind},
18    ty,
19};
20use rustc_span::def_id::LocalDefId;
21use rustc_span::kw;
22use rustc_span::sym;
23use std::collections::HashMap;
24use std::collections::HashSet;
25use std::fmt::Debug;
26use std::hash::Hash;
27use syn::Expr;
28
29pub fn generate_node_ty(tcx: TyCtxt<'_>, def_id: DefId) -> NodeType {
30    (def_id, check_safety(tcx, def_id), get_type(tcx, def_id))
31}
32
33pub fn check_visibility(tcx: TyCtxt<'_>, func_defid: DefId) -> bool {
34    if !tcx.visibility(func_defid).is_public() {
35        return false;
36    }
37    // if func_defid.is_local() {
38    //     if let Some(local_defid) = func_defid.as_local() {
39    //         let module_moddefid = tcx.parent_module_from_def_id(local_defid);
40    //         let module_defid = module_moddefid.to_def_id();
41    //         if !tcx.visibility(module_defid).is_public() {
42    //             // println!("module def id {:?}",UigUnit::get_cleaned_def_path_name(tcx, module_defid));
43    //             return Self::is_re_exported(tcx, func_defid, module_moddefid.to_local_def_id());
44    //         }
45    //     }
46    // }
47    true
48}
49
50pub fn is_re_exported(tcx: TyCtxt<'_>, target_defid: DefId, module_defid: LocalDefId) -> bool {
51    for child in tcx.module_children_local(module_defid) {
52        if child.vis.is_public() {
53            if let Some(def_id) = child.res.opt_def_id() {
54                if def_id == target_defid {
55                    return true;
56                }
57            }
58        }
59    }
60    false
61}
62
63pub fn print_hashset<T: std::fmt::Debug>(set: &HashSet<T>) {
64    for item in set {
65        println!("{:?}", item);
66    }
67    println!("---------------");
68}
69
70pub fn get_cleaned_def_path_name(tcx: TyCtxt<'_>, def_id: DefId) -> String {
71    let def_id_str = format!("{:?}", def_id);
72    let mut parts: Vec<&str> = def_id_str
73        .split("::")
74        // .filter(|part| !part.contains("{")) // 去除包含 "{" 的部分
75        .collect();
76
77    let mut remove_first = false;
78    if let Some(first_part) = parts.get_mut(0) {
79        if first_part.contains("core") {
80            *first_part = "core";
81        } else if first_part.contains("std") {
82            *first_part = "std";
83        } else if first_part.contains("alloc") {
84            *first_part = "alloc";
85        } else {
86            remove_first = true;
87        }
88    }
89    if remove_first && !parts.is_empty() {
90        parts.remove(0);
91    }
92
93    let new_parts: Vec<String> = parts
94        .into_iter()
95        .filter_map(|s| {
96            if s.contains("{") {
97                if remove_first {
98                    get_struct_name(tcx, def_id)
99                } else {
100                    None
101                }
102            } else {
103                Some(s.to_string())
104            }
105        })
106        .collect();
107
108    let mut cleaned_path = new_parts.join("::");
109    cleaned_path = cleaned_path.trim_end_matches(')').to_string();
110    cleaned_path
111}
112
113pub fn get_sp_json() -> serde_json::Value {
114    let json_data: serde_json::Value =
115        serde_json::from_str(include_str!("../unsafety_isolation/data/std_sps.json"))
116            .expect("Unable to parse JSON");
117    json_data
118}
119
120pub fn get_std_api_signature_json() -> serde_json::Value {
121    let json_data: serde_json::Value =
122        serde_json::from_str(include_str!("../unsafety_isolation/data/std_sig.json"))
123            .expect("Unable to parse JSON");
124    json_data
125}
126
127pub fn get_sp(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<String> {
128    let cleaned_path_name = get_cleaned_def_path_name(tcx, def_id);
129    let json_data: serde_json::Value = get_sp_json();
130
131    if let Some(function_info) = json_data.get(&cleaned_path_name) {
132        if let Some(sp_list) = function_info.get("0") {
133            let mut result = HashSet::new();
134            if let Some(sp_array) = sp_list.as_array() {
135                for sp in sp_array {
136                    if let Some(sp_name) = sp.as_str() {
137                        result.insert(sp_name.to_string());
138                    }
139                }
140            }
141            return result;
142        }
143    }
144    HashSet::new()
145}
146
147pub fn get_struct_name(tcx: TyCtxt<'_>, def_id: DefId) -> Option<String> {
148    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
149        if let Some(impl_id) = assoc_item.impl_container(tcx) {
150            let ty = tcx.type_of(impl_id).skip_binder();
151            let type_name = ty.to_string();
152            let struct_name = type_name
153                .split('<')
154                .next()
155                .unwrap_or("")
156                .split("::")
157                .last()
158                .unwrap_or("")
159                .to_string();
160
161            return Some(struct_name);
162        }
163    }
164    None
165}
166
167pub fn check_safety(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
168    let poly_fn_sig = tcx.fn_sig(def_id);
169    let fn_sig = poly_fn_sig.skip_binder();
170    fn_sig.safety() == rustc_hir::Safety::Unsafe
171}
172
173//retval: 0-constructor, 1-method, 2-function
174pub fn get_type(tcx: TyCtxt<'_>, def_id: DefId) -> usize {
175    let mut node_type = 2;
176    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
177        match assoc_item.kind {
178            AssocKind::Fn { has_self, .. } => {
179                if has_self {
180                    node_type = 1;
181                } else {
182                    let fn_sig = tcx.fn_sig(def_id).skip_binder();
183                    let output = fn_sig.output().skip_binder();
184                    // return type is 'Self'
185                    if output.is_param(0) {
186                        node_type = 0;
187                    }
188                    // return type is struct's name
189                    if let Some(impl_id) = assoc_item.impl_container(tcx) {
190                        let ty = tcx.type_of(impl_id).skip_binder();
191                        if output == ty {
192                            node_type = 0;
193                        }
194                    }
195                    match output.kind() {
196                        TyKind::Ref(_, ref_ty, _) => {
197                            if ref_ty.is_param(0) {
198                                node_type = 0;
199                            }
200                            if let Some(impl_id) = assoc_item.impl_container(tcx) {
201                                let ty = tcx.type_of(impl_id).skip_binder();
202                                if *ref_ty == ty {
203                                    node_type = 0;
204                                }
205                            }
206                        }
207                        TyKind::Adt(adt_def, substs) => {
208                            if adt_def.is_enum()
209                                && (tcx.is_diagnostic_item(sym::Option, adt_def.did())
210                                    || tcx.is_diagnostic_item(sym::Result, adt_def.did())
211                                    || tcx.is_diagnostic_item(kw::Box, adt_def.did()))
212                            {
213                                let inner_ty = substs.type_at(0);
214                                if inner_ty.is_param(0) {
215                                    node_type = 0;
216                                }
217                                if let Some(impl_id) = assoc_item.impl_container(tcx) {
218                                    let ty_impl = tcx.type_of(impl_id).skip_binder();
219                                    if inner_ty == ty_impl {
220                                        node_type = 0;
221                                    }
222                                }
223                            }
224                        }
225                        _ => {}
226                    }
227                }
228            }
229            _ => todo!(),
230        }
231    }
232    node_type
233}
234
235pub fn get_adt_ty(tcx: TyCtxt<'_>, def_id: DefId) -> Option<Ty> {
236    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
237        if let Some(impl_id) = assoc_item.impl_container(tcx) {
238            return Some(tcx.type_of(impl_id).skip_binder());
239        }
240    }
241    None
242}
243
244pub fn get_cons(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<NodeType> {
245    let mut cons = Vec::new();
246    if tcx.def_kind(def_id) == DefKind::Fn || get_type(tcx, def_id) == 0 {
247        return cons;
248    }
249    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
250        if let Some(impl_id) = assoc_item.impl_container(tcx) {
251            // get struct ty
252            let ty = tcx.type_of(impl_id).skip_binder();
253            if let Some(adt_def) = ty.ty_adt_def() {
254                let adt_def_id = adt_def.did();
255                let impls = tcx.inherent_impls(adt_def_id);
256                for impl_def_id in impls {
257                    for item in tcx.associated_item_def_ids(impl_def_id) {
258                        if (tcx.def_kind(item) == DefKind::Fn
259                            || tcx.def_kind(item) == DefKind::AssocFn)
260                            && get_type(tcx, *item) == 0
261                        {
262                            cons.push(generate_node_ty(tcx, *item));
263                        }
264                    }
265                }
266            }
267        }
268    }
269    cons
270}
271
272pub fn get_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
273    let mut callees = HashSet::new();
274    if tcx.is_mir_available(def_id) {
275        let body = tcx.optimized_mir(def_id);
276        for bb in body.basic_blocks.iter() {
277            if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
278                if let Operand::Constant(func_constant) = func {
279                    if let ty::FnDef(ref callee_def_id, _) = func_constant.const_.ty().kind() {
280                        if check_safety(tcx, *callee_def_id)
281                        // && check_visibility(tcx, *callee_def_id)
282                        {
283                            let sp_set = get_sp(tcx, *callee_def_id);
284                            if sp_set.len() != 0 {
285                                callees.insert(*callee_def_id);
286                            }
287                        }
288                    }
289                }
290            }
291        }
292    }
293    callees
294}
295
296// return all the impls def id of corresponding struct
297pub fn get_impls_for_struct(tcx: TyCtxt<'_>, struct_def_id: DefId) -> Vec<DefId> {
298    let mut impls = Vec::new();
299    for impl_item_id in tcx.hir_crate_items(()).impl_items() {
300        let impl_item = tcx.hir_impl_item(impl_item_id);
301        match impl_item.kind {
302            ImplItemKind::Type(ty) => {
303                if let rustc_hir::TyKind::Path(ref qpath) = ty.kind {
304                    if let rustc_hir::QPath::Resolved(_, path) = qpath {
305                        if let rustc_hir::def::Res::Def(_, ref def_id) = path.res {
306                            if *def_id == struct_def_id {
307                                impls.push(impl_item.owner_id.to_def_id());
308                            }
309                        }
310                    }
311                }
312            }
313            _ => (),
314        }
315    }
316    impls
317}
318
319pub fn get_adt_def_id_by_adt_method(tcx: TyCtxt<'_>, def_id: DefId) -> Option<DefId> {
320    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
321        if let Some(impl_id) = assoc_item.impl_container(tcx) {
322            // get struct ty
323            let ty = tcx.type_of(impl_id).skip_binder();
324            if let Some(adt_def) = ty.ty_adt_def() {
325                return Some(adt_def.did());
326            }
327        }
328    }
329    None
330}
331
332// get the pointee or wrapped type
333pub fn get_pointee(matched_ty: Ty<'_>) -> Ty<'_> {
334    // progress_info!("get_pointee: > {:?} as type: {:?}", matched_ty, matched_ty.kind());
335    let pointee = if let ty::RawPtr(ty_mut, _) = matched_ty.kind() {
336        get_pointee(*ty_mut)
337    } else if let ty::Ref(_, referred_ty, _) = matched_ty.kind() {
338        get_pointee(*referred_ty)
339    } else {
340        matched_ty
341    };
342    pointee
343}
344
345pub fn is_ptr(matched_ty: Ty<'_>) -> bool {
346    if let ty::RawPtr(_, _) = matched_ty.kind() {
347        return true;
348    }
349    false
350}
351
352pub fn is_ref(matched_ty: Ty<'_>) -> bool {
353    if let ty::Ref(_, _, _) = matched_ty.kind() {
354        return true;
355    }
356    false
357}
358
359pub fn is_slice(matched_ty: Ty<'_>) -> Option<Ty> {
360    if let ty::Slice(inner) = matched_ty.kind() {
361        return Some(*inner);
362    }
363    None
364}
365
366pub fn has_mut_self_param(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
367    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
368        match assoc_item.kind {
369            AssocKind::Fn { has_self, .. } => {
370                if has_self {
371                    let body = tcx.optimized_mir(def_id);
372                    let fst_arg = body.local_decls[Local::from_usize(1)].clone();
373                    let ty = fst_arg.ty;
374                    let is_mut_ref =
375                        matches!(ty.kind(), ty::Ref(_, _, mutbl) if *mutbl == Mutability::Mut);
376                    return fst_arg.mutability.is_mut() || is_mut_ref;
377                }
378            }
379            _ => (),
380        }
381    }
382    false
383}
384
385// Input the adt def id
386// Return set of (mutable method def_id, fields can be modified)
387pub fn get_all_mutable_methods(tcx: TyCtxt<'_>, def_id: DefId) -> HashMap<DefId, HashSet<usize>> {
388    let mut results = HashMap::new();
389    let adt_def = get_adt_def_id_by_adt_method(tcx, def_id);
390    let public_fields = adt_def.map_or_else(HashSet::new, |def| get_public_fields(tcx, def));
391    let impl_vec = adt_def.map_or_else(Vec::new, |def| get_impls_for_struct(tcx, def));
392    for impl_id in impl_vec {
393        let associated_items = tcx.associated_items(impl_id);
394        for item in associated_items.in_definition_order() {
395            if let AssocKind::Fn {
396                name: _,
397                has_self: _,
398            } = item.kind
399            {
400                let item_def_id = item.def_id;
401                if has_mut_self_param(tcx, item_def_id) {
402                    // TODO: using dataflow to detect field modificaiton, combined with public fields
403                    let modified_fields = public_fields.clone();
404                    results.insert(item_def_id, modified_fields);
405                }
406            }
407        }
408    }
409    results
410}
411
412// Check each field's visibility, return the public fields vec
413pub fn get_public_fields(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<usize> {
414    let adt_def = tcx.adt_def(def_id);
415    adt_def
416        .all_fields()
417        .enumerate()
418        .filter_map(|(index, field_def)| tcx.visibility(field_def.did).is_public().then_some(index))
419        .collect()
420}
421
422// general function for displaying hashmap
423pub fn display_hashmap<K, V>(map: &HashMap<K, V>, level: usize)
424where
425    K: Ord + Debug + Hash,
426    V: Debug,
427{
428    let indent = "  ".repeat(level);
429    let mut sorted_keys: Vec<_> = map.keys().collect();
430    sorted_keys.sort();
431
432    for key in sorted_keys {
433        if let Some(value) = map.get(key) {
434            println!("{}{:?}: {:?}", indent, key, value);
435        }
436    }
437}
438
439pub fn get_all_std_unsafe_callees(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<String> {
440    let mut results = Vec::new();
441    let body = tcx.optimized_mir(def_id);
442    let bb_len = body.basic_blocks.len();
443    for i in 0..bb_len {
444        let callees = match_std_unsafe_callee(
445            tcx,
446            body.basic_blocks[BasicBlock::from_usize(i)]
447                .clone()
448                .terminator(),
449        );
450        results.extend(callees);
451    }
452    results
453}
454
455pub fn get_all_std_unsafe_callees_block_id(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<usize> {
456    let mut results = Vec::new();
457    let body = tcx.optimized_mir(def_id);
458    let bb_len = body.basic_blocks.len();
459    for i in 0..bb_len {
460        if match_std_unsafe_callee(
461            tcx,
462            body.basic_blocks[BasicBlock::from_usize(i)]
463                .clone()
464                .terminator(),
465        )
466        .is_empty()
467        {
468            results.push(i);
469        }
470    }
471    results
472}
473
474pub fn match_std_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
475    let mut results = Vec::new();
476    if let TerminatorKind::Call { func, .. } = &terminator.kind {
477        if let Operand::Constant(func_constant) = func {
478            if let ty::FnDef(ref callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
479                let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
480                if parse_unsafe_api(&func_name).is_some() {
481                    results.push(func_name);
482                }
483            }
484        }
485    }
486    results
487}
488
489// Bug definition: (1) strict -> weak & dst is mutable;
490//                 (2) _ -> strict
491pub fn is_strict_ty_convert<'tcx>(tcx: TyCtxt<'tcx>, src_ty: Ty<'tcx>, dst_ty: Ty<'tcx>) -> bool {
492    (is_strict_ty(tcx, src_ty) && dst_ty.is_mutable_ptr()) || is_strict_ty(tcx, dst_ty)
493}
494
495// strict ty: bool, str, adt fields containing bool or str;
496pub fn is_strict_ty<'tcx>(tcx: TyCtxt<'tcx>, ori_ty: Ty<'tcx>) -> bool {
497    let ty = get_pointee(ori_ty);
498    let mut flag = false;
499    if let TyKind::Adt(adt_def, substs) = ty.kind() {
500        if adt_def.is_struct() {
501            for field_def in adt_def.all_fields() {
502                flag |= is_strict_ty(tcx, field_def.ty(tcx, substs))
503            }
504        }
505    }
506    ty.is_bool() || ty.is_str() || flag
507}
508
509pub fn reverse_op(op: BinOp) -> BinOp {
510    match op {
511        BinOp::Lt => BinOp::Ge,
512        BinOp::Ge => BinOp::Lt,
513        BinOp::Le => BinOp::Gt,
514        BinOp::Gt => BinOp::Le,
515        BinOp::Eq => BinOp::Eq,
516        BinOp::Ne => BinOp::Ne,
517        _ => op,
518    }
519}
520
521/// Same with `generate_contract_from_annotation` but does not contain field types.
522pub fn generate_contract_from_annotation_without_field_types(
523    tcx: TyCtxt<'_>,
524    def_id: DefId,
525) -> Vec<(usize, Vec<usize>, PropertyContract)> {
526    let contracts_with_ty = generate_contract_from_annotation(tcx, def_id);
527
528    contracts_with_ty
529        .into_iter()
530        .map(|(local_id, fields_with_ty, contract)| {
531            let fields: Vec<usize> = fields_with_ty
532                .into_iter()
533                .map(|(field_idx, _)| field_idx)
534                .collect();
535            (local_id, fields, contract)
536        })
537        .collect()
538}
539
540/// Get the annotation in tag-std style.
541/// Then generate the contractual invariant states (CIS) for the args.
542/// This function will recognize the args name and record states to MIR variable (represent by usize).
543/// Return value means Vec<(local_id, fields of this local, contracts)>
544pub fn generate_contract_from_annotation(
545    tcx: TyCtxt<'_>,
546    def_id: DefId,
547) -> Vec<(usize, Vec<(usize, Ty)>, PropertyContract)> {
548    const REGISTER_TOOL: &str = "rapx";
549    let tool_attrs = tcx.get_all_attrs(def_id).filter(|attr| {
550        if let Attribute::Unparsed(tool_attr) = attr
551            && tool_attr.path.segments[0].as_str() == REGISTER_TOOL
552        {
553            return true;
554        }
555        false
556    });
557    let mut results = Vec::new();
558    for attr in tool_attrs {
559        let attr_str = rustc_hir_pretty::attribute_to_string(&tcx, attr);
560        let safety_attr =
561            safety_parser::property_attr::parse_inner_attr_from_str(attr_str.as_str()).unwrap();
562        let attr_name = safety_attr.name;
563        let attr_kind = safety_attr.kind;
564        let contract = PropertyContract::new(tcx, def_id, attr_kind, attr_name, &safety_attr.expr);
565        let (local, fields) = parse_cis_local(tcx, def_id, safety_attr.expr);
566        results.push((local, fields, contract));
567    }
568    // if results.len() > 0 {
569    //     rap_warn!("results:\n{:?}", results);
570    // }
571    results
572}
573
574/// Parse attr.expr into local id and local fields.
575///
576/// Example:
577/// ```
578/// #[rapx::inner(property = ValidPtr(ptr, u32, 1), kind = "precond")]
579/// #[rapx::inner(property = ValidNum(region.size>=0), kind = "precond")]
580/// pub fn xor_secret_region(ptr: *mut u32, region:SecretRegion) -> u32 {...}
581/// ```
582///
583/// The first attribute will be parsed as (1, []).
584///     -> "1" means the first arg "ptr", "[]" means no fields.
585/// The second attribute will be parsed as (2, [1]).
586///     -> "2" means the second arg "region", "[1]" means "size" is region's second field.
587///
588/// If this function doesn't have args, then it will return default pattern: (0, Vec::new())
589pub fn parse_cis_local(
590    tcx: TyCtxt<'_>,
591    def_id: DefId,
592    expr: Vec<Expr>,
593) -> (usize, Vec<(usize, Ty<'_>)>) {
594    // match expr with cis local
595    for e in expr {
596        if let Some((base, fields, _ty)) = parse_expr_into_local_and_ty(tcx, def_id, &e) {
597            return (base, fields);
598        }
599    }
600    (0, Vec::new())
601}
602
603/// parse single expr into (local, fields, ty)
604pub fn parse_expr_into_local_and_ty<'tcx>(
605    tcx: TyCtxt<'tcx>,
606    def_id: DefId,
607    expr: &Expr,
608) -> Option<(usize, Vec<(usize, Ty<'tcx>)>, Ty<'tcx>)> {
609    if let Some((base_ident, fields)) = access_ident_recursive(&expr) {
610        let (param_names, param_tys) = parse_signature(tcx, def_id);
611        if param_names[0] == "0".to_string() {
612            return None;
613        }
614        if let Some(param_index) = param_names.iter().position(|name| name == &base_ident) {
615            let mut current_ty = param_tys[param_index];
616            let mut field_indices = Vec::new();
617            for field_name in fields {
618                // peel the ref and ptr
619                let peeled_ty = current_ty.peel_refs();
620                if let rustc_middle::ty::TyKind::Adt(adt_def, arg_list) = *peeled_ty.kind() {
621                    let variant = adt_def.non_enum_variant();
622                    // 1. if field_name is number, then parse it as usize
623                    if let Ok(field_idx) = field_name.parse::<usize>() {
624                        if field_idx < variant.fields.len() {
625                            current_ty = variant.fields[rustc_abi::FieldIdx::from_usize(field_idx)]
626                                .ty(tcx, arg_list);
627                            field_indices.push((field_idx, current_ty));
628                            continue;
629                        }
630                    }
631                    // 2. if field_name is String, then compare it with current ty's field names
632                    if let Some((idx, _)) = variant
633                        .fields
634                        .iter()
635                        .enumerate()
636                        .find(|(_, f)| f.ident(tcx).name.to_string() == field_name.clone())
637                    {
638                        current_ty =
639                            variant.fields[rustc_abi::FieldIdx::from_usize(idx)].ty(tcx, arg_list);
640                        field_indices.push((idx, current_ty));
641                    }
642                    // 3. if field_name can not match any fields, then break
643                    else {
644                        break; // TODO:
645                    }
646                }
647                // if current ty is not Adt, then break the loop
648                else {
649                    break; // TODO:
650                }
651            }
652            // It's different from default one, we return the result as param_index+1 because param_index count from 0.
653            // But 0 in MIR is the ret index, the args' indexes begin from 1.
654            return Some((param_index + 1, field_indices, current_ty));
655        }
656    }
657    None
658}
659
660/// Return the Vecs of args' names and types
661/// This function will handle outside def_id by different way.
662pub fn parse_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
663    // 0. If the def id is local
664    if def_id.as_local().is_some() {
665        return parse_local_signature(tcx, def_id);
666    } else {
667        rap_debug!("{:?} is not local def id.", def_id);
668        return parse_outside_signature(tcx, def_id);
669    };
670}
671
672/// Return the Vecs of args' names and types of outside functions.
673fn parse_outside_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
674    let sig = tcx.fn_sig(def_id).skip_binder();
675    let param_tys: Vec<Ty<'tcx>> = sig.inputs().skip_binder().iter().copied().collect();
676
677    // 1. check pre-defined std unsafe api signature
678    if let Some(args_name) = get_known_std_names(tcx, def_id) {
679        // rap_warn!(
680        //     "function {:?} has arg: {:?}, arg types: {:?}",
681        //     def_id,
682        //     args_name,
683        //     param_tys
684        // );
685        return (args_name, param_tys);
686    }
687
688    // 2. TODO: If can not find known std apis, then use numbers like `0`,`1`,... to represent args.
689    let args_name = (0..param_tys.len()).map(|i| format!("{}", i)).collect();
690    rap_debug!(
691        "function {:?} has arg: {:?}, arg types: {:?}",
692        def_id,
693        args_name,
694        param_tys
695    );
696    return (args_name, param_tys);
697}
698
699/// We use a json to record known std apis' arg names.
700/// This function will search the json and return the names.
701/// Notes: If std gets updated, the json may still record old ones.
702fn get_known_std_names<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Vec<String>> {
703    let std_func_name = get_cleaned_def_path_name(tcx, def_id);
704    let json_data: serde_json::Value = get_std_api_signature_json();
705
706    if let Some(arg_info) = json_data.get(&std_func_name) {
707        if let Some(args_name) = arg_info.as_array() {
708            // set default value to arg name
709            if args_name.len() == 0 {
710                return Some(vec!["0".to_string()]);
711            }
712            // iterate and collect
713            let mut result = Vec::new();
714            for arg in args_name {
715                if let Some(sp_name) = arg.as_str() {
716                    result.push(sp_name.to_string());
717                }
718            }
719            return Some(result);
720        }
721    }
722    None
723}
724
725/// Return the Vecs of args' names and types of local functions.
726pub fn parse_local_signature(tcx: TyCtxt<'_>, def_id: DefId) -> (Vec<String>, Vec<Ty>) {
727    // 1. parse local def_id and get arg list
728    let local_def_id = def_id.as_local().unwrap();
729    let hir_body = tcx.hir_body_owned_by(local_def_id);
730    if hir_body.params.len() == 0 {
731        return (vec!["0".to_string()], Vec::new());
732    }
733    // 2. contruct the vec of param and param ty
734    let params = hir_body.params;
735    let typeck_results = tcx.typeck_body(hir_body.id());
736    let mut param_names = Vec::new();
737    let mut param_tys = Vec::new();
738    for param in params {
739        match param.pat.kind {
740            rustc_hir::PatKind::Binding(_, _, ident, _) => {
741                param_names.push(ident.name.to_string());
742                let ty = typeck_results.pat_ty(param.pat);
743                param_tys.push(ty);
744            }
745            _ => {
746                param_names.push(String::new());
747                param_tys.push(typeck_results.pat_ty(param.pat));
748            }
749        }
750    }
751    (param_names, param_tys)
752}
753
754/// return the (ident, its fields) of the expr.
755///
756/// illustrated cases :
757///    ptr	-> ("ptr", [])
758///    region.size	-> ("region", ["size"])
759///    tuple.0.value -> ("tuple", ["0", "value"])
760pub fn access_ident_recursive(expr: &Expr) -> Option<(String, Vec<String>)> {
761    match expr {
762        Expr::Path(syn::ExprPath { path, .. }) => {
763            if path.segments.len() == 1 {
764                rap_debug!("expr2 {:?}", expr);
765                let ident = path.segments[0].ident.to_string();
766                Some((ident, Vec::new()))
767            } else {
768                None
769            }
770        }
771        // get the base and fields recursively
772        Expr::Field(syn::ExprField { base, member, .. }) => {
773            let (base_ident, mut fields) =
774                if let Some((base_ident, fields)) = access_ident_recursive(base) {
775                    (base_ident, fields)
776                } else {
777                    return None;
778                };
779            let field_name = match member {
780                syn::Member::Named(ident) => ident.to_string(),
781                syn::Member::Unnamed(index) => index.index.to_string(),
782            };
783            fields.push(field_name);
784            Some((base_ident, fields))
785        }
786        _ => None,
787    }
788}
789
790/// parse expr into number.
791pub fn parse_expr_into_number(expr: &Expr) -> Option<usize> {
792    if let Expr::Lit(expr_lit) = expr {
793        if let syn::Lit::Int(lit_int) = &expr_lit.lit {
794            return lit_int.base10_parse::<usize>().ok();
795        }
796    }
797    None
798}
799
800/// Match a type identifier string to a concrete Rust type
801///
802/// This function attempts to match a given type identifier (e.g., "u32", "T", "MyStruct")
803/// to a type in the provided parameter type list. It handles:
804/// 1. Built-in primitive types (u32, usize, etc.)
805/// 2. Generic type parameters (T, U, etc.)
806/// 3. User-defined types found in the parameter list
807///
808/// Arguments:
809/// - `tcx`: Type context for querying compiler information
810/// - `type_ident`: String representing the type identifier to match
811/// - `param_ty`: List of parameter types from the function signature
812///
813/// Returns:
814/// - `Some(Ty)` if a matching type is found
815/// - `None` if no match is found
816pub fn match_ty_with_ident(tcx: TyCtxt<'_>, def_id: DefId, type_ident: String) -> Option<Ty> {
817    // 1. First check for built-in primitive types
818    if let Some(primitive_ty) = match_primitive_type(tcx, type_ident.clone()) {
819        return Some(primitive_ty);
820    }
821    // 2. Check if the identifier matches any generic type parameter
822    return find_generic_param(tcx, def_id, type_ident.clone());
823    // 3. Check if the identifier matches any user-defined type in the parameters
824    // find_user_defined_type(tcx, def_id, type_ident)
825}
826
827/// Match built-in primitive types from String
828fn match_primitive_type(tcx: TyCtxt<'_>, type_ident: String) -> Option<Ty> {
829    match type_ident.as_str() {
830        "i8" => Some(tcx.types.i8),
831        "i16" => Some(tcx.types.i16),
832        "i32" => Some(tcx.types.i32),
833        "i64" => Some(tcx.types.i64),
834        "i128" => Some(tcx.types.i128),
835        "isize" => Some(tcx.types.isize),
836        "u8" => Some(tcx.types.u8),
837        "u16" => Some(tcx.types.u16),
838        "u32" => Some(tcx.types.u32),
839        "u64" => Some(tcx.types.u64),
840        "u128" => Some(tcx.types.u128),
841        "usize" => Some(tcx.types.usize),
842        "f16" => Some(tcx.types.f16),
843        "f32" => Some(tcx.types.f32),
844        "f64" => Some(tcx.types.f64),
845        "f128" => Some(tcx.types.f128),
846        "bool" => Some(tcx.types.bool),
847        "char" => Some(tcx.types.char),
848        "str" => Some(tcx.types.str_),
849        _ => None,
850    }
851}
852
853/// Find generic type parameters in the parameter list
854fn find_generic_param(tcx: TyCtxt<'_>, def_id: DefId, type_ident: String) -> Option<Ty> {
855    rap_debug!(
856        "Searching for generic param: {} in {:?}",
857        type_ident,
858        def_id
859    );
860    let (_, param_tys) = parse_signature(tcx, def_id);
861    rap_debug!("Function parameter types: {:?} of {:?}", param_tys, def_id);
862    // 递归查找泛型参数
863    for &ty in &param_tys {
864        if let Some(found) = find_generic_in_ty(tcx, ty, &type_ident) {
865            return Some(found);
866        }
867    }
868
869    None
870}
871
872/// Iterate the args' types recursively and find the matched generic one.
873fn find_generic_in_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>, type_ident: &str) -> Option<Ty<'tcx>> {
874    match ty.kind() {
875        TyKind::Param(param_ty) => {
876            if param_ty.name.as_str() == type_ident {
877                return Some(ty);
878            }
879        }
880        TyKind::RawPtr(ty, _)
881        | TyKind::Ref(_, ty, _)
882        | TyKind::Slice(ty)
883        | TyKind::Array(ty, _) => {
884            if let Some(found) = find_generic_in_ty(tcx, *ty, type_ident) {
885                return Some(found);
886            }
887        }
888        TyKind::Tuple(tys) => {
889            for tuple_ty in tys.iter() {
890                if let Some(found) = find_generic_in_ty(tcx, tuple_ty, type_ident) {
891                    return Some(found);
892                }
893            }
894        }
895        TyKind::Adt(adt_def, substs) => {
896            let name = tcx.item_name(adt_def.did()).to_string();
897            if name == type_ident {
898                return Some(ty);
899            }
900            for field in adt_def.all_fields() {
901                let field_ty = field.ty(tcx, substs);
902                if let Some(found) = find_generic_in_ty(tcx, field_ty, type_ident) {
903                    return Some(found);
904                }
905            }
906        }
907        _ => {}
908    }
909    None
910}
911
912// /// Find user-defined types in the parameter list
913// fn find_user_defined_type(tcx: TyCtxt<'_>, def_id: DefId, type_ident: String) -> Option<Ty> {
914//     let param_ty = parse_signature(tcx, def_id).1;
915//     param_ty.iter().find_map(|&ty| {
916//         // Peel off references and pointers to get to the underlying type
917//         let peeled_ty = ty.peel_refs();
918//         match peeled_ty.kind() {
919//             TyKind::Adt(adt_def, _raw_list) => {
920//                 // Compare the type name to our identifier
921//                 let name = tcx.item_name(adt_def.did()).to_string();
922//                 if name == type_ident {
923//                     return Some(peeled_ty);
924//                 }
925//             }
926//             _ => {}
927//         }
928//         None
929//     })
930// }
931
932pub fn reflect_generic<'tcx>(
933    generic_mapping: &FxHashMap<String, Ty<'tcx>>,
934    ty: Ty<'tcx>,
935) -> Ty<'tcx> {
936    match ty.kind() {
937        TyKind::Param(param_ty) => {
938            let generic_name = param_ty.name.to_string();
939            if let Some(actual_ty) = generic_mapping.get(&generic_name) {
940                return *actual_ty;
941            }
942        }
943        _ => {}
944    }
945    ty
946}
947
948// fn find_first_chain_in_expr(expr: &Expr, param_names: &[String]) -> Option<(String, Vec<String>)> {
949//     if let Some(chain) = access_ident_recursive(expr) {
950//         if param_names.contains(&chain.0) {
951//             return Some(chain);
952//         }
953//     }
954
955//     let mut children = Vec::new();
956//     match expr {
957//         Expr::Array(expr) => children.extend(expr.elems.iter()),
958//         Expr::Assign(expr) => {
959//             children.push(&expr.left);
960//             children.push(&expr.right);
961//         }
962//         Expr::Binary(expr) => {
963//             children.push(&expr.left);
964//             children.push(&expr.right);
965//         }
966//         Expr::Call(expr) => {
967//             children.push(&expr.func);
968//             children.extend(expr.args.iter());
969//         }
970//         Expr::Cast(expr) => children.push(&expr.expr),
971//         Expr::Field(expr) => children.push(&expr.base),
972//         Expr::Group(expr) => children.push(&expr.expr),
973//         Expr::Index(expr) => {
974//             children.push(&expr.expr);
975//             children.push(&expr.index);
976//         }
977//         Expr::Let(expr) => children.push(&expr.expr),
978//         // Expr::Lit(expr) => children.push(&expr.lit),
979//         Expr::MethodCall(expr) => {
980//             children.push(&expr.receiver);
981//             children.extend(expr.args.iter());
982//         }
983//         Expr::Paren(expr) => children.push(&expr.expr),
984//         Expr::Range(expr) => {
985//             if let Some(start) = &expr.start {
986//                 children.push(start);
987//             }
988//             if let Some(end) = &expr.end {
989//                 children.push(end);
990//             }
991//             // if let Some(limits) = &expr.limits {
992//             //     let _ = limits;
993//             // }
994//         }
995//         Expr::Reference(expr) => children.push(&expr.expr),
996//         Expr::Repeat(expr) => {
997//             children.push(&expr.expr);
998//             children.push(&expr.len);
999//         }
1000//         // Expr::Struct(expr) => {
1001//         //     children.push(&expr.path);
1002//         //     for field in &expr.fields {
1003//         //         children.push(&field.expr);
1004//         //     }
1005//         //     if let Some(rest) = &expr.rest {
1006//         //         children.push(rest);
1007//         //     }
1008//         // }
1009//         Expr::Try(expr) => children.push(&expr.expr),
1010//         Expr::Tuple(expr) => children.extend(expr.elems.iter()),
1011//         Expr::Unary(expr) => children.push(&expr.expr),
1012//         _ => {}
1013//     }
1014
1015//     for child in children {
1016//         if let Some(chain) = find_first_chain_in_expr(child, param_names) {
1017//             return Some(chain);
1018//         }
1019//     }
1020
1021//     None
1022// }