rapx/analysis/senryx/
generic_check.rs

1use std::collections::{HashMap, HashSet};
2
3use if_chain::if_chain;
4use rustc_hir::{hir_id::OwnerId, ImplPolarity, ItemId, ItemKind};
5use rustc_middle::ty::{FloatTy, IntTy, ParamEnv, Ty, TyCtxt, TyKind, UintTy};
6// use crate::rap_info;
7
8pub struct GenericChecker<'tcx> {
9    // tcx: TyCtxt<'tcx>,
10    trait_map: HashMap<String, HashSet<Ty<'tcx>>>,
11}
12
13impl<'tcx> GenericChecker<'tcx> {
14    pub fn new(tcx: TyCtxt<'tcx>, p_env: ParamEnv<'tcx>) -> Self {
15        let mut trait_bnd_map_for_generic: HashMap<String, HashSet<String>> = HashMap::new();
16        let mut satisfied_ty_map_for_generic: HashMap<String, HashSet<Ty<'tcx>>> = HashMap::new();
17
18        for cb in p_env.caller_bounds() {
19            // cb: Binder(TraitPredicate(<Self as trait>, ..)
20            // Focus on the trait bound applied to our generic parameter
21
22            if let Some(trait_pred) = cb.as_trait_clause() {
23                let trait_def_id = trait_pred.def_id();
24                let generic_name = trait_pred.self_ty().skip_binder().to_string();
25                let satisfied_ty_set = satisfied_ty_map_for_generic
26                    .entry(generic_name.clone())
27                    .or_insert_with(|| HashSet::new());
28                let trait_name = tcx.def_path_str(trait_def_id);
29                let trait_bnd_set = trait_bnd_map_for_generic
30                    .entry(generic_name)
31                    .or_insert_with(|| HashSet::new());
32                trait_bnd_set.insert(trait_name.clone());
33
34                // for each implementation
35                for def_id in tcx.all_impls(trait_def_id) {
36                    // impl_id: LocalDefId
37                    if !def_id.is_local() {
38                        continue;
39                    }
40                    let impl_owner_id = tcx
41                        .hir_owner_node(OwnerId {
42                            def_id: def_id.expect_local(),
43                        })
44                        .def_id();
45
46                    let item = tcx.hir_item(ItemId {
47                        owner_id: impl_owner_id,
48                    });
49                    if_chain! {
50                        if let ItemKind::Impl(impl_item) = item.kind;
51                        if impl_item.polarity == ImplPolarity::Positive;
52                        if let Some(binder) = tcx.impl_trait_ref(def_id);
53                        then {
54                            let trait_ref = binder.skip_binder();
55                            let impl_ty = trait_ref.self_ty();
56                            match impl_ty.kind() {
57                                TyKind::Adt(adt_def, _impl_trait_substs) => {
58                                    let adt_did = adt_def.did();
59                                    let adt_ty = tcx.type_of(adt_did).skip_binder();
60                                    // rap_info!("{} is implemented on adt({:?})", trait_name, adt_ty);
61                                    satisfied_ty_set.insert(adt_ty);
62                                },
63                                TyKind::Param(p_ty) => {
64                                    let _param_ty = p_ty.to_ty(tcx);
65                                },
66                                _ => {
67                                    // rap_info!("{} is implemented on {:?}", trait_name, impl_ty);
68                                    satisfied_ty_set.insert(impl_ty);
69                                },
70                            }
71                        }
72                    }
73                }
74
75                // handle known external trait e.g., Pod
76                if trait_name == "bytemuck::Pod" || trait_name == "plain::Plain" {
77                    let ty_bnd = Self::get_satisfied_ty_for_pod(tcx);
78                    satisfied_ty_set.extend(&ty_bnd);
79                    // rap_info!("current trait bound type set: {:?}", satisfied_ty_set);
80                }
81            }
82        }
83
84        // check trait_bnd_set
85        let std_trait_set = HashSet::from([
86            String::from("std::marker::Copy"),
87            String::from("std::clone::Clone"),
88            String::from("std::marker::Sized"),
89        ]);
90        // if all trait_bound is std::marker, then we could assume it to be arbitrary type
91        // to avoid messing up with build type manually
92        // we just clear the satisfied ty set
93        for (key, satisfied_ty_set) in &mut satisfied_ty_map_for_generic {
94            let trait_bnd_set = trait_bnd_map_for_generic
95                .entry(key.clone())
96                .or_insert_with(|| HashSet::new());
97            if trait_bnd_set.is_subset(&std_trait_set) {
98                satisfied_ty_set.clear();
99            }
100        }
101
102        // rap_info!("trait bound type map: {:?}", satisfied_ty_map_for_generic);
103
104        GenericChecker {
105            trait_map: satisfied_ty_map_for_generic,
106        }
107    }
108
109    pub fn get_satisfied_ty_map(&self) -> HashMap<String, HashSet<Ty<'tcx>>> {
110        self.trait_map.clone()
111    }
112
113    fn get_satisfied_ty_for_pod(tcx: TyCtxt<'tcx>) -> HashSet<Ty<'tcx>> {
114        let mut satisfied_ty_set_for_pod: HashSet<Ty<'tcx>> = HashSet::new();
115        // f64, u64, i8, i32, u8, i16, u16, u32, usize, i128, isize, i64, u128, f32
116        let pod_ty = [
117            tcx.mk_ty_from_kind(TyKind::Int(IntTy::Isize)),
118            tcx.mk_ty_from_kind(TyKind::Int(IntTy::I8)),
119            tcx.mk_ty_from_kind(TyKind::Int(IntTy::I16)),
120            tcx.mk_ty_from_kind(TyKind::Int(IntTy::I32)),
121            tcx.mk_ty_from_kind(TyKind::Int(IntTy::I64)),
122            tcx.mk_ty_from_kind(TyKind::Int(IntTy::I128)),
123            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::Usize)),
124            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::U8)),
125            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::U16)),
126            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::U32)),
127            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::U64)),
128            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::U128)),
129            tcx.mk_ty_from_kind(TyKind::Float(FloatTy::F32)),
130            tcx.mk_ty_from_kind(TyKind::Float(FloatTy::F64)),
131        ];
132
133        for pt in pod_ty.iter() {
134            satisfied_ty_set_for_pod.insert(*pt);
135        }
136        satisfied_ty_set_for_pod.clone()
137    }
138}