rapx/analysis/utils/
fn_info.rs

1use super::draw_dot::render_dot_string;
2use crate::analysis::{
3    core::dataflow::{DataFlowAnalysis, default::DataFlowAnalyzer},
4    senryx::{
5        contracts::{property, property::PropertyContract},
6        matcher::parse_unsafe_api,
7    },
8};
9use crate::def_id::*;
10use crate::{rap_debug, rap_warn};
11use rustc_ast::ItemKind;
12use rustc_data_structures::fx::{FxHashMap, FxHashSet};
13use rustc_hir::{
14    Attribute, ImplItemKind, Safety,
15    def::DefKind,
16    def_id::{CrateNum, DefId, DefIndex},
17};
18use rustc_middle::{
19    hir::place::PlaceBase,
20    mir::{
21        BasicBlock, BinOp, Body, Local, Operand, Place, PlaceElem, PlaceRef, ProjectionElem,
22        Rvalue, StatementKind, Terminator, TerminatorKind,
23    },
24    ty,
25    ty::{AssocKind, ConstKind, Mutability, Ty, TyCtxt, TyKind},
26};
27use rustc_span::{def_id::LocalDefId, kw, sym};
28use serde::de;
29use serde::{Deserialize, Serialize};
30use std::{
31    collections::{HashMap, HashSet},
32    fmt::Debug,
33    hash::Hash,
34};
35use syn::Expr;
36
37#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
38pub enum FnKind {
39    Fn,
40    Method,
41    Constructor,
42    Intrinsic,
43}
44
45#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
46pub struct FnInfo {
47    pub def_id: DefId,
48    pub fn_safety: Safety,
49    pub fn_kind: FnKind,
50}
51
52impl FnInfo {
53    pub fn new(def_id: DefId, fn_safety: Safety, fn_kind: FnKind) -> Self {
54        FnInfo {
55            def_id,
56            fn_safety,
57            fn_kind,
58        }
59    }
60}
61
62#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
63pub struct AdtInfo {
64    pub def_id: DefId,
65    pub literal_cons_enabled: bool,
66}
67
68impl AdtInfo {
69    pub fn new(def_id: DefId, literal_cons_enabled: bool) -> Self {
70        AdtInfo {
71            def_id,
72            literal_cons_enabled,
73        }
74    }
75}
76
77pub fn check_visibility(tcx: TyCtxt, func_defid: DefId) -> bool {
78    if !tcx.visibility(func_defid).is_public() {
79        return false;
80    }
81    true
82}
83
84pub fn is_re_exported(tcx: TyCtxt, target_defid: DefId, module_defid: LocalDefId) -> bool {
85    for child in tcx.module_children_local(module_defid) {
86        if child.vis.is_public() {
87            if let Some(def_id) = child.res.opt_def_id() {
88                if def_id == target_defid {
89                    return true;
90                }
91            }
92        }
93    }
94    false
95}
96
97pub fn print_hashset<T: std::fmt::Debug>(set: &HashSet<T>) {
98    for item in set {
99        println!("{:?}", item);
100    }
101    println!("---------------");
102}
103
104pub fn get_cleaned_def_path_name_ori(tcx: TyCtxt, def_id: DefId) -> String {
105    let def_id_str = format!("{:?}", def_id);
106    let mut parts: Vec<&str> = def_id_str.split("::").collect();
107
108    let mut remove_first = false;
109    if let Some(first_part) = parts.get_mut(0) {
110        if first_part.contains("core") {
111            *first_part = "core";
112        } else if first_part.contains("std") {
113            *first_part = "std";
114        } else if first_part.contains("alloc") {
115            *first_part = "alloc";
116        } else {
117            remove_first = true;
118        }
119    }
120    if remove_first && !parts.is_empty() {
121        parts.remove(0);
122    }
123
124    let new_parts: Vec<String> = parts
125        .into_iter()
126        .filter_map(|s| {
127            if s.contains("{") {
128                if remove_first {
129                    get_struct_name(tcx, def_id)
130                } else {
131                    None
132                }
133            } else {
134                Some(s.to_string())
135            }
136        })
137        .collect();
138
139    let mut cleaned_path = new_parts.join("::");
140    cleaned_path = cleaned_path.trim_end_matches(')').to_string();
141    cleaned_path
142}
143
144pub fn get_sp_tags_json() -> serde_json::Value {
145    let json_data: serde_json::Value =
146        serde_json::from_str(include_str!("data/std_sps.json")).expect("Unable to parse JSON");
147    json_data
148}
149
150pub fn get_std_api_signature_json() -> serde_json::Value {
151    let json_data: serde_json::Value =
152        serde_json::from_str(include_str!("data/std_sig.json")).expect("Unable to parse JSON");
153    json_data
154}
155
156pub fn get_sp_tags_and_args_json() -> serde_json::Value {
157    let json_data: serde_json::Value =
158        serde_json::from_str(include_str!("data/std_sps_args.json")).expect("Unable to parse JSON");
159    json_data
160}
161
162#[derive(Debug, Serialize, Deserialize, Clone)]
163pub struct ContractEntry {
164    pub tag: String,
165    pub args: Vec<String>,
166}
167
168pub fn get_std_contracts(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<ContractEntry> {
169    let cleaned_path_name = get_cleaned_def_path_name(tcx, def_id);
170    let json_data: serde_json::Value = get_sp_tags_and_args_json();
171
172    if let Some(entries) = json_data.get(&cleaned_path_name) {
173        if let Ok(contracts) = serde_json::from_value::<Vec<ContractEntry>>(entries.clone()) {
174            return contracts;
175        }
176    }
177    Vec::new()
178}
179
180pub fn get_sp(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<String> {
181    let cleaned_path_name = get_cleaned_def_path_name(tcx, def_id);
182    let json_data: serde_json::Value = get_sp_tags_json();
183
184    if let Some(function_info) = json_data.get(&cleaned_path_name) {
185        if let Some(sp_list) = function_info.get("0") {
186            let mut result = HashSet::new();
187            if let Some(sp_array) = sp_list.as_array() {
188                for sp in sp_array {
189                    if let Some(sp_name) = sp.as_str() {
190                        result.insert(sp_name.to_string());
191                    }
192                }
193            }
194            return result;
195        }
196    }
197    HashSet::new()
198}
199
200pub fn get_struct_name(tcx: TyCtxt<'_>, def_id: DefId) -> Option<String> {
201    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
202        if let Some(impl_id) = assoc_item.impl_container(tcx) {
203            let ty = tcx.type_of(impl_id).skip_binder();
204            let type_name = ty.to_string();
205            let struct_name = type_name
206                .split('<')
207                .next()
208                .unwrap_or("")
209                .split("::")
210                .last()
211                .unwrap_or("")
212                .to_string();
213
214            return Some(struct_name);
215        }
216    }
217    None
218}
219
220pub fn check_safety(tcx: TyCtxt<'_>, def_id: DefId) -> Safety {
221    let poly_fn_sig = tcx.fn_sig(def_id);
222    let fn_sig = poly_fn_sig.skip_binder();
223    fn_sig.safety()
224}
225
226pub fn get_type(tcx: TyCtxt<'_>, def_id: DefId) -> FnKind {
227    let mut node_type = 2;
228    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
229        match assoc_item.kind {
230            AssocKind::Fn { has_self, .. } => {
231                if has_self {
232                    return FnKind::Method;
233                } else {
234                    let fn_sig = tcx.fn_sig(def_id).skip_binder();
235                    let output = fn_sig.output().skip_binder();
236                    // return type is 'Self'
237                    if output.is_param(0) {
238                        return FnKind::Constructor;
239                    }
240                    // return type is struct's name
241                    if let Some(impl_id) = assoc_item.impl_container(tcx) {
242                        let ty = tcx.type_of(impl_id).skip_binder();
243                        if output == ty {
244                            return FnKind::Constructor;
245                        }
246                    }
247                    match output.kind() {
248                        TyKind::Ref(_, ref_ty, _) => {
249                            if ref_ty.is_param(0) {
250                                return FnKind::Constructor;
251                            }
252                            if let Some(impl_id) = assoc_item.impl_container(tcx) {
253                                let ty = tcx.type_of(impl_id).skip_binder();
254                                if *ref_ty == ty {
255                                    return FnKind::Constructor;
256                                }
257                            }
258                        }
259                        TyKind::Adt(adt_def, substs) => {
260                            if adt_def.is_enum()
261                                && (tcx.is_diagnostic_item(sym::Option, adt_def.did())
262                                    || tcx.is_diagnostic_item(sym::Result, adt_def.did())
263                                    || tcx.is_diagnostic_item(kw::Box, adt_def.did()))
264                            {
265                                let inner_ty = substs.type_at(0);
266                                if inner_ty.is_param(0) {
267                                    return FnKind::Constructor;
268                                }
269                                if let Some(impl_id) = assoc_item.impl_container(tcx) {
270                                    let ty_impl = tcx.type_of(impl_id).skip_binder();
271                                    if inner_ty == ty_impl {
272                                        return FnKind::Constructor;
273                                    }
274                                }
275                            }
276                        }
277                        _ => {}
278                    }
279                }
280            }
281            _ => todo!(),
282        }
283    }
284    return FnKind::Fn;
285}
286
287pub fn get_adt_ty(tcx: TyCtxt, def_id: DefId) -> Option<Ty> {
288    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
289        if let Some(impl_id) = assoc_item.impl_container(tcx) {
290            return Some(tcx.type_of(impl_id).skip_binder());
291        }
292    }
293    None
294}
295
296// check whether this adt contains a literal constructor
297// result: adt_def_id, is_literal
298pub fn get_adt_via_method(tcx: TyCtxt<'_>, method_def_id: DefId) -> Option<AdtInfo> {
299    let assoc_item = tcx.opt_associated_item(method_def_id)?;
300    let impl_id = assoc_item.impl_container(tcx)?;
301    let ty = tcx.type_of(impl_id).skip_binder();
302    let adt_def = ty.ty_adt_def()?;
303    let adt_def_id = adt_def.did();
304
305    let all_fields: Vec<_> = adt_def.all_fields().collect();
306    let total_count = all_fields.len();
307
308    if total_count == 0 {
309        return Some(AdtInfo::new(adt_def_id, true));
310    }
311
312    let pub_count = all_fields
313        .iter()
314        .filter(|field| tcx.visibility(field.did).is_public())
315        .count();
316
317    if pub_count == 0 {
318        return None;
319    }
320    Some(AdtInfo::new(adt_def_id, pub_count == total_count))
321}
322
323fn place_has_raw_deref<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, place: &Place<'tcx>) -> bool {
324    let mut local = place.local;
325    for proj in place.projection.iter() {
326        if let ProjectionElem::Deref = proj.kind() {
327            let ty = body.local_decls[local].ty;
328            if let TyKind::RawPtr(_, _) = ty.kind() {
329                return true;
330            }
331        }
332    }
333    false
334}
335
336/// Analyzes the MIR of the given function to collect all local variables
337/// that are involved in dereferencing raw pointers (`*const T` or `*mut T`).
338pub fn get_rawptr_deref(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<Local> {
339    let mut raw_ptrs = HashSet::new();
340    if tcx.is_mir_available(def_id) {
341        let body = tcx.optimized_mir(def_id);
342        for bb in body.basic_blocks.iter() {
343            for stmt in &bb.statements {
344                if let StatementKind::Assign(box (lhs, rhs)) = &stmt.kind {
345                    if place_has_raw_deref(tcx, &body, lhs) {
346                        raw_ptrs.insert(lhs.local);
347                    }
348                    if let Rvalue::Use(op) = rhs {
349                        match op {
350                            Operand::Copy(place) | Operand::Move(place) => {
351                                if place_has_raw_deref(tcx, &body, place) {
352                                    raw_ptrs.insert(place.local);
353                                }
354                            }
355                            _ => {}
356                        }
357                    }
358                    if let Rvalue::Ref(_, _, place) = rhs {
359                        if place_has_raw_deref(tcx, &body, place) {
360                            raw_ptrs.insert(place.local);
361                        }
362                    }
363                }
364            }
365            if let Some(terminator) = &bb.terminator {
366                match &terminator.kind {
367                    rustc_middle::mir::TerminatorKind::Call { args, .. } => {
368                        for arg in args {
369                            match arg.node {
370                                Operand::Copy(place) | Operand::Move(place) => {
371                                    if place_has_raw_deref(tcx, &body, &place) {
372                                        raw_ptrs.insert(place.local);
373                                    }
374                                }
375                                _ => {}
376                            }
377                        }
378                    }
379                    _ => {}
380                }
381            }
382        }
383    }
384    raw_ptrs
385}
386
387/* Example mir of static mutable access.
388
389static mut COUNTER: i32 = {
390    let mut _0: i32;
391
392    bb0: {
393        _0 = const 0_i32;
394        return;
395    }
396}
397
398fn main() -> () {
399    let mut _0: ();
400    let mut _1: *mut i32;
401
402    bb0: {
403        StorageLive(_1);
404        _1 = const {alloc1: *mut i32};
405        (*_1) = const 1_i32;
406        StorageDead(_1);
407        return;
408    }
409}
410
411alloc1 (static: COUNTER, size: 4, align: 4) {
412    00 00 00 00                                     │ ....
413}
414
415*/
416
417/// Collects pairs of global static variables and their corresponding local variables
418/// within a function's MIR that are assigned from statics.
419pub fn collect_global_local_pairs(tcx: TyCtxt<'_>, def_id: DefId) -> HashMap<DefId, Vec<Local>> {
420    let mut globals: HashMap<DefId, Vec<Local>> = HashMap::new();
421
422    if !tcx.is_mir_available(def_id) {
423        return globals;
424    }
425
426    let body = tcx.optimized_mir(def_id);
427
428    for bb in body.basic_blocks.iter() {
429        for stmt in &bb.statements {
430            if let StatementKind::Assign(box (lhs, rhs)) = &stmt.kind {
431                if let Rvalue::Use(Operand::Constant(c)) = rhs {
432                    if let Some(static_def_id) = c.check_static_ptr(tcx) {
433                        globals.entry(static_def_id).or_default().push(lhs.local);
434                    }
435                }
436            }
437        }
438    }
439
440    globals
441}
442
443pub fn get_unsafe_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
444    let mut unsafe_callees = HashSet::new();
445    if tcx.is_mir_available(def_id) {
446        let body = tcx.optimized_mir(def_id);
447        for bb in body.basic_blocks.iter() {
448            if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
449                if let Operand::Constant(func_constant) = func {
450                    if let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() {
451                        if check_safety(tcx, *callee_def_id) == Safety::Unsafe {
452                            unsafe_callees.insert(*callee_def_id);
453                        }
454                    }
455                }
456            }
457        }
458    }
459    unsafe_callees
460}
461
462pub fn get_all_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
463    let mut callees = HashSet::new();
464    if tcx.is_mir_available(def_id) {
465        let body = tcx.optimized_mir(def_id);
466        for bb in body.basic_blocks.iter() {
467            if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
468                if let Operand::Constant(func_constant) = func {
469                    if let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() {
470                        callees.insert(*callee_def_id);
471                    }
472                }
473            }
474        }
475    }
476    callees
477}
478
479// return all the impls def id of corresponding struct
480pub fn get_impls_for_struct(tcx: TyCtxt<'_>, struct_def_id: DefId) -> Vec<DefId> {
481    let mut impls = Vec::new();
482    for item_id in tcx.hir_crate_items(()).free_items() {
483        let item = tcx.hir_item(item_id);
484        if let rustc_hir::ItemKind::Impl(impl_details) = &item.kind {
485            if let rustc_hir::TyKind::Path(rustc_hir::QPath::Resolved(_, path)) =
486                &impl_details.self_ty.kind
487            {
488                if let rustc_hir::def::Res::Def(_, def_id) = path.res {
489                    if def_id == struct_def_id {
490                        impls.push(item_id.owner_id.to_def_id());
491                    }
492                }
493            }
494        }
495    }
496    impls
497}
498
499pub fn get_adt_def_id_by_adt_method(tcx: TyCtxt<'_>, def_id: DefId) -> Option<DefId> {
500    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
501        if let Some(impl_id) = assoc_item.impl_container(tcx) {
502            // get struct ty
503            let ty = tcx.type_of(impl_id).skip_binder();
504            if let Some(adt_def) = ty.ty_adt_def() {
505                return Some(adt_def.did());
506            }
507        }
508    }
509    None
510}
511
512// get the pointee or wrapped type
513pub fn get_pointee(matched_ty: Ty<'_>) -> Ty<'_> {
514    // progress_info!("get_pointee: > {:?} as type: {:?}", matched_ty, matched_ty.kind());
515    let pointee = if let ty::RawPtr(ty_mut, _) = matched_ty.kind() {
516        get_pointee(*ty_mut)
517    } else if let ty::Ref(_, referred_ty, _) = matched_ty.kind() {
518        get_pointee(*referred_ty)
519    } else {
520        matched_ty
521    };
522    pointee
523}
524
525pub fn is_ptr(matched_ty: Ty<'_>) -> bool {
526    if let ty::RawPtr(_, _) = matched_ty.kind() {
527        return true;
528    }
529    false
530}
531
532pub fn is_ref(matched_ty: Ty<'_>) -> bool {
533    if let ty::Ref(_, _, _) = matched_ty.kind() {
534        return true;
535    }
536    false
537}
538
539pub fn is_slice(matched_ty: Ty) -> Option<Ty> {
540    if let ty::Slice(inner) = matched_ty.kind() {
541        return Some(*inner);
542    }
543    None
544}
545
546pub fn has_mut_self_param(tcx: TyCtxt, def_id: DefId) -> bool {
547    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
548        match assoc_item.kind {
549            AssocKind::Fn { has_self, .. } => {
550                if has_self && tcx.is_mir_available(def_id) {
551                    let body = tcx.optimized_mir(def_id);
552                    let fst_arg = body.local_decls[Local::from_usize(1)].clone();
553                    let ty = fst_arg.ty;
554                    let is_mut_ref =
555                        matches!(ty.kind(), ty::Ref(_, _, mutbl) if *mutbl == Mutability::Mut);
556                    return fst_arg.mutability.is_mut() || is_mut_ref;
557                }
558            }
559            _ => (),
560        }
561    }
562    false
563}
564
565// Check each field's visibility, return the public fields vec
566pub fn get_public_fields(tcx: TyCtxt, def_id: DefId) -> HashSet<usize> {
567    let adt_def = tcx.adt_def(def_id);
568    adt_def
569        .all_fields()
570        .enumerate()
571        .filter_map(|(index, field_def)| tcx.visibility(field_def.did).is_public().then_some(index))
572        .collect()
573}
574
575// general function for displaying hashmap
576pub fn display_hashmap<K, V>(map: &HashMap<K, V>, level: usize)
577where
578    K: Ord + Debug + Hash,
579    V: Debug,
580{
581    let indent = "  ".repeat(level);
582    let mut sorted_keys: Vec<_> = map.keys().collect();
583    sorted_keys.sort();
584
585    for key in sorted_keys {
586        if let Some(value) = map.get(key) {
587            println!("{}{:?}: {:?}", indent, key, value);
588        }
589    }
590}
591
592pub fn match_std_unsafe_chains_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
593    let mut results = Vec::new();
594    if let TerminatorKind::Call { func, .. } = &terminator.kind {
595        if let Operand::Constant(func_constant) = func {
596            if let ty::FnDef(callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
597                let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
598            }
599        }
600    }
601    results
602}
603
604pub fn get_all_std_unsafe_callees(tcx: TyCtxt, def_id: DefId) -> Vec<String> {
605    let mut results = Vec::new();
606    let body = tcx.optimized_mir(def_id);
607    let bb_len = body.basic_blocks.len();
608    for i in 0..bb_len {
609        let callees = match_std_unsafe_callee(
610            tcx,
611            body.basic_blocks[BasicBlock::from_usize(i)]
612                .clone()
613                .terminator(),
614        );
615        results.extend(callees);
616    }
617    results
618}
619
620pub fn get_all_std_unsafe_callees_block_id(tcx: TyCtxt, def_id: DefId) -> Vec<usize> {
621    let mut results = Vec::new();
622    let body = tcx.optimized_mir(def_id);
623    let bb_len = body.basic_blocks.len();
624    for i in 0..bb_len {
625        if match_std_unsafe_callee(
626            tcx,
627            body.basic_blocks[BasicBlock::from_usize(i)]
628                .clone()
629                .terminator(),
630        )
631        .is_empty()
632        {
633            results.push(i);
634        }
635    }
636    results
637}
638
639pub fn match_std_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
640    let mut results = Vec::new();
641    if let TerminatorKind::Call { func, .. } = &terminator.kind {
642        if let Operand::Constant(func_constant) = func {
643            if let ty::FnDef(callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
644                let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
645                // rap_info!("{func_name}");
646                if parse_unsafe_api(&func_name).is_some() {
647                    results.push(func_name);
648                }
649            }
650        }
651    }
652    results
653}
654
655// Bug definition: (1) strict -> weak & dst is mutable;
656//                 (2) _ -> strict
657pub fn is_strict_ty_convert<'tcx>(tcx: TyCtxt<'tcx>, src_ty: Ty<'tcx>, dst_ty: Ty<'tcx>) -> bool {
658    (is_strict_ty(tcx, src_ty) && dst_ty.is_mutable_ptr()) || is_strict_ty(tcx, dst_ty)
659}
660
661// strict ty: bool, str, adt fields containing bool or str;
662pub fn is_strict_ty<'tcx>(tcx: TyCtxt<'tcx>, ori_ty: Ty<'tcx>) -> bool {
663    let ty = get_pointee(ori_ty);
664    let mut flag = false;
665    if let TyKind::Adt(adt_def, substs) = ty.kind() {
666        if adt_def.is_struct() {
667            for field_def in adt_def.all_fields() {
668                flag |= is_strict_ty(tcx, field_def.ty(tcx, substs))
669            }
670        }
671    }
672    ty.is_bool() || ty.is_str() || flag
673}
674
675pub fn reverse_op(op: BinOp) -> BinOp {
676    match op {
677        BinOp::Lt => BinOp::Ge,
678        BinOp::Ge => BinOp::Lt,
679        BinOp::Le => BinOp::Gt,
680        BinOp::Gt => BinOp::Le,
681        BinOp::Eq => BinOp::Eq,
682        BinOp::Ne => BinOp::Ne,
683        _ => op,
684    }
685}
686
687/// Generate contracts from pre-defined std-lib JSON configuration (std_sps_args.json).
688pub fn generate_contract_from_std_annotation_json(
689    tcx: TyCtxt<'_>,
690    def_id: DefId,
691) -> Vec<(usize, Vec<usize>, PropertyContract<'_>)> {
692    let mut results = Vec::new();
693    let std_contracts = get_std_contracts(tcx, def_id);
694
695    for entry in std_contracts {
696        let tag_name = entry.tag;
697        let raw_args = entry.args;
698
699        if raw_args.is_empty() {
700            continue;
701        }
702
703        let arg_index_str = &raw_args[0];
704        let local_id = if let Ok(arg_idx) = arg_index_str.parse::<usize>() {
705            arg_idx
706        } else {
707            rap_error!(
708                "JSON Contract Error: First argument must be an arg index number, got {}",
709                arg_index_str
710            );
711            continue;
712        };
713
714        let mut exprs: Vec<Expr> = Vec::new();
715        for arg_str in &raw_args {
716            match syn::parse_str::<Expr>(arg_str) {
717                Ok(expr) => exprs.push(expr),
718                Err(_) => {
719                    rap_error!(
720                        "JSON Contract Error: Failed to parse arg '{}' as Rust Expr for tag {}",
721                        arg_str,
722                        tag_name
723                    );
724                }
725            }
726        }
727
728        // Robustness check of arguments transition
729        if exprs.len() != raw_args.len() {
730            rap_error!(
731                "Parse std API args error: Failed to parse arg '{:?}'",
732                raw_args
733            );
734            continue;
735        }
736        let fields: Vec<usize> = Vec::new();
737        let contract = PropertyContract::new(tcx, def_id, &tag_name, &exprs);
738        results.push((local_id, fields, contract));
739    }
740
741    // rap_warn!("Get contract {:?}.", results);
742    results
743}
744
745/// Same with `generate_contract_from_annotation` but does not contain field types.
746pub fn generate_contract_from_annotation_without_field_types(
747    tcx: TyCtxt,
748    def_id: DefId,
749) -> Vec<(usize, Vec<usize>, PropertyContract)> {
750    let contracts_with_ty = generate_contract_from_annotation(tcx, def_id);
751
752    contracts_with_ty
753        .into_iter()
754        .map(|(local_id, fields_with_ty, contract)| {
755            let fields: Vec<usize> = fields_with_ty
756                .into_iter()
757                .map(|(field_idx, _)| field_idx)
758                .collect();
759            (local_id, fields, contract)
760        })
761        .collect()
762}
763
764/// Filter the function which contains "rapx::proof"
765pub fn is_verify_target_func(tcx: TyCtxt, def_id: DefId) -> bool {
766    for attr in tcx.get_all_attrs(def_id).into_iter() {
767        let attr_str = rustc_hir_pretty::attribute_to_string(&tcx, attr);
768        // Find proof placeholder
769        if attr_str.contains("#[rapx::proof(proof)]") {
770            return true;
771        }
772    }
773    false
774}
775
776/// Get the annotation in tag-std style.
777/// Then generate the contractual invariant states (CIS) for the args.
778/// This function will recognize the args name and record states to MIR variable (represent by usize).
779/// Return value means Vec<(local_id, fields of this local, contracts)>
780pub fn generate_contract_from_annotation(
781    tcx: TyCtxt,
782    def_id: DefId,
783) -> Vec<(usize, Vec<(usize, Ty)>, PropertyContract)> {
784    const REGISTER_TOOL: &str = "rapx";
785    let tool_attrs = tcx.get_all_attrs(def_id).into_iter().filter(|attr| {
786        if let Attribute::Unparsed(tool_attr) = attr {
787            if tool_attr.path.segments[0].as_str() == REGISTER_TOOL {
788                return true;
789            }
790        }
791        false
792    });
793    let mut results = Vec::new();
794    for attr in tool_attrs {
795        let attr_str = rustc_hir_pretty::attribute_to_string(&tcx, attr);
796        // Find proof placeholder, skip it
797        if attr_str.contains("#[rapx::proof(proof)]") {
798            continue;
799        }
800        rap_debug!("{:?}", attr_str);
801        let safety_attr = safety_parser::safety::parse_attr_and_get_properties(attr_str.as_str());
802        for par in safety_attr.iter() {
803            for property in par.tags.iter() {
804                let tag_name = property.tag.name();
805                let exprs = property.args.clone().into_vec();
806                let contract = PropertyContract::new(tcx, def_id, tag_name, &exprs);
807                let (local, fields) = parse_cis_local(tcx, def_id, exprs);
808                results.push((local, fields, contract));
809            }
810        }
811    }
812    // if results.len() > 0 {
813    //     rap_warn!("results:\n{:?}", results);
814    // }
815    results
816}
817
818/// Parse attr.expr into local id and local fields.
819///
820/// Example:
821/// ```
822/// #[rapx::inner(property = ValidPtr(ptr, u32, 1), kind = "precond")]
823/// #[rapx::inner(property = ValidNum(region.size>=0), kind = "precond")]
824/// pub fn xor_secret_region(ptr: *mut u32, region:SecretRegion) -> u32 {...}
825/// ```
826///
827/// The first attribute will be parsed as (1, []).
828///     -> "1" means the first arg "ptr", "[]" means no fields.
829/// The second attribute will be parsed as (2, [1]).
830///     -> "2" means the second arg "region", "[1]" means "size" is region's second field.
831///
832/// If this function doesn't have args, then it will return default pattern: (0, Vec::new())
833pub fn parse_cis_local(tcx: TyCtxt, def_id: DefId, expr: Vec<Expr>) -> (usize, Vec<(usize, Ty)>) {
834    // match expr with cis local
835    for e in expr {
836        if let Some((base, fields, _ty)) = parse_expr_into_local_and_ty(tcx, def_id, &e) {
837            return (base, fields);
838        }
839    }
840    (0, Vec::new())
841}
842
843/// parse single expr into (local, fields, ty)
844pub fn parse_expr_into_local_and_ty<'tcx>(
845    tcx: TyCtxt<'tcx>,
846    def_id: DefId,
847    expr: &Expr,
848) -> Option<(usize, Vec<(usize, Ty<'tcx>)>, Ty<'tcx>)> {
849    if let Some((base_ident, fields)) = access_ident_recursive(&expr) {
850        let (param_names, param_tys) = parse_signature(tcx, def_id);
851        if param_names[0] == "0".to_string() {
852            return None;
853        }
854        if let Some(param_index) = param_names.iter().position(|name| name == &base_ident) {
855            let mut current_ty = param_tys[param_index];
856            let mut field_indices = Vec::new();
857            for field_name in fields {
858                // peel the ref and ptr
859                let peeled_ty = current_ty.peel_refs();
860                if let rustc_middle::ty::TyKind::Adt(adt_def, arg_list) = *peeled_ty.kind() {
861                    let variant = adt_def.non_enum_variant();
862                    // 1. if field_name is number, then parse it as usize
863                    if let Ok(field_idx) = field_name.parse::<usize>() {
864                        if field_idx < variant.fields.len() {
865                            current_ty = variant.fields[rustc_abi::FieldIdx::from_usize(field_idx)]
866                                .ty(tcx, arg_list);
867                            field_indices.push((field_idx, current_ty));
868                            continue;
869                        }
870                    }
871                    // 2. if field_name is String, then compare it with current ty's field names
872                    if let Some((idx, _)) = variant
873                        .fields
874                        .iter()
875                        .enumerate()
876                        .find(|(_, f)| f.ident(tcx).name.to_string() == field_name.clone())
877                    {
878                        current_ty =
879                            variant.fields[rustc_abi::FieldIdx::from_usize(idx)].ty(tcx, arg_list);
880                        field_indices.push((idx, current_ty));
881                    }
882                    // 3. if field_name can not match any fields, then break
883                    else {
884                        break; // TODO:
885                    }
886                }
887                // if current ty is not Adt, then break the loop
888                else {
889                    break; // TODO:
890                }
891            }
892            // It's different from default one, we return the result as param_index+1 because param_index count from 0.
893            // But 0 in MIR is the ret index, the args' indexes begin from 1.
894            return Some((param_index + 1, field_indices, current_ty));
895        }
896    }
897    None
898}
899
900/// Return the Vecs of args' names and types
901/// This function will handle outside def_id by different way.
902pub fn parse_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
903    // 0. If the def id is local
904    if def_id.as_local().is_some() {
905        return parse_local_signature(tcx, def_id);
906    } else {
907        rap_debug!("{:?} is not local def id.", def_id);
908        return parse_outside_signature(tcx, def_id);
909    };
910}
911
912/// Return the Vecs of args' names and types of outside functions.
913fn parse_outside_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
914    let sig = tcx.fn_sig(def_id).skip_binder();
915    let param_tys: Vec<Ty<'tcx>> = sig.inputs().skip_binder().iter().copied().collect();
916
917    // 1. check pre-defined std unsafe api signature
918    if let Some(args_name) = get_known_std_names(tcx, def_id) {
919        // rap_warn!(
920        //     "function {:?} has arg: {:?}, arg types: {:?}",
921        //     def_id,
922        //     args_name,
923        //     param_tys
924        // );
925        return (args_name, param_tys);
926    }
927
928    // 2. TODO: If can not find known std apis, then use numbers like `0`,`1`,... to represent args.
929    let args_name = (0..param_tys.len()).map(|i| format!("{}", i)).collect();
930    rap_debug!(
931        "function {:?} has arg: {:?}, arg types: {:?}",
932        def_id,
933        args_name,
934        param_tys
935    );
936    return (args_name, param_tys);
937}
938
939/// We use a json to record known std apis' arg names.
940/// This function will search the json and return the names.
941/// Notes: If std gets updated, the json may still record old ones.
942fn get_known_std_names<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Vec<String>> {
943    let std_func_name = get_cleaned_def_path_name(tcx, def_id);
944    let json_data: serde_json::Value = get_std_api_signature_json();
945
946    if let Some(arg_info) = json_data.get(&std_func_name) {
947        if let Some(args_name) = arg_info.as_array() {
948            // set default value to arg name
949            if args_name.len() == 0 {
950                return Some(vec!["0".to_string()]);
951            }
952            // iterate and collect
953            let mut result = Vec::new();
954            for arg in args_name {
955                if let Some(sp_name) = arg.as_str() {
956                    result.push(sp_name.to_string());
957                }
958            }
959            return Some(result);
960        }
961    }
962    None
963}
964
965/// Return the Vecs of args' names and types of local functions.
966pub fn parse_local_signature(tcx: TyCtxt, def_id: DefId) -> (Vec<String>, Vec<Ty>) {
967    // 1. parse local def_id and get arg list
968    let local_def_id = def_id.as_local().unwrap();
969    let hir_body = tcx.hir_body_owned_by(local_def_id);
970    if hir_body.params.len() == 0 {
971        return (vec!["0".to_string()], Vec::new());
972    }
973    // 2. contruct the vec of param and param ty
974    let params = hir_body.params;
975    let typeck_results = tcx.typeck_body(hir_body.id());
976    let mut param_names = Vec::new();
977    let mut param_tys = Vec::new();
978    for param in params {
979        match param.pat.kind {
980            rustc_hir::PatKind::Binding(_, _, ident, _) => {
981                param_names.push(ident.name.to_string());
982                let ty = typeck_results.pat_ty(param.pat);
983                param_tys.push(ty);
984            }
985            _ => {
986                param_names.push(String::new());
987                param_tys.push(typeck_results.pat_ty(param.pat));
988            }
989        }
990    }
991    (param_names, param_tys)
992}
993
994/// return the (ident, its fields) of the expr.
995///
996/// illustrated cases :
997///    ptr	-> ("ptr", [])
998///    region.size	-> ("region", ["size"])
999///    tuple.0.value -> ("tuple", ["0", "value"])
1000pub fn access_ident_recursive(expr: &Expr) -> Option<(String, Vec<String>)> {
1001    match expr {
1002        Expr::Path(syn::ExprPath { path, .. }) => {
1003            if path.segments.len() == 1 {
1004                rap_debug!("expr2 {:?}", expr);
1005                let ident = path.segments[0].ident.to_string();
1006                Some((ident, Vec::new()))
1007            } else {
1008                None
1009            }
1010        }
1011        // get the base and fields recursively
1012        Expr::Field(syn::ExprField { base, member, .. }) => {
1013            let (base_ident, mut fields) =
1014                if let Some((base_ident, fields)) = access_ident_recursive(base) {
1015                    (base_ident, fields)
1016                } else {
1017                    return None;
1018                };
1019            let field_name = match member {
1020                syn::Member::Named(ident) => ident.to_string(),
1021                syn::Member::Unnamed(index) => index.index.to_string(),
1022            };
1023            fields.push(field_name);
1024            Some((base_ident, fields))
1025        }
1026        _ => None,
1027    }
1028}
1029
1030/// parse expr into number.
1031pub fn parse_expr_into_number(expr: &Expr) -> Option<usize> {
1032    if let Expr::Lit(expr_lit) = expr {
1033        if let syn::Lit::Int(lit_int) = &expr_lit.lit {
1034            return lit_int.base10_parse::<usize>().ok();
1035        }
1036    }
1037    None
1038}
1039
1040/// Match a type identifier string to a concrete Rust type
1041///
1042/// This function attempts to match a given type identifier (e.g., "u32", "T", "MyStruct")
1043/// to a type in the provided parameter type list. It handles:
1044/// 1. Built-in primitive types (u32, usize, etc.)
1045/// 2. Generic type parameters (T, U, etc.)
1046/// 3. User-defined types found in the parameter list
1047///
1048/// Arguments:
1049/// - `tcx`: Type context for querying compiler information
1050/// - `type_ident`: String representing the type identifier to match
1051/// - `param_ty`: List of parameter types from the function signature
1052///
1053/// Returns:
1054/// - `Some(Ty)` if a matching type is found
1055/// - `None` if no match is found
1056pub fn match_ty_with_ident(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
1057    // 1. First check for built-in primitive types
1058    if let Some(primitive_ty) = match_primitive_type(tcx, type_ident.clone()) {
1059        return Some(primitive_ty);
1060    }
1061    // 2. Check if the identifier matches any generic type parameter
1062    return find_generic_param(tcx, def_id, type_ident.clone());
1063    // 3. Check if the identifier matches any user-defined type in the parameters
1064    // find_user_defined_type(tcx, def_id, type_ident)
1065}
1066
1067/// Match built-in primitive types from String
1068fn match_primitive_type(tcx: TyCtxt, type_ident: String) -> Option<Ty> {
1069    match type_ident.as_str() {
1070        "i8" => Some(tcx.types.i8),
1071        "i16" => Some(tcx.types.i16),
1072        "i32" => Some(tcx.types.i32),
1073        "i64" => Some(tcx.types.i64),
1074        "i128" => Some(tcx.types.i128),
1075        "isize" => Some(tcx.types.isize),
1076        "u8" => Some(tcx.types.u8),
1077        "u16" => Some(tcx.types.u16),
1078        "u32" => Some(tcx.types.u32),
1079        "u64" => Some(tcx.types.u64),
1080        "u128" => Some(tcx.types.u128),
1081        "usize" => Some(tcx.types.usize),
1082        "f16" => Some(tcx.types.f16),
1083        "f32" => Some(tcx.types.f32),
1084        "f64" => Some(tcx.types.f64),
1085        "f128" => Some(tcx.types.f128),
1086        "bool" => Some(tcx.types.bool),
1087        "char" => Some(tcx.types.char),
1088        "str" => Some(tcx.types.str_),
1089        _ => None,
1090    }
1091}
1092
1093/// Find generic type parameters in the parameter list
1094fn find_generic_param(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
1095    rap_debug!(
1096        "Searching for generic param: {} in {:?}",
1097        type_ident,
1098        def_id
1099    );
1100    let (_, param_tys) = parse_signature(tcx, def_id);
1101    rap_debug!("Function parameter types: {:?} of {:?}", param_tys, def_id);
1102    // 递归查找泛型参数
1103    for &ty in &param_tys {
1104        if let Some(found) = find_generic_in_ty(tcx, ty, &type_ident) {
1105            return Some(found);
1106        }
1107    }
1108
1109    None
1110}
1111
1112/// Iterate the args' types recursively and find the matched generic one.
1113fn find_generic_in_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>, type_ident: &str) -> Option<Ty<'tcx>> {
1114    match ty.kind() {
1115        TyKind::Param(param_ty) => {
1116            if param_ty.name.as_str() == type_ident {
1117                return Some(ty);
1118            }
1119        }
1120        TyKind::RawPtr(ty, _)
1121        | TyKind::Ref(_, ty, _)
1122        | TyKind::Slice(ty)
1123        | TyKind::Array(ty, _) => {
1124            if let Some(found) = find_generic_in_ty(tcx, *ty, type_ident) {
1125                return Some(found);
1126            }
1127        }
1128        TyKind::Tuple(tys) => {
1129            for tuple_ty in tys.iter() {
1130                if let Some(found) = find_generic_in_ty(tcx, tuple_ty, type_ident) {
1131                    return Some(found);
1132                }
1133            }
1134        }
1135        TyKind::Adt(adt_def, substs) => {
1136            let name = tcx.item_name(adt_def.did()).to_string();
1137            if name == type_ident {
1138                return Some(ty);
1139            }
1140            for field in adt_def.all_fields() {
1141                let field_ty = field.ty(tcx, substs);
1142                if let Some(found) = find_generic_in_ty(tcx, field_ty, type_ident) {
1143                    return Some(found);
1144                }
1145            }
1146        }
1147        _ => {}
1148    }
1149    None
1150}
1151
1152pub fn reflect_generic<'tcx>(
1153    generic_mapping: &FxHashMap<String, Ty<'tcx>>,
1154    func_name: &str,
1155    ty: Ty<'tcx>,
1156) -> Ty<'tcx> {
1157    let mut actual_ty = ty;
1158    match ty.kind() {
1159        TyKind::Param(param_ty) => {
1160            let generic_name = param_ty.name.to_string();
1161            if let Some(actual_ty_from_map) = generic_mapping.get(&generic_name) {
1162                actual_ty = *actual_ty_from_map;
1163            }
1164        }
1165        _ => {}
1166    }
1167    rap_debug!(
1168        "peel generic ty for {:?}, actual_ty is {:?}",
1169        func_name,
1170        actual_ty
1171    );
1172    actual_ty
1173}
1174
1175// src_var = 0: for constructor
1176// src_var = 1: for methods
1177pub fn has_tainted_fields(tcx: TyCtxt, def_id: DefId, src_var: u32) -> bool {
1178    let mut dataflow_analyzer = DataFlowAnalyzer::new(tcx, false);
1179    dataflow_analyzer.build_graph(def_id);
1180
1181    let body = tcx.optimized_mir(def_id);
1182    let params = &body.args_iter().collect::<Vec<_>>();
1183    rap_info!("params {:?}", params);
1184    let self_local = Local::from(src_var);
1185
1186    let flowing_params: Vec<Local> = params
1187        .iter()
1188        .filter(|&&param_local| {
1189            dataflow_analyzer.has_flow_between(def_id, self_local, param_local)
1190                && self_local != param_local
1191        })
1192        .copied()
1193        .collect();
1194
1195    if !flowing_params.is_empty() {
1196        rap_info!(
1197            "Taint flow found from self to other parameters: {:?}",
1198            flowing_params
1199        );
1200        true
1201    } else {
1202        false
1203    }
1204}
1205
1206// 修改返回值类型为调用链的向量
1207pub fn get_all_std_unsafe_chains(tcx: TyCtxt, def_id: DefId) -> Vec<Vec<String>> {
1208    let mut results = Vec::new();
1209    let mut visited = HashSet::new(); // 避免循环调用
1210    let mut current_chain = Vec::new();
1211
1212    // 开始DFS遍历
1213    dfs_find_unsafe_chains(tcx, def_id, &mut current_chain, &mut results, &mut visited);
1214    results
1215}
1216
1217// DFS递归查找unsafe调用链
1218fn dfs_find_unsafe_chains(
1219    tcx: TyCtxt,
1220    def_id: DefId,
1221    current_chain: &mut Vec<String>,
1222    results: &mut Vec<Vec<String>>,
1223    visited: &mut HashSet<DefId>,
1224) {
1225    // 避免循环调用
1226    if visited.contains(&def_id) {
1227        return;
1228    }
1229    visited.insert(def_id);
1230
1231    let current_func_name = get_cleaned_def_path_name(tcx, def_id);
1232    current_chain.push(current_func_name.clone());
1233
1234    // 获取当前函数的所有unsafe callee
1235    let unsafe_callees = find_unsafe_callees_in_function(tcx, def_id);
1236
1237    if unsafe_callees.is_empty() {
1238        // 如果没有更多的unsafe callee,保存当前链
1239        results.push(current_chain.clone());
1240    } else {
1241        // 对每个unsafe callee继续DFS
1242        for (callee_def_id, callee_name) in unsafe_callees {
1243            dfs_find_unsafe_chains(tcx, callee_def_id, current_chain, results, visited);
1244        }
1245    }
1246
1247    // 回溯
1248    current_chain.pop();
1249    visited.remove(&def_id);
1250}
1251
1252fn find_unsafe_callees_in_function(tcx: TyCtxt, def_id: DefId) -> Vec<(DefId, String)> {
1253    let mut callees = Vec::new();
1254
1255    if let Some(body) = try_get_mir(tcx, def_id) {
1256        for bb in body.basic_blocks.iter() {
1257            if let Some(terminator) = &bb.terminator {
1258                if let Some((callee_def_id, callee_name)) = extract_unsafe_callee(tcx, terminator) {
1259                    callees.push((callee_def_id, callee_name));
1260                }
1261            }
1262        }
1263    }
1264
1265    callees
1266}
1267
1268fn extract_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Option<(DefId, String)> {
1269    if let TerminatorKind::Call { func, .. } = &terminator.kind {
1270        if let Operand::Constant(func_constant) = func {
1271            if let ty::FnDef(callee_def_id, _) = func_constant.const_.ty().kind() {
1272                if check_safety(tcx, *callee_def_id) == Safety::Unsafe {
1273                    let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
1274                    return Some((*callee_def_id, func_name));
1275                }
1276            }
1277        }
1278    }
1279    None
1280}
1281
1282fn try_get_mir(tcx: TyCtxt<'_>, def_id: DefId) -> Option<&rustc_middle::mir::Body<'_>> {
1283    if tcx.is_mir_available(def_id) {
1284        Some(tcx.optimized_mir(def_id))
1285    } else {
1286        None
1287    }
1288}
1289
1290pub fn get_cleaned_def_path_name(tcx: TyCtxt<'_>, def_id: DefId) -> String {
1291    let def_id_str = format!("{:?}", def_id);
1292    let mut parts: Vec<&str> = def_id_str.split("::").collect();
1293
1294    let mut remove_first = false;
1295    if let Some(first_part) = parts.get_mut(0) {
1296        if first_part.contains("core") {
1297            *first_part = "core";
1298        } else if first_part.contains("std") {
1299            *first_part = "std";
1300        } else if first_part.contains("alloc") {
1301            *first_part = "alloc";
1302        } else {
1303            remove_first = true;
1304        }
1305    }
1306    if remove_first && !parts.is_empty() {
1307        parts.remove(0);
1308    }
1309
1310    let new_parts: Vec<String> = parts
1311        .into_iter()
1312        .filter_map(|s| {
1313            if s.contains("{") {
1314                if remove_first {
1315                    get_struct_name(tcx, def_id)
1316                } else {
1317                    None
1318                }
1319            } else {
1320                Some(s.to_string())
1321            }
1322        })
1323        .collect();
1324
1325    let mut cleaned_path = new_parts.join("::");
1326    cleaned_path = cleaned_path.trim_end_matches(')').to_string();
1327    cleaned_path
1328    // tcx.def_path_str(def_id)
1329    //     .replace("::", "_")
1330    //     .replace("<", "_")
1331    //     .replace(">", "_")
1332    //     .replace(",", "_")
1333    //     .replace(" ", "")
1334    //     .replace("__", "_")
1335}
1336
1337pub fn print_unsafe_chains(chains: &[Vec<String>]) {
1338    if chains.is_empty() {
1339        return;
1340    }
1341
1342    println!("==============================");
1343    println!("Found {} unsafe call chain(s):", chains.len());
1344    for (i, chain) in chains.iter().enumerate() {
1345        println!("Chain {}:", i + 1);
1346        for (j, func_name) in chain.iter().enumerate() {
1347            let indent = "  ".repeat(j);
1348            println!("{}{}-> {}", indent, if j > 0 { " " } else { "" }, func_name);
1349        }
1350        println!();
1351    }
1352}
1353
1354pub fn get_all_std_fns_by_rustc_public(tcx: TyCtxt) -> Vec<DefId> {
1355    let mut all_std_fn_def = Vec::new();
1356    let mut results = Vec::new();
1357    let mut core_fn_def: Vec<_> = rustc_public::find_crates("core")
1358        .iter()
1359        .flat_map(|krate| krate.fn_defs())
1360        .collect();
1361    let mut std_fn_def: Vec<_> = rustc_public::find_crates("std")
1362        .iter()
1363        .flat_map(|krate| krate.fn_defs())
1364        .collect();
1365    let mut alloc_fn_def: Vec<_> = rustc_public::find_crates("alloc")
1366        .iter()
1367        .flat_map(|krate| krate.fn_defs())
1368        .collect();
1369    all_std_fn_def.append(&mut core_fn_def);
1370    all_std_fn_def.append(&mut std_fn_def);
1371    all_std_fn_def.append(&mut alloc_fn_def);
1372
1373    for fn_def in &all_std_fn_def {
1374        let def_id = crate::def_id::to_internal(fn_def, tcx);
1375        results.push(def_id);
1376    }
1377    results
1378}
1379
1380pub fn generate_mir_cfg_dot<'tcx>(
1381    tcx: TyCtxt<'tcx>,
1382    def_id: DefId,
1383    alias_sets: &Vec<FxHashSet<usize>>,
1384) -> Result<(), std::io::Error> {
1385    let mir = tcx.optimized_mir(def_id);
1386
1387    let mut dot_content = String::new();
1388
1389    let alias_info_str = format!("Alias Sets: {:?}", alias_sets);
1390
1391    dot_content.push_str(&format!(
1392        "digraph mir_cfg_{} {{\n",
1393        get_cleaned_def_path_name(tcx, def_id)
1394    ));
1395
1396    dot_content.push_str(&format!(
1397        "    label = \"MIR CFG for {}\\n{}\\n\";\n",
1398        tcx.def_path_str(def_id),
1399        alias_info_str.replace("\"", "\\\"")
1400    ));
1401    dot_content.push_str("    labelloc = \"t\";\n");
1402    dot_content.push_str("    node [shape=box, fontname=\"Courier\", align=\"left\"];\n\n");
1403
1404    for (bb_index, bb_data) in mir.basic_blocks.iter_enumerated() {
1405        let mut lines: Vec<String> = bb_data
1406            .statements
1407            .iter()
1408            .map(|stmt| format!("{:?}", stmt))
1409            .collect();
1410
1411        let mut node_style = String::new();
1412
1413        if let Some(terminator) = &bb_data.terminator {
1414            let mut is_drop_related = false;
1415
1416            match &terminator.kind {
1417                TerminatorKind::Drop { .. } => {
1418                    is_drop_related = true;
1419                }
1420                TerminatorKind::Call { func, .. } => {
1421                    if let Operand::Constant(c) = func {
1422                        if let ty::FnDef(def_id, _) = *c.ty().kind() {
1423                            if def_id == drop()
1424                                || def_id == drop_in_place()
1425                                || def_id == manually_drop()
1426                                || dealloc_opt().map(|f| f == def_id).unwrap_or(false)
1427                            {
1428                                is_drop_related = true;
1429                            }
1430                        }
1431                    }
1432                }
1433                _ => {}
1434            }
1435
1436            if is_drop_related {
1437                node_style = ", style=\"filled\", fillcolor=\"#ffdddd\", color=\"red\"".to_string();
1438            }
1439
1440            lines.push(format!("{:?}", terminator.kind));
1441        } else {
1442            lines.push("(no terminator)".to_string());
1443        }
1444
1445        let label_content = lines.join("\\l");
1446
1447        let node_label = format!("BB{}:\\l{}\\l", bb_index.index(), label_content);
1448
1449        dot_content.push_str(&format!(
1450            "    BB{} [label=\"{}\"{}];\n",
1451            bb_index.index(),
1452            node_label.replace("\"", "\\\""),
1453            node_style
1454        ));
1455
1456        if let Some(terminator) = &bb_data.terminator {
1457            for target in terminator.successors() {
1458                let edge_label = match terminator.kind {
1459                    _ => "".to_string(),
1460                };
1461
1462                dot_content.push_str(&format!(
1463                    "    BB{} -> BB{} [label=\"{}\"];\n",
1464                    bb_index.index(),
1465                    target.index(),
1466                    edge_label
1467                ));
1468            }
1469        }
1470    }
1471    dot_content.push_str("}\n");
1472    let name = get_cleaned_def_path_name(tcx, def_id);
1473    render_dot_string(name, dot_content);
1474    rap_debug!("render dot for {:?}", def_id);
1475    Ok(())
1476}
1477
1478// Input the adt def id
1479// Return set of (mutable method def_id, fields can be modified)
1480pub fn get_all_mutable_methods(tcx: TyCtxt, src_def_id: DefId) -> HashMap<DefId, HashSet<usize>> {
1481    let mut std_results = HashMap::new();
1482    if get_type(tcx, src_def_id) == FnKind::Constructor {
1483        return std_results;
1484    }
1485    let all_std_fn_def = get_all_std_fns_by_rustc_public(tcx);
1486    let target_adt_def = get_adt_def_id_by_adt_method(tcx, src_def_id);
1487    let mut is_std = false;
1488    for &def_id in &all_std_fn_def {
1489        let adt_def = get_adt_def_id_by_adt_method(tcx, def_id);
1490        if adt_def.is_some() && adt_def == target_adt_def && src_def_id != def_id {
1491            if has_mut_self_param(tcx, def_id) {
1492                std_results.insert(def_id, HashSet::new());
1493            }
1494            is_std = true;
1495        }
1496    }
1497    if is_std {
1498        return std_results;
1499    }
1500    let mut results = HashMap::new();
1501    let public_fields = target_adt_def.map_or_else(HashSet::new, |def| get_public_fields(tcx, def));
1502    let impl_vec = target_adt_def.map_or_else(Vec::new, |def| get_impls_for_struct(tcx, def));
1503    for impl_id in impl_vec {
1504        if !matches!(tcx.def_kind(impl_id), rustc_hir::def::DefKind::Impl { .. }) {
1505            continue;
1506        }
1507        let associated_items = tcx.associated_items(impl_id);
1508        for item in associated_items.in_definition_order() {
1509            if let ty::AssocKind::Fn {
1510                name: _,
1511                has_self: _,
1512            } = item.kind
1513            {
1514                let item_def_id = item.def_id;
1515                if has_mut_self_param(tcx, item_def_id) {
1516                    let modified_fields = public_fields.clone();
1517                    results.insert(item_def_id, modified_fields);
1518                }
1519            }
1520        }
1521    }
1522    results
1523}
1524
1525pub fn get_cons(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<DefId> {
1526    let mut cons = Vec::new();
1527    if tcx.def_kind(def_id) == DefKind::Fn || get_type(tcx, def_id) == FnKind::Constructor {
1528        return cons;
1529    }
1530    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
1531        if let Some(impl_id) = assoc_item.impl_container(tcx) {
1532            // get struct ty
1533            let ty = tcx.type_of(impl_id).skip_binder();
1534            if let Some(adt_def) = ty.ty_adt_def() {
1535                let adt_def_id = adt_def.did();
1536                let impls = tcx.inherent_impls(adt_def_id);
1537                for impl_def_id in impls {
1538                    for item in tcx.associated_item_def_ids(impl_def_id) {
1539                        if (tcx.def_kind(item) == DefKind::Fn
1540                            || tcx.def_kind(item) == DefKind::AssocFn)
1541                            && get_type(tcx, *item) == FnKind::Constructor
1542                        {
1543                            cons.push(*item);
1544                        }
1545                    }
1546                }
1547            }
1548        }
1549    }
1550    cons
1551}
1552
1553pub fn append_fn_with_types(tcx: TyCtxt, def_id: DefId) -> FnInfo {
1554    FnInfo::new(def_id, check_safety(tcx, def_id), get_type(tcx, def_id))
1555}
1556pub fn search_constructor(tcx: TyCtxt, def_id: DefId) -> Vec<DefId> {
1557    let mut constructors = Vec::new();
1558    if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
1559        if let Some(impl_id) = assoc_item.impl_container(tcx) {
1560            // get struct ty
1561            let ty = tcx.type_of(impl_id).skip_binder();
1562            if let Some(adt_def) = ty.ty_adt_def() {
1563                let adt_def_id = adt_def.did();
1564                let impl_vec = get_impls_for_struct(tcx, adt_def_id);
1565                for impl_id in impl_vec {
1566                    let associated_items = tcx.associated_items(impl_id);
1567                    for item in associated_items.in_definition_order() {
1568                        if let ty::AssocKind::Fn {
1569                            name: _,
1570                            has_self: _,
1571                        } = item.kind
1572                        {
1573                            let item_def_id = item.def_id;
1574                            if get_type(tcx, item_def_id) == FnKind::Constructor {
1575                                constructors.push(item_def_id);
1576                            }
1577                        }
1578                    }
1579                }
1580            }
1581        }
1582    }
1583    constructors
1584}
1585
1586pub fn get_ptr_deref_dummy_def_id(tcx: TyCtxt<'_>) -> Option<DefId> {
1587    tcx.hir_crate_items(()).free_items().find_map(|item_id| {
1588        let def_id = item_id.owner_id.to_def_id();
1589        let name = tcx.opt_item_name(def_id)?;
1590
1591        (name.as_str() == "__raw_ptr_deref_dummy").then_some(def_id)
1592    })
1593}