rapx/analysis/utils/
fn_info.rs

1use super::draw_dot::render_dot_string;
2use crate::analysis::{
3    core::dataflow::{DataFlowAnalysis, default::DataFlowAnalyzer},
4    senryx::{
5        contracts::{property, property::PropertyContract},
6        matcher::parse_unsafe_api,
7    },
8};
9use crate::{rap_debug, rap_warn};
10use rustc_ast::ItemKind;
11use rustc_data_structures::fx::FxHashMap;
12use rustc_hir::{
13    Attribute, ImplItemKind, Safety,
14    def::DefKind,
15    def_id::{CrateNum, DefId, DefIndex},
16};
17use rustc_middle::{
18    hir::place::PlaceBase,
19    mir::{
20        BasicBlock, BinOp, Body, Local, Operand, Place, PlaceElem, PlaceRef, ProjectionElem,
21        Rvalue, StatementKind, Terminator, TerminatorKind,
22    },
23    ty,
24    ty::{AssocKind, ConstKind, Mutability, Ty, TyCtxt, TyKind},
25};
26use rustc_span::{def_id::LocalDefId, kw, sym};
27use serde::de;
28use std::{
29    collections::{HashMap, HashSet},
30    fmt::Debug,
31    hash::Hash,
32};
33use syn::Expr;
34
35#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
36pub enum FnKind {
37    Fn,
38    Method,
39    Constructor,
40    Intrinsic,
41}
42
43#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
44pub struct FnInfo {
45    pub def_id: DefId,
46    pub fn_safety: Safety,
47    pub fn_kind: FnKind,
48}
49
50impl FnInfo {
51    pub fn new(def_id: DefId, fn_safety: Safety, fn_kind: FnKind) -> Self {
52        FnInfo {
53            def_id,
54            fn_safety,
55            fn_kind,
56        }
57    }
58}
59
60#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
61pub struct AdtInfo {
62    pub def_id: DefId,
63    pub literal_cons_enabled: bool,
64}
65
66impl AdtInfo {
67    pub fn new(def_id: DefId, literal_cons_enabled: bool) -> Self {
68        AdtInfo {
69            def_id,
70            literal_cons_enabled,
71        }
72    }
73}
74
75pub fn check_visibility(tcx: TyCtxt, func_defid: DefId) -> bool {
76    if !tcx.visibility(func_defid).is_public() {
77        return false;
78    }
79    // if func_defid.is_local() {
80    //     if let Some(local_defid) = func_defid.as_local() {
81    //         let module_moddefid = tcx.parent_module_from_def_id(local_defid);
82    //         let module_defid = module_moddefid.to_def_id();
83    //         if !tcx.visibility(module_defid).is_public() {
84    //             // println!("module def id {:?}",UigUnit::get_cleaned_def_path_name(tcx, module_defid));
85    //             return Self::is_re_exported(tcx, func_defid, module_moddefid.to_local_def_id());
86    //         }
87    //     }
88    // }
89    true
90}
91
92pub fn is_re_exported(tcx: TyCtxt, target_defid: DefId, module_defid: LocalDefId) -> bool {
93    for child in tcx.module_children_local(module_defid) {
94        if child.vis.is_public() {
95            if let Some(def_id) = child.res.opt_def_id() {
96                if def_id == target_defid {
97                    return true;
98                }
99            }
100        }
101    }
102    false
103}
104
105pub fn print_hashset<T: std::fmt::Debug>(set: &HashSet<T>) {
106    for item in set {
107        println!("{:?}", item);
108    }
109    println!("---------------");
110}
111
112pub fn get_cleaned_def_path_name_ori(tcx: TyCtxt, def_id: DefId) -> String {
113    let def_id_str = format!("{:?}", def_id);
114    let mut parts: Vec<&str> = def_id_str.split("::").collect();
115
116    let mut remove_first = false;
117    if let Some(first_part) = parts.get_mut(0) {
118        if first_part.contains("core") {
119            *first_part = "core";
120        } else if first_part.contains("std") {
121            *first_part = "std";
122        } else if first_part.contains("alloc") {
123            *first_part = "alloc";
124        } else {
125            remove_first = true;
126        }
127    }
128    if remove_first && !parts.is_empty() {
129        parts.remove(0);
130    }
131
132    let new_parts: Vec<String> = parts
133        .into_iter()
134        .filter_map(|s| {
135            if s.contains("{") {
136                if remove_first {
137                    get_struct_name(tcx, def_id)
138                } else {
139                    None
140                }
141            } else {
142                Some(s.to_string())
143            }
144        })
145        .collect();
146
147    let mut cleaned_path = new_parts.join("::");
148    cleaned_path = cleaned_path.trim_end_matches(')').to_string();
149    cleaned_path
150}
151
152pub fn get_sp_json() -> serde_json::Value {
153    let json_data: serde_json::Value =
154        serde_json::from_str(include_str!("data/std_sps.json")).expect("Unable to parse JSON");
155    json_data
156}
157
158pub fn get_std_api_signature_json() -> serde_json::Value {
159    let json_data: serde_json::Value =
160        serde_json::from_str(include_str!("data/std_sig.json")).expect("Unable to parse JSON");
161    json_data
162}
163
164pub fn get_sp(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<String> {
165    let cleaned_path_name = get_cleaned_def_path_name(tcx, def_id);
166    let json_data: serde_json::Value = get_sp_json();
167
168    if let Some(function_info) = json_data.get(&cleaned_path_name) {
169        if let Some(sp_list) = function_info.get("0") {
170            let mut result = HashSet::new();
171            if let Some(sp_array) = sp_list.as_array() {
172                for sp in sp_array {
173                    if let Some(sp_name) = sp.as_str() {
174                        result.insert(sp_name.to_string());
175                    }
176                }
177            }
178            return result;
179        }
180    }
181    HashSet::new()
182}
183
184pub fn get_struct_name(tcx: TyCtxt<'_>, def_id: DefId) -> Option<String> {
185    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
186        if let Some(impl_id) = assoc_item.impl_container(tcx) {
187            let ty = tcx.type_of(impl_id).skip_binder();
188            let type_name = ty.to_string();
189            let struct_name = type_name
190                .split('<')
191                .next()
192                .unwrap_or("")
193                .split("::")
194                .last()
195                .unwrap_or("")
196                .to_string();
197
198            return Some(struct_name);
199        }
200    }
201    None
202}
203
204pub fn check_safety(tcx: TyCtxt<'_>, def_id: DefId) -> Safety {
205    let poly_fn_sig = tcx.fn_sig(def_id);
206    let fn_sig = poly_fn_sig.skip_binder();
207    fn_sig.safety()
208}
209
210pub fn get_type(tcx: TyCtxt<'_>, def_id: DefId) -> FnKind {
211    let mut node_type = 2;
212    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
213        match assoc_item.kind {
214            AssocKind::Fn { has_self, .. } => {
215                if has_self {
216                    return FnKind::Method;
217                } else {
218                    let fn_sig = tcx.fn_sig(def_id).skip_binder();
219                    let output = fn_sig.output().skip_binder();
220                    // return type is 'Self'
221                    if output.is_param(0) {
222                        return FnKind::Constructor;
223                    }
224                    // return type is struct's name
225                    if let Some(impl_id) = assoc_item.impl_container(tcx) {
226                        let ty = tcx.type_of(impl_id).skip_binder();
227                        if output == ty {
228                            return FnKind::Constructor;
229                        }
230                    }
231                    match output.kind() {
232                        TyKind::Ref(_, ref_ty, _) => {
233                            if ref_ty.is_param(0) {
234                                return FnKind::Constructor;
235                            }
236                            if let Some(impl_id) = assoc_item.impl_container(tcx) {
237                                let ty = tcx.type_of(impl_id).skip_binder();
238                                if *ref_ty == ty {
239                                    return FnKind::Constructor;
240                                }
241                            }
242                        }
243                        TyKind::Adt(adt_def, substs) => {
244                            if adt_def.is_enum()
245                                && (tcx.is_diagnostic_item(sym::Option, adt_def.did())
246                                    || tcx.is_diagnostic_item(sym::Result, adt_def.did())
247                                    || tcx.is_diagnostic_item(kw::Box, adt_def.did()))
248                            {
249                                let inner_ty = substs.type_at(0);
250                                if inner_ty.is_param(0) {
251                                    return FnKind::Constructor;
252                                }
253                                if let Some(impl_id) = assoc_item.impl_container(tcx) {
254                                    let ty_impl = tcx.type_of(impl_id).skip_binder();
255                                    if inner_ty == ty_impl {
256                                        return FnKind::Constructor;
257                                    }
258                                }
259                            }
260                        }
261                        _ => {}
262                    }
263                }
264            }
265            _ => todo!(),
266        }
267    }
268    return FnKind::Fn;
269}
270
271pub fn get_adt_ty(tcx: TyCtxt, def_id: DefId) -> Option<Ty> {
272    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
273        if let Some(impl_id) = assoc_item.impl_container(tcx) {
274            return Some(tcx.type_of(impl_id).skip_binder());
275        }
276    }
277    None
278}
279
280// check whether this adt contains a literal constructor
281// result: adt_def_id, is_literal
282pub fn get_adt_via_method(tcx: TyCtxt<'_>, method_def_id: DefId) -> Option<AdtInfo> {
283    let assoc_item = tcx.opt_associated_item(method_def_id)?;
284    let impl_id = assoc_item.impl_container(tcx)?;
285    let ty = tcx.type_of(impl_id).skip_binder();
286    let adt_def = ty.ty_adt_def()?;
287    let adt_def_id = adt_def.did();
288
289    let all_fields: Vec<_> = adt_def.all_fields().collect();
290    let total_count = all_fields.len();
291
292    if total_count == 0 {
293        return Some(AdtInfo::new(adt_def_id, true));
294    }
295
296    let pub_count = all_fields
297        .iter()
298        .filter(|field| tcx.visibility(field.did).is_public())
299        .count();
300
301    if pub_count == 0 {
302        return None;
303    }
304    Some(AdtInfo::new(adt_def_id, pub_count == total_count))
305}
306
307fn place_has_raw_deref<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, place: &Place<'tcx>) -> bool {
308    let mut local = place.local;
309    for proj in place.projection.iter() {
310        if let ProjectionElem::Deref = proj.kind() {
311            let ty = body.local_decls[local].ty;
312            if let TyKind::RawPtr(_, _) = ty.kind() {
313                return true;
314            }
315        }
316    }
317    false
318}
319
320pub fn get_rawptr_deref(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<Local> {
321    let mut raw_ptrs = HashSet::new();
322    if tcx.is_mir_available(def_id) {
323        let body = tcx.optimized_mir(def_id);
324        for bb in body.basic_blocks.iter() {
325            for stmt in &bb.statements {
326                if let StatementKind::Assign(box (lhs, rhs)) = &stmt.kind {
327                    if place_has_raw_deref(tcx, &body, lhs) {
328                        raw_ptrs.insert(lhs.local);
329                    }
330                    if let Rvalue::Use(op) = rhs {
331                        match op {
332                            Operand::Copy(place) | Operand::Move(place) => {
333                                if place_has_raw_deref(tcx, &body, place) {
334                                    raw_ptrs.insert(place.local);
335                                }
336                            }
337                            _ => {}
338                        }
339                    }
340                    if let Rvalue::Ref(_, _, place) = rhs {
341                        if place_has_raw_deref(tcx, &body, place) {
342                            raw_ptrs.insert(place.local);
343                        }
344                    }
345                }
346            }
347            if let Some(terminator) = &bb.terminator {
348                match &terminator.kind {
349                    rustc_middle::mir::TerminatorKind::Call { args, .. } => {
350                        for arg in args {
351                            match arg.node {
352                                Operand::Copy(place) | Operand::Move(place) => {
353                                    if place_has_raw_deref(tcx, &body, &place) {
354                                        raw_ptrs.insert(place.local);
355                                    }
356                                }
357                                _ => {}
358                            }
359                        }
360                    }
361                    _ => {}
362                }
363            }
364        }
365    }
366    raw_ptrs
367}
368
369/* Example mir of static mutable access.
370
371static mut COUNTER: i32 = {
372    let mut _0: i32;
373
374    bb0: {
375        _0 = const 0_i32;
376        return;
377    }
378}
379
380fn main() -> () {
381    let mut _0: ();
382    let mut _1: *mut i32;
383
384    bb0: {
385        StorageLive(_1);
386        _1 = const {alloc1: *mut i32};
387        (*_1) = const 1_i32;
388        StorageDead(_1);
389        return;
390    }
391}
392
393alloc1 (static: COUNTER, size: 4, align: 4) {
394    00 00 00 00                                     │ ....
395}
396
397*/
398
399fn place_is_static_mut<'tcx>(
400    tcx: ty::TyCtxt<'tcx>,
401    body: &Body<'tcx>,
402    place: PlaceRef<'tcx>,
403    static_muts: &HashSet<(DefId, Local)>,
404) -> bool {
405    if static_muts.is_empty() {
406        return false;
407    }
408    if static_muts
409        .iter()
410        .any(|(_def_id, local)| *local == place.local)
411        && place.as_local().is_none()
412    {
413        for (place_ref, proj) in place.iter_projections() {
414            match proj {
415                PlaceElem::Deref => return true,
416                _ => {}
417            }
418        }
419    }
420    return false;
421}
422
423pub fn collect_global_local_pairs(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<(DefId, Local)> {
424    let mut globals = HashSet::new();
425
426    if !tcx.is_mir_available(def_id) {
427        return globals;
428    }
429
430    let body = tcx.optimized_mir(def_id);
431
432    for bb in body.basic_blocks.iter() {
433        for stmt in &bb.statements {
434            if let StatementKind::Assign(box (lhs, rhs)) = &stmt.kind {
435                match rhs {
436                    Rvalue::Use(op) => match op {
437                        Operand::Constant(box (cons_op)) => {
438                            if let Some(def_id) = cons_op.check_static_ptr(tcx) {
439                                globals.insert((def_id, lhs.local));
440                            }
441                        }
442                        _ => {}
443                    },
444                    _ => {}
445                }
446            }
447        }
448    }
449    globals
450}
451
452pub fn get_static_mut_accesses(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<Local> {
453    let static_muts = collect_global_local_pairs(tcx, def_id);
454    let mut globals = HashSet::new();
455
456    if !tcx.is_mir_available(def_id) {
457        return globals;
458    }
459
460    let body = tcx.optimized_mir(def_id);
461
462    for bb in body.basic_blocks.iter() {
463        for stmt in &bb.statements {
464            if let StatementKind::Assign(box (lhs, rhs)) = &stmt.kind {
465                if place_is_static_mut(tcx, body, lhs.as_ref(), &static_muts) {
466                    globals.insert(lhs.local);
467                }
468
469                // RHS
470                match rhs {
471                    Rvalue::Use(op) => match op {
472                        Operand::Copy(place) | Operand::Move(place) => {
473                            if place_is_static_mut(tcx, body, place.as_ref(), &static_muts) {
474                                globals.insert(place.local);
475                            }
476                        }
477                        _ => {}
478                    },
479
480                    Rvalue::Ref(_, _, place) => {
481                        if place_is_static_mut(tcx, body, place.as_ref(), &static_muts) {
482                            globals.insert(place.local);
483                        }
484                    }
485
486                    _ => {}
487                }
488            }
489        }
490
491        // ---------------------------------------
492        // Terminators
493        // ---------------------------------------
494        if let Some(term) = &bb.terminator {
495            if let TerminatorKind::Call {
496                args, destination, ..
497            } = &term.kind
498            {
499                // args
500                for arg in args {
501                    match &arg.node {
502                        Operand::Copy(place) | Operand::Move(place) => {
503                            if place_is_static_mut(tcx, body, place.as_ref(), &static_muts) {
504                                globals.insert(place.local);
505                            }
506                        }
507                        _ => {}
508                    }
509                }
510
511                // destination
512                if place_is_static_mut(tcx, body, destination.as_ref(), &static_muts) {
513                    globals.insert(destination.local);
514                }
515            }
516        }
517    }
518
519    globals
520}
521
522pub fn get_unsafe_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
523    let mut unsafe_callees = HashSet::new();
524    if tcx.is_mir_available(def_id) {
525        let body = tcx.optimized_mir(def_id);
526        for bb in body.basic_blocks.iter() {
527            if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
528                if let Operand::Constant(func_constant) = func {
529                    if let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() {
530                        if check_safety(tcx, *callee_def_id) == Safety::Unsafe {
531                            unsafe_callees.insert(*callee_def_id);
532                        }
533                    }
534                }
535            }
536        }
537    }
538    unsafe_callees
539}
540
541pub fn get_all_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
542    let mut callees = HashSet::new();
543    if tcx.is_mir_available(def_id) {
544        let body = tcx.optimized_mir(def_id);
545        for bb in body.basic_blocks.iter() {
546            if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
547                if let Operand::Constant(func_constant) = func {
548                    if let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() {
549                        callees.insert(*callee_def_id);
550                    }
551                }
552            }
553        }
554    }
555    callees
556}
557
558// return all the impls def id of corresponding struct
559pub fn get_impls_for_struct(tcx: TyCtxt<'_>, struct_def_id: DefId) -> Vec<DefId> {
560    let mut impls = Vec::new();
561    for item_id in tcx.hir_crate_items(()).free_items() {
562        let item = tcx.hir_item(item_id);
563        if let rustc_hir::ItemKind::Impl(impl_details) = &item.kind {
564            if let rustc_hir::TyKind::Path(rustc_hir::QPath::Resolved(_, path)) =
565                &impl_details.self_ty.kind
566            {
567                if let rustc_hir::def::Res::Def(_, def_id) = path.res {
568                    if def_id == struct_def_id {
569                        impls.push(item_id.owner_id.to_def_id());
570                    }
571                }
572            }
573        }
574    }
575    impls
576}
577
578pub fn get_adt_def_id_by_adt_method(tcx: TyCtxt<'_>, def_id: DefId) -> Option<DefId> {
579    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
580        if let Some(impl_id) = assoc_item.impl_container(tcx) {
581            // get struct ty
582            let ty = tcx.type_of(impl_id).skip_binder();
583            if let Some(adt_def) = ty.ty_adt_def() {
584                return Some(adt_def.did());
585            }
586        }
587    }
588    None
589}
590
591// get the pointee or wrapped type
592pub fn get_pointee(matched_ty: Ty<'_>) -> Ty<'_> {
593    // progress_info!("get_pointee: > {:?} as type: {:?}", matched_ty, matched_ty.kind());
594    let pointee = if let ty::RawPtr(ty_mut, _) = matched_ty.kind() {
595        get_pointee(*ty_mut)
596    } else if let ty::Ref(_, referred_ty, _) = matched_ty.kind() {
597        get_pointee(*referred_ty)
598    } else {
599        matched_ty
600    };
601    pointee
602}
603
604pub fn is_ptr(matched_ty: Ty<'_>) -> bool {
605    if let ty::RawPtr(_, _) = matched_ty.kind() {
606        return true;
607    }
608    false
609}
610
611pub fn is_ref(matched_ty: Ty<'_>) -> bool {
612    if let ty::Ref(_, _, _) = matched_ty.kind() {
613        return true;
614    }
615    false
616}
617
618pub fn is_slice(matched_ty: Ty) -> Option<Ty> {
619    if let ty::Slice(inner) = matched_ty.kind() {
620        return Some(*inner);
621    }
622    None
623}
624
625pub fn has_mut_self_param(tcx: TyCtxt, def_id: DefId) -> bool {
626    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
627        match assoc_item.kind {
628            AssocKind::Fn { has_self, .. } => {
629                if has_self && tcx.is_mir_available(def_id) {
630                    let body = tcx.optimized_mir(def_id);
631                    let fst_arg = body.local_decls[Local::from_usize(1)].clone();
632                    let ty = fst_arg.ty;
633                    let is_mut_ref =
634                        matches!(ty.kind(), ty::Ref(_, _, mutbl) if *mutbl == Mutability::Mut);
635                    return fst_arg.mutability.is_mut() || is_mut_ref;
636                }
637            }
638            _ => (),
639        }
640    }
641    false
642}
643
644// Check each field's visibility, return the public fields vec
645pub fn get_public_fields(tcx: TyCtxt, def_id: DefId) -> HashSet<usize> {
646    let adt_def = tcx.adt_def(def_id);
647    adt_def
648        .all_fields()
649        .enumerate()
650        .filter_map(|(index, field_def)| tcx.visibility(field_def.did).is_public().then_some(index))
651        .collect()
652}
653
654// general function for displaying hashmap
655pub fn display_hashmap<K, V>(map: &HashMap<K, V>, level: usize)
656where
657    K: Ord + Debug + Hash,
658    V: Debug,
659{
660    let indent = "  ".repeat(level);
661    let mut sorted_keys: Vec<_> = map.keys().collect();
662    sorted_keys.sort();
663
664    for key in sorted_keys {
665        if let Some(value) = map.get(key) {
666            println!("{}{:?}: {:?}", indent, key, value);
667        }
668    }
669}
670
671// pub fn get_all_std_unsafe_chains(tcx: TyCtxt, def_id: DefId) -> Vec<String> {
672//     let mut results = Vec::new();
673//     let body = tcx.optimized_mir(def_id);
674//     let bb_len = body.basic_blocks.len();
675//     for i in 0..bb_len {
676//         let callees = match_std_unsafe_chains_callee(
677//             tcx,
678//             body.basic_blocks[BasicBlock::from_usize(i)]
679//                 .clone()
680//                 .terminator(),
681//         );
682//         results.extend(callees);
683//     }
684//     results
685// }
686
687pub fn match_std_unsafe_chains_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
688    let mut results = Vec::new();
689    if let TerminatorKind::Call { func, .. } = &terminator.kind {
690        if let Operand::Constant(func_constant) = func {
691            if let ty::FnDef(callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
692                let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
693            }
694        }
695    }
696    results
697}
698
699pub fn get_all_std_unsafe_callees(tcx: TyCtxt, def_id: DefId) -> Vec<String> {
700    let mut results = Vec::new();
701    let body = tcx.optimized_mir(def_id);
702    let bb_len = body.basic_blocks.len();
703    for i in 0..bb_len {
704        let callees = match_std_unsafe_callee(
705            tcx,
706            body.basic_blocks[BasicBlock::from_usize(i)]
707                .clone()
708                .terminator(),
709        );
710        results.extend(callees);
711    }
712    results
713}
714
715pub fn get_all_std_unsafe_callees_block_id(tcx: TyCtxt, def_id: DefId) -> Vec<usize> {
716    let mut results = Vec::new();
717    let body = tcx.optimized_mir(def_id);
718    let bb_len = body.basic_blocks.len();
719    for i in 0..bb_len {
720        if match_std_unsafe_callee(
721            tcx,
722            body.basic_blocks[BasicBlock::from_usize(i)]
723                .clone()
724                .terminator(),
725        )
726        .is_empty()
727        {
728            results.push(i);
729        }
730    }
731    results
732}
733
734pub fn match_std_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
735    let mut results = Vec::new();
736    if let TerminatorKind::Call { func, .. } = &terminator.kind {
737        if let Operand::Constant(func_constant) = func {
738            if let ty::FnDef(callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
739                let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
740                if parse_unsafe_api(&func_name).is_some() {
741                    results.push(func_name);
742                }
743            }
744        }
745    }
746    results
747}
748
749// Bug definition: (1) strict -> weak & dst is mutable;
750//                 (2) _ -> strict
751pub fn is_strict_ty_convert<'tcx>(tcx: TyCtxt<'tcx>, src_ty: Ty<'tcx>, dst_ty: Ty<'tcx>) -> bool {
752    (is_strict_ty(tcx, src_ty) && dst_ty.is_mutable_ptr()) || is_strict_ty(tcx, dst_ty)
753}
754
755// strict ty: bool, str, adt fields containing bool or str;
756pub fn is_strict_ty<'tcx>(tcx: TyCtxt<'tcx>, ori_ty: Ty<'tcx>) -> bool {
757    let ty = get_pointee(ori_ty);
758    let mut flag = false;
759    if let TyKind::Adt(adt_def, substs) = ty.kind() {
760        if adt_def.is_struct() {
761            for field_def in adt_def.all_fields() {
762                flag |= is_strict_ty(tcx, field_def.ty(tcx, substs))
763            }
764        }
765    }
766    ty.is_bool() || ty.is_str() || flag
767}
768
769pub fn reverse_op(op: BinOp) -> BinOp {
770    match op {
771        BinOp::Lt => BinOp::Ge,
772        BinOp::Ge => BinOp::Lt,
773        BinOp::Le => BinOp::Gt,
774        BinOp::Gt => BinOp::Le,
775        BinOp::Eq => BinOp::Eq,
776        BinOp::Ne => BinOp::Ne,
777        _ => op,
778    }
779}
780
781/// Same with `generate_contract_from_annotation` but does not contain field types.
782pub fn generate_contract_from_annotation_without_field_types(
783    tcx: TyCtxt,
784    def_id: DefId,
785) -> Vec<(usize, Vec<usize>, PropertyContract)> {
786    let contracts_with_ty = generate_contract_from_annotation(tcx, def_id);
787
788    contracts_with_ty
789        .into_iter()
790        .map(|(local_id, fields_with_ty, contract)| {
791            let fields: Vec<usize> = fields_with_ty
792                .into_iter()
793                .map(|(field_idx, _)| field_idx)
794                .collect();
795            (local_id, fields, contract)
796        })
797        .collect()
798}
799
800/// Filter the function which contains "rapx::proof"
801pub fn is_verify_target_func(tcx: TyCtxt, def_id: DefId) -> bool {
802    for attr in tcx.get_all_attrs(def_id).into_iter() {
803        let attr_str = rustc_hir_pretty::attribute_to_string(&tcx, attr);
804        // Find proof placeholder
805        if attr_str.contains("#[rapx::proof(proof)]") {
806            return true;
807        }
808    }
809    false
810}
811
812/// Get the annotation in tag-std style.
813/// Then generate the contractual invariant states (CIS) for the args.
814/// This function will recognize the args name and record states to MIR variable (represent by usize).
815/// Return value means Vec<(local_id, fields of this local, contracts)>
816pub fn generate_contract_from_annotation(
817    tcx: TyCtxt,
818    def_id: DefId,
819) -> Vec<(usize, Vec<(usize, Ty)>, PropertyContract)> {
820    const REGISTER_TOOL: &str = "rapx";
821    let tool_attrs = tcx.get_all_attrs(def_id).into_iter().filter(|attr| {
822        if let Attribute::Unparsed(tool_attr) = attr {
823            if tool_attr.path.segments[0].as_str() == REGISTER_TOOL {
824                return true;
825            }
826        }
827        false
828    });
829    let mut results = Vec::new();
830    for attr in tool_attrs {
831        let attr_str = rustc_hir_pretty::attribute_to_string(&tcx, attr);
832        // Find proof placeholder, skip it
833        if attr_str.contains("#[rapx::proof(proof)]") {
834            continue;
835        }
836        rap_debug!("{:?}", attr_str);
837        let safety_attr = safety_parser::safety::parse_attr_and_get_properties(attr_str.as_str());
838        for par in safety_attr.iter() {
839            for property in par.tags.iter() {
840                let tag_name = property.tag.name();
841                let exprs = property.args.clone().into_vec();
842                let contract = PropertyContract::new(tcx, def_id, tag_name, &exprs);
843                let (local, fields) = parse_cis_local(tcx, def_id, exprs);
844                results.push((local, fields, contract));
845            }
846        }
847    }
848    // if results.len() > 0 {
849    //     rap_warn!("results:\n{:?}", results);
850    // }
851    results
852}
853
854/// Parse attr.expr into local id and local fields.
855///
856/// Example:
857/// ```
858/// #[rapx::inner(property = ValidPtr(ptr, u32, 1), kind = "precond")]
859/// #[rapx::inner(property = ValidNum(region.size>=0), kind = "precond")]
860/// pub fn xor_secret_region(ptr: *mut u32, region:SecretRegion) -> u32 {...}
861/// ```
862///
863/// The first attribute will be parsed as (1, []).
864///     -> "1" means the first arg "ptr", "[]" means no fields.
865/// The second attribute will be parsed as (2, [1]).
866///     -> "2" means the second arg "region", "[1]" means "size" is region's second field.
867///
868/// If this function doesn't have args, then it will return default pattern: (0, Vec::new())
869pub fn parse_cis_local(tcx: TyCtxt, def_id: DefId, expr: Vec<Expr>) -> (usize, Vec<(usize, Ty)>) {
870    // match expr with cis local
871    for e in expr {
872        if let Some((base, fields, _ty)) = parse_expr_into_local_and_ty(tcx, def_id, &e) {
873            return (base, fields);
874        }
875    }
876    (0, Vec::new())
877}
878
879/// parse single expr into (local, fields, ty)
880pub fn parse_expr_into_local_and_ty<'tcx>(
881    tcx: TyCtxt<'tcx>,
882    def_id: DefId,
883    expr: &Expr,
884) -> Option<(usize, Vec<(usize, Ty<'tcx>)>, Ty<'tcx>)> {
885    if let Some((base_ident, fields)) = access_ident_recursive(&expr) {
886        let (param_names, param_tys) = parse_signature(tcx, def_id);
887        if param_names[0] == "0".to_string() {
888            return None;
889        }
890        if let Some(param_index) = param_names.iter().position(|name| name == &base_ident) {
891            let mut current_ty = param_tys[param_index];
892            let mut field_indices = Vec::new();
893            for field_name in fields {
894                // peel the ref and ptr
895                let peeled_ty = current_ty.peel_refs();
896                if let rustc_middle::ty::TyKind::Adt(adt_def, arg_list) = *peeled_ty.kind() {
897                    let variant = adt_def.non_enum_variant();
898                    // 1. if field_name is number, then parse it as usize
899                    if let Ok(field_idx) = field_name.parse::<usize>() {
900                        if field_idx < variant.fields.len() {
901                            current_ty = variant.fields[rustc_abi::FieldIdx::from_usize(field_idx)]
902                                .ty(tcx, arg_list);
903                            field_indices.push((field_idx, current_ty));
904                            continue;
905                        }
906                    }
907                    // 2. if field_name is String, then compare it with current ty's field names
908                    if let Some((idx, _)) = variant
909                        .fields
910                        .iter()
911                        .enumerate()
912                        .find(|(_, f)| f.ident(tcx).name.to_string() == field_name.clone())
913                    {
914                        current_ty =
915                            variant.fields[rustc_abi::FieldIdx::from_usize(idx)].ty(tcx, arg_list);
916                        field_indices.push((idx, current_ty));
917                    }
918                    // 3. if field_name can not match any fields, then break
919                    else {
920                        break; // TODO:
921                    }
922                }
923                // if current ty is not Adt, then break the loop
924                else {
925                    break; // TODO:
926                }
927            }
928            // It's different from default one, we return the result as param_index+1 because param_index count from 0.
929            // But 0 in MIR is the ret index, the args' indexes begin from 1.
930            return Some((param_index + 1, field_indices, current_ty));
931        }
932    }
933    None
934}
935
936/// Return the Vecs of args' names and types
937/// This function will handle outside def_id by different way.
938pub fn parse_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
939    // 0. If the def id is local
940    if def_id.as_local().is_some() {
941        return parse_local_signature(tcx, def_id);
942    } else {
943        rap_debug!("{:?} is not local def id.", def_id);
944        return parse_outside_signature(tcx, def_id);
945    };
946}
947
948/// Return the Vecs of args' names and types of outside functions.
949fn parse_outside_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
950    let sig = tcx.fn_sig(def_id).skip_binder();
951    let param_tys: Vec<Ty<'tcx>> = sig.inputs().skip_binder().iter().copied().collect();
952
953    // 1. check pre-defined std unsafe api signature
954    if let Some(args_name) = get_known_std_names(tcx, def_id) {
955        // rap_warn!(
956        //     "function {:?} has arg: {:?}, arg types: {:?}",
957        //     def_id,
958        //     args_name,
959        //     param_tys
960        // );
961        return (args_name, param_tys);
962    }
963
964    // 2. TODO: If can not find known std apis, then use numbers like `0`,`1`,... to represent args.
965    let args_name = (0..param_tys.len()).map(|i| format!("{}", i)).collect();
966    rap_debug!(
967        "function {:?} has arg: {:?}, arg types: {:?}",
968        def_id,
969        args_name,
970        param_tys
971    );
972    return (args_name, param_tys);
973}
974
975/// We use a json to record known std apis' arg names.
976/// This function will search the json and return the names.
977/// Notes: If std gets updated, the json may still record old ones.
978fn get_known_std_names<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Vec<String>> {
979    let std_func_name = get_cleaned_def_path_name(tcx, def_id);
980    let json_data: serde_json::Value = get_std_api_signature_json();
981
982    if let Some(arg_info) = json_data.get(&std_func_name) {
983        if let Some(args_name) = arg_info.as_array() {
984            // set default value to arg name
985            if args_name.len() == 0 {
986                return Some(vec!["0".to_string()]);
987            }
988            // iterate and collect
989            let mut result = Vec::new();
990            for arg in args_name {
991                if let Some(sp_name) = arg.as_str() {
992                    result.push(sp_name.to_string());
993                }
994            }
995            return Some(result);
996        }
997    }
998    None
999}
1000
1001/// Return the Vecs of args' names and types of local functions.
1002pub fn parse_local_signature(tcx: TyCtxt, def_id: DefId) -> (Vec<String>, Vec<Ty>) {
1003    // 1. parse local def_id and get arg list
1004    let local_def_id = def_id.as_local().unwrap();
1005    let hir_body = tcx.hir_body_owned_by(local_def_id);
1006    if hir_body.params.len() == 0 {
1007        return (vec!["0".to_string()], Vec::new());
1008    }
1009    // 2. contruct the vec of param and param ty
1010    let params = hir_body.params;
1011    let typeck_results = tcx.typeck_body(hir_body.id());
1012    let mut param_names = Vec::new();
1013    let mut param_tys = Vec::new();
1014    for param in params {
1015        match param.pat.kind {
1016            rustc_hir::PatKind::Binding(_, _, ident, _) => {
1017                param_names.push(ident.name.to_string());
1018                let ty = typeck_results.pat_ty(param.pat);
1019                param_tys.push(ty);
1020            }
1021            _ => {
1022                param_names.push(String::new());
1023                param_tys.push(typeck_results.pat_ty(param.pat));
1024            }
1025        }
1026    }
1027    (param_names, param_tys)
1028}
1029
1030/// return the (ident, its fields) of the expr.
1031///
1032/// illustrated cases :
1033///    ptr	-> ("ptr", [])
1034///    region.size	-> ("region", ["size"])
1035///    tuple.0.value -> ("tuple", ["0", "value"])
1036pub fn access_ident_recursive(expr: &Expr) -> Option<(String, Vec<String>)> {
1037    match expr {
1038        Expr::Path(syn::ExprPath { path, .. }) => {
1039            if path.segments.len() == 1 {
1040                rap_debug!("expr2 {:?}", expr);
1041                let ident = path.segments[0].ident.to_string();
1042                Some((ident, Vec::new()))
1043            } else {
1044                None
1045            }
1046        }
1047        // get the base and fields recursively
1048        Expr::Field(syn::ExprField { base, member, .. }) => {
1049            let (base_ident, mut fields) =
1050                if let Some((base_ident, fields)) = access_ident_recursive(base) {
1051                    (base_ident, fields)
1052                } else {
1053                    return None;
1054                };
1055            let field_name = match member {
1056                syn::Member::Named(ident) => ident.to_string(),
1057                syn::Member::Unnamed(index) => index.index.to_string(),
1058            };
1059            fields.push(field_name);
1060            Some((base_ident, fields))
1061        }
1062        _ => None,
1063    }
1064}
1065
1066/// parse expr into number.
1067pub fn parse_expr_into_number(expr: &Expr) -> Option<usize> {
1068    if let Expr::Lit(expr_lit) = expr {
1069        if let syn::Lit::Int(lit_int) = &expr_lit.lit {
1070            return lit_int.base10_parse::<usize>().ok();
1071        }
1072    }
1073    None
1074}
1075
1076/// Match a type identifier string to a concrete Rust type
1077///
1078/// This function attempts to match a given type identifier (e.g., "u32", "T", "MyStruct")
1079/// to a type in the provided parameter type list. It handles:
1080/// 1. Built-in primitive types (u32, usize, etc.)
1081/// 2. Generic type parameters (T, U, etc.)
1082/// 3. User-defined types found in the parameter list
1083///
1084/// Arguments:
1085/// - `tcx`: Type context for querying compiler information
1086/// - `type_ident`: String representing the type identifier to match
1087/// - `param_ty`: List of parameter types from the function signature
1088///
1089/// Returns:
1090/// - `Some(Ty)` if a matching type is found
1091/// - `None` if no match is found
1092pub fn match_ty_with_ident(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
1093    // 1. First check for built-in primitive types
1094    if let Some(primitive_ty) = match_primitive_type(tcx, type_ident.clone()) {
1095        return Some(primitive_ty);
1096    }
1097    // 2. Check if the identifier matches any generic type parameter
1098    return find_generic_param(tcx, def_id, type_ident.clone());
1099    // 3. Check if the identifier matches any user-defined type in the parameters
1100    // find_user_defined_type(tcx, def_id, type_ident)
1101}
1102
1103/// Match built-in primitive types from String
1104fn match_primitive_type(tcx: TyCtxt, type_ident: String) -> Option<Ty> {
1105    match type_ident.as_str() {
1106        "i8" => Some(tcx.types.i8),
1107        "i16" => Some(tcx.types.i16),
1108        "i32" => Some(tcx.types.i32),
1109        "i64" => Some(tcx.types.i64),
1110        "i128" => Some(tcx.types.i128),
1111        "isize" => Some(tcx.types.isize),
1112        "u8" => Some(tcx.types.u8),
1113        "u16" => Some(tcx.types.u16),
1114        "u32" => Some(tcx.types.u32),
1115        "u64" => Some(tcx.types.u64),
1116        "u128" => Some(tcx.types.u128),
1117        "usize" => Some(tcx.types.usize),
1118        "f16" => Some(tcx.types.f16),
1119        "f32" => Some(tcx.types.f32),
1120        "f64" => Some(tcx.types.f64),
1121        "f128" => Some(tcx.types.f128),
1122        "bool" => Some(tcx.types.bool),
1123        "char" => Some(tcx.types.char),
1124        "str" => Some(tcx.types.str_),
1125        _ => None,
1126    }
1127}
1128
1129/// Find generic type parameters in the parameter list
1130fn find_generic_param(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
1131    rap_debug!(
1132        "Searching for generic param: {} in {:?}",
1133        type_ident,
1134        def_id
1135    );
1136    let (_, param_tys) = parse_signature(tcx, def_id);
1137    rap_debug!("Function parameter types: {:?} of {:?}", param_tys, def_id);
1138    // 递归查找泛型参数
1139    for &ty in &param_tys {
1140        if let Some(found) = find_generic_in_ty(tcx, ty, &type_ident) {
1141            return Some(found);
1142        }
1143    }
1144
1145    None
1146}
1147
1148/// Iterate the args' types recursively and find the matched generic one.
1149fn find_generic_in_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>, type_ident: &str) -> Option<Ty<'tcx>> {
1150    match ty.kind() {
1151        TyKind::Param(param_ty) => {
1152            if param_ty.name.as_str() == type_ident {
1153                return Some(ty);
1154            }
1155        }
1156        TyKind::RawPtr(ty, _)
1157        | TyKind::Ref(_, ty, _)
1158        | TyKind::Slice(ty)
1159        | TyKind::Array(ty, _) => {
1160            if let Some(found) = find_generic_in_ty(tcx, *ty, type_ident) {
1161                return Some(found);
1162            }
1163        }
1164        TyKind::Tuple(tys) => {
1165            for tuple_ty in tys.iter() {
1166                if let Some(found) = find_generic_in_ty(tcx, tuple_ty, type_ident) {
1167                    return Some(found);
1168                }
1169            }
1170        }
1171        TyKind::Adt(adt_def, substs) => {
1172            let name = tcx.item_name(adt_def.did()).to_string();
1173            if name == type_ident {
1174                return Some(ty);
1175            }
1176            for field in adt_def.all_fields() {
1177                let field_ty = field.ty(tcx, substs);
1178                if let Some(found) = find_generic_in_ty(tcx, field_ty, type_ident) {
1179                    return Some(found);
1180                }
1181            }
1182        }
1183        _ => {}
1184    }
1185    None
1186}
1187
1188pub fn reflect_generic<'tcx>(
1189    generic_mapping: &FxHashMap<String, Ty<'tcx>>,
1190    ty: Ty<'tcx>,
1191) -> Ty<'tcx> {
1192    match ty.kind() {
1193        TyKind::Param(param_ty) => {
1194            let generic_name = param_ty.name.to_string();
1195            if let Some(actual_ty) = generic_mapping.get(&generic_name) {
1196                return *actual_ty;
1197            }
1198        }
1199        _ => {}
1200    }
1201    ty
1202}
1203
1204// src_var = 0: for constructor
1205// src_var = 1: for methods
1206pub fn has_tainted_fields(tcx: TyCtxt, def_id: DefId, src_var: u32) -> bool {
1207    let mut dataflow_analyzer = DataFlowAnalyzer::new(tcx, false);
1208    dataflow_analyzer.build_graph(def_id);
1209
1210    let body = tcx.optimized_mir(def_id);
1211    let params = &body.args_iter().collect::<Vec<_>>();
1212    rap_info!("params {:?}", params);
1213    let self_local = Local::from(src_var);
1214
1215    let flowing_params: Vec<Local> = params
1216        .iter()
1217        .filter(|&&param_local| {
1218            dataflow_analyzer.has_flow_between(def_id, self_local, param_local)
1219                && self_local != param_local
1220        })
1221        .copied()
1222        .collect();
1223
1224    if !flowing_params.is_empty() {
1225        rap_info!(
1226            "Taint flow found from self to other parameters: {:?}",
1227            flowing_params
1228        );
1229        true
1230    } else {
1231        false
1232    }
1233}
1234
1235// 修改返回值类型为调用链的向量
1236pub fn get_all_std_unsafe_chains(tcx: TyCtxt, def_id: DefId) -> Vec<Vec<String>> {
1237    let mut results = Vec::new();
1238    let mut visited = HashSet::new(); // 避免循环调用
1239    let mut current_chain = Vec::new();
1240
1241    // 开始DFS遍历
1242    dfs_find_unsafe_chains(tcx, def_id, &mut current_chain, &mut results, &mut visited);
1243    results
1244}
1245
1246// DFS递归查找unsafe调用链
1247fn dfs_find_unsafe_chains(
1248    tcx: TyCtxt,
1249    def_id: DefId,
1250    current_chain: &mut Vec<String>,
1251    results: &mut Vec<Vec<String>>,
1252    visited: &mut HashSet<DefId>,
1253) {
1254    // 避免循环调用
1255    if visited.contains(&def_id) {
1256        return;
1257    }
1258    visited.insert(def_id);
1259
1260    let current_func_name = get_cleaned_def_path_name(tcx, def_id);
1261    current_chain.push(current_func_name.clone());
1262
1263    // 获取当前函数的所有unsafe callee
1264    let unsafe_callees = find_unsafe_callees_in_function(tcx, def_id);
1265
1266    if unsafe_callees.is_empty() {
1267        // 如果没有更多的unsafe callee,保存当前链
1268        results.push(current_chain.clone());
1269    } else {
1270        // 对每个unsafe callee继续DFS
1271        for (callee_def_id, callee_name) in unsafe_callees {
1272            dfs_find_unsafe_chains(tcx, callee_def_id, current_chain, results, visited);
1273        }
1274    }
1275
1276    // 回溯
1277    current_chain.pop();
1278    visited.remove(&def_id);
1279}
1280
1281fn find_unsafe_callees_in_function(tcx: TyCtxt, def_id: DefId) -> Vec<(DefId, String)> {
1282    let mut callees = Vec::new();
1283
1284    if let Some(body) = try_get_mir(tcx, def_id) {
1285        for bb in body.basic_blocks.iter() {
1286            if let Some(terminator) = &bb.terminator {
1287                if let Some((callee_def_id, callee_name)) = extract_unsafe_callee(tcx, terminator) {
1288                    callees.push((callee_def_id, callee_name));
1289                }
1290            }
1291        }
1292    }
1293
1294    callees
1295}
1296
1297fn extract_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Option<(DefId, String)> {
1298    if let TerminatorKind::Call { func, .. } = &terminator.kind {
1299        if let Operand::Constant(func_constant) = func {
1300            if let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() {
1301                if check_safety(tcx, *callee_def_id) == Safety::Unsafe {
1302                    let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
1303                    return Some((*callee_def_id, func_name));
1304                }
1305            }
1306        }
1307    }
1308    None
1309}
1310
1311fn try_get_mir(tcx: TyCtxt<'_>, def_id: DefId) -> Option<&rustc_middle::mir::Body<'_>> {
1312    if tcx.is_mir_available(def_id) {
1313        Some(tcx.optimized_mir(def_id))
1314    } else {
1315        None
1316    }
1317}
1318
1319pub fn get_cleaned_def_path_name(tcx: TyCtxt<'_>, def_id: DefId) -> String {
1320    tcx.def_path_str(def_id)
1321        .replace("::", "_")
1322        .replace("<", "_")
1323        .replace(">", "_")
1324        .replace(",", "_")
1325        .replace(" ", "")
1326        .replace("__", "_")
1327}
1328
1329pub fn print_unsafe_chains(chains: &[Vec<String>]) {
1330    if chains.is_empty() {
1331        return;
1332    }
1333
1334    println!("==============================");
1335    println!("Found {} unsafe call chain(s):", chains.len());
1336    for (i, chain) in chains.iter().enumerate() {
1337        println!("Chain {}:", i + 1);
1338        for (j, func_name) in chain.iter().enumerate() {
1339            let indent = "  ".repeat(j);
1340            println!("{}{}-> {}", indent, if j > 0 { " " } else { "" }, func_name);
1341        }
1342        println!();
1343    }
1344}
1345
1346pub fn get_all_std_fns_by_rustc_public(tcx: TyCtxt) -> Vec<DefId> {
1347    let mut all_std_fn_def = Vec::new();
1348    let mut results = Vec::new();
1349    let mut core_fn_def: Vec<_> = rustc_public::find_crates("core")
1350        .iter()
1351        .flat_map(|krate| krate.fn_defs())
1352        .collect();
1353    let mut std_fn_def: Vec<_> = rustc_public::find_crates("std")
1354        .iter()
1355        .flat_map(|krate| krate.fn_defs())
1356        .collect();
1357    let mut alloc_fn_def: Vec<_> = rustc_public::find_crates("alloc")
1358        .iter()
1359        .flat_map(|krate| krate.fn_defs())
1360        .collect();
1361    all_std_fn_def.append(&mut core_fn_def);
1362    all_std_fn_def.append(&mut std_fn_def);
1363    all_std_fn_def.append(&mut alloc_fn_def);
1364
1365    for fn_def in &all_std_fn_def {
1366        let def_id = crate::def_id::to_internal(fn_def, tcx);
1367        results.push(def_id);
1368    }
1369    results
1370}
1371
1372pub fn generate_mir_cfg_dot(tcx: TyCtxt<'_>, def_id: DefId) -> Result<(), std::io::Error> {
1373    let mir = tcx.optimized_mir(def_id);
1374
1375    let mut dot_content = String::new();
1376
1377    // Setup Header
1378    dot_content.push_str(&format!(
1379        "digraph mir_cfg_{} {{\n",
1380        get_cleaned_def_path_name(tcx, def_id)
1381    ));
1382    dot_content.push_str(&format!(
1383        "    label = \"MIR CFG for {}\";\n",
1384        tcx.def_path_str(def_id)
1385    ));
1386    dot_content.push_str("    labelloc = \"t\";\n");
1387    dot_content.push_str("    node [shape=box, fontname=\"Courier\", align=\"left\"];\n\n");
1388
1389    for (bb_index, bb_data) in mir.basic_blocks.iter_enumerated() {
1390        let mut lines: Vec<String> = bb_data
1391            .statements
1392            .iter()
1393            .map(|stmt| format!("{:?}", stmt))
1394            .collect();
1395
1396        let mut node_style = String::new();
1397
1398        if let Some(terminator) = &bb_data.terminator {
1399            if let TerminatorKind::Drop { .. } = terminator.kind {
1400                node_style = ", style=\"filled\", fillcolor=\"#ffdddd\", color=\"red\"".to_string();
1401            }
1402
1403            lines.push(format!("{:?}", terminator.kind));
1404        } else {
1405            lines.push("(no terminator)".to_string());
1406        }
1407
1408        let label_content = lines.join("\\l");
1409
1410        let node_label = format!("BB{}:\\l{}\\l", bb_index.index(), label_content);
1411
1412        dot_content.push_str(&format!(
1413            "    BB{} [label=\"{}\"{}];\n",
1414            bb_index.index(),
1415            node_label.replace("\"", "\\\""),
1416            node_style
1417        ));
1418
1419        if let Some(terminator) = &bb_data.terminator {
1420            for target in terminator.successors() {
1421                let edge_label = match terminator.kind {
1422                    _ => "".to_string(),
1423                };
1424
1425                dot_content.push_str(&format!(
1426                    "    BB{} -> BB{} [label=\"{}\"];\n",
1427                    bb_index.index(),
1428                    target.index(),
1429                    edge_label
1430                ));
1431            }
1432        }
1433    }
1434    dot_content.push_str("}\n");
1435    let name = get_cleaned_def_path_name(tcx, def_id);
1436    render_dot_string(name, dot_content);
1437    rap_debug!("render dot for {:?}", def_id);
1438    Ok(())
1439}
1440
1441pub fn convert_alias_to_sets(alias_map: Vec<usize>) -> Vec<Vec<usize>> {
1442    let mut groups: HashMap<usize, Vec<usize>> = HashMap::new();
1443
1444    for (local_id, &representative) in alias_map.iter().enumerate() {
1445        groups
1446            .entry(representative)
1447            .or_insert_with(Vec::new)
1448            .push(local_id);
1449    }
1450
1451    let mut result: Vec<Vec<usize>> = groups.into_values().collect();
1452
1453    for group in &mut result {
1454        group.sort();
1455    }
1456    result.sort_by_key(|group| group[0]);
1457
1458    result
1459}
1460
1461// Input the adt def id
1462// Return set of (mutable method def_id, fields can be modified)
1463pub fn get_all_mutable_methods(tcx: TyCtxt, src_def_id: DefId) -> HashMap<DefId, HashSet<usize>> {
1464    let mut std_results = HashMap::new();
1465    if get_type(tcx, src_def_id) == FnKind::Constructor {
1466        return std_results;
1467    }
1468    let all_std_fn_def = get_all_std_fns_by_rustc_public(tcx);
1469    let target_adt_def = get_adt_def_id_by_adt_method(tcx, src_def_id);
1470    let mut is_std = false;
1471    for &def_id in &all_std_fn_def {
1472        let adt_def = get_adt_def_id_by_adt_method(tcx, def_id);
1473        if adt_def.is_some() && adt_def == target_adt_def && src_def_id != def_id {
1474            if has_mut_self_param(tcx, def_id) {
1475                std_results.insert(def_id, HashSet::new());
1476            }
1477            is_std = true;
1478        }
1479    }
1480    if is_std {
1481        return std_results;
1482    }
1483    let mut results = HashMap::new();
1484    let public_fields = target_adt_def.map_or_else(HashSet::new, |def| get_public_fields(tcx, def));
1485    let impl_vec = target_adt_def.map_or_else(Vec::new, |def| get_impls_for_struct(tcx, def));
1486    for impl_id in impl_vec {
1487        if !matches!(tcx.def_kind(impl_id), rustc_hir::def::DefKind::Impl { .. }) {
1488            continue;
1489        }
1490        let associated_items = tcx.associated_items(impl_id);
1491        for item in associated_items.in_definition_order() {
1492            if let ty::AssocKind::Fn {
1493                name: _,
1494                has_self: _,
1495            } = item.kind
1496            {
1497                let item_def_id = item.def_id;
1498                if has_mut_self_param(tcx, item_def_id) {
1499                    let modified_fields = public_fields.clone();
1500                    results.insert(item_def_id, modified_fields);
1501                }
1502            }
1503        }
1504    }
1505    results
1506}
1507
1508pub fn get_cons(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<DefId> {
1509    let mut cons = Vec::new();
1510    if tcx.def_kind(def_id) == DefKind::Fn || get_type(tcx, def_id) == FnKind::Constructor {
1511        return cons;
1512    }
1513    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
1514        if let Some(impl_id) = assoc_item.impl_container(tcx) {
1515            // get struct ty
1516            let ty = tcx.type_of(impl_id).skip_binder();
1517            if let Some(adt_def) = ty.ty_adt_def() {
1518                let adt_def_id = adt_def.did();
1519                let impls = tcx.inherent_impls(adt_def_id);
1520                for impl_def_id in impls {
1521                    for item in tcx.associated_item_def_ids(impl_def_id) {
1522                        if (tcx.def_kind(item) == DefKind::Fn
1523                            || tcx.def_kind(item) == DefKind::AssocFn)
1524                            && get_type(tcx, *item) == FnKind::Constructor
1525                        {
1526                            cons.push(*item);
1527                        }
1528                    }
1529                }
1530            }
1531        }
1532    }
1533    cons
1534}
1535
1536pub fn append_fn_with_types(tcx: TyCtxt, def_id: DefId) -> FnInfo {
1537    FnInfo::new(def_id, check_safety(tcx, def_id), get_type(tcx, def_id))
1538}
1539pub fn search_constructor(tcx: TyCtxt, def_id: DefId) -> Vec<DefId> {
1540    let mut constructors = Vec::new();
1541    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
1542        if let Some(impl_id) = assoc_item.impl_container(tcx) {
1543            // get struct ty
1544            let ty = tcx.type_of(impl_id).skip_binder();
1545            if let Some(adt_def) = ty.ty_adt_def() {
1546                let adt_def_id = adt_def.did();
1547                let impl_vec = get_impls_for_struct(tcx, adt_def_id);
1548                for impl_id in impl_vec {
1549                    let associated_items = tcx.associated_items(impl_id);
1550                    for item in associated_items.in_definition_order() {
1551                        if let ty::AssocKind::Fn {
1552                            name: _,
1553                            has_self: _,
1554                        } = item.kind
1555                        {
1556                            let item_def_id = item.def_id;
1557                            if get_type(tcx, item_def_id) == FnKind::Constructor {
1558                                constructors.push(item_def_id);
1559                            }
1560                        }
1561                    }
1562                }
1563            }
1564        }
1565    }
1566    constructors
1567}
1568
1569pub fn get_ptr_deref_dummy_def_id(tcx: TyCtxt<'_>) -> Option<DefId> {
1570    tcx.hir_crate_items(()).free_items().find_map(|item_id| {
1571        let def_id = item_id.owner_id.to_def_id();
1572        let name = tcx.opt_item_name(def_id)?;
1573
1574        (name.as_str() == "__raw_ptr_deref_dummy").then_some(def_id)
1575    })
1576}