1use super::graph::TyWrapper;
2use super::utils::{self, fn_sig_with_generic_args};
3use crate::analysis::utils::def_path::path_str_def_id;
4use crate::{rap_debug, rap_trace};
5use rand::seq::SliceRandom;
6use rand::Rng;
7use rustc_hir::def_id::DefId;
8use rustc_hir::LangItem;
9use rustc_infer::infer::DefineOpaqueTypes;
10use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
11use rustc_infer::traits::{ImplSource, Obligation, ObligationCause};
12use rustc_middle::ty::{self, GenericArgsRef, Ty, TyCtxt, TypeVisitableExt, TypingEnv};
13use rustc_span::DUMMY_SP;
14use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt as _;
15use std::collections::HashSet;
16
17static MAX_STEP_SET_SIZE: usize = 1000;
18
19#[derive(Clone, Debug, Hash, PartialEq, Eq)]
20pub struct Mono<'tcx> {
21 pub value: Vec<ty::GenericArg<'tcx>>,
22}
23
24impl<'tcx> FromIterator<ty::GenericArg<'tcx>> for Mono<'tcx> {
25 fn from_iter<T>(iter: T) -> Self
26 where
27 T: IntoIterator<Item = ty::GenericArg<'tcx>>,
28 {
29 Mono {
30 value: iter.into_iter().collect(),
31 }
32 }
33}
34
35impl<'tcx> Mono<'tcx> {
36 pub fn new(identity: &[ty::GenericArg<'tcx>]) -> Self {
37 Mono {
38 value: Vec::from(identity),
39 }
40 }
41
42 fn has_infer_types(&self) -> bool {
43 self.value.iter().any(|arg| match arg.kind() {
44 ty::GenericArgKind::Type(ty) => ty.has_infer_types(),
45 _ => false,
46 })
47 }
48
49 fn mut_arg_at(&mut self, idx: usize) -> &mut ty::GenericArg<'tcx> {
50 &mut self.value[idx]
51 }
52
53 fn merge(&self, other: &Mono<'tcx>, tcx: TyCtxt<'tcx>) -> Option<Mono<'tcx>> {
54 assert!(self.value.len() == other.value.len());
55 let mut res = Vec::new();
56 for i in 0..self.value.len() {
57 let arg = self.value[i];
58 let other_arg = other.value[i];
59 let new_arg = if let Some(ty) = arg.as_type() {
60 let other_ty = other_arg.expect_ty();
61 if ty.is_ty_var() && other_ty.is_ty_var() {
62 arg
63 } else if ty.is_ty_var() {
64 other_arg
65 } else if other_ty.is_ty_var() {
66 arg
67 } else if utils::is_ty_eq(ty, other_ty, tcx) {
68 arg
69 } else {
70 return None;
71 }
72 } else {
73 arg
74 };
75 res.push(new_arg);
76 }
77 Some(Mono { value: res })
78 }
79
80 fn fill_unbound_var(&self, tcx: TyCtxt<'tcx>) -> Vec<Mono<'tcx>> {
81 let candidates = get_unbound_generic_candidates(tcx);
82 let mut res = vec![self.clone()];
83 rap_trace!("fill unbound: {:?}", self);
84
85 for (i, arg) in self.value.iter().enumerate() {
86 if let Some(ty) = arg.as_type() {
87 if ty.is_ty_var() {
88 let mut last = Vec::new();
89 std::mem::swap(&mut res, &mut last);
90 last.into_iter().for_each(|mono| {
91 for candidate in &candidates {
92 let mut new_mono = mono.clone();
93 *new_mono.mut_arg_at(i) = (*candidate).into();
94 res.push(new_mono);
95 }
96 });
97 }
98 }
99 }
100 res
101 }
102}
103
104#[derive(Clone, Debug, Default)]
105pub struct MonoSet<'tcx> {
106 pub monos: Vec<Mono<'tcx>>,
107}
108
109impl<'tcx> MonoSet<'tcx> {
110 pub fn all(identity: &[ty::GenericArg<'tcx>]) -> MonoSet<'tcx> {
111 MonoSet {
112 monos: vec![Mono::new(identity)],
113 }
114 }
115
116 pub fn empty() -> MonoSet<'tcx> {
117 MonoSet { monos: Vec::new() }
118 }
119
120 pub fn count(&self) -> usize {
121 self.monos.len()
122 }
123
124 pub fn at(&self, no: usize) -> &Mono<'tcx> {
125 &self.monos[no]
126 }
127
128 pub fn is_empty(&self) -> bool {
129 self.monos.is_empty()
130 }
131
132 pub fn new() -> MonoSet<'tcx> {
133 MonoSet { monos: Vec::new() }
134 }
135
136 pub fn insert(&mut self, mono: Mono<'tcx>) {
137 self.monos.push(mono);
138 }
139
140 pub fn merge(&mut self, other: &MonoSet<'tcx>, tcx: TyCtxt<'tcx>) -> MonoSet<'tcx> {
141 let mut res = MonoSet::new();
142
143 for args in self.monos.iter() {
144 for other_args in other.monos.iter() {
145 let merged = args.merge(other_args, tcx);
146 if let Some(mono) = merged {
147 res.insert(mono);
148 }
149 }
150 }
151 res
152 }
153
154 fn filter_unbound_solution(mut self) -> Self {
155 self.monos.retain(|mono| mono.has_infer_types());
156 self
157 }
158
159 fn instantiate_unbound(&self, tcx: TyCtxt<'tcx>) -> Self {
163 let mut res = MonoSet::new();
164 for mono in &self.monos {
165 let filled = mono.fill_unbound_var(tcx);
166 res.monos.extend(filled);
167 }
168 res
169 }
170
171 fn erase_region_var(&mut self, tcx: TyCtxt<'tcx>) {
172 for mono in &mut self.monos {
173 mono.value
174 .iter_mut()
175 .for_each(|arg| *arg = tcx.erase_and_anonymize_regions(*arg))
176 }
177 }
178
179 pub fn filter(mut self, f: impl Fn(&Mono<'tcx>) -> bool) -> Self {
180 self.monos.retain(|args| f(args));
181 self
182 }
183
184 fn filter_by_trait_bound(mut self, fn_did: DefId, tcx: TyCtxt<'tcx>) -> Self {
185 let early_fn_sig = tcx.fn_sig(fn_did);
186 self.monos
187 .retain(|args| is_args_fit_trait_bound(fn_did, &args.value, tcx));
188 self
189 }
190
191 pub fn random_sample<R: Rng>(&mut self, rng: &mut R) {
192 if self.monos.len() <= MAX_STEP_SET_SIZE {
193 return;
194 }
195 self.monos.shuffle(rng);
196 self.monos.truncate(MAX_STEP_SET_SIZE);
197 }
198}
199
200fn unify_ty<'tcx>(
205 lhs: Ty<'tcx>,
206 rhs: Ty<'tcx>,
207 identity: &[ty::GenericArg<'tcx>],
208 infcx: &InferCtxt<'tcx>,
209 cause: &ObligationCause<'tcx>,
210 param_env: ty::ParamEnv<'tcx>,
211) -> Option<Mono<'tcx>> {
212 infcx.probe(|_| {
214 match infcx
215 .at(cause, param_env)
216 .eq(DefineOpaqueTypes::Yes, lhs, rhs)
217 {
218 Ok(infer_ok) => {
219 let mono = identity
221 .iter()
222 .map(|arg| match arg.kind() {
223 ty::GenericArgKind::Lifetime(region) => {
224 infcx.resolve_vars_if_possible(region).into()
225 }
226 ty::GenericArgKind::Type(ty) => infcx.resolve_vars_if_possible(ty).into(),
227 ty::GenericArgKind::Const(ct) => infcx.resolve_vars_if_possible(ct).into(),
228 })
229 .collect();
230 Some(mono)
231 }
232 Err(e) => {
233 None
235 }
236 }
237 })
238}
239
240fn is_args_fit_trait_bound<'tcx>(
241 fn_did: DefId,
242 args: &[ty::GenericArg<'tcx>],
243 tcx: TyCtxt<'tcx>,
244) -> bool {
245 let args = tcx.mk_args(args);
246 let infcx = tcx.infer_ctxt().build(ty::TypingMode::PostAnalysis);
253 let pred = tcx.predicates_of(fn_did);
254 let inst_pred = pred.instantiate(tcx, args);
255 let param_env = tcx.param_env(fn_did);
256 rap_trace!(
257 "[trait bound] check {}",
258 tcx.def_path_str_with_args(fn_did, args)
259 );
260
261 for pred in inst_pred.predicates.iter() {
262 let obligation = Obligation::new(
263 tcx,
264 ObligationCause::dummy(),
265 param_env,
266 pred.as_predicate(),
267 );
268
269 let res = infcx.evaluate_obligation(&obligation);
270 match res {
271 Ok(eva) => {
272 if !eva.may_apply() {
273 rap_trace!("[trait bound] check fail for {pred:?}");
274 return false;
275 }
276 }
277 Err(_) => {
278 rap_trace!("[trait bound] check fail for {pred:?}");
279 return false;
280 }
281 }
282 }
283 rap_trace!("[trait bound] check succ");
284 true
285}
286
287fn is_fn_solvable<'tcx>(fn_did: DefId, tcx: TyCtxt<'tcx>) -> bool {
288 for pred in tcx
289 .predicates_of(fn_did)
290 .instantiate_identity(tcx)
291 .predicates
292 {
293 if let Some(pred) = pred.as_trait_clause() {
294 let trait_did = pred.skip_binder().trait_ref.def_id;
295 if tcx.is_lang_item(trait_did, LangItem::Fn)
296 || tcx.is_lang_item(trait_did, LangItem::FnMut)
297 || tcx.is_lang_item(trait_did, LangItem::FnOnce)
298 {
299 return false;
300 }
301 }
302 }
303 true
304}
305
306fn get_mono_set<'tcx>(
307 fn_did: DefId,
308 available_ty: &HashSet<TyWrapper<'tcx>>,
309 tcx: TyCtxt<'tcx>,
310) -> MonoSet<'tcx> {
311 let mut rng = rand::rng();
312
313 rap_debug!("[get_mono_set] fn_did: {:?}", fn_did);
315 let infcx = tcx
316 .infer_ctxt()
317 .ignoring_regions()
318 .build(ty::TypingMode::PostAnalysis);
319 let param_env = tcx.param_env(fn_did);
320 let dummy_cause = ObligationCause::dummy();
321 let fresh_args = infcx.fresh_args_for_item(DUMMY_SP, fn_did);
322 let fn_sig = fn_sig_with_generic_args(fn_did, fresh_args, tcx);
324 let generics = tcx.generics_of(fn_did);
325
326 for i in 0..fresh_args.len() {
328 rap_trace!(
329 "[get_mono_set] arg#{}: {:?} -> {:?}",
330 i,
331 generics.param_at(i, tcx).name,
332 fresh_args[i]
333 );
334 }
335
336 let mut s = MonoSet::all(&fresh_args);
337
338 rap_trace!("[get_mono_set] initialize s: {:?}", s);
339
340 let mut cnt = 0;
341
342 for input_ty in fn_sig.inputs().iter() {
343 cnt += 1;
344 if !input_ty.has_infer_types() {
345 continue;
346 }
347 rap_trace!("[get_mono_set] input_ty#{}: {:?}", cnt - 1, input_ty);
348
349 let mut reachable_set =
350 available_ty
351 .iter()
352 .fold(MonoSet::new(), |mut reachable_set, ty| {
353 if let Some(mono) = unify_ty(
354 *input_ty,
355 (*ty).into(),
356 &fresh_args,
357 &infcx,
358 &dummy_cause,
359 param_env,
360 ) {
361 reachable_set.insert(mono);
362 }
363 reachable_set
364 });
365 reachable_set.random_sample(&mut rng);
366 rap_debug!(
367 "[get_mono_set] size: s = {}, input = {}",
368 s.count(),
369 reachable_set.count()
370 );
371 s = s.merge(&reachable_set, tcx);
372 s.random_sample(&mut rng);
373 }
374
375 rap_trace!("[get_mono_set] after input types: {:?}", s);
376
377 let mut res = MonoSet::new();
378
379 for mono in s.monos {
380 solve_unbound_type_generics(
381 fn_did,
382 mono,
383 &mut res,
384 &infcx,
386 &dummy_cause,
387 param_env,
388 tcx,
389 );
390 }
391
392 res.erase_region_var(tcx);
394
395 res
396}
397
398fn is_special_std_ty<'tcx>(def_id: DefId, tcx: TyCtxt<'tcx>) -> bool {
399 let allowed_std_ty = [
400 tcx.lang_items().string().unwrap(),
401 path_str_def_id(tcx, "std::vec::Vec"),
402 ];
403
404 allowed_std_ty.contains(&def_id)
405}
406
407fn solve_unbound_type_generics<'tcx>(
408 did: DefId,
409 mono: Mono<'tcx>,
410 res: &mut MonoSet<'tcx>,
411 infcx: &InferCtxt<'tcx>,
412 cause: &ObligationCause<'tcx>,
413 param_env: ty::ParamEnv<'tcx>,
414 tcx: TyCtxt<'tcx>,
415) {
416 if !mono.has_infer_types() {
417 res.insert(mono);
418 return;
419 }
420 let args = tcx.mk_args(&mono.value);
421 let preds = tcx.predicates_of(did).instantiate(tcx, args);
422 let mut mset = MonoSet::all(args);
423 rap_debug!("[solve_unbound] did = {did:?}, mset={mset:?}");
424 for pred in preds.predicates.iter() {
425 if let Some(trait_pred) = pred.as_trait_clause() {
426 let trait_pred = trait_pred.skip_binder();
427
428 rap_trace!("[solve_unbound] pred: {:?}", trait_pred);
429
430 let trait_def_id = trait_pred.trait_ref.def_id;
431 if tcx.is_lang_item(trait_def_id, LangItem::Sized)
433 || tcx.is_lang_item(trait_def_id, LangItem::Copy)
434 {
435 continue;
436 }
437
438 let mut p = MonoSet::new();
439
440 for impl_did in tcx
441 .all_impls(trait_def_id)
442 .chain(tcx.inherent_impls(trait_def_id).iter().map(|did| *did))
443 {
444 let impl_trait_ref = tcx.impl_trait_ref(impl_did).unwrap().skip_binder();
446
447 rap_trace!("impl_trait_ref: {}", impl_trait_ref);
448 if !impl_did.is_local() && !impl_trait_ref.self_ty().is_primitive() {
452 continue;
453 }
454 if let Some(mono) = unify_trait(
457 trait_pred.trait_ref,
458 impl_trait_ref,
459 args,
460 &infcx,
461 &cause,
462 param_env,
463 tcx,
464 ) {
465 p.insert(mono);
466 }
467 }
468 mset = mset.merge(&p, tcx);
469 rap_trace!("[solve_unbound] mset: {:?}", mset);
470 }
471 }
472
473 rap_trace!("[solve_unbound] (final) mset: {:?}", mset);
474 for mono in mset.monos {
475 res.insert(mono);
476 }
477}
478
479fn unify_trait<'tcx>(
482 lhs: ty::TraitRef<'tcx>,
483 rhs: ty::TraitRef<'tcx>,
484 identity: &[ty::GenericArg<'tcx>],
485 infcx: &InferCtxt<'tcx>,
486 cause: &ObligationCause<'tcx>,
487 param_env: ty::ParamEnv<'tcx>,
488 tcx: TyCtxt<'tcx>,
489) -> Option<Mono<'tcx>> {
490 rap_trace!("[unify_trait] lhs: {:?}, rhs: {:?}", lhs, rhs);
491 if lhs.def_id != rhs.def_id {
492 return None;
493 }
494
495 assert!(lhs.args.len() == rhs.args.len());
496 let mut s = Mono::new(identity);
497 for (lhs_arg, rhs_arg) in lhs.args.iter().zip(rhs.args.iter()) {
498 if let (Some(lhs_ty), Some(rhs_ty)) = (lhs_arg.as_type(), rhs_arg.as_type()) {
499 if rhs_ty.has_infer_types() || rhs_ty.has_param() {
500 return None;
502 }
503 let mono = unify_ty(lhs_ty, rhs_ty, identity, infcx, cause, param_env)?;
504 rap_trace!("[unify_trait] unified mono: {:?}", mono);
505 s = s.merge(&mono, tcx)?;
506 }
507 }
508 Some(s)
509}
510
511pub fn resolve_mono_apis<'tcx>(
512 fn_did: DefId,
513 available_ty: &HashSet<TyWrapper<'tcx>>,
514 tcx: TyCtxt<'tcx>,
515) -> MonoSet<'tcx> {
516 if !is_fn_solvable(fn_did, tcx) {
518 return MonoSet::empty();
519 }
520
521 let ret = get_mono_set(fn_did, &available_ty, tcx).instantiate_unbound(tcx);
523
524 let ret = ret.filter_by_trait_bound(fn_did, tcx);
526
527 ret
528}
529
530pub fn add_transform_tys<'tcx>(available_ty: &mut HashSet<TyWrapper<'tcx>>, tcx: TyCtxt<'tcx>) {
531 let mut new_tys = Vec::new();
532 available_ty.iter().for_each(|ty| {
533 new_tys.push(
534 Ty::new_ref(
535 tcx,
536 tcx.lifetimes.re_erased,
537 (*ty).into(),
538 ty::Mutability::Not,
539 )
540 .into(),
541 );
542 new_tys.push(Ty::new_ref(
543 tcx,
544 tcx.lifetimes.re_erased,
545 (*ty).into(),
546 ty::Mutability::Mut,
547 ));
548 new_tys.push(Ty::new_ref(
549 tcx,
550 tcx.lifetimes.re_erased,
551 Ty::new_slice(tcx, (*ty).into()),
552 ty::Mutability::Not,
553 ));
554 new_tys.push(Ty::new_ref(
555 tcx,
556 tcx.lifetimes.re_erased,
557 Ty::new_slice(tcx, (*ty).into()),
558 ty::Mutability::Mut,
559 ));
560 });
561
562 new_tys.into_iter().for_each(|ty| {
563 available_ty.insert(ty.into());
564 });
565}
566
567pub fn eliminate_infer_var<'tcx>(
568 fn_did: DefId,
569 args: &[ty::GenericArg<'tcx>],
570 tcx: TyCtxt<'tcx>,
571) -> Vec<ty::GenericArg<'tcx>> {
572 let mut res = Vec::new();
573 let identity = ty::GenericArgs::identity_for_item(tcx, fn_did);
574 for (i, arg) in args.iter().enumerate() {
575 if let Some(ty) = arg.as_type() {
576 if ty.is_ty_var() {
577 res.push(identity[i]);
578 } else {
579 res.push(*arg);
580 }
581 } else {
582 res.push(*arg);
583 }
584 }
585 res
586}
587
588pub fn get_unbound_generic_candidates<'tcx>(tcx: TyCtxt<'tcx>) -> Vec<ty::Ty<'tcx>> {
591 vec![
592 tcx.types.bool,
593 tcx.types.char,
594 tcx.types.u8,
595 tcx.types.i8,
596 tcx.types.i32,
597 tcx.types.u32,
598 tcx.types.f32,
601 Ty::new_imm_ref(
603 tcx,
604 tcx.lifetimes.re_erased,
605 Ty::new_slice(tcx, tcx.types.u8),
606 ),
607 Ty::new_mut_ref(
608 tcx,
609 tcx.lifetimes.re_erased,
610 Ty::new_slice(tcx, tcx.types.u8),
611 ),
612 ]
613}
614
615pub fn get_impls<'tcx>(
616 tcx: TyCtxt<'tcx>,
617 fn_did: DefId,
618 args: GenericArgsRef<'tcx>,
619) -> HashSet<DefId> {
620 let mut impls = HashSet::new();
621 let preds = tcx.predicates_of(fn_did).instantiate(tcx, args);
622 for (pred, _) in preds {
623 if let Some(trait_pred) = pred.as_trait_clause() {
624 let trait_ref: rustc_type_ir::TraitRef<TyCtxt<'tcx>> =
625 trait_pred.skip_binder().trait_ref;
626 let res = tcx.codegen_select_candidate(
634 TypingEnv::fully_monomorphized().as_query_input(trait_ref),
635 );
636 if let Ok(source) = res {
637 match source {
638 ImplSource::UserDefined(data) => {
639 if data.impl_def_id.is_local() {
640 impls.insert(data.impl_def_id);
641 }
642 }
643 _ => {}
644 }
645 }
646 }
648 }
649 rap_trace!("fn: {:?} args: {:?} impls: {:?}", fn_did, args, impls);
650 impls
651}