rapx/analysis/senryx/
generic_check.rs

1use std::collections::{HashMap, HashSet};
2
3use if_chain::if_chain;
4use rustc_hir::{hir_id::OwnerId, ImplPolarity, ItemId, ItemKind};
5use rustc_middle::ty::{FloatTy, IntTy, ParamEnv, Ty, TyCtxt, TyKind, UintTy};
6// use crate::rap_info;
7
8pub struct GenericChecker<'tcx> {
9    // tcx: TyCtxt<'tcx>,
10    trait_map: HashMap<String, HashSet<Ty<'tcx>>>,
11}
12
13impl<'tcx> GenericChecker<'tcx> {
14    pub fn new(tcx: TyCtxt<'tcx>, p_env: ParamEnv<'tcx>) -> Self {
15        let mut trait_bnd_map_for_generic: HashMap<String, HashSet<String>> = HashMap::new();
16        let mut satisfied_ty_map_for_generic: HashMap<String, HashSet<Ty<'tcx>>> = HashMap::new();
17
18        for cb in p_env.caller_bounds() {
19            // cb: Binder(TraitPredicate(<Self as trait>, ..)
20            // Focus on the trait bound applied to our generic parameter
21
22            if let Some(trait_pred) = cb.as_trait_clause() {
23                let trait_def_id = trait_pred.def_id();
24                let generic_name = trait_pred.self_ty().skip_binder().to_string();
25                let satisfied_ty_set = satisfied_ty_map_for_generic
26                    .entry(generic_name.clone())
27                    .or_insert_with(|| HashSet::new());
28                let trait_name = tcx.def_path_str(trait_def_id);
29                let trait_bnd_set = trait_bnd_map_for_generic
30                    .entry(generic_name)
31                    .or_insert_with(|| HashSet::new());
32                trait_bnd_set.insert(trait_name.clone());
33
34                // for each implementation
35                for def_id in tcx.all_impls(trait_def_id) {
36                    // impl_id: LocalDefId
37                    if !def_id.is_local() {
38                        continue;
39                    }
40                    let impl_owner_id = tcx
41                        .hir_owner_node(OwnerId {
42                            def_id: def_id.expect_local(),
43                        })
44                        .def_id();
45
46                    let item = tcx.hir_item(ItemId {
47                        owner_id: impl_owner_id,
48                    });
49                    if_chain! {
50                        if let ItemKind::Impl(impl_item) = item.kind;
51                        if let Some(trait_impl_header) = impl_item.of_trait;
52                        if trait_impl_header.polarity == ImplPolarity::Positive;
53                        if let Some(binder) = tcx.impl_trait_ref(def_id);
54                        then {
55                            let trait_ref = binder.skip_binder();
56                            let impl_ty = trait_ref.self_ty();
57                            match impl_ty.kind() {
58                                TyKind::Adt(adt_def, _impl_trait_substs) => {
59                                    let adt_did = adt_def.did();
60                                    let adt_ty = tcx.type_of(adt_did).skip_binder();
61                                    // rap_info!("{} is implemented on adt({:?})", trait_name, adt_ty);
62                                    satisfied_ty_set.insert(adt_ty);
63                                },
64                                TyKind::Param(p_ty) => {
65                                    let _param_ty = p_ty.to_ty(tcx);
66                                },
67                                _ => {
68                                    // rap_info!("{} is implemented on {:?}", trait_name, impl_ty);
69                                    satisfied_ty_set.insert(impl_ty);
70                                },
71                            }
72                        }
73                    }
74                }
75
76                // handle known external trait e.g., Pod
77                if trait_name == "bytemuck::Pod" || trait_name == "plain::Plain" {
78                    let ty_bnd = Self::get_satisfied_ty_for_pod(tcx);
79                    satisfied_ty_set.extend(&ty_bnd);
80                    // rap_info!("current trait bound type set: {:?}", satisfied_ty_set);
81                }
82            }
83        }
84
85        // check trait_bnd_set
86        let std_trait_set = HashSet::from([
87            String::from("std::marker::Copy"),
88            String::from("std::clone::Clone"),
89            String::from("std::marker::Sized"),
90        ]);
91        // if all trait_bound is std::marker, then we could assume it to be arbitrary type
92        // to avoid messing up with build type manually
93        // we just clear the satisfied ty set
94        for (key, satisfied_ty_set) in &mut satisfied_ty_map_for_generic {
95            let trait_bnd_set = trait_bnd_map_for_generic
96                .entry(key.clone())
97                .or_insert_with(|| HashSet::new());
98            if trait_bnd_set.is_subset(&std_trait_set) {
99                satisfied_ty_set.clear();
100            }
101        }
102
103        // rap_info!("trait bound type map: {:?}", satisfied_ty_map_for_generic);
104
105        GenericChecker {
106            trait_map: satisfied_ty_map_for_generic,
107        }
108    }
109
110    pub fn get_satisfied_ty_map(&self) -> HashMap<String, HashSet<Ty<'tcx>>> {
111        self.trait_map.clone()
112    }
113
114    fn get_satisfied_ty_for_pod(tcx: TyCtxt<'tcx>) -> HashSet<Ty<'tcx>> {
115        let mut satisfied_ty_set_for_pod: HashSet<Ty<'tcx>> = HashSet::new();
116        // f64, u64, i8, i32, u8, i16, u16, u32, usize, i128, isize, i64, u128, f32
117        let pod_ty = [
118            tcx.mk_ty_from_kind(TyKind::Int(IntTy::Isize)),
119            tcx.mk_ty_from_kind(TyKind::Int(IntTy::I8)),
120            tcx.mk_ty_from_kind(TyKind::Int(IntTy::I16)),
121            tcx.mk_ty_from_kind(TyKind::Int(IntTy::I32)),
122            tcx.mk_ty_from_kind(TyKind::Int(IntTy::I64)),
123            tcx.mk_ty_from_kind(TyKind::Int(IntTy::I128)),
124            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::Usize)),
125            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::U8)),
126            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::U16)),
127            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::U32)),
128            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::U64)),
129            tcx.mk_ty_from_kind(TyKind::Uint(UintTy::U128)),
130            tcx.mk_ty_from_kind(TyKind::Float(FloatTy::F32)),
131            tcx.mk_ty_from_kind(TyKind::Float(FloatTy::F64)),
132        ];
133
134        for pt in pod_ty.iter() {
135            satisfied_ty_set_for_pod.insert(*pt);
136        }
137        satisfied_ty_set_for_pod.clone()
138    }
139}