1use std::convert::Infallible;
49use std::mem;
50use std::sync::Arc;
51
52use rustc_index::{Idx, IndexVec};
53use thin_vec::ThinVec;
54use tracing::{debug, instrument};
55
56use crate::inherent::*;
57use crate::visit::{TypeVisitable, TypeVisitableExt as _};
58use crate::{self as ty, Interner, TypeFlags};
59
60pub trait TypeFoldable<I: Interner>: TypeVisitable<I> + Clone {
72 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error>;
83
84 fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self;
98}
99
100pub trait TypeSuperFoldable<I: Interner>: TypeFoldable<I> {
102 fn try_super_fold_with<F: FallibleTypeFolder<I>>(
109 self,
110 folder: &mut F,
111 ) -> Result<Self, F::Error>;
112
113 fn super_fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self;
117}
118
119pub trait TypeFolder<I: Interner>: Sized {
129 fn cx(&self) -> I;
130
131 fn fold_binder<T>(&mut self, t: ty::Binder<I, T>) -> ty::Binder<I, T>
132 where
133 T: TypeFoldable<I>,
134 {
135 t.super_fold_with(self)
136 }
137
138 fn fold_ty(&mut self, t: I::Ty) -> I::Ty {
139 t.super_fold_with(self)
140 }
141
142 fn fold_region(&mut self, r: I::Region) -> I::Region {
145 r
146 }
147
148 fn fold_const(&mut self, c: I::Const) -> I::Const {
149 c.super_fold_with(self)
150 }
151
152 fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
153 p.super_fold_with(self)
154 }
155
156 fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
157 c.super_fold_with(self)
158 }
159}
160
161pub trait FallibleTypeFolder<I: Interner>: Sized {
169 type Error;
170
171 fn cx(&self) -> I;
172
173 fn try_fold_binder<T>(&mut self, t: ty::Binder<I, T>) -> Result<ty::Binder<I, T>, Self::Error>
174 where
175 T: TypeFoldable<I>,
176 {
177 t.try_super_fold_with(self)
178 }
179
180 fn try_fold_ty(&mut self, t: I::Ty) -> Result<I::Ty, Self::Error> {
181 t.try_super_fold_with(self)
182 }
183
184 fn try_fold_region(&mut self, r: I::Region) -> Result<I::Region, Self::Error> {
187 Ok(r)
188 }
189
190 fn try_fold_const(&mut self, c: I::Const) -> Result<I::Const, Self::Error> {
191 c.try_super_fold_with(self)
192 }
193
194 fn try_fold_predicate(&mut self, p: I::Predicate) -> Result<I::Predicate, Self::Error> {
195 p.try_super_fold_with(self)
196 }
197
198 fn try_fold_clauses(&mut self, c: I::Clauses) -> Result<I::Clauses, Self::Error> {
199 c.try_super_fold_with(self)
200 }
201}
202
203impl<I: Interner, T: TypeFoldable<I>, U: TypeFoldable<I>> TypeFoldable<I> for (T, U) {
207 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<(T, U), F::Error> {
208 Ok((self.0.try_fold_with(folder)?, self.1.try_fold_with(folder)?))
209 }
210
211 fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
212 (self.0.fold_with(folder), self.1.fold_with(folder))
213 }
214}
215
216impl<I: Interner, A: TypeFoldable<I>, B: TypeFoldable<I>, C: TypeFoldable<I>> TypeFoldable<I>
217 for (A, B, C)
218{
219 fn try_fold_with<F: FallibleTypeFolder<I>>(
220 self,
221 folder: &mut F,
222 ) -> Result<(A, B, C), F::Error> {
223 Ok((
224 self.0.try_fold_with(folder)?,
225 self.1.try_fold_with(folder)?,
226 self.2.try_fold_with(folder)?,
227 ))
228 }
229
230 fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
231 (self.0.fold_with(folder), self.1.fold_with(folder), self.2.fold_with(folder))
232 }
233}
234
235impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Option<T> {
236 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
237 Ok(match self {
238 Some(v) => Some(v.try_fold_with(folder)?),
239 None => None,
240 })
241 }
242
243 fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
244 Some(self?.fold_with(folder))
245 }
246}
247
248impl<I: Interner, T: TypeFoldable<I>, E: TypeFoldable<I>> TypeFoldable<I> for Result<T, E> {
249 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
250 Ok(match self {
251 Ok(v) => Ok(v.try_fold_with(folder)?),
252 Err(e) => Err(e.try_fold_with(folder)?),
253 })
254 }
255
256 fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
257 match self {
258 Ok(v) => Ok(v.fold_with(folder)),
259 Err(e) => Err(e.fold_with(folder)),
260 }
261 }
262}
263
264fn fold_arc<T: Clone, E>(
265 mut arc: Arc<T>,
266 fold: impl FnOnce(T) -> Result<T, E>,
267) -> Result<Arc<T>, E> {
268 unsafe {
272 Arc::make_mut(&mut arc);
278
279 let ptr = Arc::into_raw(arc).cast::<mem::ManuallyDrop<T>>();
282 let mut unique = Arc::from_raw(ptr);
283
284 let slot = Arc::get_mut(&mut unique).unwrap_unchecked();
288
289 let owned = mem::ManuallyDrop::take(slot);
294 let folded = fold(owned)?;
295 *slot = mem::ManuallyDrop::new(folded);
296
297 Ok(Arc::from_raw(Arc::into_raw(unique).cast()))
299 }
300}
301
302impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Arc<T> {
303 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
304 fold_arc(self, |t| t.try_fold_with(folder))
305 }
306
307 fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
308 match fold_arc::<T, Infallible>(self, |t| Ok(t.fold_with(folder))) {
309 Ok(t) => t,
310 }
311 }
312}
313
314impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Box<T> {
315 fn try_fold_with<F: FallibleTypeFolder<I>>(mut self, folder: &mut F) -> Result<Self, F::Error> {
316 *self = (*self).try_fold_with(folder)?;
317 Ok(self)
318 }
319
320 fn fold_with<F: TypeFolder<I>>(mut self, folder: &mut F) -> Self {
321 *self = (*self).fold_with(folder);
322 self
323 }
324}
325
326impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Vec<T> {
327 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
328 self.into_iter().map(|t| t.try_fold_with(folder)).collect()
329 }
330
331 fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
332 self.into_iter().map(|t| t.fold_with(folder)).collect()
333 }
334}
335
336impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for ThinVec<T> {
337 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
338 self.into_iter().map(|t| t.try_fold_with(folder)).collect()
339 }
340
341 fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
342 self.into_iter().map(|t| t.fold_with(folder)).collect()
343 }
344}
345
346impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Box<[T]> {
347 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
348 Vec::from(self).try_fold_with(folder).map(Vec::into_boxed_slice)
349 }
350
351 fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
352 Vec::into_boxed_slice(Vec::from(self).fold_with(folder))
353 }
354}
355
356impl<I: Interner, T: TypeFoldable<I>, Ix: Idx> TypeFoldable<I> for IndexVec<Ix, T> {
357 fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
358 self.raw.try_fold_with(folder).map(IndexVec::from_raw)
359 }
360
361 fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
362 IndexVec::from_raw(self.raw.fold_with(folder))
363 }
364}
365
366struct Shifter<I: Interner> {
376 cx: I,
377 current_index: ty::DebruijnIndex,
378 amount: u32,
379}
380
381impl<I: Interner> Shifter<I> {
382 fn new(cx: I, amount: u32) -> Self {
383 Shifter { cx, current_index: ty::INNERMOST, amount }
384 }
385}
386
387impl<I: Interner> TypeFolder<I> for Shifter<I> {
388 fn cx(&self) -> I {
389 self.cx
390 }
391
392 fn fold_binder<T: TypeFoldable<I>>(&mut self, t: ty::Binder<I, T>) -> ty::Binder<I, T> {
393 self.current_index.shift_in(1);
394 let t = t.super_fold_with(self);
395 self.current_index.shift_out(1);
396 t
397 }
398
399 fn fold_region(&mut self, r: I::Region) -> I::Region {
400 match r.kind() {
401 ty::ReBound(debruijn, br) if debruijn >= self.current_index => {
402 let debruijn = debruijn.shifted_in(self.amount);
403 Region::new_bound(self.cx, debruijn, br)
404 }
405 _ => r,
406 }
407 }
408
409 fn fold_ty(&mut self, ty: I::Ty) -> I::Ty {
410 match ty.kind() {
411 ty::Bound(debruijn, bound_ty) if debruijn >= self.current_index => {
412 let debruijn = debruijn.shifted_in(self.amount);
413 Ty::new_bound(self.cx, debruijn, bound_ty)
414 }
415
416 _ if ty.has_vars_bound_at_or_above(self.current_index) => ty.super_fold_with(self),
417 _ => ty,
418 }
419 }
420
421 fn fold_const(&mut self, ct: I::Const) -> I::Const {
422 match ct.kind() {
423 ty::ConstKind::Bound(debruijn, bound_ct) if debruijn >= self.current_index => {
424 let debruijn = debruijn.shifted_in(self.amount);
425 Const::new_bound(self.cx, debruijn, bound_ct)
426 }
427 _ => ct.super_fold_with(self),
428 }
429 }
430
431 fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
432 if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
433 }
434}
435
436pub fn shift_region<I: Interner>(cx: I, region: I::Region, amount: u32) -> I::Region {
437 match region.kind() {
438 ty::ReBound(debruijn, br) if amount > 0 => {
439 Region::new_bound(cx, debruijn.shifted_in(amount), br)
440 }
441 _ => region,
442 }
443}
444
445#[instrument(level = "trace", skip(cx), ret)]
446pub fn shift_vars<I: Interner, T>(cx: I, value: T, amount: u32) -> T
447where
448 T: TypeFoldable<I>,
449{
450 if amount == 0 || !value.has_escaping_bound_vars() {
451 value
452 } else {
453 value.fold_with(&mut Shifter::new(cx, amount))
454 }
455}
456
457pub fn fold_regions<I: Interner, T>(
461 cx: I,
462 value: T,
463 f: impl FnMut(I::Region, ty::DebruijnIndex) -> I::Region,
464) -> T
465where
466 T: TypeFoldable<I>,
467{
468 value.fold_with(&mut RegionFolder::new(cx, f))
469}
470
471pub struct RegionFolder<I, F> {
479 cx: I,
480
481 current_index: ty::DebruijnIndex,
485
486 fold_region_fn: F,
490}
491
492impl<I, F> RegionFolder<I, F> {
493 #[inline]
494 pub fn new(cx: I, fold_region_fn: F) -> RegionFolder<I, F> {
495 RegionFolder { cx, current_index: ty::INNERMOST, fold_region_fn }
496 }
497}
498
499impl<I, F> TypeFolder<I> for RegionFolder<I, F>
500where
501 I: Interner,
502 F: FnMut(I::Region, ty::DebruijnIndex) -> I::Region,
503{
504 fn cx(&self) -> I {
505 self.cx
506 }
507
508 fn fold_binder<T: TypeFoldable<I>>(&mut self, t: ty::Binder<I, T>) -> ty::Binder<I, T> {
509 self.current_index.shift_in(1);
510 let t = t.super_fold_with(self);
511 self.current_index.shift_out(1);
512 t
513 }
514
515 #[instrument(skip(self), level = "debug", ret)]
516 fn fold_region(&mut self, r: I::Region) -> I::Region {
517 match r.kind() {
518 ty::ReBound(debruijn, _) if debruijn < self.current_index => {
519 debug!(?self.current_index, "skipped bound region");
520 r
521 }
522 _ => {
523 debug!(?self.current_index, "folding free region");
524 (self.fold_region_fn)(r, self.current_index)
525 }
526 }
527 }
528
529 fn fold_ty(&mut self, t: I::Ty) -> I::Ty {
530 if t.has_type_flags(
531 TypeFlags::HAS_FREE_REGIONS | TypeFlags::HAS_RE_BOUND | TypeFlags::HAS_RE_ERASED,
532 ) {
533 t.super_fold_with(self)
534 } else {
535 t
536 }
537 }
538
539 fn fold_const(&mut self, ct: I::Const) -> I::Const {
540 if ct.has_type_flags(
541 TypeFlags::HAS_FREE_REGIONS | TypeFlags::HAS_RE_BOUND | TypeFlags::HAS_RE_ERASED,
542 ) {
543 ct.super_fold_with(self)
544 } else {
545 ct
546 }
547 }
548
549 fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
550 if p.has_type_flags(
551 TypeFlags::HAS_FREE_REGIONS | TypeFlags::HAS_RE_BOUND | TypeFlags::HAS_RE_ERASED,
552 ) {
553 p.super_fold_with(self)
554 } else {
555 p
556 }
557 }
558}