1#![allow(warnings, unused)]
2
3use super::graph::TyWrapper;
4use super::utils::{self, fn_sig_with_generic_args};
5use crate::analysis::utils::def_path::path_str_def_id;
6use crate::{rap_debug, rap_trace};
7use rand::Rng;
8use rand::seq::SliceRandom;
9use rustc_hir::LangItem;
10use rustc_hir::def_id::DefId;
11use rustc_infer::infer::DefineOpaqueTypes;
12use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
13use rustc_infer::traits::{ImplSource, Obligation, ObligationCause};
14use rustc_middle::ty::{self, GenericArgsRef, Ty, TyCtxt, TypeVisitableExt, TypingEnv};
15use rustc_span::DUMMY_SP;
16use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt as _;
17use std::collections::HashSet;
18
19static MAX_STEP_SET_SIZE: usize = 1000;
20
21#[derive(Clone, Debug, Hash, PartialEq, Eq)]
22pub struct Mono<'tcx> {
23 pub value: Vec<ty::GenericArg<'tcx>>,
24}
25
26impl<'tcx> FromIterator<ty::GenericArg<'tcx>> for Mono<'tcx> {
27 fn from_iter<T>(iter: T) -> Self
28 where
29 T: IntoIterator<Item = ty::GenericArg<'tcx>>,
30 {
31 Mono {
32 value: iter.into_iter().collect(),
33 }
34 }
35}
36
37impl<'tcx> Mono<'tcx> {
38 pub fn new(identity: &[ty::GenericArg<'tcx>]) -> Self {
39 Mono {
40 value: Vec::from(identity),
41 }
42 }
43
44 fn has_infer_types(&self) -> bool {
45 self.value.iter().any(|arg| match arg.kind() {
46 ty::GenericArgKind::Type(ty) => ty.has_infer_types(),
47 _ => false,
48 })
49 }
50
51 fn mut_arg_at(&mut self, idx: usize) -> &mut ty::GenericArg<'tcx> {
52 &mut self.value[idx]
53 }
54
55 fn merge(&self, other: &Mono<'tcx>, tcx: TyCtxt<'tcx>) -> Option<Mono<'tcx>> {
56 assert!(self.value.len() == other.value.len());
57 let mut res = Vec::new();
58 for i in 0..self.value.len() {
59 let arg = self.value[i];
60 let other_arg = other.value[i];
61 let new_arg = if let Some(ty) = arg.as_type() {
62 let other_ty = other_arg.expect_ty();
63 if ty.is_ty_var() && other_ty.is_ty_var() {
64 arg
65 } else if ty.is_ty_var() {
66 other_arg
67 } else if other_ty.is_ty_var() {
68 arg
69 } else if utils::is_ty_eq(ty, other_ty, tcx) {
70 arg
71 } else {
72 return None;
73 }
74 } else {
75 arg
76 };
77 res.push(new_arg);
78 }
79 Some(Mono { value: res })
80 }
81
82 fn fill_unbound_var(&self, tcx: TyCtxt<'tcx>) -> Vec<Mono<'tcx>> {
83 let candidates = get_unbound_generic_candidates(tcx);
84 let mut res = vec![self.clone()];
85 rap_trace!("fill unbound: {:?}", self);
86
87 for (i, arg) in self.value.iter().enumerate() {
88 if let Some(ty) = arg.as_type() {
89 if ty.is_ty_var() {
90 let mut last = Vec::new();
91 std::mem::swap(&mut res, &mut last);
92 last.into_iter().for_each(|mono| {
93 for candidate in &candidates {
94 let mut new_mono = mono.clone();
95 *new_mono.mut_arg_at(i) = (*candidate).into();
96 res.push(new_mono);
97 }
98 });
99 }
100 }
101 }
102 res
103 }
104}
105
106#[derive(Clone, Debug, Default)]
107pub struct MonoSet<'tcx> {
108 pub monos: Vec<Mono<'tcx>>,
109}
110
111impl<'tcx> MonoSet<'tcx> {
112 pub fn all(identity: &[ty::GenericArg<'tcx>]) -> MonoSet<'tcx> {
113 MonoSet {
114 monos: vec![Mono::new(identity)],
115 }
116 }
117
118 pub fn empty() -> MonoSet<'tcx> {
119 MonoSet { monos: Vec::new() }
120 }
121
122 pub fn count(&self) -> usize {
123 self.monos.len()
124 }
125
126 pub fn at(&self, no: usize) -> &Mono<'tcx> {
127 &self.monos[no]
128 }
129
130 pub fn is_empty(&self) -> bool {
131 self.monos.is_empty()
132 }
133
134 pub fn new() -> MonoSet<'tcx> {
135 MonoSet { monos: Vec::new() }
136 }
137
138 pub fn insert(&mut self, mono: Mono<'tcx>) {
139 self.monos.push(mono);
140 }
141
142 pub fn merge(&mut self, other: &MonoSet<'tcx>, tcx: TyCtxt<'tcx>) -> MonoSet<'tcx> {
143 let mut res = MonoSet::new();
144
145 for args in self.monos.iter() {
146 for other_args in other.monos.iter() {
147 let merged = args.merge(other_args, tcx);
148 if let Some(mono) = merged {
149 res.insert(mono);
150 }
151 }
152 }
153 res
154 }
155
156 fn filter_unbound_solution(mut self) -> Self {
157 self.monos.retain(|mono| mono.has_infer_types());
158 self
159 }
160
161 fn instantiate_unbound(&self, tcx: TyCtxt<'tcx>) -> Self {
165 let mut res = MonoSet::new();
166 for mono in &self.monos {
167 let filled = mono.fill_unbound_var(tcx);
168 res.monos.extend(filled);
169 }
170 res
171 }
172
173 fn erase_region_var(&mut self, tcx: TyCtxt<'tcx>) {
174 for mono in &mut self.monos {
175 mono.value
176 .iter_mut()
177 .for_each(|arg| *arg = tcx.erase_and_anonymize_regions(*arg))
178 }
179 }
180
181 pub fn filter(mut self, f: impl Fn(&Mono<'tcx>) -> bool) -> Self {
182 self.monos.retain(|args| f(args));
183 self
184 }
185
186 fn filter_by_trait_bound(mut self, fn_did: DefId, tcx: TyCtxt<'tcx>) -> Self {
187 self.monos
189 .retain(|args| is_args_fit_trait_bound(fn_did, &args.value, tcx));
190 self
191 }
192
193 pub fn random_sample<R: Rng>(&mut self, rng: &mut R) {
194 if self.monos.len() <= MAX_STEP_SET_SIZE {
195 return;
196 }
197 self.monos.shuffle(rng);
198 self.monos.truncate(MAX_STEP_SET_SIZE);
199 }
200}
201
202fn unify_ty<'tcx>(
207 lhs: Ty<'tcx>,
208 rhs: Ty<'tcx>,
209 identity: &[ty::GenericArg<'tcx>],
210 infcx: &InferCtxt<'tcx>,
211 cause: &ObligationCause<'tcx>,
212 param_env: ty::ParamEnv<'tcx>,
213) -> Option<Mono<'tcx>> {
214 infcx.probe(|_| {
216 match infcx
217 .at(cause, param_env)
218 .eq(DefineOpaqueTypes::Yes, lhs, rhs)
219 {
220 Ok(_infer_ok) => {
221 let mono = identity
223 .iter()
224 .map(|arg| match arg.kind() {
225 ty::GenericArgKind::Lifetime(region) => {
226 infcx.resolve_vars_if_possible(region).into()
227 }
228 ty::GenericArgKind::Type(ty) => infcx.resolve_vars_if_possible(ty).into(),
229 ty::GenericArgKind::Const(ct) => infcx.resolve_vars_if_possible(ct).into(),
230 })
231 .collect();
232 Some(mono)
233 }
234 Err(_e) => {
235 None
237 }
238 }
239 })
240}
241
242fn is_args_fit_trait_bound<'tcx>(
243 fn_did: DefId,
244 args: &[ty::GenericArg<'tcx>],
245 tcx: TyCtxt<'tcx>,
246) -> bool {
247 let args = tcx.mk_args(args);
248 let infcx = tcx.infer_ctxt().build(ty::TypingMode::PostAnalysis);
255 let pred = tcx.predicates_of(fn_did);
256 let inst_pred = pred.instantiate(tcx, args);
257 let param_env = tcx.param_env(fn_did);
258 rap_trace!(
259 "[trait bound] check {}",
260 tcx.def_path_str_with_args(fn_did, args)
261 );
262
263 for pred in inst_pred.predicates.iter() {
264 let obligation = Obligation::new(
265 tcx,
266 ObligationCause::dummy(),
267 param_env,
268 pred.as_predicate(),
269 );
270
271 let res = infcx.evaluate_obligation(&obligation);
272 match res {
273 Ok(eva) => {
274 if !eva.may_apply() {
275 rap_trace!("[trait bound] check fail for {pred:?}");
276 return false;
277 }
278 }
279 Err(_) => {
280 rap_trace!("[trait bound] check fail for {pred:?}");
281 return false;
282 }
283 }
284 }
285 rap_trace!("[trait bound] check succ");
286 true
287}
288
289fn is_fn_solvable<'tcx>(fn_did: DefId, tcx: TyCtxt<'tcx>) -> bool {
290 for pred in tcx
291 .predicates_of(fn_did)
292 .instantiate_identity(tcx)
293 .predicates
294 {
295 if let Some(pred) = pred.as_trait_clause() {
296 let trait_did = pred.skip_binder().trait_ref.def_id;
297 if tcx.is_lang_item(trait_did, LangItem::Fn)
298 || tcx.is_lang_item(trait_did, LangItem::FnMut)
299 || tcx.is_lang_item(trait_did, LangItem::FnOnce)
300 {
301 return false;
302 }
303 }
304 }
305 true
306}
307
308fn get_mono_set<'tcx>(
309 fn_did: DefId,
310 available_ty: &HashSet<TyWrapper<'tcx>>,
311 tcx: TyCtxt<'tcx>,
312) -> MonoSet<'tcx> {
313 let mut rng = rand::rng();
314
315 rap_debug!("[get_mono_set] fn_did: {:?}", fn_did);
317 let infcx = tcx
318 .infer_ctxt()
319 .ignoring_regions()
320 .build(ty::TypingMode::PostAnalysis);
321 let param_env = tcx.param_env(fn_did);
322 let dummy_cause = ObligationCause::dummy();
323 let fresh_args = infcx.fresh_args_for_item(DUMMY_SP, fn_did);
324 let fn_sig = fn_sig_with_generic_args(fn_did, fresh_args, tcx);
326 let generics = tcx.generics_of(fn_did);
327
328 for i in 0..fresh_args.len() {
330 rap_trace!(
331 "[get_mono_set] arg#{}: {:?} -> {:?}",
332 i,
333 generics.param_at(i, tcx).name,
334 fresh_args[i]
335 );
336 }
337
338 let mut s = MonoSet::all(&fresh_args);
339
340 rap_trace!("[get_mono_set] initialize s: {:?}", s);
341
342 let mut cnt = 0;
343
344 for input_ty in fn_sig.inputs().iter() {
345 cnt += 1;
346 if !input_ty.has_infer_types() {
347 continue;
348 }
349 rap_trace!("[get_mono_set] input_ty#{}: {:?}", cnt - 1, input_ty);
350
351 let mut reachable_set =
352 available_ty
353 .iter()
354 .fold(MonoSet::new(), |mut reachable_set, ty| {
355 if let Some(mono) = unify_ty(
356 *input_ty,
357 (*ty).into(),
358 &fresh_args,
359 &infcx,
360 &dummy_cause,
361 param_env,
362 ) {
363 reachable_set.insert(mono);
364 }
365 reachable_set
366 });
367 reachable_set.random_sample(&mut rng);
368 rap_debug!(
369 "[get_mono_set] size: s = {}, input = {}",
370 s.count(),
371 reachable_set.count()
372 );
373 s = s.merge(&reachable_set, tcx);
374 s.random_sample(&mut rng);
375 }
376
377 rap_trace!("[get_mono_set] after input types: {:?}", s);
378
379 let mut res = MonoSet::new();
380
381 for mono in s.monos {
382 solve_unbound_type_generics(
383 fn_did,
384 mono,
385 &mut res,
386 &infcx,
388 &dummy_cause,
389 param_env,
390 tcx,
391 );
392 }
393
394 res.erase_region_var(tcx);
396
397 res
398}
399
400fn is_special_std_ty<'tcx>(def_id: DefId, tcx: TyCtxt<'tcx>) -> bool {
401 let allowed_std_ty = [
402 tcx.lang_items().string().unwrap(),
403 path_str_def_id(tcx, "std::vec::Vec"),
404 ];
405
406 allowed_std_ty.contains(&def_id)
407}
408
409fn solve_unbound_type_generics<'tcx>(
410 did: DefId,
411 mono: Mono<'tcx>,
412 res: &mut MonoSet<'tcx>,
413 infcx: &InferCtxt<'tcx>,
414 cause: &ObligationCause<'tcx>,
415 param_env: ty::ParamEnv<'tcx>,
416 tcx: TyCtxt<'tcx>,
417) {
418 if !mono.has_infer_types() {
419 res.insert(mono);
420 return;
421 }
422 let args = tcx.mk_args(&mono.value);
423 let preds = tcx.predicates_of(did).instantiate(tcx, args);
424 let mut mset = MonoSet::all(args);
425 rap_debug!("[solve_unbound] did = {did:?}, mset={mset:?}");
426 for pred in preds.predicates.iter() {
427 if let Some(trait_pred) = pred.as_trait_clause() {
428 let trait_pred = trait_pred.skip_binder();
429
430 rap_trace!("[solve_unbound] pred: {:?}", trait_pred);
431
432 let trait_def_id = trait_pred.trait_ref.def_id;
433 if tcx.is_lang_item(trait_def_id, LangItem::Sized)
435 || tcx.is_lang_item(trait_def_id, LangItem::Copy)
436 {
437 continue;
438 }
439
440 let mut p = MonoSet::new();
441
442 for impl_did in tcx
443 .all_impls(trait_def_id)
444 .chain(tcx.inherent_impls(trait_def_id).iter().map(|did| *did))
445 {
446 let impl_trait_ref = tcx.impl_trait_ref(impl_did).unwrap().skip_binder();
448
449 rap_trace!("impl_trait_ref: {}", impl_trait_ref);
450 if !impl_did.is_local() && !impl_trait_ref.self_ty().is_primitive() {
454 continue;
455 }
456 if let Some(mono) = unify_trait(
459 trait_pred.trait_ref,
460 impl_trait_ref,
461 args,
462 &infcx,
463 &cause,
464 param_env,
465 tcx,
466 ) {
467 p.insert(mono);
468 }
469 }
470 mset = mset.merge(&p, tcx);
471 rap_trace!("[solve_unbound] mset: {:?}", mset);
472 }
473 }
474
475 rap_trace!("[solve_unbound] (final) mset: {:?}", mset);
476 for mono in mset.monos {
477 res.insert(mono);
478 }
479}
480
481fn unify_trait<'tcx>(
484 lhs: ty::TraitRef<'tcx>,
485 rhs: ty::TraitRef<'tcx>,
486 identity: &[ty::GenericArg<'tcx>],
487 infcx: &InferCtxt<'tcx>,
488 cause: &ObligationCause<'tcx>,
489 param_env: ty::ParamEnv<'tcx>,
490 tcx: TyCtxt<'tcx>,
491) -> Option<Mono<'tcx>> {
492 rap_trace!("[unify_trait] lhs: {:?}, rhs: {:?}", lhs, rhs);
493 if lhs.def_id != rhs.def_id {
494 return None;
495 }
496
497 assert!(lhs.args.len() == rhs.args.len());
498 let mut s = Mono::new(identity);
499 for (lhs_arg, rhs_arg) in lhs.args.iter().zip(rhs.args.iter()) {
500 if let (Some(lhs_ty), Some(rhs_ty)) = (lhs_arg.as_type(), rhs_arg.as_type()) {
501 if rhs_ty.has_infer_types() || rhs_ty.has_param() {
502 return None;
504 }
505 let mono = unify_ty(lhs_ty, rhs_ty, identity, infcx, cause, param_env)?;
506 rap_trace!("[unify_trait] unified mono: {:?}", mono);
507 s = s.merge(&mono, tcx)?;
508 }
509 }
510 Some(s)
511}
512
513pub fn resolve_mono_apis<'tcx>(
514 fn_did: DefId,
515 available_ty: &HashSet<TyWrapper<'tcx>>,
516 tcx: TyCtxt<'tcx>,
517) -> MonoSet<'tcx> {
518 if !is_fn_solvable(fn_did, tcx) {
520 return MonoSet::empty();
521 }
522
523 let ret = get_mono_set(fn_did, &available_ty, tcx).instantiate_unbound(tcx);
525
526 let ret = ret.filter_by_trait_bound(fn_did, tcx);
528
529 ret
530}
531
532pub fn add_transform_tys<'tcx>(available_ty: &mut HashSet<TyWrapper<'tcx>>, tcx: TyCtxt<'tcx>) {
533 let mut new_tys = Vec::new();
534 available_ty.iter().for_each(|ty| {
535 new_tys.push(
536 Ty::new_ref(
537 tcx,
538 tcx.lifetimes.re_erased,
539 (*ty).into(),
540 ty::Mutability::Not,
541 )
542 .into(),
543 );
544 new_tys.push(Ty::new_ref(
545 tcx,
546 tcx.lifetimes.re_erased,
547 (*ty).into(),
548 ty::Mutability::Mut,
549 ));
550 new_tys.push(Ty::new_ref(
551 tcx,
552 tcx.lifetimes.re_erased,
553 Ty::new_slice(tcx, (*ty).into()),
554 ty::Mutability::Not,
555 ));
556 new_tys.push(Ty::new_ref(
557 tcx,
558 tcx.lifetimes.re_erased,
559 Ty::new_slice(tcx, (*ty).into()),
560 ty::Mutability::Mut,
561 ));
562 });
563
564 new_tys.into_iter().for_each(|ty| {
565 available_ty.insert(ty.into());
566 });
567}
568
569pub fn eliminate_infer_var<'tcx>(
570 fn_did: DefId,
571 args: &[ty::GenericArg<'tcx>],
572 tcx: TyCtxt<'tcx>,
573) -> Vec<ty::GenericArg<'tcx>> {
574 let mut res = Vec::new();
575 let identity = ty::GenericArgs::identity_for_item(tcx, fn_did);
576 for (i, arg) in args.iter().enumerate() {
577 if let Some(ty) = arg.as_type() {
578 if ty.is_ty_var() {
579 res.push(identity[i]);
580 } else {
581 res.push(*arg);
582 }
583 } else {
584 res.push(*arg);
585 }
586 }
587 res
588}
589
590pub fn get_unbound_generic_candidates<'tcx>(tcx: TyCtxt<'tcx>) -> Vec<ty::Ty<'tcx>> {
593 vec![
594 tcx.types.bool,
595 tcx.types.char,
596 tcx.types.u8,
597 tcx.types.i8,
598 tcx.types.i32,
599 tcx.types.u32,
600 tcx.types.f32,
603 Ty::new_imm_ref(
605 tcx,
606 tcx.lifetimes.re_erased,
607 Ty::new_slice(tcx, tcx.types.u8),
608 ),
609 Ty::new_mut_ref(
610 tcx,
611 tcx.lifetimes.re_erased,
612 Ty::new_slice(tcx, tcx.types.u8),
613 ),
614 ]
615}
616
617pub fn get_impls<'tcx>(
618 tcx: TyCtxt<'tcx>,
619 fn_did: DefId,
620 args: GenericArgsRef<'tcx>,
621) -> HashSet<DefId> {
622 let mut impls = HashSet::new();
623 let preds = tcx.predicates_of(fn_did).instantiate(tcx, args);
624 for (pred, _) in preds {
625 if let Some(trait_pred) = pred.as_trait_clause() {
626 let trait_ref: rustc_type_ir::TraitRef<TyCtxt<'tcx>> =
627 trait_pred.skip_binder().trait_ref;
628 let res = tcx.codegen_select_candidate(
636 TypingEnv::fully_monomorphized().as_query_input(trait_ref),
637 );
638 if let Ok(source) = res {
639 match source {
640 ImplSource::UserDefined(data) => {
641 if data.impl_def_id.is_local() {
642 impls.insert(data.impl_def_id);
643 }
644 }
645 _ => {}
646 }
647 }
648 }
650 }
651 rap_trace!("fn: {:?} args: {:?} impls: {:?}", fn_did, args, impls);
652 impls
653}