rapx/analysis/utils/
fn_info.rs

1#[allow(unused)]
2use crate::analysis::senryx::contracts::property::PropertyContract;
3use crate::analysis::senryx::matcher::parse_unsafe_api;
4use crate::analysis::unsafety_isolation::generate_dot::NodeType;
5use crate::rap_debug;
6use crate::rap_warn;
7use rustc_data_structures::fx::FxHashMap;
8use rustc_hir::def::DefKind;
9use rustc_hir::def_id::DefId;
10use rustc_hir::Attribute;
11use rustc_hir::ImplItemKind;
12use rustc_middle::mir::BinOp;
13use rustc_middle::mir::Local;
14use rustc_middle::mir::{BasicBlock, Terminator};
15use rustc_middle::ty::{AssocKind, Mutability, Ty, TyCtxt, TyKind};
16use rustc_middle::{
17    mir::{Operand, TerminatorKind},
18    ty,
19};
20use rustc_span::def_id::LocalDefId;
21use rustc_span::kw;
22use rustc_span::sym;
23use std::collections::HashMap;
24use std::collections::HashSet;
25use std::fmt::Debug;
26use std::hash::Hash;
27use syn::Expr;
28
29pub fn generate_node_ty(tcx: TyCtxt, def_id: DefId) -> NodeType {
30    (def_id, check_safety(tcx, def_id), get_type(tcx, def_id))
31}
32
33pub fn check_visibility(tcx: TyCtxt, func_defid: DefId) -> bool {
34    if !tcx.visibility(func_defid).is_public() {
35        return false;
36    }
37    // if func_defid.is_local() {
38    //     if let Some(local_defid) = func_defid.as_local() {
39    //         let module_moddefid = tcx.parent_module_from_def_id(local_defid);
40    //         let module_defid = module_moddefid.to_def_id();
41    //         if !tcx.visibility(module_defid).is_public() {
42    //             // println!("module def id {:?}",UigUnit::get_cleaned_def_path_name(tcx, module_defid));
43    //             return Self::is_re_exported(tcx, func_defid, module_moddefid.to_local_def_id());
44    //         }
45    //     }
46    // }
47    true
48}
49
50pub fn is_re_exported(tcx: TyCtxt, target_defid: DefId, module_defid: LocalDefId) -> bool {
51    for child in tcx.module_children_local(module_defid) {
52        if child.vis.is_public() {
53            if let Some(def_id) = child.res.opt_def_id() {
54                if def_id == target_defid {
55                    return true;
56                }
57            }
58        }
59    }
60    false
61}
62
63pub fn print_hashset<T: std::fmt::Debug>(set: &HashSet<T>) {
64    for item in set {
65        println!("{:?}", item);
66    }
67    println!("---------------");
68}
69
70pub fn get_cleaned_def_path_name(tcx: TyCtxt, def_id: DefId) -> String {
71    let def_id_str = format!("{:?}", def_id);
72    let mut parts: Vec<&str> = def_id_str.split("::").collect();
73
74    let mut remove_first = false;
75    if let Some(first_part) = parts.get_mut(0) {
76        if first_part.contains("core") {
77            *first_part = "core";
78        } else if first_part.contains("std") {
79            *first_part = "std";
80        } else if first_part.contains("alloc") {
81            *first_part = "alloc";
82        } else {
83            remove_first = true;
84        }
85    }
86    if remove_first && !parts.is_empty() {
87        parts.remove(0);
88    }
89
90    let new_parts: Vec<String> = parts
91        .into_iter()
92        .filter_map(|s| {
93            if s.contains("{") {
94                if remove_first {
95                    get_struct_name(tcx, def_id)
96                } else {
97                    None
98                }
99            } else {
100                Some(s.to_string())
101            }
102        })
103        .collect();
104
105    let mut cleaned_path = new_parts.join("::");
106    cleaned_path = cleaned_path.trim_end_matches(')').to_string();
107    cleaned_path
108}
109
110pub fn get_sp_json() -> serde_json::Value {
111    let json_data: serde_json::Value =
112        serde_json::from_str(include_str!("../unsafety_isolation/data/std_sps.json"))
113            .expect("Unable to parse JSON");
114    json_data
115}
116
117pub fn get_std_api_signature_json() -> serde_json::Value {
118    let json_data: serde_json::Value =
119        serde_json::from_str(include_str!("../unsafety_isolation/data/std_sig.json"))
120            .expect("Unable to parse JSON");
121    json_data
122}
123
124pub fn get_sp(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<String> {
125    let cleaned_path_name = get_cleaned_def_path_name(tcx, def_id);
126    let json_data: serde_json::Value = get_sp_json();
127
128    if let Some(function_info) = json_data.get(&cleaned_path_name) {
129        if let Some(sp_list) = function_info.get("0") {
130            let mut result = HashSet::new();
131            if let Some(sp_array) = sp_list.as_array() {
132                for sp in sp_array {
133                    if let Some(sp_name) = sp.as_str() {
134                        result.insert(sp_name.to_string());
135                    }
136                }
137            }
138            return result;
139        }
140    }
141    HashSet::new()
142}
143
144pub fn get_struct_name(tcx: TyCtxt<'_>, def_id: DefId) -> Option<String> {
145    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
146        if let Some(impl_id) = assoc_item.impl_container(tcx) {
147            let ty = tcx.type_of(impl_id).skip_binder();
148            let type_name = ty.to_string();
149            let struct_name = type_name
150                .split('<')
151                .next()
152                .unwrap_or("")
153                .split("::")
154                .last()
155                .unwrap_or("")
156                .to_string();
157
158            return Some(struct_name);
159        }
160    }
161    None
162}
163
164pub fn check_safety(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
165    let poly_fn_sig = tcx.fn_sig(def_id);
166    let fn_sig = poly_fn_sig.skip_binder();
167    fn_sig.safety() == rustc_hir::Safety::Unsafe
168}
169
170//retval: 0-constructor, 1-method, 2-function
171pub fn get_type(tcx: TyCtxt<'_>, def_id: DefId) -> usize {
172    let mut node_type = 2;
173    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
174        match assoc_item.kind {
175            AssocKind::Fn { has_self, .. } => {
176                if has_self {
177                    node_type = 1;
178                } else {
179                    let fn_sig = tcx.fn_sig(def_id).skip_binder();
180                    let output = fn_sig.output().skip_binder();
181                    // return type is 'Self'
182                    if output.is_param(0) {
183                        node_type = 0;
184                    }
185                    // return type is struct's name
186                    if let Some(impl_id) = assoc_item.impl_container(tcx) {
187                        let ty = tcx.type_of(impl_id).skip_binder();
188                        if output == ty {
189                            node_type = 0;
190                        }
191                    }
192                    match output.kind() {
193                        TyKind::Ref(_, ref_ty, _) => {
194                            if ref_ty.is_param(0) {
195                                node_type = 0;
196                            }
197                            if let Some(impl_id) = assoc_item.impl_container(tcx) {
198                                let ty = tcx.type_of(impl_id).skip_binder();
199                                if *ref_ty == ty {
200                                    node_type = 0;
201                                }
202                            }
203                        }
204                        TyKind::Adt(adt_def, substs) => {
205                            if adt_def.is_enum()
206                                && (tcx.is_diagnostic_item(sym::Option, adt_def.did())
207                                    || tcx.is_diagnostic_item(sym::Result, adt_def.did())
208                                    || tcx.is_diagnostic_item(kw::Box, adt_def.did()))
209                            {
210                                let inner_ty = substs.type_at(0);
211                                if inner_ty.is_param(0) {
212                                    node_type = 0;
213                                }
214                                if let Some(impl_id) = assoc_item.impl_container(tcx) {
215                                    let ty_impl = tcx.type_of(impl_id).skip_binder();
216                                    if inner_ty == ty_impl {
217                                        node_type = 0;
218                                    }
219                                }
220                            }
221                        }
222                        _ => {}
223                    }
224                }
225            }
226            _ => todo!(),
227        }
228    }
229    node_type
230}
231
232pub fn get_adt_ty(tcx: TyCtxt, def_id: DefId) -> Option<Ty> {
233    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
234        if let Some(impl_id) = assoc_item.impl_container(tcx) {
235            return Some(tcx.type_of(impl_id).skip_binder());
236        }
237    }
238    None
239}
240
241pub fn get_cons(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<NodeType> {
242    let mut cons = Vec::new();
243    if tcx.def_kind(def_id) == DefKind::Fn || get_type(tcx, def_id) == 0 {
244        return cons;
245    }
246    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
247        if let Some(impl_id) = assoc_item.impl_container(tcx) {
248            // get struct ty
249            let ty = tcx.type_of(impl_id).skip_binder();
250            if let Some(adt_def) = ty.ty_adt_def() {
251                let adt_def_id = adt_def.did();
252                let impls = tcx.inherent_impls(adt_def_id);
253                for impl_def_id in impls {
254                    for item in tcx.associated_item_def_ids(impl_def_id) {
255                        if (tcx.def_kind(item) == DefKind::Fn
256                            || tcx.def_kind(item) == DefKind::AssocFn)
257                            && get_type(tcx, *item) == 0
258                        {
259                            cons.push(generate_node_ty(tcx, *item));
260                        }
261                    }
262                }
263            }
264        }
265    }
266    cons
267}
268
269pub fn get_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
270    let mut callees = HashSet::new();
271    if tcx.is_mir_available(def_id) {
272        let body = tcx.optimized_mir(def_id);
273        for bb in body.basic_blocks.iter() {
274            if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
275                if let Operand::Constant(func_constant) = func {
276                    if let ty::FnDef(ref callee_def_id, _) = func_constant.const_.ty().kind() {
277                        if check_safety(tcx, *callee_def_id)
278                        // && check_visibility(tcx, *callee_def_id)
279                        {
280                            let sp_set = get_sp(tcx, *callee_def_id);
281                            if sp_set.len() != 0 {
282                                callees.insert(*callee_def_id);
283                            }
284                        }
285                    }
286                }
287            }
288        }
289    }
290    callees
291}
292
293// return all the impls def id of corresponding struct
294pub fn get_impls_for_struct(tcx: TyCtxt<'_>, struct_def_id: DefId) -> Vec<DefId> {
295    let mut impls = Vec::new();
296    for impl_item_id in tcx.hir_crate_items(()).impl_items() {
297        let impl_item = tcx.hir_impl_item(impl_item_id);
298        match impl_item.kind {
299            ImplItemKind::Type(ty) => {
300                if let rustc_hir::TyKind::Path(ref qpath) = ty.kind {
301                    if let rustc_hir::QPath::Resolved(_, path) = qpath {
302                        if let rustc_hir::def::Res::Def(_, ref def_id) = path.res {
303                            if *def_id == struct_def_id {
304                                impls.push(impl_item.owner_id.to_def_id());
305                            }
306                        }
307                    }
308                }
309            }
310            _ => (),
311        }
312    }
313    impls
314}
315
316pub fn get_adt_def_id_by_adt_method(tcx: TyCtxt<'_>, def_id: DefId) -> Option<DefId> {
317    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
318        if let Some(impl_id) = assoc_item.impl_container(tcx) {
319            // get struct ty
320            let ty = tcx.type_of(impl_id).skip_binder();
321            if let Some(adt_def) = ty.ty_adt_def() {
322                return Some(adt_def.did());
323            }
324        }
325    }
326    None
327}
328
329// get the pointee or wrapped type
330pub fn get_pointee(matched_ty: Ty<'_>) -> Ty<'_> {
331    // progress_info!("get_pointee: > {:?} as type: {:?}", matched_ty, matched_ty.kind());
332    let pointee = if let ty::RawPtr(ty_mut, _) = matched_ty.kind() {
333        get_pointee(*ty_mut)
334    } else if let ty::Ref(_, referred_ty, _) = matched_ty.kind() {
335        get_pointee(*referred_ty)
336    } else {
337        matched_ty
338    };
339    pointee
340}
341
342pub fn is_ptr(matched_ty: Ty<'_>) -> bool {
343    if let ty::RawPtr(_, _) = matched_ty.kind() {
344        return true;
345    }
346    false
347}
348
349pub fn is_ref(matched_ty: Ty<'_>) -> bool {
350    if let ty::Ref(_, _, _) = matched_ty.kind() {
351        return true;
352    }
353    false
354}
355
356pub fn is_slice(matched_ty: Ty) -> Option<Ty> {
357    if let ty::Slice(inner) = matched_ty.kind() {
358        return Some(*inner);
359    }
360    None
361}
362
363pub fn has_mut_self_param(tcx: TyCtxt, def_id: DefId) -> bool {
364    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
365        match assoc_item.kind {
366            AssocKind::Fn { has_self, .. } => {
367                if has_self {
368                    let body = tcx.optimized_mir(def_id);
369                    let fst_arg = body.local_decls[Local::from_usize(1)].clone();
370                    let ty = fst_arg.ty;
371                    let is_mut_ref =
372                        matches!(ty.kind(), ty::Ref(_, _, mutbl) if *mutbl == Mutability::Mut);
373                    return fst_arg.mutability.is_mut() || is_mut_ref;
374                }
375            }
376            _ => (),
377        }
378    }
379    false
380}
381
382// Input the adt def id
383// Return set of (mutable method def_id, fields can be modified)
384pub fn get_all_mutable_methods(tcx: TyCtxt, def_id: DefId) -> HashMap<DefId, HashSet<usize>> {
385    let mut results = HashMap::new();
386    let adt_def = get_adt_def_id_by_adt_method(tcx, def_id);
387    let public_fields = adt_def.map_or_else(HashSet::new, |def| get_public_fields(tcx, def));
388    let impl_vec = adt_def.map_or_else(Vec::new, |def| get_impls_for_struct(tcx, def));
389    for impl_id in impl_vec {
390        let associated_items = tcx.associated_items(impl_id);
391        for item in associated_items.in_definition_order() {
392            if let AssocKind::Fn {
393                name: _,
394                has_self: _,
395            } = item.kind
396            {
397                let item_def_id = item.def_id;
398                if has_mut_self_param(tcx, item_def_id) {
399                    // TODO: using dataflow to detect field modificaiton, combined with public fields
400                    let modified_fields = public_fields.clone();
401                    results.insert(item_def_id, modified_fields);
402                }
403            }
404        }
405    }
406    results
407}
408
409// Check each field's visibility, return the public fields vec
410pub fn get_public_fields(tcx: TyCtxt, def_id: DefId) -> HashSet<usize> {
411    let adt_def = tcx.adt_def(def_id);
412    adt_def
413        .all_fields()
414        .enumerate()
415        .filter_map(|(index, field_def)| tcx.visibility(field_def.did).is_public().then_some(index))
416        .collect()
417}
418
419// general function for displaying hashmap
420pub fn display_hashmap<K, V>(map: &HashMap<K, V>, level: usize)
421where
422    K: Ord + Debug + Hash,
423    V: Debug,
424{
425    let indent = "  ".repeat(level);
426    let mut sorted_keys: Vec<_> = map.keys().collect();
427    sorted_keys.sort();
428
429    for key in sorted_keys {
430        if let Some(value) = map.get(key) {
431            println!("{}{:?}: {:?}", indent, key, value);
432        }
433    }
434}
435
436pub fn get_all_std_unsafe_callees(tcx: TyCtxt, def_id: DefId) -> Vec<String> {
437    let mut results = Vec::new();
438    let body = tcx.optimized_mir(def_id);
439    let bb_len = body.basic_blocks.len();
440    for i in 0..bb_len {
441        let callees = match_std_unsafe_callee(
442            tcx,
443            body.basic_blocks[BasicBlock::from_usize(i)]
444                .clone()
445                .terminator(),
446        );
447        results.extend(callees);
448    }
449    results
450}
451
452pub fn get_all_std_unsafe_callees_block_id(tcx: TyCtxt, def_id: DefId) -> Vec<usize> {
453    let mut results = Vec::new();
454    let body = tcx.optimized_mir(def_id);
455    let bb_len = body.basic_blocks.len();
456    for i in 0..bb_len {
457        if match_std_unsafe_callee(
458            tcx,
459            body.basic_blocks[BasicBlock::from_usize(i)]
460                .clone()
461                .terminator(),
462        )
463        .is_empty()
464        {
465            results.push(i);
466        }
467    }
468    results
469}
470
471pub fn match_std_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
472    let mut results = Vec::new();
473    if let TerminatorKind::Call { func, .. } = &terminator.kind {
474        if let Operand::Constant(func_constant) = func {
475            if let ty::FnDef(ref callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
476                let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
477                if parse_unsafe_api(&func_name).is_some() {
478                    results.push(func_name);
479                }
480            }
481        }
482    }
483    results
484}
485
486// Bug definition: (1) strict -> weak & dst is mutable;
487//                 (2) _ -> strict
488pub fn is_strict_ty_convert<'tcx>(tcx: TyCtxt<'tcx>, src_ty: Ty<'tcx>, dst_ty: Ty<'tcx>) -> bool {
489    (is_strict_ty(tcx, src_ty) && dst_ty.is_mutable_ptr()) || is_strict_ty(tcx, dst_ty)
490}
491
492// strict ty: bool, str, adt fields containing bool or str;
493pub fn is_strict_ty<'tcx>(tcx: TyCtxt<'tcx>, ori_ty: Ty<'tcx>) -> bool {
494    let ty = get_pointee(ori_ty);
495    let mut flag = false;
496    if let TyKind::Adt(adt_def, substs) = ty.kind() {
497        if adt_def.is_struct() {
498            for field_def in adt_def.all_fields() {
499                flag |= is_strict_ty(tcx, field_def.ty(tcx, substs))
500            }
501        }
502    }
503    ty.is_bool() || ty.is_str() || flag
504}
505
506pub fn reverse_op(op: BinOp) -> BinOp {
507    match op {
508        BinOp::Lt => BinOp::Ge,
509        BinOp::Ge => BinOp::Lt,
510        BinOp::Le => BinOp::Gt,
511        BinOp::Gt => BinOp::Le,
512        BinOp::Eq => BinOp::Eq,
513        BinOp::Ne => BinOp::Ne,
514        _ => op,
515    }
516}
517
518/// Same with `generate_contract_from_annotation` but does not contain field types.
519pub fn generate_contract_from_annotation_without_field_types(
520    tcx: TyCtxt,
521    def_id: DefId,
522) -> Vec<(usize, Vec<usize>, PropertyContract)> {
523    let contracts_with_ty = generate_contract_from_annotation(tcx, def_id);
524
525    contracts_with_ty
526        .into_iter()
527        .map(|(local_id, fields_with_ty, contract)| {
528            let fields: Vec<usize> = fields_with_ty
529                .into_iter()
530                .map(|(field_idx, _)| field_idx)
531                .collect();
532            (local_id, fields, contract)
533        })
534        .collect()
535}
536
537/// Filter the function which contains "rapx::proof"
538pub fn is_verify_target_func(tcx: TyCtxt, def_id: DefId) -> bool {
539    const REGISTER_TOOL: &str = "rapx";
540    for attr in tcx.get_all_attrs(def_id).into_iter() {
541        if let Attribute::Unparsed(tool_attr) = attr {
542            if tool_attr.path.segments[0].as_str() == REGISTER_TOOL
543                && tool_attr.path.segments[1].as_str() == "proof"
544            {
545                return true;
546            }
547        }
548    }
549    false
550}
551
552/// Get the annotation in tag-std style.
553/// Then generate the contractual invariant states (CIS) for the args.
554/// This function will recognize the args name and record states to MIR variable (represent by usize).
555/// Return value means Vec<(local_id, fields of this local, contracts)>
556pub fn generate_contract_from_annotation(
557    tcx: TyCtxt,
558    def_id: DefId,
559) -> Vec<(usize, Vec<(usize, Ty)>, PropertyContract)> {
560    const REGISTER_TOOL: &str = "rapx";
561    let tool_attrs = tcx.get_all_attrs(def_id).into_iter().filter(|attr| {
562        if let Attribute::Unparsed(tool_attr) = attr {
563            if tool_attr.path.segments[0].as_str() == REGISTER_TOOL
564                && tool_attr.path.segments[1].as_str() != "proof"
565            {
566                return true;
567            }
568        }
569        false
570    });
571    let mut results = Vec::new();
572    for attr in tool_attrs {
573        let attr_str = rustc_hir_pretty::attribute_to_string(&tcx, attr);
574        let safety_attr =
575            safety_parser::property_attr::parse_inner_attr_from_str(attr_str.as_str()).unwrap();
576        let attr_name = safety_attr.name;
577        let attr_kind = safety_attr.kind;
578        let contract = PropertyContract::new(tcx, def_id, attr_kind, attr_name, &safety_attr.expr);
579        let (local, fields) = parse_cis_local(tcx, def_id, safety_attr.expr);
580        results.push((local, fields, contract));
581    }
582    // if results.len() > 0 {
583    //     rap_warn!("results:\n{:?}", results);
584    // }
585    results
586}
587
588/// Parse attr.expr into local id and local fields.
589///
590/// Example:
591/// ```
592/// #[rapx::inner(property = ValidPtr(ptr, u32, 1), kind = "precond")]
593/// #[rapx::inner(property = ValidNum(region.size>=0), kind = "precond")]
594/// pub fn xor_secret_region(ptr: *mut u32, region:SecretRegion) -> u32 {...}
595/// ```
596///
597/// The first attribute will be parsed as (1, []).
598///     -> "1" means the first arg "ptr", "[]" means no fields.
599/// The second attribute will be parsed as (2, [1]).
600///     -> "2" means the second arg "region", "[1]" means "size" is region's second field.
601///
602/// If this function doesn't have args, then it will return default pattern: (0, Vec::new())
603pub fn parse_cis_local(tcx: TyCtxt, def_id: DefId, expr: Vec<Expr>) -> (usize, Vec<(usize, Ty)>) {
604    // match expr with cis local
605    for e in expr {
606        if let Some((base, fields, _ty)) = parse_expr_into_local_and_ty(tcx, def_id, &e) {
607            return (base, fields);
608        }
609    }
610    (0, Vec::new())
611}
612
613/// parse single expr into (local, fields, ty)
614pub fn parse_expr_into_local_and_ty<'tcx>(
615    tcx: TyCtxt<'tcx>,
616    def_id: DefId,
617    expr: &Expr,
618) -> Option<(usize, Vec<(usize, Ty<'tcx>)>, Ty<'tcx>)> {
619    if let Some((base_ident, fields)) = access_ident_recursive(&expr) {
620        let (param_names, param_tys) = parse_signature(tcx, def_id);
621        if param_names[0] == "0".to_string() {
622            return None;
623        }
624        if let Some(param_index) = param_names.iter().position(|name| name == &base_ident) {
625            let mut current_ty = param_tys[param_index];
626            let mut field_indices = Vec::new();
627            for field_name in fields {
628                // peel the ref and ptr
629                let peeled_ty = current_ty.peel_refs();
630                if let rustc_middle::ty::TyKind::Adt(adt_def, arg_list) = *peeled_ty.kind() {
631                    let variant = adt_def.non_enum_variant();
632                    // 1. if field_name is number, then parse it as usize
633                    if let Ok(field_idx) = field_name.parse::<usize>() {
634                        if field_idx < variant.fields.len() {
635                            current_ty = variant.fields[rustc_abi::FieldIdx::from_usize(field_idx)]
636                                .ty(tcx, arg_list);
637                            field_indices.push((field_idx, current_ty));
638                            continue;
639                        }
640                    }
641                    // 2. if field_name is String, then compare it with current ty's field names
642                    if let Some((idx, _)) = variant
643                        .fields
644                        .iter()
645                        .enumerate()
646                        .find(|(_, f)| f.ident(tcx).name.to_string() == field_name.clone())
647                    {
648                        current_ty =
649                            variant.fields[rustc_abi::FieldIdx::from_usize(idx)].ty(tcx, arg_list);
650                        field_indices.push((idx, current_ty));
651                    }
652                    // 3. if field_name can not match any fields, then break
653                    else {
654                        break; // TODO:
655                    }
656                }
657                // if current ty is not Adt, then break the loop
658                else {
659                    break; // TODO:
660                }
661            }
662            // It's different from default one, we return the result as param_index+1 because param_index count from 0.
663            // But 0 in MIR is the ret index, the args' indexes begin from 1.
664            return Some((param_index + 1, field_indices, current_ty));
665        }
666    }
667    None
668}
669
670/// Return the Vecs of args' names and types
671/// This function will handle outside def_id by different way.
672pub fn parse_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
673    // 0. If the def id is local
674    if def_id.as_local().is_some() {
675        return parse_local_signature(tcx, def_id);
676    } else {
677        rap_debug!("{:?} is not local def id.", def_id);
678        return parse_outside_signature(tcx, def_id);
679    };
680}
681
682/// Return the Vecs of args' names and types of outside functions.
683fn parse_outside_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
684    let sig = tcx.fn_sig(def_id).skip_binder();
685    let param_tys: Vec<Ty<'tcx>> = sig.inputs().skip_binder().iter().copied().collect();
686
687    // 1. check pre-defined std unsafe api signature
688    if let Some(args_name) = get_known_std_names(tcx, def_id) {
689        // rap_warn!(
690        //     "function {:?} has arg: {:?}, arg types: {:?}",
691        //     def_id,
692        //     args_name,
693        //     param_tys
694        // );
695        return (args_name, param_tys);
696    }
697
698    // 2. TODO: If can not find known std apis, then use numbers like `0`,`1`,... to represent args.
699    let args_name = (0..param_tys.len()).map(|i| format!("{}", i)).collect();
700    rap_debug!(
701        "function {:?} has arg: {:?}, arg types: {:?}",
702        def_id,
703        args_name,
704        param_tys
705    );
706    return (args_name, param_tys);
707}
708
709/// We use a json to record known std apis' arg names.
710/// This function will search the json and return the names.
711/// Notes: If std gets updated, the json may still record old ones.
712fn get_known_std_names<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Vec<String>> {
713    let std_func_name = get_cleaned_def_path_name(tcx, def_id);
714    let json_data: serde_json::Value = get_std_api_signature_json();
715
716    if let Some(arg_info) = json_data.get(&std_func_name) {
717        if let Some(args_name) = arg_info.as_array() {
718            // set default value to arg name
719            if args_name.len() == 0 {
720                return Some(vec!["0".to_string()]);
721            }
722            // iterate and collect
723            let mut result = Vec::new();
724            for arg in args_name {
725                if let Some(sp_name) = arg.as_str() {
726                    result.push(sp_name.to_string());
727                }
728            }
729            return Some(result);
730        }
731    }
732    None
733}
734
735/// Return the Vecs of args' names and types of local functions.
736pub fn parse_local_signature(tcx: TyCtxt, def_id: DefId) -> (Vec<String>, Vec<Ty>) {
737    // 1. parse local def_id and get arg list
738    let local_def_id = def_id.as_local().unwrap();
739    let hir_body = tcx.hir_body_owned_by(local_def_id);
740    if hir_body.params.len() == 0 {
741        return (vec!["0".to_string()], Vec::new());
742    }
743    // 2. contruct the vec of param and param ty
744    let params = hir_body.params;
745    let typeck_results = tcx.typeck_body(hir_body.id());
746    let mut param_names = Vec::new();
747    let mut param_tys = Vec::new();
748    for param in params {
749        match param.pat.kind {
750            rustc_hir::PatKind::Binding(_, _, ident, _) => {
751                param_names.push(ident.name.to_string());
752                let ty = typeck_results.pat_ty(param.pat);
753                param_tys.push(ty);
754            }
755            _ => {
756                param_names.push(String::new());
757                param_tys.push(typeck_results.pat_ty(param.pat));
758            }
759        }
760    }
761    (param_names, param_tys)
762}
763
764/// return the (ident, its fields) of the expr.
765///
766/// illustrated cases :
767///    ptr	-> ("ptr", [])
768///    region.size	-> ("region", ["size"])
769///    tuple.0.value -> ("tuple", ["0", "value"])
770pub fn access_ident_recursive(expr: &Expr) -> Option<(String, Vec<String>)> {
771    match expr {
772        Expr::Path(syn::ExprPath { path, .. }) => {
773            if path.segments.len() == 1 {
774                rap_debug!("expr2 {:?}", expr);
775                let ident = path.segments[0].ident.to_string();
776                Some((ident, Vec::new()))
777            } else {
778                None
779            }
780        }
781        // get the base and fields recursively
782        Expr::Field(syn::ExprField { base, member, .. }) => {
783            let (base_ident, mut fields) =
784                if let Some((base_ident, fields)) = access_ident_recursive(base) {
785                    (base_ident, fields)
786                } else {
787                    return None;
788                };
789            let field_name = match member {
790                syn::Member::Named(ident) => ident.to_string(),
791                syn::Member::Unnamed(index) => index.index.to_string(),
792            };
793            fields.push(field_name);
794            Some((base_ident, fields))
795        }
796        _ => None,
797    }
798}
799
800/// parse expr into number.
801pub fn parse_expr_into_number(expr: &Expr) -> Option<usize> {
802    if let Expr::Lit(expr_lit) = expr {
803        if let syn::Lit::Int(lit_int) = &expr_lit.lit {
804            return lit_int.base10_parse::<usize>().ok();
805        }
806    }
807    None
808}
809
810/// Match a type identifier string to a concrete Rust type
811///
812/// This function attempts to match a given type identifier (e.g., "u32", "T", "MyStruct")
813/// to a type in the provided parameter type list. It handles:
814/// 1. Built-in primitive types (u32, usize, etc.)
815/// 2. Generic type parameters (T, U, etc.)
816/// 3. User-defined types found in the parameter list
817///
818/// Arguments:
819/// - `tcx`: Type context for querying compiler information
820/// - `type_ident`: String representing the type identifier to match
821/// - `param_ty`: List of parameter types from the function signature
822///
823/// Returns:
824/// - `Some(Ty)` if a matching type is found
825/// - `None` if no match is found
826pub fn match_ty_with_ident(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
827    // 1. First check for built-in primitive types
828    if let Some(primitive_ty) = match_primitive_type(tcx, type_ident.clone()) {
829        return Some(primitive_ty);
830    }
831    // 2. Check if the identifier matches any generic type parameter
832    return find_generic_param(tcx, def_id, type_ident.clone());
833    // 3. Check if the identifier matches any user-defined type in the parameters
834    // find_user_defined_type(tcx, def_id, type_ident)
835}
836
837/// Match built-in primitive types from String
838fn match_primitive_type(tcx: TyCtxt, type_ident: String) -> Option<Ty> {
839    match type_ident.as_str() {
840        "i8" => Some(tcx.types.i8),
841        "i16" => Some(tcx.types.i16),
842        "i32" => Some(tcx.types.i32),
843        "i64" => Some(tcx.types.i64),
844        "i128" => Some(tcx.types.i128),
845        "isize" => Some(tcx.types.isize),
846        "u8" => Some(tcx.types.u8),
847        "u16" => Some(tcx.types.u16),
848        "u32" => Some(tcx.types.u32),
849        "u64" => Some(tcx.types.u64),
850        "u128" => Some(tcx.types.u128),
851        "usize" => Some(tcx.types.usize),
852        "f16" => Some(tcx.types.f16),
853        "f32" => Some(tcx.types.f32),
854        "f64" => Some(tcx.types.f64),
855        "f128" => Some(tcx.types.f128),
856        "bool" => Some(tcx.types.bool),
857        "char" => Some(tcx.types.char),
858        "str" => Some(tcx.types.str_),
859        _ => None,
860    }
861}
862
863/// Find generic type parameters in the parameter list
864fn find_generic_param(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
865    rap_debug!(
866        "Searching for generic param: {} in {:?}",
867        type_ident,
868        def_id
869    );
870    let (_, param_tys) = parse_signature(tcx, def_id);
871    rap_debug!("Function parameter types: {:?} of {:?}", param_tys, def_id);
872    // 递归查找泛型参数
873    for &ty in &param_tys {
874        if let Some(found) = find_generic_in_ty(tcx, ty, &type_ident) {
875            return Some(found);
876        }
877    }
878
879    None
880}
881
882/// Iterate the args' types recursively and find the matched generic one.
883fn find_generic_in_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>, type_ident: &str) -> Option<Ty<'tcx>> {
884    match ty.kind() {
885        TyKind::Param(param_ty) => {
886            if param_ty.name.as_str() == type_ident {
887                return Some(ty);
888            }
889        }
890        TyKind::RawPtr(ty, _)
891        | TyKind::Ref(_, ty, _)
892        | TyKind::Slice(ty)
893        | TyKind::Array(ty, _) => {
894            if let Some(found) = find_generic_in_ty(tcx, *ty, type_ident) {
895                return Some(found);
896            }
897        }
898        TyKind::Tuple(tys) => {
899            for tuple_ty in tys.iter() {
900                if let Some(found) = find_generic_in_ty(tcx, tuple_ty, type_ident) {
901                    return Some(found);
902                }
903            }
904        }
905        TyKind::Adt(adt_def, substs) => {
906            let name = tcx.item_name(adt_def.did()).to_string();
907            if name == type_ident {
908                return Some(ty);
909            }
910            for field in adt_def.all_fields() {
911                let field_ty = field.ty(tcx, substs);
912                if let Some(found) = find_generic_in_ty(tcx, field_ty, type_ident) {
913                    return Some(found);
914                }
915            }
916        }
917        _ => {}
918    }
919    None
920}
921
922// /// Find user-defined types in the parameter list
923// fn find_user_defined_type(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
924//     let param_ty = parse_signature(tcx, def_id).1;
925//     param_ty.iter().find_map(|&ty| {
926//         // Peel off references and pointers to get to the underlying type
927//         let peeled_ty = ty.peel_refs();
928//         match peeled_ty.kind() {
929//             TyKind::Adt(adt_def, _raw_list) => {
930//                 // Compare the type name to our identifier
931//                 let name = tcx.item_name(adt_def.did()).to_string();
932//                 if name == type_ident {
933//                     return Some(peeled_ty);
934//                 }
935//             }
936//             _ => {}
937//         }
938//         None
939//     })
940// }
941
942pub fn reflect_generic<'tcx>(
943    generic_mapping: &FxHashMap<String, Ty<'tcx>>,
944    ty: Ty<'tcx>,
945) -> Ty<'tcx> {
946    match ty.kind() {
947        TyKind::Param(param_ty) => {
948            let generic_name = param_ty.name.to_string();
949            if let Some(actual_ty) = generic_mapping.get(&generic_name) {
950                return *actual_ty;
951            }
952        }
953        _ => {}
954    }
955    ty
956}
957
958// fn find_first_chain_in_expr(expr: &Expr, param_names: &[String]) -> Option<(String, Vec<String>)> {
959//     if let Some(chain) = access_ident_recursive(expr) {
960//         if param_names.contains(&chain.0) {
961//             return Some(chain);
962//         }
963//     }
964
965//     let mut children = Vec::new();
966//     match expr {
967//         Expr::Array(expr) => children.extend(expr.elems.iter()),
968//         Expr::Assign(expr) => {
969//             children.push(&expr.left);
970//             children.push(&expr.right);
971//         }
972//         Expr::Binary(expr) => {
973//             children.push(&expr.left);
974//             children.push(&expr.right);
975//         }
976//         Expr::Call(expr) => {
977//             children.push(&expr.func);
978//             children.extend(expr.args.iter());
979//         }
980//         Expr::Cast(expr) => children.push(&expr.expr),
981//         Expr::Field(expr) => children.push(&expr.base),
982//         Expr::Group(expr) => children.push(&expr.expr),
983//         Expr::Index(expr) => {
984//             children.push(&expr.expr);
985//             children.push(&expr.index);
986//         }
987//         Expr::Let(expr) => children.push(&expr.expr),
988//         // Expr::Lit(expr) => children.push(&expr.lit),
989//         Expr::MethodCall(expr) => {
990//             children.push(&expr.receiver);
991//             children.extend(expr.args.iter());
992//         }
993//         Expr::Paren(expr) => children.push(&expr.expr),
994//         Expr::Range(expr) => {
995//             if let Some(start) = &expr.start {
996//                 children.push(start);
997//             }
998//             if let Some(end) = &expr.end {
999//                 children.push(end);
1000//             }
1001//             // if let Some(limits) = &expr.limits {
1002//             //     let _ = limits;
1003//             // }
1004//         }
1005//         Expr::Reference(expr) => children.push(&expr.expr),
1006//         Expr::Repeat(expr) => {
1007//             children.push(&expr.expr);
1008//             children.push(&expr.len);
1009//         }
1010//         // Expr::Struct(expr) => {
1011//         //     children.push(&expr.path);
1012//         //     for field in &expr.fields {
1013//         //         children.push(&field.expr);
1014//         //     }
1015//         //     if let Some(rest) = &expr.rest {
1016//         //         children.push(rest);
1017//         //     }
1018//         // }
1019//         Expr::Try(expr) => children.push(&expr.expr),
1020//         Expr::Tuple(expr) => children.extend(expr.elems.iter()),
1021//         Expr::Unary(expr) => children.push(&expr.expr),
1022//         _ => {}
1023//     }
1024
1025//     for child in children {
1026//         if let Some(chain) = find_first_chain_in_expr(child, param_names) {
1027//             return Some(chain);
1028//         }
1029//     }
1030
1031//     None
1032// }