rapx/analysis/core/api_dependency/
mono.rs

1use super::graph::TyWrapper;
2use super::utils::{self, fn_sig_with_generic_args};
3use crate::analysis::utils::def_path::path_str_def_id;
4use crate::{rap_debug, rap_trace};
5use rand::seq::SliceRandom;
6use rand::Rng;
7use rustc_hir::def_id::DefId;
8use rustc_hir::LangItem;
9use rustc_infer::infer::DefineOpaqueTypes;
10use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
11use rustc_infer::traits::{ImplSource, Obligation, ObligationCause};
12use rustc_middle::ty::{self, GenericArgsRef, Ty, TyCtxt, TypeVisitableExt, TypingEnv};
13use rustc_span::DUMMY_SP;
14use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt as _;
15use std::collections::HashSet;
16
17static MAX_STEP_SET_SIZE: usize = 1000;
18
19#[derive(Clone, Debug, Hash, PartialEq, Eq)]
20pub struct Mono<'tcx> {
21    pub value: Vec<ty::GenericArg<'tcx>>,
22}
23
24impl<'tcx> FromIterator<ty::GenericArg<'tcx>> for Mono<'tcx> {
25    fn from_iter<T>(iter: T) -> Self
26    where
27        T: IntoIterator<Item = ty::GenericArg<'tcx>>,
28    {
29        Mono {
30            value: iter.into_iter().collect(),
31        }
32    }
33}
34
35impl<'tcx> Mono<'tcx> {
36    pub fn new(identity: &[ty::GenericArg<'tcx>]) -> Self {
37        Mono {
38            value: Vec::from(identity),
39        }
40    }
41
42    fn has_infer_types(&self) -> bool {
43        self.value.iter().any(|arg| match arg.kind() {
44            ty::GenericArgKind::Type(ty) => ty.has_infer_types(),
45            _ => false,
46        })
47    }
48
49    fn mut_arg_at(&mut self, idx: usize) -> &mut ty::GenericArg<'tcx> {
50        &mut self.value[idx]
51    }
52
53    fn merge(&self, other: &Mono<'tcx>, tcx: TyCtxt<'tcx>) -> Option<Mono<'tcx>> {
54        assert!(self.value.len() == other.value.len());
55        let mut res = Vec::new();
56        for i in 0..self.value.len() {
57            let arg = self.value[i];
58            let other_arg = other.value[i];
59            let new_arg = if let Some(ty) = arg.as_type() {
60                let other_ty = other_arg.expect_ty();
61                if ty.is_ty_var() && other_ty.is_ty_var() {
62                    arg
63                } else if ty.is_ty_var() {
64                    other_arg
65                } else if other_ty.is_ty_var() {
66                    arg
67                } else if utils::is_ty_eq(ty, other_ty, tcx) {
68                    arg
69                } else {
70                    return None;
71                }
72            } else {
73                arg
74            };
75            res.push(new_arg);
76        }
77        Some(Mono { value: res })
78    }
79
80    fn fill_unbound_var(&self, tcx: TyCtxt<'tcx>) -> Vec<Mono<'tcx>> {
81        let candidates = get_unbound_generic_candidates(tcx);
82        let mut res = vec![self.clone()];
83        rap_trace!("fill unbound: {:?}", self);
84
85        for (i, arg) in self.value.iter().enumerate() {
86            if let Some(ty) = arg.as_type() {
87                if ty.is_ty_var() {
88                    let mut last = Vec::new();
89                    std::mem::swap(&mut res, &mut last);
90                    last.into_iter().for_each(|mono| {
91                        for candidate in &candidates {
92                            let mut new_mono = mono.clone();
93                            *new_mono.mut_arg_at(i) = (*candidate).into();
94                            res.push(new_mono);
95                        }
96                    });
97                }
98            }
99        }
100        res
101    }
102}
103
104#[derive(Clone, Debug, Default)]
105pub struct MonoSet<'tcx> {
106    pub monos: Vec<Mono<'tcx>>,
107}
108
109impl<'tcx> MonoSet<'tcx> {
110    pub fn all(identity: &[ty::GenericArg<'tcx>]) -> MonoSet<'tcx> {
111        MonoSet {
112            monos: vec![Mono::new(identity)],
113        }
114    }
115
116    pub fn empty() -> MonoSet<'tcx> {
117        MonoSet { monos: Vec::new() }
118    }
119
120    pub fn count(&self) -> usize {
121        self.monos.len()
122    }
123
124    pub fn at(&self, no: usize) -> &Mono<'tcx> {
125        &self.monos[no]
126    }
127
128    pub fn is_empty(&self) -> bool {
129        self.monos.is_empty()
130    }
131
132    pub fn new() -> MonoSet<'tcx> {
133        MonoSet { monos: Vec::new() }
134    }
135
136    pub fn insert(&mut self, mono: Mono<'tcx>) {
137        self.monos.push(mono);
138    }
139
140    pub fn merge(&mut self, other: &MonoSet<'tcx>, tcx: TyCtxt<'tcx>) -> MonoSet<'tcx> {
141        let mut res = MonoSet::new();
142
143        for args in self.monos.iter() {
144            for other_args in other.monos.iter() {
145                let merged = args.merge(other_args, tcx);
146                if let Some(mono) = merged {
147                    res.insert(mono);
148                }
149            }
150        }
151        res
152    }
153
154    fn filter_unbound_solution(mut self) -> Self {
155        self.monos.retain(|mono| mono.has_infer_types());
156        self
157    }
158
159    // if the unbound generic type is still exist (this could happen
160    // if `T` has no trait bounds at all)
161    // we substitute the unbound generic type with predefined type candidates
162    fn instantiate_unbound(&self, tcx: TyCtxt<'tcx>) -> Self {
163        let mut res = MonoSet::new();
164        for mono in &self.monos {
165            let filled = mono.fill_unbound_var(tcx);
166            res.monos.extend(filled);
167        }
168        res
169    }
170
171    fn erase_region_var(&mut self, tcx: TyCtxt<'tcx>) {
172        for mono in &mut self.monos {
173            mono.value
174                .iter_mut()
175                .for_each(|arg| *arg = tcx.erase_and_anonymize_regions(*arg))
176        }
177    }
178
179    pub fn filter(mut self, f: impl Fn(&Mono<'tcx>) -> bool) -> Self {
180        self.monos.retain(|args| f(args));
181        self
182    }
183
184    fn filter_by_trait_bound(mut self, fn_did: DefId, tcx: TyCtxt<'tcx>) -> Self {
185        let early_fn_sig = tcx.fn_sig(fn_did);
186        self.monos
187            .retain(|args| is_args_fit_trait_bound(fn_did, &args.value, tcx));
188        self
189    }
190
191    pub fn random_sample<R: Rng>(&mut self, rng: &mut R) {
192        if self.monos.len() <= MAX_STEP_SET_SIZE {
193            return;
194        }
195        self.monos.shuffle(rng);
196        self.monos.truncate(MAX_STEP_SET_SIZE);
197    }
198}
199
200/// try to unfiy lhs = rhs,
201/// e.g.,
202/// try_unify(Vec<T>, Vec<i32>, ...) = Some(i32)
203/// try_unify(Vec<T>, i32, ...) = None
204fn unify_ty<'tcx>(
205    lhs: Ty<'tcx>,
206    rhs: Ty<'tcx>,
207    identity: &[ty::GenericArg<'tcx>],
208    infcx: &InferCtxt<'tcx>,
209    cause: &ObligationCause<'tcx>,
210    param_env: ty::ParamEnv<'tcx>,
211) -> Option<Mono<'tcx>> {
212    // rap_info!("check {} = {}", lhs, rhs);
213    infcx.probe(|_| {
214        match infcx
215            .at(cause, param_env)
216            .eq(DefineOpaqueTypes::Yes, lhs, rhs)
217        {
218            Ok(infer_ok) => {
219                // rap_trace!("[infer_ok] {} = {} : {:?}", lhs, rhs, infer_ok);
220                let mono = identity
221                    .iter()
222                    .map(|arg| match arg.kind() {
223                        ty::GenericArgKind::Lifetime(region) => {
224                            infcx.resolve_vars_if_possible(region).into()
225                        }
226                        ty::GenericArgKind::Type(ty) => infcx.resolve_vars_if_possible(ty).into(),
227                        ty::GenericArgKind::Const(ct) => infcx.resolve_vars_if_possible(ct).into(),
228                    })
229                    .collect();
230                Some(mono)
231            }
232            Err(e) => {
233                // rap_trace!("[infer_err] {} = {} : {:?}", lhs, rhs, e);
234                None
235            }
236        }
237    })
238}
239
240fn is_args_fit_trait_bound<'tcx>(
241    fn_did: DefId,
242    args: &[ty::GenericArg<'tcx>],
243    tcx: TyCtxt<'tcx>,
244) -> bool {
245    let args = tcx.mk_args(args);
246    // rap_info!(
247    //     "fn: {:?} args: {:?} identity: {:?}",
248    //     fn_did,
249    //     args,
250    //     ty::GenericArgs::identity_for_item(tcx, fn_did)
251    // );
252    let infcx = tcx.infer_ctxt().build(ty::TypingMode::PostAnalysis);
253    let pred = tcx.predicates_of(fn_did);
254    let inst_pred = pred.instantiate(tcx, args);
255    let param_env = tcx.param_env(fn_did);
256    rap_trace!(
257        "[trait bound] check {}",
258        tcx.def_path_str_with_args(fn_did, args)
259    );
260
261    for pred in inst_pred.predicates.iter() {
262        let obligation = Obligation::new(
263            tcx,
264            ObligationCause::dummy(),
265            param_env,
266            pred.as_predicate(),
267        );
268
269        let res = infcx.evaluate_obligation(&obligation);
270        match res {
271            Ok(eva) => {
272                if !eva.may_apply() {
273                    rap_trace!("[trait bound] check fail for {pred:?}");
274                    return false;
275                }
276            }
277            Err(_) => {
278                rap_trace!("[trait bound] check fail for {pred:?}");
279                return false;
280            }
281        }
282    }
283    rap_trace!("[trait bound] check succ");
284    true
285}
286
287fn is_fn_solvable<'tcx>(fn_did: DefId, tcx: TyCtxt<'tcx>) -> bool {
288    for pred in tcx
289        .predicates_of(fn_did)
290        .instantiate_identity(tcx)
291        .predicates
292    {
293        if let Some(pred) = pred.as_trait_clause() {
294            let trait_did = pred.skip_binder().trait_ref.def_id;
295            if tcx.is_lang_item(trait_did, LangItem::Fn)
296                || tcx.is_lang_item(trait_did, LangItem::FnMut)
297                || tcx.is_lang_item(trait_did, LangItem::FnOnce)
298            {
299                return false;
300            }
301        }
302    }
303    true
304}
305
306fn get_mono_set<'tcx>(
307    fn_did: DefId,
308    available_ty: &HashSet<TyWrapper<'tcx>>,
309    tcx: TyCtxt<'tcx>,
310) -> MonoSet<'tcx> {
311    let mut rng = rand::rng();
312
313    // sample from reachable types
314    rap_debug!("[get_mono_set] fn_did: {:?}", fn_did);
315    let infcx = tcx
316        .infer_ctxt()
317        .ignoring_regions()
318        .build(ty::TypingMode::PostAnalysis);
319    let param_env = tcx.param_env(fn_did);
320    let dummy_cause = ObligationCause::dummy();
321    let fresh_args = infcx.fresh_args_for_item(DUMMY_SP, fn_did);
322    // this replace generic types in fn_sig to infer var, e.g. fn(Vec<T>, i32) => fn(Vec<?0>, i32)
323    let fn_sig = fn_sig_with_generic_args(fn_did, fresh_args, tcx);
324    let generics = tcx.generics_of(fn_did);
325
326    // print fresh_args for debugging
327    for i in 0..fresh_args.len() {
328        rap_trace!(
329            "[get_mono_set] arg#{}: {:?} -> {:?}",
330            i,
331            generics.param_at(i, tcx).name,
332            fresh_args[i]
333        );
334    }
335
336    let mut s = MonoSet::all(&fresh_args);
337
338    rap_trace!("[get_mono_set] initialize s: {:?}", s);
339
340    let mut cnt = 0;
341
342    for input_ty in fn_sig.inputs().iter() {
343        cnt += 1;
344        if !input_ty.has_infer_types() {
345            continue;
346        }
347        rap_trace!("[get_mono_set] input_ty#{}: {:?}", cnt - 1, input_ty);
348
349        let mut reachable_set =
350            available_ty
351                .iter()
352                .fold(MonoSet::new(), |mut reachable_set, ty| {
353                    if let Some(mono) = unify_ty(
354                        *input_ty,
355                        (*ty).into(),
356                        &fresh_args,
357                        &infcx,
358                        &dummy_cause,
359                        param_env,
360                    ) {
361                        reachable_set.insert(mono);
362                    }
363                    reachable_set
364                });
365        reachable_set.random_sample(&mut rng);
366        rap_debug!(
367            "[get_mono_set] size: s = {}, input = {}",
368            s.count(),
369            reachable_set.count()
370        );
371        s = s.merge(&reachable_set, tcx);
372        s.random_sample(&mut rng);
373    }
374
375    rap_trace!("[get_mono_set] after input types: {:?}", s);
376
377    let mut res = MonoSet::new();
378
379    for mono in s.monos {
380        solve_unbound_type_generics(
381            fn_did,
382            mono,
383            &mut res,
384            // &fresh_args,
385            &infcx,
386            &dummy_cause,
387            param_env,
388            tcx,
389        );
390    }
391
392    // erase infer region var
393    res.erase_region_var(tcx);
394
395    res
396}
397
398fn is_special_std_ty<'tcx>(def_id: DefId, tcx: TyCtxt<'tcx>) -> bool {
399    let allowed_std_ty = [
400        tcx.lang_items().string().unwrap(),
401        path_str_def_id(tcx, "std::vec::Vec"),
402    ];
403
404    allowed_std_ty.contains(&def_id)
405}
406
407fn solve_unbound_type_generics<'tcx>(
408    did: DefId,
409    mono: Mono<'tcx>,
410    res: &mut MonoSet<'tcx>,
411    infcx: &InferCtxt<'tcx>,
412    cause: &ObligationCause<'tcx>,
413    param_env: ty::ParamEnv<'tcx>,
414    tcx: TyCtxt<'tcx>,
415) {
416    if !mono.has_infer_types() {
417        res.insert(mono);
418        return;
419    }
420    let args = tcx.mk_args(&mono.value);
421    let preds = tcx.predicates_of(did).instantiate(tcx, args);
422    let mut mset = MonoSet::all(args);
423    rap_debug!("[solve_unbound] did = {did:?}, mset={mset:?}");
424    for pred in preds.predicates.iter() {
425        if let Some(trait_pred) = pred.as_trait_clause() {
426            let trait_pred = trait_pred.skip_binder();
427
428            rap_trace!("[solve_unbound] pred: {:?}", trait_pred);
429
430            let trait_def_id = trait_pred.trait_ref.def_id;
431            // ignore Sized trait
432            if tcx.is_lang_item(trait_def_id, LangItem::Sized)
433                || tcx.is_lang_item(trait_def_id, LangItem::Copy)
434            {
435                continue;
436            }
437
438            let mut p = MonoSet::new();
439
440            for impl_did in tcx
441                .all_impls(trait_def_id)
442                .chain(tcx.inherent_impls(trait_def_id).iter().map(|did| *did))
443            {
444                // format: <arg0 as Trait<arg1, arg2>>
445                let impl_trait_ref = tcx.impl_trait_ref(impl_did).unwrap().skip_binder();
446
447                rap_trace!("impl_trait_ref: {}", impl_trait_ref);
448                // filter irrelevant implementation. We only consider implementation if:
449                // 1. it is local
450                // 2. it is not local, but its' self_ty is a primitive
451                if !impl_did.is_local() && !impl_trait_ref.self_ty().is_primitive() {
452                    continue;
453                }
454                // rap_trace!("impl_trait_ref: {}", impl_trait_ref);
455
456                if let Some(mono) = unify_trait(
457                    trait_pred.trait_ref,
458                    impl_trait_ref,
459                    args,
460                    &infcx,
461                    &cause,
462                    param_env,
463                    tcx,
464                ) {
465                    p.insert(mono);
466                }
467            }
468            mset = mset.merge(&p, tcx);
469            rap_trace!("[solve_unbound] mset: {:?}", mset);
470        }
471    }
472
473    rap_trace!("[solve_unbound] (final) mset: {:?}", mset);
474    for mono in mset.monos {
475        res.insert(mono);
476    }
477}
478
479/// only handle the case that rhs does not have any infer types
480/// e.g., `<T as Into<U>> == <Foo as Into<Bar>> => Some(T=Foo, U=Bar))`
481fn unify_trait<'tcx>(
482    lhs: ty::TraitRef<'tcx>,
483    rhs: ty::TraitRef<'tcx>,
484    identity: &[ty::GenericArg<'tcx>],
485    infcx: &InferCtxt<'tcx>,
486    cause: &ObligationCause<'tcx>,
487    param_env: ty::ParamEnv<'tcx>,
488    tcx: TyCtxt<'tcx>,
489) -> Option<Mono<'tcx>> {
490    rap_trace!("[unify_trait] lhs: {:?}, rhs: {:?}", lhs, rhs);
491    if lhs.def_id != rhs.def_id {
492        return None;
493    }
494
495    assert!(lhs.args.len() == rhs.args.len());
496    let mut s = Mono::new(identity);
497    for (lhs_arg, rhs_arg) in lhs.args.iter().zip(rhs.args.iter()) {
498        if let (Some(lhs_ty), Some(rhs_ty)) = (lhs_arg.as_type(), rhs_arg.as_type()) {
499            if rhs_ty.has_infer_types() || rhs_ty.has_param() {
500                // if rhs has infer types, we cannot unify it with lhs
501                return None;
502            }
503            let mono = unify_ty(lhs_ty, rhs_ty, identity, infcx, cause, param_env)?;
504            rap_trace!("[unify_trait] unified mono: {:?}", mono);
505            s = s.merge(&mono, tcx)?;
506        }
507    }
508    Some(s)
509}
510
511pub fn resolve_mono_apis<'tcx>(
512    fn_did: DefId,
513    available_ty: &HashSet<TyWrapper<'tcx>>,
514    tcx: TyCtxt<'tcx>,
515) -> MonoSet<'tcx> {
516    // 1. check solvable condition
517    if !is_fn_solvable(fn_did, tcx) {
518        return MonoSet::empty();
519    }
520
521    // 2. get mono set from available types
522    let ret = get_mono_set(fn_did, &available_ty, tcx).instantiate_unbound(tcx);
523
524    // 3. check trait bound
525    let ret = ret.filter_by_trait_bound(fn_did, tcx);
526
527    ret
528}
529
530pub fn add_transform_tys<'tcx>(available_ty: &mut HashSet<TyWrapper<'tcx>>, tcx: TyCtxt<'tcx>) {
531    let mut new_tys = Vec::new();
532    available_ty.iter().for_each(|ty| {
533        new_tys.push(
534            Ty::new_ref(
535                tcx,
536                tcx.lifetimes.re_erased,
537                (*ty).into(),
538                ty::Mutability::Not,
539            )
540            .into(),
541        );
542        new_tys.push(Ty::new_ref(
543            tcx,
544            tcx.lifetimes.re_erased,
545            (*ty).into(),
546            ty::Mutability::Mut,
547        ));
548        new_tys.push(Ty::new_ref(
549            tcx,
550            tcx.lifetimes.re_erased,
551            Ty::new_slice(tcx, (*ty).into()),
552            ty::Mutability::Not,
553        ));
554        new_tys.push(Ty::new_ref(
555            tcx,
556            tcx.lifetimes.re_erased,
557            Ty::new_slice(tcx, (*ty).into()),
558            ty::Mutability::Mut,
559        ));
560    });
561
562    new_tys.into_iter().for_each(|ty| {
563        available_ty.insert(ty.into());
564    });
565}
566
567pub fn eliminate_infer_var<'tcx>(
568    fn_did: DefId,
569    args: &[ty::GenericArg<'tcx>],
570    tcx: TyCtxt<'tcx>,
571) -> Vec<ty::GenericArg<'tcx>> {
572    let mut res = Vec::new();
573    let identity = ty::GenericArgs::identity_for_item(tcx, fn_did);
574    for (i, arg) in args.iter().enumerate() {
575        if let Some(ty) = arg.as_type() {
576            if ty.is_ty_var() {
577                res.push(identity[i]);
578            } else {
579                res.push(*arg);
580            }
581        } else {
582            res.push(*arg);
583        }
584    }
585    res
586}
587
588/// if type parameter is unbound, e.g., `T` in `fn foo<T>()`,
589/// we use some predefined types to substitute it
590pub fn get_unbound_generic_candidates<'tcx>(tcx: TyCtxt<'tcx>) -> Vec<ty::Ty<'tcx>> {
591    vec![
592        tcx.types.bool,
593        tcx.types.char,
594        tcx.types.u8,
595        tcx.types.i8,
596        tcx.types.i32,
597        tcx.types.u32,
598        // tcx.types.i64,
599        // tcx.types.u64,
600        tcx.types.f32,
601        // tcx.types.f64,
602        Ty::new_imm_ref(
603            tcx,
604            tcx.lifetimes.re_erased,
605            Ty::new_slice(tcx, tcx.types.u8),
606        ),
607        Ty::new_mut_ref(
608            tcx,
609            tcx.lifetimes.re_erased,
610            Ty::new_slice(tcx, tcx.types.u8),
611        ),
612    ]
613}
614
615pub fn get_impls<'tcx>(
616    tcx: TyCtxt<'tcx>,
617    fn_did: DefId,
618    args: GenericArgsRef<'tcx>,
619) -> HashSet<DefId> {
620    let mut impls = HashSet::new();
621    let preds = tcx.predicates_of(fn_did).instantiate(tcx, args);
622    for (pred, _) in preds {
623        if let Some(trait_pred) = pred.as_trait_clause() {
624            let trait_ref: rustc_type_ir::TraitRef<TyCtxt<'tcx>> =
625                trait_pred.skip_binder().trait_ref;
626            // ignore Sized trait
627            // if tcx.is_lang_item(trait_ref.def_id, LangItem::Sized)
628            //     || tcx.def_path_str(trait_ref.def_id) == "std::default::Default"
629            // {
630            //     continue;
631            // }
632
633            let res = tcx.codegen_select_candidate(
634                TypingEnv::fully_monomorphized().as_query_input(trait_ref),
635            );
636            if let Ok(source) = res {
637                match source {
638                    ImplSource::UserDefined(data) => {
639                        if data.impl_def_id.is_local() {
640                            impls.insert(data.impl_def_id);
641                        }
642                    }
643                    _ => {}
644                }
645            }
646            // rap_debug!("{:?} => {:?}", trait_ref, res);
647        }
648    }
649    rap_trace!("fn: {:?} args: {:?} impls: {:?}", fn_did, args, impls);
650    impls
651}