rustc_trait_selection/traits/
engine.rs1use std::cell::RefCell;
2use std::fmt::Debug;
3
4use rustc_data_structures::fx::FxIndexSet;
5use rustc_errors::ErrorGuaranteed;
6use rustc_hir::def_id::{DefId, LocalDefId};
7use rustc_infer::infer::at::ToTrace;
8use rustc_infer::infer::canonical::{
9 Canonical, CanonicalQueryResponse, CanonicalVarValues, QueryResponse,
10};
11use rustc_infer::infer::{DefineOpaqueTypes, InferCtxt, InferOk, RegionResolutionError, TypeTrace};
12use rustc_infer::traits::PredicateObligations;
13use rustc_macros::extension;
14use rustc_middle::arena::ArenaAllocatable;
15use rustc_middle::traits::query::NoSolution;
16use rustc_middle::ty::error::TypeError;
17use rustc_middle::ty::relate::Relate;
18use rustc_middle::ty::{self, Ty, TyCtxt, TypeFoldable, Upcast, Variance};
19
20use super::{FromSolverError, FulfillmentContext, ScrubbedTraitError, TraitEngine};
21use crate::error_reporting::InferCtxtErrorExt;
22use crate::regions::InferCtxtRegionExt;
23use crate::solve::{FulfillmentCtxt as NextFulfillmentCtxt, NextSolverError};
24use crate::traits::fulfill::OldSolverError;
25use crate::traits::{
26 FulfillmentError, NormalizeExt, Obligation, ObligationCause, PredicateObligation,
27 StructurallyNormalizeExt,
28};
29
30#[extension(pub trait TraitEngineExt<'tcx, E>)]
31impl<'tcx, E> dyn TraitEngine<'tcx, E>
32where
33 E: FromSolverError<'tcx, NextSolverError<'tcx>> + FromSolverError<'tcx, OldSolverError<'tcx>>,
34{
35 fn new(infcx: &InferCtxt<'tcx>) -> Box<Self> {
36 if infcx.next_trait_solver() {
37 Box::new(NextFulfillmentCtxt::new(infcx))
38 } else {
39 assert!(
40 !infcx.tcx.next_trait_solver_globally(),
41 "using old solver even though new solver is enabled globally"
42 );
43 Box::new(FulfillmentContext::new(infcx))
44 }
45 }
46}
47
48pub struct ObligationCtxt<'a, 'tcx, E = ScrubbedTraitError<'tcx>> {
51 pub infcx: &'a InferCtxt<'tcx>,
52 engine: RefCell<Box<dyn TraitEngine<'tcx, E>>>,
53}
54
55impl<'a, 'tcx> ObligationCtxt<'a, 'tcx, FulfillmentError<'tcx>> {
56 pub fn new_with_diagnostics(infcx: &'a InferCtxt<'tcx>) -> Self {
57 Self { infcx, engine: RefCell::new(<dyn TraitEngine<'tcx, _>>::new(infcx)) }
58 }
59}
60
61impl<'a, 'tcx> ObligationCtxt<'a, 'tcx, ScrubbedTraitError<'tcx>> {
62 pub fn new(infcx: &'a InferCtxt<'tcx>) -> Self {
63 Self { infcx, engine: RefCell::new(<dyn TraitEngine<'tcx, _>>::new(infcx)) }
64 }
65}
66
67impl<'a, 'tcx, E> ObligationCtxt<'a, 'tcx, E>
68where
69 E: 'tcx,
70{
71 pub fn register_obligation(&self, obligation: PredicateObligation<'tcx>) {
72 self.engine.borrow_mut().register_predicate_obligation(self.infcx, obligation);
73 }
74
75 pub fn register_obligations(
76 &self,
77 obligations: impl IntoIterator<Item = PredicateObligation<'tcx>>,
78 ) {
79 for obligation in obligations {
82 self.engine.borrow_mut().register_predicate_obligation(self.infcx, obligation)
83 }
84 }
85
86 pub fn register_infer_ok_obligations<T>(&self, infer_ok: InferOk<'tcx, T>) -> T {
87 let InferOk { value, obligations } = infer_ok;
88 self.engine.borrow_mut().register_predicate_obligations(self.infcx, obligations);
89 value
90 }
91
92 pub fn register_bound(
96 &self,
97 cause: ObligationCause<'tcx>,
98 param_env: ty::ParamEnv<'tcx>,
99 ty: Ty<'tcx>,
100 def_id: DefId,
101 ) {
102 let tcx = self.infcx.tcx;
103 let trait_ref = ty::TraitRef::new(tcx, def_id, [ty]);
104 self.register_obligation(Obligation {
105 cause,
106 recursion_depth: 0,
107 param_env,
108 predicate: trait_ref.upcast(tcx),
109 });
110 }
111
112 pub fn normalize<T: TypeFoldable<TyCtxt<'tcx>>>(
113 &self,
114 cause: &ObligationCause<'tcx>,
115 param_env: ty::ParamEnv<'tcx>,
116 value: T,
117 ) -> T {
118 let infer_ok = self.infcx.at(cause, param_env).normalize(value);
119 self.register_infer_ok_obligations(infer_ok)
120 }
121
122 pub fn eq<T: ToTrace<'tcx>>(
123 &self,
124 cause: &ObligationCause<'tcx>,
125 param_env: ty::ParamEnv<'tcx>,
126 expected: T,
127 actual: T,
128 ) -> Result<(), TypeError<'tcx>> {
129 self.infcx
130 .at(cause, param_env)
131 .eq(DefineOpaqueTypes::Yes, expected, actual)
132 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
133 }
134
135 pub fn eq_trace<T: Relate<TyCtxt<'tcx>>>(
136 &self,
137 cause: &ObligationCause<'tcx>,
138 param_env: ty::ParamEnv<'tcx>,
139 trace: TypeTrace<'tcx>,
140 expected: T,
141 actual: T,
142 ) -> Result<(), TypeError<'tcx>> {
143 self.infcx
144 .at(cause, param_env)
145 .eq_trace(DefineOpaqueTypes::Yes, trace, expected, actual)
146 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
147 }
148
149 pub fn sub<T: ToTrace<'tcx>>(
151 &self,
152 cause: &ObligationCause<'tcx>,
153 param_env: ty::ParamEnv<'tcx>,
154 expected: T,
155 actual: T,
156 ) -> Result<(), TypeError<'tcx>> {
157 self.infcx
158 .at(cause, param_env)
159 .sub(DefineOpaqueTypes::Yes, expected, actual)
160 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
161 }
162
163 pub fn relate<T: ToTrace<'tcx>>(
164 &self,
165 cause: &ObligationCause<'tcx>,
166 param_env: ty::ParamEnv<'tcx>,
167 variance: Variance,
168 expected: T,
169 actual: T,
170 ) -> Result<(), TypeError<'tcx>> {
171 self.infcx
172 .at(cause, param_env)
173 .relate(DefineOpaqueTypes::Yes, expected, variance, actual)
174 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
175 }
176
177 pub fn sup<T: ToTrace<'tcx>>(
179 &self,
180 cause: &ObligationCause<'tcx>,
181 param_env: ty::ParamEnv<'tcx>,
182 expected: T,
183 actual: T,
184 ) -> Result<(), TypeError<'tcx>> {
185 self.infcx
186 .at(cause, param_env)
187 .sup(DefineOpaqueTypes::Yes, expected, actual)
188 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
189 }
190
191 pub fn lub<T: ToTrace<'tcx>>(
193 &self,
194 cause: &ObligationCause<'tcx>,
195 param_env: ty::ParamEnv<'tcx>,
196 expected: T,
197 actual: T,
198 ) -> Result<T, TypeError<'tcx>> {
199 self.infcx
200 .at(cause, param_env)
201 .lub(expected, actual)
202 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
203 }
204
205 #[must_use]
206 pub fn select_where_possible(&self) -> Vec<E> {
207 self.engine.borrow_mut().select_where_possible(self.infcx)
208 }
209
210 #[must_use]
211 pub fn select_all_or_error(&self) -> Vec<E> {
212 self.engine.borrow_mut().select_all_or_error(self.infcx)
213 }
214
215 #[must_use]
223 pub fn into_pending_obligations(self) -> PredicateObligations<'tcx> {
224 self.engine.borrow().pending_obligations()
225 }
226
227 pub fn resolve_regions_and_report_errors(
232 self,
233 body_id: LocalDefId,
234 param_env: ty::ParamEnv<'tcx>,
235 assumed_wf_tys: impl IntoIterator<Item = Ty<'tcx>>,
236 ) -> Result<(), ErrorGuaranteed> {
237 let errors = self.infcx.resolve_regions(body_id, param_env, assumed_wf_tys);
238 if errors.is_empty() {
239 Ok(())
240 } else {
241 Err(self.infcx.err_ctxt().report_region_errors(body_id, &errors))
242 }
243 }
244
245 #[must_use]
250 pub fn resolve_regions(
251 self,
252 body_id: LocalDefId,
253 param_env: ty::ParamEnv<'tcx>,
254 assumed_wf_tys: impl IntoIterator<Item = Ty<'tcx>>,
255 ) -> Vec<RegionResolutionError<'tcx>> {
256 self.infcx.resolve_regions(body_id, param_env, assumed_wf_tys)
257 }
258}
259
260impl<'tcx> ObligationCtxt<'_, 'tcx, FulfillmentError<'tcx>> {
261 pub fn assumed_wf_types_and_report_errors(
262 &self,
263 param_env: ty::ParamEnv<'tcx>,
264 def_id: LocalDefId,
265 ) -> Result<FxIndexSet<Ty<'tcx>>, ErrorGuaranteed> {
266 self.assumed_wf_types(param_env, def_id)
267 .map_err(|errors| self.infcx.err_ctxt().report_fulfillment_errors(errors))
268 }
269}
270
271impl<'tcx> ObligationCtxt<'_, 'tcx, ScrubbedTraitError<'tcx>> {
272 pub fn make_canonicalized_query_response<T>(
273 &self,
274 inference_vars: CanonicalVarValues<'tcx>,
275 answer: T,
276 ) -> Result<CanonicalQueryResponse<'tcx, T>, NoSolution>
277 where
278 T: Debug + TypeFoldable<TyCtxt<'tcx>>,
279 Canonical<'tcx, QueryResponse<'tcx, T>>: ArenaAllocatable<'tcx>,
280 {
281 self.infcx.make_canonicalized_query_response(
282 inference_vars,
283 answer,
284 &mut **self.engine.borrow_mut(),
285 )
286 }
287}
288
289impl<'tcx, E> ObligationCtxt<'_, 'tcx, E>
290where
291 E: FromSolverError<'tcx, NextSolverError<'tcx>>,
292{
293 pub fn assumed_wf_types(
294 &self,
295 param_env: ty::ParamEnv<'tcx>,
296 def_id: LocalDefId,
297 ) -> Result<FxIndexSet<Ty<'tcx>>, Vec<E>> {
298 let tcx = self.infcx.tcx;
299 let mut implied_bounds = FxIndexSet::default();
300 let mut errors = Vec::new();
301 for &(ty, span) in tcx.assumed_wf_types(def_id) {
302 let cause = ObligationCause::misc(span, def_id);
315 match self
316 .infcx
317 .at(&cause, param_env)
318 .deeply_normalize(ty, &mut **self.engine.borrow_mut())
319 {
320 Ok(normalized) => drop(implied_bounds.insert(normalized)),
322 Err(normalization_errors) => errors.extend(normalization_errors),
323 };
324 }
325
326 if errors.is_empty() { Ok(implied_bounds) } else { Err(errors) }
327 }
328
329 pub fn deeply_normalize<T: TypeFoldable<TyCtxt<'tcx>>>(
330 &self,
331 cause: &ObligationCause<'tcx>,
332 param_env: ty::ParamEnv<'tcx>,
333 value: T,
334 ) -> Result<T, Vec<E>> {
335 self.infcx.at(cause, param_env).deeply_normalize(value, &mut **self.engine.borrow_mut())
336 }
337
338 pub fn structurally_normalize_ty(
339 &self,
340 cause: &ObligationCause<'tcx>,
341 param_env: ty::ParamEnv<'tcx>,
342 value: Ty<'tcx>,
343 ) -> Result<Ty<'tcx>, Vec<E>> {
344 self.infcx
345 .at(cause, param_env)
346 .structurally_normalize_ty(value, &mut **self.engine.borrow_mut())
347 }
348
349 pub fn structurally_normalize_const(
350 &self,
351 cause: &ObligationCause<'tcx>,
352 param_env: ty::ParamEnv<'tcx>,
353 value: ty::Const<'tcx>,
354 ) -> Result<ty::Const<'tcx>, Vec<E>> {
355 self.infcx
356 .at(cause, param_env)
357 .structurally_normalize_const(value, &mut **self.engine.borrow_mut())
358 }
359
360 pub fn structurally_normalize_term(
361 &self,
362 cause: &ObligationCause<'tcx>,
363 param_env: ty::ParamEnv<'tcx>,
364 value: ty::Term<'tcx>,
365 ) -> Result<ty::Term<'tcx>, Vec<E>> {
366 self.infcx
367 .at(cause, param_env)
368 .structurally_normalize_term(value, &mut **self.engine.borrow_mut())
369 }
370}