rapx/analysis/core/api_dependency/graph/
ty_wrapper.rs

1use std::hash::Hash;
2use std::ops::Deref;
3
4use super::transform::TransformKind;
5use rustc_infer::infer::TyCtxtInferExt;
6use rustc_infer::traits::{Obligation, ObligationCause};
7use rustc_middle::traits;
8use rustc_middle::ty::{self, Ty, TyCtxt};
9use rustc_trait_selection::infer::InferCtxtExt;
10use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt as _;
11
12/// TyWrapper is a wrapper of rustc_middle::ty::Ty
13#[derive(Clone, Copy, Eq, Debug)]
14pub struct TyWrapper<'tcx> {
15    ty: Ty<'tcx>,
16}
17
18impl<'tcx> TyWrapper<'tcx> {
19    pub fn ty(&self) -> Ty<'tcx> {
20        self.ty
21    }
22
23    pub fn into_ref(&self, tcx: TyCtxt<'tcx>) -> TyWrapper<'tcx> {
24        Ty::new_ref(tcx, tcx.lifetimes.re_erased, self.ty, ty::Mutability::Not).into()
25    }
26
27    pub fn into_ref_mut(&self, tcx: TyCtxt<'tcx>) -> TyWrapper<'tcx> {
28        Ty::new_ref(tcx, tcx.lifetimes.re_erased, self.ty, ty::Mutability::Mut).into()
29    }
30
31    pub fn transform(&self, kind: TransformKind, tcx: TyCtxt<'tcx>) -> TyWrapper<'tcx> {
32        match kind {
33            TransformKind::Ref(mutability) => {
34                let ty = match mutability {
35                    ty::Mutability::Not => self.into_ref(tcx),
36                    ty::Mutability::Mut => self.into_ref_mut(tcx),
37                };
38                ty
39            }
40            _ => {
41                todo!();
42            }
43        }
44    }
45}
46
47impl<'tcx> From<Ty<'tcx>> for TyWrapper<'tcx> {
48    fn from(ty: ty::Ty<'tcx>) -> TyWrapper<'tcx> {
49        TyWrapper { ty }
50    }
51}
52
53impl<'tcx> Into<Ty<'tcx>> for TyWrapper<'tcx> {
54    fn into(self) -> Ty<'tcx> {
55        self.ty
56    }
57}
58
59impl PartialEq for TyWrapper<'_> {
60    fn eq(&self, other: &Self) -> bool {
61        eq_ty(self.ty, other.ty)
62    }
63}
64
65impl Hash for TyWrapper<'_> {
66    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
67        hash_ty(self.ty, state, &mut 0);
68    }
69}
70
71fn eq_ty<'tcx>(lhs: Ty<'tcx>, rhs: Ty<'tcx>) -> bool {
72    match (lhs.kind(), rhs.kind()) {
73        (ty::TyKind::Adt(adt_def1, generic_arg1), ty::TyKind::Adt(adt_def2, generic_arg2)) => {
74            if adt_def1.did() != adt_def2.did() {
75                return false;
76            }
77            for (arg1, arg2) in generic_arg1.iter().zip(generic_arg2.iter()) {
78                match (arg1.kind(), arg2.kind()) {
79                    (ty::GenericArgKind::Lifetime(_), ty::GenericArgKind::Lifetime(_)) => continue,
80                    (ty::GenericArgKind::Type(ty1), ty::GenericArgKind::Type(ty2)) => {
81                        if !eq_ty(ty1, ty2) {
82                            return false;
83                        }
84                    }
85                    (ty::GenericArgKind::Const(ct1), ty::GenericArgKind::Const(ct2)) => {
86                        if ct1 != ct2 {
87                            return false;
88                        }
89                    }
90                    _ => return false,
91                }
92            }
93            true
94        }
95        (
96            ty::TyKind::RawPtr(inner_ty1, mutability1),
97            ty::TyKind::RawPtr(inner_ty2, mutability2),
98        )
99        | (
100            ty::TyKind::Ref(_, inner_ty1, mutability1),
101            ty::TyKind::Ref(_, inner_ty2, mutability2),
102        ) => mutability1 == mutability2 && eq_ty(*inner_ty1, *inner_ty2),
103        (ty::TyKind::Array(inner_ty1, _), ty::TyKind::Array(inner_ty2, _))
104        | (ty::TyKind::Pat(inner_ty1, _), ty::TyKind::Pat(inner_ty2, _))
105        | (ty::TyKind::Slice(inner_ty1), ty::TyKind::Slice(inner_ty2)) => {
106            eq_ty(*inner_ty1, *inner_ty2)
107        }
108        (ty::TyKind::Tuple(tys1), ty::TyKind::Tuple(tys2)) => {
109            if tys1.len() != tys2.len() {
110                return false;
111            }
112            tys1.iter()
113                .zip(tys2.iter())
114                .all(|(ty1, ty2)| eq_ty(ty1, ty2))
115        }
116        _ => lhs == rhs,
117    }
118}
119
120fn traverse_ty_with_lifetime<'tcx, F: Fn(ty::Region, usize)>(ty: Ty<'tcx>, no: &mut usize, f: &F) {
121    match ty.kind() {
122        ty::TyKind::Adt(adt_def, generic_arg) => {
123            for arg in generic_arg.iter() {
124                match arg.kind() {
125                    ty::GenericArgKind::Lifetime(lt) => {
126                        *no = *no + 1;
127                        f(lt, *no);
128                    }
129                    ty::GenericArgKind::Type(ty) => {
130                        traverse_ty_with_lifetime(ty, no, f);
131                    }
132                    ty::GenericArgKind::Const(ct) => {}
133                }
134            }
135        }
136
137        ty::TyKind::RawPtr(inner_ty, mutability) => {
138            traverse_ty_with_lifetime(*inner_ty, no, f);
139        }
140
141        ty::TyKind::Ref(region, inner_ty, mutability) => {
142            *no = *no + 1;
143            f(*region, *no);
144            traverse_ty_with_lifetime(*inner_ty, no, f);
145        }
146        ty::TyKind::Array(inner_ty, _)
147        | ty::TyKind::Pat(inner_ty, _)
148        | ty::TyKind::Slice(inner_ty) => {
149            traverse_ty_with_lifetime(*inner_ty, no, f);
150        }
151        ty::TyKind::Tuple(tys) => {
152            for inner_ty in tys.iter() {
153                traverse_ty_with_lifetime(inner_ty, no, f);
154            }
155        }
156        _ => {
157            unreachable!("unexpected ty kind");
158        }
159    }
160}
161
162// hashing Ty<'tcx>, but ignore the difference of lifetimes
163fn hash_ty<'tcx, H: std::hash::Hasher>(ty: Ty<'tcx>, state: &mut H, no: &mut usize) {
164    std::mem::discriminant(ty.kind()).hash(state);
165
166    // hash the content
167    match ty.kind() {
168        ty::TyKind::Adt(adt_def, generic_arg) => {
169            adt_def.did().hash(state);
170            for arg in generic_arg.iter() {
171                match arg.kind() {
172                    ty::GenericArgKind::Lifetime(lt) => {
173                        *no = *no + 1;
174                        no.hash(state);
175                    }
176                    ty::GenericArgKind::Type(ty) => {
177                        hash_ty(ty, state, no);
178                    }
179                    ty::GenericArgKind::Const(ct) => {
180                        ct.hash(state);
181                    }
182                }
183            }
184        }
185
186        ty::TyKind::RawPtr(inner_ty, mutability) => {
187            mutability.hash(state);
188            hash_ty(*inner_ty, state, no);
189        }
190        ty::TyKind::Ref(_, inner_ty, mutability) => {
191            mutability.hash(state);
192            *no = *no + 1;
193            no.hash(state);
194            hash_ty(*inner_ty, state, no);
195        }
196        ty::TyKind::Array(inner_ty, _) | ty::TyKind::Slice(inner_ty) => {
197            hash_ty(*inner_ty, state, no);
198        }
199        ty::TyKind::Tuple(tys) => {
200            for inner_ty in tys.iter() {
201                hash_ty(inner_ty, state, no);
202            }
203        }
204        _ => {
205            ty.hash(state);
206        }
207    }
208}
209
210pub fn desc_ty_str<'tcx>(ty: Ty<'tcx>, no: &mut usize, tcx: TyCtxt<'tcx>) -> String {
211    match ty.kind() {
212        ty::TyKind::Adt(adt_def, generic_arg) => {
213            let mut ty_str = tcx.def_path_str(adt_def.did());
214            if !generic_arg.is_empty() {
215                ty_str += "<";
216                ty_str += &generic_arg
217                    .iter()
218                    .map(|arg| match arg.kind() {
219                        ty::GenericArgKind::Lifetime(lt) => {
220                            let current_no = *no;
221                            *no = *no + 1;
222                            format!("'#{:?}", current_no)
223                        }
224                        ty::GenericArgKind::Type(ty) => desc_ty_str(ty, no, tcx),
225                        ty::GenericArgKind::Const(ct) => format!("{:?}", ct),
226                    })
227                    .collect::<Vec<String>>()
228                    .join(", ");
229                ty_str += ">";
230            }
231            ty_str
232        }
233
234        ty::TyKind::RawPtr(inner_ty, mutability) => {
235            format!(
236                "*{} {}",
237                mutability.ptr_str(),
238                desc_ty_str(*inner_ty, no, tcx)
239            )
240        }
241        ty::TyKind::Ref(_, inner_ty, mutability) => {
242            let current_no = *no;
243            *no = *no + 1;
244            format!(
245                "&'#{} {}{}",
246                current_no,
247                mutability.prefix_str(),
248                desc_ty_str(*inner_ty, no, tcx)
249            )
250        }
251        ty::TyKind::Array(inner_ty, len) => {
252            format!("[{};{}]", desc_ty_str(*inner_ty, no, tcx), len)
253        }
254
255        ty::TyKind::Slice(inner_ty) => {
256            format!("[{}]", desc_ty_str(*inner_ty, no, tcx))
257        }
258        ty::TyKind::Tuple(tys) => format!(
259            "({})",
260            tys.iter()
261                .map(|ty| desc_ty_str(ty, no, tcx,))
262                .collect::<Vec<String>>()
263                .join(", "),
264        ),
265        ty::TyKind::Pat(inner_ty, _) => {
266            unreachable!();
267        }
268        _ => format!("{:?}", ty),
269    }
270}
271
272impl<'tcx> TyWrapper<'tcx> {
273    pub fn desc_str(&self, tcx: TyCtxt<'tcx>) -> String {
274        desc_ty_str(self.ty, &mut 0, tcx)
275    }
276}