rapx/analysis/core/api_dep/graph/
ty_wrapper.rs

1use std::hash::Hash;
2
3use super::lifetime::Lifetime;
4use rustc_middle::ty::{self, Ty, TyCtxt};
5
6/// TyWrapper is a wrapper of rustc_middle::ty::Ty,
7/// for the purpose of attaching custom lifetime information to Ty.
8#[derive(Clone, Eq, Debug)]
9pub struct TyWrapper<'tcx> {
10    ty: ty::Ty<'tcx>,
11    lifetime_map: Vec<Lifetime>,
12}
13
14impl<'tcx> From<Ty<'tcx>> for TyWrapper<'tcx> {
15    fn from(ty: ty::Ty<'tcx>) -> TyWrapper<'tcx> {
16        // TODO: initialize lifetime_map
17        let lifetime_map = Vec::new();
18        TyWrapper { ty, lifetime_map }
19    }
20}
21
22impl PartialEq for TyWrapper<'_> {
23    fn eq(&self, other: &Self) -> bool {
24        eq_ty(self.ty, other.ty)
25    }
26}
27
28impl Hash for TyWrapper<'_> {
29    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
30        hash_ty(self.ty, state, &mut 0);
31    }
32}
33
34fn eq_ty<'tcx>(lhs: Ty<'tcx>, rhs: Ty<'tcx>) -> bool {
35    match (lhs.kind(), rhs.kind()) {
36        (ty::TyKind::Adt(adt_def1, generic_arg1), ty::TyKind::Adt(adt_def2, generic_arg2)) => {
37            if adt_def1.did() != adt_def2.did() {
38                return false;
39            }
40            for (arg1, arg2) in generic_arg1.iter().zip(generic_arg2.iter()) {
41                match (arg1.kind(), arg2.kind()) {
42                    (ty::GenericArgKind::Lifetime(_), ty::GenericArgKind::Lifetime(_)) => continue,
43                    (ty::GenericArgKind::Type(ty1), ty::GenericArgKind::Type(ty2)) => {
44                        if !eq_ty(ty1, ty2) {
45                            return false;
46                        }
47                    }
48                    (ty::GenericArgKind::Const(ct1), ty::GenericArgKind::Const(ct2)) => {
49                        if ct1 != ct2 {
50                            return false;
51                        }
52                    }
53                    _ => return false,
54                }
55            }
56            true
57        }
58        (
59            ty::TyKind::RawPtr(inner_ty1, mutability1),
60            ty::TyKind::RawPtr(inner_ty2, mutability2),
61        )
62        | (
63            ty::TyKind::Ref(_, inner_ty1, mutability1),
64            ty::TyKind::Ref(_, inner_ty2, mutability2),
65        ) => mutability1 == mutability2 && eq_ty(*inner_ty1, *inner_ty2),
66        (ty::TyKind::Array(inner_ty1, _), ty::TyKind::Array(inner_ty2, _))
67        | (ty::TyKind::Pat(inner_ty1, _), ty::TyKind::Pat(inner_ty2, _))
68        | (ty::TyKind::Slice(inner_ty1), ty::TyKind::Slice(inner_ty2)) => {
69            eq_ty(*inner_ty1, *inner_ty2)
70        }
71        (ty::TyKind::Tuple(tys1), ty::TyKind::Tuple(tys2)) => {
72            if tys1.len() != tys2.len() {
73                return false;
74            }
75            tys1.iter()
76                .zip(tys2.iter())
77                .all(|(ty1, ty2)| eq_ty(ty1, ty2))
78        }
79        _ => lhs == rhs,
80    }
81}
82
83fn traverse_ty_with_lifetime<'tcx, F: Fn(ty::Region, usize)>(ty: Ty<'tcx>, no: &mut usize, f: &F) {
84    match ty.kind() {
85        ty::TyKind::Adt(adt_def, generic_arg) => {
86            for arg in generic_arg.iter() {
87                match arg.kind() {
88                    ty::GenericArgKind::Lifetime(lt) => {
89                        *no = *no + 1;
90                        f(lt, *no);
91                    }
92                    ty::GenericArgKind::Type(ty) => {
93                        traverse_ty_with_lifetime(ty, no, f);
94                    }
95                    ty::GenericArgKind::Const(ct) => {}
96                }
97            }
98        }
99
100        ty::TyKind::RawPtr(inner_ty, mutability) => {
101            traverse_ty_with_lifetime(*inner_ty, no, f);
102        }
103
104        ty::TyKind::Ref(region, inner_ty, mutability) => {
105            *no = *no + 1;
106            f(*region, *no);
107            traverse_ty_with_lifetime(*inner_ty, no, f);
108        }
109        ty::TyKind::Array(inner_ty, _)
110        | ty::TyKind::Pat(inner_ty, _)
111        | ty::TyKind::Slice(inner_ty) => {
112            traverse_ty_with_lifetime(*inner_ty, no, f);
113        }
114        ty::TyKind::Tuple(tys) => {
115            for inner_ty in tys.iter() {
116                traverse_ty_with_lifetime(inner_ty, no, f);
117            }
118        }
119        _ => {
120            unreachable!("unexpected ty kind");
121        }
122    }
123}
124
125// hashing Ty<'tcx>, but ignore the difference of lifetimes
126fn hash_ty<'tcx, H: std::hash::Hasher>(ty: Ty<'tcx>, state: &mut H, no: &mut usize) {
127    // hash the variant
128    match ty.kind() {
129        ty::TyKind::Adt(..) => 0,
130        ty::TyKind::RawPtr(..) => 1,
131        ty::TyKind::Ref(..) => 2,
132        ty::TyKind::Array(..) => 3,
133        ty::TyKind::Pat(..) => 4,
134        ty::TyKind::Slice(..) => 5,
135        ty::TyKind::Tuple(..) => 6,
136        _ => {
137            ty.hash(state);
138            return;
139        }
140    }
141    .hash(state);
142
143    // hash the content
144    match ty.kind() {
145        ty::TyKind::Adt(adt_def, generic_arg) => {
146            adt_def.did().hash(state);
147            for arg in generic_arg.iter() {
148                match arg.kind() {
149                    ty::GenericArgKind::Lifetime(lt) => {
150                        *no = *no + 1;
151                        no.hash(state);
152                    }
153                    ty::GenericArgKind::Type(ty) => {
154                        hash_ty(ty, state, no);
155                    }
156                    ty::GenericArgKind::Const(ct) => {
157                        ct.hash(state);
158                    }
159                }
160            }
161        }
162
163        ty::TyKind::RawPtr(inner_ty, mutability) => {
164            mutability.hash(state);
165            hash_ty(*inner_ty, state, no);
166        }
167        ty::TyKind::Ref(_, inner_ty, mutability) => {
168            mutability.hash(state);
169            *no = *no + 1;
170            no.hash(state);
171            hash_ty(*inner_ty, state, no);
172        }
173        ty::TyKind::Array(inner_ty, _)
174        | ty::TyKind::Pat(inner_ty, _)
175        | ty::TyKind::Slice(inner_ty) => {
176            hash_ty(*inner_ty, state, no);
177        }
178        ty::TyKind::Tuple(tys) => {
179            for inner_ty in tys.iter() {
180                hash_ty(inner_ty, state, no);
181            }
182        }
183        _ => {
184            unreachable!("unexpected ty kind");
185        }
186    }
187}
188
189pub fn desc_ty_str<'tcx>(ty: Ty<'tcx>, no: &mut usize, tcx: TyCtxt<'tcx>) -> String {
190    match ty.kind() {
191        ty::TyKind::Adt(adt_def, generic_arg) => {
192            let mut ty_str = tcx.def_path_str(adt_def.did());
193            if !generic_arg.is_empty() {
194                ty_str += "<";
195                ty_str += &generic_arg
196                    .iter()
197                    .map(|arg| match arg.kind() {
198                        ty::GenericArgKind::Lifetime(lt) => {
199                            let current_no = *no;
200                            *no = *no + 1;
201                            format!("'#{:?}", current_no)
202                        }
203                        ty::GenericArgKind::Type(ty) => desc_ty_str(ty, no, tcx),
204                        ty::GenericArgKind::Const(ct) => format!("{:?}", ct),
205                    })
206                    .collect::<Vec<String>>()
207                    .join(", ");
208                ty_str += ">";
209            }
210            ty_str
211        }
212
213        ty::TyKind::RawPtr(inner_ty, mutability) => {
214            let mut_str = {
215                match mutability {
216                    ty::Mutability::Mut => "mut",
217                    ty::Mutability::Not => "const",
218                }
219            };
220            format!("*{} {}", mut_str, desc_ty_str(*inner_ty, no, tcx))
221        }
222        ty::TyKind::Ref(_, inner_ty, mutability) => {
223            let mut_str = {
224                match mutability {
225                    ty::Mutability::Mut => "mut",
226                    ty::Mutability::Not => "",
227                }
228            };
229            let current_no = *no;
230            *no = *no + 1;
231            format!(
232                "&'#{}{} {}",
233                current_no,
234                mut_str,
235                desc_ty_str(*inner_ty, no, tcx)
236            )
237        }
238        ty::TyKind::Array(inner_ty, _)
239        | ty::TyKind::Pat(inner_ty, _)
240        | ty::TyKind::Slice(inner_ty) => desc_ty_str(*inner_ty, no, tcx),
241        ty::TyKind::Tuple(tys) => format!(
242            "({})",
243            tys.iter()
244                .map(|ty| desc_ty_str(ty, no, tcx,))
245                .collect::<Vec<String>>()
246                .join(", "),
247        ),
248        _ => format!("{:?}", ty),
249    }
250}
251
252impl<'tcx> TyWrapper<'tcx> {
253    pub fn desc_str(&self, tcx: TyCtxt<'tcx>) -> String {
254        desc_ty_str(self.ty, &mut 0, tcx)
255    }
256}