rapx/analysis/core/api_dependency/
mono.rs

1#![allow(warnings, unused)]
2
3use super::graph::TyWrapper;
4use super::utils::{self, fn_sig_with_generic_args};
5use crate::analysis::utils::def_path::path_str_def_id;
6use crate::{rap_debug, rap_trace};
7use rand::Rng;
8use rand::seq::SliceRandom;
9use rustc_hir::LangItem;
10use rustc_hir::def_id::DefId;
11use rustc_infer::infer::DefineOpaqueTypes;
12use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
13use rustc_infer::traits::{ImplSource, Obligation, ObligationCause};
14use rustc_middle::ty::{self, GenericArgsRef, Ty, TyCtxt, TypeVisitableExt, TypingEnv};
15use rustc_span::DUMMY_SP;
16use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt as _;
17use std::collections::HashSet;
18
19static MAX_STEP_SET_SIZE: usize = 1000;
20
21#[derive(Clone, Debug, Hash, PartialEq, Eq)]
22pub struct Mono<'tcx> {
23    pub value: Vec<ty::GenericArg<'tcx>>,
24}
25
26impl<'tcx> FromIterator<ty::GenericArg<'tcx>> for Mono<'tcx> {
27    fn from_iter<T>(iter: T) -> Self
28    where
29        T: IntoIterator<Item = ty::GenericArg<'tcx>>,
30    {
31        Mono {
32            value: iter.into_iter().collect(),
33        }
34    }
35}
36
37impl<'tcx> Mono<'tcx> {
38    pub fn new(identity: &[ty::GenericArg<'tcx>]) -> Self {
39        Mono {
40            value: Vec::from(identity),
41        }
42    }
43
44    fn has_infer_types(&self) -> bool {
45        self.value.iter().any(|arg| match arg.kind() {
46            ty::GenericArgKind::Type(ty) => ty.has_infer_types(),
47            _ => false,
48        })
49    }
50
51    fn mut_arg_at(&mut self, idx: usize) -> &mut ty::GenericArg<'tcx> {
52        &mut self.value[idx]
53    }
54
55    fn merge(&self, other: &Mono<'tcx>, tcx: TyCtxt<'tcx>) -> Option<Mono<'tcx>> {
56        assert!(self.value.len() == other.value.len());
57        let mut res = Vec::new();
58        for i in 0..self.value.len() {
59            let arg = self.value[i];
60            let other_arg = other.value[i];
61            let new_arg = if let Some(ty) = arg.as_type() {
62                let other_ty = other_arg.expect_ty();
63                if ty.is_ty_var() && other_ty.is_ty_var() {
64                    arg
65                } else if ty.is_ty_var() {
66                    other_arg
67                } else if other_ty.is_ty_var() {
68                    arg
69                } else if utils::is_ty_eq(ty, other_ty, tcx) {
70                    arg
71                } else {
72                    return None;
73                }
74            } else {
75                arg
76            };
77            res.push(new_arg);
78        }
79        Some(Mono { value: res })
80    }
81
82    fn fill_unbound_var(&self, tcx: TyCtxt<'tcx>) -> Vec<Mono<'tcx>> {
83        let candidates = get_unbound_generic_candidates(tcx);
84        let mut res = vec![self.clone()];
85        rap_trace!("fill unbound: {:?}", self);
86
87        for (i, arg) in self.value.iter().enumerate() {
88            if let Some(ty) = arg.as_type() {
89                if ty.is_ty_var() {
90                    let mut last = Vec::new();
91                    std::mem::swap(&mut res, &mut last);
92                    last.into_iter().for_each(|mono| {
93                        for candidate in &candidates {
94                            let mut new_mono = mono.clone();
95                            *new_mono.mut_arg_at(i) = (*candidate).into();
96                            res.push(new_mono);
97                        }
98                    });
99                }
100            }
101        }
102        res
103    }
104}
105
106#[derive(Clone, Debug, Default)]
107pub struct MonoSet<'tcx> {
108    pub monos: Vec<Mono<'tcx>>,
109}
110
111impl<'tcx> MonoSet<'tcx> {
112    pub fn all(identity: &[ty::GenericArg<'tcx>]) -> MonoSet<'tcx> {
113        MonoSet {
114            monos: vec![Mono::new(identity)],
115        }
116    }
117
118    pub fn empty() -> MonoSet<'tcx> {
119        MonoSet { monos: Vec::new() }
120    }
121
122    pub fn count(&self) -> usize {
123        self.monos.len()
124    }
125
126    pub fn at(&self, no: usize) -> &Mono<'tcx> {
127        &self.monos[no]
128    }
129
130    pub fn is_empty(&self) -> bool {
131        self.monos.is_empty()
132    }
133
134    pub fn new() -> MonoSet<'tcx> {
135        MonoSet { monos: Vec::new() }
136    }
137
138    pub fn insert(&mut self, mono: Mono<'tcx>) {
139        self.monos.push(mono);
140    }
141
142    pub fn merge(&mut self, other: &MonoSet<'tcx>, tcx: TyCtxt<'tcx>) -> MonoSet<'tcx> {
143        let mut res = MonoSet::new();
144
145        for args in self.monos.iter() {
146            for other_args in other.monos.iter() {
147                let merged = args.merge(other_args, tcx);
148                if let Some(mono) = merged {
149                    res.insert(mono);
150                }
151            }
152        }
153        res
154    }
155
156    fn filter_unbound_solution(mut self) -> Self {
157        self.monos.retain(|mono| mono.has_infer_types());
158        self
159    }
160
161    // if the unbound generic type is still exist (this could happen
162    // if `T` has no trait bounds at all)
163    // we substitute the unbound generic type with predefined type candidates
164    fn instantiate_unbound(&self, tcx: TyCtxt<'tcx>) -> Self {
165        let mut res = MonoSet::new();
166        for mono in &self.monos {
167            let filled = mono.fill_unbound_var(tcx);
168            res.monos.extend(filled);
169        }
170        res
171    }
172
173    fn erase_region_var(&mut self, tcx: TyCtxt<'tcx>) {
174        for mono in &mut self.monos {
175            mono.value
176                .iter_mut()
177                .for_each(|arg| *arg = tcx.erase_and_anonymize_regions(*arg))
178        }
179    }
180
181    pub fn filter(mut self, f: impl Fn(&Mono<'tcx>) -> bool) -> Self {
182        self.monos.retain(|args| f(args));
183        self
184    }
185
186    fn filter_by_trait_bound(mut self, fn_did: DefId, tcx: TyCtxt<'tcx>) -> Self {
187        //let early_fn_sig = tcx.fn_sig(fn_did);
188        self.monos
189            .retain(|args| is_args_fit_trait_bound(fn_did, &args.value, tcx));
190        self
191    }
192
193    pub fn random_sample<R: Rng>(&mut self, rng: &mut R) {
194        if self.monos.len() <= MAX_STEP_SET_SIZE {
195            return;
196        }
197        self.monos.shuffle(rng);
198        self.monos.truncate(MAX_STEP_SET_SIZE);
199    }
200}
201
202/// try to unfiy lhs = rhs,
203/// e.g.,
204/// try_unify(Vec<T>, Vec<i32>, ...) = Some(i32)
205/// try_unify(Vec<T>, i32, ...) = None
206fn unify_ty<'tcx>(
207    lhs: Ty<'tcx>,
208    rhs: Ty<'tcx>,
209    identity: &[ty::GenericArg<'tcx>],
210    infcx: &InferCtxt<'tcx>,
211    cause: &ObligationCause<'tcx>,
212    param_env: ty::ParamEnv<'tcx>,
213) -> Option<Mono<'tcx>> {
214    // rap_info!("check {} = {}", lhs, rhs);
215    infcx.probe(|_| {
216        match infcx
217            .at(cause, param_env)
218            .eq(DefineOpaqueTypes::Yes, lhs, rhs)
219        {
220            Ok(_infer_ok) => {
221                // rap_trace!("[infer_ok] {} = {} : {:?}", lhs, rhs, infer_ok);
222                let mono = identity
223                    .iter()
224                    .map(|arg| match arg.kind() {
225                        ty::GenericArgKind::Lifetime(region) => {
226                            infcx.resolve_vars_if_possible(region).into()
227                        }
228                        ty::GenericArgKind::Type(ty) => infcx.resolve_vars_if_possible(ty).into(),
229                        ty::GenericArgKind::Const(ct) => infcx.resolve_vars_if_possible(ct).into(),
230                    })
231                    .collect();
232                Some(mono)
233            }
234            Err(_e) => {
235                // rap_trace!("[infer_err] {} = {} : {:?}", lhs, rhs, e);
236                None
237            }
238        }
239    })
240}
241
242fn is_args_fit_trait_bound<'tcx>(
243    fn_did: DefId,
244    args: &[ty::GenericArg<'tcx>],
245    tcx: TyCtxt<'tcx>,
246) -> bool {
247    let args = tcx.mk_args(args);
248    // rap_info!(
249    //     "fn: {:?} args: {:?} identity: {:?}",
250    //     fn_did,
251    //     args,
252    //     ty::GenericArgs::identity_for_item(tcx, fn_did)
253    // );
254    let infcx = tcx.infer_ctxt().build(ty::TypingMode::PostAnalysis);
255    let pred = tcx.predicates_of(fn_did);
256    let inst_pred = pred.instantiate(tcx, args);
257    let param_env = tcx.param_env(fn_did);
258    rap_trace!(
259        "[trait bound] check {}",
260        tcx.def_path_str_with_args(fn_did, args)
261    );
262
263    for pred in inst_pred.predicates.iter() {
264        let obligation = Obligation::new(
265            tcx,
266            ObligationCause::dummy(),
267            param_env,
268            pred.as_predicate(),
269        );
270
271        let res = infcx.evaluate_obligation(&obligation);
272        match res {
273            Ok(eva) => {
274                if !eva.may_apply() {
275                    rap_trace!("[trait bound] check fail for {pred:?}");
276                    return false;
277                }
278            }
279            Err(_) => {
280                rap_trace!("[trait bound] check fail for {pred:?}");
281                return false;
282            }
283        }
284    }
285    rap_trace!("[trait bound] check succ");
286    true
287}
288
289fn is_fn_solvable<'tcx>(fn_did: DefId, tcx: TyCtxt<'tcx>) -> bool {
290    for pred in tcx
291        .predicates_of(fn_did)
292        .instantiate_identity(tcx)
293        .predicates
294    {
295        if let Some(pred) = pred.as_trait_clause() {
296            let trait_did = pred.skip_binder().trait_ref.def_id;
297            if tcx.is_lang_item(trait_did, LangItem::Fn)
298                || tcx.is_lang_item(trait_did, LangItem::FnMut)
299                || tcx.is_lang_item(trait_did, LangItem::FnOnce)
300            {
301                return false;
302            }
303        }
304    }
305    true
306}
307
308fn get_mono_set<'tcx>(
309    fn_did: DefId,
310    available_ty: &HashSet<TyWrapper<'tcx>>,
311    tcx: TyCtxt<'tcx>,
312) -> MonoSet<'tcx> {
313    let mut rng = rand::rng();
314
315    // sample from reachable types
316    rap_debug!("[get_mono_set] fn_did: {:?}", fn_did);
317    let infcx = tcx
318        .infer_ctxt()
319        .ignoring_regions()
320        .build(ty::TypingMode::PostAnalysis);
321    let param_env = tcx.param_env(fn_did);
322    let dummy_cause = ObligationCause::dummy();
323    let fresh_args = infcx.fresh_args_for_item(DUMMY_SP, fn_did);
324    // this replace generic types in fn_sig to infer var, e.g. fn(Vec<T>, i32) => fn(Vec<?0>, i32)
325    let fn_sig = fn_sig_with_generic_args(fn_did, fresh_args, tcx);
326    let generics = tcx.generics_of(fn_did);
327
328    // print fresh_args for debugging
329    for i in 0..fresh_args.len() {
330        rap_trace!(
331            "[get_mono_set] arg#{}: {:?} -> {:?}",
332            i,
333            generics.param_at(i, tcx).name,
334            fresh_args[i]
335        );
336    }
337
338    let mut s = MonoSet::all(&fresh_args);
339
340    rap_trace!("[get_mono_set] initialize s: {:?}", s);
341
342    let mut cnt = 0;
343
344    for input_ty in fn_sig.inputs().iter() {
345        cnt += 1;
346        if !input_ty.has_infer_types() {
347            continue;
348        }
349        rap_trace!("[get_mono_set] input_ty#{}: {:?}", cnt - 1, input_ty);
350
351        let mut reachable_set =
352            available_ty
353                .iter()
354                .fold(MonoSet::new(), |mut reachable_set, ty| {
355                    if let Some(mono) = unify_ty(
356                        *input_ty,
357                        (*ty).into(),
358                        &fresh_args,
359                        &infcx,
360                        &dummy_cause,
361                        param_env,
362                    ) {
363                        reachable_set.insert(mono);
364                    }
365                    reachable_set
366                });
367        reachable_set.random_sample(&mut rng);
368        rap_debug!(
369            "[get_mono_set] size: s = {}, input = {}",
370            s.count(),
371            reachable_set.count()
372        );
373        s = s.merge(&reachable_set, tcx);
374        s.random_sample(&mut rng);
375    }
376
377    rap_trace!("[get_mono_set] after input types: {:?}", s);
378
379    let mut res = MonoSet::new();
380
381    for mono in s.monos {
382        solve_unbound_type_generics(
383            fn_did,
384            mono,
385            &mut res,
386            // &fresh_args,
387            &infcx,
388            &dummy_cause,
389            param_env,
390            tcx,
391        );
392    }
393
394    // erase infer region var
395    res.erase_region_var(tcx);
396
397    res
398}
399
400fn is_special_std_ty<'tcx>(def_id: DefId, tcx: TyCtxt<'tcx>) -> bool {
401    let allowed_std_ty = [
402        tcx.lang_items().string().unwrap(),
403        path_str_def_id(tcx, "std::vec::Vec"),
404    ];
405
406    allowed_std_ty.contains(&def_id)
407}
408
409fn solve_unbound_type_generics<'tcx>(
410    did: DefId,
411    mono: Mono<'tcx>,
412    res: &mut MonoSet<'tcx>,
413    infcx: &InferCtxt<'tcx>,
414    cause: &ObligationCause<'tcx>,
415    param_env: ty::ParamEnv<'tcx>,
416    tcx: TyCtxt<'tcx>,
417) {
418    if !mono.has_infer_types() {
419        res.insert(mono);
420        return;
421    }
422    let args = tcx.mk_args(&mono.value);
423    let preds = tcx.predicates_of(did).instantiate(tcx, args);
424    let mut mset = MonoSet::all(args);
425    rap_debug!("[solve_unbound] did = {did:?}, mset={mset:?}");
426    for pred in preds.predicates.iter() {
427        if let Some(trait_pred) = pred.as_trait_clause() {
428            let trait_pred = trait_pred.skip_binder();
429
430            rap_trace!("[solve_unbound] pred: {:?}", trait_pred);
431
432            let trait_def_id = trait_pred.trait_ref.def_id;
433            // ignore Sized trait
434            if tcx.is_lang_item(trait_def_id, LangItem::Sized)
435                || tcx.is_lang_item(trait_def_id, LangItem::Copy)
436            {
437                continue;
438            }
439
440            let mut p = MonoSet::new();
441
442            for impl_did in tcx
443                .all_impls(trait_def_id)
444                .chain(tcx.inherent_impls(trait_def_id).iter().map(|did| *did))
445            {
446                // format: <arg0 as Trait<arg1, arg2>>
447                let impl_trait_ref = tcx.impl_trait_ref(impl_did).unwrap().skip_binder();
448
449                rap_trace!("impl_trait_ref: {}", impl_trait_ref);
450                // filter irrelevant implementation. We only consider implementation if:
451                // 1. it is local
452                // 2. it is not local, but its' self_ty is a primitive
453                if !impl_did.is_local() && !impl_trait_ref.self_ty().is_primitive() {
454                    continue;
455                }
456                // rap_trace!("impl_trait_ref: {}", impl_trait_ref);
457
458                if let Some(mono) = unify_trait(
459                    trait_pred.trait_ref,
460                    impl_trait_ref,
461                    args,
462                    &infcx,
463                    &cause,
464                    param_env,
465                    tcx,
466                ) {
467                    p.insert(mono);
468                }
469            }
470            mset = mset.merge(&p, tcx);
471            rap_trace!("[solve_unbound] mset: {:?}", mset);
472        }
473    }
474
475    rap_trace!("[solve_unbound] (final) mset: {:?}", mset);
476    for mono in mset.monos {
477        res.insert(mono);
478    }
479}
480
481/// only handle the case that rhs does not have any infer types
482/// e.g., `<T as Into<U>> == <Foo as Into<Bar>> => Some(T=Foo, U=Bar))`
483fn unify_trait<'tcx>(
484    lhs: ty::TraitRef<'tcx>,
485    rhs: ty::TraitRef<'tcx>,
486    identity: &[ty::GenericArg<'tcx>],
487    infcx: &InferCtxt<'tcx>,
488    cause: &ObligationCause<'tcx>,
489    param_env: ty::ParamEnv<'tcx>,
490    tcx: TyCtxt<'tcx>,
491) -> Option<Mono<'tcx>> {
492    rap_trace!("[unify_trait] lhs: {:?}, rhs: {:?}", lhs, rhs);
493    if lhs.def_id != rhs.def_id {
494        return None;
495    }
496
497    assert!(lhs.args.len() == rhs.args.len());
498    let mut s = Mono::new(identity);
499    for (lhs_arg, rhs_arg) in lhs.args.iter().zip(rhs.args.iter()) {
500        if let (Some(lhs_ty), Some(rhs_ty)) = (lhs_arg.as_type(), rhs_arg.as_type()) {
501            if rhs_ty.has_infer_types() || rhs_ty.has_param() {
502                // if rhs has infer types, we cannot unify it with lhs
503                return None;
504            }
505            let mono = unify_ty(lhs_ty, rhs_ty, identity, infcx, cause, param_env)?;
506            rap_trace!("[unify_trait] unified mono: {:?}", mono);
507            s = s.merge(&mono, tcx)?;
508        }
509    }
510    Some(s)
511}
512
513pub fn resolve_mono_apis<'tcx>(
514    fn_did: DefId,
515    available_ty: &HashSet<TyWrapper<'tcx>>,
516    tcx: TyCtxt<'tcx>,
517) -> MonoSet<'tcx> {
518    // 1. check solvable condition
519    if !is_fn_solvable(fn_did, tcx) {
520        return MonoSet::empty();
521    }
522
523    // 2. get mono set from available types
524    let ret = get_mono_set(fn_did, &available_ty, tcx).instantiate_unbound(tcx);
525
526    // 3. check trait bound
527    let ret = ret.filter_by_trait_bound(fn_did, tcx);
528
529    ret
530}
531
532pub fn add_transform_tys<'tcx>(available_ty: &mut HashSet<TyWrapper<'tcx>>, tcx: TyCtxt<'tcx>) {
533    let mut new_tys = Vec::new();
534    available_ty.iter().for_each(|ty| {
535        new_tys.push(
536            Ty::new_ref(
537                tcx,
538                tcx.lifetimes.re_erased,
539                (*ty).into(),
540                ty::Mutability::Not,
541            )
542            .into(),
543        );
544        new_tys.push(Ty::new_ref(
545            tcx,
546            tcx.lifetimes.re_erased,
547            (*ty).into(),
548            ty::Mutability::Mut,
549        ));
550        new_tys.push(Ty::new_ref(
551            tcx,
552            tcx.lifetimes.re_erased,
553            Ty::new_slice(tcx, (*ty).into()),
554            ty::Mutability::Not,
555        ));
556        new_tys.push(Ty::new_ref(
557            tcx,
558            tcx.lifetimes.re_erased,
559            Ty::new_slice(tcx, (*ty).into()),
560            ty::Mutability::Mut,
561        ));
562    });
563
564    new_tys.into_iter().for_each(|ty| {
565        available_ty.insert(ty.into());
566    });
567}
568
569pub fn eliminate_infer_var<'tcx>(
570    fn_did: DefId,
571    args: &[ty::GenericArg<'tcx>],
572    tcx: TyCtxt<'tcx>,
573) -> Vec<ty::GenericArg<'tcx>> {
574    let mut res = Vec::new();
575    let identity = ty::GenericArgs::identity_for_item(tcx, fn_did);
576    for (i, arg) in args.iter().enumerate() {
577        if let Some(ty) = arg.as_type() {
578            if ty.is_ty_var() {
579                res.push(identity[i]);
580            } else {
581                res.push(*arg);
582            }
583        } else {
584            res.push(*arg);
585        }
586    }
587    res
588}
589
590/// if type parameter is unbound, e.g., `T` in `fn foo<T>()`,
591/// we use some predefined types to substitute it
592pub fn get_unbound_generic_candidates<'tcx>(tcx: TyCtxt<'tcx>) -> Vec<ty::Ty<'tcx>> {
593    vec![
594        tcx.types.bool,
595        tcx.types.char,
596        tcx.types.u8,
597        tcx.types.i8,
598        tcx.types.i32,
599        tcx.types.u32,
600        // tcx.types.i64,
601        // tcx.types.u64,
602        tcx.types.f32,
603        // tcx.types.f64,
604        Ty::new_imm_ref(
605            tcx,
606            tcx.lifetimes.re_erased,
607            Ty::new_slice(tcx, tcx.types.u8),
608        ),
609        Ty::new_mut_ref(
610            tcx,
611            tcx.lifetimes.re_erased,
612            Ty::new_slice(tcx, tcx.types.u8),
613        ),
614    ]
615}
616
617pub fn get_impls<'tcx>(
618    tcx: TyCtxt<'tcx>,
619    fn_did: DefId,
620    args: GenericArgsRef<'tcx>,
621) -> HashSet<DefId> {
622    let mut impls = HashSet::new();
623    let preds = tcx.predicates_of(fn_did).instantiate(tcx, args);
624    for (pred, _) in preds {
625        if let Some(trait_pred) = pred.as_trait_clause() {
626            let trait_ref: rustc_type_ir::TraitRef<TyCtxt<'tcx>> =
627                trait_pred.skip_binder().trait_ref;
628            // ignore Sized trait
629            // if tcx.is_lang_item(trait_ref.def_id, LangItem::Sized)
630            //     || tcx.def_path_str(trait_ref.def_id) == "std::default::Default"
631            // {
632            //     continue;
633            // }
634
635            let res = tcx.codegen_select_candidate(
636                TypingEnv::fully_monomorphized().as_query_input(trait_ref),
637            );
638            if let Ok(source) = res {
639                match source {
640                    ImplSource::UserDefined(data) => {
641                        if data.impl_def_id.is_local() {
642                            impls.insert(data.impl_def_id);
643                        }
644                    }
645                    _ => {}
646                }
647            }
648            // rap_debug!("{:?} => {:?}", trait_ref, res);
649        }
650    }
651    rap_trace!("fn: {:?} args: {:?} impls: {:?}", fn_did, args, impls);
652    impls
653}